1
use std::fmt::{Debug, Display};
2
use std::sync::{Arc, RwLock};
3

            
4
use derivative::Derivative;
5
use serde::{Deserialize, Serialize};
6

            
7
use crate::ast::{DecisionVariable, Domain, Expression, Name, SymbolTable};
8
use crate::context::Context;
9

            
10
use crate::ast::pretty::{pretty_expressions_as_top_level, pretty_variable_declaration};
11

            
12
/// Represents a computational model containing variables, constraints, and a shared context.
13
///
14
/// The `Model` struct holds a set of variables and constraints for manipulating and evaluating symbolic expressions.
15
///
16
/// # Fields
17
/// - `constraints`:
18
///   - Type: `Vec<Expression>`
19
///   - Represents the logical constraints applied to the model's variables.
20
///   - Can be a single constraint or a combination of various expressions, such as logical operations (e.g., `AND`, `OR`),
21
///     arithmetic operations (e.g., `SafeDiv`, `UnsafeDiv`), or specialized constraints like `SumEq`.
22
///
23
/// - `context`:
24
///   - Type: `Arc<RwLock<Context<'static>>>`
25
///   - A shared object that stores global settings and state for the model.
26
///   - Can be safely read or changed by multiple parts of the program at the same time, making it good for multi-threaded use.
27
///
28
/// # Usage
29
/// This struct is typically used to:
30
/// - Define a set of variables and constraints for rule-based evaluation.
31
/// - Have transformations, optimizations, and simplifications applied to it using a set of rules.
32
#[derive(Derivative, Clone, Debug, Serialize, Deserialize)]
33
#[derivative(PartialEq, Eq)]
34
pub struct Model {
35
    pub constraints: Vec<Expression>,
36

            
37
    symbols: SymbolTable,
38

            
39
    #[serde(skip)]
40
    #[derivative(PartialEq = "ignore")]
41
    pub context: Arc<RwLock<Context<'static>>>,
42
}
43

            
44
impl Model {
45
    /// Creates a new model.
46
6307
    pub fn new(
47
6307
        symbols: SymbolTable,
48
6307
        constraints: Vec<Expression>,
49
6307
        context: Arc<RwLock<Context<'static>>>,
50
6307
    ) -> Model {
51
6307
        Model {
52
6307
            symbols,
53
6307
            constraints,
54
6307
            context,
55
6307
        }
56
6307
    }
57

            
58
6239
    pub fn new_empty(context: Arc<RwLock<Context<'static>>>) -> Model {
59
6239
        Model::new(Default::default(), Vec::new(), context)
60
6239
    }
61

            
62
    /// The global symbol table for this model.
63
18853
    pub fn symbols(&self) -> &SymbolTable {
64
18853
        &self.symbols
65
18853
    }
66

            
67
    /// The global symbol table for this model, as a mutable reference.
68
20009
    pub fn symbols_mut(&mut self) -> &mut SymbolTable {
69
20009
        &mut self.symbols
70
20009
    }
71

            
72
    // Function to update a DecisionVariable based on its Name
73
17
    pub fn update_domain(&mut self, name: &Name, new_domain: Domain) {
74
17
        if let Some(decision_var) = self.symbols_mut().get_var_mut(name) {
75
17
            decision_var.domain = new_domain;
76
17
        }
77
17
    }
78

            
79
    /// Gets the domain of `name` if it exists and has one.
80
85
    pub fn get_domain(&self, name: &Name) -> Option<&Domain> {
81
85
        self.symbols().domain_of(name)
82
85
    }
83

            
84
    /// Adds a decision variable to the model.
85
    ///
86
    /// Returns `None` if there is a decision variable or other object with that name in the symbol
87
    /// table.
88
8177
    pub fn add_variable(&mut self, name: Name, decision_var: DecisionVariable) -> Option<()> {
89
8177
        self.symbols_mut().add_var(name, decision_var)
90
8177
    }
91

            
92
92174
    pub fn get_constraints_vec(&self) -> Vec<Expression> {
93
92174
        self.constraints.clone()
94
92174
    }
95

            
96
13243
    pub fn set_constraints(&mut self, constraints: Vec<Expression>) {
97
13243
        if constraints.is_empty() {
98
            self.constraints = Vec::new();
99
13243
        } else {
100
13243
            self.constraints = constraints;
101
13243
        }
102
13243
    }
103

            
104
    pub fn set_context(&mut self, context: Arc<RwLock<Context<'static>>>) {
105
        self.context = context;
106
    }
107

            
108
    pub fn add_constraint(&mut self, expression: Expression) {
109
        // ToDo (gs248) - there is no checking whatsoever
110
        // We need to properly validate the expression but this is just for testing
111
        let mut constraints = self.get_constraints_vec();
112
        constraints.push(expression);
113
        self.set_constraints(constraints);
114
    }
115

            
116
1496
    pub fn add_constraints(&mut self, expressions: Vec<Expression>) {
117
1496
        let mut constraints = self.get_constraints_vec();
118
1496
        constraints.extend(expressions);
119
1496
        self.set_constraints(constraints);
120
1496
    }
121

            
122
    /// Returns an arbitrary variable name that is not in the model.
123
1972
    pub fn gensym(&self) -> Name {
124
1972
        self.symbols().gensym()
125
1972
    }
126

            
127
    /// Extends the models symbol table with the given symbol table, updating the gensym counter if
128
    /// necessary.
129
11815
    pub fn extend_sym_table(&mut self, other: SymbolTable) {
130
11815
        self.symbols_mut().extend(other);
131
11815
    }
132
}
133

            
134
impl Display for Model {
135
    #[allow(clippy::unwrap_used)] // [rustdocs]: should only fail iff the formatter fails
136
2856
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137
9044
        for name in self.symbols.names() {
138
9044
            writeln!(
139
9044
                f,
140
9044
                "find {}",
141
9044
                pretty_variable_declaration(&self.symbols, name).unwrap()
142
9044
            )?;
143
        }
144

            
145
2856
        writeln!(f, "\nsuch that\n")?;
146

            
147
2856
        writeln!(f, "{}", pretty_expressions_as_top_level(&self.constraints))?;
148

            
149
2856
        Ok(())
150
2856
    }
151
}