1
use doc_solver_support::doc_solver_support;
2
use serde::{Deserialize, Serialize};
3
use serde_with::serde_as;
4
use std::collections::HashMap;
5
use std::fmt::{Debug, Display, Formatter};
6

            
7
#[serde_as]
8
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
9
pub struct Model {
10
    #[serde_as(as = "Vec<(_, _)>")]
11
    pub variables: HashMap<Name, DecisionVariable>,
12
    pub constraints: Expression,
13
}
14

            
15
impl Model {
16
63
    pub fn new() -> Model {
17
63
        Model {
18
63
            variables: HashMap::new(),
19
63
            constraints: Expression::Nothing,
20
63
        }
21
63
    }
22
    // Function to update a DecisionVariable based on its Name
23
    pub fn update_domain(&mut self, name: &Name, new_domain: Domain) {
24
        if let Some(decision_var) = self.variables.get_mut(name) {
25
            decision_var.domain = new_domain;
26
        }
27
    }
28
    // Function to add a new DecisionVariable to the Model
29
119
    pub fn add_variable(&mut self, name: Name, decision_var: DecisionVariable) {
30
119
        self.variables.insert(name, decision_var);
31
119
    }
32

            
33
147
    pub fn get_constraints_vec(&self) -> Vec<Expression> {
34
147
        match &self.constraints {
35
28
            Expression::And(constraints) => constraints.clone(),
36
63
            Expression::Nothing => vec![],
37
56
            _ => vec![self.constraints.clone()],
38
        }
39
147
    }
40

            
41
84
    pub fn set_constraints(&mut self, constraints: Vec<Expression>) {
42
84
        if constraints.is_empty() {
43
            self.constraints = Expression::Nothing;
44
84
        } else if constraints.len() == 1 {
45
63
            self.constraints = constraints[0].clone();
46
63
        } else {
47
21
            self.constraints = Expression::And(constraints);
48
21
        }
49
84
    }
50

            
51
77
    pub fn add_constraint(&mut self, expression: Expression) {
52
77
        // ToDo (gs248) - there is no checking whatsoever
53
77
        // We need to properly validate the expression but this is just for testing
54
77
        let mut constraints = self.get_constraints_vec();
55
77
        constraints.push(expression);
56
77
        self.set_constraints(constraints);
57
77
    }
58

            
59
7
    pub fn add_constraints(&mut self, expressions: Vec<Expression>) {
60
7
        let mut constraints = self.get_constraints_vec();
61
7
        constraints.extend(expressions);
62
7
        self.set_constraints(constraints);
63
7
    }
64
}
65

            
66
impl Default for Model {
67
    fn default() -> Self {
68
        Self::new()
69
    }
70
}
71

            
72
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
73
pub enum Name {
74
    UserName(String),
75
    MachineName(i32),
76
}
77

            
78
impl Display for Name {
79
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80
        match self {
81
            Name::UserName(s) => write!(f, "UserName({})", s),
82
            Name::MachineName(i) => write!(f, "MachineName({})", i),
83
        }
84
    }
85
}
86

            
87
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
88
pub struct DecisionVariable {
89
    pub domain: Domain,
90
}
91

            
92
impl Display for DecisionVariable {
93
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94
        match &self.domain {
95
            Domain::BoolDomain => write!(f, "bool"),
96
            Domain::IntDomain(ranges) => {
97
                let mut first = true;
98
                for r in ranges {
99
                    if first {
100
                        first = false;
101
                    } else {
102
                        write!(f, " or ")?;
103
                    }
104
                    match r {
105
                        Range::Single(i) => write!(f, "{}", i)?,
106
                        Range::Bounded(i, j) => write!(f, "{}..{}", i, j)?,
107
                    }
108
                }
109
                Ok(())
110
            }
111
        }
112
    }
113
}
114

            
115
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
116
pub enum Domain {
117
    BoolDomain,
118
    IntDomain(Vec<Range<i32>>),
119
}
120

            
121
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
122
pub enum Range<A> {
123
    Single(A),
124
    Bounded(A, A),
125
}
126

            
127
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
128
pub enum Constant {
129
    Int(i32),
130
    Bool(bool),
131
}
132

            
133
impl TryFrom<Constant> for i32 {
134
    type Error = &'static str;
135

            
136
    fn try_from(value: Constant) -> Result<Self, Self::Error> {
137
        match value {
138
            Constant::Int(i) => Ok(i),
139
            _ => Err("Cannot convert non-i32 Constant to i32"),
140
        }
141
    }
142
}
143
impl TryFrom<Constant> for bool {
144
    type Error = &'static str;
145

            
146
    fn try_from(value: Constant) -> Result<Self, Self::Error> {
147
        match value {
148
            Constant::Bool(b) => Ok(b),
149
            _ => Err("Cannot convert non-bool Constant to bool"),
150
        }
151
    }
152
}
153

            
154
#[doc_solver_support]
155
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
156
#[non_exhaustive]
157
pub enum Expression {
158
    /**
159
     * Represents an empty expression
160
     * NB: we only expect this at the top level of a model (if there is no constraints)
161
     */
162
    Nothing,
163

            
164
    #[solver(Minion, SAT)]
165
    Constant(Constant),
166

            
167
    #[solver(Minion)]
168
    Reference(Name),
169

            
170
    Sum(Vec<Expression>),
171

            
172
    #[solver(SAT)]
173
    Not(Box<Expression>),
174
    #[solver(SAT)]
175
    Or(Vec<Expression>),
176
    #[solver(SAT)]
177
    And(Vec<Expression>),
178

            
179
    Eq(Box<Expression>, Box<Expression>),
180
    Neq(Box<Expression>, Box<Expression>),
181
    Geq(Box<Expression>, Box<Expression>),
182
    Leq(Box<Expression>, Box<Expression>),
183
    Gt(Box<Expression>, Box<Expression>),
184
    Lt(Box<Expression>, Box<Expression>),
185

            
186
    /* Flattened SumEq.
187
     *
188
     * Note: this is an intermediary step that's used in the process of converting from conjure model to minion.
189
     * This is NOT a valid expression in either Essence or minion.
190
     *
191
     * ToDo: This is a stop gap solution. Eventually it may be better to have multiple constraints instead? (gs248)
192
     */
193
    SumEq(Vec<Expression>, Box<Expression>),
194

            
195
    // Flattened Constraints
196
    #[solver(Minion)]
197
    SumGeq(Vec<Expression>, Box<Expression>),
198
    #[solver(Minion)]
199
    SumLeq(Vec<Expression>, Box<Expression>),
200
    #[solver(Minion)]
201
    Ineq(Box<Expression>, Box<Expression>, Box<Expression>),
202
}
203

            
204
impl Expression {
205
    /**
206
     * Returns a vector of references to the sub-expressions of the expression.
207
     * If the expression is a primitive (variable, constant, etc.), returns None.
208
     *
209
     * Note: If the expression is NOT MEANT TO have sub-expressions, this function will return None.
210
     * Otherwise, it will return Some(Vec), where the Vec can be empty.
211
     */
212
    pub fn sub_expressions(&self) -> Option<Vec<&Expression>> {
213
        fn unwrap_flat_expression<'a>(
214
            lhs: &'a Vec<Expression>,
215
            rhs: &'a Box<Expression>,
216
        ) -> Vec<&'a Expression> {
217
            let mut sub_exprs = lhs.iter().collect::<Vec<_>>();
218
            sub_exprs.push(rhs.as_ref());
219
            sub_exprs
220
        }
221

            
222
        match self {
223
            Expression::Constant(_) => None,
224
            Expression::Reference(_) => None,
225
            Expression::Nothing => None,
226
            Expression::Sum(exprs) => Some(exprs.iter().collect()),
227
            Expression::Not(expr_box) => Some(vec![expr_box.as_ref()]),
228
            Expression::Or(exprs) => Some(exprs.iter().collect()),
229
            Expression::And(exprs) => Some(exprs.iter().collect()),
230
            Expression::Eq(lhs, rhs) => Some(vec![lhs.as_ref(), rhs.as_ref()]),
231
            Expression::Neq(lhs, rhs) => Some(vec![lhs.as_ref(), rhs.as_ref()]),
232
            Expression::Geq(lhs, rhs) => Some(vec![lhs.as_ref(), rhs.as_ref()]),
233
            Expression::Leq(lhs, rhs) => Some(vec![lhs.as_ref(), rhs.as_ref()]),
234
            Expression::Gt(lhs, rhs) => Some(vec![lhs.as_ref(), rhs.as_ref()]),
235
            Expression::Lt(lhs, rhs) => Some(vec![lhs.as_ref(), rhs.as_ref()]),
236
            Expression::SumGeq(lhs, rhs) => Some(unwrap_flat_expression(lhs, rhs)),
237
            Expression::SumLeq(lhs, rhs) => Some(unwrap_flat_expression(lhs, rhs)),
238
            Expression::SumEq(lhs, rhs) => Some(unwrap_flat_expression(lhs, rhs)),
239
            Expression::Ineq(lhs, rhs, _) => Some(vec![lhs.as_ref(), rhs.as_ref()]),
240
        }
241
    }
242

            
243
    /// Returns a clone of the same expression type with the given sub-expressions.
244
    pub fn with_sub_expressions(&self, sub: Vec<&Expression>) -> Expression {
245
        match self {
246
            Expression::Constant(c) => Expression::Constant(c.clone()),
247
            Expression::Reference(name) => Expression::Reference(name.clone()),
248
            Expression::Nothing => Expression::Nothing,
249
            Expression::Sum(_) => Expression::Sum(sub.iter().cloned().cloned().collect()),
250
            Expression::Not(_) => Expression::Not(Box::new(sub[0].clone())),
251
            Expression::Or(_) => Expression::Or(sub.iter().cloned().cloned().collect()),
252
            Expression::And(_) => Expression::And(sub.iter().cloned().cloned().collect()),
253
            Expression::Eq(_, _) => {
254
                Expression::Eq(Box::new(sub[0].clone()), Box::new(sub[1].clone()))
255
            }
256
            Expression::Neq(_, _) => {
257
                Expression::Neq(Box::new(sub[0].clone()), Box::new(sub[1].clone()))
258
            }
259
            Expression::Geq(_, _) => {
260
                Expression::Geq(Box::new(sub[0].clone()), Box::new(sub[1].clone()))
261
            }
262
            Expression::Leq(_, _) => {
263
                Expression::Leq(Box::new(sub[0].clone()), Box::new(sub[1].clone()))
264
            }
265
            Expression::Gt(_, _) => {
266
                Expression::Gt(Box::new(sub[0].clone()), Box::new(sub[1].clone()))
267
            }
268
            Expression::Lt(_, _) => {
269
                Expression::Lt(Box::new(sub[0].clone()), Box::new(sub[1].clone()))
270
            }
271
            Expression::SumGeq(_, _) => Expression::SumGeq(
272
                sub.iter().cloned().cloned().collect(),
273
                Box::new(sub[2].clone()), // ToDo (gs248) - Why are we using sub[2] here?
274
            ),
275
            Expression::SumLeq(_, _) => Expression::SumLeq(
276
                sub.iter().cloned().cloned().collect(),
277
                Box::new(sub[2].clone()),
278
            ),
279
            Expression::SumEq(_, _) => Expression::SumEq(
280
                sub.iter().cloned().cloned().collect(),
281
                Box::new(sub[2].clone()),
282
            ),
283
            Expression::Ineq(_, _, _) => Expression::Ineq(
284
                Box::new(sub[0].clone()),
285
                Box::new(sub[1].clone()),
286
                Box::new(sub[2].clone()),
287
            ),
288
        }
289
    }
290

            
291
    pub fn is_constant(&self) -> bool {
292
        match self {
293
            Expression::Constant(_) => true,
294
            _ => false,
295
        }
296
    }
297
}
298

            
299
fn display_expressions(expressions: &Vec<Expression>) -> String {
300
    if expressions.len() <= 3 {
301
        format!(
302
            "Sum({})",
303
            expressions
304
                .iter()
305
                .map(|e| e.to_string())
306
                .collect::<Vec<String>>()
307
                .join(", ")
308
        )
309
    } else {
310
        format!(
311
            "Sum({}..{})",
312
            expressions[0],
313
            expressions[expressions.len() - 1]
314
        )
315
    }
316
}
317

            
318
impl Display for Constant {
319
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
320
        match &self {
321
            Constant::Int(i) => write!(f, "Int({})", i),
322
            Constant::Bool(b) => write!(f, "Bool({})", b),
323
        }
324
    }
325
}
326

            
327
impl Display for Expression {
328
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
329
        match &self {
330
            Expression::Constant(c) => write!(f, "Constant::{}", c),
331
            Expression::Reference(name) => write!(f, "Reference({})", name),
332
            Expression::Nothing => write!(f, "Nothing"),
333
            Expression::Sum(expressions) => write!(f, "Sum({})", display_expressions(expressions)),
334
            Expression::Not(expr_box) => write!(f, "Not({})", expr_box.clone()),
335
            Expression::Or(expressions) => write!(f, "Not({})", display_expressions(expressions)),
336
            Expression::And(expressions) => write!(f, "And({})", display_expressions(expressions)),
337
            Expression::Eq(box1, box2) => write!(f, "Eq({}, {})", box1.clone(), box2.clone()),
338
            Expression::Neq(box1, box2) => write!(f, "Neq({}, {})", box1.clone(), box2.clone()),
339
            Expression::Geq(box1, box2) => write!(f, "Geq({}, {})", box1.clone(), box2.clone()),
340
            Expression::Leq(box1, box2) => write!(f, "Leq({}, {})", box1.clone(), box2.clone()),
341
            Expression::Gt(box1, box2) => write!(f, "Gt({}, {})", box1.clone(), box2.clone()),
342
            Expression::Lt(box1, box2) => write!(f, "Lt({}, {})", box1.clone(), box2.clone()),
343
            Expression::SumGeq(box1, box2) => {
344
                write!(f, "SumGeq({}, {})", display_expressions(box1), box2.clone())
345
            }
346
            Expression::SumLeq(box1, box2) => {
347
                write!(f, "SumLeq({}, {})", display_expressions(box1), box2.clone())
348
            }
349
            Expression::Ineq(box1, box2, box3) => write!(
350
                f,
351
                "Ineq({}, {}, {})",
352
                box1.clone(),
353
                box2.clone(),
354
                box3.clone()
355
            ),
356
            #[allow(unreachable_patterns)]
357
            _ => write!(f, "Expression::Unknown"),
358
        }
359
    }
360
}