1
0
mirror of https://github.com/ellmau/adf-obdd.git synced 2025-12-19 09:29:36 +01:00

Refactor bdd to use rc refcell

This commit is contained in:
Stefan Ellmauthaler 2022-06-01 13:07:56 +02:00
parent 78949b18e5
commit be68af7d05
Failed to extract signature
4 changed files with 113 additions and 122 deletions

17
Cargo.lock generated
View File

@ -19,26 +19,11 @@ dependencies = [
"test-log",
]
[[package]]
name = "adf_bdd"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e781519ea5434514f014476c02ccee777b28e600ad58fadca195715acb194c69"
dependencies = [
"biodivine-lib-bdd",
"derivative",
"lexical-sort",
"log",
"nom",
"serde",
"serde_json",
]
[[package]]
name = "adf_bdd-solver"
version = "0.2.4"
dependencies = [
"adf_bdd 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)",
"adf_bdd",
"assert_cmd",
"assert_fs",
"clap",

View File

@ -14,7 +14,7 @@ name = "adf-bdd"
path = "src/main.rs"
[dependencies]
adf_bdd = { version="0.2.4", default-features = false }
adf_bdd = { version="0.2.4", path="../lib", default-features = false }
clap = {version = "3.1.14", features = [ "derive", "cargo", "env" ]}
log = { version = "0.4", features = [ "max_level_trace", "release_max_level_info" ] }
serde = { version = "1.0", features = ["derive","rc"] }

View File

@ -946,7 +946,7 @@ mod test {
"s(a). s(b). s(c). s(d). ac(a,c(v)). ac(b,b). ac(c,and(a,b)). ac(d,neg(b)).",
)
.unwrap();
let mut adf = Adf::from_parser(&parser);
let adf = Adf::from_parser(&parser);
let mut v = adf.ac.clone();
let mut fcs = adf.facet_count(&v);

View File

@ -4,6 +4,7 @@ pub mod vectorize;
use crate::datatypes::*;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::rc::Rc;
use std::{cell::RefCell, cmp::min, collections::HashMap, fmt::Display};
/// Contains the data of (possibly) multiple roBDDs, managed over one collection of nodes.
@ -11,20 +12,21 @@ use std::{cell::RefCell, cmp::min, collections::HashMap, fmt::Display};
/// Each roBDD is identified by its corresponding [`Term`], which implicitly identifies the root node of a roBDD.
#[derive(Debug, Serialize, Deserialize)]
pub struct Bdd {
pub(crate) nodes: Vec<BddNode>,
nodes: Rc<RefCell<Vec<BddNode>>>,
#[cfg(feature = "variablelist")]
#[serde(skip)]
var_deps: Vec<HashSet<Var>>,
#[serde(with = "vectorize")]
cache: HashMap<BddNode, Term>,
var_deps: Rc<RefCell<Vec<HashSet<Var>>>>,
//#[serde(with = "vectorize")]
#[serde(skip)]
cache: Rc<RefCell<HashMap<BddNode, Term>>>,
#[serde(skip, default = "Bdd::default_count_cache")]
count_cache: RefCell<HashMap<Term, CountNode>>,
count_cache: Rc<RefCell<HashMap<Term, CountNode>>>,
}
impl Display for Bdd {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, " ")?;
for (idx, elem) in self.nodes.iter().enumerate() {
for (idx, elem) in self.nodes.borrow().iter().enumerate() {
writeln!(f, "{} {}", idx, *elem)?;
}
Ok(())
@ -47,18 +49,18 @@ impl Bdd {
nodes: vec![BddNode::bot_node(), BddNode::top_node()],
#[cfg(feature = "variablelist")]
var_deps: vec![HashSet::new(), HashSet::new()],
cache: HashMap::new(),
cache: Rc::new(RefCell::new(HashMap::new())),
count_cache: RefCell::new(HashMap::new()),
}
}
#[cfg(feature = "adhoccounting")]
{
let result = Self {
nodes: vec![BddNode::bot_node(), BddNode::top_node()],
nodes: Rc::new(RefCell::new(vec![BddNode::bot_node(), BddNode::top_node()])),
#[cfg(feature = "variablelist")]
var_deps: vec![HashSet::new(), HashSet::new()],
cache: HashMap::new(),
count_cache: RefCell::new(HashMap::new()),
var_deps: Rc::new(RefCell::new(vec![HashSet::new(), HashSet::new()])),
cache: Rc::new(RefCell::new(HashMap::new())),
count_cache: Rc::new(RefCell::new(HashMap::new())),
};
result
.count_cache
@ -72,8 +74,8 @@ impl Bdd {
}
}
fn default_count_cache() -> RefCell<HashMap<Term, CountNode>> {
RefCell::new(HashMap::new())
fn default_count_cache() -> Rc<RefCell<HashMap<Term, CountNode>>> {
Rc::new(RefCell::new(HashMap::new()))
}
/// Instantiates a [variable][crate::datatypes::Var] and returns the representing roBDD as a [`Term`][crate::datatypes::Term].
@ -134,7 +136,7 @@ impl Bdd {
positive: &[Var],
) -> Vec<(Vec<Var>, Vec<Var>)> {
let mut result = Vec::new();
let node = self.nodes[tree.value()];
let node = self.nodes.borrow()[tree.value()];
let var = node.var();
if tree.is_truth_value() {
return Vec::new();
@ -175,11 +177,11 @@ impl Bdd {
}
/// Restrict the value of a given [variable][crate::datatypes::Var] to **val**.
pub fn restrict(&mut self, tree: Term, var: Var, val: bool) -> Term {
let node = self.nodes[tree.0];
pub fn restrict(&self, tree: Term, var: Var, val: bool) -> Term {
let node = self.nodes.borrow()[tree.0];
#[cfg(feature = "variablelist")]
{
if !self.var_deps[tree.value()].contains(&var) {
if !self.var_deps.borrow()[tree.value()].contains(&var) {
return tree;
}
}
@ -201,7 +203,7 @@ impl Bdd {
}
/// Creates an roBDD, based on the relation of three roBDDs, which are in an `if-then-else` relation.
fn if_then_else(&mut self, i: Term, t: Term, e: Term) -> Term {
fn if_then_else(&self, i: Term, t: Term, e: Term) -> Term {
if i == Term::TOP {
t
} else if i == Term::BOT {
@ -212,10 +214,10 @@ impl Bdd {
i
} else {
let minvar = Var(min(
self.nodes[i.value()].var().value(),
self.nodes.borrow()[i.value()].var().value(),
min(
self.nodes[t.value()].var().value(),
self.nodes[e.value()].var().value(),
self.nodes.borrow()[t.value()].var().value(),
self.nodes.borrow()[e.value()].var().value(),
),
));
let itop = self.restrict(i, minvar, true);
@ -233,76 +235,74 @@ impl Bdd {
/// Creates a new node in the roBDD.
/// It will not create duplicate nodes and uses already existing nodes, if applicable.
pub fn node(&mut self, var: Var, lo: Term, hi: Term) -> Term {
pub fn node(&self, var: Var, lo: Term, hi: Term) -> Term {
if lo == hi {
lo
} else {
let node = BddNode::new(var, lo, hi);
match self.cache.get(&node) {
Some(t) => *t,
None => {
let new_term = Term(self.nodes.len());
self.nodes.push(node);
self.cache.insert(node, new_term);
#[cfg(feature = "variablelist")]
{
let mut var_set: HashSet<Var> = self.var_deps[lo.value()]
.union(&self.var_deps[hi.value()])
.copied()
.collect();
var_set.insert(var);
self.var_deps.push(var_set);
}
log::debug!("newterm: {} as {:?}", new_term, node);
#[cfg(feature = "adhoccounting")]
{
let mut count_cache = self.count_cache.borrow_mut();
let (lo_counts, lo_paths, lodepth) =
*count_cache.get(&lo).expect("Cache corrupted");
let (hi_counts, hi_paths, hidepth) =
*count_cache.get(&hi).expect("Cache corrupted");
log::debug!(
"lo (cm: {}, mo: {}, p-: {}, p+: {}, dp: {})",
lo_counts.cmodels,
lo_counts.models,
lo_paths.cmodels,
lo_paths.models,
lodepth
);
log::debug!(
"hi (cm: {}, mo: {}, p-: {}, p+: {}, dp: {})",
hi_counts.cmodels,
hi_counts.models,
hi_paths.cmodels,
hi_paths.models,
hidepth
);
let (lo_exp, hi_exp) = if lodepth > hidepth {
(1, 2usize.pow((lodepth - hidepth) as u32))
} else {
(2usize.pow((hidepth - lodepth) as u32), 1)
};
log::debug!("lo_exp {}, hi_exp {}", lo_exp, hi_exp);
count_cache.insert(
new_term,
(
(
lo_counts.cmodels * lo_exp + hi_counts.cmodels * hi_exp,
lo_counts.models * lo_exp + hi_counts.models * hi_exp,
)
.into(),
(
lo_paths.cmodels + hi_paths.cmodels,
lo_paths.models + hi_paths.models,
)
.into(),
std::cmp::max(lodepth, hidepth) + 1,
),
);
}
new_term
}
if let Some(t) = self.cache.borrow().get(&node) {
return *t;
}
let new_term = Term(self.nodes.borrow().len());
self.nodes.borrow_mut().push(node);
self.cache.borrow_mut().insert(node, new_term);
#[cfg(feature = "variablelist")]
{
let mut var_set: HashSet<Var> = self.var_deps.borrow()[lo.value()]
.union(&self.var_deps.borrow()[hi.value()])
.copied()
.collect();
var_set.insert(var);
self.var_deps.borrow_mut().push(var_set);
}
log::debug!("newterm: {} as {:?}", new_term, node);
#[cfg(feature = "adhoccounting")]
{
let mut count_cache = self.count_cache.borrow_mut();
let (lo_counts, lo_paths, lodepth) =
*count_cache.get(&lo).expect("Cache corrupted");
let (hi_counts, hi_paths, hidepth) =
*count_cache.get(&hi).expect("Cache corrupted");
log::debug!(
"lo (cm: {}, mo: {}, p-: {}, p+: {}, dp: {})",
lo_counts.cmodels,
lo_counts.models,
lo_paths.cmodels,
lo_paths.models,
lodepth
);
log::debug!(
"hi (cm: {}, mo: {}, p-: {}, p+: {}, dp: {})",
hi_counts.cmodels,
hi_counts.models,
hi_paths.cmodels,
hi_paths.models,
hidepth
);
let (lo_exp, hi_exp) = if lodepth > hidepth {
(1, 2usize.pow((lodepth - hidepth) as u32))
} else {
(2usize.pow((hidepth - lodepth) as u32), 1)
};
log::debug!("lo_exp {}, hi_exp {}", lo_exp, hi_exp);
count_cache.insert(
new_term,
(
(
lo_counts.cmodels * lo_exp + hi_counts.cmodels * hi_exp,
lo_counts.models * lo_exp + hi_counts.models * hi_exp,
)
.into(),
(
lo_paths.cmodels + hi_paths.cmodels,
lo_paths.models + hi_paths.models,
)
.into(),
std::cmp::max(lodepth, hidepth) + 1,
),
);
}
new_term
}
}
@ -369,7 +369,7 @@ impl Bdd {
} else if term == Term::BOT {
(ModelCounts::bot(), ModelCounts::bot(), 0)
} else {
let node = &self.nodes[term.0];
let node = &self.nodes.borrow()[term.0];
let mut lo_exp = 0u32;
let mut hi_exp = 0u32;
let (lo_counts, lo_paths, lodepth) = self.modelcount_naive(node.lo());
@ -405,7 +405,7 @@ impl Bdd {
return *result;
}
let result = {
let node = &self.nodes[term.0];
let node = &self.nodes.borrow()[term.0];
let mut lo_exp = 0u32;
let mut hi_exp = 0u32;
let (lo_counts, lo_paths, lodepth) = self.modelcount_memoization(node.lo());
@ -437,7 +437,7 @@ impl Bdd {
}
/// Repairs the internal structures after an import.
pub fn fix_import(&mut self) {
pub fn fix_import(&self) {
self.generate_var_dependencies();
#[cfg(feature = "adhoccounting")]
{
@ -447,25 +447,25 @@ impl Bdd {
self.count_cache
.borrow_mut()
.insert(Term::BOT, (ModelCounts::bot(), ModelCounts::bot(), 0));
for i in 0..self.nodes.len() {
for i in 0..self.nodes.borrow().len() {
log::debug!("fixing Term({})", i);
self.modelcount_memoization(Term(i));
}
}
}
fn generate_var_dependencies(&mut self) {
fn generate_var_dependencies(&self) {
#[cfg(feature = "variablelist")]
self.nodes.iter().for_each(|node| {
self.nodes.borrow().iter().for_each(|node| {
if node.var() >= Var::BOT {
self.var_deps.push(HashSet::new());
self.var_deps.borrow_mut().push(HashSet::new());
} else {
let mut var_set: HashSet<Var> = self.var_deps[node.lo().value()]
.union(&self.var_deps[node.hi().value()])
let mut var_set: HashSet<Var> = self.var_deps.borrow()[node.lo().value()]
.union(&self.var_deps.borrow()[node.hi().value()])
.copied()
.collect();
var_set.insert(node.var());
self.var_deps.push(var_set);
self.var_deps.borrow_mut().push(var_set);
}
});
}
@ -474,7 +474,7 @@ impl Bdd {
pub fn var_dependencies(&self, tree: Term) -> HashSet<Var> {
#[cfg(feature = "variablelist")]
{
self.var_deps[tree.value()].clone()
self.var_deps.borrow()[tree.value()].clone()
}
#[cfg(not(feature = "variablelist"))]
{
@ -525,7 +525,7 @@ mod test {
#[test]
fn newbdd() {
let bdd = Bdd::new();
assert_eq!(bdd.nodes.len(), 2);
assert_eq!(bdd.nodes.borrow().len(), 2);
}
#[test]
@ -534,7 +534,7 @@ mod test {
assert_eq!(Bdd::constant(true), Term::TOP);
assert_eq!(Bdd::constant(false), Term::BOT);
assert_eq!(bdd.nodes.len(), 2);
assert_eq!(bdd.nodes.borrow().len(), 2);
}
#[test]
@ -643,10 +643,15 @@ mod test {
let formula4 = bdd.and(v3, formula2);
assert_eq!(bdd.models(v1, false), (1, 1).into());
let mut x = bdd.count_cache.get_mut().iter().collect::<Vec<_>>();
let mut x = bdd
.count_cache
.borrow_mut()
.iter()
.map(|(x, y)| (*x, *y))
.collect::<Vec<_>>();
x.sort();
log::debug!("{:?}", formula1);
for x in bdd.nodes.iter().enumerate() {
for x in bdd.nodes.borrow().iter().enumerate() {
log::debug!("{:?}", x);
}
log::debug!("{:?}", x);
@ -757,14 +762,15 @@ mod test {
bdd.not(formula4);
let constructed = bdd.var_deps.clone();
bdd.var_deps = Vec::new();
bdd.var_deps = Rc::new(RefCell::new(Vec::new()));
bdd.generate_var_dependencies();
constructed
.borrow()
.iter()
.zip(bdd.var_deps.iter())
.zip(bdd.var_deps.borrow().iter())
.for_each(|(left, right)| {
assert!(left == right);
assert!(*left == *right);
});
assert_eq!(