1use std::collections::{HashMap, VecDeque};
2use std::fmt::{Debug, Display};
3use std::hash::Hash;
4use std::sync::{Arc, RwLock};
5
6use crate::ast::Domain;
7use crate::context::Context;
8use crate::{bug, into_matrix_expr};
9use derivative::Derivative;
10use indexmap::IndexSet;
11use itertools::izip;
12use parking_lot::{RwLockReadGuard, RwLockWriteGuard};
13use serde::{Deserialize, Serialize};
14use serde_with::serde_as;
15use uniplate::{Biplate, Tree, Uniplate};
16
17use super::serde::{HasId, ObjId, PtrAsInner};
18use super::{
19 Atom, CnfClause, DeclarationPtr, Expression, Literal, Metadata, Moo, Name, ReturnType,
20 SymbolTable, SymbolTablePtr, Typeable,
21 comprehension::Comprehension,
22 declaration::DeclarationKind,
23 pretty::{
24 pretty_clauses, pretty_domain_letting_declaration, pretty_expressions_as_top_level,
25 pretty_value_letting_declaration, pretty_variable_declaration,
26 },
27};
28
29#[serde_as]
31#[derive(Derivative, Clone, Debug, Serialize, Deserialize)]
32#[derivative(PartialEq, Eq)]
33pub struct Model {
34 constraints: Moo<Expression>,
35 #[serde_as(as = "PtrAsInner")]
36 symbols: SymbolTablePtr,
37 cnf_clauses: Vec<CnfClause>,
38
39 pub search_order: Option<Vec<Name>>,
40 pub dominance: Option<Expression>,
41
42 #[serde(skip, default = "default_context")]
43 #[derivative(PartialEq = "ignore")]
44 pub context: Arc<RwLock<Context<'static>>>,
45}
46
47fn default_context() -> Arc<RwLock<Context<'static>>> {
48 Arc::new(RwLock::new(Context::default()))
49}
50
51impl Model {
52 fn new_empty(symbols: SymbolTablePtr, context: Arc<RwLock<Context<'static>>>) -> Model {
53 Model {
54 constraints: Moo::new(Expression::Root(Metadata::new(), vec![])),
55 symbols,
56 cnf_clauses: Vec::new(),
57 search_order: None,
58 dominance: None,
59 context,
60 }
61 }
62
63 pub fn new(context: Arc<RwLock<Context<'static>>>) -> Model {
65 Self::new_empty(SymbolTablePtr::new(), context)
66 }
67
68 pub fn new_in_parent_scope(parent: SymbolTablePtr) -> Model {
70 Self::new_empty(SymbolTablePtr::with_parent(parent), default_context())
71 }
72
73 pub fn symbols_ptr_unchecked(&self) -> &SymbolTablePtr {
75 &self.symbols
76 }
77
78 pub fn symbols_ptr_unchecked_mut(&mut self) -> &mut SymbolTablePtr {
80 &mut self.symbols
81 }
82
83 pub fn symbols(&self) -> RwLockReadGuard<'_, SymbolTable> {
85 self.symbols.read()
86 }
87
88 pub fn symbols_mut(&mut self) -> RwLockWriteGuard<'_, SymbolTable> {
90 self.symbols.write()
91 }
92
93 pub fn root(&self) -> &Expression {
95 &self.constraints
96 }
97
98 pub fn root_mut_unchecked(&mut self) -> &mut Expression {
102 Moo::make_mut(&mut self.constraints)
103 }
104
105 pub fn replace_root(&mut self, new_root: Expression) -> Expression {
107 let Expression::Root(_, _) = new_root else {
108 tracing::error!(new_root=?new_root,"new_root is not an Expression::Root");
109 panic!("new_root is not an Expression::Root");
110 };
111
112 std::mem::replace(self.root_mut_unchecked(), new_root)
113 }
114
115 pub fn constraints(&self) -> &Vec<Expression> {
117 let Expression::Root(_, constraints) = self.constraints.as_ref() else {
118 bug!("The top level expression in a model should be Expr::Root");
119 };
120 constraints
121 }
122
123 pub fn clauses(&self) -> &Vec<CnfClause> {
125 &self.cnf_clauses
126 }
127
128 pub fn constraints_mut(&mut self) -> &mut Vec<Expression> {
130 let Expression::Root(_, constraints) = Moo::make_mut(&mut self.constraints) else {
131 bug!("The top level expression in a model should be Expr::Root");
132 };
133
134 constraints
135 }
136
137 pub fn clauses_mut(&mut self) -> &mut Vec<CnfClause> {
139 &mut self.cnf_clauses
140 }
141
142 pub fn replace_constraints(&mut self, new_constraints: Vec<Expression>) -> Vec<Expression> {
144 std::mem::replace(self.constraints_mut(), new_constraints)
145 }
146
147 pub fn replace_clauses(&mut self, new_clauses: Vec<CnfClause>) -> Vec<CnfClause> {
149 std::mem::replace(self.clauses_mut(), new_clauses)
150 }
151
152 pub fn add_constraint(&mut self, constraint: Expression) {
154 self.constraints_mut().push(constraint);
155 }
156
157 pub fn add_clause(&mut self, clause: CnfClause) {
159 self.clauses_mut().push(clause);
160 }
161
162 pub fn add_constraints(&mut self, constraints: Vec<Expression>) {
164 self.constraints_mut().extend(constraints);
165 }
166
167 pub fn add_clauses(&mut self, clauses: Vec<CnfClause>) {
169 self.clauses_mut().extend(clauses);
170 }
171
172 pub fn add_symbol(&mut self, decl: DeclarationPtr) -> Option<()> {
174 self.symbols_mut().insert(decl)
175 }
176
177 pub fn into_single_expression(self) -> Expression {
180 let constraints = self.constraints().clone();
181 match constraints.len() {
182 0 => Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
183 1 => constraints[0].clone(),
184 _ => Expression::And(Metadata::new(), Moo::new(into_matrix_expr![constraints])),
185 }
186 }
187
188 pub fn collect_stable_id_mapping(&self) -> HashMap<ObjId, ObjId> {
190 fn visit_symbol_table(symbol_table: SymbolTablePtr, id_list: &mut IndexSet<ObjId>) {
191 if !id_list.insert(symbol_table.id()) {
192 return;
193 }
194
195 let table_ref = symbol_table.read();
196 table_ref.iter_local().for_each(|(_, decl)| {
197 id_list.insert(decl.id());
198 });
199 }
200
201 let mut id_list: IndexSet<ObjId> = IndexSet::new();
202
203 visit_symbol_table(self.symbols_ptr_unchecked().clone(), &mut id_list);
204
205 let mut exprs: VecDeque<Expression> = self.universe_bi();
206 if let Some(dominance) = &self.dominance {
207 exprs.push_back(dominance.clone());
208 }
209
210 for symbol_table in Biplate::<SymbolTablePtr>::universe_bi(&exprs) {
211 visit_symbol_table(symbol_table, &mut id_list);
212 }
213 for declaration in Biplate::<DeclarationPtr>::universe_bi(&exprs) {
214 id_list.insert(declaration.id());
215 }
216
217 let mut id_map = HashMap::new();
218 for (stable_id, original_id) in id_list.into_iter().enumerate() {
219 let type_name = original_id.type_name;
220 id_map.insert(
221 original_id,
222 ObjId {
223 object_id: stable_id as u32,
224 type_name,
225 },
226 );
227 }
228
229 id_map
230 }
231}
232
233impl Default for Model {
234 fn default() -> Self {
235 Self::new(default_context())
236 }
237}
238
239impl Typeable for Model {
240 fn return_type(&self) -> ReturnType {
241 ReturnType::Bool
242 }
243}
244
245impl Hash for Model {
246 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
247 self.constraints.hash(state);
248 self.symbols.hash(state);
249 self.cnf_clauses.hash(state);
250 self.search_order.hash(state);
251 self.dominance.hash(state);
252 }
253}
254
255impl Uniplate for Model {
258 fn uniplate(&self) -> (Tree<Self>, Box<dyn Fn(Tree<Self>) -> Self>) {
259 let self2 = self.clone();
260 (Tree::Zero, Box::new(move |_| self2.clone()))
261 }
262}
263
264impl Biplate<Expression> for Model {
265 fn biplate(&self) -> (Tree<Expression>, Box<dyn Fn(Tree<Expression>) -> Self>) {
266 let (symtab_tree, symtab_ctx) =
267 <SymbolTable as Biplate<Expression>>::biplate(&self.symbols());
268
269 let dom_tree = match &self.dominance {
270 Some(expr) => Tree::One(expr.clone()),
271 None => Tree::Zero,
272 };
273
274 let tree = Tree::Many(VecDeque::from([
275 Tree::One(self.root().clone()),
276 symtab_tree,
277 dom_tree,
278 ]));
279
280 let self2 = self.clone();
281 let ctx = Box::new(move |x| {
282 let Tree::Many(xs) = x else {
283 panic!("Expected a tree with three children");
284 };
285 if xs.len() != 3 {
286 panic!("Expected a tree with three children");
287 }
288
289 let Tree::One(root) = xs[0].clone() else {
290 panic!("Expected root expression tree");
291 };
292
293 let symtab = symtab_ctx(xs[1].clone());
294 let dominance = match xs[2].clone() {
295 Tree::One(expr) => Some(expr),
296 Tree::Zero => None,
297 _ => panic!("Expected dominance tree"),
298 };
299
300 let mut self3 = self2.clone();
301
302 let Expression::Root(_, _) = root else {
303 bug!("root expression not root");
304 };
305
306 *self3.root_mut_unchecked() = root;
307 *self3.symbols_mut() = symtab;
308 self3.dominance = dominance;
309
310 self3
311 });
312
313 (tree, ctx)
314 }
315}
316
317impl Biplate<Atom> for Model {
318 fn biplate(&self) -> (Tree<Atom>, Box<dyn Fn(Tree<Atom>) -> Self>) {
319 let (expression_tree, rebuild_self) = <Model as Biplate<Expression>>::biplate(self);
320 let (expression_list, rebuild_expression_tree) = expression_tree.list();
321
322 let (atom_trees, reconstruct_exprs): (VecDeque<_>, VecDeque<_>) = expression_list
323 .iter()
324 .map(|e| <Expression as Biplate<Atom>>::biplate(e))
325 .unzip();
326
327 let tree = Tree::Many(atom_trees);
328 let ctx = Box::new(move |atom_tree: Tree<Atom>| {
329 let Tree::Many(atoms) = atom_tree else {
330 panic!();
331 };
332
333 assert_eq!(
334 atoms.len(),
335 reconstruct_exprs.len(),
336 "the number of children should not change when using Biplate"
337 );
338
339 let expression_list: VecDeque<Expression> = izip!(atoms, &reconstruct_exprs)
340 .map(|(atom, recons)| recons(atom))
341 .collect();
342
343 let expression_tree = rebuild_expression_tree(expression_list);
344 rebuild_self(expression_tree)
345 });
346
347 (tree, ctx)
348 }
349}
350
351impl Biplate<Comprehension> for Model {
352 fn biplate(
353 &self,
354 ) -> (
355 Tree<Comprehension>,
356 Box<dyn Fn(Tree<Comprehension>) -> Self>,
357 ) {
358 let (f1_tree, f1_ctx) = <_ as Biplate<Comprehension>>::biplate(&self.constraints);
359 let (f2_tree, f2_ctx) = <SymbolTable as Biplate<Comprehension>>::biplate(&self.symbols());
360
361 let tree = Tree::Many(VecDeque::from([f1_tree, f2_tree]));
362 let self2 = self.clone();
363 let ctx = Box::new(move |x| {
364 let Tree::Many(xs) = x else {
365 panic!();
366 };
367
368 let root = f1_ctx(xs[0].clone());
369 let symtab = f2_ctx(xs[1].clone());
370
371 let mut self3 = self2.clone();
372
373 let Expression::Root(_, _) = &*root else {
374 bug!("root expression not root");
375 };
376
377 *self3.symbols_mut() = symtab;
378 self3.constraints = root;
379
380 self3
381 });
382
383 (tree, ctx)
384 }
385}
386
387impl Display for Model {
388 #[allow(clippy::unwrap_used)]
389 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390 for (name, decl) in self.symbols().clone().into_iter_local() {
391 match &decl.kind() as &DeclarationKind {
392 DeclarationKind::Find(_) => {
393 writeln!(
394 f,
395 "{}",
396 pretty_variable_declaration(&self.symbols(), &name).unwrap()
397 )?;
398 }
399 DeclarationKind::ValueLetting(_, _) | DeclarationKind::TemporaryValueLetting(_) => {
400 writeln!(
401 f,
402 "{}",
403 pretty_value_letting_declaration(&self.symbols(), &name).unwrap()
404 )?;
405 }
406 DeclarationKind::DomainLetting(_) => {
407 writeln!(
408 f,
409 "{}",
410 pretty_domain_letting_declaration(&self.symbols(), &name).unwrap()
411 )?;
412 }
413 DeclarationKind::Given(d) => {
414 writeln!(f, "given {name}: {d}")?;
415 }
416 DeclarationKind::Quantified(inner) => {
417 writeln!(f, "quantified {name}: {}", inner.domain())?;
418 }
419 DeclarationKind::RecordField(_) => {
420 writeln!(f)?;
421 }
422 }
423 }
424
425 if !self.constraints().is_empty() {
426 writeln!(f, "\nsuch that\n")?;
427 writeln!(f, "{}", pretty_expressions_as_top_level(self.constraints()))?;
428 }
429
430 if !self.clauses().is_empty() {
431 writeln!(f, "\nclauses:\n")?;
432 writeln!(f, "{}", pretty_clauses(self.clauses()))?;
433 }
434 Ok(())
435 }
436}
437
438#[serde_as]
443#[derive(Clone, Debug, Serialize, Deserialize)]
444pub struct SerdeModel {
445 constraints: Moo<Expression>,
446 #[serde_as(as = "PtrAsInner")]
447 symbols: SymbolTablePtr,
448 cnf_clauses: Vec<CnfClause>,
449 search_order: Option<Vec<Name>>,
450 dominance: Option<Expression>,
451}
452
453impl SerdeModel {
454 pub fn initialise(mut self, context: Arc<RwLock<Context<'static>>>) -> Option<Model> {
456 let mut tables: HashMap<ObjId, SymbolTablePtr> = HashMap::new();
457
458 tables.insert(self.symbols.id(), self.symbols.clone());
460
461 let mut exprs: VecDeque<Expression> = self.constraints.universe_bi();
462 if let Some(dominance) = &self.dominance {
463 exprs.push_back(dominance.clone());
464 }
465
466 for table in Biplate::<SymbolTablePtr>::universe_bi(&exprs) {
468 tables.entry(table.id()).or_insert(table);
469 }
470
471 for table in tables.clone().into_values() {
472 let mut table_mut = table.write();
473 let parent_mut = table_mut.parent_mut_unchecked();
474
475 #[allow(clippy::unwrap_used)]
476 if let Some(parent) = parent_mut {
477 let parent_id = parent.id();
478 *parent = tables.get(&parent_id).unwrap().clone();
479 }
480 }
481
482 let mut all_declarations: HashMap<ObjId, DeclarationPtr> = HashMap::new();
483 for table in tables.values() {
484 for (_, decl) in table.read().iter_local() {
485 let id = decl.id();
486 all_declarations.insert(id, decl.clone());
487 }
488 }
489
490 self.constraints = self.constraints.transform_bi(&move |decl: DeclarationPtr| {
491 let id = decl.id();
492 all_declarations
493 .get(&id)
494 .unwrap_or_else(|| {
495 panic!(
496 "A declaration used in the expression tree should exist in the symbol table. The missing declaration has id {id}."
497 )
498 })
499 .clone()
500 });
501
502 Some(Model {
503 constraints: self.constraints,
504 symbols: self.symbols,
505 cnf_clauses: self.cnf_clauses,
506 search_order: self.search_order,
507 dominance: self.dominance,
508 context,
509 })
510 }
511}
512
513impl From<Model> for SerdeModel {
514 fn from(val: Model) -> Self {
515 SerdeModel {
516 constraints: val.constraints,
517 symbols: val.symbols,
518 cnf_clauses: val.cnf_clauses,
519 search_order: val.search_order,
520 dominance: val.dominance,
521 }
522 }
523}
524
525impl Display for SerdeModel {
526 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
527 let model = Model {
528 constraints: self.constraints.clone(),
529 symbols: self.symbols.clone(),
530 cnf_clauses: self.cnf_clauses.clone(),
531 search_order: self.search_order.clone(),
532 dominance: self.dominance.clone(),
533 context: default_context(),
534 };
535 std::fmt::Display::fmt(&model, f)
536 }
537}
538
539impl SerdeModel {
540 pub fn collect_stable_id_mapping(&self) -> HashMap<ObjId, ObjId> {
542 let model = Model {
543 constraints: self.constraints.clone(),
544 symbols: self.symbols.clone(),
545 cnf_clauses: self.cnf_clauses.clone(),
546 search_order: self.search_order.clone(),
547 dominance: self.dominance.clone(),
548 context: default_context(),
549 };
550 model.collect_stable_id_mapping()
551 }
552}
553
554#[serde_as]
556#[derive(Serialize)]
557pub struct ExprInfo {
558 pretty: String,
559 domain: Option<Moo<Domain>>,
560 children: Vec<ExprInfo>,
561}
562
563impl ExprInfo {
564 pub fn create(expr: &Expression) -> ExprInfo {
565 let pretty = expr.to_string();
566 let domain = expr.domain_of();
567 let children = expr.children().iter().map(Self::create).collect();
568
569 ExprInfo {
570 pretty,
571 domain,
572 children,
573 }
574 }
575}