diff --git a/lib/src/obdd.rs b/lib/src/obdd.rs index 0a653c4..10acea2 100644 --- a/lib/src/obdd.rs +++ b/lib/src/obdd.rs @@ -19,6 +19,10 @@ pub struct Bdd { cache: HashMap, #[serde(skip, default = "Bdd::default_count_cache")] count_cache: RefCell>, + #[serde(skip)] + ite_cache: HashMap<(Term, Term, Term), Term>, + #[serde(skip)] + restrict_cache: HashMap<(Term, Var, bool), Term>, } impl Display for Bdd { @@ -49,6 +53,8 @@ impl Bdd { var_deps: vec![HashSet::new(), HashSet::new()], cache: HashMap::new(), count_cache: RefCell::new(HashMap::new()), + ite_cache: HashMap::new(), + restrict_cache: HashMap::new(), } } #[cfg(feature = "adhoccounting")] @@ -59,6 +65,8 @@ impl Bdd { var_deps: vec![HashSet::new(), HashSet::new()], cache: HashMap::new(), count_cache: RefCell::new(HashMap::new()), + ite_cache: HashMap::new(), + restrict_cache: HashMap::new(), }; result .count_cache @@ -176,26 +184,34 @@ impl Bdd { /// Restrict the value of a given [variable][crate::datatypes::Var] to **val**. pub fn restrict(&mut self, tree: Term, var: Var, val: bool) -> Term { - let node = self.nodes[tree.0]; - #[cfg(feature = "variablelist")] - { - if !self.var_deps[tree.value()].contains(&var) { - return tree; - } - } - #[allow(clippy::collapsible_else_if)] - // Readability of algorithm > code-elegance - if node.var() > var || node.var() >= Var::BOT { - tree - } else if node.var() < var { - let lonode = self.restrict(node.lo(), var, val); - let hinode = self.restrict(node.hi(), var, val); - self.node(node.var(), lonode, hinode) + if let Some(result) = self.restrict_cache.get(&(tree, var, val)) { + *result } else { - if val { - self.restrict(node.hi(), var, val) + let node = self.nodes[tree.0]; + #[cfg(feature = "variablelist")] + { + if !self.var_deps[tree.value()].contains(&var) { + return tree; + } + } + #[allow(clippy::collapsible_else_if)] + // Readability of algorithm > code-elegance + if node.var() > var || node.var() >= Var::BOT { + tree + } else if node.var() < var { + let lonode = self.restrict(node.lo(), var, val); + let hinode = self.restrict(node.hi(), var, val); + self.node(node.var(), lonode, hinode) } else { - self.restrict(node.lo(), var, val) + if val { + let result = self.restrict(node.hi(), var, val); + self.restrict_cache.insert((tree, var, val), result); + result + } else { + let result = self.restrict(node.lo(), var, val); + self.restrict_cache.insert((tree, var, val), result); + result + } } } } @@ -210,6 +226,8 @@ impl Bdd { t } else if t == Term::TOP && e == Term::BOT { i + } else if let Some(result) = self.ite_cache.get(&(i, t, e)) { + *result } else { let minvar = Var(min( self.nodes[i.value()].var().value(), @@ -227,7 +245,9 @@ impl Bdd { let top_ite = self.if_then_else(itop, ttop, etop); let bot_ite = self.if_then_else(ibot, tbot, ebot); - self.node(minvar, bot_ite, top_ite) + let result = self.node(minvar, bot_ite, top_ite); + self.ite_cache.insert((i, t, e), result); + result } }