mirror of
https://github.com/ellmau/adf-obdd.git
synced 2025-12-19 09:29:36 +01:00
Refactor bdd to use rc refcell
This commit is contained in:
parent
78949b18e5
commit
be68af7d05
17
Cargo.lock
generated
17
Cargo.lock
generated
@ -19,26 +19,11 @@ dependencies = [
|
||||
"test-log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "adf_bdd"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e781519ea5434514f014476c02ccee777b28e600ad58fadca195715acb194c69"
|
||||
dependencies = [
|
||||
"biodivine-lib-bdd",
|
||||
"derivative",
|
||||
"lexical-sort",
|
||||
"log",
|
||||
"nom",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "adf_bdd-solver"
|
||||
version = "0.2.4"
|
||||
dependencies = [
|
||||
"adf_bdd 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"adf_bdd",
|
||||
"assert_cmd",
|
||||
"assert_fs",
|
||||
"clap",
|
||||
|
||||
@ -14,7 +14,7 @@ name = "adf-bdd"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
adf_bdd = { version="0.2.4", default-features = false }
|
||||
adf_bdd = { version="0.2.4", path="../lib", default-features = false }
|
||||
clap = {version = "3.1.14", features = [ "derive", "cargo", "env" ]}
|
||||
log = { version = "0.4", features = [ "max_level_trace", "release_max_level_info" ] }
|
||||
serde = { version = "1.0", features = ["derive","rc"] }
|
||||
|
||||
@ -946,7 +946,7 @@ mod test {
|
||||
"s(a). s(b). s(c). s(d). ac(a,c(v)). ac(b,b). ac(c,and(a,b)). ac(d,neg(b)).",
|
||||
)
|
||||
.unwrap();
|
||||
let mut adf = Adf::from_parser(&parser);
|
||||
let adf = Adf::from_parser(&parser);
|
||||
|
||||
let mut v = adf.ac.clone();
|
||||
let mut fcs = adf.facet_count(&v);
|
||||
|
||||
214
lib/src/obdd.rs
214
lib/src/obdd.rs
@ -4,6 +4,7 @@ pub mod vectorize;
|
||||
use crate::datatypes::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashSet;
|
||||
use std::rc::Rc;
|
||||
use std::{cell::RefCell, cmp::min, collections::HashMap, fmt::Display};
|
||||
|
||||
/// Contains the data of (possibly) multiple roBDDs, managed over one collection of nodes.
|
||||
@ -11,20 +12,21 @@ use std::{cell::RefCell, cmp::min, collections::HashMap, fmt::Display};
|
||||
/// Each roBDD is identified by its corresponding [`Term`], which implicitly identifies the root node of a roBDD.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Bdd {
|
||||
pub(crate) nodes: Vec<BddNode>,
|
||||
nodes: Rc<RefCell<Vec<BddNode>>>,
|
||||
#[cfg(feature = "variablelist")]
|
||||
#[serde(skip)]
|
||||
var_deps: Vec<HashSet<Var>>,
|
||||
#[serde(with = "vectorize")]
|
||||
cache: HashMap<BddNode, Term>,
|
||||
var_deps: Rc<RefCell<Vec<HashSet<Var>>>>,
|
||||
//#[serde(with = "vectorize")]
|
||||
#[serde(skip)]
|
||||
cache: Rc<RefCell<HashMap<BddNode, Term>>>,
|
||||
#[serde(skip, default = "Bdd::default_count_cache")]
|
||||
count_cache: RefCell<HashMap<Term, CountNode>>,
|
||||
count_cache: Rc<RefCell<HashMap<Term, CountNode>>>,
|
||||
}
|
||||
|
||||
impl Display for Bdd {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
writeln!(f, " ")?;
|
||||
for (idx, elem) in self.nodes.iter().enumerate() {
|
||||
for (idx, elem) in self.nodes.borrow().iter().enumerate() {
|
||||
writeln!(f, "{} {}", idx, *elem)?;
|
||||
}
|
||||
Ok(())
|
||||
@ -47,18 +49,18 @@ impl Bdd {
|
||||
nodes: vec![BddNode::bot_node(), BddNode::top_node()],
|
||||
#[cfg(feature = "variablelist")]
|
||||
var_deps: vec![HashSet::new(), HashSet::new()],
|
||||
cache: HashMap::new(),
|
||||
cache: Rc::new(RefCell::new(HashMap::new())),
|
||||
count_cache: RefCell::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "adhoccounting")]
|
||||
{
|
||||
let result = Self {
|
||||
nodes: vec![BddNode::bot_node(), BddNode::top_node()],
|
||||
nodes: Rc::new(RefCell::new(vec![BddNode::bot_node(), BddNode::top_node()])),
|
||||
#[cfg(feature = "variablelist")]
|
||||
var_deps: vec![HashSet::new(), HashSet::new()],
|
||||
cache: HashMap::new(),
|
||||
count_cache: RefCell::new(HashMap::new()),
|
||||
var_deps: Rc::new(RefCell::new(vec![HashSet::new(), HashSet::new()])),
|
||||
cache: Rc::new(RefCell::new(HashMap::new())),
|
||||
count_cache: Rc::new(RefCell::new(HashMap::new())),
|
||||
};
|
||||
result
|
||||
.count_cache
|
||||
@ -72,8 +74,8 @@ impl Bdd {
|
||||
}
|
||||
}
|
||||
|
||||
fn default_count_cache() -> RefCell<HashMap<Term, CountNode>> {
|
||||
RefCell::new(HashMap::new())
|
||||
fn default_count_cache() -> Rc<RefCell<HashMap<Term, CountNode>>> {
|
||||
Rc::new(RefCell::new(HashMap::new()))
|
||||
}
|
||||
|
||||
/// Instantiates a [variable][crate::datatypes::Var] and returns the representing roBDD as a [`Term`][crate::datatypes::Term].
|
||||
@ -134,7 +136,7 @@ impl Bdd {
|
||||
positive: &[Var],
|
||||
) -> Vec<(Vec<Var>, Vec<Var>)> {
|
||||
let mut result = Vec::new();
|
||||
let node = self.nodes[tree.value()];
|
||||
let node = self.nodes.borrow()[tree.value()];
|
||||
let var = node.var();
|
||||
if tree.is_truth_value() {
|
||||
return Vec::new();
|
||||
@ -175,11 +177,11 @@ impl Bdd {
|
||||
}
|
||||
|
||||
/// Restrict the value of a given [variable][crate::datatypes::Var] to **val**.
|
||||
pub fn restrict(&mut self, tree: Term, var: Var, val: bool) -> Term {
|
||||
let node = self.nodes[tree.0];
|
||||
pub fn restrict(&self, tree: Term, var: Var, val: bool) -> Term {
|
||||
let node = self.nodes.borrow()[tree.0];
|
||||
#[cfg(feature = "variablelist")]
|
||||
{
|
||||
if !self.var_deps[tree.value()].contains(&var) {
|
||||
if !self.var_deps.borrow()[tree.value()].contains(&var) {
|
||||
return tree;
|
||||
}
|
||||
}
|
||||
@ -201,7 +203,7 @@ impl Bdd {
|
||||
}
|
||||
|
||||
/// Creates an roBDD, based on the relation of three roBDDs, which are in an `if-then-else` relation.
|
||||
fn if_then_else(&mut self, i: Term, t: Term, e: Term) -> Term {
|
||||
fn if_then_else(&self, i: Term, t: Term, e: Term) -> Term {
|
||||
if i == Term::TOP {
|
||||
t
|
||||
} else if i == Term::BOT {
|
||||
@ -212,10 +214,10 @@ impl Bdd {
|
||||
i
|
||||
} else {
|
||||
let minvar = Var(min(
|
||||
self.nodes[i.value()].var().value(),
|
||||
self.nodes.borrow()[i.value()].var().value(),
|
||||
min(
|
||||
self.nodes[t.value()].var().value(),
|
||||
self.nodes[e.value()].var().value(),
|
||||
self.nodes.borrow()[t.value()].var().value(),
|
||||
self.nodes.borrow()[e.value()].var().value(),
|
||||
),
|
||||
));
|
||||
let itop = self.restrict(i, minvar, true);
|
||||
@ -233,76 +235,74 @@ impl Bdd {
|
||||
|
||||
/// Creates a new node in the roBDD.
|
||||
/// It will not create duplicate nodes and uses already existing nodes, if applicable.
|
||||
pub fn node(&mut self, var: Var, lo: Term, hi: Term) -> Term {
|
||||
pub fn node(&self, var: Var, lo: Term, hi: Term) -> Term {
|
||||
if lo == hi {
|
||||
lo
|
||||
} else {
|
||||
let node = BddNode::new(var, lo, hi);
|
||||
match self.cache.get(&node) {
|
||||
Some(t) => *t,
|
||||
None => {
|
||||
let new_term = Term(self.nodes.len());
|
||||
self.nodes.push(node);
|
||||
self.cache.insert(node, new_term);
|
||||
#[cfg(feature = "variablelist")]
|
||||
{
|
||||
let mut var_set: HashSet<Var> = self.var_deps[lo.value()]
|
||||
.union(&self.var_deps[hi.value()])
|
||||
.copied()
|
||||
.collect();
|
||||
var_set.insert(var);
|
||||
self.var_deps.push(var_set);
|
||||
}
|
||||
log::debug!("newterm: {} as {:?}", new_term, node);
|
||||
#[cfg(feature = "adhoccounting")]
|
||||
{
|
||||
let mut count_cache = self.count_cache.borrow_mut();
|
||||
let (lo_counts, lo_paths, lodepth) =
|
||||
*count_cache.get(&lo).expect("Cache corrupted");
|
||||
let (hi_counts, hi_paths, hidepth) =
|
||||
*count_cache.get(&hi).expect("Cache corrupted");
|
||||
log::debug!(
|
||||
"lo (cm: {}, mo: {}, p-: {}, p+: {}, dp: {})",
|
||||
lo_counts.cmodels,
|
||||
lo_counts.models,
|
||||
lo_paths.cmodels,
|
||||
lo_paths.models,
|
||||
lodepth
|
||||
);
|
||||
log::debug!(
|
||||
"hi (cm: {}, mo: {}, p-: {}, p+: {}, dp: {})",
|
||||
hi_counts.cmodels,
|
||||
hi_counts.models,
|
||||
hi_paths.cmodels,
|
||||
hi_paths.models,
|
||||
hidepth
|
||||
);
|
||||
let (lo_exp, hi_exp) = if lodepth > hidepth {
|
||||
(1, 2usize.pow((lodepth - hidepth) as u32))
|
||||
} else {
|
||||
(2usize.pow((hidepth - lodepth) as u32), 1)
|
||||
};
|
||||
log::debug!("lo_exp {}, hi_exp {}", lo_exp, hi_exp);
|
||||
count_cache.insert(
|
||||
new_term,
|
||||
(
|
||||
(
|
||||
lo_counts.cmodels * lo_exp + hi_counts.cmodels * hi_exp,
|
||||
lo_counts.models * lo_exp + hi_counts.models * hi_exp,
|
||||
)
|
||||
.into(),
|
||||
(
|
||||
lo_paths.cmodels + hi_paths.cmodels,
|
||||
lo_paths.models + hi_paths.models,
|
||||
)
|
||||
.into(),
|
||||
std::cmp::max(lodepth, hidepth) + 1,
|
||||
),
|
||||
);
|
||||
}
|
||||
new_term
|
||||
}
|
||||
if let Some(t) = self.cache.borrow().get(&node) {
|
||||
return *t;
|
||||
}
|
||||
let new_term = Term(self.nodes.borrow().len());
|
||||
self.nodes.borrow_mut().push(node);
|
||||
self.cache.borrow_mut().insert(node, new_term);
|
||||
#[cfg(feature = "variablelist")]
|
||||
{
|
||||
let mut var_set: HashSet<Var> = self.var_deps.borrow()[lo.value()]
|
||||
.union(&self.var_deps.borrow()[hi.value()])
|
||||
.copied()
|
||||
.collect();
|
||||
var_set.insert(var);
|
||||
self.var_deps.borrow_mut().push(var_set);
|
||||
}
|
||||
log::debug!("newterm: {} as {:?}", new_term, node);
|
||||
#[cfg(feature = "adhoccounting")]
|
||||
{
|
||||
let mut count_cache = self.count_cache.borrow_mut();
|
||||
let (lo_counts, lo_paths, lodepth) =
|
||||
*count_cache.get(&lo).expect("Cache corrupted");
|
||||
let (hi_counts, hi_paths, hidepth) =
|
||||
*count_cache.get(&hi).expect("Cache corrupted");
|
||||
log::debug!(
|
||||
"lo (cm: {}, mo: {}, p-: {}, p+: {}, dp: {})",
|
||||
lo_counts.cmodels,
|
||||
lo_counts.models,
|
||||
lo_paths.cmodels,
|
||||
lo_paths.models,
|
||||
lodepth
|
||||
);
|
||||
log::debug!(
|
||||
"hi (cm: {}, mo: {}, p-: {}, p+: {}, dp: {})",
|
||||
hi_counts.cmodels,
|
||||
hi_counts.models,
|
||||
hi_paths.cmodels,
|
||||
hi_paths.models,
|
||||
hidepth
|
||||
);
|
||||
let (lo_exp, hi_exp) = if lodepth > hidepth {
|
||||
(1, 2usize.pow((lodepth - hidepth) as u32))
|
||||
} else {
|
||||
(2usize.pow((hidepth - lodepth) as u32), 1)
|
||||
};
|
||||
log::debug!("lo_exp {}, hi_exp {}", lo_exp, hi_exp);
|
||||
count_cache.insert(
|
||||
new_term,
|
||||
(
|
||||
(
|
||||
lo_counts.cmodels * lo_exp + hi_counts.cmodels * hi_exp,
|
||||
lo_counts.models * lo_exp + hi_counts.models * hi_exp,
|
||||
)
|
||||
.into(),
|
||||
(
|
||||
lo_paths.cmodels + hi_paths.cmodels,
|
||||
lo_paths.models + hi_paths.models,
|
||||
)
|
||||
.into(),
|
||||
std::cmp::max(lodepth, hidepth) + 1,
|
||||
),
|
||||
);
|
||||
}
|
||||
new_term
|
||||
}
|
||||
}
|
||||
|
||||
@ -369,7 +369,7 @@ impl Bdd {
|
||||
} else if term == Term::BOT {
|
||||
(ModelCounts::bot(), ModelCounts::bot(), 0)
|
||||
} else {
|
||||
let node = &self.nodes[term.0];
|
||||
let node = &self.nodes.borrow()[term.0];
|
||||
let mut lo_exp = 0u32;
|
||||
let mut hi_exp = 0u32;
|
||||
let (lo_counts, lo_paths, lodepth) = self.modelcount_naive(node.lo());
|
||||
@ -405,7 +405,7 @@ impl Bdd {
|
||||
return *result;
|
||||
}
|
||||
let result = {
|
||||
let node = &self.nodes[term.0];
|
||||
let node = &self.nodes.borrow()[term.0];
|
||||
let mut lo_exp = 0u32;
|
||||
let mut hi_exp = 0u32;
|
||||
let (lo_counts, lo_paths, lodepth) = self.modelcount_memoization(node.lo());
|
||||
@ -437,7 +437,7 @@ impl Bdd {
|
||||
}
|
||||
|
||||
/// Repairs the internal structures after an import.
|
||||
pub fn fix_import(&mut self) {
|
||||
pub fn fix_import(&self) {
|
||||
self.generate_var_dependencies();
|
||||
#[cfg(feature = "adhoccounting")]
|
||||
{
|
||||
@ -447,25 +447,25 @@ impl Bdd {
|
||||
self.count_cache
|
||||
.borrow_mut()
|
||||
.insert(Term::BOT, (ModelCounts::bot(), ModelCounts::bot(), 0));
|
||||
for i in 0..self.nodes.len() {
|
||||
for i in 0..self.nodes.borrow().len() {
|
||||
log::debug!("fixing Term({})", i);
|
||||
self.modelcount_memoization(Term(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_var_dependencies(&mut self) {
|
||||
fn generate_var_dependencies(&self) {
|
||||
#[cfg(feature = "variablelist")]
|
||||
self.nodes.iter().for_each(|node| {
|
||||
self.nodes.borrow().iter().for_each(|node| {
|
||||
if node.var() >= Var::BOT {
|
||||
self.var_deps.push(HashSet::new());
|
||||
self.var_deps.borrow_mut().push(HashSet::new());
|
||||
} else {
|
||||
let mut var_set: HashSet<Var> = self.var_deps[node.lo().value()]
|
||||
.union(&self.var_deps[node.hi().value()])
|
||||
let mut var_set: HashSet<Var> = self.var_deps.borrow()[node.lo().value()]
|
||||
.union(&self.var_deps.borrow()[node.hi().value()])
|
||||
.copied()
|
||||
.collect();
|
||||
var_set.insert(node.var());
|
||||
self.var_deps.push(var_set);
|
||||
self.var_deps.borrow_mut().push(var_set);
|
||||
}
|
||||
});
|
||||
}
|
||||
@ -474,7 +474,7 @@ impl Bdd {
|
||||
pub fn var_dependencies(&self, tree: Term) -> HashSet<Var> {
|
||||
#[cfg(feature = "variablelist")]
|
||||
{
|
||||
self.var_deps[tree.value()].clone()
|
||||
self.var_deps.borrow()[tree.value()].clone()
|
||||
}
|
||||
#[cfg(not(feature = "variablelist"))]
|
||||
{
|
||||
@ -525,7 +525,7 @@ mod test {
|
||||
#[test]
|
||||
fn newbdd() {
|
||||
let bdd = Bdd::new();
|
||||
assert_eq!(bdd.nodes.len(), 2);
|
||||
assert_eq!(bdd.nodes.borrow().len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -534,7 +534,7 @@ mod test {
|
||||
|
||||
assert_eq!(Bdd::constant(true), Term::TOP);
|
||||
assert_eq!(Bdd::constant(false), Term::BOT);
|
||||
assert_eq!(bdd.nodes.len(), 2);
|
||||
assert_eq!(bdd.nodes.borrow().len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -643,10 +643,15 @@ mod test {
|
||||
let formula4 = bdd.and(v3, formula2);
|
||||
|
||||
assert_eq!(bdd.models(v1, false), (1, 1).into());
|
||||
let mut x = bdd.count_cache.get_mut().iter().collect::<Vec<_>>();
|
||||
let mut x = bdd
|
||||
.count_cache
|
||||
.borrow_mut()
|
||||
.iter()
|
||||
.map(|(x, y)| (*x, *y))
|
||||
.collect::<Vec<_>>();
|
||||
x.sort();
|
||||
log::debug!("{:?}", formula1);
|
||||
for x in bdd.nodes.iter().enumerate() {
|
||||
for x in bdd.nodes.borrow().iter().enumerate() {
|
||||
log::debug!("{:?}", x);
|
||||
}
|
||||
log::debug!("{:?}", x);
|
||||
@ -757,14 +762,15 @@ mod test {
|
||||
bdd.not(formula4);
|
||||
|
||||
let constructed = bdd.var_deps.clone();
|
||||
bdd.var_deps = Vec::new();
|
||||
bdd.var_deps = Rc::new(RefCell::new(Vec::new()));
|
||||
bdd.generate_var_dependencies();
|
||||
|
||||
constructed
|
||||
.borrow()
|
||||
.iter()
|
||||
.zip(bdd.var_deps.iter())
|
||||
.zip(bdd.var_deps.borrow().iter())
|
||||
.for_each(|(left, right)| {
|
||||
assert!(left == right);
|
||||
assert!(*left == *right);
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user