use crate::sketchbook::ids::{UninterpretedFnId, VarId};
use crate::sketchbook::model::{BinaryOp, ModelState, UninterpretedFn};
use biodivine_lib_param_bn::{BooleanNetwork, FnUpdate};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum FnTree {
Const(bool),
Var(VarId),
PlaceholderVar(VarId),
UninterpretedFn(UninterpretedFnId, Vec<FnTree>),
Not(Box<FnTree>),
Binary(BinaryOp, Box<FnTree>, Box<FnTree>),
}
fn parse_update_fn_wrapper(
expression: &str,
bn_context: &BooleanNetwork,
) -> Result<FnUpdate, String> {
let fn_update = FnUpdate::try_from_str(expression, bn_context)
.map_err(|e| format!("Error during update function processing: {}", e))?;
Ok(fn_update)
}
impl FnTree {
pub fn try_from_str(
expression: &str,
model: &ModelState,
is_uninterpreted: Option<(&UninterpretedFnId, &UninterpretedFn)>,
) -> Result<FnTree, String> {
let bn_context = if let Some((_, f)) = is_uninterpreted {
model.to_fake_bn_with_params(f.get_arity())
} else {
model.to_empty_bn_with_params()
};
let fn_update = parse_update_fn_wrapper(expression, &bn_context)?;
let fn_tree = Self::from_fn_update(fn_update, model, is_uninterpreted)?;
Ok(fn_tree)
}
pub fn to_string(&self, model: &ModelState, is_uninterpreted: Option<usize>) -> String {
let bn_context = if let Some(n) = is_uninterpreted {
model.to_fake_bn_with_params(n)
} else {
model.to_empty_bn_with_params()
};
let fn_update = self.to_fn_update_recursive(&bn_context);
fn_update.to_string(&bn_context)
}
fn from_fn_update(
fn_update: FnUpdate,
model: &ModelState,
is_uninterpreted: Option<(&UninterpretedFnId, &UninterpretedFn)>,
) -> Result<FnTree, String> {
if let Some((fn_id, f)) = is_uninterpreted {
let bn_context = model.to_fake_bn_with_params(f.get_arity());
Self::from_fn_update_recursive(fn_update, model, &bn_context, Some(fn_id))
} else {
let bn_context = model.to_empty_bn_with_params();
Self::from_fn_update_recursive(fn_update, model, &bn_context, None)
}
}
fn from_fn_update_recursive(
fn_update: FnUpdate,
model: &ModelState,
bn_context: &BooleanNetwork,
is_uninterpreted: Option<&UninterpretedFnId>,
) -> Result<FnTree, String> {
match fn_update {
FnUpdate::Const(value) => Ok(FnTree::Const(value)),
FnUpdate::Var(id) => {
let var_id_str = bn_context.get_variable_name(id);
if is_uninterpreted.is_some() {
let var_id = model.get_placeholder_var_id(var_id_str)?;
Ok(FnTree::PlaceholderVar(var_id))
} else {
let var_id = model.get_var_id(var_id_str)?;
Ok(FnTree::Var(var_id))
}
}
FnUpdate::Not(inner) => {
let inner_transformed =
Self::from_fn_update_recursive(*inner, model, bn_context, is_uninterpreted)?;
Ok(FnTree::Not(Box::new(inner_transformed)))
}
FnUpdate::Binary(op, l, r) => {
let binary_transformed = BinaryOp::from(op);
let l_transformed =
Self::from_fn_update_recursive(*l, model, bn_context, is_uninterpreted)?;
let r_transformed =
Self::from_fn_update_recursive(*r, model, bn_context, is_uninterpreted)?;
Ok(FnTree::Binary(
binary_transformed,
Box::new(l_transformed),
Box::new(r_transformed),
))
}
FnUpdate::Param(id, args) => {
let fn_id_str = bn_context[id].get_name();
let fn_id = model.get_uninterpreted_fn_id(fn_id_str)?;
if let Some(fn_id_def) = is_uninterpreted {
if fn_id == *fn_id_def {
let msg = format!(
"An uninterpreted fn {fn_id} cannot be used in its own expression."
);
return Err(msg);
}
}
let args_transformed: Result<Vec<FnTree>, String> = args
.into_iter()
.map(|f| Self::from_fn_update_recursive(f, model, bn_context, is_uninterpreted))
.collect();
Ok(FnTree::UninterpretedFn(fn_id, args_transformed?))
}
}
}
pub(crate) fn to_fn_update_recursive(&self, bn_context: &BooleanNetwork) -> FnUpdate {
match self {
FnTree::Const(value) => FnUpdate::Const(*value),
FnTree::Var(var_id) => {
let bn_var_id = bn_context
.as_graph()
.find_variable(var_id.as_str())
.unwrap();
FnUpdate::Var(bn_var_id)
}
FnTree::PlaceholderVar(var_id) => {
let bn_var_id = bn_context
.as_graph()
.find_variable(var_id.as_str())
.unwrap();
FnUpdate::Var(bn_var_id)
}
FnTree::Not(inner) => {
let inner_transformed = inner.to_fn_update_recursive(bn_context);
FnUpdate::Not(Box::new(inner_transformed))
}
FnTree::Binary(op, l, r) => {
let binary_transformed = op.to_lib_param_bn_version();
let l_transformed = l.to_fn_update_recursive(bn_context);
let r_transformed = r.to_fn_update_recursive(bn_context);
FnUpdate::Binary(
binary_transformed,
Box::new(l_transformed),
Box::new(r_transformed),
)
}
FnTree::UninterpretedFn(fn_id, args) => {
let bn_param_id = bn_context.find_parameter(fn_id.as_str()).unwrap();
let args_transformed: Vec<FnUpdate> = args
.iter()
.map(|f| f.to_fn_update_recursive(bn_context))
.collect();
FnUpdate::Param(bn_param_id, args_transformed)
}
}
}
pub fn collect_variables(&self) -> HashSet<VarId> {
fn r_arguments(function: &FnTree, args: &mut HashSet<VarId>) {
match function {
FnTree::Const(_) => (),
FnTree::Var(id) => {
args.insert(id.clone());
}
FnTree::PlaceholderVar(id) => {
args.insert(id.clone());
}
FnTree::UninterpretedFn(_, p_args) => {
for fun in p_args {
r_arguments(fun, args);
}
}
FnTree::Not(inner) => r_arguments(inner, args),
FnTree::Binary(_, l, r) => {
r_arguments(l, args);
r_arguments(r, args);
}
};
}
let mut vars = HashSet::new();
r_arguments(self, &mut vars);
vars
}
pub fn collect_fn_symbols(&self) -> HashSet<UninterpretedFnId> {
fn r_parameters(function: &FnTree, params: &mut HashSet<UninterpretedFnId>) {
match function {
FnTree::Const(_) => (),
FnTree::Var(_) => (),
FnTree::PlaceholderVar(_) => (),
FnTree::UninterpretedFn(id, args) => {
params.insert(id.clone());
for fun in args {
r_parameters(fun, params);
}
}
FnTree::Not(inner) => r_parameters(inner, params),
FnTree::Binary(_, l, r) => {
r_parameters(l, params);
r_parameters(r, params);
}
};
}
let mut params = HashSet::new();
r_parameters(self, &mut params);
params
}
pub fn substitute_var(&self, old_id: &VarId, new_id: &VarId) -> FnTree {
match self {
FnTree::Const(_) => self.clone(),
FnTree::Var(id) => {
if id == old_id {
FnTree::Var(new_id.clone())
} else {
self.clone()
}
}
FnTree::PlaceholderVar(_) => self.clone(),
FnTree::UninterpretedFn(id, args) => {
let new_args = args
.iter()
.map(|it| it.substitute_var(old_id, new_id))
.collect::<Vec<_>>();
FnTree::UninterpretedFn(id.clone(), new_args)
}
FnTree::Not(inner) => (*inner).substitute_var(old_id, new_id),
FnTree::Binary(op, l, r) => FnTree::Binary(
*op,
Box::new((*l).substitute_var(old_id, new_id)),
Box::new((*r).substitute_var(old_id, new_id)),
),
}
}
pub fn substitute_fn_symbol(
&self,
old_id: &UninterpretedFnId,
new_id: &UninterpretedFnId,
) -> FnTree {
match self {
FnTree::Const(_) => self.clone(),
FnTree::Var(_) => self.clone(),
FnTree::PlaceholderVar(_) => self.clone(),
FnTree::UninterpretedFn(id, args) => {
let new_args = args
.iter()
.map(|it| it.substitute_fn_symbol(old_id, new_id))
.collect::<Vec<_>>();
if old_id == id {
FnTree::UninterpretedFn(new_id.clone(), new_args)
} else {
FnTree::UninterpretedFn(id.clone(), new_args)
}
}
FnTree::Not(inner) => (*inner).substitute_fn_symbol(old_id, new_id),
FnTree::Binary(op, l, r) => FnTree::Binary(
*op,
Box::new((*l).substitute_fn_symbol(old_id, new_id)),
Box::new((*r).substitute_fn_symbol(old_id, new_id)),
),
}
}
}
#[cfg(test)]
mod tests {
use crate::sketchbook::model::{FnTree, ModelState};
use std::collections::HashSet;
#[test]
fn test_valid_update_fn() {
let mut model = ModelState::new_from_vars(vec![("a", "a"), ("b", "b")]).unwrap();
model
.add_empty_uninterpreted_fn_by_str("f", "f", 1)
.unwrap();
let expression = "a & (b | f(b))";
let fn_tree = FnTree::try_from_str(expression, &model, None).unwrap();
let processed_expression = fn_tree.to_string(&model, None);
assert_eq!(processed_expression.as_str(), expression);
}
#[test]
fn test_valid_uninterpreted_fn() {
let mut model = ModelState::new_empty();
let arity = 2;
model
.add_empty_uninterpreted_fn_by_str("f", "f", arity)
.unwrap();
model
.add_empty_uninterpreted_fn_by_str("g", "g", arity)
.unwrap();
let expression = "var0 & (var1 | f(var0, var0))";
let fn_id = model.get_uninterpreted_fn_id("g").unwrap();
let uninterpreted_fn = model.get_uninterpreted_fn(&fn_id).unwrap();
let fn_tree =
FnTree::try_from_str(expression, &model, Some((&fn_id, uninterpreted_fn))).unwrap();
let processed_expression = fn_tree.to_string(&model, Some(arity));
assert_eq!(processed_expression.as_str(), expression,);
}
#[test]
fn test_invalid_update_fns() {
let mut model = ModelState::new_from_vars(vec![("a", "a"), ("b", "b")]).unwrap();
model
.add_empty_uninterpreted_fn_by_str("f", "f", 2)
.unwrap();
let expression = "var0 & var1";
let fn_tree = FnTree::try_from_str(expression, &model, None);
assert!(fn_tree.is_err());
let expression = "a & (b | g(b))";
let fn_tree = FnTree::try_from_str(expression, &model, None);
assert!(fn_tree.is_err());
let expression = "a & (b | f(b))";
let fn_tree = FnTree::try_from_str(expression, &model, None);
assert!(fn_tree.is_err());
}
#[test]
fn test_invalid_uninterpreted_fn() {
let mut model = ModelState::new_from_vars(vec![("a", "a"), ("b", "b")]).unwrap();
model
.add_empty_uninterpreted_fn_by_str("f", "f", 1)
.unwrap();
model
.add_empty_uninterpreted_fn_by_str("g", "g", 2)
.unwrap();
let expression = "a & (b | f(a))";
let fn_id = model.get_uninterpreted_fn_id("g").unwrap();
let uninterpreted_fn = model.get_uninterpreted_fn(&fn_id).unwrap();
let fn_tree = FnTree::try_from_str(expression, &model, Some((&fn_id, uninterpreted_fn)));
assert!(fn_tree.is_err());
let expression = "f(var0)";
let fn_id = model.get_uninterpreted_fn_id("f").unwrap();
let uninterpreted_fn = model.get_uninterpreted_fn(&fn_id).unwrap();
let fn_tree = FnTree::try_from_str(expression, &model, Some((&fn_id, uninterpreted_fn)));
assert!(fn_tree.is_err());
let expression = "var0 | var1";
let fn_id = model.get_uninterpreted_fn_id("f").unwrap();
let uninterpreted_fn = model.get_uninterpreted_fn(&fn_id).unwrap();
let fn_tree = FnTree::try_from_str(expression, &model, Some((&fn_id, uninterpreted_fn)));
assert!(fn_tree.is_err());
}
#[test]
fn test_substitution() {
let mut model = ModelState::new_from_vars(vec![("a", "a"), ("b", "b")]).unwrap();
model
.add_empty_uninterpreted_fn_by_str("f", "f", 1)
.unwrap();
model
.add_empty_uninterpreted_fn_by_str("g", "g", 1)
.unwrap();
let a = model.get_var_id("a").unwrap();
let b = model.get_var_id("b").unwrap();
let f = model.get_uninterpreted_fn_id("f").unwrap();
let g = model.get_uninterpreted_fn_id("g").unwrap();
let fn_tree = FnTree::try_from_str("a & f(a)", &model, None).unwrap();
let modified_tree = fn_tree.substitute_var(&a, &b);
assert_eq!(modified_tree.to_string(&model, None), "b & f(b)");
let modified_tree = fn_tree.substitute_fn_symbol(&f, &g);
assert_eq!(modified_tree.to_string(&model, None), "a & g(a)");
}
#[test]
fn test_collect_fns() {
let mut model = ModelState::new_from_vars(vec![("a", "a"), ("b", "b")]).unwrap();
let fns = vec![("f", "f", 1), ("g", "g", 1), ("h", "h", 1)];
model.add_multiple_uninterpreted_fns(fns).unwrap();
let f = model.get_uninterpreted_fn_id("f").unwrap();
let g = model.get_uninterpreted_fn_id("g").unwrap();
let fn_tree = FnTree::try_from_str("a & f(a) | (g(b))", &model, None).unwrap();
let collected_fns = fn_tree.collect_fn_symbols();
let expected_fns = HashSet::from([f, g]);
assert_eq!(expected_fns, collected_fns);
}
#[test]
fn test_collect_vars() {
let variables = vec![("a", "a"), ("b", "b"), ("c", "c")];
let mut model = ModelState::new_from_vars(variables).unwrap();
model
.add_empty_uninterpreted_fn_by_str("f", "f", 1)
.unwrap();
let a = model.get_var_id("a").unwrap();
let b = model.get_var_id("b").unwrap();
let fn_tree = FnTree::try_from_str("a & f(a) | (f(b))", &model, None).unwrap();
let collected_vars = fn_tree.collect_variables();
let expected_vars = HashSet::from([a, b]);
assert_eq!(expected_vars, collected_vars);
}
}