1
use conjure_cp::{
2
    Model,
3
    ast::{
4
        Atom, DeclarationPtr, Domain, Expression, Literal, Metadata, Moo, Name, Range, Reference,
5
        SymbolTable, eval_constant,
6
    },
7
    into_matrix_expr, matrix_expr,
8
    rule_engine::{Rule, get_all_rules, get_rule_by_name, resolve_rule_sets, rewrite_naive},
9
    settings::{QuantifiedExpander, SolverFamily, set_comprehension_expander},
10
    solver::{Solver, adaptors},
11
};
12
#[allow(unused_imports)]
13
#[allow(clippy::single_component_path_imports)] // ensure this is linked so we can lookup rules
14
use conjure_cp_rules;
15
use pretty_assertions::assert_eq;
16
use std::process::exit;
17
use uniplate::Uniplate;
18

            
19
#[test]
20
2
fn rules_present() {
21
2
    let rules = get_all_rules();
22
2
    assert!(!rules.is_empty());
23
2
}
24

            
25
#[test]
26
2
fn sum_of_constants() {
27
2
    let valid_sum_expression = Expression::Sum(
28
2
        Metadata::new(),
29
2
        Moo::new(matrix_expr![
30
2
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
31
2
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
32
2
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(3))),
33
2
        ]),
34
2
    );
35

            
36
2
    let invalid_sum_expression = Expression::Sum(
37
2
        Metadata::new(),
38
2
        Moo::new(matrix_expr![
39
2
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
40
2
            Expression::Atomic(
41
2
                Metadata::new(),
42
2
                Atom::Reference(Reference::new(DeclarationPtr::new_find(
43
2
                    Name::user("a"),
44
2
                    Domain::bool()
45
2
                ))),
46
2
            ),
47
2
        ]),
48
2
    );
49

            
50
2
    assert_eq!(evaluate_sum_of_constants(&valid_sum_expression), Some(6));
51

            
52
2
    assert_eq!(evaluate_sum_of_constants(&invalid_sum_expression), None);
53
2
}
54

            
55
8
fn evaluate_sum_of_constants(expr: &Expression) -> Option<i32> {
56
8
    match expr {
57
8
        Expression::Sum(_metadata, expressions) => {
58
8
            let expressions = (**expressions).clone().unwrap_list()?;
59
8
            let mut sum = 0;
60
20
            for e in expressions {
61
16
                match e {
62
16
                    Expression::Atomic(_, Atom::Literal(Literal::Int(value))) => {
63
16
                        sum += value;
64
16
                    }
65
4
                    _ => return None,
66
                }
67
            }
68
4
            Some(sum)
69
        }
70
        _ => None,
71
    }
72
8
}
73

            
74
#[test]
75
2
fn recursive_sum_of_constants() {
76
2
    let a = Atom::Reference(Reference::new(DeclarationPtr::new_find(
77
2
        Name::user("a"),
78
2
        Domain::int(vec![Range::Bounded(1, 5)]),
79
2
    )));
80
2
    let complex_expression = Expression::Eq(
81
2
        Metadata::new(),
82
2
        Moo::new(Expression::Sum(
83
2
            Metadata::new(),
84
2
            Moo::new(matrix_expr![
85
2
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
86
2
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
87
2
                Expression::Sum(
88
2
                    Metadata::new(),
89
2
                    Moo::new(matrix_expr![
90
2
                        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
91
2
                        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
92
2
                    ]),
93
2
                ),
94
2
                Expression::Atomic(Metadata::new(), a.clone()),
95
2
            ]),
96
2
        )),
97
2
        Moo::new(Expression::Atomic(
98
2
            Metadata::new(),
99
2
            Atom::Literal(Literal::Int(3)),
100
2
        )),
101
2
    );
102
2
    let correct_simplified_expression = Expression::Eq(
103
2
        Metadata::new(),
104
2
        Moo::new(Expression::Sum(
105
2
            Metadata::new(),
106
2
            Moo::new(matrix_expr![
107
2
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
108
2
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
109
2
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(3))),
110
2
                Expression::Atomic(Metadata::new(), a),
111
2
            ]),
112
2
        )),
113
2
        Moo::new(Expression::Atomic(
114
2
            Metadata::new(),
115
2
            Atom::Literal(Literal::Int(3)),
116
2
        )),
117
2
    );
118

            
119
2
    let simplified_expression = simplify_expression(complex_expression);
120
2
    assert_eq!(simplified_expression, correct_simplified_expression);
121
2
}
122

            
123
14
fn simplify_expression(expr: Expression) -> Expression {
124
14
    match expr {
125
4
        Expression::Sum(_metadata, expressions) => {
126
4
            let expressions = Moo::unwrap_or_clone(expressions).unwrap_list().unwrap();
127
4
            if let Some(result) = evaluate_sum_of_constants(&Expression::Sum(
128
4
                Metadata::new(),
129
4
                Moo::new(into_matrix_expr![expressions.clone()]),
130
4
            )) {
131
2
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(result)))
132
            } else {
133
2
                Expression::Sum(
134
2
                    Metadata::new(),
135
2
                    Moo::new(into_matrix_expr![
136
2
                        expressions.into_iter().map(simplify_expression).collect()
137
2
                    ]),
138
2
                )
139
            }
140
        }
141
2
        Expression::Eq(_metadata, left, right) => Expression::Eq(
142
2
            Metadata::new(),
143
2
            Moo::new(simplify_expression(Moo::unwrap_or_clone(left))),
144
2
            Moo::new(simplify_expression(Moo::unwrap_or_clone(right))),
145
2
        ),
146
        Expression::Geq(_metadata, left, right) => Expression::Geq(
147
            Metadata::new(),
148
            Moo::new(simplify_expression(Moo::unwrap_or_clone(left))),
149
            Moo::new(simplify_expression(Moo::unwrap_or_clone(right))),
150
        ),
151
8
        _ => expr,
152
    }
153
14
}
154

            
155
#[test]
156
2
fn rule_sum_constants() {
157
2
    let sum_constants = get_rule_by_name("partial_evaluator").unwrap();
158
2
    let unwrap_sum = get_rule_by_name("remove_unit_vector_sum").unwrap();
159

            
160
2
    let mut expr = Expression::Sum(
161
2
        Metadata::new(),
162
2
        Moo::new(matrix_expr![
163
2
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
164
2
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
165
2
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(3))),
166
2
        ]),
167
2
    );
168

            
169
2
    expr = sum_constants
170
2
        .apply(&expr, &SymbolTable::new())
171
2
        .unwrap()
172
2
        .new_expression;
173
2
    expr = unwrap_sum
174
2
        .apply(&expr, &SymbolTable::new())
175
2
        .unwrap()
176
2
        .new_expression;
177

            
178
2
    assert_eq!(
179
        expr,
180
2
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(6)))
181
    );
182
2
}
183

            
184
#[test]
185
2
fn rule_sum_geq() {
186
2
    let introduce_sumgeq = get_rule_by_name("introduce_weighted_sumleq_sumgeq").unwrap();
187

            
188
2
    let mut expr = Expression::Geq(
189
2
        Metadata::new(),
190
2
        Moo::new(Expression::Sum(
191
2
            Metadata::new(),
192
2
            Moo::new(matrix_expr![
193
2
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
194
2
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
195
2
            ]),
196
2
        )),
197
2
        Moo::new(Expression::Atomic(
198
2
            Metadata::new(),
199
2
            Atom::Literal(Literal::Int(3)),
200
2
        )),
201
2
    );
202

            
203
2
    expr = introduce_sumgeq
204
2
        .apply(&expr, &SymbolTable::new())
205
2
        .unwrap()
206
2
        .new_expression;
207

            
208
2
    assert_eq!(
209
        expr,
210
2
        Expression::FlatSumGeq(
211
2
            Metadata::new(),
212
2
            vec![
213
2
                Atom::Literal(Literal::Int(1)),
214
2
                Atom::Literal(Literal::Int(2)),
215
2
            ],
216
2
            Atom::Literal(Literal::Int(3))
217
2
        )
218
    );
219
2
}
220

            
221
///
222
/// Reduce and solve:
223
/// ```text
224
/// find a,b,c : int(1..3)
225
/// such that a + b + c <= 2 + 3 - 1
226
/// such that a < b
227
/// ```
228
#[test]
229
2
fn reduce_solve_xyz() {
230
2
    println!("Rules: {:?}", get_all_rules());
231
2
    let sum_constants = get_rule_by_name("partial_evaluator").unwrap();
232
2
    let unwrap_sum = get_rule_by_name("remove_unit_vector_sum").unwrap();
233
2
    let lt_to_leq = get_rule_by_name("lt_to_leq").unwrap();
234
2
    let leq_to_ineq = get_rule_by_name("x_leq_y_plus_k_to_ineq").unwrap();
235
2
    let introduce_sumleq = get_rule_by_name("introduce_weighted_sumleq_sumgeq").unwrap();
236

            
237
    // 2 + 3 - 1
238
2
    let mut expr1 = Expression::Sum(
239
2
        Metadata::new(),
240
2
        Moo::new(matrix_expr![
241
2
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
242
2
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(3))),
243
2
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(-1))),
244
2
        ]),
245
2
    );
246

            
247
2
    expr1 = sum_constants
248
2
        .apply(&expr1, &SymbolTable::new())
249
2
        .unwrap()
250
2
        .new_expression;
251
2
    expr1 = unwrap_sum
252
2
        .apply(&expr1, &SymbolTable::new())
253
2
        .unwrap()
254
2
        .new_expression;
255
2
    assert_eq!(
256
        expr1,
257
2
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(4)))
258
    );
259

            
260
2
    let a = Atom::Reference(Reference::new(DeclarationPtr::new_find(
261
2
        Name::user("a"),
262
2
        Domain::int(vec![Range::Bounded(1, 5)]),
263
2
    )));
264
2
    let b = Atom::Reference(Reference::new(DeclarationPtr::new_find(
265
2
        Name::user("b"),
266
2
        Domain::int(vec![Range::Bounded(1, 5)]),
267
2
    )));
268
2
    let c = Atom::Reference(Reference::new(DeclarationPtr::new_find(
269
2
        Name::user("c"),
270
2
        Domain::int(vec![Range::Bounded(1, 5)]),
271
2
    )));
272

            
273
    // a + b + c = 4
274
2
    expr1 = Expression::Leq(
275
2
        Metadata::new(),
276
2
        Moo::new(Expression::Sum(
277
2
            Metadata::new(),
278
2
            Moo::new(matrix_expr![
279
2
                Expression::Atomic(Metadata::new(), a.clone()),
280
2
                Expression::Atomic(Metadata::new(), b.clone()),
281
2
                Expression::Atomic(Metadata::new(), c.clone()),
282
2
            ]),
283
2
        )),
284
2
        Moo::new(expr1),
285
2
    );
286
2
    expr1 = introduce_sumleq
287
2
        .apply(&expr1, &SymbolTable::new())
288
2
        .unwrap()
289
2
        .new_expression;
290
2
    assert_eq!(
291
        expr1,
292
2
        Expression::FlatSumLeq(
293
2
            Metadata::new(),
294
2
            vec![a.clone(), b.clone(), c],
295
2
            Atom::Literal(Literal::Int(4))
296
2
        )
297
    );
298

            
299
    // a < b
300
2
    let mut expr2 = Expression::Lt(
301
2
        Metadata::new(),
302
2
        Moo::new(Expression::Atomic(Metadata::new(), a.clone())),
303
2
        Moo::new(Expression::Atomic(Metadata::new(), b.clone())),
304
2
    );
305
2
    expr2 = lt_to_leq
306
2
        .apply(&expr2, &SymbolTable::new())
307
2
        .unwrap()
308
2
        .new_expression;
309

            
310
2
    expr2 = leq_to_ineq
311
2
        .apply(&expr2, &SymbolTable::new())
312
2
        .unwrap()
313
2
        .new_expression;
314
2
    assert_eq!(
315
        expr2,
316
2
        Expression::FlatIneq(
317
2
            Metadata::new(),
318
2
            Moo::new(a),
319
2
            Moo::new(b),
320
2
            Box::new(Literal::Int(-1)),
321
2
        )
322
    );
323

            
324
2
    let mut model = Model::new(Default::default());
325
2
    *model.constraints_mut() = vec![expr1, expr2];
326

            
327
2
    model
328
2
        .symbols_mut()
329
2
        .insert(DeclarationPtr::new_find(
330
2
            Name::user("a"),
331
2
            Domain::int(vec![Range::Bounded(1, 3)]),
332
2
        ))
333
2
        .unwrap();
334
2
    model
335
2
        .symbols_mut()
336
2
        .insert(DeclarationPtr::new_find(
337
2
            Name::user("b"),
338
2
            Domain::int(vec![Range::Bounded(1, 3)]),
339
2
        ))
340
2
        .unwrap();
341
2
    model
342
2
        .symbols_mut()
343
2
        .insert(DeclarationPtr::new_find(
344
2
            Name::user("c"),
345
2
            Domain::int(vec![Range::Bounded(1, 3)]),
346
2
        ))
347
2
        .unwrap();
348

            
349
2
    let solver: Solver = Solver::new(adaptors::Minion::new());
350
2
    let solver = solver.load_model(model).unwrap();
351
2
    solver.solve(Box::new(|_| true)).unwrap();
352
2
}
353

            
354
#[test]
355
2
fn rule_remove_double_negation() {
356
2
    let remove_double_negation = get_rule_by_name("remove_double_negation").unwrap();
357

            
358
2
    let mut expr = Expression::Not(
359
2
        Metadata::new(),
360
2
        Moo::new(Expression::Not(
361
2
            Metadata::new(),
362
2
            Moo::new(Expression::Atomic(
363
2
                Metadata::new(),
364
2
                Atom::Literal(Literal::Bool(true)),
365
2
            )),
366
2
        )),
367
2
    );
368

            
369
2
    expr = remove_double_negation
370
2
        .apply(&expr, &SymbolTable::new())
371
2
        .unwrap()
372
2
        .new_expression;
373

            
374
2
    assert_eq!(
375
        expr,
376
2
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)))
377
    );
378
2
}
379

            
380
#[test]
381
2
fn remove_trivial_and_or() {
382
2
    let remove_trivial_and = get_rule_by_name("remove_unit_vector_and").unwrap();
383
2
    let remove_trivial_or = get_rule_by_name("remove_unit_vector_or").unwrap();
384

            
385
2
    let mut expr_and = Expression::And(
386
2
        Metadata::new(),
387
2
        Moo::new(matrix_expr![Expression::Atomic(
388
2
            Metadata::new(),
389
2
            Atom::Literal(Literal::Bool(true)),
390
2
        )]),
391
2
    );
392
2
    let mut expr_or = Expression::Or(
393
2
        Metadata::new(),
394
2
        Moo::new(matrix_expr![Expression::Atomic(
395
2
            Metadata::new(),
396
2
            Atom::Literal(Literal::Bool(false)),
397
2
        )]),
398
2
    );
399

            
400
2
    expr_and = remove_trivial_and
401
2
        .apply(&expr_and, &SymbolTable::new())
402
2
        .unwrap()
403
2
        .new_expression;
404
2
    expr_or = remove_trivial_or
405
2
        .apply(&expr_or, &SymbolTable::new())
406
2
        .unwrap()
407
2
        .new_expression;
408

            
409
2
    assert_eq!(
410
        expr_and,
411
2
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)))
412
    );
413
2
    assert_eq!(
414
        expr_or,
415
2
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false)))
416
    );
417
2
}
418

            
419
#[test]
420
2
fn rule_distribute_not_over_and() {
421
2
    let distribute_not_over_and = get_rule_by_name("distribute_not_over_and").unwrap();
422

            
423
2
    let a = Atom::Reference(Reference::new(DeclarationPtr::new_find(
424
2
        Name::user("a"),
425
2
        Domain::bool(),
426
2
    )));
427

            
428
2
    let b = Atom::Reference(Reference::new(DeclarationPtr::new_find(
429
2
        Name::user("b"),
430
2
        Domain::bool(),
431
2
    )));
432

            
433
2
    let mut expr = Expression::Not(
434
2
        Metadata::new(),
435
2
        Moo::new(Expression::And(
436
2
            Metadata::new(),
437
2
            Moo::new(matrix_expr![
438
2
                Expression::Atomic(Metadata::new(), a.clone()),
439
2
                Expression::Atomic(Metadata::new(), b.clone()),
440
2
            ]),
441
2
        )),
442
2
    );
443

            
444
2
    expr = distribute_not_over_and
445
2
        .apply(&expr, &SymbolTable::new())
446
2
        .unwrap()
447
2
        .new_expression;
448

            
449
2
    assert_eq!(
450
        expr,
451
2
        Expression::Or(
452
2
            Metadata::new(),
453
2
            Moo::new(matrix_expr![
454
2
                Expression::Not(
455
2
                    Metadata::new(),
456
2
                    Moo::new(Expression::Atomic(Metadata::new(), a))
457
2
                ),
458
2
                Expression::Not(
459
2
                    Metadata::new(),
460
2
                    Moo::new(Expression::Atomic(Metadata::new(), b))
461
2
                ),
462
2
            ])
463
2
        )
464
    );
465
2
}
466

            
467
#[test]
468
2
fn rule_distribute_not_over_or() {
469
2
    let distribute_not_over_or = get_rule_by_name("distribute_not_over_or").unwrap();
470

            
471
2
    let a = Atom::Reference(Reference::new(DeclarationPtr::new_find(
472
2
        Name::user("a"),
473
2
        Domain::bool(),
474
2
    )));
475

            
476
2
    let b = Atom::Reference(Reference::new(DeclarationPtr::new_find(
477
2
        Name::user("b"),
478
2
        Domain::bool(),
479
2
    )));
480

            
481
2
    let mut expr = Expression::Not(
482
2
        Metadata::new(),
483
2
        Moo::new(Expression::Or(
484
2
            Metadata::new(),
485
2
            Moo::new(matrix_expr![
486
2
                Expression::Atomic(Metadata::new(), a.clone()),
487
2
                Expression::Atomic(Metadata::new(), b.clone()),
488
2
            ]),
489
2
        )),
490
2
    );
491

            
492
2
    expr = distribute_not_over_or
493
2
        .apply(&expr, &SymbolTable::new())
494
2
        .unwrap()
495
2
        .new_expression;
496

            
497
2
    assert_eq!(
498
        expr,
499
2
        Expression::And(
500
2
            Metadata::new(),
501
2
            Moo::new(matrix_expr![
502
2
                Expression::Not(
503
2
                    Metadata::new(),
504
2
                    Moo::new(Expression::Atomic(Metadata::new(), a))
505
2
                ),
506
2
                Expression::Not(
507
2
                    Metadata::new(),
508
2
                    Moo::new(Expression::Atomic(Metadata::new(), b))
509
2
                ),
510
2
            ])
511
2
        )
512
    );
513
2
}
514

            
515
#[test]
516
2
fn rule_distribute_not_over_and_not_changed() {
517
2
    let distribute_not_over_and = get_rule_by_name("distribute_not_over_and").unwrap();
518

            
519
2
    let expr = Expression::Not(
520
2
        Metadata::new(),
521
2
        Moo::new(Expression::Atomic(
522
2
            Metadata::new(),
523
2
            Atom::Reference(Reference::new(DeclarationPtr::new_find(
524
2
                Name::user("a"),
525
2
                Domain::int(vec![Range::Bounded(1, 5)]),
526
2
            ))),
527
2
        )),
528
2
    );
529

            
530
2
    let result = distribute_not_over_and.apply(&expr, &SymbolTable::new());
531

            
532
2
    assert!(result.is_err());
533
2
}
534

            
535
#[test]
536
2
fn rule_distribute_not_over_or_not_changed() {
537
2
    let distribute_not_over_or = get_rule_by_name("distribute_not_over_or").unwrap();
538

            
539
2
    let expr = Expression::Not(
540
2
        Metadata::new(),
541
2
        Moo::new(Expression::Atomic(
542
2
            Metadata::new(),
543
2
            Atom::Reference(Reference::new(DeclarationPtr::new_find(
544
2
                Name::user("a"),
545
2
                Domain::int(vec![Range::Bounded(1, 5)]),
546
2
            ))),
547
2
        )),
548
2
    );
549

            
550
2
    let result = distribute_not_over_or.apply(&expr, &SymbolTable::new());
551

            
552
2
    assert!(result.is_err());
553
2
}
554

            
555
#[test]
556
2
fn rule_distribute_or_over_and() {
557
2
    let distribute_or_over_and = get_rule_by_name("distribute_or_over_and").unwrap();
558

            
559
2
    let d1 = Atom::Reference(Reference::new(DeclarationPtr::new_find(
560
2
        Name::Machine(1),
561
2
        Domain::bool(),
562
2
    )));
563

            
564
2
    let d2 = Atom::Reference(Reference::new(DeclarationPtr::new_find(
565
2
        Name::Machine(2),
566
2
        Domain::bool(),
567
2
    )));
568

            
569
2
    let expr = Expression::Or(
570
2
        Metadata::new(),
571
2
        Moo::new(matrix_expr![
572
2
            Expression::And(
573
2
                Metadata::new(),
574
2
                Moo::new(matrix_expr![
575
2
                    Expression::Atomic(Metadata::new(), d1.clone()),
576
2
                    Expression::Atomic(Metadata::new(), d2.clone()),
577
2
                ]),
578
2
            ),
579
2
            Expression::Atomic(Metadata::new(), d2.clone()),
580
2
        ]),
581
2
    );
582

            
583
2
    let red = distribute_or_over_and
584
2
        .apply(&expr, &SymbolTable::new())
585
2
        .unwrap();
586

            
587
2
    assert_eq!(
588
        red.new_expression,
589
2
        Expression::And(
590
2
            Metadata::new(),
591
2
            Moo::new(matrix_expr![
592
2
                Expression::Or(
593
2
                    Metadata::new(),
594
2
                    Moo::new(matrix_expr![
595
2
                        Expression::Atomic(Metadata::new(), d2.clone()),
596
2
                        Expression::Atomic(Metadata::new(), d1),
597
2
                    ])
598
2
                ),
599
2
                Expression::Or(
600
2
                    Metadata::new(),
601
2
                    Moo::new(matrix_expr![
602
2
                        Expression::Atomic(Metadata::new(), d2.clone()),
603
2
                        Expression::Atomic(Metadata::new(), d2),
604
2
                    ])
605
2
                ),
606
2
            ])
607
2
        ),
608
    );
609
2
}
610

            
611
///
612
/// Reduce and solve:
613
/// ```text
614
/// find a,b,c : int(1..3)
615
/// such that a + b + c = 4
616
/// such that a < b
617
/// ```
618
///
619
/// This test uses the rewrite function to simplify the expression instead
620
/// of applying the rules manually.
621
#[test]
622
2
fn rewrite_solve_xyz() {
623
2
    println!("Rules: {:?}", get_all_rules());
624
2
    set_comprehension_expander(QuantifiedExpander::Native);
625

            
626
2
    let rule_sets = match resolve_rule_sets(SolverFamily::Minion, &["Constant"]) {
627
2
        Ok(rs) => rs,
628
        Err(e) => {
629
            eprintln!("Error resolving rule sets: {e}");
630
            exit(1);
631
        }
632
    };
633
2
    println!("Rule sets: {rule_sets:?}");
634

            
635
    // Create variables and domains
636
2
    let decl_a = DeclarationPtr::new_find(Name::user("a"), Domain::int(vec![Range::Bounded(1, 5)]));
637

            
638
2
    let decl_b = DeclarationPtr::new_find(Name::user("b"), Domain::int(vec![Range::Bounded(1, 5)]));
639

            
640
2
    let decl_c = DeclarationPtr::new_find(Name::user("c"), Domain::int(vec![Range::Bounded(1, 5)]));
641

            
642
2
    let a = Atom::Reference(Reference::new(decl_a.clone()));
643
2
    let b = Atom::Reference(Reference::new(decl_b.clone()));
644
2
    let c = Atom::Reference(Reference::new(decl_c.clone()));
645

            
646
    // Construct nested expression
647
2
    let nested_expr = Expression::And(
648
2
        Metadata::new(),
649
2
        Moo::new(matrix_expr![
650
2
            Expression::Eq(
651
2
                Metadata::new(),
652
2
                Moo::new(Expression::Sum(
653
2
                    Metadata::new(),
654
2
                    Moo::new(matrix_expr![
655
2
                        Expression::Atomic(Metadata::new(), a.clone()),
656
2
                        Expression::Atomic(Metadata::new(), b.clone()),
657
2
                        Expression::Atomic(Metadata::new(), c),
658
2
                    ]),
659
2
                )),
660
2
                Moo::new(Expression::Atomic(
661
2
                    Metadata::new(),
662
2
                    Atom::Literal(Literal::Int(4)),
663
2
                )),
664
2
            ),
665
2
            Expression::Lt(
666
2
                Metadata::new(),
667
2
                Moo::new(Expression::Atomic(Metadata::new(), a)),
668
2
                Moo::new(Expression::Atomic(Metadata::new(), b)),
669
2
            ),
670
2
        ]),
671
2
    );
672

            
673
2
    let rule_sets = match resolve_rule_sets(SolverFamily::Minion, &["Constant"]) {
674
2
        Ok(rs) => rs,
675
        Err(e) => {
676
            eprintln!("Error resolving rule sets: {e}");
677
            exit(1);
678
        }
679
    };
680

            
681
    // Apply rewrite function to the nested expression
682
2
    let mut model = Model::new(Default::default());
683

            
684
    // Insert variables and domains
685
2
    model.symbols_mut().insert(decl_a).unwrap();
686
2
    model.symbols_mut().insert(decl_b).unwrap();
687
2
    model.symbols_mut().insert(decl_c).unwrap();
688

            
689
2
    *model.constraints_mut() = vec![nested_expr];
690

            
691
2
    model = rewrite_naive(&model, &rule_sets, true).unwrap();
692
2
    let rewritten_expr = model.constraints();
693

            
694
    // Check if the expression is in its simplest form
695

            
696
2
    assert!(rewritten_expr.iter().all(is_simple));
697

            
698
2
    let solver: Solver = Solver::new(adaptors::Minion::new());
699
2
    let solver = solver.load_model(model).unwrap();
700
2
    solver.solve(Box::new(|_| true)).unwrap();
701
2
}
702

            
703
struct RuleResult<'a> {
704
    #[allow(dead_code)]
705
    rule: &'a Rule<'a>,
706
    new_expression: Expression,
707
}
708

            
709
/// # Returns
710
/// - True if `expression` is in its simplest form.
711
/// - False otherwise.
712
6
pub fn is_simple(expression: &Expression) -> bool {
713
6
    let rules = get_all_rules();
714
6
    let mut new = expression.clone();
715
6
    while let Some(step) = is_simple_iteration(&new, &rules) {
716
        new = step;
717
    }
718
6
    new == *expression
719
6
}
720

            
721
/// # Returns
722
/// - Some(<new_expression>) after applying the first applicable rule to `expr` or a sub-expression.
723
/// - None if no rule is applicable to the expression or any sub-expression.
724
6
fn is_simple_iteration<'a>(
725
6
    expression: &'a Expression,
726
6
    rules: &'a Vec<&'a Rule<'a>>,
727
6
) -> Option<Expression> {
728
6
    let rule_results = apply_all_rules(expression, rules);
729
6
    if let Some(new) = choose_rewrite(&rule_results) {
730
        return Some(new);
731
    } else {
732
6
        let mut sub = expression.children();
733
6
        for i in 0..sub.len() {
734
            if let Some(new) = is_simple_iteration(&sub[i], rules) {
735
                sub[i] = new;
736
                return Some(expression.with_children(sub.clone()));
737
            }
738
        }
739
    }
740
6
    None // No rules applicable to this branch of the expression
741
6
}
742

            
743
/// # Returns
744
/// - A list of RuleResults after applying all rules to `expression`.
745
/// - An empty list if no rules are applicable.
746
6
fn apply_all_rules<'a>(
747
6
    expression: &'a Expression,
748
6
    rules: &'a Vec<&'a Rule<'a>>,
749
6
) -> Vec<RuleResult<'a>> {
750
6
    let mut results = Vec::new();
751
822
    for rule in rules {
752
822
        match rule.apply(expression, &SymbolTable::new()) {
753
            Ok(red) => {
754
                results.push(RuleResult {
755
                    rule,
756
                    new_expression: red.new_expression,
757
                });
758
            }
759
822
            Err(_) => continue,
760
        }
761
    }
762
6
    results
763
6
}
764

            
765
/// # Returns
766
/// - Some(<new_expression>) after applying the first rule in `results`.
767
/// - None if `results` is empty.
768
6
fn choose_rewrite(results: &[RuleResult]) -> Option<Expression> {
769
6
    if results.is_empty() {
770
6
        return None;
771
    }
772
    // Return the first result for now
773
    // println!("Applying rule: {:?}", results[0].rule);
774
    Some(results[0].new_expression.clone())
775
6
}
776

            
777
#[test]
778
2
fn eval_const_int() {
779
2
    let expr = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
780
2
    let result = eval_constant(&expr);
781
2
    assert_eq!(result, Some(Literal::Int(1)));
782
2
}
783

            
784
#[test]
785
2
fn eval_const_bool() {
786
2
    let expr = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
787
2
    let result = eval_constant(&expr);
788
2
    assert_eq!(result, Some(Literal::Bool(true)));
789
2
}
790

            
791
#[test]
792
2
fn eval_const_and() {
793
2
    let expr = Expression::And(
794
2
        Metadata::new(),
795
2
        Moo::new(matrix_expr![
796
2
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
797
2
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
798
2
        ]),
799
2
    );
800
2
    let result = eval_constant(&expr);
801
2
    assert_eq!(result, Some(Literal::Bool(false)));
802
2
}
803

            
804
#[test]
805
2
fn eval_const_ref() {
806
2
    let expr = Expression::Atomic(
807
2
        Metadata::new(),
808
2
        Atom::Reference(Reference::new(DeclarationPtr::new_find(
809
2
            Name::user("a"),
810
2
            Domain::int(vec![Range::Bounded(1, 5)]),
811
2
        ))),
812
2
    );
813
2
    let result = eval_constant(&expr);
814
2
    assert_eq!(result, None);
815
2
}
816

            
817
#[test]
818
2
fn eval_const_nested_ref() {
819
2
    let expr = Expression::Sum(
820
2
        Metadata::new(),
821
2
        Moo::new(matrix_expr![
822
2
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
823
2
            Expression::And(
824
2
                Metadata::new(),
825
2
                Moo::new(matrix_expr![
826
2
                    Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
827
2
                    Expression::Atomic(
828
2
                        Metadata::new(),
829
2
                        Atom::Reference(Reference::new(DeclarationPtr::new_find(
830
2
                            Name::user("a"),
831
2
                            Domain::int(vec![Range::Bounded(1, 5)])
832
2
                        )))
833
2
                    ),
834
2
                ]),
835
2
            ),
836
2
        ]),
837
2
    );
838
2
    let result = eval_constant(&expr);
839
2
    assert_eq!(result, None);
840
2
}
841

            
842
#[test]
843
2
fn eval_const_eq_int() {
844
2
    let expr = Expression::Eq(
845
2
        Metadata::new(),
846
2
        Moo::new(Expression::Atomic(
847
2
            Metadata::new(),
848
2
            Atom::Literal(Literal::Int(1)),
849
2
        )),
850
2
        Moo::new(Expression::Atomic(
851
2
            Metadata::new(),
852
2
            Atom::Literal(Literal::Int(1)),
853
2
        )),
854
2
    );
855
2
    let result = eval_constant(&expr);
856
2
    assert_eq!(result, Some(Literal::Bool(true)));
857
2
}
858

            
859
#[test]
860
2
fn eval_const_eq_bool() {
861
2
    let expr = Expression::Eq(
862
2
        Metadata::new(),
863
2
        Moo::new(Expression::Atomic(
864
2
            Metadata::new(),
865
2
            Atom::Literal(Literal::Bool(true)),
866
2
        )),
867
2
        Moo::new(Expression::Atomic(
868
2
            Metadata::new(),
869
2
            Atom::Literal(Literal::Bool(true)),
870
2
        )),
871
2
    );
872
2
    let result = eval_constant(&expr);
873
2
    assert_eq!(result, Some(Literal::Bool(true)));
874
2
}
875

            
876
#[test]
877
2
fn eval_const_eq_mixed() {
878
2
    let expr = Expression::Eq(
879
2
        Metadata::new(),
880
2
        Moo::new(Expression::Atomic(
881
2
            Metadata::new(),
882
2
            Atom::Literal(Literal::Int(1)),
883
2
        )),
884
2
        Moo::new(Expression::Atomic(
885
2
            Metadata::new(),
886
2
            Atom::Literal(Literal::Bool(true)),
887
2
        )),
888
2
    );
889
2
    let result = eval_constant(&expr);
890
2
    assert_eq!(result, None);
891
2
}
892

            
893
#[test]
894
2
fn eval_const_sum_mixed() {
895
2
    let expr = Expression::Sum(
896
2
        Metadata::new(),
897
2
        Moo::new(matrix_expr![
898
2
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
899
2
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
900
2
        ]),
901
2
    );
902
2
    let result = eval_constant(&expr);
903
2
    assert_eq!(result, None);
904
2
}
905

            
906
#[test]
907
2
fn eval_const_sum_xyz() {
908
2
    let expr = Expression::And(
909
2
        Metadata::new(),
910
2
        Moo::new(matrix_expr![
911
2
            Expression::Eq(
912
2
                Metadata::new(),
913
2
                Moo::new(Expression::Sum(
914
2
                    Metadata::new(),
915
2
                    Moo::new(matrix_expr![
916
2
                        Expression::Atomic(
917
2
                            Metadata::new(),
918
2
                            Atom::Reference(Reference::new(DeclarationPtr::new_find(
919
2
                                Name::user("x"),
920
2
                                Domain::int(vec![Range::Bounded(1, 5)])
921
2
                            )))
922
2
                        ),
923
2
                        Expression::Atomic(
924
2
                            Metadata::new(),
925
2
                            Atom::Reference(Reference::new(DeclarationPtr::new_find(
926
2
                                Name::user("y"),
927
2
                                Domain::int(vec![Range::Bounded(1, 5)])
928
2
                            )))
929
2
                        ),
930
2
                        Expression::Atomic(
931
2
                            Metadata::new(),
932
2
                            Atom::Reference(Reference::new(DeclarationPtr::new_find(
933
2
                                Name::user("z"),
934
2
                                Domain::int(vec![Range::Bounded(1, 5)])
935
2
                            )))
936
2
                        ),
937
2
                    ])
938
2
                )),
939
2
                Moo::new(Expression::Atomic(
940
2
                    Metadata::new(),
941
2
                    Atom::Literal(Literal::Int(4)),
942
2
                )),
943
2
            ),
944
2
            Expression::Geq(
945
2
                Metadata::new(),
946
2
                Moo::new(Expression::Atomic(
947
2
                    Metadata::new(),
948
2
                    Atom::Reference(Reference::new(DeclarationPtr::new_find(
949
2
                        Name::user("x"),
950
2
                        Domain::int(vec![Range::Bounded(1, 5)])
951
2
                    )))
952
2
                )),
953
2
                Moo::new(Expression::Atomic(
954
2
                    Metadata::new(),
955
2
                    Atom::Reference(Reference::new(DeclarationPtr::new_find(
956
2
                        Name::user("y"),
957
2
                        Domain::int(vec![Range::Bounded(1, 5)])
958
2
                    )))
959
2
                )),
960
2
            ),
961
2
        ]),
962
2
    );
963
2
    let result = eval_constant(&expr);
964
2
    assert_eq!(result, None);
965
2
}
966

            
967
#[test]
968
2
fn eval_const_or() {
969
2
    let expr = Expression::Or(
970
2
        Metadata::new(),
971
2
        Moo::new(matrix_expr![
972
2
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
973
2
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
974
2
        ]),
975
2
    );
976
2
    let result = eval_constant(&expr);
977
2
    assert_eq!(result, Some(Literal::Bool(false)));
978
2
}