1
use super::{
2
    Atom, CnfClause, DeclarationPtr, Expression, Literal, Metadata, Moo, ReturnType, SymbolTable,
3
    Typeable,
4
    comprehension::Comprehension,
5
    declaration::DeclarationKind,
6
    pretty::{
7
        pretty_clauses, pretty_domain_letting_declaration, pretty_expressions_as_top_level,
8
        pretty_value_letting_declaration, pretty_variable_declaration,
9
    },
10
    serde::RcRefCellAsInner,
11
};
12
use itertools::izip;
13
use serde::{Deserialize, Serialize};
14
use serde_with::serde_as;
15
use uniplate::{Biplate, Tree, Uniplate};
16

            
17
use crate::{bug, into_matrix_expr};
18
use std::hash::{Hash, Hasher};
19
use std::{
20
    cell::{Ref, RefCell, RefMut},
21
    collections::VecDeque,
22
    fmt::Display,
23
    rc::Rc,
24
};
25

            
26
/// A sub-model, representing a lexical scope in the model.
27
///
28
/// Each sub-model contains a symbol table representing its scope, as well as a expression tree.
29
///
30
/// The expression tree is formed of a root node of type [`Expression::Root`], which contains a
31
/// vector of top-level constraints.
32
#[serde_as]
33
#[derive(Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
34
pub struct SubModel {
35
    constraints: Moo<Expression>,
36
    #[serde_as(as = "RcRefCellAsInner")]
37
    symbols: Rc<RefCell<SymbolTable>>,
38
    cnf_clauses: Vec<CnfClause>, // CNF clauses
39
}
40

            
41
impl SubModel {
42
    /// Creates a new [`Submodel`] with no parent scope.
43
    ///
44
    /// Top level models are represented as [`Model`](super::model): consider using
45
    /// [`Model::new`](super::Model::new) instead.
46
    #[doc(hidden)]
47
    pub(super) fn new_top_level() -> SubModel {
48
        SubModel {
49
            constraints: Moo::new(Expression::Root(Metadata::new(), vec![])),
50
            symbols: Rc::new(RefCell::new(SymbolTable::new())),
51
            cnf_clauses: Vec::new(),
52
        }
53
    }
54

            
55
    /// Creates a new [`Submodel`] as a child scope of `parent`.
56
    ///
57
    /// `parent` should be the symbol table of the containing scope of this sub-model.
58
    pub fn new(parent: Rc<RefCell<SymbolTable>>) -> SubModel {
59
        SubModel {
60
            constraints: Moo::new(Expression::Root(Metadata::new(), vec![])),
61
            symbols: Rc::new(RefCell::new(SymbolTable::with_parent(parent))),
62
            cnf_clauses: Vec::new(),
63
        }
64
    }
65

            
66
    /// The symbol table for this sub-model as a pointer.
67
    ///
68
    /// The caller should only mutate the returned symbol table if this method was called on a
69
    /// mutable model.
70
    pub fn symbols_ptr_unchecked(&self) -> &Rc<RefCell<SymbolTable>> {
71
        &self.symbols
72
    }
73

            
74
    /// The symbol table for this sub-model as a mutable pointer.
75
    ///
76
    /// The caller should only mutate the returned symbol table if this method was called on a
77
    /// mutable model.
78
    pub fn symbols_ptr_unchecked_mut(&mut self) -> &mut Rc<RefCell<SymbolTable>> {
79
        &mut self.symbols
80
    }
81

            
82
    /// The symbol table for this sub-model as a reference.
83
    pub fn symbols(&self) -> Ref<'_, SymbolTable> {
84
        (*self.symbols).borrow()
85
    }
86

            
87
    /// The symbol table for this sub-model as a mutable reference.
88
    pub fn symbols_mut(&mut self) -> RefMut<'_, SymbolTable> {
89
        (*self.symbols).borrow_mut()
90
    }
91

            
92
    /// The root node of this sub-model.
93
    ///
94
    /// The root node is an [`Expression::Root`] containing a vector of the top level constraints
95
    /// in this sub-model.
96
    pub fn root(&self) -> &Expression {
97
        &self.constraints
98
    }
99

            
100
    /// The root node of this sub-model, as a mutable reference.
101
    ///
102
    /// The caller is responsible for ensuring that the root node remains an [`Expression::Root`].
103
    ///
104
    pub fn root_mut_unchecked(&mut self) -> &mut Expression {
105
        Moo::make_mut(&mut self.constraints)
106
    }
107

            
108
    /// Replaces the root node with `new_root`, returning the old root node.
109
    ///
110
    /// # Panics
111
    ///
112
    /// - If `new_root` is not an [`Expression::Root`].
113
    pub fn replace_root(&mut self, new_root: Expression) -> Expression {
114
        let Expression::Root(_, _) = new_root else {
115
            tracing::error!(new_root=?new_root,"new_root is not an Expression::root");
116
            panic!("new_root is not an Expression::Root");
117
        };
118

            
119
        // INVARIANT: already checked that `new_root` is an [`Expression::Root`]
120
        std::mem::replace(self.root_mut_unchecked(), new_root)
121
    }
122

            
123
    /// The top-level constraints in this sub-model.
124
    pub fn constraints(&self) -> &Vec<Expression> {
125
        let Expression::Root(_, constraints) = self.constraints.as_ref() else {
126
            bug!("The top level expression in a submodel should be Expr::Root");
127
        };
128

            
129
        constraints
130
    }
131

            
132
    /// The cnf clauses in this sub-model.
133
    pub fn clauses(&self) -> &Vec<CnfClause> {
134
        &self.cnf_clauses
135
    }
136

            
137
    /// The top-level constraints in this sub-model as a mutable vector.
138
    pub fn constraints_mut(&mut self) -> &mut Vec<Expression> {
139
        let Expression::Root(_, constraints) = Moo::make_mut(&mut self.constraints) else {
140
            bug!("The top level expression in a submodel should be Expr::Root");
141
        };
142

            
143
        constraints
144
    }
145

            
146
    /// The cnf clauses in this sub-model as a mutable vector.
147
    pub fn clauses_mut(&mut self) -> &mut Vec<CnfClause> {
148
        &mut self.cnf_clauses
149
    }
150

            
151
    /// Replaces the top-level constraints with `new_constraints`, returning the old ones.
152
    pub fn replace_constraints(&mut self, new_constraints: Vec<Expression>) -> Vec<Expression> {
153
        std::mem::replace(self.constraints_mut(), new_constraints)
154
    }
155

            
156
    /// Replaces the cnf clauses with `new_clauses`, returning the old ones.
157
    pub fn replace_clauses(&mut self, new_clauses: Vec<CnfClause>) -> Vec<CnfClause> {
158
        std::mem::replace(self.clauses_mut(), new_clauses)
159
    }
160

            
161
    /// Adds a top-level constraint.
162
    pub fn add_constraint(&mut self, constraint: Expression) {
163
        self.constraints_mut().push(constraint);
164
    }
165

            
166
    /// Adds a cnf clause.
167
    pub fn add_clause(&mut self, clause: CnfClause) {
168
        self.clauses_mut().push(clause);
169
    }
170

            
171
    /// Adds top-level constraints.
172
    pub fn add_constraints(&mut self, constraints: Vec<Expression>) {
173
        self.constraints_mut().extend(constraints);
174
    }
175

            
176
    /// Adds cnf clauses.
177
    pub fn add_clauses(&mut self, clauses: Vec<CnfClause>) {
178
        self.clauses_mut().extend(clauses);
179
    }
180

            
181
    /// Adds a new symbol to the symbol table
182
    /// (Wrapper over `SymbolTable.insert`)
183
    pub fn add_symbol(&mut self, decl: DeclarationPtr) -> Option<()> {
184
        self.symbols_mut().insert(decl)
185
    }
186

            
187
    /// Converts the constraints in this submodel to a single expression suitable for use inside
188
    /// another expression tree.
189
    ///
190
    /// * If this submodel has no constraints, true is returned.
191
    /// * If this submodel has a single constraint, that constraint is returned.
192
    /// * If this submodel has multiple constraints, they are returned as an `and` constraint.
193
    pub fn into_single_expression(self) -> Expression {
194
        let constraints = self.constraints().clone();
195
        match constraints.len() {
196
            0 => Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
197
            1 => constraints[0].clone(),
198
            _ => Expression::And(Metadata::new(), Moo::new(into_matrix_expr![constraints])),
199
        }
200
    }
201
}
202

            
203
impl Typeable for SubModel {
204
    fn return_type(&self) -> ReturnType {
205
        ReturnType::Bool
206
    }
207
}
208

            
209
impl Hash for SubModel {
210
    fn hash<H: Hasher>(&self, state: &mut H) {
211
        self.symbols.borrow().hash(state);
212
        self.constraints.hash(state);
213
        self.cnf_clauses.hash(state);
214
    }
215
}
216

            
217
impl Display for SubModel {
218
    #[allow(clippy::unwrap_used)] // [rustdocs]: should only fail iff the formatter fails
219
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220
        for (name, decl) in self.symbols().clone().into_iter_local() {
221
            match &decl.kind() as &DeclarationKind {
222
                DeclarationKind::DecisionVariable(_) => {
223
                    writeln!(
224
                        f,
225
                        "{}",
226
                        pretty_variable_declaration(&self.symbols(), &name).unwrap()
227
                    )?;
228
                }
229
                DeclarationKind::ValueLetting(_) => {
230
                    writeln!(
231
                        f,
232
                        "{}",
233
                        pretty_value_letting_declaration(&self.symbols(), &name).unwrap()
234
                    )?;
235
                }
236
                DeclarationKind::DomainLetting(_) => {
237
                    writeln!(
238
                        f,
239
                        "{}",
240
                        pretty_domain_letting_declaration(&self.symbols(), &name).unwrap()
241
                    )?;
242
                }
243
                DeclarationKind::Given(d) => {
244
                    writeln!(f, "given {name}: {d}")?;
245
                }
246

            
247
                DeclarationKind::RecordField(_) => {
248
                    // Do not print a record field as it is an internal type
249
                    writeln!(f)?;
250
                    // TODO: is this correct?
251
                }
252
            }
253
        }
254

            
255
        if !self.constraints().is_empty() {
256
            writeln!(f, "\nsuch that\n")?;
257
            writeln!(f, "{}", pretty_expressions_as_top_level(self.constraints()))?;
258
        }
259

            
260
        if !self.clauses().is_empty() {
261
            writeln!(f, "\nclauses:\n")?;
262

            
263
            writeln!(f, "{}", pretty_clauses(self.clauses()))?;
264
        }
265
        Ok(())
266
    }
267
}
268

            
269
// Using manual implementations of Uniplate so that we can update the old Rc<RefCell<<>>> with the
270
// new value instead of creating a new one. This will keep the parent pointers in sync.
271
//
272
// I considered adding Rc RefCell shared-mutability to Uniplate, but I think this is unsound in
273
// generality: e.g. two pointers to the same object are in our tree, and both get modified in
274
// different ways.
275
//
276
// Shared mutability is probably fine here, as we only have one pointer to each symbol table
277
// reachable via Uniplate, the one in its Submodel. The SymbolTable implementation doesn't return
278
// or traverse through the parent pointers.
279
//
280
// -- nd60
281

            
282
impl Uniplate for SubModel {
283
    fn uniplate(&self) -> (Tree<Self>, Box<dyn Fn(Tree<Self>) -> Self>) {
284
        // Look inside constraint tree and symbol tables.
285

            
286
        let (expr_tree, expr_ctx) = <Expression as Biplate<SubModel>>::biplate(self.root());
287

            
288
        let symtab_ptr = self.symbols();
289
        let (symtab_tree, symtab_ctx) = <SymbolTable as Biplate<SubModel>>::biplate(&symtab_ptr);
290

            
291
        let tree = Tree::Many(VecDeque::from([expr_tree, symtab_tree]));
292

            
293
        let self2 = self.clone();
294
        let ctx = Box::new(move |x| {
295
            let Tree::Many(xs) = x else {
296
                panic!();
297
            };
298

            
299
            let root = expr_ctx(xs[0].clone());
300
            let symtab = symtab_ctx(xs[1].clone());
301

            
302
            let mut self3 = self2.clone();
303

            
304
            let Expression::Root(_, _) = root else {
305
                bug!("root expression not root");
306
            };
307

            
308
            *self3.root_mut_unchecked() = root;
309

            
310
            *self3.symbols_mut() = symtab;
311

            
312
            self3
313
        });
314

            
315
        (tree, ctx)
316
    }
317
}
318

            
319
impl Biplate<Expression> for SubModel {
320
    fn biplate(&self) -> (Tree<Expression>, Box<dyn Fn(Tree<Expression>) -> Self>) {
321
        // Return constraints tree and look inside symbol table.
322
        let symtab_ptr = self.symbols();
323
        let (symtab_tree, symtab_ctx) = <SymbolTable as Biplate<Expression>>::biplate(&symtab_ptr);
324

            
325
        let tree = Tree::Many(VecDeque::from([
326
            Tree::One(self.root().clone()),
327
            symtab_tree,
328
        ]));
329

            
330
        let self2 = self.clone();
331
        let ctx = Box::new(move |x| {
332
            let Tree::Many(xs) = x else {
333
                panic!();
334
            };
335

            
336
            let Tree::One(root) = xs[0].clone() else {
337
                panic!();
338
            };
339

            
340
            let symtab = symtab_ctx(xs[1].clone());
341

            
342
            let mut self3 = self2.clone();
343

            
344
            let Expression::Root(_, _) = root else {
345
                bug!("root expression not root");
346
            };
347

            
348
            *self3.root_mut_unchecked() = root;
349

            
350
            *self3.symbols_mut() = symtab;
351

            
352
            self3
353
        });
354

            
355
        (tree, ctx)
356
    }
357
}
358

            
359
impl Biplate<SubModel> for SubModel {
360
    fn biplate(&self) -> (Tree<SubModel>, Box<dyn Fn(Tree<SubModel>) -> Self>) {
361
        (
362
            Tree::One(self.clone()),
363
            Box::new(move |x| {
364
                let Tree::One(x) = x else {
365
                    panic!();
366
                };
367
                x
368
            }),
369
        )
370
    }
371
}
372

            
373
impl Biplate<Atom> for SubModel {
374
    fn biplate(&self) -> (Tree<Atom>, Box<dyn Fn(Tree<Atom>) -> Self>) {
375
        // As atoms are only found in expressions, create a tree of atoms by
376
        //
377
        //  1. getting the expression tree
378
        //  2. Turning that into a list
379
        //  3. For each expression in the list, use Biplate<Atom> to turn it into an atom
380
        //
381
        //  Reconstruction works in reverse.
382

            
383
        let (expression_tree, rebuild_self) = <SubModel as Biplate<Expression>>::biplate(self);
384
        let (expression_list, rebuild_expression_tree) = expression_tree.list();
385

            
386
        // Let the atom tree be a Tree::Many where each element is the result of running Biplate<Atom>::biplate on an expression in the expression list.
387
        let (atom_trees, reconstruct_exprs): (VecDeque<_>, VecDeque<_>) = expression_list
388
            .iter()
389
            .map(|e| <Expression as Biplate<Atom>>::biplate(e))
390
            .unzip();
391

            
392
        let tree = Tree::Many(atom_trees);
393
        let ctx = Box::new(move |atom_tree: Tree<Atom>| {
394
            // 1. reconstruct expression_list from the atom tree
395

            
396
            let Tree::Many(atoms) = atom_tree else {
397
                panic!();
398
            };
399

            
400
            assert_eq!(
401
                atoms.len(),
402
                reconstruct_exprs.len(),
403
                "the number of children should not change when using Biplate"
404
            );
405

            
406
            let expression_list: VecDeque<Expression> = izip!(atoms, &reconstruct_exprs)
407
                .map(|(atom, recons)| recons(atom))
408
                .collect();
409

            
410
            // 2. reconstruct expression_tree from expression_list
411
            let expression_tree = rebuild_expression_tree(expression_list);
412

            
413
            // 3. reconstruct submodel from expression_tree
414
            rebuild_self(expression_tree)
415
        });
416

            
417
        (tree, ctx)
418
    }
419
}
420

            
421
impl Biplate<Comprehension> for SubModel {
422
    fn biplate(
423
        &self,
424
    ) -> (
425
        Tree<Comprehension>,
426
        Box<dyn Fn(Tree<Comprehension>) -> Self>,
427
    ) {
428
        let (f1_tree, f1_ctx) = <_ as Biplate<Comprehension>>::biplate(&self.constraints);
429
        let (f2_tree, f2_ctx) =
430
            <SymbolTable as Biplate<Comprehension>>::biplate(&self.symbols.borrow());
431

            
432
        let tree = Tree::Many(VecDeque::from([f1_tree, f2_tree]));
433
        let self2 = self.clone();
434
        let ctx = Box::new(move |x| {
435
            let Tree::Many(xs) = x else {
436
                panic!();
437
            };
438

            
439
            let root = f1_ctx(xs[0].clone());
440
            let symtab = f2_ctx(xs[1].clone());
441

            
442
            let mut self3 = self2.clone();
443

            
444
            let Expression::Root(_, _) = &*root else {
445
                bug!("root expression not root");
446
            };
447

            
448
            *self3.symbols_mut() = symtab;
449
            self3.constraints = root;
450

            
451
            self3
452
        });
453

            
454
        (tree, ctx)
455
    }
456
}