From cdd21f7e4f553eb9b5e8aa6e998360bbc0bd7ad7 Mon Sep 17 00:00:00 2001 From: monsterkrampe Date: Mon, 29 Aug 2022 18:21:50 +0200 Subject: [PATCH] Continue implementing basic solving endpoint --- Cargo.lock | 16 ++++++ lib/src/adf.rs | 7 +-- lib/src/obdd.rs | 4 +- server/Cargo.toml | 1 + server/src/main.rs | 131 +++++++++++++++++++++++++++++++++++++++++---- 5 files changed, 144 insertions(+), 15 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ec894e8..f5c573a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,6 +19,21 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "actix-cors" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02a0adcaabb68f1dfe8880cb3c5f049261c68f5d69ce06b6f3a930f31710838e" +dependencies = [ + "actix-utils", + "actix-web", + "derive_more", + "futures-util", + "log", + "once_cell", + "smallvec", +] + [[package]] name = "actix-http" version = "3.3.1" @@ -204,6 +219,7 @@ dependencies = [ name = "adf-bdd-server" version = "0.3.0" dependencies = [ + "actix-cors", "actix-web", "adf_bdd", "env_logger 0.9.3", diff --git a/lib/src/adf.rs b/lib/src/adf.rs index 3b007a1..4d13df6 100644 --- a/lib/src/adf.rs +++ b/lib/src/adf.rs @@ -30,9 +30,10 @@ use self::heuristics::Heuristic; /// /// Please note that due to the nature of the underlying reduced and ordered Bdd the concept of a [`Term`][crate::datatypes::Term] represents one (sub) formula as well as truth-values. pub struct Adf { - ordering: VarContainer, - bdd: Bdd, - ac: Vec, + // TODO: none of this should be public + pub ordering: VarContainer, + pub bdd: Bdd, + pub ac: Vec, #[serde(skip, default = "Adf::default_rng")] rng: RefCell, } diff --git a/lib/src/obdd.rs b/lib/src/obdd.rs index a54cb70..321a2f3 100644 --- a/lib/src/obdd.rs +++ b/lib/src/obdd.rs @@ -13,7 +13,9 @@ use std::{cell::RefCell, cmp::min, collections::HashMap, fmt::Display}; /// Each roBDD is identified by its corresponding [`Term`], which implicitly identifies the root node of a roBDD. #[derive(Debug, Serialize, Deserialize)] pub struct Bdd { - pub(crate) nodes: Vec, + // TODO: use this again + // pub(crate) nodes: Vec, + pub nodes: Vec, #[cfg(feature = "variablelist")] #[serde(skip)] var_deps: Vec>, diff --git a/server/Cargo.toml b/server/Cargo.toml index 663b71b..2ad8642 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -14,6 +14,7 @@ description = "Offer Solving ADFs as a service" [dependencies] adf_bdd = { version="0.3.1", path="../lib", features = ["frontend"] } actix-web = "4" +actix-cors = "0.6" env_logger = "0.9" log = "0.4" serde = "1" diff --git a/server/src/main.rs b/server/src/main.rs index 2509031..f35d3e5 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,7 +1,11 @@ -use actix_web::{get, post, web, App, HttpResponse, HttpServer, Responder}; +use actix_cors::Cors; +use actix_web::{get, http, post, web, App, HttpResponse, HttpServer, Responder}; use serde::{Deserialize, Serialize}; +use std::collections::{HashMap, HashSet}; use adf_bdd::adf::Adf; +use adf_bdd::datatypes::BddNode; +use adf_bdd::datatypes::Var; use adf_bdd::parser::AdfParser; #[get("/")] @@ -15,15 +19,16 @@ async fn root() -> impl Responder { struct DoubleLabeledGraph { // number of nodes equals the number of node labels // nodes implicitly have their index as their ID - node_labels: Vec, + node_labels: HashMap, // every node gets this label containing multiple entries (it might be empty) - tree_root_labels: Vec>, - edges: Vec<(usize, usize)>, + tree_root_labels: HashMap>, + lo_edges: Vec<(usize, usize)>, + hi_edges: Vec<(usize, usize)>, } #[derive(Deserialize)] struct SolveReqBody { - adf_input: String, + code: String, } #[derive(Serialize)] @@ -33,7 +38,7 @@ struct SolveResBody { #[post("/solve")] async fn solve(req_body: web::Json) -> impl Responder { - let input = &req_body.adf_input; + let input = &req_body.code; let parser = AdfParser::default(); match parser.parse()(input) { @@ -51,7 +56,98 @@ async fn solve(req_body: web::Json) -> impl Responder { // TODO: as first test: turn full graph with initial ac into DoubleLabeledGraph DTO and return it - "Hello World" + // get relevant nodes from bdd and ac + let mut node_indices: HashSet = HashSet::new(); + let mut new_node_indices: HashSet = adf.ac.iter().map(|term| term.value()).collect(); + + while !new_node_indices.is_empty() { + node_indices = node_indices.union(&new_node_indices).map(|i| *i).collect(); + new_node_indices = HashSet::new(); + + for node_index in &node_indices { + let lo_node_index = adf.bdd.nodes[*node_index].lo().value(); + if !node_indices.contains(&lo_node_index) { + new_node_indices.insert(lo_node_index); + } + + let hi_node_index = adf.bdd.nodes[*node_index].hi().value(); + if !node_indices.contains(&hi_node_index) { + new_node_indices.insert(hi_node_index); + } + } + } + + let node_labels: HashMap = + adf.bdd + .nodes + .iter() + .enumerate() + .filter(|(i, _)| node_indices.contains(i)) + .map(|(i, &node)| { + let value_part = match node.var() { + Var::TOP => "TOP".to_string(), + Var::BOT => "BOT".to_string(), + _ => adf.ordering.name(node.var()).expect( + "name for each var should exist; special cases are handled separately", + ), + }; + + (i, value_part) + }) + .collect(); + + let tree_root_labels: HashMap> = adf.ac.iter().enumerate().fold( + adf.bdd + .nodes + .iter() + .enumerate() + .filter(|(i, _)| node_indices.contains(i)) + .map(|(i, _)| (i, vec![])) + .collect(), + |mut acc, (root_for, root_node)| { + acc.get_mut(&root_node.value()) + .expect("we know that the index will be in the map") + .push(adf.ordering.name(Var(root_for)).expect( + "name for each var should exist; special cases are handled separately", + )); + + acc + }, + ); + + let lo_edges: Vec<(usize, usize)> = adf + .bdd + .nodes + .iter() + .enumerate() + .filter(|(i, _)| node_indices.contains(i)) + .filter(|(_, node)| !vec![Var::TOP, Var::BOT].contains(&node.var())) + .map(|(i, &node)| (i, node.lo().value())) + .collect(); + + let hi_edges: Vec<(usize, usize)> = adf + .bdd + .nodes + .iter() + .enumerate() + .filter(|(i, _)| node_indices.contains(i)) + .filter(|(_, node)| !vec![Var::TOP, Var::BOT].contains(&node.var())) + .map(|(i, &node)| (i, node.hi().value())) + .collect(); + + log::debug!("{:?}", node_labels); + log::debug!("{:?}", tree_root_labels); + log::debug!("{:?}", lo_edges); + log::debug!("{:?}", hi_edges); + + let dto = DoubleLabeledGraph { + node_labels, + tree_root_labels, + lo_edges, + hi_edges, + }; + + web::Json(dto) } #[actix_web::main] @@ -59,8 +155,21 @@ async fn main() -> std::io::Result<()> { env_logger::builder() .filter_level(log::LevelFilter::Debug) .init(); - HttpServer::new(|| App::new().service(root).service(solve)) - .bind(("127.0.0.1", 8080))? - .run() - .await + + HttpServer::new(|| { + let cors = Cors::default() + .allowed_origin("http://localhost:1234") + .allowed_methods(vec!["GET", "POST"]) + .allowed_headers(vec![ + http::header::AUTHORIZATION, + http::header::ACCEPT, + http::header::CONTENT_TYPE, + ]) + .max_age(3600); + + App::new().wrap(cors).service(root).service(solve) + }) + .bind(("127.0.0.1", 8080))? + .run() + .await }