1
use std::cell::{Ref, RefCell, RefMut};
2
use std::collections::VecDeque;
3
use std::fmt::{Debug, Display};
4
use std::rc::Rc;
5
use std::sync::{Arc, RwLock};
6

            
7
use derivative::Derivative;
8
use serde::{Deserialize, Serialize};
9
use serde_with::serde_as;
10
use uniplate::{Biplate, Tree, Uniplate};
11

            
12
use crate::ast::serde::RcRefCellAsInner;
13
use crate::ast::{Expression, SymbolTable};
14
use crate::bug;
15
use crate::context::Context;
16

            
17
use crate::ast::pretty::{
18
    pretty_domain_letting_declaration, pretty_expressions_as_top_level,
19
    pretty_value_letting_declaration, pretty_variable_declaration,
20
};
21
use crate::metadata::Metadata;
22

            
23
use super::declaration::DeclarationKind;
24
use super::types::Typeable;
25
use super::ReturnType;
26

            
27
/// Represents a computational model containing variables, constraints, and a shared context.
28
///
29
/// To de/serialise a model using serde, see [`SerdeModel`].
30
#[derive(Derivative, Clone, Debug)]
31
#[derivative(PartialEq, Eq)]
32
pub struct Model {
33
    /// Top level constraints. This should be a `Expression::Root`.
34
    constraints: Box<Expression>,
35

            
36
    symbols: Rc<RefCell<SymbolTable>>,
37

            
38
    #[derivative(PartialEq = "ignore")]
39
    pub context: Arc<RwLock<Context<'static>>>,
40
}
41

            
42
impl Model {
43
    /// Creates a new model.
44
3145
    pub fn new(
45
3145
        symbols: Rc<RefCell<SymbolTable>>,
46
3145
        constraints: Vec<Expression>,
47
3145
        context: Arc<RwLock<Context<'static>>>,
48
3145
    ) -> Model {
49
3145
        Model {
50
3145
            symbols,
51
3145
            constraints: Box::new(Expression::Root(Metadata::new(), constraints)),
52
3145
            context,
53
3145
        }
54
3145
    }
55

            
56
3077
    pub fn new_empty(context: Arc<RwLock<Context<'static>>>) -> Model {
57
3077
        Model::new(Default::default(), Vec::new(), context)
58
3077
    }
59

            
60
    /// The symbol table for this model as a pointer.
61
    ///
62
    /// The caller should only mutate the returned symbol table if this method was called on a
63
    /// mutable model.
64
    pub fn symbols_ptr_unchecked(&self) -> &Rc<RefCell<SymbolTable>> {
65
        &self.symbols
66
    }
67

            
68
    /// The global symbol table for this model as a reference.
69
12543382
    pub fn symbols(&self) -> Ref<SymbolTable> {
70
12543382
        (*self.symbols).borrow()
71
12543382
    }
72

            
73
    /// The global symbol table for this model as a mutable reference.
74
21029
    pub fn symbols_mut(&self) -> RefMut<SymbolTable> {
75
21029
        (*self.symbols).borrow_mut()
76
21029
    }
77

            
78
32351
    pub fn get_constraints_vec(&self) -> Vec<Expression> {
79
32351
        match *self.constraints {
80
32351
            Expression::Root(_, ref exprs) => exprs.clone(),
81
            ref e => {
82
                bug!(
83
                    "get_constraints_vec: unexpected top level expression, {} ",
84
                    e
85
                );
86
            }
87
        }
88
32351
    }
89

            
90
15402
    pub fn set_constraints(&mut self, constraints: Vec<Expression>) {
91
15402
        self.constraints = Box::new(Expression::Root(Metadata::new(), constraints));
92
15402
    }
93

            
94
    pub fn set_context(&mut self, context: Arc<RwLock<Context<'static>>>) {
95
        self.context = context;
96
    }
97

            
98
    pub fn add_constraint(&mut self, expression: Expression) {
99
        // TODO (gs248): there is no checking whatsoever
100
        // We need to properly validate the expression but this is just for testing
101
        let mut constraints = self.get_constraints_vec();
102
        constraints.push(expression);
103
        self.set_constraints(constraints);
104
    }
105

            
106
15402
    pub fn add_constraints(&mut self, expressions: Vec<Expression>) {
107
15402
        let mut constraints = self.get_constraints_vec();
108
15402
        constraints.extend(expressions);
109
15402
        self.set_constraints(constraints);
110
15402
    }
111
}
112

            
113
impl Typeable for Model {
114
    fn return_type(&self) -> Option<ReturnType> {
115
        Some(ReturnType::Bool)
116
    }
117
}
118

            
119
// At time of writing (03/02/2025), the Uniplate derive macro doesn't like the lifetimes inside
120
// context, and we do not yet have a way of ignoring this field.
121
impl Uniplate for Model {
122
    fn uniplate(&self) -> (Tree<Self>, Box<dyn Fn(Tree<Self>) -> Self>) {
123
        // Model contains no sub-models.
124
        let self2 = self.clone();
125
        (Tree::Zero, Box::new(move |_| self2.clone()))
126
    }
127
}
128

            
129
// TODO: replace with derive macro when possible.
130
impl Biplate<Expression> for Model {
131
102646
    fn biplate(&self) -> (Tree<Expression>, Box<dyn Fn(Tree<Expression>) -> Self>) {
132
102646
        let (symtab_tree, symtab_ctx) = (*self.symbols).borrow().biplate();
133
102646
        let (constraints_tree, constraints_ctx) = self.constraints.biplate();
134
102646

            
135
102646
        let tree = Tree::Many(VecDeque::from([symtab_tree, constraints_tree]));
136
102646

            
137
102646
        let self2 = self.clone();
138
102646
        let ctx = Box::new(move |tree| {
139
12308
            let Tree::Many(fields) = tree else {
140
                panic!("number of children changed!");
141
            };
142

            
143
12308
            let mut self3 = self2.clone();
144
12308
            {
145
12308
                let mut symbols = (*self3.symbols).borrow_mut();
146
12308
                *symbols = (symtab_ctx)(fields[0].clone());
147
12308
            }
148
12308
            self3.constraints = Box::new((constraints_ctx)(fields[1].clone()));
149
12308
            self3
150
102646
        });
151
102646

            
152
102646
        (tree, ctx)
153
102646
    }
154
}
155

            
156
impl Display for Model {
157
    #[allow(clippy::unwrap_used)] // [rustdocs]: should only fail iff the formatter fails
158
15334
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159
99722
        for (name, decl) in self.symbols().clone().into_iter_local() {
160
99722
            match decl.kind() {
161
                DeclarationKind::DecisionVariable(_) => {
162
98923
                    writeln!(
163
98923
                        f,
164
98923
                        "{}",
165
98923
                        pretty_variable_declaration(&self.symbols(), &name).unwrap()
166
98923
                    )?;
167
                }
168
                DeclarationKind::ValueLetting(_) => {
169
561
                    writeln!(
170
561
                        f,
171
561
                        "{}",
172
561
                        pretty_value_letting_declaration(&self.symbols(), &name).unwrap()
173
561
                    )?;
174
                }
175
                DeclarationKind::DomainLetting(_) => {
176
238
                    writeln!(
177
238
                        f,
178
238
                        "{}",
179
238
                        pretty_domain_letting_declaration(&self.symbols(), &name).unwrap()
180
238
                    )?;
181
                }
182
            }
183
        }
184

            
185
15334
        writeln!(f, "\nsuch that\n")?;
186

            
187
15334
        writeln!(
188
15334
            f,
189
15334
            "{}",
190
15334
            pretty_expressions_as_top_level(&self.get_constraints_vec())
191
15334
        )?;
192

            
193
15334
        Ok(())
194
15334
    }
195
}
196

            
197
/// A model that is de/serializable using `serde`.
198
///
199
/// To turn this into a rewritable model, it needs to be initialised using [`initialise`](SerdeModel::initialise).
200
///
201
/// To deserialise a [`Model`], use `.into()` to convert it into a `SerdeModel` first.
202
#[serde_as]
203
#[derive(Clone, Debug, Serialize, Deserialize)]
204
pub struct SerdeModel {
205
    constraints: Box<Expression>,
206

            
207
    #[serde_as(as = "RcRefCellAsInner")]
208
    symbols: Rc<RefCell<SymbolTable>>,
209
}
210

            
211
impl SerdeModel {
212
    /// Initialises the model for rewriting.
213
4539
    pub fn initialise(self, context: Arc<RwLock<Context<'static>>>) -> Option<Model> {
214
4539
        // TODO: Once we have submodels and multiple symbol tables, de-duplicate deserialized
215
4539
        // Rc<RefCell<>> symbol tables and declarations using their stored ids.
216
4539
        //
217
4539
        // See ast::serde::RcRefCellAsId.
218
4539
        Some(Model {
219
4539
            constraints: self.constraints,
220
4539
            symbols: self.symbols,
221
4539
            context,
222
4539
        })
223
4539
    }
224
}
225

            
226
impl From<Model> for SerdeModel {
227
4539
    fn from(val: Model) -> Self {
228
4539
        SerdeModel {
229
4539
            constraints: val.constraints,
230
4539
            symbols: val.symbols,
231
4539
        }
232
4539
    }
233
}