1
use std::cell::RefCell;
2
use std::collections::HashSet;
3
use std::fmt::Debug;
4
use std::sync::{Arc, RwLock};
5

            
6
use derivative::Derivative;
7
use serde::{Deserialize, Serialize};
8
use serde_with::serde_as;
9

            
10
use crate::ast::{DecisionVariable, Domain, Expression, Name, SymbolTable};
11
use crate::context::Context;
12
use crate::metadata::Metadata;
13

            
14
/// Represents a computational model containing variables, constraints, and a shared context.
15
///
16
/// The `Model` struct holds a set of variables and constraints for manipulating and evaluating symbolic expressions.
17
///
18
/// # Fields
19
/// - `variables`:
20
///   - Type: `SymbolTable`
21
///   - A table that links each variable's name to its corresponding `DecisionVariable`.
22
///   - For example, the name `x` might be linked to a `DecisionVariable` that says `x` can only take values between 1 and 10.
23
///
24
/// - `constraints`:
25
///   - Type: `Expression`
26
///   - Represents the logical constraints applied to the model's variables.
27
///   - Can be a single constraint or a combination of various expressions, such as logical operations (e.g., `AND`, `OR`),
28
///     arithmetic operations (e.g., `SafeDiv`, `UnsafeDiv`), or specialized constraints like `SumEq`.
29
///
30
/// - `context`:
31
///   - Type: `Arc<RwLock<Context<'static>>>`
32
///   - A shared object that stores global settings and state for the model.
33
///   - Can be safely read or changed by multiple parts of the program at the same time, making it good for multi-threaded use.
34
///
35
/// - `next_var`:
36
///   - Type: `RefCell<i32>`
37
///   - A counter used to create new, unique variable names.
38
///   - Allows updating the counter inside the model without making the whole model mutable.
39
///
40
/// # Usage
41
/// This struct is typically used to:
42
/// - Define a set of variables and constraints for rule-based evaluation.
43
/// - Have transformations, optimizations, and simplifications applied to it using a set of rules.
44
#[serde_as]
45
354
#[derive(Derivative, Clone, Debug, Serialize, Deserialize)]
46
#[derivative(PartialEq, Eq)]
47
pub struct Model {
48
    #[serde_as(as = "Vec<(_, _)>")]
49
    pub variables: SymbolTable,
50
    pub constraints: Vec<Expression>,
51
    #[serde(skip)]
52
    #[derivative(PartialEq = "ignore")]
53
    pub context: Arc<RwLock<Context<'static>>>,
54
    next_var: RefCell<i32>,
55
}
56

            
57
impl Model {
58
468
    pub fn new(
59
468
        variables: SymbolTable,
60
468
        constraints: Vec<Expression>,
61
468
        context: Arc<RwLock<Context<'static>>>,
62
468
    ) -> Model {
63
468
        Model {
64
468
            variables,
65
468
            constraints,
66
468
            context,
67
468
            next_var: RefCell::new(0),
68
468
        }
69
468
    }
70

            
71
468
    pub fn new_empty(context: Arc<RwLock<Context<'static>>>) -> Model {
72
468
        Model::new(Default::default(), Vec::new(), context)
73
468
    }
74
    // Function to update a DecisionVariable based on its Name
75
    pub fn update_domain(&mut self, name: &Name, new_domain: Domain) {
76
        if let Some(decision_var) = self.variables.get_mut(name) {
77
            decision_var.domain = new_domain;
78
        }
79
    }
80

            
81
    pub fn get_domain(&self, name: &Name) -> Option<&Domain> {
82
        self.variables.get(name).map(|v| &v.domain)
83
    }
84

            
85
    // Function to add a new DecisionVariable to the Model
86
1089
    pub fn add_variable(&mut self, name: Name, decision_var: DecisionVariable) {
87
1089
        self.variables.insert(name, decision_var);
88
1089
    }
89

            
90
477
    pub fn get_constraints_vec(&self) -> Vec<Expression> {
91
477
        self.constraints.clone()
92
477
    }
93

            
94
477
    pub fn set_constraints(&mut self, constraints: Vec<Expression>) {
95
477
        if constraints.is_empty() {
96
            self.constraints = Vec::new();
97
477
        } else {
98
477
            self.constraints = constraints;
99
477
        }
100
477
    }
101

            
102
    pub fn set_context(&mut self, context: Arc<RwLock<Context<'static>>>) {
103
        self.context = context;
104
    }
105

            
106
    pub fn add_constraint(&mut self, expression: Expression) {
107
        // ToDo (gs248) - there is no checking whatsoever
108
        // We need to properly validate the expression but this is just for testing
109
        let mut constraints = self.get_constraints_vec();
110
        constraints.push(expression);
111
        self.set_constraints(constraints);
112
    }
113

            
114
477
    pub fn add_constraints(&mut self, expressions: Vec<Expression>) {
115
477
        let mut constraints = self.get_constraints_vec();
116
477
        constraints.extend(expressions);
117
477
        self.set_constraints(constraints);
118
477
    }
119

            
120
    /// Returns an arbitrary variable name that is not in the model.
121
    pub fn gensym(&self) -> Name {
122
        let num = *self.next_var.borrow();
123
        *(self.next_var.borrow_mut()) += 1;
124
        Name::MachineName(num) // incremented when inserted
125
    }
126

            
127
    /// Extends the models symbol table with the given symbol table, updating the gensym counter if
128
    /// necessary.
129
    ///
130
    pub fn extend_sym_table(&mut self, symbol_table: SymbolTable) {
131
        if symbol_table.keys().len() > self.variables.keys().len() {
132
            let new_vars = symbol_table.keys().collect::<HashSet<_>>();
133
            let old_vars = self.variables.keys().collect::<HashSet<_>>();
134

            
135
            for added_var in new_vars.difference(&old_vars) {
136
                let mut next_var = self.next_var.borrow_mut();
137
                match *added_var {
138
                    Name::UserName(_) => {}
139
                    Name::MachineName(m) => {
140
                        if *m >= *next_var {
141
                            *next_var = *m + 1;
142
                        }
143
                    }
144
                }
145
            }
146
        }
147
        self.variables.extend(symbol_table);
148
    }
149
}