1
use std::borrow::{Borrow, BorrowMut};
2
use std::cell::RefCell;
3
use std::collections::HashSet;
4
use std::fmt::Debug;
5
use std::sync::{Arc, RwLock};
6

            
7
use derivative::Derivative;
8
use serde::{Deserialize, Serialize};
9
use serde_with::serde_as;
10

            
11
use crate::ast::{DecisionVariable, Domain, Expression, Name, SymbolTable};
12
use crate::context::Context;
13
use crate::metadata::Metadata;
14

            
15
/// Represents a computational model containing variables, constraints, and a shared context.
16
///
17
/// The `Model` struct holds a set of variables and constraints for manipulating and evaluating symbolic expressions.
18
///
19
/// # Fields
20
/// - `variables`:
21
///   - Type: `SymbolTable`
22
///   - A table that links each variable's name to its corresponding `DecisionVariable`.
23
///   - For example, the name `x` might be linked to a `DecisionVariable` that says `x` can only take values between 1 and 10.
24
///
25
/// - `constraints`:
26
///   - Type: `Expression`
27
///   - Represents the logical constraints applied to the model's variables.
28
///   - Can be a single constraint or a combination of various expressions, such as logical operations (e.g., `AND`, `OR`),
29
///     arithmetic operations (e.g., `SafeDiv`, `UnsafeDiv`), or specialized constraints like `SumEq`.
30
///
31
/// - `context`:
32
///   - Type: `Arc<RwLock<Context<'static>>>`
33
///   - A shared object that stores global settings and state for the model.
34
///   - Can be safely read or changed by multiple parts of the program at the same time, making it good for multi-threaded use.
35
///
36
/// - `next_var`:
37
///   - Type: `RefCell<i32>`
38
///   - A counter used to create new, unique variable names.
39
///   - Allows updating the counter inside the model without making the whole model mutable.
40
///
41
/// # Usage
42
/// This struct is typically used to:
43
/// - Define a set of variables and constraints for rule-based evaluation.
44
/// - Have transformations, optimizations, and simplifications applied to it using a set of rules.
45
#[serde_as]
46
1404
#[derive(Derivative, Clone, Debug, Serialize, Deserialize)]
47
#[derivative(PartialEq, Eq)]
48
pub struct Model {
49
    #[serde_as(as = "Vec<(_, _)>")]
50
    pub variables: SymbolTable,
51
    pub constraints: Expression,
52
    #[serde(skip)]
53
    #[derivative(PartialEq = "ignore")]
54
    pub context: Arc<RwLock<Context<'static>>>,
55
    next_var: RefCell<i32>,
56
}
57

            
58
impl Model {
59
2735895
    pub fn new(
60
2735895
        variables: SymbolTable,
61
2735895
        constraints: Expression,
62
2735895
        context: Arc<RwLock<Context<'static>>>,
63
2735895
    ) -> Model {
64
2735895
        Model {
65
2735895
            variables,
66
2735895
            constraints,
67
2735895
            context,
68
2735895
            next_var: RefCell::new(0),
69
2735895
        }
70
2735895
    }
71

            
72
2735555
    pub fn new_empty(context: Arc<RwLock<Context<'static>>>) -> Model {
73
2735555
        Model::new(
74
2735555
            Default::default(),
75
2735555
            Expression::And(Metadata::new(), Vec::new()),
76
2735555
            context,
77
2735555
        )
78
2735555
    }
79
    // Function to update a DecisionVariable based on its Name
80
17
    pub fn update_domain(&mut self, name: &Name, new_domain: Domain) {
81
17
        if let Some(decision_var) = self.variables.get_mut(name) {
82
17
            decision_var.domain = new_domain;
83
17
        }
84
17
    }
85

            
86
289
    pub fn get_domain(&self, name: &Name) -> Option<&Domain> {
87
289
        self.variables.get(name).map(|v| &v.domain)
88
289
    }
89

            
90
    // Function to add a new DecisionVariable to the Model
91
1921
    pub fn add_variable(&mut self, name: Name, decision_var: DecisionVariable) {
92
1921
        self.variables.insert(name, decision_var);
93
1921
    }
94

            
95
1581
    pub fn get_constraints_vec(&self) -> Vec<Expression> {
96
1581
        match &self.constraints {
97
1292
            Expression::And(_, constraints) => constraints.clone(),
98
289
            _ => vec![self.constraints.clone()],
99
        }
100
1581
    }
101

            
102
731
    pub fn set_constraints(&mut self, constraints: Vec<Expression>) {
103
731
        if constraints.is_empty() {
104
            self.constraints = Expression::And(Metadata::new(), Vec::new());
105
731
        } else if constraints.len() == 1 {
106
595
            self.constraints = constraints[0].clone();
107
595
        } else {
108
136
            self.constraints = Expression::And(Metadata::new(), constraints);
109
136
        }
110
731
    }
111

            
112
    pub fn set_context(&mut self, context: Arc<RwLock<Context<'static>>>) {
113
        self.context = context;
114
    }
115

            
116
    pub fn add_constraint(&mut self, expression: Expression) {
117
        // ToDo (gs248) - there is no checking whatsoever
118
        // We need to properly validate the expression but this is just for testing
119
        let mut constraints = self.get_constraints_vec();
120
        constraints.push(expression);
121
        self.set_constraints(constraints);
122
    }
123

            
124
731
    pub fn add_constraints(&mut self, expressions: Vec<Expression>) {
125
731
        let mut constraints = self.get_constraints_vec();
126
731
        constraints.extend(expressions);
127
731
        self.set_constraints(constraints);
128
731
    }
129

            
130
    /// Returns an arbitrary variable name that is not in the model.
131
476
    pub fn gensym(&self) -> Name {
132
476
        let num = *self.next_var.borrow();
133
476
        *(self.next_var.borrow_mut()) += 1;
134
476
        Name::MachineName(num) // incremented when inserted
135
476
    }
136

            
137
    /// Extends the models symbol table with the given symbol table, updating the gensym counter if
138
    /// necessary.
139
    ///
140
20570
    pub fn extend_sym_table(&mut self, symbol_table: SymbolTable) {
141
20570
        if symbol_table.keys().len() > self.variables.keys().len() {
142
170
            let new_vars = symbol_table.keys().collect::<HashSet<_>>();
143
170
            let old_vars = self.variables.keys().collect::<HashSet<_>>();
144

            
145
170
            for added_var in new_vars.difference(&old_vars) {
146
170
                let mut next_var = self.next_var.borrow_mut();
147
170
                match *added_var {
148
                    Name::UserName(_) => {}
149
170
                    Name::MachineName(m) => {
150
170
                        if *m >= *next_var {
151
170
                            *next_var = *m + 1;
152
170
                        }
153
                    }
154
                }
155
            }
156
20400
        }
157
20570
        self.variables.extend(symbol_table);
158
20570
    }
159
}