Skip to main content

conjure_cp_core/ast/
model.rs

1use std::collections::{HashMap, VecDeque};
2use std::fmt::{Debug, Display};
3use std::hash::Hash;
4use std::sync::{Arc, RwLock};
5
6use crate::ast::Domain;
7use crate::context::Context;
8use crate::{bug, into_matrix_expr};
9use derivative::Derivative;
10use indexmap::IndexSet;
11use itertools::izip;
12use parking_lot::{RwLockReadGuard, RwLockWriteGuard};
13use serde::{Deserialize, Serialize};
14use serde_with::serde_as;
15use uniplate::{Biplate, Tree, Uniplate};
16
17use super::serde::{HasId, ObjId, PtrAsInner};
18use super::{
19    Atom, CnfClause, DeclarationPtr, Expression, Literal, Metadata, Moo, Name, ReturnType,
20    SymbolTable, SymbolTablePtr, Typeable,
21    comprehension::Comprehension,
22    declaration::DeclarationKind,
23    pretty::{
24        pretty_clauses, pretty_domain_letting_declaration, pretty_expressions_as_top_level,
25        pretty_value_letting_declaration, pretty_variable_declaration,
26    },
27};
28
29/// An Essence model.
30#[serde_as]
31#[derive(Derivative, Clone, Debug, Serialize, Deserialize)]
32#[derivative(PartialEq, Eq)]
33pub struct Model {
34    constraints: Moo<Expression>,
35    #[serde_as(as = "PtrAsInner")]
36    symbols: SymbolTablePtr,
37    cnf_clauses: Vec<CnfClause>,
38
39    pub search_order: Option<Vec<Name>>,
40    pub dominance: Option<Expression>,
41
42    #[serde(skip, default = "default_context")]
43    #[derivative(PartialEq = "ignore")]
44    pub context: Arc<RwLock<Context<'static>>>,
45}
46
47fn default_context() -> Arc<RwLock<Context<'static>>> {
48    Arc::new(RwLock::new(Context::default()))
49}
50
51impl Model {
52    fn new_empty(symbols: SymbolTablePtr, context: Arc<RwLock<Context<'static>>>) -> Model {
53        Model {
54            constraints: Moo::new(Expression::Root(Metadata::new(), vec![])),
55            symbols,
56            cnf_clauses: Vec::new(),
57            search_order: None,
58            dominance: None,
59            context,
60        }
61    }
62
63    /// Creates a new top-level model from the given context.
64    pub fn new(context: Arc<RwLock<Context<'static>>>) -> Model {
65        Self::new_empty(SymbolTablePtr::new(), context)
66    }
67
68    /// Creates a new model whose symbol table has `parent` as parent scope.
69    pub fn new_in_parent_scope(parent: SymbolTablePtr) -> Model {
70        Self::new_empty(SymbolTablePtr::with_parent(parent), default_context())
71    }
72
73    /// The symbol table for this model as a pointer.
74    pub fn symbols_ptr_unchecked(&self) -> &SymbolTablePtr {
75        &self.symbols
76    }
77
78    /// The symbol table for this model as a mutable pointer.
79    pub fn symbols_ptr_unchecked_mut(&mut self) -> &mut SymbolTablePtr {
80        &mut self.symbols
81    }
82
83    /// The symbol table for this model as a reference.
84    pub fn symbols(&self) -> RwLockReadGuard<'_, SymbolTable> {
85        self.symbols.read()
86    }
87
88    /// The symbol table for this model as a mutable reference.
89    pub fn symbols_mut(&mut self) -> RwLockWriteGuard<'_, SymbolTable> {
90        self.symbols.write()
91    }
92
93    /// The root node of this model.
94    pub fn root(&self) -> &Expression {
95        &self.constraints
96    }
97
98    /// The root node of this model, as a mutable reference.
99    ///
100    /// The caller is responsible for ensuring that the root node remains an [`Expression::Root`].
101    pub fn root_mut_unchecked(&mut self) -> &mut Expression {
102        Moo::make_mut(&mut self.constraints)
103    }
104
105    /// Replaces the root node with `new_root`, returning the old root node.
106    pub fn replace_root(&mut self, new_root: Expression) -> Expression {
107        let Expression::Root(_, _) = new_root else {
108            tracing::error!(new_root=?new_root,"new_root is not an Expression::Root");
109            panic!("new_root is not an Expression::Root");
110        };
111
112        std::mem::replace(self.root_mut_unchecked(), new_root)
113    }
114
115    /// The top-level constraints in this model.
116    pub fn constraints(&self) -> &Vec<Expression> {
117        let Expression::Root(_, constraints) = self.constraints.as_ref() else {
118            bug!("The top level expression in a model should be Expr::Root");
119        };
120        constraints
121    }
122
123    /// The cnf clauses in this model.
124    pub fn clauses(&self) -> &Vec<CnfClause> {
125        &self.cnf_clauses
126    }
127
128    /// The top-level constraints in this model as a mutable vector.
129    pub fn constraints_mut(&mut self) -> &mut Vec<Expression> {
130        let Expression::Root(_, constraints) = Moo::make_mut(&mut self.constraints) else {
131            bug!("The top level expression in a model should be Expr::Root");
132        };
133
134        constraints
135    }
136
137    /// The cnf clauses in this model as a mutable vector.
138    pub fn clauses_mut(&mut self) -> &mut Vec<CnfClause> {
139        &mut self.cnf_clauses
140    }
141
142    /// Replaces the top-level constraints with `new_constraints`, returning the old ones.
143    pub fn replace_constraints(&mut self, new_constraints: Vec<Expression>) -> Vec<Expression> {
144        std::mem::replace(self.constraints_mut(), new_constraints)
145    }
146
147    /// Replaces the cnf clauses with `new_clauses`, returning the old ones.
148    pub fn replace_clauses(&mut self, new_clauses: Vec<CnfClause>) -> Vec<CnfClause> {
149        std::mem::replace(self.clauses_mut(), new_clauses)
150    }
151
152    /// Adds a top-level constraint.
153    pub fn add_constraint(&mut self, constraint: Expression) {
154        self.constraints_mut().push(constraint);
155    }
156
157    /// Adds a cnf clause.
158    pub fn add_clause(&mut self, clause: CnfClause) {
159        self.clauses_mut().push(clause);
160    }
161
162    /// Adds top-level constraints.
163    pub fn add_constraints(&mut self, constraints: Vec<Expression>) {
164        self.constraints_mut().extend(constraints);
165    }
166
167    /// Adds cnf clauses.
168    pub fn add_clauses(&mut self, clauses: Vec<CnfClause>) {
169        self.clauses_mut().extend(clauses);
170    }
171
172    /// Adds a new symbol to the symbol table.
173    pub fn add_symbol(&mut self, decl: DeclarationPtr) -> Option<()> {
174        self.symbols_mut().insert(decl)
175    }
176
177    /// Converts the constraints in this model to a single expression suitable for use inside
178    /// another expression tree.
179    pub fn into_single_expression(self) -> Expression {
180        let constraints = self.constraints().clone();
181        match constraints.len() {
182            0 => Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
183            1 => constraints[0].clone(),
184            _ => Expression::And(Metadata::new(), Moo::new(into_matrix_expr![constraints])),
185        }
186    }
187
188    /// Collects all ObjId values from the model using uniplate traversal.
189    pub fn collect_stable_id_mapping(&self) -> HashMap<ObjId, ObjId> {
190        fn visit_symbol_table(symbol_table: SymbolTablePtr, id_list: &mut IndexSet<ObjId>) {
191            if !id_list.insert(symbol_table.id()) {
192                return;
193            }
194
195            let table_ref = symbol_table.read();
196            table_ref.iter_local().for_each(|(_, decl)| {
197                id_list.insert(decl.id());
198            });
199        }
200
201        let mut id_list: IndexSet<ObjId> = IndexSet::new();
202
203        visit_symbol_table(self.symbols_ptr_unchecked().clone(), &mut id_list);
204
205        let mut exprs: VecDeque<Expression> = self.universe_bi();
206        if let Some(dominance) = &self.dominance {
207            exprs.push_back(dominance.clone());
208        }
209
210        for symbol_table in Biplate::<SymbolTablePtr>::universe_bi(&exprs) {
211            visit_symbol_table(symbol_table, &mut id_list);
212        }
213        for declaration in Biplate::<DeclarationPtr>::universe_bi(&exprs) {
214            id_list.insert(declaration.id());
215        }
216
217        let mut id_map = HashMap::new();
218        for (stable_id, original_id) in id_list.into_iter().enumerate() {
219            let type_name = original_id.type_name;
220            id_map.insert(
221                original_id,
222                ObjId {
223                    object_id: stable_id as u32,
224                    type_name,
225                },
226            );
227        }
228
229        id_map
230    }
231}
232
233impl Default for Model {
234    fn default() -> Self {
235        Self::new(default_context())
236    }
237}
238
239impl Typeable for Model {
240    fn return_type(&self) -> ReturnType {
241        ReturnType::Bool
242    }
243}
244
245impl Hash for Model {
246    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
247        self.constraints.hash(state);
248        self.symbols.hash(state);
249        self.cnf_clauses.hash(state);
250        self.search_order.hash(state);
251        self.dominance.hash(state);
252    }
253}
254
255// At time of writing (03/02/2025), the Uniplate derive macro doesn't like the lifetimes inside
256// context, and we do not yet have a way of ignoring this field.
257impl Uniplate for Model {
258    fn uniplate(&self) -> (Tree<Self>, Box<dyn Fn(Tree<Self>) -> Self>) {
259        let self2 = self.clone();
260        (Tree::Zero, Box::new(move |_| self2.clone()))
261    }
262}
263
264impl Biplate<Expression> for Model {
265    fn biplate(&self) -> (Tree<Expression>, Box<dyn Fn(Tree<Expression>) -> Self>) {
266        let (symtab_tree, symtab_ctx) =
267            <SymbolTable as Biplate<Expression>>::biplate(&self.symbols());
268
269        let dom_tree = match &self.dominance {
270            Some(expr) => Tree::One(expr.clone()),
271            None => Tree::Zero,
272        };
273
274        let tree = Tree::Many(VecDeque::from([
275            Tree::One(self.root().clone()),
276            symtab_tree,
277            dom_tree,
278        ]));
279
280        let self2 = self.clone();
281        let ctx = Box::new(move |x| {
282            let Tree::Many(xs) = x else {
283                panic!("Expected a tree with three children");
284            };
285            if xs.len() != 3 {
286                panic!("Expected a tree with three children");
287            }
288
289            let Tree::One(root) = xs[0].clone() else {
290                panic!("Expected root expression tree");
291            };
292
293            let symtab = symtab_ctx(xs[1].clone());
294            let dominance = match xs[2].clone() {
295                Tree::One(expr) => Some(expr),
296                Tree::Zero => None,
297                _ => panic!("Expected dominance tree"),
298            };
299
300            let mut self3 = self2.clone();
301
302            let Expression::Root(_, _) = root else {
303                bug!("root expression not root");
304            };
305
306            *self3.root_mut_unchecked() = root;
307            *self3.symbols_mut() = symtab;
308            self3.dominance = dominance;
309
310            self3
311        });
312
313        (tree, ctx)
314    }
315}
316
317impl Biplate<Atom> for Model {
318    fn biplate(&self) -> (Tree<Atom>, Box<dyn Fn(Tree<Atom>) -> Self>) {
319        let (expression_tree, rebuild_self) = <Model as Biplate<Expression>>::biplate(self);
320        let (expression_list, rebuild_expression_tree) = expression_tree.list();
321
322        let (atom_trees, reconstruct_exprs): (VecDeque<_>, VecDeque<_>) = expression_list
323            .iter()
324            .map(|e| <Expression as Biplate<Atom>>::biplate(e))
325            .unzip();
326
327        let tree = Tree::Many(atom_trees);
328        let ctx = Box::new(move |atom_tree: Tree<Atom>| {
329            let Tree::Many(atoms) = atom_tree else {
330                panic!();
331            };
332
333            assert_eq!(
334                atoms.len(),
335                reconstruct_exprs.len(),
336                "the number of children should not change when using Biplate"
337            );
338
339            let expression_list: VecDeque<Expression> = izip!(atoms, &reconstruct_exprs)
340                .map(|(atom, recons)| recons(atom))
341                .collect();
342
343            let expression_tree = rebuild_expression_tree(expression_list);
344            rebuild_self(expression_tree)
345        });
346
347        (tree, ctx)
348    }
349}
350
351impl Biplate<Comprehension> for Model {
352    fn biplate(
353        &self,
354    ) -> (
355        Tree<Comprehension>,
356        Box<dyn Fn(Tree<Comprehension>) -> Self>,
357    ) {
358        let (f1_tree, f1_ctx) = <_ as Biplate<Comprehension>>::biplate(&self.constraints);
359        let (f2_tree, f2_ctx) = <SymbolTable as Biplate<Comprehension>>::biplate(&self.symbols());
360
361        let tree = Tree::Many(VecDeque::from([f1_tree, f2_tree]));
362        let self2 = self.clone();
363        let ctx = Box::new(move |x| {
364            let Tree::Many(xs) = x else {
365                panic!();
366            };
367
368            let root = f1_ctx(xs[0].clone());
369            let symtab = f2_ctx(xs[1].clone());
370
371            let mut self3 = self2.clone();
372
373            let Expression::Root(_, _) = &*root else {
374                bug!("root expression not root");
375            };
376
377            *self3.symbols_mut() = symtab;
378            self3.constraints = root;
379
380            self3
381        });
382
383        (tree, ctx)
384    }
385}
386
387impl Display for Model {
388    #[allow(clippy::unwrap_used)]
389    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390        for (name, decl) in self.symbols().clone().into_iter_local() {
391            match &decl.kind() as &DeclarationKind {
392                DeclarationKind::Find(_) => {
393                    writeln!(
394                        f,
395                        "{}",
396                        pretty_variable_declaration(&self.symbols(), &name).unwrap()
397                    )?;
398                }
399                DeclarationKind::ValueLetting(_, _) | DeclarationKind::TemporaryValueLetting(_) => {
400                    writeln!(
401                        f,
402                        "{}",
403                        pretty_value_letting_declaration(&self.symbols(), &name).unwrap()
404                    )?;
405                }
406                DeclarationKind::DomainLetting(_) => {
407                    writeln!(
408                        f,
409                        "{}",
410                        pretty_domain_letting_declaration(&self.symbols(), &name).unwrap()
411                    )?;
412                }
413                DeclarationKind::Given(d) => {
414                    writeln!(f, "given {name}: {d}")?;
415                }
416                DeclarationKind::Quantified(inner) => {
417                    writeln!(f, "quantified {name}: {}", inner.domain())?;
418                }
419                DeclarationKind::RecordField(_) => {
420                    writeln!(f)?;
421                }
422            }
423        }
424
425        if !self.constraints().is_empty() {
426            writeln!(f, "\nsuch that\n")?;
427            writeln!(f, "{}", pretty_expressions_as_top_level(self.constraints()))?;
428        }
429
430        if !self.clauses().is_empty() {
431            writeln!(f, "\nclauses:\n")?;
432            writeln!(f, "{}", pretty_clauses(self.clauses()))?;
433        }
434        Ok(())
435    }
436}
437
438/// A model that is de/serializable using `serde`.
439///
440/// To turn this into a rewritable model, it needs to be initialised using
441/// [`initialise`](SerdeModel::initialise).
442#[serde_as]
443#[derive(Clone, Debug, Serialize, Deserialize)]
444pub struct SerdeModel {
445    constraints: Moo<Expression>,
446    #[serde_as(as = "PtrAsInner")]
447    symbols: SymbolTablePtr,
448    cnf_clauses: Vec<CnfClause>,
449    search_order: Option<Vec<Name>>,
450    dominance: Option<Expression>,
451}
452
453impl SerdeModel {
454    /// Initialises the model for rewriting.
455    pub fn initialise(mut self, context: Arc<RwLock<Context<'static>>>) -> Option<Model> {
456        let mut tables: HashMap<ObjId, SymbolTablePtr> = HashMap::new();
457
458        // Root model symbol table is always definitive.
459        tables.insert(self.symbols.id(), self.symbols.clone());
460
461        let mut exprs: VecDeque<Expression> = self.constraints.universe_bi();
462        if let Some(dominance) = &self.dominance {
463            exprs.push_back(dominance.clone());
464        }
465
466        // Some expressions (e.g. abstract comprehensions) contain additional symbol tables.
467        for table in Biplate::<SymbolTablePtr>::universe_bi(&exprs) {
468            tables.entry(table.id()).or_insert(table);
469        }
470
471        for table in tables.clone().into_values() {
472            let mut table_mut = table.write();
473            let parent_mut = table_mut.parent_mut_unchecked();
474
475            #[allow(clippy::unwrap_used)]
476            if let Some(parent) = parent_mut {
477                let parent_id = parent.id();
478                *parent = tables.get(&parent_id).unwrap().clone();
479            }
480        }
481
482        let mut all_declarations: HashMap<ObjId, DeclarationPtr> = HashMap::new();
483        for table in tables.values() {
484            for (_, decl) in table.read().iter_local() {
485                let id = decl.id();
486                all_declarations.insert(id, decl.clone());
487            }
488        }
489
490        self.constraints = self.constraints.transform_bi(&move |decl: DeclarationPtr| {
491            let id = decl.id();
492            all_declarations
493                .get(&id)
494                .unwrap_or_else(|| {
495                    panic!(
496                        "A declaration used in the expression tree should exist in the symbol table. The missing declaration has id {id}."
497                    )
498                })
499                .clone()
500        });
501
502        Some(Model {
503            constraints: self.constraints,
504            symbols: self.symbols,
505            cnf_clauses: self.cnf_clauses,
506            search_order: self.search_order,
507            dominance: self.dominance,
508            context,
509        })
510    }
511}
512
513impl From<Model> for SerdeModel {
514    fn from(val: Model) -> Self {
515        SerdeModel {
516            constraints: val.constraints,
517            symbols: val.symbols,
518            cnf_clauses: val.cnf_clauses,
519            search_order: val.search_order,
520            dominance: val.dominance,
521        }
522    }
523}
524
525impl Display for SerdeModel {
526    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
527        let model = Model {
528            constraints: self.constraints.clone(),
529            symbols: self.symbols.clone(),
530            cnf_clauses: self.cnf_clauses.clone(),
531            search_order: self.search_order.clone(),
532            dominance: self.dominance.clone(),
533            context: default_context(),
534        };
535        std::fmt::Display::fmt(&model, f)
536    }
537}
538
539impl SerdeModel {
540    /// Collects all ObjId values from the model and maps them to stable sequential IDs.
541    pub fn collect_stable_id_mapping(&self) -> HashMap<ObjId, ObjId> {
542        let model = Model {
543            constraints: self.constraints.clone(),
544            symbols: self.symbols.clone(),
545            cnf_clauses: self.cnf_clauses.clone(),
546            search_order: self.search_order.clone(),
547            dominance: self.dominance.clone(),
548            context: default_context(),
549        };
550        model.collect_stable_id_mapping()
551    }
552}
553
554/// A struct for the information about expressions
555#[serde_as]
556#[derive(Serialize)]
557pub struct ExprInfo {
558    pretty: String,
559    domain: Option<Moo<Domain>>,
560    children: Vec<ExprInfo>,
561}
562
563impl ExprInfo {
564    pub fn create(expr: &Expression) -> ExprInfo {
565        let pretty = expr.to_string();
566        let domain = expr.domain_of();
567        let children = expr.children().iter().map(Self::create).collect();
568
569        ExprInfo {
570            pretty,
571            domain,
572            children,
573        }
574    }
575}