1
use std::borrow::{Borrow, BorrowMut};
2
use std::cell::RefCell;
3
use std::collections::HashSet;
4
use std::fmt::Debug;
5
use std::sync::{Arc, RwLock};
6

            
7
use derivative::Derivative;
8
use serde::{Deserialize, Serialize};
9
use serde_with::serde_as;
10

            
11
use crate::ast::{DecisionVariable, Domain, Expression, Name, SymbolTable};
12
use crate::context::Context;
13
use crate::metadata::Metadata;
14

            
15
/// Represents a computational model containing variables, constraints, and a shared context.
16
///
17
/// The `Model` struct holds a set of variables and constraints for manipulating and evaluating symbolic expressions.
18
///
19
/// # Fields
20
/// - `variables`:
21
///   - Type: `SymbolTable`
22
///   - A table that links each variable's name to its corresponding `DecisionVariable`.
23
///   - For example, the name `x` might be linked to a `DecisionVariable` that says `x` can only take values between 1 and 10.
24
///
25
/// - `constraints`:
26
///   - Type: `Expression`
27
///   - Represents the logical constraints applied to the model's variables.
28
///   - Can be a single constraint or a combination of various expressions, such as logical operations (e.g., `AND`, `OR`),
29
///     arithmetic operations (e.g., `SafeDiv`, `UnsafeDiv`), or specialized constraints like `SumEq`.
30
///
31
/// - `context`:
32
///   - Type: `Arc<RwLock<Context<'static>>>`
33
///   - A shared object that stores global settings and state for the model.
34
///   - Can be safely read or changed by multiple parts of the program at the same time, making it good for multi-threaded use.
35
///
36
/// - `next_var`:
37
///   - Type: `RefCell<i32>`
38
///   - A counter used to create new, unique variable names.
39
///   - Allows updating the counter inside the model without making the whole model mutable.
40
///
41
/// # Usage
42
/// This struct is typically used to:
43
/// - Define a set of variables and constraints for rule-based evaluation.
44
/// - Have transformations, optimizations, and simplifications applied to it using a set of rules.
45
#[serde_as]
46
1800
#[derive(Derivative, Clone, Debug, Serialize, Deserialize)]
47
#[derivative(PartialEq, Eq)]
48
pub struct Model {
49
    #[serde_as(as = "Vec<(_, _)>")]
50
    pub variables: SymbolTable,
51
    pub constraints: Expression,
52
    #[serde(skip)]
53
    #[derivative(PartialEq = "ignore")]
54
    pub context: Arc<RwLock<Context<'static>>>,
55
    next_var: RefCell<i32>,
56
}
57

            
58
impl Model {
59
2880004
    pub fn new(
60
2880004
        variables: SymbolTable,
61
2880004
        constraints: Expression,
62
2880004
        context: Arc<RwLock<Context<'static>>>,
63
2880004
    ) -> Model {
64
2880004
        Model {
65
2880004
            variables,
66
2880004
            constraints,
67
2880004
            context,
68
2880004
            next_var: RefCell::new(0),
69
2880004
        }
70
2880004
    }
71

            
72
2879664
    pub fn new_empty(context: Arc<RwLock<Context<'static>>>) -> Model {
73
2879664
        Model::new(
74
2879664
            Default::default(),
75
2879664
            Expression::And(Metadata::new(), Vec::new()),
76
2879664
            context,
77
2879664
        )
78
2879664
    }
79
    // Function to update a DecisionVariable based on its Name
80
17
    pub fn update_domain(&mut self, name: &Name, new_domain: Domain) {
81
17
        if let Some(decision_var) = self.variables.get_mut(name) {
82
17
            decision_var.domain = new_domain;
83
17
        }
84
17
    }
85

            
86
289
    pub fn get_domain(&self, name: &Name) -> Option<&Domain> {
87
289
        self.variables.get(name).map(|v| &v.domain)
88
289
    }
89

            
90
    // Function to add a new DecisionVariable to the Model
91
2669
    pub fn add_variable(&mut self, name: Name, decision_var: DecisionVariable) {
92
2669
        self.variables.insert(name, decision_var);
93
2669
    }
94

            
95
1955
    pub fn get_constraints_vec(&self) -> Vec<Expression> {
96
1955
        match &self.constraints {
97
1649
            Expression::And(_, constraints) => constraints.clone(),
98
306
            _ => vec![self.constraints.clone()],
99
        }
100
1955
    }
101

            
102
918
    pub fn set_constraints(&mut self, constraints: Vec<Expression>) {
103
918
        if constraints.is_empty() {
104
            self.constraints = Expression::And(Metadata::new(), Vec::new());
105
918
        } else if constraints.len() == 1 {
106
782
            self.constraints = constraints[0].clone();
107
782
        } else {
108
136
            self.constraints = Expression::And(Metadata::new(), constraints);
109
136
        }
110
918
    }
111

            
112
    pub fn set_context(&mut self, context: Arc<RwLock<Context<'static>>>) {
113
        self.context = context;
114
    }
115

            
116
    pub fn add_constraint(&mut self, expression: Expression) {
117
        // ToDo (gs248) - there is no checking whatsoever
118
        // We need to properly validate the expression but this is just for testing
119
        let mut constraints = self.get_constraints_vec();
120
        constraints.push(expression);
121
        self.set_constraints(constraints);
122
    }
123

            
124
918
    pub fn add_constraints(&mut self, expressions: Vec<Expression>) {
125
918
        let mut constraints = self.get_constraints_vec();
126
918
        constraints.extend(expressions);
127
918
        self.set_constraints(constraints);
128
918
    }
129

            
130
    /// Returns an arbitrary variable name that is not in the model.
131
765
    pub fn gensym(&self) -> Name {
132
765
        let num = *self.next_var.borrow();
133
765
        *(self.next_var.borrow_mut()) += 1;
134
765
        Name::MachineName(num) // incremented when inserted
135
765
    }
136

            
137
    /// Extends the models symbol table with the given symbol table, updating the gensym counter if
138
    /// necessary.
139
    ///
140
22270
    pub fn extend_sym_table(&mut self, symbol_table: SymbolTable) {
141
22270
        if symbol_table.keys().len() > self.variables.keys().len() {
142
340
            let new_vars = symbol_table.keys().collect::<HashSet<_>>();
143
340
            let old_vars = self.variables.keys().collect::<HashSet<_>>();
144

            
145
340
            for added_var in new_vars.difference(&old_vars) {
146
340
                let mut next_var = self.next_var.borrow_mut();
147
340
                match *added_var {
148
                    Name::UserName(_) => {}
149
340
                    Name::MachineName(m) => {
150
340
                        if *m >= *next_var {
151
340
                            *next_var = *m + 1;
152
340
                        }
153
                    }
154
                }
155
            }
156
21930
        }
157
22270
        self.variables.extend(symbol_table);
158
22270
    }
159
}