1
use std::collections::{HashMap, VecDeque};
2
use std::fmt::{Debug, Display};
3
use std::hash::Hash;
4
use std::sync::{Arc, RwLock};
5

            
6
use crate::ast::Domain;
7
use crate::context::Context;
8
use crate::{bug, into_matrix_expr};
9
use derivative::Derivative;
10
use indexmap::IndexSet;
11
use itertools::izip;
12
use parking_lot::{RwLockReadGuard, RwLockWriteGuard};
13
use serde::{Deserialize, Serialize};
14
use serde_with::serde_as;
15
use uniplate::{Biplate, Tree, Uniplate};
16

            
17
use super::serde::{HasId, ObjId, PtrAsInner};
18
use super::{
19
    Atom, CnfClause, DeclarationPtr, Expression, Literal, Metadata, Moo, Name, ReturnType,
20
    SymbolTable, SymbolTablePtr, Typeable,
21
    comprehension::Comprehension,
22
    declaration::DeclarationKind,
23
    pretty::{
24
        pretty_clauses, pretty_domain_letting_declaration, pretty_expressions_as_top_level,
25
        pretty_value_letting_declaration, pretty_variable_declaration,
26
    },
27
};
28

            
29
/// An Essence model.
30
#[serde_as]
31
#[derive(Derivative, Clone, Debug, Serialize, Deserialize)]
32
#[derivative(PartialEq, Eq)]
33
pub struct Model {
34
    constraints: Moo<Expression>,
35
    #[serde_as(as = "PtrAsInner")]
36
    symbols: SymbolTablePtr,
37
    cnf_clauses: Vec<CnfClause>,
38

            
39
    pub search_order: Option<Vec<Name>>,
40
    pub dominance: Option<Expression>,
41

            
42
    #[serde(skip, default = "default_context")]
43
    #[derivative(PartialEq = "ignore")]
44
    pub context: Arc<RwLock<Context<'static>>>,
45
}
46

            
47
4790
fn default_context() -> Arc<RwLock<Context<'static>>> {
48
4790
    Arc::new(RwLock::new(Context::default()))
49
4790
}
50

            
51
impl Model {
52
15522
    fn new_empty(symbols: SymbolTablePtr, context: Arc<RwLock<Context<'static>>>) -> Model {
53
15522
        Model {
54
15522
            constraints: Moo::new(Expression::Root(Metadata::new(), vec![])),
55
15522
            symbols,
56
15522
            cnf_clauses: Vec::new(),
57
15522
            search_order: None,
58
15522
            dominance: None,
59
15522
            context,
60
15522
        }
61
15522
    }
62

            
63
    /// Creates a new top-level model from the given context.
64
11434
    pub fn new(context: Arc<RwLock<Context<'static>>>) -> Model {
65
11434
        Self::new_empty(SymbolTablePtr::new(), context)
66
11434
    }
67

            
68
    /// Creates a new model whose symbol table has `parent` as parent scope.
69
4088
    pub fn new_in_parent_scope(parent: SymbolTablePtr) -> Model {
70
4088
        Self::new_empty(SymbolTablePtr::with_parent(parent), default_context())
71
4088
    }
72

            
73
    /// The symbol table for this model as a pointer.
74
207100
    pub fn symbols_ptr_unchecked(&self) -> &SymbolTablePtr {
75
207100
        &self.symbols
76
207100
    }
77

            
78
    /// The symbol table for this model as a mutable pointer.
79
4088
    pub fn symbols_ptr_unchecked_mut(&mut self) -> &mut SymbolTablePtr {
80
4088
        &mut self.symbols
81
4088
    }
82

            
83
    /// The symbol table for this model as a reference.
84
166617602
    pub fn symbols(&self) -> RwLockReadGuard<'_, SymbolTable> {
85
166617602
        self.symbols.read()
86
166617602
    }
87

            
88
    /// The symbol table for this model as a mutable reference.
89
173544
    pub fn symbols_mut(&mut self) -> RwLockWriteGuard<'_, SymbolTable> {
90
173544
        self.symbols.write()
91
173544
    }
92

            
93
    /// The root node of this model.
94
2090988
    pub fn root(&self) -> &Expression {
95
2090988
        &self.constraints
96
2090988
    }
97

            
98
    /// The root node of this model, as a mutable reference.
99
    ///
100
    /// The caller is responsible for ensuring that the root node remains an [`Expression::Root`].
101
150468
    pub fn root_mut_unchecked(&mut self) -> &mut Expression {
102
150468
        Moo::make_mut(&mut self.constraints)
103
150468
    }
104

            
105
    /// Replaces the root node with `new_root`, returning the old root node.
106
150468
    pub fn replace_root(&mut self, new_root: Expression) -> Expression {
107
150468
        let Expression::Root(_, _) = new_root else {
108
            tracing::error!(new_root=?new_root,"new_root is not an Expression::Root");
109
            panic!("new_root is not an Expression::Root");
110
        };
111

            
112
150468
        std::mem::replace(self.root_mut_unchecked(), new_root)
113
150468
    }
114

            
115
    /// The top-level constraints in this model.
116
90474
    pub fn constraints(&self) -> &Vec<Expression> {
117
90474
        let Expression::Root(_, constraints) = self.constraints.as_ref() else {
118
            bug!("The top level expression in a model should be Expr::Root");
119
        };
120
90474
        constraints
121
90474
    }
122

            
123
    /// The cnf clauses in this model.
124
511972
    pub fn clauses(&self) -> &Vec<CnfClause> {
125
511972
        &self.cnf_clauses
126
511972
    }
127

            
128
    /// The top-level constraints in this model as a mutable vector.
129
166110
    pub fn constraints_mut(&mut self) -> &mut Vec<Expression> {
130
166110
        let Expression::Root(_, constraints) = Moo::make_mut(&mut self.constraints) else {
131
            bug!("The top level expression in a model should be Expr::Root");
132
        };
133

            
134
166110
        constraints
135
166110
    }
136

            
137
    /// The cnf clauses in this model as a mutable vector.
138
150468
    pub fn clauses_mut(&mut self) -> &mut Vec<CnfClause> {
139
150468
        &mut self.cnf_clauses
140
150468
    }
141

            
142
    /// Replaces the top-level constraints with `new_constraints`, returning the old ones.
143
    pub fn replace_constraints(&mut self, new_constraints: Vec<Expression>) -> Vec<Expression> {
144
        std::mem::replace(self.constraints_mut(), new_constraints)
145
    }
146

            
147
    /// Replaces the cnf clauses with `new_clauses`, returning the old ones.
148
    pub fn replace_clauses(&mut self, new_clauses: Vec<CnfClause>) -> Vec<CnfClause> {
149
        std::mem::replace(self.clauses_mut(), new_clauses)
150
    }
151

            
152
    /// Adds a top-level constraint.
153
4648
    pub fn add_constraint(&mut self, constraint: Expression) {
154
4648
        self.constraints_mut().push(constraint);
155
4648
    }
156

            
157
    /// Adds a cnf clause.
158
    pub fn add_clause(&mut self, clause: CnfClause) {
159
        self.clauses_mut().push(clause);
160
    }
161

            
162
    /// Adds top-level constraints.
163
161422
    pub fn add_constraints(&mut self, constraints: Vec<Expression>) {
164
161422
        self.constraints_mut().extend(constraints);
165
161422
    }
166

            
167
    /// Adds cnf clauses.
168
150468
    pub fn add_clauses(&mut self, clauses: Vec<CnfClause>) {
169
150468
        self.clauses_mut().extend(clauses);
170
150468
    }
171

            
172
    /// Adds a new symbol to the symbol table.
173
    pub fn add_symbol(&mut self, decl: DeclarationPtr) -> Option<()> {
174
        self.symbols_mut().insert(decl)
175
    }
176

            
177
    /// Converts the constraints in this model to a single expression suitable for use inside
178
    /// another expression tree.
179
24766
    pub fn into_single_expression(self) -> Expression {
180
24766
        let constraints = self.constraints().clone();
181
24766
        match constraints.len() {
182
            0 => Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
183
24766
            1 => constraints[0].clone(),
184
            _ => Expression::And(Metadata::new(), Moo::new(into_matrix_expr![constraints])),
185
        }
186
24766
    }
187

            
188
    /// Collects all ObjId values from the model using uniplate traversal.
189
702
    pub fn collect_stable_id_mapping(&self) -> HashMap<ObjId, ObjId> {
190
702
        fn visit_symbol_table(symbol_table: SymbolTablePtr, id_list: &mut IndexSet<ObjId>) {
191
702
            if !id_list.insert(symbol_table.id()) {
192
                return;
193
702
            }
194

            
195
702
            let table_ref = symbol_table.read();
196
862
            table_ref.iter_local().for_each(|(_, decl)| {
197
862
                id_list.insert(decl.id());
198
862
            });
199
702
        }
200

            
201
702
        let mut id_list: IndexSet<ObjId> = IndexSet::new();
202

            
203
702
        visit_symbol_table(self.symbols_ptr_unchecked().clone(), &mut id_list);
204

            
205
702
        let mut exprs: VecDeque<Expression> = self.universe_bi();
206
702
        if let Some(dominance) = &self.dominance {
207
            exprs.push_back(dominance.clone());
208
702
        }
209

            
210
702
        for symbol_table in Biplate::<SymbolTablePtr>::universe_bi(&exprs) {
211
            visit_symbol_table(symbol_table, &mut id_list);
212
        }
213
1786
        for declaration in Biplate::<DeclarationPtr>::universe_bi(&exprs) {
214
1786
            id_list.insert(declaration.id());
215
1786
        }
216

            
217
702
        let mut id_map = HashMap::new();
218
1564
        for (stable_id, original_id) in id_list.into_iter().enumerate() {
219
1564
            let type_name = original_id.type_name;
220
1564
            id_map.insert(
221
1564
                original_id,
222
1564
                ObjId {
223
1564
                    object_id: stable_id as u32,
224
1564
                    type_name,
225
1564
                },
226
1564
            );
227
1564
        }
228

            
229
702
        id_map
230
702
    }
231
}
232

            
233
impl Default for Model {
234
    fn default() -> Self {
235
        Self::new(default_context())
236
    }
237
}
238

            
239
impl Typeable for Model {
240
    fn return_type(&self) -> ReturnType {
241
        ReturnType::Bool
242
    }
243
}
244

            
245
impl Hash for Model {
246
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
247
        self.constraints.hash(state);
248
        self.symbols.hash(state);
249
        self.cnf_clauses.hash(state);
250
        self.search_order.hash(state);
251
        self.dominance.hash(state);
252
    }
253
}
254

            
255
// At time of writing (03/02/2025), the Uniplate derive macro doesn't like the lifetimes inside
256
// context, and we do not yet have a way of ignoring this field.
257
impl Uniplate for Model {
258
    fn uniplate(&self) -> (Tree<Self>, Box<dyn Fn(Tree<Self>) -> Self>) {
259
        let self2 = self.clone();
260
        (Tree::Zero, Box::new(move |_| self2.clone()))
261
    }
262
}
263

            
264
impl Biplate<Expression> for Model {
265
161842
    fn biplate(&self) -> (Tree<Expression>, Box<dyn Fn(Tree<Expression>) -> Self>) {
266
161842
        let (symtab_tree, symtab_ctx) =
267
161842
            <SymbolTable as Biplate<Expression>>::biplate(&self.symbols());
268

            
269
161842
        let dom_tree = match &self.dominance {
270
            Some(expr) => Tree::One(expr.clone()),
271
161842
            None => Tree::Zero,
272
        };
273

            
274
161842
        let tree = Tree::Many(VecDeque::from([
275
161842
            Tree::One(self.root().clone()),
276
161842
            symtab_tree,
277
161842
            dom_tree,
278
161842
        ]));
279

            
280
161842
        let self2 = self.clone();
281
161842
        let ctx = Box::new(move |x| {
282
            let Tree::Many(xs) = x else {
283
                panic!("Expected a tree with three children");
284
            };
285
            if xs.len() != 3 {
286
                panic!("Expected a tree with three children");
287
            }
288

            
289
            let Tree::One(root) = xs[0].clone() else {
290
                panic!("Expected root expression tree");
291
            };
292

            
293
            let symtab = symtab_ctx(xs[1].clone());
294
            let dominance = match xs[2].clone() {
295
                Tree::One(expr) => Some(expr),
296
                Tree::Zero => None,
297
                _ => panic!("Expected dominance tree"),
298
            };
299

            
300
            let mut self3 = self2.clone();
301

            
302
            let Expression::Root(_, _) = root else {
303
                bug!("root expression not root");
304
            };
305

            
306
            *self3.root_mut_unchecked() = root;
307
            *self3.symbols_mut() = symtab;
308
            self3.dominance = dominance;
309

            
310
            self3
311
        });
312

            
313
161842
        (tree, ctx)
314
161842
    }
315
}
316

            
317
impl Biplate<Atom> for Model {
318
    fn biplate(&self) -> (Tree<Atom>, Box<dyn Fn(Tree<Atom>) -> Self>) {
319
        let (expression_tree, rebuild_self) = <Model as Biplate<Expression>>::biplate(self);
320
        let (expression_list, rebuild_expression_tree) = expression_tree.list();
321

            
322
        let (atom_trees, reconstruct_exprs): (VecDeque<_>, VecDeque<_>) = expression_list
323
            .iter()
324
            .map(|e| <Expression as Biplate<Atom>>::biplate(e))
325
            .unzip();
326

            
327
        let tree = Tree::Many(atom_trees);
328
        let ctx = Box::new(move |atom_tree: Tree<Atom>| {
329
            let Tree::Many(atoms) = atom_tree else {
330
                panic!();
331
            };
332

            
333
            assert_eq!(
334
                atoms.len(),
335
                reconstruct_exprs.len(),
336
                "the number of children should not change when using Biplate"
337
            );
338

            
339
            let expression_list: VecDeque<Expression> = izip!(atoms, &reconstruct_exprs)
340
                .map(|(atom, recons)| recons(atom))
341
                .collect();
342

            
343
            let expression_tree = rebuild_expression_tree(expression_list);
344
            rebuild_self(expression_tree)
345
        });
346

            
347
        (tree, ctx)
348
    }
349
}
350

            
351
impl Biplate<Comprehension> for Model {
352
    fn biplate(
353
        &self,
354
    ) -> (
355
        Tree<Comprehension>,
356
        Box<dyn Fn(Tree<Comprehension>) -> Self>,
357
    ) {
358
        let (f1_tree, f1_ctx) = <_ as Biplate<Comprehension>>::biplate(&self.constraints);
359
        let (f2_tree, f2_ctx) = <SymbolTable as Biplate<Comprehension>>::biplate(&self.symbols());
360

            
361
        let tree = Tree::Many(VecDeque::from([f1_tree, f2_tree]));
362
        let self2 = self.clone();
363
        let ctx = Box::new(move |x| {
364
            let Tree::Many(xs) = x else {
365
                panic!();
366
            };
367

            
368
            let root = f1_ctx(xs[0].clone());
369
            let symtab = f2_ctx(xs[1].clone());
370

            
371
            let mut self3 = self2.clone();
372

            
373
            let Expression::Root(_, _) = &*root else {
374
                bug!("root expression not root");
375
            };
376

            
377
            *self3.symbols_mut() = symtab;
378
            self3.constraints = root;
379

            
380
            self3
381
        });
382

            
383
        (tree, ctx)
384
    }
385
}
386

            
387
impl Display for Model {
388
    #[allow(clippy::unwrap_used)]
389
24272
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390
575346
        for (name, decl) in self.symbols().clone().into_iter_local() {
391
575346
            match &decl.kind() as &DeclarationKind {
392
                DeclarationKind::Find(_) => {
393
570582
                    writeln!(
394
570582
                        f,
395
                        "{}",
396
570582
                        pretty_variable_declaration(&self.symbols(), &name).unwrap()
397
                    )?;
398
                }
399
                DeclarationKind::ValueLetting(_, _) | DeclarationKind::TemporaryValueLetting(_) => {
400
4044
                    writeln!(
401
4044
                        f,
402
                        "{}",
403
4044
                        pretty_value_letting_declaration(&self.symbols(), &name).unwrap()
404
                    )?;
405
                }
406
                DeclarationKind::DomainLetting(_) => {
407
640
                    writeln!(
408
640
                        f,
409
                        "{}",
410
640
                        pretty_domain_letting_declaration(&self.symbols(), &name).unwrap()
411
                    )?;
412
                }
413
                DeclarationKind::Given(d) => {
414
                    writeln!(f, "given {name}: {d}")?;
415
                }
416
                DeclarationKind::Quantified(inner) => {
417
                    writeln!(f, "quantified {name}: {}", inner.domain())?;
418
                }
419
                DeclarationKind::RecordField(_) => {
420
80
                    writeln!(f)?;
421
                }
422
            }
423
        }
424

            
425
24272
        if !self.constraints().is_empty() {
426
22188
            writeln!(f, "\nsuch that\n")?;
427
22188
            writeln!(f, "{}", pretty_expressions_as_top_level(self.constraints()))?;
428
2084
        }
429

            
430
24272
        if !self.clauses().is_empty() {
431
2020
            writeln!(f, "\nclauses:\n")?;
432
2020
            writeln!(f, "{}", pretty_clauses(self.clauses()))?;
433
22252
        }
434
24272
        Ok(())
435
24272
    }
436
}
437

            
438
/// A model that is de/serializable using `serde`.
439
///
440
/// To turn this into a rewritable model, it needs to be initialised using
441
/// [`initialise`](SerdeModel::initialise).
442
#[serde_as]
443
#[derive(Clone, Debug, Serialize, Deserialize)]
444
pub struct SerdeModel {
445
    constraints: Moo<Expression>,
446
    #[serde_as(as = "PtrAsInner")]
447
    symbols: SymbolTablePtr,
448
    cnf_clauses: Vec<CnfClause>,
449
    search_order: Option<Vec<Name>>,
450
    dominance: Option<Expression>,
451
}
452

            
453
impl SerdeModel {
454
    /// Initialises the model for rewriting.
455
1400
    pub fn initialise(mut self, context: Arc<RwLock<Context<'static>>>) -> Option<Model> {
456
1400
        let mut tables: HashMap<ObjId, SymbolTablePtr> = HashMap::new();
457

            
458
        // Root model symbol table is always definitive.
459
1400
        tables.insert(self.symbols.id(), self.symbols.clone());
460

            
461
1400
        let mut exprs: VecDeque<Expression> = self.constraints.universe_bi();
462
1400
        if let Some(dominance) = &self.dominance {
463
            exprs.push_back(dominance.clone());
464
1400
        }
465

            
466
        // Some expressions (e.g. abstract comprehensions) contain additional symbol tables.
467
1400
        for table in Biplate::<SymbolTablePtr>::universe_bi(&exprs) {
468
            tables.entry(table.id()).or_insert(table);
469
        }
470

            
471
1400
        for table in tables.clone().into_values() {
472
1400
            let mut table_mut = table.write();
473
1400
            let parent_mut = table_mut.parent_mut_unchecked();
474

            
475
            #[allow(clippy::unwrap_used)]
476
1400
            if let Some(parent) = parent_mut {
477
                let parent_id = parent.id();
478
                *parent = tables.get(&parent_id).unwrap().clone();
479
1400
            }
480
        }
481

            
482
1400
        let mut all_declarations: HashMap<ObjId, DeclarationPtr> = HashMap::new();
483
1400
        for table in tables.values() {
484
1720
            for (_, decl) in table.read().iter_local() {
485
1720
                let id = decl.id();
486
1720
                all_declarations.insert(id, decl.clone());
487
1720
            }
488
        }
489

            
490
1400
        self.constraints = self.constraints.transform_bi(&move |decl: DeclarationPtr| {
491
1000
            let id = decl.id();
492
1000
            all_declarations
493
1000
                .get(&id)
494
1000
                .unwrap_or_else(|| {
495
                    panic!(
496
                        "A declaration used in the expression tree should exist in the symbol table. The missing declaration has id {id}."
497
                    )
498
                })
499
1000
                .clone()
500
1000
        });
501

            
502
1400
        Some(Model {
503
1400
            constraints: self.constraints,
504
1400
            symbols: self.symbols,
505
1400
            cnf_clauses: self.cnf_clauses,
506
1400
            search_order: self.search_order,
507
1400
            dominance: self.dominance,
508
1400
            context,
509
1400
        })
510
1400
    }
511
}
512

            
513
impl From<Model> for SerdeModel {
514
702
    fn from(val: Model) -> Self {
515
702
        SerdeModel {
516
702
            constraints: val.constraints,
517
702
            symbols: val.symbols,
518
702
            cnf_clauses: val.cnf_clauses,
519
702
            search_order: val.search_order,
520
702
            dominance: val.dominance,
521
702
        }
522
702
    }
523
}
524

            
525
impl Display for SerdeModel {
526
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
527
        let model = Model {
528
            constraints: self.constraints.clone(),
529
            symbols: self.symbols.clone(),
530
            cnf_clauses: self.cnf_clauses.clone(),
531
            search_order: self.search_order.clone(),
532
            dominance: self.dominance.clone(),
533
            context: default_context(),
534
        };
535
        std::fmt::Display::fmt(&model, f)
536
    }
537
}
538

            
539
impl SerdeModel {
540
    /// Collects all ObjId values from the model and maps them to stable sequential IDs.
541
702
    pub fn collect_stable_id_mapping(&self) -> HashMap<ObjId, ObjId> {
542
702
        let model = Model {
543
702
            constraints: self.constraints.clone(),
544
702
            symbols: self.symbols.clone(),
545
702
            cnf_clauses: self.cnf_clauses.clone(),
546
702
            search_order: self.search_order.clone(),
547
702
            dominance: self.dominance.clone(),
548
702
            context: default_context(),
549
702
        };
550
702
        model.collect_stable_id_mapping()
551
702
    }
552
}
553

            
554
/// A struct for the information about expressions
555
#[serde_as]
556
#[derive(Serialize)]
557
pub struct ExprInfo {
558
    pretty: String,
559
    domain: Option<Moo<Domain>>,
560
    children: Vec<ExprInfo>,
561
}
562

            
563
impl ExprInfo {
564
    pub fn create(expr: &Expression) -> ExprInfo {
565
        let pretty = expr.to_string();
566
        let domain = expr.domain_of();
567
        let children = expr.children().iter().map(Self::create).collect();
568

            
569
        ExprInfo {
570
            pretty,
571
            domain,
572
            children,
573
        }
574
    }
575
}