1
use std::cell::RefCell;
2
use std::collections::HashSet;
3
use std::fmt::Debug;
4
use std::sync::{Arc, RwLock};
5

            
6
use derivative::Derivative;
7
use serde::{Deserialize, Serialize};
8
use serde_with::serde_as;
9

            
10
use crate::ast::{DecisionVariable, Domain, Expression, Name, SymbolTable};
11
use crate::context::Context;
12
use crate::metadata::Metadata;
13

            
14
/// Represents a computational model containing variables, constraints, and a shared context.
15
///
16
/// The `Model` struct holds a set of variables and constraints for manipulating and evaluating symbolic expressions.
17
///
18
/// # Fields
19
/// - `variables`:
20
///   - Type: `SymbolTable`
21
///   - A table that links each variable's name to its corresponding `DecisionVariable`.
22
///   - For example, the name `x` might be linked to a `DecisionVariable` that says `x` can only take values between 1 and 10.
23
///
24
/// - `constraints`:
25
///   - Type: `Expression`
26
///   - Represents the logical constraints applied to the model's variables.
27
///   - Can be a single constraint or a combination of various expressions, such as logical operations (e.g., `AND`, `OR`),
28
///     arithmetic operations (e.g., `SafeDiv`, `UnsafeDiv`), or specialized constraints like `SumEq`.
29
///
30
/// - `context`:
31
///   - Type: `Arc<RwLock<Context<'static>>>`
32
///   - A shared object that stores global settings and state for the model.
33
///   - Can be safely read or changed by multiple parts of the program at the same time, making it good for multi-threaded use.
34
///
35
/// - `next_var`:
36
///   - Type: `RefCell<i32>`
37
///   - A counter used to create new, unique variable names.
38
///   - Allows updating the counter inside the model without making the whole model mutable.
39
///
40
/// # Usage
41
/// This struct is typically used to:
42
/// - Define a set of variables and constraints for rule-based evaluation.
43
/// - Have transformations, optimizations, and simplifications applied to it using a set of rules.
44
#[serde_as]
45
1800
#[derive(Derivative, Clone, Debug, Serialize, Deserialize)]
46
#[derivative(PartialEq, Eq)]
47
pub struct Model {
48
    #[serde_as(as = "Vec<(_, _)>")]
49
    pub variables: SymbolTable,
50
    pub constraints: Expression,
51
    #[serde(skip)]
52
    #[derivative(PartialEq = "ignore")]
53
    pub context: Arc<RwLock<Context<'static>>>,
54
    next_var: RefCell<i32>,
55
}
56

            
57
impl Model {
58
2880004
    pub fn new(
59
2880004
        variables: SymbolTable,
60
2880004
        constraints: Expression,
61
2880004
        context: Arc<RwLock<Context<'static>>>,
62
2880004
    ) -> Model {
63
2880004
        Model {
64
2880004
            variables,
65
2880004
            constraints,
66
2880004
            context,
67
2880004
            next_var: RefCell::new(0),
68
2880004
        }
69
2880004
    }
70

            
71
2879664
    pub fn new_empty(context: Arc<RwLock<Context<'static>>>) -> Model {
72
2879664
        Model::new(
73
2879664
            Default::default(),
74
2879664
            Expression::And(Metadata::new(), Vec::new()),
75
2879664
            context,
76
2879664
        )
77
2879664
    }
78
    // Function to update a DecisionVariable based on its Name
79
17
    pub fn update_domain(&mut self, name: &Name, new_domain: Domain) {
80
17
        if let Some(decision_var) = self.variables.get_mut(name) {
81
17
            decision_var.domain = new_domain;
82
17
        }
83
17
    }
84

            
85
289
    pub fn get_domain(&self, name: &Name) -> Option<&Domain> {
86
289
        self.variables.get(name).map(|v| &v.domain)
87
289
    }
88

            
89
    // Function to add a new DecisionVariable to the Model
90
2669
    pub fn add_variable(&mut self, name: Name, decision_var: DecisionVariable) {
91
2669
        self.variables.insert(name, decision_var);
92
2669
    }
93

            
94
1955
    pub fn get_constraints_vec(&self) -> Vec<Expression> {
95
1955
        match &self.constraints {
96
1649
            Expression::And(_, constraints) => constraints.clone(),
97
306
            _ => vec![self.constraints.clone()],
98
        }
99
1955
    }
100

            
101
918
    pub fn set_constraints(&mut self, constraints: Vec<Expression>) {
102
918
        if constraints.is_empty() {
103
            self.constraints = Expression::And(Metadata::new(), Vec::new());
104
918
        } else if constraints.len() == 1 {
105
782
            self.constraints = constraints[0].clone();
106
782
        } else {
107
136
            self.constraints = Expression::And(Metadata::new(), constraints);
108
136
        }
109
918
    }
110

            
111
    pub fn set_context(&mut self, context: Arc<RwLock<Context<'static>>>) {
112
        self.context = context;
113
    }
114

            
115
    pub fn add_constraint(&mut self, expression: Expression) {
116
        // ToDo (gs248) - there is no checking whatsoever
117
        // We need to properly validate the expression but this is just for testing
118
        let mut constraints = self.get_constraints_vec();
119
        constraints.push(expression);
120
        self.set_constraints(constraints);
121
    }
122

            
123
918
    pub fn add_constraints(&mut self, expressions: Vec<Expression>) {
124
918
        let mut constraints = self.get_constraints_vec();
125
918
        constraints.extend(expressions);
126
918
        self.set_constraints(constraints);
127
918
    }
128

            
129
    /// Returns an arbitrary variable name that is not in the model.
130
765
    pub fn gensym(&self) -> Name {
131
765
        let num = *self.next_var.borrow();
132
765
        *(self.next_var.borrow_mut()) += 1;
133
765
        Name::MachineName(num) // incremented when inserted
134
765
    }
135

            
136
    /// Extends the models symbol table with the given symbol table, updating the gensym counter if
137
    /// necessary.
138
    ///
139
22270
    pub fn extend_sym_table(&mut self, symbol_table: SymbolTable) {
140
22270
        if symbol_table.keys().len() > self.variables.keys().len() {
141
340
            let new_vars = symbol_table.keys().collect::<HashSet<_>>();
142
340
            let old_vars = self.variables.keys().collect::<HashSet<_>>();
143

            
144
340
            for added_var in new_vars.difference(&old_vars) {
145
340
                let mut next_var = self.next_var.borrow_mut();
146
340
                match *added_var {
147
                    Name::UserName(_) => {}
148
340
                    Name::MachineName(m) => {
149
340
                        if *m >= *next_var {
150
340
                            *next_var = *m + 1;
151
340
                        }
152
                    }
153
                }
154
            }
155
21930
        }
156
22270
        self.variables.extend(symbol_table);
157
22270
    }
158
}