1
use std::collections::{HashMap, VecDeque};
2
use std::fmt::{Debug, Display};
3
use std::hash::Hash;
4
use std::sync::{Arc, RwLock};
5

            
6
use crate::ast::Domain;
7
use crate::context::Context;
8
use crate::{bug, into_matrix_expr};
9
use derivative::Derivative;
10
use indexmap::IndexSet;
11
use itertools::izip;
12
use parking_lot::{RwLockReadGuard, RwLockWriteGuard};
13
use serde::{Deserialize, Serialize};
14
use serde_with::serde_as;
15
use uniplate::{Biplate, Tree, Uniplate};
16

            
17
use super::serde::{HasId, ObjId, PtrAsInner};
18
use super::{
19
    Atom, CnfClause, DeclarationPtr, Expression, Literal, Metadata, Moo, Name, ReturnType,
20
    SymbolTable, SymbolTablePtr, Typeable,
21
    comprehension::Comprehension,
22
    declaration::DeclarationKind,
23
    pretty::{
24
        pretty_clauses, pretty_domain_letting_declaration, pretty_expressions_as_top_level,
25
        pretty_value_letting_declaration, pretty_variable_declaration,
26
    },
27
};
28

            
29
/// An Essence model.
30
#[serde_as]
31
#[derive(Derivative, Clone, Debug, Serialize, Deserialize)]
32
#[derivative(PartialEq, Eq)]
33
pub struct Model {
34
    constraints: Moo<Expression>,
35
    #[serde_as(as = "PtrAsInner")]
36
    symbols: SymbolTablePtr,
37
    cnf_clauses: Vec<CnfClause>,
38

            
39
    pub search_order: Option<Vec<Name>>,
40
    pub dominance: Option<Expression>,
41

            
42
    #[serde(skip, default = "default_context")]
43
    #[derivative(PartialEq = "ignore")]
44
    pub context: Arc<RwLock<Context<'static>>>,
45
}
46

            
47
4690
fn default_context() -> Arc<RwLock<Context<'static>>> {
48
4690
    Arc::new(RwLock::new(Context::default()))
49
4690
}
50

            
51
impl Model {
52
14576
    fn new_empty(symbols: SymbolTablePtr, context: Arc<RwLock<Context<'static>>>) -> Model {
53
14576
        Model {
54
14576
            constraints: Moo::new(Expression::Root(Metadata::new(), vec![])),
55
14576
            symbols,
56
14576
            cnf_clauses: Vec::new(),
57
14576
            search_order: None,
58
14576
            dominance: None,
59
14576
            context,
60
14576
        }
61
14576
    }
62

            
63
    /// Creates a new top-level model from the given context.
64
10488
    pub fn new(context: Arc<RwLock<Context<'static>>>) -> Model {
65
10488
        Self::new_empty(SymbolTablePtr::new(), context)
66
10488
    }
67

            
68
    /// Creates a new model whose symbol table has `parent` as parent scope.
69
4088
    pub fn new_in_parent_scope(parent: SymbolTablePtr) -> Model {
70
4088
        Self::new_empty(SymbolTablePtr::with_parent(parent), default_context())
71
4088
    }
72

            
73
    /// The symbol table for this model as a pointer.
74
191576
    pub fn symbols_ptr_unchecked(&self) -> &SymbolTablePtr {
75
191576
        &self.symbols
76
191576
    }
77

            
78
    /// The symbol table for this model as a mutable pointer.
79
4088
    pub fn symbols_ptr_unchecked_mut(&mut self) -> &mut SymbolTablePtr {
80
4088
        &mut self.symbols
81
4088
    }
82

            
83
    /// The symbol table for this model as a reference.
84
158278726
    pub fn symbols(&self) -> RwLockReadGuard<'_, SymbolTable> {
85
158278726
        self.symbols.read()
86
158278726
    }
87

            
88
    /// The symbol table for this model as a mutable reference.
89
160696
    pub fn symbols_mut(&mut self) -> RwLockWriteGuard<'_, SymbolTable> {
90
160696
        self.symbols.write()
91
160696
    }
92

            
93
    /// The root node of this model.
94
1978668
    pub fn root(&self) -> &Expression {
95
1978668
        &self.constraints
96
1978668
    }
97

            
98
    /// The root node of this model, as a mutable reference.
99
    ///
100
    /// The caller is responsible for ensuring that the root node remains an [`Expression::Root`].
101
139584
    pub fn root_mut_unchecked(&mut self) -> &mut Expression {
102
139584
        Moo::make_mut(&mut self.constraints)
103
139584
    }
104

            
105
    /// Replaces the root node with `new_root`, returning the old root node.
106
139584
    pub fn replace_root(&mut self, new_root: Expression) -> Expression {
107
139584
        let Expression::Root(_, _) = new_root else {
108
            tracing::error!(new_root=?new_root,"new_root is not an Expression::Root");
109
            panic!("new_root is not an Expression::Root");
110
        };
111

            
112
139584
        std::mem::replace(self.root_mut_unchecked(), new_root)
113
139584
    }
114

            
115
    /// The top-level constraints in this model.
116
87226
    pub fn constraints(&self) -> &Vec<Expression> {
117
87226
        let Expression::Root(_, constraints) = self.constraints.as_ref() else {
118
            bug!("The top level expression in a model should be Expr::Root");
119
        };
120
87226
        constraints
121
87226
    }
122

            
123
    /// The cnf clauses in this model.
124
474974
    pub fn clauses(&self) -> &Vec<CnfClause> {
125
474974
        &self.cnf_clauses
126
474974
    }
127

            
128
    /// The top-level constraints in this model as a mutable vector.
129
154494
    pub fn constraints_mut(&mut self) -> &mut Vec<Expression> {
130
154494
        let Expression::Root(_, constraints) = Moo::make_mut(&mut self.constraints) else {
131
            bug!("The top level expression in a model should be Expr::Root");
132
        };
133

            
134
154494
        constraints
135
154494
    }
136

            
137
    /// The cnf clauses in this model as a mutable vector.
138
139584
    pub fn clauses_mut(&mut self) -> &mut Vec<CnfClause> {
139
139584
        &mut self.cnf_clauses
140
139584
    }
141

            
142
    /// Replaces the top-level constraints with `new_constraints`, returning the old ones.
143
    pub fn replace_constraints(&mut self, new_constraints: Vec<Expression>) -> Vec<Expression> {
144
        std::mem::replace(self.constraints_mut(), new_constraints)
145
    }
146

            
147
    /// Replaces the cnf clauses with `new_clauses`, returning the old ones.
148
    pub fn replace_clauses(&mut self, new_clauses: Vec<CnfClause>) -> Vec<CnfClause> {
149
        std::mem::replace(self.clauses_mut(), new_clauses)
150
    }
151

            
152
    /// Adds a top-level constraint.
153
4508
    pub fn add_constraint(&mut self, constraint: Expression) {
154
4508
        self.constraints_mut().push(constraint);
155
4508
    }
156

            
157
    /// Adds a cnf clause.
158
    pub fn add_clause(&mut self, clause: CnfClause) {
159
        self.clauses_mut().push(clause);
160
    }
161

            
162
    /// Adds top-level constraints.
163
149946
    pub fn add_constraints(&mut self, constraints: Vec<Expression>) {
164
149946
        self.constraints_mut().extend(constraints);
165
149946
    }
166

            
167
    /// Adds cnf clauses.
168
139584
    pub fn add_clauses(&mut self, clauses: Vec<CnfClause>) {
169
139584
        self.clauses_mut().extend(clauses);
170
139584
    }
171

            
172
    /// Adds a new symbol to the symbol table.
173
    pub fn add_symbol(&mut self, decl: DeclarationPtr) -> Option<()> {
174
        self.symbols_mut().insert(decl)
175
    }
176

            
177
    /// Converts the constraints in this model to a single expression suitable for use inside
178
    /// another expression tree.
179
24766
    pub fn into_single_expression(self) -> Expression {
180
24766
        let constraints = self.constraints().clone();
181
24766
        match constraints.len() {
182
            0 => Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
183
24766
            1 => constraints[0].clone(),
184
            _ => Expression::And(Metadata::new(), Moo::new(into_matrix_expr![constraints])),
185
        }
186
24766
    }
187

            
188
    /// Collects all ObjId values from the model using uniplate traversal.
189
602
    pub fn collect_stable_id_mapping(&self) -> HashMap<ObjId, ObjId> {
190
602
        fn visit_symbol_table(symbol_table: SymbolTablePtr, id_list: &mut IndexSet<ObjId>) {
191
602
            if !id_list.insert(symbol_table.id()) {
192
                return;
193
602
            }
194

            
195
602
            let table_ref = symbol_table.read();
196
742
            table_ref.iter_local().for_each(|(_, decl)| {
197
742
                id_list.insert(decl.id());
198
742
            });
199
602
        }
200

            
201
602
        let mut id_list: IndexSet<ObjId> = IndexSet::new();
202

            
203
602
        visit_symbol_table(self.symbols_ptr_unchecked().clone(), &mut id_list);
204

            
205
602
        let mut exprs: VecDeque<Expression> = self.universe_bi();
206
602
        if let Some(dominance) = &self.dominance {
207
            exprs.push_back(dominance.clone());
208
602
        }
209

            
210
602
        for symbol_table in Biplate::<SymbolTablePtr>::universe_bi(&exprs) {
211
            visit_symbol_table(symbol_table, &mut id_list);
212
        }
213
1506
        for declaration in Biplate::<DeclarationPtr>::universe_bi(&exprs) {
214
1506
            id_list.insert(declaration.id());
215
1506
        }
216

            
217
602
        let mut id_map = HashMap::new();
218
1344
        for (stable_id, original_id) in id_list.into_iter().enumerate() {
219
1344
            let type_name = original_id.type_name;
220
1344
            id_map.insert(
221
1344
                original_id,
222
1344
                ObjId {
223
1344
                    object_id: stable_id as u32,
224
1344
                    type_name,
225
1344
                },
226
1344
            );
227
1344
        }
228

            
229
602
        id_map
230
602
    }
231
}
232

            
233
impl Default for Model {
234
    fn default() -> Self {
235
        Self::new(default_context())
236
    }
237
}
238

            
239
impl Typeable for Model {
240
    fn return_type(&self) -> ReturnType {
241
        ReturnType::Bool
242
    }
243
}
244

            
245
impl Hash for Model {
246
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
247
        self.constraints.hash(state);
248
        self.symbols.hash(state);
249
        self.cnf_clauses.hash(state);
250
        self.search_order.hash(state);
251
        self.dominance.hash(state);
252
    }
253
}
254

            
255
// At time of writing (03/02/2025), the Uniplate derive macro doesn't like the lifetimes inside
256
// context, and we do not yet have a way of ignoring this field.
257
impl Uniplate for Model {
258
    fn uniplate(&self) -> (Tree<Self>, Box<dyn Fn(Tree<Self>) -> Self>) {
259
        let self2 = self.clone();
260
        (Tree::Zero, Box::new(move |_| self2.clone()))
261
    }
262
}
263

            
264
impl Biplate<Expression> for Model {
265
150036
    fn biplate(&self) -> (Tree<Expression>, Box<dyn Fn(Tree<Expression>) -> Self>) {
266
150036
        let (symtab_tree, symtab_ctx) =
267
150036
            <SymbolTable as Biplate<Expression>>::biplate(&self.symbols());
268

            
269
150036
        let dom_tree = match &self.dominance {
270
            Some(expr) => Tree::One(expr.clone()),
271
150036
            None => Tree::Zero,
272
        };
273

            
274
150036
        let tree = Tree::Many(VecDeque::from([
275
150036
            Tree::One(self.root().clone()),
276
150036
            symtab_tree,
277
150036
            dom_tree,
278
150036
        ]));
279

            
280
150036
        let self2 = self.clone();
281
150036
        let ctx = Box::new(move |x| {
282
            let Tree::Many(xs) = x else {
283
                panic!("Expected a tree with three children");
284
            };
285
            if xs.len() != 3 {
286
                panic!("Expected a tree with three children");
287
            }
288

            
289
            let Tree::One(root) = xs[0].clone() else {
290
                panic!("Expected root expression tree");
291
            };
292

            
293
            let symtab = symtab_ctx(xs[1].clone());
294
            let dominance = match xs[2].clone() {
295
                Tree::One(expr) => Some(expr),
296
                Tree::Zero => None,
297
                _ => panic!("Expected dominance tree"),
298
            };
299

            
300
            let mut self3 = self2.clone();
301

            
302
            let Expression::Root(_, _) = root else {
303
                bug!("root expression not root");
304
            };
305

            
306
            *self3.root_mut_unchecked() = root;
307
            *self3.symbols_mut() = symtab;
308
            self3.dominance = dominance;
309

            
310
            self3
311
        });
312

            
313
150036
        (tree, ctx)
314
150036
    }
315
}
316

            
317
impl Biplate<Atom> for Model {
318
    fn biplate(&self) -> (Tree<Atom>, Box<dyn Fn(Tree<Atom>) -> Self>) {
319
        let (expression_tree, rebuild_self) = <Model as Biplate<Expression>>::biplate(self);
320
        let (expression_list, rebuild_expression_tree) = expression_tree.list();
321

            
322
        let (atom_trees, reconstruct_exprs): (VecDeque<_>, VecDeque<_>) = expression_list
323
            .iter()
324
            .map(|e| <Expression as Biplate<Atom>>::biplate(e))
325
            .unzip();
326

            
327
        let tree = Tree::Many(atom_trees);
328
        let ctx = Box::new(move |atom_tree: Tree<Atom>| {
329
            let Tree::Many(atoms) = atom_tree else {
330
                panic!();
331
            };
332

            
333
            assert_eq!(
334
                atoms.len(),
335
                reconstruct_exprs.len(),
336
                "the number of children should not change when using Biplate"
337
            );
338

            
339
            let expression_list: VecDeque<Expression> = izip!(atoms, &reconstruct_exprs)
340
                .map(|(atom, recons)| recons(atom))
341
                .collect();
342

            
343
            let expression_tree = rebuild_expression_tree(expression_list);
344
            rebuild_self(expression_tree)
345
        });
346

            
347
        (tree, ctx)
348
    }
349
}
350

            
351
impl Biplate<Comprehension> for Model {
352
    fn biplate(
353
        &self,
354
    ) -> (
355
        Tree<Comprehension>,
356
        Box<dyn Fn(Tree<Comprehension>) -> Self>,
357
    ) {
358
        let (f1_tree, f1_ctx) = <_ as Biplate<Comprehension>>::biplate(&self.constraints);
359
        let (f2_tree, f2_ctx) = <SymbolTable as Biplate<Comprehension>>::biplate(&self.symbols());
360

            
361
        let tree = Tree::Many(VecDeque::from([f1_tree, f2_tree]));
362
        let self2 = self.clone();
363
        let ctx = Box::new(move |x| {
364
            let Tree::Many(xs) = x else {
365
                panic!();
366
            };
367

            
368
            let root = f1_ctx(xs[0].clone());
369
            let symtab = f2_ctx(xs[1].clone());
370

            
371
            let mut self3 = self2.clone();
372

            
373
            let Expression::Root(_, _) = &*root else {
374
                bug!("root expression not root");
375
            };
376

            
377
            *self3.symbols_mut() = symtab;
378
            self3.constraints = root;
379

            
380
            self3
381
        });
382

            
383
        (tree, ctx)
384
    }
385
}
386

            
387
impl Display for Model {
388
    #[allow(clippy::unwrap_used)]
389
22952
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390
444826
        for (name, decl) in self.symbols().clone().into_iter_local() {
391
444826
            match &decl.kind() as &DeclarationKind {
392
                DeclarationKind::Find(_) => {
393
440102
                    writeln!(
394
440102
                        f,
395
                        "{}",
396
440102
                        pretty_variable_declaration(&self.symbols(), &name).unwrap()
397
                    )?;
398
                }
399
                DeclarationKind::ValueLetting(_) | DeclarationKind::TemporaryValueLetting(_) => {
400
4004
                    writeln!(
401
4004
                        f,
402
                        "{}",
403
4004
                        pretty_value_letting_declaration(&self.symbols(), &name).unwrap()
404
                    )?;
405
                }
406
                DeclarationKind::DomainLetting(_) => {
407
640
                    writeln!(
408
640
                        f,
409
                        "{}",
410
640
                        pretty_domain_letting_declaration(&self.symbols(), &name).unwrap()
411
                    )?;
412
                }
413
                DeclarationKind::Given(d) => {
414
                    writeln!(f, "given {name}: {d}")?;
415
                }
416
                DeclarationKind::Quantified(inner) => {
417
                    writeln!(f, "quantified {name}: {}", inner.domain())?;
418
                }
419
                DeclarationKind::RecordField(_) => {
420
80
                    writeln!(f)?;
421
                }
422
            }
423
        }
424

            
425
22952
        if !self.constraints().is_empty() {
426
20948
            writeln!(f, "\nsuch that\n")?;
427
20948
            writeln!(f, "{}", pretty_expressions_as_top_level(self.constraints()))?;
428
2004
        }
429

            
430
22952
        if !self.clauses().is_empty() {
431
1740
            writeln!(f, "\nclauses:\n")?;
432
1740
            writeln!(f, "{}", pretty_clauses(self.clauses()))?;
433
21212
        }
434
22952
        Ok(())
435
22952
    }
436
}
437

            
438
/// A model that is de/serializable using `serde`.
439
///
440
/// To turn this into a rewritable model, it needs to be initialised using
441
/// [`initialise`](SerdeModel::initialise).
442
#[serde_as]
443
#[derive(Clone, Debug, Serialize, Deserialize)]
444
pub struct SerdeModel {
445
    constraints: Moo<Expression>,
446
    #[serde_as(as = "PtrAsInner")]
447
    symbols: SymbolTablePtr,
448
    cnf_clauses: Vec<CnfClause>,
449
    search_order: Option<Vec<Name>>,
450
    dominance: Option<Expression>,
451
}
452

            
453
impl SerdeModel {
454
    /// Initialises the model for rewriting.
455
1200
    pub fn initialise(mut self, context: Arc<RwLock<Context<'static>>>) -> Option<Model> {
456
1200
        let mut tables: HashMap<ObjId, SymbolTablePtr> = HashMap::new();
457

            
458
        // Root model symbol table is always definitive.
459
1200
        tables.insert(self.symbols.id(), self.symbols.clone());
460

            
461
1200
        let mut exprs: VecDeque<Expression> = self.constraints.universe_bi();
462
1200
        if let Some(dominance) = &self.dominance {
463
            exprs.push_back(dominance.clone());
464
1200
        }
465

            
466
        // Some expressions (e.g. abstract comprehensions) contain additional symbol tables.
467
1200
        for table in Biplate::<SymbolTablePtr>::universe_bi(&exprs) {
468
            tables.entry(table.id()).or_insert(table);
469
        }
470

            
471
1200
        for table in tables.clone().into_values() {
472
1200
            let mut table_mut = table.write();
473
1200
            let parent_mut = table_mut.parent_mut_unchecked();
474

            
475
            #[allow(clippy::unwrap_used)]
476
1200
            if let Some(parent) = parent_mut {
477
                let parent_id = parent.id();
478
                *parent = tables.get(&parent_id).unwrap().clone();
479
1200
            }
480
        }
481

            
482
1200
        let mut all_declarations: HashMap<ObjId, DeclarationPtr> = HashMap::new();
483
1200
        for table in tables.values() {
484
1480
            for (_, decl) in table.read().iter_local() {
485
1480
                let id = decl.id();
486
1480
                all_declarations.insert(id, decl.clone());
487
1480
            }
488
        }
489

            
490
1200
        self.constraints = self.constraints.transform_bi(&move |decl: DeclarationPtr| {
491
840
            let id = decl.id();
492
840
            all_declarations
493
840
                .get(&id)
494
840
                .unwrap_or_else(|| {
495
                    panic!(
496
                        "A declaration used in the expression tree should exist in the symbol table. The missing declaration has id {id}."
497
                    )
498
                })
499
840
                .clone()
500
840
        });
501

            
502
1200
        Some(Model {
503
1200
            constraints: self.constraints,
504
1200
            symbols: self.symbols,
505
1200
            cnf_clauses: self.cnf_clauses,
506
1200
            search_order: self.search_order,
507
1200
            dominance: self.dominance,
508
1200
            context,
509
1200
        })
510
1200
    }
511
}
512

            
513
impl From<Model> for SerdeModel {
514
602
    fn from(val: Model) -> Self {
515
602
        SerdeModel {
516
602
            constraints: val.constraints,
517
602
            symbols: val.symbols,
518
602
            cnf_clauses: val.cnf_clauses,
519
602
            search_order: val.search_order,
520
602
            dominance: val.dominance,
521
602
        }
522
602
    }
523
}
524

            
525
impl Display for SerdeModel {
526
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
527
        let model = Model {
528
            constraints: self.constraints.clone(),
529
            symbols: self.symbols.clone(),
530
            cnf_clauses: self.cnf_clauses.clone(),
531
            search_order: self.search_order.clone(),
532
            dominance: self.dominance.clone(),
533
            context: default_context(),
534
        };
535
        std::fmt::Display::fmt(&model, f)
536
    }
537
}
538

            
539
impl SerdeModel {
540
    /// Collects all ObjId values from the model and maps them to stable sequential IDs.
541
602
    pub fn collect_stable_id_mapping(&self) -> HashMap<ObjId, ObjId> {
542
602
        let model = Model {
543
602
            constraints: self.constraints.clone(),
544
602
            symbols: self.symbols.clone(),
545
602
            cnf_clauses: self.cnf_clauses.clone(),
546
602
            search_order: self.search_order.clone(),
547
602
            dominance: self.dominance.clone(),
548
602
            context: default_context(),
549
602
        };
550
602
        model.collect_stable_id_mapping()
551
602
    }
552
}
553

            
554
/// A struct for the information about expressions
555
#[serde_as]
556
#[derive(Serialize)]
557
pub struct ExprInfo {
558
    pretty: String,
559
    domain: Option<Moo<Domain>>,
560
    children: Vec<ExprInfo>,
561
}
562

            
563
impl ExprInfo {
564
    pub fn create(expr: &Expression) -> ExprInfo {
565
        let pretty = expr.to_string();
566
        let domain = expr.domain_of();
567
        let children = expr.children().iter().map(Self::create).collect();
568

            
569
        ExprInfo {
570
            pretty,
571
            domain,
572
            children,
573
        }
574
    }
575
}