diff --git a/Cargo.lock b/Cargo.lock index 449d92c..838842a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,7 +4,7 @@ version = 3 [[package]] name = "adf_bdd" -version = "0.1.3" +version = "0.1.4" dependencies = [ "assert_cmd", "assert_fs", diff --git a/src/adf.rs b/src/adf.rs index dd5334b..cb881ae 100644 --- a/src/adf.rs +++ b/src/adf.rs @@ -234,20 +234,20 @@ impl Adf { 'a: 'c, 'b: 'c, { - ThreeValuedInterpretationsIterator::new(grounded).filter(|interpretation| { - interpretation.iter().all(|ac| { - ac.compare_inf( - &interpretation - .iter() - .enumerate() - .fold(*ac, |acc, (var, term)| { - if term.is_truth_value() { - self.bdd.restrict(acc, Var(var), term.is_true()) - } else { - acc - } - }), - ) + let ac = self.ac.clone(); + ThreeValuedInterpretationsIterator::new(grounded).filter(move |interpretation| { + interpretation.iter().enumerate().all(|(ac_idx, it)| { + log::trace!("idx [{}], term: {}", ac_idx, it); + it.compare_inf(&interpretation.iter().enumerate().fold( + ac[ac_idx], + |acc, (var, term)| { + if term.is_truth_value() { + self.bdd.restrict(acc, Var(var), term.is_true()) + } else { + acc + } + }, + )) }) }) } @@ -421,21 +421,30 @@ mod test { assert_eq!( adf.complete(0), - vec![ - vec![Term(1), Term(3), Term(3), Term(9), Term(0), Term(1)], - vec![Term(1), Term(3), Term(3), Term(1), Term(0), Term(1)], - vec![Term(1), Term(3), Term(3), Term(0), Term(0), Term(1)], - vec![Term(1), Term(3), Term(1), Term(9), Term(0), Term(1)], - vec![Term(1), Term(3), Term(1), Term(1), Term(0), Term(1)], - vec![Term(1), Term(3), Term(1), Term(0), Term(0), Term(1)], - vec![Term(1), Term(3), Term(0), Term(0), Term(0), Term(1)], - vec![Term(1), Term(1), Term(1), Term(1), Term(0), Term(1)], - vec![Term(1), Term(1), Term(1), Term(0), Term(0), Term(1)], - vec![Term(1), Term(1), Term(0), Term(0), Term(0), Term(1)], - vec![Term(1), Term(0), Term(0), Term(0), Term(0), Term(1)], - vec![Term(1), Term(0), Term(1), Term(1), Term(0), Term(1)], - vec![Term(1), Term(0), Term(1), Term(0), Term(0), Term(1)] + [ + [Term(1), Term(3), Term(3), Term(9), Term(0), Term(1)], + [Term(1), Term(1), Term(1), Term(0), Term(0), Term(1)], + [Term(1), Term(0), Term(0), Term(1), Term(0), Term(1)] ] ); } + + #[test] + fn complete2() { + let parser = AdfParser::default(); + parser.parse()("s(a).s(b).s(c).s(d).ac(a,c(v)).ac(b,b).ac(c,and(a,b)).ac(d,neg(b)).") + .unwrap(); + let mut adf = Adf::from_parser(&parser); + assert_eq!( + adf.complete(0), + [ + [Term(1), Term(3), Term(3), Term(7)], + [Term(1), Term(1), Term(1), Term(0)], + [Term(1), Term(0), Term(0), Term(1)] + ] + ); + for model in adf.complete(0) { + println!("{}", adf.print_interpretation(&model)); + } + } } diff --git a/src/datatypes/adf.rs b/src/datatypes/adf.rs index c3b73b0..45f762a 100644 --- a/src/datatypes/adf.rs +++ b/src/datatypes/adf.rs @@ -176,29 +176,51 @@ impl ThreeValuedInterpretationsIterator { fn decrement(&mut self) { if let Some(ref mut current) = self.current { - if let Some((pos, val)) = current.iter().enumerate().find(|(idx, val)| **val > 0) { - if pos > 0 && *val == 2 { - for elem in &mut current[0..pos] { - *elem = 2; - } - } - current[pos] -= 1; - if self.last_iteration { - if current.iter().all(|val| *val == 0) { - self.current = None; - } - } - } else if !self.last_iteration { - let len = current.len(); - if len <= 1 { - self.current = None; - } else { - for elem in &mut current[0..len - 1] { - *elem = 2; - } - } - self.last_iteration = true; + if !ThreeValuedInterpretationsIterator::decrement_vec(current) { + self.current = None; } + // if let Some((pos, val)) = current.iter().enumerate().find(|(idx, val)| **val > 0) { + // if pos > 0 && *val == 2 { + // for elem in &mut current[0..pos] { + // *elem = 2; + // } + // } + // current[pos] -= 1; + // if self.last_iteration { + // if current.iter().all(|val| *val == 0) { + // self.current = None; + // } + // } + // } else if !self.last_iteration { + // let len = current.len(); + // if len <= 1 { + // self.current = None; + // } else { + // for elem in &mut current[0..len - 1] { + // *elem = 2; + // } + // } + // self.last_iteration = true; + //} + } + } + + fn decrement_vec(vector: &mut Vec) -> bool { + let mut cur_pos = None; + for (idx, value) in vector.iter_mut().enumerate() { + if *value > 0 { + *value -= 1; + cur_pos = Some(idx); + break; + } + } + if let Some(cur) = cur_pos { + for value in vector[0..cur].iter_mut() { + *value = 2; + } + true + } else { + false } } } @@ -299,10 +321,6 @@ mod test { iter.next(), Some(vec![Term::TOP, Term::TOP, Term::BOT, Term::BOT, Term::TOP]) ); - assert_eq!( - iter.next(), - Some(vec![Term::TOP, Term::BOT, Term::BOT, Term::BOT, Term::TOP]) - ); assert_eq!( iter.next(), Some(vec![Term::TOP, Term::BOT, Term::BOT, Term(12), Term::TOP]) @@ -311,7 +329,47 @@ mod test { iter.next(), Some(vec![Term::TOP, Term::BOT, Term::BOT, Term::TOP, Term::TOP]) ); + assert_eq!( + iter.next(), + Some(vec![Term::TOP, Term::BOT, Term::BOT, Term::BOT, Term::TOP]) + ); assert_eq!(iter.next(), None); + + let testinterpretation = vec![Term(1), Term(3), Term(3), Term(7)]; + let mut iter: Vec> = + ThreeValuedInterpretationsIterator::new(&testinterpretation).collect(); + assert_eq!( + iter, + [ + [Term(1), Term(3), Term(3), Term(7)], + [Term(1), Term(3), Term(3), Term(1)], + [Term(1), Term(3), Term(3), Term(0)], + [Term(1), Term(3), Term(1), Term(7)], + [Term(1), Term(3), Term(1), Term(1)], + [Term(1), Term(3), Term(1), Term(0)], + [Term(1), Term(3), Term(0), Term(7)], + [Term(1), Term(3), Term(0), Term(1)], + [Term(1), Term(3), Term(0), Term(0)], + [Term(1), Term(1), Term(3), Term(7)], + [Term(1), Term(1), Term(3), Term(1)], + [Term(1), Term(1), Term(3), Term(0)], + [Term(1), Term(1), Term(1), Term(7)], + [Term(1), Term(1), Term(1), Term(1)], + [Term(1), Term(1), Term(1), Term(0)], + [Term(1), Term(1), Term(0), Term(7)], + [Term(1), Term(1), Term(0), Term(1)], + [Term(1), Term(1), Term(0), Term(0)], + [Term(1), Term(0), Term(3), Term(7)], + [Term(1), Term(0), Term(3), Term(1)], + [Term(1), Term(0), Term(3), Term(0)], + [Term(1), Term(0), Term(1), Term(7)], + [Term(1), Term(0), Term(1), Term(1)], + [Term(1), Term(0), Term(1), Term(0)], + [Term(1), Term(0), Term(0), Term(7)], + [Term(1), Term(0), Term(0), Term(1)], + [Term(1), Term(0), Term(0), Term(0)] + ] + ); } #[test] @@ -330,12 +388,12 @@ mod test { iter.decrement(); assert_eq!(iter.current, Some(vec![0, 1])); iter.decrement(); - assert_eq!(iter.current, Some(vec![0, 0])); - iter.decrement(); assert_eq!(iter.current, Some(vec![2, 0])); iter.decrement(); assert_eq!(iter.current, Some(vec![1, 0])); iter.decrement(); + assert_eq!(iter.current, Some(vec![0, 0])); + iter.decrement(); assert_eq!(iter.current, None); let testinterpretation = vec![Term::TOP, Term(22), Term::BOT, Term::TOP, Term::TOP];