From be68af7d059ee65f5d62dbe9e20ba00e5e910c10 Mon Sep 17 00:00:00 2001 From: Stefan Ellmauthaler Date: Wed, 1 Jun 2022 13:07:56 +0200 Subject: [PATCH] Refactor bdd to use rc refcell --- Cargo.lock | 17 +--- bin/Cargo.toml | 2 +- lib/src/adf.rs | 2 +- lib/src/obdd.rs | 214 +++++++++++++++++++++++++----------------------- 4 files changed, 113 insertions(+), 122 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4192abd..201253f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,26 +19,11 @@ dependencies = [ "test-log", ] -[[package]] -name = "adf_bdd" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e781519ea5434514f014476c02ccee777b28e600ad58fadca195715acb194c69" -dependencies = [ - "biodivine-lib-bdd", - "derivative", - "lexical-sort", - "log", - "nom", - "serde", - "serde_json", -] - [[package]] name = "adf_bdd-solver" version = "0.2.4" dependencies = [ - "adf_bdd 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)", + "adf_bdd", "assert_cmd", "assert_fs", "clap", diff --git a/bin/Cargo.toml b/bin/Cargo.toml index 587d4f1..41e7894 100644 --- a/bin/Cargo.toml +++ b/bin/Cargo.toml @@ -14,7 +14,7 @@ name = "adf-bdd" path = "src/main.rs" [dependencies] -adf_bdd = { version="0.2.4", default-features = false } +adf_bdd = { version="0.2.4", path="../lib", default-features = false } clap = {version = "3.1.14", features = [ "derive", "cargo", "env" ]} log = { version = "0.4", features = [ "max_level_trace", "release_max_level_info" ] } serde = { version = "1.0", features = ["derive","rc"] } diff --git a/lib/src/adf.rs b/lib/src/adf.rs index 03ebffc..ddc14f5 100644 --- a/lib/src/adf.rs +++ b/lib/src/adf.rs @@ -946,7 +946,7 @@ mod test { "s(a). s(b). s(c). s(d). ac(a,c(v)). ac(b,b). ac(c,and(a,b)). ac(d,neg(b)).", ) .unwrap(); - let mut adf = Adf::from_parser(&parser); + let adf = Adf::from_parser(&parser); let mut v = adf.ac.clone(); let mut fcs = adf.facet_count(&v); diff --git a/lib/src/obdd.rs b/lib/src/obdd.rs index 97f9268..85e9d38 100644 --- a/lib/src/obdd.rs +++ b/lib/src/obdd.rs @@ -4,6 +4,7 @@ pub mod vectorize; use crate::datatypes::*; use serde::{Deserialize, Serialize}; use std::collections::HashSet; +use std::rc::Rc; use std::{cell::RefCell, cmp::min, collections::HashMap, fmt::Display}; /// Contains the data of (possibly) multiple roBDDs, managed over one collection of nodes. @@ -11,20 +12,21 @@ use std::{cell::RefCell, cmp::min, collections::HashMap, fmt::Display}; /// Each roBDD is identified by its corresponding [`Term`], which implicitly identifies the root node of a roBDD. #[derive(Debug, Serialize, Deserialize)] pub struct Bdd { - pub(crate) nodes: Vec, + nodes: Rc>>, #[cfg(feature = "variablelist")] #[serde(skip)] - var_deps: Vec>, - #[serde(with = "vectorize")] - cache: HashMap, + var_deps: Rc>>>, + //#[serde(with = "vectorize")] + #[serde(skip)] + cache: Rc>>, #[serde(skip, default = "Bdd::default_count_cache")] - count_cache: RefCell>, + count_cache: Rc>>, } impl Display for Bdd { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { writeln!(f, " ")?; - for (idx, elem) in self.nodes.iter().enumerate() { + for (idx, elem) in self.nodes.borrow().iter().enumerate() { writeln!(f, "{} {}", idx, *elem)?; } Ok(()) @@ -47,18 +49,18 @@ impl Bdd { nodes: vec![BddNode::bot_node(), BddNode::top_node()], #[cfg(feature = "variablelist")] var_deps: vec![HashSet::new(), HashSet::new()], - cache: HashMap::new(), + cache: Rc::new(RefCell::new(HashMap::new())), count_cache: RefCell::new(HashMap::new()), } } #[cfg(feature = "adhoccounting")] { let result = Self { - nodes: vec![BddNode::bot_node(), BddNode::top_node()], + nodes: Rc::new(RefCell::new(vec![BddNode::bot_node(), BddNode::top_node()])), #[cfg(feature = "variablelist")] - var_deps: vec![HashSet::new(), HashSet::new()], - cache: HashMap::new(), - count_cache: RefCell::new(HashMap::new()), + var_deps: Rc::new(RefCell::new(vec![HashSet::new(), HashSet::new()])), + cache: Rc::new(RefCell::new(HashMap::new())), + count_cache: Rc::new(RefCell::new(HashMap::new())), }; result .count_cache @@ -72,8 +74,8 @@ impl Bdd { } } - fn default_count_cache() -> RefCell> { - RefCell::new(HashMap::new()) + fn default_count_cache() -> Rc>> { + Rc::new(RefCell::new(HashMap::new())) } /// Instantiates a [variable][crate::datatypes::Var] and returns the representing roBDD as a [`Term`][crate::datatypes::Term]. @@ -134,7 +136,7 @@ impl Bdd { positive: &[Var], ) -> Vec<(Vec, Vec)> { let mut result = Vec::new(); - let node = self.nodes[tree.value()]; + let node = self.nodes.borrow()[tree.value()]; let var = node.var(); if tree.is_truth_value() { return Vec::new(); @@ -175,11 +177,11 @@ impl Bdd { } /// Restrict the value of a given [variable][crate::datatypes::Var] to **val**. - pub fn restrict(&mut self, tree: Term, var: Var, val: bool) -> Term { - let node = self.nodes[tree.0]; + pub fn restrict(&self, tree: Term, var: Var, val: bool) -> Term { + let node = self.nodes.borrow()[tree.0]; #[cfg(feature = "variablelist")] { - if !self.var_deps[tree.value()].contains(&var) { + if !self.var_deps.borrow()[tree.value()].contains(&var) { return tree; } } @@ -201,7 +203,7 @@ impl Bdd { } /// Creates an roBDD, based on the relation of three roBDDs, which are in an `if-then-else` relation. - fn if_then_else(&mut self, i: Term, t: Term, e: Term) -> Term { + fn if_then_else(&self, i: Term, t: Term, e: Term) -> Term { if i == Term::TOP { t } else if i == Term::BOT { @@ -212,10 +214,10 @@ impl Bdd { i } else { let minvar = Var(min( - self.nodes[i.value()].var().value(), + self.nodes.borrow()[i.value()].var().value(), min( - self.nodes[t.value()].var().value(), - self.nodes[e.value()].var().value(), + self.nodes.borrow()[t.value()].var().value(), + self.nodes.borrow()[e.value()].var().value(), ), )); let itop = self.restrict(i, minvar, true); @@ -233,76 +235,74 @@ impl Bdd { /// Creates a new node in the roBDD. /// It will not create duplicate nodes and uses already existing nodes, if applicable. - pub fn node(&mut self, var: Var, lo: Term, hi: Term) -> Term { + pub fn node(&self, var: Var, lo: Term, hi: Term) -> Term { if lo == hi { lo } else { let node = BddNode::new(var, lo, hi); - match self.cache.get(&node) { - Some(t) => *t, - None => { - let new_term = Term(self.nodes.len()); - self.nodes.push(node); - self.cache.insert(node, new_term); - #[cfg(feature = "variablelist")] - { - let mut var_set: HashSet = self.var_deps[lo.value()] - .union(&self.var_deps[hi.value()]) - .copied() - .collect(); - var_set.insert(var); - self.var_deps.push(var_set); - } - log::debug!("newterm: {} as {:?}", new_term, node); - #[cfg(feature = "adhoccounting")] - { - let mut count_cache = self.count_cache.borrow_mut(); - let (lo_counts, lo_paths, lodepth) = - *count_cache.get(&lo).expect("Cache corrupted"); - let (hi_counts, hi_paths, hidepth) = - *count_cache.get(&hi).expect("Cache corrupted"); - log::debug!( - "lo (cm: {}, mo: {}, p-: {}, p+: {}, dp: {})", - lo_counts.cmodels, - lo_counts.models, - lo_paths.cmodels, - lo_paths.models, - lodepth - ); - log::debug!( - "hi (cm: {}, mo: {}, p-: {}, p+: {}, dp: {})", - hi_counts.cmodels, - hi_counts.models, - hi_paths.cmodels, - hi_paths.models, - hidepth - ); - let (lo_exp, hi_exp) = if lodepth > hidepth { - (1, 2usize.pow((lodepth - hidepth) as u32)) - } else { - (2usize.pow((hidepth - lodepth) as u32), 1) - }; - log::debug!("lo_exp {}, hi_exp {}", lo_exp, hi_exp); - count_cache.insert( - new_term, - ( - ( - lo_counts.cmodels * lo_exp + hi_counts.cmodels * hi_exp, - lo_counts.models * lo_exp + hi_counts.models * hi_exp, - ) - .into(), - ( - lo_paths.cmodels + hi_paths.cmodels, - lo_paths.models + hi_paths.models, - ) - .into(), - std::cmp::max(lodepth, hidepth) + 1, - ), - ); - } - new_term - } + if let Some(t) = self.cache.borrow().get(&node) { + return *t; } + let new_term = Term(self.nodes.borrow().len()); + self.nodes.borrow_mut().push(node); + self.cache.borrow_mut().insert(node, new_term); + #[cfg(feature = "variablelist")] + { + let mut var_set: HashSet = self.var_deps.borrow()[lo.value()] + .union(&self.var_deps.borrow()[hi.value()]) + .copied() + .collect(); + var_set.insert(var); + self.var_deps.borrow_mut().push(var_set); + } + log::debug!("newterm: {} as {:?}", new_term, node); + #[cfg(feature = "adhoccounting")] + { + let mut count_cache = self.count_cache.borrow_mut(); + let (lo_counts, lo_paths, lodepth) = + *count_cache.get(&lo).expect("Cache corrupted"); + let (hi_counts, hi_paths, hidepth) = + *count_cache.get(&hi).expect("Cache corrupted"); + log::debug!( + "lo (cm: {}, mo: {}, p-: {}, p+: {}, dp: {})", + lo_counts.cmodels, + lo_counts.models, + lo_paths.cmodels, + lo_paths.models, + lodepth + ); + log::debug!( + "hi (cm: {}, mo: {}, p-: {}, p+: {}, dp: {})", + hi_counts.cmodels, + hi_counts.models, + hi_paths.cmodels, + hi_paths.models, + hidepth + ); + let (lo_exp, hi_exp) = if lodepth > hidepth { + (1, 2usize.pow((lodepth - hidepth) as u32)) + } else { + (2usize.pow((hidepth - lodepth) as u32), 1) + }; + log::debug!("lo_exp {}, hi_exp {}", lo_exp, hi_exp); + count_cache.insert( + new_term, + ( + ( + lo_counts.cmodels * lo_exp + hi_counts.cmodels * hi_exp, + lo_counts.models * lo_exp + hi_counts.models * hi_exp, + ) + .into(), + ( + lo_paths.cmodels + hi_paths.cmodels, + lo_paths.models + hi_paths.models, + ) + .into(), + std::cmp::max(lodepth, hidepth) + 1, + ), + ); + } + new_term } } @@ -369,7 +369,7 @@ impl Bdd { } else if term == Term::BOT { (ModelCounts::bot(), ModelCounts::bot(), 0) } else { - let node = &self.nodes[term.0]; + let node = &self.nodes.borrow()[term.0]; let mut lo_exp = 0u32; let mut hi_exp = 0u32; let (lo_counts, lo_paths, lodepth) = self.modelcount_naive(node.lo()); @@ -405,7 +405,7 @@ impl Bdd { return *result; } let result = { - let node = &self.nodes[term.0]; + let node = &self.nodes.borrow()[term.0]; let mut lo_exp = 0u32; let mut hi_exp = 0u32; let (lo_counts, lo_paths, lodepth) = self.modelcount_memoization(node.lo()); @@ -437,7 +437,7 @@ impl Bdd { } /// Repairs the internal structures after an import. - pub fn fix_import(&mut self) { + pub fn fix_import(&self) { self.generate_var_dependencies(); #[cfg(feature = "adhoccounting")] { @@ -447,25 +447,25 @@ impl Bdd { self.count_cache .borrow_mut() .insert(Term::BOT, (ModelCounts::bot(), ModelCounts::bot(), 0)); - for i in 0..self.nodes.len() { + for i in 0..self.nodes.borrow().len() { log::debug!("fixing Term({})", i); self.modelcount_memoization(Term(i)); } } } - fn generate_var_dependencies(&mut self) { + fn generate_var_dependencies(&self) { #[cfg(feature = "variablelist")] - self.nodes.iter().for_each(|node| { + self.nodes.borrow().iter().for_each(|node| { if node.var() >= Var::BOT { - self.var_deps.push(HashSet::new()); + self.var_deps.borrow_mut().push(HashSet::new()); } else { - let mut var_set: HashSet = self.var_deps[node.lo().value()] - .union(&self.var_deps[node.hi().value()]) + let mut var_set: HashSet = self.var_deps.borrow()[node.lo().value()] + .union(&self.var_deps.borrow()[node.hi().value()]) .copied() .collect(); var_set.insert(node.var()); - self.var_deps.push(var_set); + self.var_deps.borrow_mut().push(var_set); } }); } @@ -474,7 +474,7 @@ impl Bdd { pub fn var_dependencies(&self, tree: Term) -> HashSet { #[cfg(feature = "variablelist")] { - self.var_deps[tree.value()].clone() + self.var_deps.borrow()[tree.value()].clone() } #[cfg(not(feature = "variablelist"))] { @@ -525,7 +525,7 @@ mod test { #[test] fn newbdd() { let bdd = Bdd::new(); - assert_eq!(bdd.nodes.len(), 2); + assert_eq!(bdd.nodes.borrow().len(), 2); } #[test] @@ -534,7 +534,7 @@ mod test { assert_eq!(Bdd::constant(true), Term::TOP); assert_eq!(Bdd::constant(false), Term::BOT); - assert_eq!(bdd.nodes.len(), 2); + assert_eq!(bdd.nodes.borrow().len(), 2); } #[test] @@ -643,10 +643,15 @@ mod test { let formula4 = bdd.and(v3, formula2); assert_eq!(bdd.models(v1, false), (1, 1).into()); - let mut x = bdd.count_cache.get_mut().iter().collect::>(); + let mut x = bdd + .count_cache + .borrow_mut() + .iter() + .map(|(x, y)| (*x, *y)) + .collect::>(); x.sort(); log::debug!("{:?}", formula1); - for x in bdd.nodes.iter().enumerate() { + for x in bdd.nodes.borrow().iter().enumerate() { log::debug!("{:?}", x); } log::debug!("{:?}", x); @@ -757,14 +762,15 @@ mod test { bdd.not(formula4); let constructed = bdd.var_deps.clone(); - bdd.var_deps = Vec::new(); + bdd.var_deps = Rc::new(RefCell::new(Vec::new())); bdd.generate_var_dependencies(); constructed + .borrow() .iter() - .zip(bdd.var_deps.iter()) + .zip(bdd.var_deps.borrow().iter()) .for_each(|(left, right)| { - assert!(left == right); + assert!(*left == *right); }); assert_eq!(