1
use super::{
2
    Atom, CnfClause, DeclarationPtr, Expression, Literal, Metadata, Moo, ReturnType, SymbolTable,
3
    SymbolTablePtr, Typeable,
4
    comprehension::Comprehension,
5
    declaration::DeclarationKind,
6
    pretty::{
7
        pretty_clauses, pretty_domain_letting_declaration, pretty_expressions_as_top_level,
8
        pretty_value_letting_declaration, pretty_variable_declaration,
9
    },
10
    serde::PtrAsInner,
11
};
12
use itertools::izip;
13
use serde::{Deserialize, Serialize};
14
use serde_with::serde_as;
15
use uniplate::{Biplate, Tree, Uniplate};
16

            
17
use crate::{bug, into_matrix_expr};
18
use parking_lot::{RwLockReadGuard, RwLockWriteGuard};
19
use std::hash::Hash;
20
use std::{collections::VecDeque, fmt::Display};
21

            
22
/// A sub-model, representing a lexical scope in the model.
23
///
24
/// Each sub-model contains a symbol table representing its scope, as well as a expression tree.
25
///
26
/// The expression tree is formed of a root node of type [`Expression::Root`], which contains a
27
/// vector of top-level constraints.
28
#[serde_as]
29
#[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
30
pub struct SubModel {
31
    constraints: Moo<Expression>,
32
    #[serde_as(as = "PtrAsInner")]
33
    symbols: SymbolTablePtr,
34
    cnf_clauses: Vec<CnfClause>, // CNF clauses
35
}
36

            
37
impl SubModel {
38
    /// Creates a new [`Submodel`] with no parent scope.
39
    ///
40
    /// Top level models are represented as [`Model`](super::model): consider using
41
    /// [`Model::new`](super::Model::new) instead.
42
    #[doc(hidden)]
43
8696
    pub(super) fn new_top_level() -> SubModel {
44
8696
        SubModel {
45
8696
            constraints: Moo::new(Expression::Root(Metadata::new(), vec![])),
46
8696
            symbols: SymbolTablePtr::new(),
47
8696
            cnf_clauses: Vec::new(),
48
8696
        }
49
8696
    }
50

            
51
    /// Creates a new [`Submodel`] as a child scope of `parent`.
52
    ///
53
    /// `parent` should be the symbol table of the containing scope of this sub-model.
54
1600
    pub fn new(parent: SymbolTablePtr) -> SubModel {
55
1600
        SubModel {
56
1600
            constraints: Moo::new(Expression::Root(Metadata::new(), vec![])),
57
1600
            symbols: SymbolTablePtr::with_parent(parent),
58
1600
            cnf_clauses: Vec::new(),
59
1600
        }
60
1600
    }
61

            
62
    /// The symbol table for this sub-model as a pointer.
63
    ///
64
    /// The caller should only mutate the returned symbol table if this method was called on a
65
    /// mutable model.
66
32728
    pub fn symbols_ptr_unchecked(&self) -> &SymbolTablePtr {
67
32728
        &self.symbols
68
32728
    }
69

            
70
    /// The symbol table for this sub-model as a mutable pointer.
71
    ///
72
    /// The caller should only mutate the returned symbol table if this method was called on a
73
    /// mutable model.
74
1600
    pub fn symbols_ptr_unchecked_mut(&mut self) -> &mut SymbolTablePtr {
75
1600
        &mut self.symbols
76
1600
    }
77

            
78
    /// The symbol table for this sub-model as a reference.
79
44613830
    pub fn symbols(&self) -> RwLockReadGuard<'_, SymbolTable> {
80
44613830
        self.symbols.read()
81
44613830
    }
82

            
83
    /// The symbol table for this sub-model as a mutable reference.
84
60088
    pub fn symbols_mut(&mut self) -> RwLockWriteGuard<'_, SymbolTable> {
85
60088
        self.symbols.write()
86
60088
    }
87

            
88
    /// The root node of this sub-model.
89
    ///
90
    /// The root node is an [`Expression::Root`] containing a vector of the top level constraints
91
    /// in this sub-model.
92
653584
    pub fn root(&self) -> &Expression {
93
653584
        &self.constraints
94
653584
    }
95

            
96
    /// The root node of this sub-model, as a mutable reference.
97
    ///
98
    /// The caller is responsible for ensuring that the root node remains an [`Expression::Root`].
99
    ///
100
47790
    pub fn root_mut_unchecked(&mut self) -> &mut Expression {
101
47790
        Moo::make_mut(&mut self.constraints)
102
47790
    }
103

            
104
    /// Replaces the root node with `new_root`, returning the old root node.
105
    ///
106
    /// # Panics
107
    ///
108
    /// - If `new_root` is not an [`Expression::Root`].
109
44170
    pub fn replace_root(&mut self, new_root: Expression) -> Expression {
110
44170
        let Expression::Root(_, _) = new_root else {
111
            tracing::error!(new_root=?new_root,"new_root is not an Expression::root");
112
            panic!("new_root is not an Expression::Root");
113
        };
114

            
115
        // INVARIANT: already checked that `new_root` is an [`Expression::Root`]
116
44170
        std::mem::replace(self.root_mut_unchecked(), new_root)
117
44170
    }
118

            
119
    /// The top-level constraints in this sub-model.
120
59422
    pub fn constraints(&self) -> &Vec<Expression> {
121
59422
        let Expression::Root(_, constraints) = self.constraints.as_ref() else {
122
            bug!("The top level expression in a submodel should be Expr::Root");
123
        };
124
59422
        constraints
125
59422
    }
126

            
127
    /// The cnf clauses in this sub-model.
128
15648
    pub fn clauses(&self) -> &Vec<CnfClause> {
129
15648
        &self.cnf_clauses
130
15648
    }
131

            
132
    /// The top-level constraints in this sub-model as a mutable vector.
133
54012
    pub fn constraints_mut(&mut self) -> &mut Vec<Expression> {
134
54012
        let Expression::Root(_, constraints) = Moo::make_mut(&mut self.constraints) else {
135
            bug!("The top level expression in a submodel should be Expr::Root");
136
        };
137

            
138
54012
        constraints
139
54012
    }
140

            
141
    /// The cnf clauses in this sub-model as a mutable vector.
142
44170
    pub fn clauses_mut(&mut self) -> &mut Vec<CnfClause> {
143
44170
        &mut self.cnf_clauses
144
44170
    }
145

            
146
    /// Replaces the top-level constraints with `new_constraints`, returning the old ones.
147
    pub fn replace_constraints(&mut self, new_constraints: Vec<Expression>) -> Vec<Expression> {
148
        std::mem::replace(self.constraints_mut(), new_constraints)
149
    }
150

            
151
    /// Replaces the cnf clauses with `new_clauses`, returning the old ones.
152
    pub fn replace_clauses(&mut self, new_clauses: Vec<CnfClause>) -> Vec<CnfClause> {
153
        std::mem::replace(self.clauses_mut(), new_clauses)
154
    }
155

            
156
    /// Adds a top-level constraint.
157
2222
    pub fn add_constraint(&mut self, constraint: Expression) {
158
2222
        self.constraints_mut().push(constraint);
159
2222
    }
160

            
161
    /// Adds a cnf clause.
162
    pub fn add_clause(&mut self, clause: CnfClause) {
163
        self.clauses_mut().push(clause);
164
    }
165

            
166
    /// Adds top-level constraints.
167
50550
    pub fn add_constraints(&mut self, constraints: Vec<Expression>) {
168
50550
        self.constraints_mut().extend(constraints);
169
50550
    }
170

            
171
    /// Adds cnf clauses.
172
44170
    pub fn add_clauses(&mut self, clauses: Vec<CnfClause>) {
173
44170
        self.clauses_mut().extend(clauses);
174
44170
    }
175

            
176
    /// Adds a new symbol to the symbol table
177
    /// (Wrapper over `SymbolTable.insert`)
178
    pub fn add_symbol(&mut self, decl: DeclarationPtr) -> Option<()> {
179
        self.symbols_mut().insert(decl)
180
    }
181

            
182
    /// Converts the constraints in this submodel to a single expression suitable for use inside
183
    /// another expression tree.
184
    ///
185
    /// * If this submodel has no constraints, true is returned.
186
    /// * If this submodel has a single constraint, that constraint is returned.
187
    /// * If this submodel has multiple constraints, they are returned as an `and` constraint.
188
13220
    pub fn into_single_expression(self) -> Expression {
189
13220
        let constraints = self.constraints().clone();
190
13220
        match constraints.len() {
191
            0 => Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
192
12400
            1 => constraints[0].clone(),
193
820
            _ => Expression::And(Metadata::new(), Moo::new(into_matrix_expr![constraints])),
194
        }
195
13220
    }
196
}
197

            
198
impl Typeable for SubModel {
199
    fn return_type(&self) -> ReturnType {
200
        ReturnType::Bool
201
    }
202
}
203

            
204
impl Display for SubModel {
205
    #[allow(clippy::unwrap_used)] // [rustdocs]: should only fail iff the formatter fails
206
15648
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207
39288
        for (name, decl) in self.symbols().clone().into_iter_local() {
208
39288
            match &decl.kind() as &DeclarationKind {
209
                DeclarationKind::Find(_) => {
210
33448
                    writeln!(
211
33448
                        f,
212
                        "{}",
213
33448
                        pretty_variable_declaration(&self.symbols(), &name).unwrap()
214
                    )?;
215
                }
216
                DeclarationKind::ValueLetting(_) => {
217
3240
                    writeln!(
218
3240
                        f,
219
                        "{}",
220
3240
                        pretty_value_letting_declaration(&self.symbols(), &name).unwrap()
221
                    )?;
222
                }
223
                DeclarationKind::DomainLetting(_) => {
224
600
                    writeln!(
225
600
                        f,
226
                        "{}",
227
600
                        pretty_domain_letting_declaration(&self.symbols(), &name).unwrap()
228
                    )?;
229
                }
230
                DeclarationKind::Given(d) => {
231
                    writeln!(f, "given {name}: {d}")?;
232
                }
233
1920
                DeclarationKind::Quantified(inner) => {
234
1920
                    writeln!(f, "given {name}: {}", inner.domain())?;
235
                }
236

            
237
                DeclarationKind::RecordField(_) => {
238
                    // Do not print a record field as it is an internal type
239
80
                    writeln!(f)?;
240
                    // TODO: is this correct?
241
                }
242
            }
243
        }
244

            
245
15648
        if !self.constraints().is_empty() {
246
14584
            writeln!(f, "\nsuch that\n")?;
247
14584
            writeln!(f, "{}", pretty_expressions_as_top_level(self.constraints()))?;
248
1064
        }
249

            
250
15648
        if !self.clauses().is_empty() {
251
            writeln!(f, "\nclauses:\n")?;
252

            
253
            writeln!(f, "{}", pretty_clauses(self.clauses()))?;
254
15648
        }
255
15648
        Ok(())
256
15648
    }
257
}
258

            
259
// Using manual implementations of Uniplate so that we can update the old Rc<RefCell<<>>> with the
260
// new value instead of creating a new one. This will keep the parent pointers in sync.
261
//
262
// I considered adding Rc RefCell shared-mutability to Uniplate, but I think this is unsound in
263
// generality: e.g. two pointers to the same object are in our tree, and both get modified in
264
// different ways.
265
//
266
// Shared mutability is probably fine here, as we only have one pointer to each symbol table
267
// reachable via Uniplate, the one in its Submodel. The SymbolTable implementation doesn't return
268
// or traverse through the parent pointers.
269
//
270
// -- nd60
271

            
272
impl Uniplate for SubModel {
273
54942
    fn uniplate(&self) -> (Tree<Self>, Box<dyn Fn(Tree<Self>) -> Self>) {
274
        // Look inside constraint tree and symbol tables.
275

            
276
54942
        let (expr_tree, expr_ctx) = <Expression as Biplate<SubModel>>::biplate(self.root());
277

            
278
54942
        let symtab_ptr = self.symbols();
279
54942
        let (symtab_tree, symtab_ctx) = <SymbolTable as Biplate<SubModel>>::biplate(&symtab_ptr);
280

            
281
54942
        let tree = Tree::Many(VecDeque::from([expr_tree, symtab_tree]));
282

            
283
54942
        let self2 = self.clone();
284
54942
        let ctx = Box::new(move |x| {
285
1740
            let Tree::Many(xs) = x else {
286
                panic!();
287
            };
288

            
289
1740
            let root = expr_ctx(xs[0].clone());
290
1740
            let symtab = symtab_ctx(xs[1].clone());
291

            
292
1740
            let mut self3 = self2.clone();
293

            
294
1740
            let Expression::Root(_, _) = root else {
295
                bug!("root expression not root");
296
            };
297

            
298
1740
            *self3.root_mut_unchecked() = root;
299

            
300
1740
            *self3.symbols_mut() = symtab;
301

            
302
1740
            self3
303
1740
        });
304

            
305
54942
        (tree, ctx)
306
54942
    }
307
}
308

            
309
impl Biplate<Expression> for SubModel {
310
54602
    fn biplate(&self) -> (Tree<Expression>, Box<dyn Fn(Tree<Expression>) -> Self>) {
311
        // Return constraints tree and look inside symbol table.
312
54602
        let symtab_ptr = self.symbols();
313
54602
        let (symtab_tree, symtab_ctx) = <SymbolTable as Biplate<Expression>>::biplate(&symtab_ptr);
314

            
315
54602
        let tree = Tree::Many(VecDeque::from([
316
54602
            Tree::One(self.root().clone()),
317
54602
            symtab_tree,
318
54602
        ]));
319

            
320
54602
        let self2 = self.clone();
321
54602
        let ctx = Box::new(move |x| {
322
280
            let Tree::Many(xs) = x else {
323
                panic!();
324
            };
325

            
326
280
            let Tree::One(root) = xs[0].clone() else {
327
                panic!();
328
            };
329

            
330
280
            let symtab = symtab_ctx(xs[1].clone());
331

            
332
280
            let mut self3 = self2.clone();
333

            
334
280
            let Expression::Root(_, _) = root else {
335
                bug!("root expression not root");
336
            };
337

            
338
280
            *self3.root_mut_unchecked() = root;
339

            
340
280
            *self3.symbols_mut() = symtab;
341

            
342
280
            self3
343
280
        });
344

            
345
54602
        (tree, ctx)
346
54602
    }
347
}
348

            
349
impl Biplate<SubModel> for SubModel {
350
11160
    fn biplate(&self) -> (Tree<SubModel>, Box<dyn Fn(Tree<SubModel>) -> Self>) {
351
        (
352
11160
            Tree::One(self.clone()),
353
11160
            Box::new(move |x| {
354
280
                let Tree::One(x) = x else {
355
                    panic!();
356
                };
357
280
                x
358
280
            }),
359
        )
360
11160
    }
361
}
362

            
363
impl Biplate<Atom> for SubModel {
364
    fn biplate(&self) -> (Tree<Atom>, Box<dyn Fn(Tree<Atom>) -> Self>) {
365
        // As atoms are only found in expressions, create a tree of atoms by
366
        //
367
        //  1. getting the expression tree
368
        //  2. Turning that into a list
369
        //  3. For each expression in the list, use Biplate<Atom> to turn it into an atom
370
        //
371
        //  Reconstruction works in reverse.
372

            
373
        let (expression_tree, rebuild_self) = <SubModel as Biplate<Expression>>::biplate(self);
374
        let (expression_list, rebuild_expression_tree) = expression_tree.list();
375

            
376
        // Let the atom tree be a Tree::Many where each element is the result of running Biplate<Atom>::biplate on an expression in the expression list.
377
        let (atom_trees, reconstruct_exprs): (VecDeque<_>, VecDeque<_>) = expression_list
378
            .iter()
379
            .map(|e| <Expression as Biplate<Atom>>::biplate(e))
380
            .unzip();
381

            
382
        let tree = Tree::Many(atom_trees);
383
        let ctx = Box::new(move |atom_tree: Tree<Atom>| {
384
            // 1. reconstruct expression_list from the atom tree
385

            
386
            let Tree::Many(atoms) = atom_tree else {
387
                panic!();
388
            };
389

            
390
            assert_eq!(
391
                atoms.len(),
392
                reconstruct_exprs.len(),
393
                "the number of children should not change when using Biplate"
394
            );
395

            
396
            let expression_list: VecDeque<Expression> = izip!(atoms, &reconstruct_exprs)
397
                .map(|(atom, recons)| recons(atom))
398
                .collect();
399

            
400
            // 2. reconstruct expression_tree from expression_list
401
            let expression_tree = rebuild_expression_tree(expression_list);
402

            
403
            // 3. reconstruct submodel from expression_tree
404
            rebuild_self(expression_tree)
405
        });
406

            
407
        (tree, ctx)
408
    }
409
}
410

            
411
impl Biplate<Comprehension> for SubModel {
412
69840
    fn biplate(
413
69840
        &self,
414
69840
    ) -> (
415
69840
        Tree<Comprehension>,
416
69840
        Box<dyn Fn(Tree<Comprehension>) -> Self>,
417
69840
    ) {
418
69840
        let (f1_tree, f1_ctx) = <_ as Biplate<Comprehension>>::biplate(&self.constraints);
419
69840
        let (f2_tree, f2_ctx) = <SymbolTable as Biplate<Comprehension>>::biplate(&self.symbols());
420

            
421
69840
        let tree = Tree::Many(VecDeque::from([f1_tree, f2_tree]));
422
69840
        let self2 = self.clone();
423
69840
        let ctx = Box::new(move |x| {
424
            let Tree::Many(xs) = x else {
425
                panic!();
426
            };
427

            
428
            let root = f1_ctx(xs[0].clone());
429
            let symtab = f2_ctx(xs[1].clone());
430

            
431
            let mut self3 = self2.clone();
432

            
433
            let Expression::Root(_, _) = &*root else {
434
                bug!("root expression not root");
435
            };
436

            
437
            *self3.symbols_mut() = symtab;
438
            self3.constraints = root;
439

            
440
            self3
441
        });
442

            
443
69840
        (tree, ctx)
444
69840
    }
445
}