1
use std::collections::{HashMap, VecDeque};
2
use std::fmt::{Debug, Display};
3
use std::hash::Hash;
4
use std::sync::{Arc, RwLock};
5

            
6
use crate::context::Context;
7
use crate::{bug, into_matrix_expr};
8
use derivative::Derivative;
9
use indexmap::IndexSet;
10
use itertools::izip;
11
use parking_lot::{RwLockReadGuard, RwLockWriteGuard};
12
use serde::{Deserialize, Serialize};
13
use serde_with::serde_as;
14
use uniplate::{Biplate, Tree, Uniplate};
15

            
16
use super::serde::{HasId, ObjId, PtrAsInner};
17
use super::{
18
    Atom, CnfClause, DeclarationPtr, Expression, Literal, Metadata, Moo, Name, ReturnType,
19
    SymbolTable, SymbolTablePtr, Typeable,
20
    comprehension::Comprehension,
21
    declaration::DeclarationKind,
22
    pretty::{
23
        pretty_clauses, pretty_domain_letting_declaration, pretty_expressions_as_top_level,
24
        pretty_value_letting_declaration, pretty_variable_declaration,
25
    },
26
};
27

            
28
/// An Essence model.
29
#[serde_as]
30
#[derive(Derivative, Clone, Debug, Serialize, Deserialize)]
31
#[derivative(PartialEq, Eq)]
32
pub struct Model {
33
    constraints: Moo<Expression>,
34
    #[serde_as(as = "PtrAsInner")]
35
    symbols: SymbolTablePtr,
36
    cnf_clauses: Vec<CnfClause>,
37

            
38
    pub search_order: Option<Vec<Name>>,
39
    pub dominance: Option<Expression>,
40

            
41
    #[serde(skip, default = "default_context")]
42
    #[derivative(PartialEq = "ignore")]
43
    pub context: Arc<RwLock<Context<'static>>>,
44
}
45

            
46
12846
fn default_context() -> Arc<RwLock<Context<'static>>> {
47
12846
    Arc::new(RwLock::new(Context::default()))
48
12846
}
49

            
50
impl Model {
51
40476
    fn new_empty(symbols: SymbolTablePtr, context: Arc<RwLock<Context<'static>>>) -> Model {
52
40476
        Model {
53
40476
            constraints: Moo::new(Expression::Root(Metadata::new(), vec![])),
54
40476
            symbols,
55
40476
            cnf_clauses: Vec::new(),
56
40476
            search_order: None,
57
40476
            dominance: None,
58
40476
            context,
59
40476
        }
60
40476
    }
61

            
62
    /// Creates a new top-level model from the given context.
63
28236
    pub fn new(context: Arc<RwLock<Context<'static>>>) -> Model {
64
28236
        Self::new_empty(SymbolTablePtr::new(), context)
65
28236
    }
66

            
67
    /// Creates a new model whose symbol table has `parent` as parent scope.
68
12240
    pub fn new_in_parent_scope(parent: SymbolTablePtr) -> Model {
69
12240
        Self::new_empty(SymbolTablePtr::with_parent(parent), default_context())
70
12240
    }
71

            
72
    /// The symbol table for this model as a pointer.
73
564578
    pub fn symbols_ptr_unchecked(&self) -> &SymbolTablePtr {
74
564578
        &self.symbols
75
564578
    }
76

            
77
    /// The symbol table for this model as a mutable pointer.
78
12240
    pub fn symbols_ptr_unchecked_mut(&mut self) -> &mut SymbolTablePtr {
79
12240
        &mut self.symbols
80
12240
    }
81

            
82
    /// The symbol table for this model as a reference.
83
471433014
    pub fn symbols(&self) -> RwLockReadGuard<'_, SymbolTable> {
84
471433014
        self.symbols.read()
85
471433014
    }
86

            
87
    /// The symbol table for this model as a mutable reference.
88
476866
    pub fn symbols_mut(&mut self) -> RwLockWriteGuard<'_, SymbolTable> {
89
476866
        self.symbols.write()
90
476866
    }
91

            
92
    /// The root node of this model.
93
5904166
    pub fn root(&self) -> &Expression {
94
5904166
        &self.constraints
95
5904166
    }
96

            
97
    /// The root node of this model, as a mutable reference.
98
    ///
99
    /// The caller is responsible for ensuring that the root node remains an [`Expression::Root`].
100
417270
    pub fn root_mut_unchecked(&mut self) -> &mut Expression {
101
417270
        Moo::make_mut(&mut self.constraints)
102
417270
    }
103

            
104
    /// Replaces the root node with `new_root`, returning the old root node.
105
417270
    pub fn replace_root(&mut self, new_root: Expression) -> Expression {
106
417270
        let Expression::Root(_, _) = new_root else {
107
            tracing::error!(new_root=?new_root,"new_root is not an Expression::Root");
108
            panic!("new_root is not an Expression::Root");
109
        };
110

            
111
417270
        std::mem::replace(self.root_mut_unchecked(), new_root)
112
417270
    }
113

            
114
    /// The top-level constraints in this model.
115
87192
    pub fn constraints(&self) -> &Vec<Expression> {
116
87192
        let Expression::Root(_, constraints) = self.constraints.as_ref() else {
117
            bug!("The top level expression in a model should be Expr::Root");
118
        };
119
87192
        constraints
120
87192
    }
121

            
122
    /// The cnf clauses in this model.
123
1410524
    pub fn clauses(&self) -> &Vec<CnfClause> {
124
1410524
        &self.cnf_clauses
125
1410524
    }
126

            
127
    /// The top-level constraints in this model as a mutable vector.
128
154052
    pub fn constraints_mut(&mut self) -> &mut Vec<Expression> {
129
154052
        let Expression::Root(_, constraints) = Moo::make_mut(&mut self.constraints) else {
130
            bug!("The top level expression in a model should be Expr::Root");
131
        };
132

            
133
154052
        constraints
134
154052
    }
135

            
136
    /// The cnf clauses in this model as a mutable vector.
137
417270
    pub fn clauses_mut(&mut self) -> &mut Vec<CnfClause> {
138
417270
        &mut self.cnf_clauses
139
417270
    }
140

            
141
    /// Replaces the top-level constraints with `new_constraints`, returning the old ones.
142
    pub fn replace_constraints(&mut self, new_constraints: Vec<Expression>) -> Vec<Expression> {
143
        std::mem::replace(self.constraints_mut(), new_constraints)
144
    }
145

            
146
    /// Replaces the cnf clauses with `new_clauses`, returning the old ones.
147
    pub fn replace_clauses(&mut self, new_clauses: Vec<CnfClause>) -> Vec<CnfClause> {
148
        std::mem::replace(self.clauses_mut(), new_clauses)
149
    }
150

            
151
    /// Adds a top-level constraint.
152
13100
    pub fn add_constraint(&mut self, constraint: Expression) {
153
13100
        self.constraints_mut().push(constraint);
154
13100
    }
155

            
156
    /// Adds a cnf clause.
157
    pub fn add_clause(&mut self, clause: CnfClause) {
158
        self.clauses_mut().push(clause);
159
    }
160

            
161
    /// Adds top-level constraints.
162
447356
    pub fn add_constraints(&mut self, constraints: Vec<Expression>) {
163
447356
        self.constraints_mut().extend(constraints);
164
447356
    }
165

            
166
    /// Adds cnf clauses.
167
417270
    pub fn add_clauses(&mut self, clauses: Vec<CnfClause>) {
168
417270
        self.clauses_mut().extend(clauses);
169
417270
    }
170

            
171
    /// Adds a new symbol to the symbol table.
172
    pub fn add_symbol(&mut self, decl: DeclarationPtr) -> Option<()> {
173
        self.symbols_mut().insert(decl)
174
    }
175

            
176
    /// Converts the constraints in this model to a single expression suitable for use inside
177
    /// another expression tree.
178
74280
    pub fn into_single_expression(self) -> Expression {
179
74280
        let constraints = self.constraints().clone();
180
74280
        match constraints.len() {
181
            0 => Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
182
74280
            1 => constraints[0].clone(),
183
            _ => Expression::And(Metadata::new(), Moo::new(into_matrix_expr![constraints])),
184
        }
185
74280
    }
186

            
187
    /// Collects all ObjId values from the model using uniplate traversal.
188
606
    pub fn collect_stable_id_mapping(&self) -> HashMap<ObjId, ObjId> {
189
606
        fn visit_symbol_table(symbol_table: SymbolTablePtr, id_list: &mut IndexSet<ObjId>) {
190
606
            if !id_list.insert(symbol_table.id()) {
191
                return;
192
606
            }
193

            
194
606
            let table_ref = symbol_table.read();
195
746
            table_ref.iter_local().for_each(|(_, decl)| {
196
746
                id_list.insert(decl.id());
197
746
            });
198
606
        }
199

            
200
606
        let mut id_list: IndexSet<ObjId> = IndexSet::new();
201

            
202
606
        visit_symbol_table(self.symbols_ptr_unchecked().clone(), &mut id_list);
203

            
204
606
        let mut exprs: VecDeque<Expression> = self.universe_bi();
205
606
        if let Some(dominance) = &self.dominance {
206
            exprs.push_back(dominance.clone());
207
606
        }
208

            
209
606
        for symbol_table in Biplate::<SymbolTablePtr>::universe_bi(&exprs) {
210
            visit_symbol_table(symbol_table, &mut id_list);
211
        }
212
1518
        for declaration in Biplate::<DeclarationPtr>::universe_bi(&exprs) {
213
1518
            id_list.insert(declaration.id());
214
1518
        }
215

            
216
606
        let mut id_map = HashMap::new();
217
1352
        for (stable_id, original_id) in id_list.into_iter().enumerate() {
218
1352
            let type_name = original_id.type_name;
219
1352
            id_map.insert(
220
1352
                original_id,
221
1352
                ObjId {
222
1352
                    object_id: stable_id as u32,
223
1352
                    type_name,
224
1352
                },
225
1352
            );
226
1352
        }
227

            
228
606
        id_map
229
606
    }
230
}
231

            
232
impl Default for Model {
233
    fn default() -> Self {
234
        Self::new(default_context())
235
    }
236
}
237

            
238
impl Typeable for Model {
239
    fn return_type(&self) -> ReturnType {
240
        ReturnType::Bool
241
    }
242
}
243

            
244
impl Hash for Model {
245
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
246
        self.constraints.hash(state);
247
        self.symbols.hash(state);
248
        self.cnf_clauses.hash(state);
249
        self.search_order.hash(state);
250
        self.dominance.hash(state);
251
    }
252
}
253

            
254
// At time of writing (03/02/2025), the Uniplate derive macro doesn't like the lifetimes inside
255
// context, and we do not yet have a way of ignoring this field.
256
impl Uniplate for Model {
257
    fn uniplate(&self) -> (Tree<Self>, Box<dyn Fn(Tree<Self>) -> Self>) {
258
        let self2 = self.clone();
259
        (Tree::Zero, Box::new(move |_| self2.clone()))
260
    }
261
}
262

            
263
impl Biplate<Expression> for Model {
264
149622
    fn biplate(&self) -> (Tree<Expression>, Box<dyn Fn(Tree<Expression>) -> Self>) {
265
149622
        let (symtab_tree, symtab_ctx) =
266
149622
            <SymbolTable as Biplate<Expression>>::biplate(&self.symbols());
267

            
268
149622
        let dom_tree = match &self.dominance {
269
            Some(expr) => Tree::One(expr.clone()),
270
149622
            None => Tree::Zero,
271
        };
272

            
273
149622
        let tree = Tree::Many(VecDeque::from([
274
149622
            Tree::One(self.root().clone()),
275
149622
            symtab_tree,
276
149622
            dom_tree,
277
149622
        ]));
278

            
279
149622
        let self2 = self.clone();
280
149622
        let ctx = Box::new(move |x| {
281
            let Tree::Many(xs) = x else {
282
                panic!("Expected a tree with three children");
283
            };
284
            if xs.len() != 3 {
285
                panic!("Expected a tree with three children");
286
            }
287

            
288
            let Tree::One(root) = xs[0].clone() else {
289
                panic!("Expected root expression tree");
290
            };
291

            
292
            let symtab = symtab_ctx(xs[1].clone());
293
            let dominance = match xs[2].clone() {
294
                Tree::One(expr) => Some(expr),
295
                Tree::Zero => None,
296
                _ => panic!("Expected dominance tree"),
297
            };
298

            
299
            let mut self3 = self2.clone();
300

            
301
            let Expression::Root(_, _) = root else {
302
                bug!("root expression not root");
303
            };
304

            
305
            *self3.root_mut_unchecked() = root;
306
            *self3.symbols_mut() = symtab;
307
            self3.dominance = dominance;
308

            
309
            self3
310
        });
311

            
312
149622
        (tree, ctx)
313
149622
    }
314
}
315

            
316
impl Biplate<Atom> for Model {
317
    fn biplate(&self) -> (Tree<Atom>, Box<dyn Fn(Tree<Atom>) -> Self>) {
318
        let (expression_tree, rebuild_self) = <Model as Biplate<Expression>>::biplate(self);
319
        let (expression_list, rebuild_expression_tree) = expression_tree.list();
320

            
321
        let (atom_trees, reconstruct_exprs): (VecDeque<_>, VecDeque<_>) = expression_list
322
            .iter()
323
            .map(|e| <Expression as Biplate<Atom>>::biplate(e))
324
            .unzip();
325

            
326
        let tree = Tree::Many(atom_trees);
327
        let ctx = Box::new(move |atom_tree: Tree<Atom>| {
328
            let Tree::Many(atoms) = atom_tree else {
329
                panic!();
330
            };
331

            
332
            assert_eq!(
333
                atoms.len(),
334
                reconstruct_exprs.len(),
335
                "the number of children should not change when using Biplate"
336
            );
337

            
338
            let expression_list: VecDeque<Expression> = izip!(atoms, &reconstruct_exprs)
339
                .map(|(atom, recons)| recons(atom))
340
                .collect();
341

            
342
            let expression_tree = rebuild_expression_tree(expression_list);
343
            rebuild_self(expression_tree)
344
        });
345

            
346
        (tree, ctx)
347
    }
348
}
349

            
350
impl Biplate<Comprehension> for Model {
351
    fn biplate(
352
        &self,
353
    ) -> (
354
        Tree<Comprehension>,
355
        Box<dyn Fn(Tree<Comprehension>) -> Self>,
356
    ) {
357
        let (f1_tree, f1_ctx) = <_ as Biplate<Comprehension>>::biplate(&self.constraints);
358
        let (f2_tree, f2_ctx) = <SymbolTable as Biplate<Comprehension>>::biplate(&self.symbols());
359

            
360
        let tree = Tree::Many(VecDeque::from([f1_tree, f2_tree]));
361
        let self2 = self.clone();
362
        let ctx = Box::new(move |x| {
363
            let Tree::Many(xs) = x else {
364
                panic!();
365
            };
366

            
367
            let root = f1_ctx(xs[0].clone());
368
            let symtab = f2_ctx(xs[1].clone());
369

            
370
            let mut self3 = self2.clone();
371

            
372
            let Expression::Root(_, _) = &*root else {
373
                bug!("root expression not root");
374
            };
375

            
376
            *self3.symbols_mut() = symtab;
377
            self3.constraints = root;
378

            
379
            self3
380
        });
381

            
382
        (tree, ctx)
383
    }
384
}
385

            
386
impl Display for Model {
387
    #[allow(clippy::unwrap_used)]
388
66324
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
389
1327844
        for (name, decl) in self.symbols().clone().into_iter_local() {
390
1327844
            match &decl.kind() as &DeclarationKind {
391
                DeclarationKind::Find(_) => {
392
1314004
                    writeln!(
393
1314004
                        f,
394
                        "{}",
395
1314004
                        pretty_variable_declaration(&self.symbols(), &name).unwrap()
396
                    )?;
397
                }
398
                DeclarationKind::ValueLetting(_) | DeclarationKind::TemporaryValueLetting(_) => {
399
11760
                    writeln!(
400
11760
                        f,
401
                        "{}",
402
11760
                        pretty_value_letting_declaration(&self.symbols(), &name).unwrap()
403
                    )?;
404
                }
405
                DeclarationKind::DomainLetting(_) => {
406
1840
                    writeln!(
407
1840
                        f,
408
                        "{}",
409
1840
                        pretty_domain_letting_declaration(&self.symbols(), &name).unwrap()
410
                    )?;
411
                }
412
                DeclarationKind::Given(d) => {
413
                    writeln!(f, "given {name}: {d}")?;
414
                }
415
                DeclarationKind::Quantified(inner) => {
416
                    writeln!(f, "quantified {name}: {}", inner.domain())?;
417
                }
418
                DeclarationKind::RecordField(_) => {
419
240
                    writeln!(f)?;
420
                }
421
            }
422
        }
423

            
424
66324
        if !self.constraints().is_empty() {
425
61672
            writeln!(f, "\nsuch that\n")?;
426
61672
            writeln!(f, "{}", pretty_expressions_as_top_level(self.constraints()))?;
427
4652
        }
428

            
429
66324
        if !self.clauses().is_empty() {
430
5200
            writeln!(f, "\nclauses:\n")?;
431
5200
            writeln!(f, "{}", pretty_clauses(self.clauses()))?;
432
61124
        }
433
66324
        Ok(())
434
66324
    }
435
}
436

            
437
/// A model that is de/serializable using `serde`.
438
///
439
/// To turn this into a rewritable model, it needs to be initialised using
440
/// [`initialise`](SerdeModel::initialise).
441
#[serde_as]
442
#[derive(Clone, Debug, Serialize, Deserialize)]
443
pub struct SerdeModel {
444
    constraints: Moo<Expression>,
445
    #[serde_as(as = "PtrAsInner")]
446
    symbols: SymbolTablePtr,
447
    cnf_clauses: Vec<CnfClause>,
448
    search_order: Option<Vec<Name>>,
449
    dominance: Option<Expression>,
450
}
451

            
452
impl SerdeModel {
453
    /// Initialises the model for rewriting.
454
1200
    pub fn initialise(mut self, context: Arc<RwLock<Context<'static>>>) -> Option<Model> {
455
1200
        let mut tables: HashMap<ObjId, SymbolTablePtr> = HashMap::new();
456

            
457
        // Root model symbol table is always definitive.
458
1200
        tables.insert(self.symbols.id(), self.symbols.clone());
459

            
460
1200
        let mut exprs: VecDeque<Expression> = self.constraints.universe_bi();
461
1200
        if let Some(dominance) = &self.dominance {
462
            exprs.push_back(dominance.clone());
463
1200
        }
464

            
465
        // Some expressions (e.g. abstract comprehensions) contain additional symbol tables.
466
1200
        for table in Biplate::<SymbolTablePtr>::universe_bi(&exprs) {
467
            tables.entry(table.id()).or_insert(table);
468
        }
469

            
470
1200
        for table in tables.clone().into_values() {
471
1200
            let mut table_mut = table.write();
472
1200
            let parent_mut = table_mut.parent_mut_unchecked();
473

            
474
            #[allow(clippy::unwrap_used)]
475
1200
            if let Some(parent) = parent_mut {
476
                let parent_id = parent.id();
477
                *parent = tables.get(&parent_id).unwrap().clone();
478
1200
            }
479
        }
480

            
481
1200
        let mut all_declarations: HashMap<ObjId, DeclarationPtr> = HashMap::new();
482
1200
        for table in tables.values() {
483
1480
            for (_, decl) in table.read().iter_local() {
484
1480
                let id = decl.id();
485
1480
                all_declarations.insert(id, decl.clone());
486
1480
            }
487
        }
488

            
489
1200
        self.constraints = self.constraints.transform_bi(&move |decl: DeclarationPtr| {
490
840
            let id = decl.id();
491
840
            all_declarations
492
840
                .get(&id)
493
840
                .unwrap_or_else(|| {
494
                    panic!(
495
                        "A declaration used in the expression tree should exist in the symbol table. The missing declaration has id {id}."
496
                    )
497
                })
498
840
                .clone()
499
840
        });
500

            
501
1200
        Some(Model {
502
1200
            constraints: self.constraints,
503
1200
            symbols: self.symbols,
504
1200
            cnf_clauses: self.cnf_clauses,
505
1200
            search_order: self.search_order,
506
1200
            dominance: self.dominance,
507
1200
            context,
508
1200
        })
509
1200
    }
510
}
511

            
512
impl From<Model> for SerdeModel {
513
606
    fn from(val: Model) -> Self {
514
606
        SerdeModel {
515
606
            constraints: val.constraints,
516
606
            symbols: val.symbols,
517
606
            cnf_clauses: val.cnf_clauses,
518
606
            search_order: val.search_order,
519
606
            dominance: val.dominance,
520
606
        }
521
606
    }
522
}
523

            
524
impl Display for SerdeModel {
525
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
526
        let model = Model {
527
            constraints: self.constraints.clone(),
528
            symbols: self.symbols.clone(),
529
            cnf_clauses: self.cnf_clauses.clone(),
530
            search_order: self.search_order.clone(),
531
            dominance: self.dominance.clone(),
532
            context: default_context(),
533
        };
534
        std::fmt::Display::fmt(&model, f)
535
    }
536
}
537

            
538
impl SerdeModel {
539
    /// Collects all ObjId values from the model and maps them to stable sequential IDs.
540
606
    pub fn collect_stable_id_mapping(&self) -> HashMap<ObjId, ObjId> {
541
606
        let model = Model {
542
606
            constraints: self.constraints.clone(),
543
606
            symbols: self.symbols.clone(),
544
606
            cnf_clauses: self.cnf_clauses.clone(),
545
606
            search_order: self.search_order.clone(),
546
606
            dominance: self.dominance.clone(),
547
606
            context: default_context(),
548
606
        };
549
606
        model.collect_stable_id_mapping()
550
606
    }
551
}