1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
use crate::sketchbook::ids::{UninterpretedFnId, VarId};
use crate::sketchbook::model::{Essentiality, FnArgument, FnTree, ModelState, Monotonicity};
use crate::sketchbook::utils::assert_name_valid;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::fmt::{Display, Formatter};

/// An uninterpreted function of a partially specified model.
///
/// Field `arguments` hold information regarding properties of the function
/// with respect to each of its arguments (in order). It also tracks the arity.
///
/// You can leave the function completely unspecified, or you can add an
/// "partial expression". Field `tree` holds the parsed version of that formula,
/// while `expression` tracks the original formula.
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct UninterpretedFn {
    name: String,
    annotation: String,
    arguments: Vec<FnArgument>,
    tree: Option<FnTree>,
    expression: String,
}

/// Creating new `UninterpretedFn` instances.
impl UninterpretedFn {
    /// Create new `UninterpretedFn` object that has no constraints regarding monotonicity,
    /// essentiality, or its expression. Annotation is empty.
    pub fn new_without_constraints(name: &str, arity: usize) -> Result<UninterpretedFn, String> {
        assert_name_valid(name)?;

        Ok(UninterpretedFn {
            name: name.to_string(),
            annotation: String::new(),
            arguments: vec![FnArgument::default(); arity],
            tree: None,
            expression: String::new(),
        })
    }

    /// Create new `UninterpretedFn` instance given all its raw components.
    ///
    /// The model and function's ID are used for validity check during argument parsing.
    pub fn new(
        name: &str,
        annotation: &str,
        expression: &str,
        arguments: Vec<FnArgument>,
        model: &ModelState,
        own_id: &UninterpretedFnId,
    ) -> Result<UninterpretedFn, String> {
        assert_name_valid(name)?;
        let arity = arguments.len();
        let mut f = UninterpretedFn::new_without_constraints(name, arity)?;
        f.set_all_arguments(arguments)?;
        f.set_fn_expression(expression, model, own_id)?;
        f.set_annotation(annotation);

        Ok(f)
    }

    /// Create uninterpreted function using another one as a template, but changing the expression.
    /// The provided original function object is consumed.
    pub fn with_new_expression(
        mut original_fn: UninterpretedFn,
        new_expression: &str,
        context: &ModelState,
        own_id: &UninterpretedFnId,
    ) -> Result<UninterpretedFn, String> {
        original_fn.set_fn_expression(new_expression, context, own_id)?;
        Ok(original_fn)
    }

    /// Create uninterpreted function from another one, substituting all occurrences of a given
    /// function symbol in the syntactic tree. The provided original function object is consumed.
    pub fn with_substituted_fn_symbol(
        mut original_fn: UninterpretedFn,
        old_id: &UninterpretedFnId,
        new_id: &UninterpretedFnId,
        context: &ModelState,
    ) -> UninterpretedFn {
        original_fn.substitute_fn_symbol(old_id, new_id, context);
        original_fn
    }
}

/// Editing `UninterpretedFn` instances.
impl UninterpretedFn {
    /// Rename this uninterpreted fn.
    pub fn set_name(&mut self, new_name: &str) -> Result<(), String> {
        assert_name_valid(new_name)?;
        self.name = new_name.to_string();
        Ok(())
    }

    /// Change annotation of this uninterpreted fn.
    pub fn set_annotation(&mut self, annotation: &str) {
        self.annotation = annotation.to_string();
    }

    /// Change arity of this uninterpreted fn.
    ///
    /// If arity is made larger, default arguments (without monotonicity/essentiality constraints)
    /// are added.
    /// If arity is made smaller, last arguments are dropped. These must not be used in function's
    /// expression.
    pub fn set_arity(&mut self, new_arity: usize) -> Result<(), String> {
        let arity = self.get_arity();
        if new_arity < arity {
            // if arity made smaller, check that the expression does not contain variables that
            // will be dropped
            if let Some(highest_var_idx) = self.get_highest_var_idx_in_expression() {
                if new_arity <= highest_var_idx {
                    let msg = "Cannot change arity of a function - its expression contains variables that would become invalid.";
                    return Err(msg.to_string());
                }
            }
            self.arguments.truncate(new_arity);
        } else {
            // if arity made larger, add default arguments
            let arg_count = new_arity - arity;
            for _ in 0..arg_count {
                self.add_default_argument();
            }
        }
        Ok(())
    }

    /// Drop the last argument of the function, essentially decrementing the arity of this
    /// uninterpreted fn. The last argument must not be used in function's expression.
    pub fn drop_last_argument(&mut self) -> Result<(), String> {
        if self.get_arity() == 0 {
            return Err("Cannot drop argument of a function with zero arguments.".to_string());
        }
        self.set_arity(self.get_arity() - 1)
    }

    /// Add an argument with specified monotonicity/essentiality.
    /// Argument is added at the end of the current argument list.
    pub fn add_argument(&mut self, monotonicity: Monotonicity, essentiality: Essentiality) {
        self.arguments
            .push(FnArgument::new(essentiality, monotonicity));
    }

    /// Add default argument (with unknown monotonicity/essentiality) for this function.
    /// Argument is added at the end of the current argument list.
    pub fn add_default_argument(&mut self) {
        self.arguments.push(FnArgument::default());
    }

    /// Set the function's expression to a given string.
    ///
    /// `model` is used to provide context regarding valid IDs.
    ///
    /// We also need ID of this uninterpreted function to ensure that the expression is not defined
    /// recursively, i.e., to check that expression of function `f` does not contain `f` inside.
    pub fn set_fn_expression(
        &mut self,
        new_expression: &str,
        model: &ModelState,
        own_id: &UninterpretedFnId,
    ) -> Result<(), String> {
        if new_expression.chars().all(|c| c.is_whitespace()) {
            self.tree = None;
            self.expression = String::new()
        } else {
            let syntactic_tree = FnTree::try_from_str(new_expression, model, Some((own_id, self)))?;
            self.expression = syntactic_tree.to_string(model, Some(self.get_arity()));
            self.tree = Some(syntactic_tree);
        }
        Ok(())
    }

    /// Set properties of an argument with given `index` (starting from 0).
    pub fn set_argument(&mut self, index: usize, argument: FnArgument) -> Result<(), String> {
        if index < self.get_arity() {
            self.arguments[index] = argument;
            Ok(())
        } else {
            Err("Cannot constrain an argument on index higher than function's arity.".to_string())
        }
    }

    /// Set `Essentiality` of argument with given `index` (starting from 0).
    pub fn set_essential(&mut self, index: usize, essential: Essentiality) -> Result<(), String> {
        if index < self.get_arity() {
            self.arguments[index].essential = essential;
            Ok(())
        } else {
            Err("Cannot constrain an argument on index higher than function's arity.".to_string())
        }
    }

    /// Set `Monotonicity` of argument with given `index` (starting from 0).
    pub fn set_monotonic(&mut self, index: usize, monotone: Monotonicity) -> Result<(), String> {
        if index < self.get_arity() {
            self.arguments[index].monotonicity = monotone;
            Ok(())
        } else {
            Err("Cannot constrain an argument on index higher than function's arity.".to_string())
        }
    }

    /// Set the properties for all arguments (essentially replacing the current version).
    /// The number of arguments must stay the same, not changing arity.
    pub fn set_all_arguments(&mut self, argument_list: Vec<FnArgument>) -> Result<(), String> {
        if argument_list.len() == self.get_arity() {
            self.arguments = argument_list;
            Ok(())
        } else {
            Err("Provided vector has different length than arity of this function.".to_string())
        }
    }

    /// Substitute all occurrences of a given function symbol in the syntactic tree.
    pub fn substitute_fn_symbol(
        &mut self,
        old_id: &UninterpretedFnId,
        new_id: &UninterpretedFnId,
        context: &ModelState,
    ) {
        if let Some(tree) = &self.tree {
            let new_tree = tree.substitute_fn_symbol(old_id, new_id);
            self.expression = new_tree.to_string(context, Some(self.get_arity()));
            self.tree = Some(new_tree);
        }
    }
}

/// Observing `UninterpretedFn` instances.
impl UninterpretedFn {
    /// Human-readable name of this uninterpreted fn.
    pub fn get_name(&self) -> &str {
        &self.name
    }

    /// Annotation of this uninterpreted fn.
    pub fn get_annotation(&self) -> &str {
        &self.annotation
    }

    /// Read arity (number of arguments) of this uninterpreted fn.
    pub fn get_arity(&self) -> usize {
        self.arguments.len()
    }

    /// Get highest index of a variable that is actually used in the function's expression.
    /// This number might be lower than function's actual arity.
    fn get_highest_var_idx_in_expression(&self) -> Option<usize> {
        if let Some(tree) = &self.tree {
            tree.collect_variables()
                .iter()
                .filter_map(|v| {
                    v.to_string()
                        .strip_prefix("var")
                        .and_then(|num_str| num_str.parse::<usize>().ok())
                })
                .max()
        } else {
            None
        }
    }

    /// Get function's expression.
    pub fn get_fn_expression(&self) -> &str {
        &self.expression
    }

    /// Get function's argument (`FnArgument` object) on given `index` (starting from 0).
    pub fn get_argument(&self, index: usize) -> &FnArgument {
        &self.arguments[index]
    }

    /// Get `Essentiality` of argument with given `index` (starting from 0).
    pub fn get_essential(&self, index: usize) -> &Essentiality {
        &self.arguments[index].essential
    }

    /// Get `Monotonicity` of argument with given `index` (starting from 0).
    pub fn get_monotonic(&self, index: usize) -> &Monotonicity {
        &self.arguments[index].monotonicity
    }

    /// Get list of all ordered arguments (`FnArgument` objects) of this function.
    pub fn get_all_arguments(&self) -> &Vec<FnArgument> {
        &self.arguments
    }

    /// Return a set of all variables that are actually used as inputs in this function.
    pub fn collect_variables(&self) -> HashSet<VarId> {
        if let Some(tree) = &self.tree {
            tree.collect_variables()
        } else {
            HashSet::new()
        }
    }

    /// Return a set of all uninterpreted fns that are actually used in this function.
    pub fn collect_fn_symbols(&self) -> HashSet<UninterpretedFnId> {
        if let Some(tree) = &self.tree {
            tree.collect_fn_symbols()
        } else {
            HashSet::new()
        }
    }
}

impl Display for UninterpretedFn {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        let mut args = Vec::new();
        for i in 1..=self.get_arity() {
            args.push(format!("x_{}", i));
        }
        let args_str = args.join(", ");
        write!(f, "{}({})", self.name, args_str)
    }
}

#[cfg(test)]
mod tests {
    use crate::sketchbook::ids::UninterpretedFnId;
    use crate::sketchbook::model::{ModelState, UninterpretedFn};

    #[test]
    fn basic_uninterpreted_fn_test() {
        let f = UninterpretedFn::new_without_constraints("f", 3).unwrap();
        assert_eq!(3, f.get_arity());
        assert_eq!("f", f.get_name());
        assert_eq!("f(x_1, x_2, x_3)", f.to_string().as_str());
    }

    #[test]
    fn invalid_uninterpreted_fn_test() {
        let f = UninterpretedFn::new_without_constraints("f\nxyz", 3);
        assert!(f.is_err());
    }

    #[test]
    fn uninterpreted_fn_expression_test() {
        // this test is a hack, normally just edit the function's expression through the `ModelState`
        // object that owns it

        let mut context = ModelState::new_empty();
        context
            .add_empty_uninterpreted_fn_by_str("f", "f", 3)
            .unwrap();

        let fn_id = UninterpretedFnId::new("f").unwrap();
        let mut f = UninterpretedFn::new_without_constraints("f", 3).unwrap();
        let expression = "var0 & (var1 => var2)";
        f.set_fn_expression(expression, &context, &fn_id).unwrap();
        assert_eq!(f.get_fn_expression(), expression);
    }
}