1
use std::collections::{HashMap, VecDeque};
2
use std::fmt::{Debug, Display};
3
use std::hash::Hash;
4
use std::sync::{Arc, RwLock};
5

            
6
use crate::ast::Domain;
7
use crate::context::Context;
8
use crate::{bug, into_matrix_expr};
9
use derivative::Derivative;
10
use indexmap::IndexSet;
11
use itertools::izip;
12
use parking_lot::{RwLockReadGuard, RwLockWriteGuard};
13
use serde::{Deserialize, Serialize};
14
use serde_with::serde_as;
15
use uniplate::{Biplate, Tree, Uniplate};
16

            
17
use super::serde::{HasId, ObjId, PtrAsInner};
18
use super::{
19
    Atom, CnfClause, DeclarationPtr, Expression, Literal, Metadata, Moo, Name, ReturnType,
20
    SymbolTable, SymbolTablePtr, Typeable,
21
    comprehension::Comprehension,
22
    declaration::DeclarationKind,
23
    pretty::{
24
        pretty_clauses, pretty_domain_letting_declaration, pretty_expressions_as_top_level,
25
        pretty_value_letting_declaration, pretty_variable_declaration,
26
    },
27
};
28

            
29
/// An Essence model.
30
#[serde_as]
31
#[derive(Derivative, Clone, Debug, Serialize, Deserialize)]
32
#[derivative(PartialEq, Eq)]
33
pub struct Model {
34
    constraints: Moo<Expression>,
35
    #[serde_as(as = "PtrAsInner")]
36
    symbols: SymbolTablePtr,
37
    cnf_clauses: Vec<CnfClause>,
38

            
39
    pub search_order: Option<Vec<Name>>,
40
    pub dominance: Option<Expression>,
41

            
42
    #[serde(skip, default = "default_context")]
43
    #[derivative(PartialEq = "ignore")]
44
    pub context: Arc<RwLock<Context<'static>>>,
45
}
46

            
47
4916
fn default_context() -> Arc<RwLock<Context<'static>>> {
48
4916
    Arc::new(RwLock::new(Context::default()))
49
4916
}
50

            
51
impl Model {
52
15279
    fn new_empty(symbols: SymbolTablePtr, context: Arc<RwLock<Context<'static>>>) -> Model {
53
15279
        Model {
54
15279
            constraints: Moo::new(Expression::Root(Metadata::new(), vec![])),
55
15279
            symbols,
56
15279
            cnf_clauses: Vec::new(),
57
15279
            search_order: None,
58
15279
            dominance: None,
59
15279
            context,
60
15279
        }
61
15279
    }
62

            
63
    /// Creates a new top-level model from the given context.
64
10995
    pub fn new(context: Arc<RwLock<Context<'static>>>) -> Model {
65
10995
        Self::new_empty(SymbolTablePtr::new(), context)
66
10995
    }
67

            
68
    /// Creates a new model whose symbol table has `parent` as parent scope.
69
4284
    pub fn new_in_parent_scope(parent: SymbolTablePtr) -> Model {
70
4284
        Self::new_empty(SymbolTablePtr::with_parent(parent), default_context())
71
4284
    }
72

            
73
    /// The symbol table for this model as a pointer.
74
200638
    pub fn symbols_ptr_unchecked(&self) -> &SymbolTablePtr {
75
200638
        &self.symbols
76
200638
    }
77

            
78
    /// The symbol table for this model as a mutable pointer.
79
4284
    pub fn symbols_ptr_unchecked_mut(&mut self) -> &mut SymbolTablePtr {
80
4284
        &mut self.symbols
81
4284
    }
82

            
83
    /// The symbol table for this model as a reference.
84
164972066
    pub fn symbols(&self) -> RwLockReadGuard<'_, SymbolTable> {
85
164972066
        self.symbols.read()
86
164972066
    }
87

            
88
    /// The symbol table for this model as a mutable reference.
89
168232
    pub fn symbols_mut(&mut self) -> RwLockWriteGuard<'_, SymbolTable> {
90
168232
        self.symbols.write()
91
168232
    }
92

            
93
    /// The root node of this model.
94
2071594
    pub fn root(&self) -> &Expression {
95
2071594
        &self.constraints
96
2071594
    }
97

            
98
    /// The root node of this model, as a mutable reference.
99
    ///
100
    /// The caller is responsible for ensuring that the root node remains an [`Expression::Root`].
101
146128
    pub fn root_mut_unchecked(&mut self) -> &mut Expression {
102
146128
        Moo::make_mut(&mut self.constraints)
103
146128
    }
104

            
105
    /// Replaces the root node with `new_root`, returning the old root node.
106
146128
    pub fn replace_root(&mut self, new_root: Expression) -> Expression {
107
146128
        let Expression::Root(_, _) = new_root else {
108
            tracing::error!(new_root=?new_root,"new_root is not an Expression::Root");
109
            panic!("new_root is not an Expression::Root");
110
        };
111

            
112
146128
        std::mem::replace(self.root_mut_unchecked(), new_root)
113
146128
    }
114

            
115
    /// The top-level constraints in this model.
116
91547
    pub fn constraints(&self) -> &Vec<Expression> {
117
91547
        let Expression::Root(_, constraints) = self.constraints.as_ref() else {
118
            bug!("The top level expression in a model should be Expr::Root");
119
        };
120
91547
        constraints
121
91547
    }
122

            
123
    /// The cnf clauses in this model.
124
497409
    pub fn clauses(&self) -> &Vec<CnfClause> {
125
497409
        &self.cnf_clauses
126
497409
    }
127

            
128
    /// The top-level constraints in this model as a mutable vector.
129
161754
    pub fn constraints_mut(&mut self) -> &mut Vec<Expression> {
130
161754
        let Expression::Root(_, constraints) = Moo::make_mut(&mut self.constraints) else {
131
            bug!("The top level expression in a model should be Expr::Root");
132
        };
133

            
134
161754
        constraints
135
161754
    }
136

            
137
    /// The cnf clauses in this model as a mutable vector.
138
146128
    pub fn clauses_mut(&mut self) -> &mut Vec<CnfClause> {
139
146128
        &mut self.cnf_clauses
140
146128
    }
141

            
142
    /// Replaces the top-level constraints with `new_constraints`, returning the old ones.
143
    pub fn replace_constraints(&mut self, new_constraints: Vec<Expression>) -> Vec<Expression> {
144
        std::mem::replace(self.constraints_mut(), new_constraints)
145
    }
146

            
147
    /// Replaces the cnf clauses with `new_clauses`, returning the old ones.
148
    pub fn replace_clauses(&mut self, new_clauses: Vec<CnfClause>) -> Vec<CnfClause> {
149
        std::mem::replace(self.clauses_mut(), new_clauses)
150
    }
151

            
152
    /// Adds a top-level constraint.
153
4725
    pub fn add_constraint(&mut self, constraint: Expression) {
154
4725
        self.constraints_mut().push(constraint);
155
4725
    }
156

            
157
    /// Adds a cnf clause.
158
    pub fn add_clause(&mut self, clause: CnfClause) {
159
        self.clauses_mut().push(clause);
160
    }
161

            
162
    /// Adds top-level constraints.
163
156987
    pub fn add_constraints(&mut self, constraints: Vec<Expression>) {
164
156987
        self.constraints_mut().extend(constraints);
165
156987
    }
166

            
167
    /// Adds cnf clauses.
168
146128
    pub fn add_clauses(&mut self, clauses: Vec<CnfClause>) {
169
146128
        self.clauses_mut().extend(clauses);
170
146128
    }
171

            
172
    /// Adds a new symbol to the symbol table.
173
    pub fn add_symbol(&mut self, decl: DeclarationPtr) -> Option<()> {
174
        self.symbols_mut().insert(decl)
175
    }
176

            
177
    /// Converts the constraints in this model to a single expression suitable for use inside
178
    /// another expression tree.
179
25998
    pub fn into_single_expression(self) -> Expression {
180
25998
        let constraints = self.constraints().clone();
181
25998
        match constraints.len() {
182
            0 => Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
183
25998
            1 => constraints[0].clone(),
184
            _ => Expression::And(Metadata::new(), Moo::new(into_matrix_expr![constraints])),
185
        }
186
25998
    }
187

            
188
    /// Collects all ObjId values from the model using uniplate traversal.
189
632
    pub fn collect_stable_id_mapping(&self) -> HashMap<ObjId, ObjId> {
190
632
        fn visit_symbol_table(symbol_table: SymbolTablePtr, id_list: &mut IndexSet<ObjId>) {
191
632
            if !id_list.insert(symbol_table.id()) {
192
                return;
193
632
            }
194

            
195
632
            let table_ref = symbol_table.read();
196
779
            table_ref.iter_local().for_each(|(_, decl)| {
197
779
                id_list.insert(decl.id());
198
779
            });
199
632
        }
200

            
201
632
        let mut id_list: IndexSet<ObjId> = IndexSet::new();
202

            
203
632
        visit_symbol_table(self.symbols_ptr_unchecked().clone(), &mut id_list);
204

            
205
632
        let mut exprs: VecDeque<Expression> = self.universe_bi();
206
632
        if let Some(dominance) = &self.dominance {
207
            exprs.push_back(dominance.clone());
208
632
        }
209

            
210
632
        for symbol_table in Biplate::<SymbolTablePtr>::universe_bi(&exprs) {
211
            visit_symbol_table(symbol_table, &mut id_list);
212
        }
213
1581
        for declaration in Biplate::<DeclarationPtr>::universe_bi(&exprs) {
214
1581
            id_list.insert(declaration.id());
215
1581
        }
216

            
217
632
        let mut id_map = HashMap::new();
218
1411
        for (stable_id, original_id) in id_list.into_iter().enumerate() {
219
1411
            let type_name = original_id.type_name;
220
1411
            id_map.insert(
221
1411
                original_id,
222
1411
                ObjId {
223
1411
                    object_id: stable_id as u32,
224
1411
                    type_name,
225
1411
                },
226
1411
            );
227
1411
        }
228

            
229
632
        id_map
230
632
    }
231
}
232

            
233
impl Default for Model {
234
    fn default() -> Self {
235
        Self::new(default_context())
236
    }
237
}
238

            
239
impl Typeable for Model {
240
    fn return_type(&self) -> ReturnType {
241
        ReturnType::Bool
242
    }
243
}
244

            
245
impl Hash for Model {
246
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
247
        self.constraints.hash(state);
248
        self.symbols.hash(state);
249
        self.cnf_clauses.hash(state);
250
        self.search_order.hash(state);
251
        self.dominance.hash(state);
252
    }
253
}
254

            
255
// At time of writing (03/02/2025), the Uniplate derive macro doesn't like the lifetimes inside
256
// context, and we do not yet have a way of ignoring this field.
257
impl Uniplate for Model {
258
    fn uniplate(&self) -> (Tree<Self>, Box<dyn Fn(Tree<Self>) -> Self>) {
259
        let self2 = self.clone();
260
        (Tree::Zero, Box::new(move |_| self2.clone()))
261
    }
262
}
263

            
264
impl Biplate<Expression> for Model {
265
157102
    fn biplate(&self) -> (Tree<Expression>, Box<dyn Fn(Tree<Expression>) -> Self>) {
266
157102
        let (symtab_tree, symtab_ctx) =
267
157102
            <SymbolTable as Biplate<Expression>>::biplate(&self.symbols());
268

            
269
157102
        let dom_tree = match &self.dominance {
270
            Some(expr) => Tree::One(expr.clone()),
271
157102
            None => Tree::Zero,
272
        };
273

            
274
157102
        let tree = Tree::Many(VecDeque::from([
275
157102
            Tree::One(self.root().clone()),
276
157102
            symtab_tree,
277
157102
            dom_tree,
278
157102
        ]));
279

            
280
157102
        let self2 = self.clone();
281
157102
        let ctx = Box::new(move |x| {
282
            let Tree::Many(xs) = x else {
283
                panic!("Expected a tree with three children");
284
            };
285
            if xs.len() != 3 {
286
                panic!("Expected a tree with three children");
287
            }
288

            
289
            let Tree::One(root) = xs[0].clone() else {
290
                panic!("Expected root expression tree");
291
            };
292

            
293
            let symtab = symtab_ctx(xs[1].clone());
294
            let dominance = match xs[2].clone() {
295
                Tree::One(expr) => Some(expr),
296
                Tree::Zero => None,
297
                _ => panic!("Expected dominance tree"),
298
            };
299

            
300
            let mut self3 = self2.clone();
301

            
302
            let Expression::Root(_, _) = root else {
303
                bug!("root expression not root");
304
            };
305

            
306
            *self3.root_mut_unchecked() = root;
307
            *self3.symbols_mut() = symtab;
308
            self3.dominance = dominance;
309

            
310
            self3
311
        });
312

            
313
157102
        (tree, ctx)
314
157102
    }
315
}
316

            
317
impl Biplate<Atom> for Model {
318
    fn biplate(&self) -> (Tree<Atom>, Box<dyn Fn(Tree<Atom>) -> Self>) {
319
        let (expression_tree, rebuild_self) = <Model as Biplate<Expression>>::biplate(self);
320
        let (expression_list, rebuild_expression_tree) = expression_tree.list();
321

            
322
        let (atom_trees, reconstruct_exprs): (VecDeque<_>, VecDeque<_>) = expression_list
323
            .iter()
324
            .map(|e| <Expression as Biplate<Atom>>::biplate(e))
325
            .unzip();
326

            
327
        let tree = Tree::Many(atom_trees);
328
        let ctx = Box::new(move |atom_tree: Tree<Atom>| {
329
            let Tree::Many(atoms) = atom_tree else {
330
                panic!();
331
            };
332

            
333
            assert_eq!(
334
                atoms.len(),
335
                reconstruct_exprs.len(),
336
                "the number of children should not change when using Biplate"
337
            );
338

            
339
            let expression_list: VecDeque<Expression> = izip!(atoms, &reconstruct_exprs)
340
                .map(|(atom, recons)| recons(atom))
341
                .collect();
342

            
343
            let expression_tree = rebuild_expression_tree(expression_list);
344
            rebuild_self(expression_tree)
345
        });
346

            
347
        (tree, ctx)
348
    }
349
}
350

            
351
impl Biplate<Comprehension> for Model {
352
    fn biplate(
353
        &self,
354
    ) -> (
355
        Tree<Comprehension>,
356
        Box<dyn Fn(Tree<Comprehension>) -> Self>,
357
    ) {
358
        let (f1_tree, f1_ctx) = <_ as Biplate<Comprehension>>::biplate(&self.constraints);
359
        let (f2_tree, f2_ctx) = <SymbolTable as Biplate<Comprehension>>::biplate(&self.symbols());
360

            
361
        let tree = Tree::Many(VecDeque::from([f1_tree, f2_tree]));
362
        let self2 = self.clone();
363
        let ctx = Box::new(move |x| {
364
            let Tree::Many(xs) = x else {
365
                panic!();
366
            };
367

            
368
            let root = f1_ctx(xs[0].clone());
369
            let symtab = f2_ctx(xs[1].clone());
370

            
371
            let mut self3 = self2.clone();
372

            
373
            let Expression::Root(_, _) = &*root else {
374
                bug!("root expression not root");
375
            };
376

            
377
            *self3.symbols_mut() = symtab;
378
            self3.constraints = root;
379

            
380
            self3
381
        });
382

            
383
        (tree, ctx)
384
    }
385
}
386

            
387
impl Display for Model {
388
    #[allow(clippy::unwrap_used)]
389
24093
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390
467025
        for (name, decl) in self.symbols().clone().into_iter_local() {
391
467025
            match &decl.kind() as &DeclarationKind {
392
                DeclarationKind::Find(_) => {
393
462069
                    writeln!(
394
462069
                        f,
395
                        "{}",
396
462069
                        pretty_variable_declaration(&self.symbols(), &name).unwrap()
397
                    )?;
398
                }
399
                DeclarationKind::ValueLetting(_) | DeclarationKind::TemporaryValueLetting(_) => {
400
4200
                    writeln!(
401
4200
                        f,
402
                        "{}",
403
4200
                        pretty_value_letting_declaration(&self.symbols(), &name).unwrap()
404
                    )?;
405
                }
406
                DeclarationKind::DomainLetting(_) => {
407
672
                    writeln!(
408
672
                        f,
409
                        "{}",
410
672
                        pretty_domain_letting_declaration(&self.symbols(), &name).unwrap()
411
                    )?;
412
                }
413
                DeclarationKind::Given(d) => {
414
                    writeln!(f, "given {name}: {d}")?;
415
                }
416
                DeclarationKind::Quantified(inner) => {
417
                    writeln!(f, "quantified {name}: {}", inner.domain())?;
418
                }
419
                DeclarationKind::RecordField(_) => {
420
84
                    writeln!(f)?;
421
                }
422
            }
423
        }
424

            
425
24093
        if !self.constraints().is_empty() {
426
21990
            writeln!(f, "\nsuch that\n")?;
427
21990
            writeln!(f, "{}", pretty_expressions_as_top_level(self.constraints()))?;
428
2103
        }
429

            
430
24093
        if !self.clauses().is_empty() {
431
1827
            writeln!(f, "\nclauses:\n")?;
432
1827
            writeln!(f, "{}", pretty_clauses(self.clauses()))?;
433
22266
        }
434
24093
        Ok(())
435
24093
    }
436
}
437

            
438
/// A model that is de/serializable using `serde`.
439
///
440
/// To turn this into a rewritable model, it needs to be initialised using
441
/// [`initialise`](SerdeModel::initialise).
442
#[serde_as]
443
#[derive(Clone, Debug, Serialize, Deserialize)]
444
pub struct SerdeModel {
445
    constraints: Moo<Expression>,
446
    #[serde_as(as = "PtrAsInner")]
447
    symbols: SymbolTablePtr,
448
    cnf_clauses: Vec<CnfClause>,
449
    search_order: Option<Vec<Name>>,
450
    dominance: Option<Expression>,
451
}
452

            
453
impl SerdeModel {
454
    /// Initialises the model for rewriting.
455
1260
    pub fn initialise(mut self, context: Arc<RwLock<Context<'static>>>) -> Option<Model> {
456
1260
        let mut tables: HashMap<ObjId, SymbolTablePtr> = HashMap::new();
457

            
458
        // Root model symbol table is always definitive.
459
1260
        tables.insert(self.symbols.id(), self.symbols.clone());
460

            
461
1260
        let mut exprs: VecDeque<Expression> = self.constraints.universe_bi();
462
1260
        if let Some(dominance) = &self.dominance {
463
            exprs.push_back(dominance.clone());
464
1260
        }
465

            
466
        // Some expressions (e.g. abstract comprehensions) contain additional symbol tables.
467
1260
        for table in Biplate::<SymbolTablePtr>::universe_bi(&exprs) {
468
            tables.entry(table.id()).or_insert(table);
469
        }
470

            
471
1260
        for table in tables.clone().into_values() {
472
1260
            let mut table_mut = table.write();
473
1260
            let parent_mut = table_mut.parent_mut_unchecked();
474

            
475
            #[allow(clippy::unwrap_used)]
476
1260
            if let Some(parent) = parent_mut {
477
                let parent_id = parent.id();
478
                *parent = tables.get(&parent_id).unwrap().clone();
479
1260
            }
480
        }
481

            
482
1260
        let mut all_declarations: HashMap<ObjId, DeclarationPtr> = HashMap::new();
483
1260
        for table in tables.values() {
484
1554
            for (_, decl) in table.read().iter_local() {
485
1554
                let id = decl.id();
486
1554
                all_declarations.insert(id, decl.clone());
487
1554
            }
488
        }
489

            
490
1260
        self.constraints = self.constraints.transform_bi(&move |decl: DeclarationPtr| {
491
882
            let id = decl.id();
492
882
            all_declarations
493
882
                .get(&id)
494
882
                .unwrap_or_else(|| {
495
                    panic!(
496
                        "A declaration used in the expression tree should exist in the symbol table. The missing declaration has id {id}."
497
                    )
498
                })
499
882
                .clone()
500
882
        });
501

            
502
1260
        Some(Model {
503
1260
            constraints: self.constraints,
504
1260
            symbols: self.symbols,
505
1260
            cnf_clauses: self.cnf_clauses,
506
1260
            search_order: self.search_order,
507
1260
            dominance: self.dominance,
508
1260
            context,
509
1260
        })
510
1260
    }
511
}
512

            
513
impl From<Model> for SerdeModel {
514
632
    fn from(val: Model) -> Self {
515
632
        SerdeModel {
516
632
            constraints: val.constraints,
517
632
            symbols: val.symbols,
518
632
            cnf_clauses: val.cnf_clauses,
519
632
            search_order: val.search_order,
520
632
            dominance: val.dominance,
521
632
        }
522
632
    }
523
}
524

            
525
impl Display for SerdeModel {
526
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
527
        let model = Model {
528
            constraints: self.constraints.clone(),
529
            symbols: self.symbols.clone(),
530
            cnf_clauses: self.cnf_clauses.clone(),
531
            search_order: self.search_order.clone(),
532
            dominance: self.dominance.clone(),
533
            context: default_context(),
534
        };
535
        std::fmt::Display::fmt(&model, f)
536
    }
537
}
538

            
539
impl SerdeModel {
540
    /// Collects all ObjId values from the model and maps them to stable sequential IDs.
541
632
    pub fn collect_stable_id_mapping(&self) -> HashMap<ObjId, ObjId> {
542
632
        let model = Model {
543
632
            constraints: self.constraints.clone(),
544
632
            symbols: self.symbols.clone(),
545
632
            cnf_clauses: self.cnf_clauses.clone(),
546
632
            search_order: self.search_order.clone(),
547
632
            dominance: self.dominance.clone(),
548
632
            context: default_context(),
549
632
        };
550
632
        model.collect_stable_id_mapping()
551
632
    }
552
}
553

            
554
/// A struct for the information about expressions
555
#[serde_as]
556
#[derive(Serialize, Deserialize)]
557
pub struct ExprInfo {
558
    pretty: String,
559
    domain: Option<Moo<Domain>>,
560
    children: Vec<ExprInfo>,
561
}
562

            
563
impl ExprInfo {
564
    pub fn create(expr: &Expression) -> ExprInfo {
565
        let pretty = expr.to_string();
566
        let domain = expr.domain_of();
567
        let children = expr.children().iter().map(Self::create).collect();
568

            
569
        ExprInfo {
570
            pretty,
571
            domain,
572
            children,
573
        }
574
    }
575
}