1
use std::collections::HashMap;
2
use std::sync::{Mutex, OnceLock};
3

            
4
use regex::Regex;
5

            
6
use minion_ast::Model as MinionModel;
7
use minion_rs::ast as minion_ast;
8
use minion_rs::error::MinionError;
9
use minion_rs::{get_from_table, run_minion};
10

            
11
use crate::ast as conjure_ast;
12
use crate::solver::SolverCallback;
13
use crate::solver::SolverFamily;
14
use crate::solver::SolverMutCallback;
15
use crate::stats::SolverStats;
16
use crate::Model as ConjureModel;
17

            
18
use super::super::model_modifier::NotModifiable;
19
use super::super::private;
20
use super::super::SearchComplete::*;
21
use super::super::SearchIncomplete::*;
22
use super::super::SearchStatus::*;
23
use super::super::SolveSuccess;
24
use super::super::SolverAdaptor;
25
use super::super::SolverError;
26
use super::super::SolverError::*;
27

            
28
/// A [SolverAdaptor] for interacting with Minion.
29
///
30
/// This adaptor uses the `minion_rs` crate to talk to Minion over FFI.
31
pub struct Minion {
32
    __non_constructable: private::Internal,
33
    model: Option<MinionModel>,
34
}
35

            
36
static MINION_LOCK: Mutex<()> = Mutex::new(());
37
static USER_CALLBACK: OnceLock<Mutex<SolverCallback>> = OnceLock::new();
38
static ANY_SOLUTIONS: Mutex<bool> = Mutex::new(false);
39
static USER_TERMINATED: Mutex<bool> = Mutex::new(false);
40

            
41
#[allow(clippy::unwrap_used)]
42
26775
fn minion_rs_callback(solutions: HashMap<minion_ast::VarName, minion_ast::Constant>) -> bool {
43
26775
    *(ANY_SOLUTIONS.lock().unwrap()) = true;
44
26775
    let callback = USER_CALLBACK
45
26775
        .get_or_init(|| Mutex::new(Box::new(|x| true)))
46
26775
        .lock()
47
26775
        .unwrap();
48
26775

            
49
26775
    let mut conjure_solutions: HashMap<conjure_ast::Name, conjure_ast::Literal> = HashMap::new();
50
120802
    for (minion_name, minion_const) in solutions.into_iter() {
51
120802
        let conjure_const = match minion_const {
52
            minion_ast::Constant::Bool(x) => conjure_ast::Literal::Bool(x),
53
120802
            minion_ast::Constant::Integer(x) => conjure_ast::Literal::Int(x),
54
            _ => todo!(),
55
        };
56

            
57
120802
        let machine_name_re = Regex::new(r"__conjure_machine_name_([0-9]+)").unwrap();
58
120802
        let conjure_name = if let Some(caps) = machine_name_re.captures(&minion_name) {
59
44285
            conjure_ast::Name::MachineName(caps[1].parse::<i32>().unwrap())
60
        } else {
61
76517
            conjure_ast::Name::UserName(minion_name)
62
        };
63

            
64
120802
        conjure_solutions.insert(conjure_name, conjure_const);
65
    }
66

            
67
26775
    let continue_search = (**callback)(conjure_solutions);
68
26775
    if !continue_search {
69
        *(USER_TERMINATED.lock().unwrap()) = true;
70
26775
    }
71

            
72
26775
    continue_search
73
26775
}
74

            
75
impl private::Sealed for Minion {}
76

            
77
impl Minion {
78
1037
    pub fn new() -> Minion {
79
1037
        Minion {
80
1037
            __non_constructable: private::Internal,
81
1037
            model: None,
82
1037
        }
83
1037
    }
84
}
85

            
86
impl Default for Minion {
87
    fn default() -> Self {
88
        Minion::new()
89
    }
90
}
91

            
92
impl SolverAdaptor for Minion {
93
    #[allow(clippy::unwrap_used)]
94
1037
    fn solve(
95
1037
        &mut self,
96
1037
        callback: SolverCallback,
97
1037
        _: private::Internal,
98
1037
    ) -> Result<SolveSuccess, SolverError> {
99
1037
        // our minion callback is global state, so single threading the adaptor as a whole is
100
1037
        // probably a good move...
101
1037
        #[allow(clippy::unwrap_used)]
102
1037
        let mut minion_lock = MINION_LOCK.lock().unwrap();
103
1037

            
104
1037
        #[allow(clippy::unwrap_used)]
105
1037
        let mut user_callback = USER_CALLBACK
106
1037
            .get_or_init(|| Mutex::new(Box::new(|x| true)))
107
1037
            .lock()
108
1037
            .unwrap();
109
1037
        *user_callback = callback;
110
1037
        drop(user_callback); // release mutex. REQUIRED so that run_minion can use the
111
1037
                             // user callback and not deadlock.
112
1037

            
113
1037
        run_minion(
114
1037
            self.model.clone().expect("STATE MACHINE ERR"),
115
1037
            minion_rs_callback,
116
1037
        )
117
1037
        .map_err(|err| match err {
118
            MinionError::RuntimeError(x) => Runtime(format!("{:#?}", x)),
119
            MinionError::Other(x) => Runtime(format!("{:#?}", x)),
120
            MinionError::NotImplemented(x) => RuntimeNotImplemented(x),
121
            x => Runtime(format!("unknown minion_rs error: {:#?}", x)),
122
1037
        })?;
123

            
124
1037
        let mut status = Complete(HasSolutions);
125
1037
        if *(USER_TERMINATED.lock()).unwrap() {
126
            status = Incomplete(UserTerminated);
127
1037
        } else if *(ANY_SOLUTIONS.lock()).unwrap() {
128
1037
            status = Complete(NoSolutions);
129
1037
        }
130
1037
        Ok(SolveSuccess {
131
1037
            stats: get_solver_stats(),
132
1037
            status,
133
1037
        })
134
1037
    }
135

            
136
    fn solve_mut(
137
        &mut self,
138
        callback: SolverMutCallback,
139
        _: private::Internal,
140
    ) -> Result<SolveSuccess, SolverError> {
141
        Err(OpNotImplemented("solve_mut".into()))
142
    }
143

            
144
1037
    fn load_model(&mut self, model: ConjureModel, _: private::Internal) -> Result<(), SolverError> {
145
1037
        let mut minion_model = MinionModel::new();
146
1037
        parse_vars(&model, &mut minion_model)?;
147
1037
        parse_exprs(&model, &mut minion_model)?;
148
1037
        self.model = Some(minion_model);
149
1037
        Ok(())
150
1037
    }
151

            
152
1037
    fn get_family(&self) -> SolverFamily {
153
1037
        SolverFamily::Minion
154
1037
    }
155

            
156
1037
    fn get_name(&self) -> Option<String> {
157
1037
        Some("Minion".to_owned())
158
1037
    }
159
}
160

            
161
1037
fn parse_vars(
162
1037
    conjure_model: &ConjureModel,
163
1037
    minion_model: &mut MinionModel,
164
1037
) -> Result<(), SolverError> {
165
    // TODO (niklasdewally): remove unused vars?
166
    // TODO (niklasdewally): ensure all vars references are used.
167

            
168
3026
    for (name, variable) in conjure_model.variables.iter() {
169
3026
        parse_var(name, variable, minion_model)?;
170
    }
171
1037
    Ok(())
172
1037
}
173

            
174
3026
fn parse_var(
175
3026
    name: &conjure_ast::Name,
176
3026
    var: &conjure_ast::DecisionVariable,
177
3026
    minion_model: &mut MinionModel,
178
3026
) -> Result<(), SolverError> {
179
3026
    match &var.domain {
180
2703
        conjure_ast::Domain::IntDomain(ranges) => _parse_intdomain_var(name, ranges, minion_model),
181
323
        conjure_ast::Domain::BoolDomain => _parse_booldomain_var(name, minion_model),
182
        x => Err(ModelFeatureNotSupported(format!("{:?}", x))),
183
    }
184
3026
}
185

            
186
2703
fn _parse_intdomain_var(
187
2703
    name: &conjure_ast::Name,
188
2703
    ranges: &[conjure_ast::Range<i32>],
189
2703
    minion_model: &mut MinionModel,
190
2703
) -> Result<(), SolverError> {
191
2703
    let str_name = _name_to_string(name.to_owned());
192
2703

            
193
2703
    if ranges.len() != 1 {
194
        return Err(ModelFeatureNotImplemented(format!(
195
            "variable {:?} has {:?} ranges. Multiple ranges / SparseBound is not yet supported.",
196
            str_name,
197
            ranges.len()
198
        )));
199
2703
    }
200

            
201
2703
    let range = ranges.first().ok_or(ModelInvalid(format!(
202
2703
        "variable {:?} has no range",
203
2703
        str_name
204
2703
    )))?;
205

            
206
2703
    let (low, high) = match range {
207
2686
        conjure_ast::Range::Bounded(x, y) => Ok((x.to_owned(), y.to_owned())),
208
17
        conjure_ast::Range::Single(x) => Ok((x.to_owned(), x.to_owned())),
209
        #[allow(unreachable_patterns)]
210
        x => Err(ModelFeatureNotSupported(format!("{:?}", x))),
211
    }?;
212

            
213
2703
    _try_add_var(
214
2703
        str_name.to_owned(),
215
2703
        minion_ast::VarDomain::Bound(low, high),
216
2703
        minion_model,
217
2703
    )
218
2703
}
219

            
220
323
fn _parse_booldomain_var(
221
323
    name: &conjure_ast::Name,
222
323
    minion_model: &mut MinionModel,
223
323
) -> Result<(), SolverError> {
224
323
    let str_name = _name_to_string(name.to_owned());
225
323
    _try_add_var(
226
323
        str_name.to_owned(),
227
323
        minion_ast::VarDomain::Bool,
228
323
        minion_model,
229
323
    )
230
323
}
231

            
232
3026
fn _try_add_var(
233
3026
    name: minion_ast::VarName,
234
3026
    domain: minion_ast::VarDomain,
235
3026
    minion_model: &mut MinionModel,
236
3026
) -> Result<(), SolverError> {
237
3026
    minion_model
238
3026
        .named_variables
239
3026
        .add_var(name.clone(), domain)
240
3026
        .ok_or(ModelInvalid(format!(
241
3026
            "variable {:?} is defined twice",
242
3026
            name
243
3026
        )))
244
3026
}
245

            
246
1037
fn parse_exprs(
247
1037
    conjure_model: &ConjureModel,
248
1037
    minion_model: &mut MinionModel,
249
1037
) -> Result<(), SolverError> {
250
6443
    for expr in conjure_model.get_constraints_vec().iter() {
251
        // TODO: top level false / trues should not go to the solver to begin with
252
        // ... but changing this at this stage would require rewriting the tester
253
986
        use crate::metadata::Metadata;
254
986
        use conjure_ast::Atom;
255
986
        use conjure_ast::Expression as Expr;
256
986
        use conjure_ast::Literal::*;
257

            
258
51
        match expr {
259
            // top level false
260
            Expr::Atomic(_, Atom::Literal(Bool(false))) => {
261
17
                minion_model.constraints.push(minion_ast::Constraint::False);
262
17
                return Ok(());
263
            }
264
            // top level true
265
            Expr::Atomic(_, Atom::Literal(Bool(true))) => {
266
34
                minion_model.constraints.push(minion_ast::Constraint::True);
267
34
                return Ok(());
268
            }
269

            
270
            _ => {
271
6392
                parse_expr(expr.to_owned(), minion_model)?;
272
            }
273
        }
274
    }
275
986
    Ok(())
276
1037
}
277

            
278
6392
fn parse_expr(
279
6392
    expr: conjure_ast::Expression,
280
6392
    minion_model: &mut MinionModel,
281
6392
) -> Result<(), SolverError> {
282
6392
    minion_model.constraints.push(read_expr(expr)?);
283
6392
    Ok(())
284
6392
}
285

            
286
21556
fn read_expr(expr: conjure_ast::Expression) -> Result<minion_ast::Constraint, SolverError> {
287
21556
    match expr {
288
        conjure_ast::Expression::Atomic(_metadata, reff) => Ok(minion_ast::Constraint::WLiteral(
289
            read_var(reff.into())?,
290
            minion_ast::Constant::Integer(1),
291
        )),
292
4998
        conjure_ast::Expression::SumLeq(_metadata, lhs, rhs) => Ok(minion_ast::Constraint::SumLeq(
293
4998
            read_vars(lhs)?,
294
4998
            read_var(*rhs)?,
295
        )),
296
4947
        conjure_ast::Expression::SumGeq(_metadata, lhs, rhs) => Ok(minion_ast::Constraint::SumGeq(
297
4947
            read_vars(lhs)?,
298
4947
            read_var(*rhs)?,
299
        )),
300
5440
        conjure_ast::Expression::Ineq(_metadata, a, b, c) => Ok(minion_ast::Constraint::Ineq(
301
5440
            read_var(*a)?,
302
5440
            read_var(*b)?,
303
5440
            minion_ast::Constant::Integer(read_const(*c)?),
304
        )),
305
544
        conjure_ast::Expression::Neq(_metadata, a, b) => {
306
544
            Ok(minion_ast::Constraint::DisEq(read_var(*a)?, read_var(*b)?))
307
        }
308
272
        conjure_ast::Expression::DivEqUndefZero(_metadata, a, b, c) => {
309
272
            Ok(minion_ast::Constraint::DivUndefZero(
310
272
                (read_var(a.into())?, read_var(b.into())?),
311
272
                read_var(c.into())?,
312
            ))
313
        }
314
272
        conjure_ast::Expression::ModuloEqUndefZero(_metadata, a, b, c) => {
315
272
            Ok(minion_ast::Constraint::ModuloUndefZero(
316
272
                (read_var(a.into())?, read_var(b.into())?),
317
272
                read_var(c.into())?,
318
            ))
319
        }
320
4352
        conjure_ast::Expression::Or(_metadata, exprs) => Ok(minion_ast::Constraint::WatchedOr(
321
4352
            exprs
322
4352
                .iter()
323
15164
                .map(|x| read_expr(x.to_owned()))
324
4352
                .collect::<Result<Vec<minion_ast::Constraint>, SolverError>>()?,
325
        )),
326
        conjure_ast::Expression::And(_metadata, exprs) => Ok(minion_ast::Constraint::WatchedAnd(
327
            exprs
328
                .iter()
329
                .map(|x| read_expr(x.to_owned()))
330
                .collect::<Result<Vec<minion_ast::Constraint>, SolverError>>()?,
331
        )),
332
476
        conjure_ast::Expression::Eq(_metadata, a, b) => {
333
476
            Ok(minion_ast::Constraint::Eq(read_var(*a)?, read_var(*b)?))
334
        }
335

            
336
255
        conjure_ast::Expression::WatchedLiteral(_metadata, name, k) => {
337
255
            Ok(minion_ast::Constraint::WLiteral(
338
255
                minion_ast::Var::NameRef(_name_to_string(name)),
339
255
                minion_ast::Constant::Integer(read_const_1(k)?),
340
            ))
341
        }
342
        conjure_ast::Expression::Reify(_metadata, e, v) => Ok(minion_ast::Constraint::Reify(
343
            Box::new(read_expr(*e)?),
344
            read_var(*v)?,
345
        )),
346

            
347
        conjure_ast::Expression::AuxDeclaration(_metadata, name, expr) => {
348
            Ok(minion_ast::Constraint::Eq(
349
                read_var(conjure_ast::Expression::Atomic(
350
                    _metadata,
351
                    conjure_ast::Atom::Reference(name),
352
                ))?,
353
                read_var(*expr)?,
354
            ))
355
        }
356
        x => Err(ModelFeatureNotSupported(format!("{:?}", x))),
357
    }
358
21556
}
359
9945
fn read_vars(exprs: Vec<conjure_ast::Expression>) -> Result<Vec<minion_ast::Var>, SolverError> {
360
9945
    let mut minion_vars: Vec<minion_ast::Var> = vec![];
361
39576
    for expr in exprs {
362
29631
        let minion_var = read_var(expr)?;
363
29631
        minion_vars.push(minion_var);
364
    }
365
9945
    Ok(minion_vars)
366
9945
}
367

            
368
54128
fn read_var(e: conjure_ast::Expression) -> Result<minion_ast::Var, SolverError> {
369
54128
    // a minion var is either a reference or a "var as const"
370
54128
    match _read_ref(e.clone()) {
371
43316
        Ok(name) => Ok(minion_ast::Var::NameRef(name)),
372
10812
        Err(_) => match read_const(e) {
373
10812
            Ok(n) => Ok(minion_ast::Var::ConstantAsVar(n)),
374
            Err(x) => Err(x),
375
        },
376
    }
377
54128
}
378

            
379
54128
fn _read_ref(e: conjure_ast::Expression) -> Result<String, SolverError> {
380
54128
    let name = match e {
381
43316
        conjure_ast::Expression::Atomic(_metadata, conjure_ast::Atom::Reference(n)) => Ok(n),
382
10812
        x => Err(ModelInvalid(format!(
383
10812
            "expected a reference, but got `{0:?}`",
384
10812
            x
385
10812
        ))),
386
10812
    }?;
387

            
388
43316
    let str_name = _name_to_string(name);
389
43316
    Ok(str_name)
390
54128
}
391

            
392
16252
fn read_const(e: conjure_ast::Expression) -> Result<i32, SolverError> {
393
16252
    match e {
394
16252
        conjure_ast::Expression::Atomic(_, conjure_ast::Atom::Literal(x)) => Ok(read_const_1(x)?),
395
        x => Err(ModelInvalid(format!(
396
            "expected a constant, but got `{0:?}`",
397
            x
398
        ))),
399
    }
400
16252
}
401

            
402
16507
fn read_const_1(k: conjure_ast::Literal) -> Result<i32, SolverError> {
403
16507
    match k {
404
16218
        conjure_ast::Literal::Int(n) => Ok(n),
405
238
        conjure_ast::Literal::Bool(true) => Ok(1),
406
51
        conjure_ast::Literal::Bool(false) => Ok(0),
407
        x => Err(ModelInvalid(format!(
408
            "expected a constant, but got `{0:?}`",
409
            x
410
        ))),
411
    }
412
16507
}
413

            
414
46597
fn _name_to_string(name: conjure_ast::Name) -> String {
415
46597
    match name {
416
44557
        conjure_ast::Name::UserName(x) => x,
417
2040
        conjure_ast::Name::MachineName(x) => format!("__conjure_machine_name_{}", x),
418
    }
419
46597
}
420

            
421
#[allow(clippy::unwrap_used)]
422
1037
fn get_solver_stats() -> SolverStats {
423
1037
    SolverStats {
424
1037
        nodes: get_from_table("Nodes".into()).map(|x| x.parse::<u64>().unwrap()),
425
1037
        ..Default::default()
426
1037
    }
427
1037
}