1
use std::{
2
    collections::HashMap,
3
    ffi::CString,
4
    sync::{Mutex, MutexGuard},
5
};
6

            
7
use crate::ffi;
8
use crate::{ast::*, error::*, scoped_ptr::Scoped};
9
use anyhow::anyhow;
10

            
11
// TODO: allow passing of options.
12

            
13
/// Callback function used to capture results from minion as they are generated.
14
/// Should return `true` if search is to continue, `false` otherwise.
15
///
16
/// Consider using a global mutex (or other static variable) to use these results
17
/// elsewhere.
18
///
19
/// For example:
20
///
21
/// ```
22
///   use minion_rs::ast::*;
23
///   use minion_rs::run_minion;
24
///   use std::{
25
///       collections::HashMap,
26
///       sync::{Mutex, MutexGuard},
27
///   };
28
///
29
///   // More elaborate data-structures are possible, but for sake of example store
30
///   // a vector of solution sets.
31
///   static ALL_SOLUTIONS: Mutex<Vec<HashMap<VarName,Constant>>>  = Mutex::new(vec![]);
32
///   
33
///   fn callback(solutions: HashMap<VarName,Constant>) -> bool {
34
///       let mut guard = ALL_SOLUTIONS.lock().unwrap();
35
///       guard.push(solutions);
36
///       true
37
///   }
38
///    
39
///   // Build and run the model.
40
///   let mut model = Model::new();
41
///
42
///   // ... omitted for brevity ...
43
/// # model
44
/// #     .named_variables
45
/// #     .add_var("x".to_owned(), VarDomain::Bound(1, 3));
46
/// # model
47
/// #     .named_variables
48
/// #     .add_var("y".to_owned(), VarDomain::Bound(2, 4));
49
/// # model
50
/// #     .named_variables
51
/// #     .add_var("z".to_owned(), VarDomain::Bound(1, 5));
52
/// #
53
/// # let leq = Constraint::SumLeq(
54
/// #     vec![
55
/// #         Var::NameRef("x".to_owned()),
56
/// #         Var::NameRef("y".to_owned()),
57
/// #         Var::NameRef("z".to_owned()),
58
/// #     ],
59
/// #     Var::ConstantAsVar(4),
60
/// # );
61
/// #
62
/// # let geq = Constraint::SumGeq(
63
/// #     vec![
64
/// #         Var::NameRef("x".to_owned()),
65
/// #         Var::NameRef("y".to_owned()),
66
/// #         Var::NameRef("z".to_owned()),
67
/// #     ],
68
/// #     Var::ConstantAsVar(4),
69
/// # );
70
/// #
71
/// # let ineq = Constraint::Ineq(
72
/// #     Var::NameRef("x".to_owned()),
73
/// #     Var::NameRef("y".to_owned()),
74
/// #     Constant::Integer(-1),
75
/// # );
76
/// #
77
/// # model.constraints.push(leq);
78
/// # model.constraints.push(geq);
79
/// # model.constraints.push(ineq);
80
///  
81
///   let res = run_minion(model, callback);
82
///   res.expect("Error occurred");
83
///
84
///   // Get solutions
85
///   let guard = ALL_SOLUTIONS.lock().unwrap();
86
///   let solution_set_1 = &(guard.get(0).unwrap());
87
///
88
///   let x1 = solution_set_1.get("x").unwrap();
89
///   let y1 = solution_set_1.get("y").unwrap();
90
///   let z1 = solution_set_1.get("z").unwrap();
91
/// #
92
/// # // TODO: this test would be better with an example with >1 solution.
93
/// # assert_eq!(guard.len(),1);
94
/// # assert_eq!(*x1,Constant::Integer(1));
95
/// # assert_eq!(*y1,Constant::Integer(2));
96
/// # assert_eq!(*z1,Constant::Integer(1));
97
/// ```
98
pub type Callback = fn(solution_set: HashMap<VarName, Constant>) -> bool;
99

            
100
// Use globals to pass things between run_minion and the callback function.
101
// Minion is (currently) single threaded anyways so the Mutexs' don't matter.
102

            
103
// the current callback function
104
static CALLBACK: Mutex<Option<Callback>> = Mutex::new(None);
105

            
106
// the variables we want to return, and their ordering in the print matrix
107
static PRINT_VARS: Mutex<Option<Vec<VarName>>> = Mutex::new(None);
108

            
109
#[no_mangle]
110
7
unsafe extern "C" fn run_callback() -> bool {
111
7
    // get printvars from static PRINT_VARS if they exist.
112
7
    // if not, return true and continue search.
113
7

            
114
7
    // Mutex poisoning is probably panic worthy.
115
7
    #[allow(clippy::unwrap_used)]
116
7
    let mut guard: MutexGuard<'_, Option<Vec<VarName>>> = PRINT_VARS.lock().unwrap();
117
7

            
118
7
    if guard.is_none() {
119
        return true;
120
7
    }
121

            
122
7
    let print_vars = match &mut *guard {
123
7
        Some(x) => x,
124
        None => unreachable!(),
125
    };
126

            
127
7
    if print_vars.is_empty() {
128
        return true;
129
7
    }
130
7

            
131
7
    // build nice solutions view to be used by callback
132
7
    let mut solutions: HashMap<VarName, Constant> = HashMap::new();
133

            
134
21
    for (i, var) in print_vars.iter().enumerate() {
135
21
        let solution_int: i32 = ffi::printMatrix_getValue(i as _);
136
21
        let solution: Constant = Constant::Integer(solution_int);
137
21
        solutions.insert(var.to_string(), solution);
138
21
    }
139

            
140
    #[allow(clippy::unwrap_used)]
141
7
    match *CALLBACK.lock().unwrap() {
142
        None => true,
143
7
        Some(func) => func(solutions),
144
    }
145
7
}
146

            
147
/// Run Minion on the given [Model].
148
///
149
/// The given [callback](Callback) is ran whenever a new solution set is found.
150

            
151
// Turn it into a warning for this function, cant unwarn it directly above callback wierdness
152
#[allow(clippy::unwrap_used)]
153
7
pub fn run_minion(model: Model, callback: Callback) -> Result<(), MinionError> {
154
7
    // Mutex poisoning is probably panic worthy.
155
7
    *CALLBACK.lock().unwrap() = Some(callback);
156
7

            
157
7
    unsafe {
158
7
        let options = Scoped::new(ffi::newSearchOptions(), |x| ffi::searchOptions_free(x as _));
159
7
        let args = Scoped::new(ffi::newSearchMethod(), |x| ffi::searchMethod_free(x as _));
160
7
        let instance = Scoped::new(ffi::newInstance(), |x| ffi::instance_free(x as _));
161
7

            
162
7
        convert_model_to_raw(&instance, &model)?;
163

            
164
7
        let res = ffi::runMinion(options.ptr, args.ptr, instance.ptr, Some(run_callback));
165
7
        match res {
166
7
            0 => Ok(()),
167
            x => Err(MinionError::from(RuntimeError::from(x))),
168
        }
169
    }
170
7
}
171

            
172
7
unsafe fn convert_model_to_raw(
173
7
    instance: &Scoped<ffi::ProbSpec_CSPInstance>,
174
7
    model: &Model,
175
7
) -> Result<(), MinionError> {
176
7
    /*******************************/
177
7
    /*        Add variables        */
178
7
    /*******************************/
179
7

            
180
7
    /*
181
7
     * Add variables to:
182
7
     * 1. symbol table
183
7
     * 2. print matrix
184
7
     * 3. search vars
185
7
     *
186
7
     * These are all done in the order saved in the SymbolTable.
187
7
     */
188
7

            
189
7
    let search_vars = Scoped::new(ffi::vec_var_new(), |x| ffi::vec_var_free(x as _));
190
7

            
191
7
    // store variables and the order they will be returned inside rust for later use.
192
7
    #[allow(clippy::unwrap_used)]
193
7
    let mut print_vars_guard = PRINT_VARS.lock().unwrap();
194
7
    *print_vars_guard = Some(vec![]);
195

            
196
21
    for var_name in model.named_variables.get_variable_order() {
197
21
        let c_str = CString::new(var_name.clone()).map_err(|_| {
198
            anyhow!(
199
                "Variable name {:?} contains a null character.",
200
                var_name.clone()
201
            )
202
21
        })?;
203

            
204
21
        let vartype = model
205
21
            .named_variables
206
21
            .get_vartype(var_name.clone())
207
21
            .ok_or(anyhow!("Could not get var type for {:?}", var_name.clone()))?;
208

            
209
21
        let (vartype_raw, domain_low, domain_high) = match vartype {
210
21
            VarDomain::Bound(a, b) => Ok((ffi::VariableType_VAR_BOUND, a, b)),
211
            x => Err(MinionError::NotImplemented(format!("{:?}", x))),
212
        }?;
213

            
214
21
        ffi::newVar_ffi(
215
21
            instance.ptr,
216
21
            c_str.as_ptr() as _,
217
21
            vartype_raw,
218
21
            domain_low,
219
21
            domain_high,
220
21
        );
221
21

            
222
21
        let var = ffi::getVarByName(instance.ptr, c_str.as_ptr() as _);
223
21

            
224
21
        ffi::printMatrix_addVar(instance.ptr, var);
225
21

            
226
21
        // add to the print vars stored in rust so to remember
227
21
        // the order for callback function.
228
21

            
229
21
        #[allow(clippy::unwrap_used)]
230
21
        (*print_vars_guard).as_mut().unwrap().push(var_name.clone());
231
21

            
232
21
        ffi::vec_var_push_back(search_vars.ptr, var);
233
    }
234

            
235
7
    let search_order = Scoped::new(
236
7
        ffi::newSearchOrder(search_vars.ptr, ffi::VarOrderEnum_ORDER_STATIC, false),
237
7
        |x| ffi::searchOrder_free(x as _),
238
7
    );
239
7

            
240
7
    ffi::instance_addSearchOrder(instance.ptr, search_order.ptr);
241

            
242
    /*********************************/
243
    /*        Add constraints        */
244
    /*********************************/
245

            
246
28
    for constraint in &model.constraints {
247
        // 1. get constraint type and create C++ constraint object
248
        // 2. run through arguments and add them to the constraint
249
        // 3. add constraint to instance
250

            
251
21
        let constraint_type = get_constraint_type(constraint)?;
252
21
        let raw_constraint = Scoped::new(ffi::newConstraintBlob(constraint_type), |x| {
253
21
            ffi::constraint_free(x as _)
254
21
        });
255
21

            
256
21
        constraint_add_args(instance.ptr, raw_constraint.ptr, constraint)?;
257
21
        ffi::instance_addConstraint(instance.ptr, raw_constraint.ptr);
258
    }
259

            
260
7
    Ok(())
261
7
}
262

            
263
21
unsafe fn get_constraint_type(constraint: &Constraint) -> Result<u32, MinionError> {
264
21
    match constraint {
265
7
        Constraint::SumGeq(_, _) => Ok(ffi::ConstraintType_CT_GEQSUM),
266
7
        Constraint::SumLeq(_, _) => Ok(ffi::ConstraintType_CT_LEQSUM),
267
7
        Constraint::Ineq(_, _, _) => Ok(ffi::ConstraintType_CT_INEQ),
268
        Constraint::Eq(_, _) => Ok(ffi::ConstraintType_CT_EQ),
269
        #[allow(unreachable_patterns)]
270
        x => Err(MinionError::NotImplemented(format!(
271
            "Constraint not implemented {:?}",
272
            x,
273
        ))),
274
    }
275
21
}
276

            
277
21
unsafe fn constraint_add_args(
278
21
    i: *mut ffi::ProbSpec_CSPInstance,
279
21
    r_constr: *mut ffi::ProbSpec_ConstraintBlob,
280
21
    constr: &Constraint,
281
21
) -> Result<(), MinionError> {
282
21
    match constr {
283
7
        Constraint::SumGeq(lhs_vars, rhs_var) => {
284
7
            read_vars(i, r_constr, lhs_vars)?;
285
7
            read_var(i, r_constr, rhs_var)?;
286
7
            Ok(())
287
        }
288
7
        Constraint::SumLeq(lhs_vars, rhs_var) => {
289
7
            read_vars(i, r_constr, lhs_vars)?;
290
7
            read_var(i, r_constr, rhs_var)?;
291
7
            Ok(())
292
        }
293
7
        Constraint::Ineq(var1, var2, c) => {
294
7
            read_var(i, r_constr, var1)?;
295
7
            read_var(i, r_constr, var2)?;
296
7
            read_const(r_constr, c)?;
297
7
            Ok(())
298
        }
299
        Constraint::Eq(var1, var2) => {
300
            read_var(i, r_constr, var1)?;
301
            read_var(i, r_constr, var2)?;
302
            Ok(())
303
        }
304
        #[allow(unreachable_patterns)]
305
        x => Err(MinionError::NotImplemented(format!("{:?}", x))),
306
    }
307
21
}
308

            
309
// DO NOT call manually - this assumes that all needed vars are already in the symbol table.
310
// TODO not happy with this just assuming the name is in the symbol table
311
14
unsafe fn read_vars(
312
14
    instance: *mut ffi::ProbSpec_CSPInstance,
313
14
    raw_constraint: *mut ffi::ProbSpec_ConstraintBlob,
314
14
    vars: &Vec<Var>,
315
14
) -> Result<(), MinionError> {
316
14
    let raw_vars = Scoped::new(ffi::vec_var_new(), |x| ffi::vec_var_free(x as _));
317
56
    for var in vars {
318
42
        let raw_var = match var {
319
42
            Var::NameRef(name) => {
320
42
                let c_str = CString::new(name.clone()).map_err(|_| {
321
                    anyhow!(
322
                        "Variable name {:?} contains a null character.",
323
                        name.clone()
324
                    )
325
42
                })?;
326
42
                ffi::getVarByName(instance, c_str.as_ptr() as _)
327
            }
328
            Var::ConstantAsVar(n) => ffi::constantAsVar(*n),
329
        };
330

            
331
42
        ffi::vec_var_push_back(raw_vars.ptr, raw_var);
332
    }
333

            
334
14
    ffi::constraint_addVarList(raw_constraint, raw_vars.ptr);
335
14

            
336
14
    Ok(())
337
14
}
338

            
339
28
unsafe fn read_var(
340
28
    instance: *mut ffi::ProbSpec_CSPInstance,
341
28
    raw_constraint: *mut ffi::ProbSpec_ConstraintBlob,
342
28
    var: &Var,
343
28
) -> Result<(), MinionError> {
344
28
    let raw_vars = Scoped::new(ffi::vec_var_new(), |x| ffi::vec_var_free(x as _));
345
28
    let raw_var = match var {
346
14
        Var::NameRef(name) => {
347
14
            let c_str = CString::new(name.clone()).map_err(|_| {
348
                anyhow!(
349
                    "Variable name {:?} contains a null character.",
350
                    name.clone()
351
                )
352
14
            })?;
353
14
            ffi::getVarByName(instance, c_str.as_ptr() as _)
354
        }
355
14
        Var::ConstantAsVar(n) => ffi::constantAsVar(*n),
356
    };
357

            
358
28
    ffi::vec_var_push_back(raw_vars.ptr, raw_var);
359
28
    ffi::constraint_addVarList(raw_constraint, raw_vars.ptr);
360
28

            
361
28
    Ok(())
362
28
}
363

            
364
7
unsafe fn read_const(
365
7
    raw_constraint: *mut ffi::ProbSpec_ConstraintBlob,
366
7
    constant: &Constant,
367
7
) -> Result<(), MinionError> {
368
7
    let raw_consts = Scoped::new(ffi::vec_int_new(), |x| ffi::vec_var_free(x as _));
369

            
370
7
    let val = match constant {
371
7
        Constant::Integer(n) => Ok(n),
372
        x => Err(MinionError::NotImplemented(format!("{:?}", x))),
373
    }?;
374

            
375
7
    ffi::vec_int_push_back(raw_consts.ptr, *val);
376
7
    ffi::constraint_addConstantList(raw_constraint, raw_consts.ptr);
377
7

            
378
7
    Ok(())
379
7
}