1
use conjure_cp::{
2
    Model,
3
    ast::{
4
        Atom, DeclarationPtr, Domain, Expression, Literal, Metadata, Moo, Name, Range, Reference,
5
        SymbolTable, eval_constant,
6
    },
7
    into_matrix_expr, matrix_expr,
8
    rule_engine::{Rule, get_all_rules, get_rule_by_name, resolve_rule_sets, rewrite_naive},
9
    solver::{Solver, SolverFamily, adaptors},
10
};
11
#[allow(unused_imports)]
12
#[allow(clippy::single_component_path_imports)] // ensure this is linked so we can lookup rules
13
use conjure_cp_rules;
14
use pretty_assertions::assert_eq;
15
use std::process::exit;
16
use uniplate::Uniplate;
17

            
18
#[test]
19
fn rules_present() {
20
    let rules = get_all_rules();
21
    assert!(!rules.is_empty());
22
}
23

            
24
#[test]
25
fn sum_of_constants() {
26
    let valid_sum_expression = Expression::Sum(
27
        Metadata::new(),
28
        Moo::new(matrix_expr![
29
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
30
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
31
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(3))),
32
        ]),
33
    );
34

            
35
    let invalid_sum_expression = Expression::Sum(
36
        Metadata::new(),
37
        Moo::new(matrix_expr![
38
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
39
            Expression::Atomic(
40
                Metadata::new(),
41
                Atom::Reference(Reference::new(DeclarationPtr::new_var(
42
                    Name::user("a"),
43
                    Domain::bool()
44
                ))),
45
            ),
46
        ]),
47
    );
48

            
49
    assert_eq!(evaluate_sum_of_constants(&valid_sum_expression), Some(6));
50

            
51
    assert_eq!(evaluate_sum_of_constants(&invalid_sum_expression), None);
52
}
53

            
54
fn evaluate_sum_of_constants(expr: &Expression) -> Option<i32> {
55
    match expr {
56
        Expression::Sum(_metadata, expressions) => {
57
            let expressions = (**expressions).clone().unwrap_list()?;
58
            let mut sum = 0;
59
            for e in expressions {
60
                match e {
61
                    Expression::Atomic(_, Atom::Literal(Literal::Int(value))) => {
62
                        sum += value;
63
                    }
64
                    _ => return None,
65
                }
66
            }
67
            Some(sum)
68
        }
69
        _ => None,
70
    }
71
}
72

            
73
#[test]
74
fn recursive_sum_of_constants() {
75
    let a = Atom::Reference(Reference::new(DeclarationPtr::new_var(
76
        Name::user("a"),
77
        Domain::int(vec![Range::Bounded(1, 5)]),
78
    )));
79
    let complex_expression = Expression::Eq(
80
        Metadata::new(),
81
        Moo::new(Expression::Sum(
82
            Metadata::new(),
83
            Moo::new(matrix_expr![
84
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
85
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
86
                Expression::Sum(
87
                    Metadata::new(),
88
                    Moo::new(matrix_expr![
89
                        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
90
                        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
91
                    ]),
92
                ),
93
                Expression::Atomic(Metadata::new(), a.clone()),
94
            ]),
95
        )),
96
        Moo::new(Expression::Atomic(
97
            Metadata::new(),
98
            Atom::Literal(Literal::Int(3)),
99
        )),
100
    );
101
    let correct_simplified_expression = Expression::Eq(
102
        Metadata::new(),
103
        Moo::new(Expression::Sum(
104
            Metadata::new(),
105
            Moo::new(matrix_expr![
106
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
107
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
108
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(3))),
109
                Expression::Atomic(Metadata::new(), a),
110
            ]),
111
        )),
112
        Moo::new(Expression::Atomic(
113
            Metadata::new(),
114
            Atom::Literal(Literal::Int(3)),
115
        )),
116
    );
117

            
118
    let simplified_expression = simplify_expression(complex_expression);
119
    assert_eq!(simplified_expression, correct_simplified_expression);
120
}
121

            
122
fn simplify_expression(expr: Expression) -> Expression {
123
    match expr {
124
        Expression::Sum(_metadata, expressions) => {
125
            let expressions = Moo::unwrap_or_clone(expressions).unwrap_list().unwrap();
126
            if let Some(result) = evaluate_sum_of_constants(&Expression::Sum(
127
                Metadata::new(),
128
                Moo::new(into_matrix_expr![expressions.clone()]),
129
            )) {
130
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(result)))
131
            } else {
132
                Expression::Sum(
133
                    Metadata::new(),
134
                    Moo::new(into_matrix_expr![
135
                        expressions.into_iter().map(simplify_expression).collect()
136
                    ]),
137
                )
138
            }
139
        }
140
        Expression::Eq(_metadata, left, right) => Expression::Eq(
141
            Metadata::new(),
142
            Moo::new(simplify_expression(Moo::unwrap_or_clone(left))),
143
            Moo::new(simplify_expression(Moo::unwrap_or_clone(right))),
144
        ),
145
        Expression::Geq(_metadata, left, right) => Expression::Geq(
146
            Metadata::new(),
147
            Moo::new(simplify_expression(Moo::unwrap_or_clone(left))),
148
            Moo::new(simplify_expression(Moo::unwrap_or_clone(right))),
149
        ),
150
        _ => expr,
151
    }
152
}
153

            
154
#[test]
155
fn rule_sum_constants() {
156
    let sum_constants = get_rule_by_name("partial_evaluator").unwrap();
157
    let unwrap_sum = get_rule_by_name("remove_unit_vector_sum").unwrap();
158

            
159
    let mut expr = Expression::Sum(
160
        Metadata::new(),
161
        Moo::new(matrix_expr![
162
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
163
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
164
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(3))),
165
        ]),
166
    );
167

            
168
    expr = sum_constants
169
        .apply(&expr, &SymbolTable::new())
170
        .unwrap()
171
        .new_expression;
172
    expr = unwrap_sum
173
        .apply(&expr, &SymbolTable::new())
174
        .unwrap()
175
        .new_expression;
176

            
177
    assert_eq!(
178
        expr,
179
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(6)))
180
    );
181
}
182

            
183
#[test]
184
fn rule_sum_geq() {
185
    let introduce_sumgeq = get_rule_by_name("introduce_weighted_sumleq_sumgeq").unwrap();
186

            
187
    let mut expr = Expression::Geq(
188
        Metadata::new(),
189
        Moo::new(Expression::Sum(
190
            Metadata::new(),
191
            Moo::new(matrix_expr![
192
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
193
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
194
            ]),
195
        )),
196
        Moo::new(Expression::Atomic(
197
            Metadata::new(),
198
            Atom::Literal(Literal::Int(3)),
199
        )),
200
    );
201

            
202
    expr = introduce_sumgeq
203
        .apply(&expr, &SymbolTable::new())
204
        .unwrap()
205
        .new_expression;
206

            
207
    assert_eq!(
208
        expr,
209
        Expression::FlatSumGeq(
210
            Metadata::new(),
211
            vec![
212
                Atom::Literal(Literal::Int(1)),
213
                Atom::Literal(Literal::Int(2)),
214
            ],
215
            Atom::Literal(Literal::Int(3))
216
        )
217
    );
218
}
219

            
220
///
221
/// Reduce and solve:
222
/// ```text
223
/// find a,b,c : int(1..3)
224
/// such that a + b + c <= 2 + 3 - 1
225
/// such that a < b
226
/// ```
227
#[test]
228
fn reduce_solve_xyz() {
229
    println!("Rules: {:?}", get_all_rules());
230
    let sum_constants = get_rule_by_name("partial_evaluator").unwrap();
231
    let unwrap_sum = get_rule_by_name("remove_unit_vector_sum").unwrap();
232
    let lt_to_leq = get_rule_by_name("lt_to_leq").unwrap();
233
    let leq_to_ineq = get_rule_by_name("x_leq_y_plus_k_to_ineq").unwrap();
234
    let introduce_sumleq = get_rule_by_name("introduce_weighted_sumleq_sumgeq").unwrap();
235

            
236
    // 2 + 3 - 1
237
    let mut expr1 = Expression::Sum(
238
        Metadata::new(),
239
        Moo::new(matrix_expr![
240
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
241
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(3))),
242
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(-1))),
243
        ]),
244
    );
245

            
246
    expr1 = sum_constants
247
        .apply(&expr1, &SymbolTable::new())
248
        .unwrap()
249
        .new_expression;
250
    expr1 = unwrap_sum
251
        .apply(&expr1, &SymbolTable::new())
252
        .unwrap()
253
        .new_expression;
254
    assert_eq!(
255
        expr1,
256
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(4)))
257
    );
258

            
259
    let a = Atom::Reference(Reference::new(DeclarationPtr::new_var(
260
        Name::user("a"),
261
        Domain::int(vec![Range::Bounded(1, 5)]),
262
    )));
263
    let b = Atom::Reference(Reference::new(DeclarationPtr::new_var(
264
        Name::user("b"),
265
        Domain::int(vec![Range::Bounded(1, 5)]),
266
    )));
267
    let c = Atom::Reference(Reference::new(DeclarationPtr::new_var(
268
        Name::user("c"),
269
        Domain::int(vec![Range::Bounded(1, 5)]),
270
    )));
271

            
272
    // a + b + c = 4
273
    expr1 = Expression::Leq(
274
        Metadata::new(),
275
        Moo::new(Expression::Sum(
276
            Metadata::new(),
277
            Moo::new(matrix_expr![
278
                Expression::Atomic(Metadata::new(), a.clone()),
279
                Expression::Atomic(Metadata::new(), b.clone()),
280
                Expression::Atomic(Metadata::new(), c.clone()),
281
            ]),
282
        )),
283
        Moo::new(expr1),
284
    );
285
    expr1 = introduce_sumleq
286
        .apply(&expr1, &SymbolTable::new())
287
        .unwrap()
288
        .new_expression;
289
    assert_eq!(
290
        expr1,
291
        Expression::FlatSumLeq(
292
            Metadata::new(),
293
            vec![a.clone(), b.clone(), c],
294
            Atom::Literal(Literal::Int(4))
295
        )
296
    );
297

            
298
    // a < b
299
    let mut expr2 = Expression::Lt(
300
        Metadata::new(),
301
        Moo::new(Expression::Atomic(Metadata::new(), a.clone())),
302
        Moo::new(Expression::Atomic(Metadata::new(), b.clone())),
303
    );
304
    expr2 = lt_to_leq
305
        .apply(&expr2, &SymbolTable::new())
306
        .unwrap()
307
        .new_expression;
308

            
309
    expr2 = leq_to_ineq
310
        .apply(&expr2, &SymbolTable::new())
311
        .unwrap()
312
        .new_expression;
313
    assert_eq!(
314
        expr2,
315
        Expression::FlatIneq(
316
            Metadata::new(),
317
            Moo::new(a),
318
            Moo::new(b),
319
            Box::new(Literal::Int(-1)),
320
        )
321
    );
322

            
323
    let mut model = Model::new(Default::default());
324
    *model.as_submodel_mut().constraints_mut() = vec![expr1, expr2];
325

            
326
    model
327
        .as_submodel_mut()
328
        .symbols_mut()
329
        .insert(DeclarationPtr::new_var(
330
            Name::user("a"),
331
            Domain::int(vec![Range::Bounded(1, 3)]),
332
        ))
333
        .unwrap();
334
    model
335
        .as_submodel_mut()
336
        .symbols_mut()
337
        .insert(DeclarationPtr::new_var(
338
            Name::user("b"),
339
            Domain::int(vec![Range::Bounded(1, 3)]),
340
        ))
341
        .unwrap();
342
    model
343
        .as_submodel_mut()
344
        .symbols_mut()
345
        .insert(DeclarationPtr::new_var(
346
            Name::user("c"),
347
            Domain::int(vec![Range::Bounded(1, 3)]),
348
        ))
349
        .unwrap();
350

            
351
    let solver: Solver = Solver::new(adaptors::Minion::new());
352
    let solver = solver.load_model(model).unwrap();
353
    solver.solve(Box::new(|_| true)).unwrap();
354
}
355

            
356
#[test]
357
fn rule_remove_double_negation() {
358
    let remove_double_negation = get_rule_by_name("remove_double_negation").unwrap();
359

            
360
    let mut expr = Expression::Not(
361
        Metadata::new(),
362
        Moo::new(Expression::Not(
363
            Metadata::new(),
364
            Moo::new(Expression::Atomic(
365
                Metadata::new(),
366
                Atom::Literal(Literal::Bool(true)),
367
            )),
368
        )),
369
    );
370

            
371
    expr = remove_double_negation
372
        .apply(&expr, &SymbolTable::new())
373
        .unwrap()
374
        .new_expression;
375

            
376
    assert_eq!(
377
        expr,
378
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)))
379
    );
380
}
381

            
382
#[test]
383
fn remove_trivial_and_or() {
384
    let remove_trivial_and = get_rule_by_name("remove_unit_vector_and").unwrap();
385
    let remove_trivial_or = get_rule_by_name("remove_unit_vector_or").unwrap();
386

            
387
    let mut expr_and = Expression::And(
388
        Metadata::new(),
389
        Moo::new(matrix_expr![Expression::Atomic(
390
            Metadata::new(),
391
            Atom::Literal(Literal::Bool(true)),
392
        )]),
393
    );
394
    let mut expr_or = Expression::Or(
395
        Metadata::new(),
396
        Moo::new(matrix_expr![Expression::Atomic(
397
            Metadata::new(),
398
            Atom::Literal(Literal::Bool(false)),
399
        )]),
400
    );
401

            
402
    expr_and = remove_trivial_and
403
        .apply(&expr_and, &SymbolTable::new())
404
        .unwrap()
405
        .new_expression;
406
    expr_or = remove_trivial_or
407
        .apply(&expr_or, &SymbolTable::new())
408
        .unwrap()
409
        .new_expression;
410

            
411
    assert_eq!(
412
        expr_and,
413
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)))
414
    );
415
    assert_eq!(
416
        expr_or,
417
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false)))
418
    );
419
}
420

            
421
#[test]
422
fn rule_distribute_not_over_and() {
423
    let distribute_not_over_and = get_rule_by_name("distribute_not_over_and").unwrap();
424

            
425
    let a = Atom::Reference(Reference::new(DeclarationPtr::new_var(
426
        Name::user("a"),
427
        Domain::bool(),
428
    )));
429

            
430
    let b = Atom::Reference(Reference::new(DeclarationPtr::new_var(
431
        Name::user("b"),
432
        Domain::bool(),
433
    )));
434

            
435
    let mut expr = Expression::Not(
436
        Metadata::new(),
437
        Moo::new(Expression::And(
438
            Metadata::new(),
439
            Moo::new(matrix_expr![
440
                Expression::Atomic(Metadata::new(), a.clone()),
441
                Expression::Atomic(Metadata::new(), b.clone()),
442
            ]),
443
        )),
444
    );
445

            
446
    expr = distribute_not_over_and
447
        .apply(&expr, &SymbolTable::new())
448
        .unwrap()
449
        .new_expression;
450

            
451
    assert_eq!(
452
        expr,
453
        Expression::Or(
454
            Metadata::new(),
455
            Moo::new(matrix_expr![
456
                Expression::Not(
457
                    Metadata::new(),
458
                    Moo::new(Expression::Atomic(Metadata::new(), a))
459
                ),
460
                Expression::Not(
461
                    Metadata::new(),
462
                    Moo::new(Expression::Atomic(Metadata::new(), b))
463
                ),
464
            ])
465
        )
466
    );
467
}
468

            
469
#[test]
470
fn rule_distribute_not_over_or() {
471
    let distribute_not_over_or = get_rule_by_name("distribute_not_over_or").unwrap();
472

            
473
    let a = Atom::Reference(Reference::new(DeclarationPtr::new_var(
474
        Name::user("a"),
475
        Domain::bool(),
476
    )));
477

            
478
    let b = Atom::Reference(Reference::new(DeclarationPtr::new_var(
479
        Name::user("b"),
480
        Domain::bool(),
481
    )));
482

            
483
    let mut expr = Expression::Not(
484
        Metadata::new(),
485
        Moo::new(Expression::Or(
486
            Metadata::new(),
487
            Moo::new(matrix_expr![
488
                Expression::Atomic(Metadata::new(), a.clone()),
489
                Expression::Atomic(Metadata::new(), b.clone()),
490
            ]),
491
        )),
492
    );
493

            
494
    expr = distribute_not_over_or
495
        .apply(&expr, &SymbolTable::new())
496
        .unwrap()
497
        .new_expression;
498

            
499
    assert_eq!(
500
        expr,
501
        Expression::And(
502
            Metadata::new(),
503
            Moo::new(matrix_expr![
504
                Expression::Not(
505
                    Metadata::new(),
506
                    Moo::new(Expression::Atomic(Metadata::new(), a))
507
                ),
508
                Expression::Not(
509
                    Metadata::new(),
510
                    Moo::new(Expression::Atomic(Metadata::new(), b))
511
                ),
512
            ])
513
        )
514
    );
515
}
516

            
517
#[test]
518
fn rule_distribute_not_over_and_not_changed() {
519
    let distribute_not_over_and = get_rule_by_name("distribute_not_over_and").unwrap();
520

            
521
    let expr = Expression::Not(
522
        Metadata::new(),
523
        Moo::new(Expression::Atomic(
524
            Metadata::new(),
525
            Atom::Reference(Reference::new(DeclarationPtr::new_var(
526
                Name::user("a"),
527
                Domain::int(vec![Range::Bounded(1, 5)]),
528
            ))),
529
        )),
530
    );
531

            
532
    let result = distribute_not_over_and.apply(&expr, &SymbolTable::new());
533

            
534
    assert!(result.is_err());
535
}
536

            
537
#[test]
538
fn rule_distribute_not_over_or_not_changed() {
539
    let distribute_not_over_or = get_rule_by_name("distribute_not_over_or").unwrap();
540

            
541
    let expr = Expression::Not(
542
        Metadata::new(),
543
        Moo::new(Expression::Atomic(
544
            Metadata::new(),
545
            Atom::Reference(Reference::new(DeclarationPtr::new_var(
546
                Name::user("a"),
547
                Domain::int(vec![Range::Bounded(1, 5)]),
548
            ))),
549
        )),
550
    );
551

            
552
    let result = distribute_not_over_or.apply(&expr, &SymbolTable::new());
553

            
554
    assert!(result.is_err());
555
}
556

            
557
#[test]
558
fn rule_distribute_or_over_and() {
559
    let distribute_or_over_and = get_rule_by_name("distribute_or_over_and").unwrap();
560

            
561
    let d1 = Atom::Reference(Reference::new(DeclarationPtr::new_var(
562
        Name::Machine(1),
563
        Domain::bool(),
564
    )));
565

            
566
    let d2 = Atom::Reference(Reference::new(DeclarationPtr::new_var(
567
        Name::Machine(2),
568
        Domain::bool(),
569
    )));
570

            
571
    let expr = Expression::Or(
572
        Metadata::new(),
573
        Moo::new(matrix_expr![
574
            Expression::And(
575
                Metadata::new(),
576
                Moo::new(matrix_expr![
577
                    Expression::Atomic(Metadata::new(), d1.clone()),
578
                    Expression::Atomic(Metadata::new(), d2.clone()),
579
                ]),
580
            ),
581
            Expression::Atomic(Metadata::new(), d2.clone()),
582
        ]),
583
    );
584

            
585
    let red = distribute_or_over_and
586
        .apply(&expr, &SymbolTable::new())
587
        .unwrap();
588

            
589
    assert_eq!(
590
        red.new_expression,
591
        Expression::And(
592
            Metadata::new(),
593
            Moo::new(matrix_expr![
594
                Expression::Or(
595
                    Metadata::new(),
596
                    Moo::new(matrix_expr![
597
                        Expression::Atomic(Metadata::new(), d2.clone()),
598
                        Expression::Atomic(Metadata::new(), d1),
599
                    ])
600
                ),
601
                Expression::Or(
602
                    Metadata::new(),
603
                    Moo::new(matrix_expr![
604
                        Expression::Atomic(Metadata::new(), d2.clone()),
605
                        Expression::Atomic(Metadata::new(), d2),
606
                    ])
607
                ),
608
            ])
609
        ),
610
    );
611
}
612

            
613
///
614
/// Reduce and solve:
615
/// ```text
616
/// find a,b,c : int(1..3)
617
/// such that a + b + c = 4
618
/// such that a < b
619
/// ```
620
///
621
/// This test uses the rewrite function to simplify the expression instead
622
/// of applying the rules manually.
623
#[test]
624
fn rewrite_solve_xyz() {
625
    println!("Rules: {:?}", get_all_rules());
626

            
627
    let rule_sets = match resolve_rule_sets(SolverFamily::Minion, &["Constant"]) {
628
        Ok(rs) => rs,
629
        Err(e) => {
630
            eprintln!("Error resolving rule sets: {e}");
631
            exit(1);
632
        }
633
    };
634
    println!("Rule sets: {rule_sets:?}");
635

            
636
    // Create variables and domains
637
    let decl_a = DeclarationPtr::new_var(Name::user("a"), Domain::int(vec![Range::Bounded(1, 5)]));
638

            
639
    let decl_b = DeclarationPtr::new_var(Name::user("b"), Domain::int(vec![Range::Bounded(1, 5)]));
640

            
641
    let decl_c = DeclarationPtr::new_var(Name::user("c"), Domain::int(vec![Range::Bounded(1, 5)]));
642

            
643
    let a = Atom::Reference(Reference::new(decl_a.clone()));
644
    let b = Atom::Reference(Reference::new(decl_b.clone()));
645
    let c = Atom::Reference(Reference::new(decl_c.clone()));
646

            
647
    // Construct nested expression
648
    let nested_expr = Expression::And(
649
        Metadata::new(),
650
        Moo::new(matrix_expr![
651
            Expression::Eq(
652
                Metadata::new(),
653
                Moo::new(Expression::Sum(
654
                    Metadata::new(),
655
                    Moo::new(matrix_expr![
656
                        Expression::Atomic(Metadata::new(), a.clone()),
657
                        Expression::Atomic(Metadata::new(), b.clone()),
658
                        Expression::Atomic(Metadata::new(), c),
659
                    ]),
660
                )),
661
                Moo::new(Expression::Atomic(
662
                    Metadata::new(),
663
                    Atom::Literal(Literal::Int(4)),
664
                )),
665
            ),
666
            Expression::Lt(
667
                Metadata::new(),
668
                Moo::new(Expression::Atomic(Metadata::new(), a)),
669
                Moo::new(Expression::Atomic(Metadata::new(), b)),
670
            ),
671
        ]),
672
    );
673

            
674
    let rule_sets = match resolve_rule_sets(SolverFamily::Minion, &["Constant"]) {
675
        Ok(rs) => rs,
676
        Err(e) => {
677
            eprintln!("Error resolving rule sets: {e}");
678
            exit(1);
679
        }
680
    };
681

            
682
    // Apply rewrite function to the nested expression
683
    let mut model = Model::new(Default::default());
684

            
685
    // Insert variables and domains
686
    model
687
        .as_submodel_mut()
688
        .symbols_mut()
689
        .insert(decl_a)
690
        .unwrap();
691
    model
692
        .as_submodel_mut()
693
        .symbols_mut()
694
        .insert(decl_b)
695
        .unwrap();
696
    model
697
        .as_submodel_mut()
698
        .symbols_mut()
699
        .insert(decl_c)
700
        .unwrap();
701

            
702
    *model.as_submodel_mut().constraints_mut() = vec![nested_expr];
703

            
704
    model = rewrite_naive(&model, &rule_sets, true, false).unwrap();
705
    let rewritten_expr = model.as_submodel().constraints();
706

            
707
    // Check if the expression is in its simplest form
708

            
709
    assert!(rewritten_expr.iter().all(is_simple));
710

            
711
    let solver: Solver = Solver::new(adaptors::Minion::new());
712
    let solver = solver.load_model(model).unwrap();
713
    solver.solve(Box::new(|_| true)).unwrap();
714
}
715

            
716
struct RuleResult<'a> {
717
    #[allow(dead_code)]
718
    rule: &'a Rule<'a>,
719
    new_expression: Expression,
720
}
721

            
722
/// # Returns
723
/// - True if `expression` is in its simplest form.
724
/// - False otherwise.
725
pub fn is_simple(expression: &Expression) -> bool {
726
    let rules = get_all_rules();
727
    let mut new = expression.clone();
728
    while let Some(step) = is_simple_iteration(&new, &rules) {
729
        new = step;
730
    }
731
    new == *expression
732
}
733

            
734
/// # Returns
735
/// - Some(<new_expression>) after applying the first applicable rule to `expr` or a sub-expression.
736
/// - None if no rule is applicable to the expression or any sub-expression.
737
fn is_simple_iteration<'a>(
738
    expression: &'a Expression,
739
    rules: &'a Vec<&'a Rule<'a>>,
740
) -> Option<Expression> {
741
    let rule_results = apply_all_rules(expression, rules);
742
    if let Some(new) = choose_rewrite(&rule_results) {
743
        return Some(new);
744
    } else {
745
        let mut sub = expression.children();
746
        for i in 0..sub.len() {
747
            if let Some(new) = is_simple_iteration(&sub[i], rules) {
748
                sub[i] = new;
749
                return Some(expression.with_children(sub.clone()));
750
            }
751
        }
752
    }
753
    None // No rules applicable to this branch of the expression
754
}
755

            
756
/// # Returns
757
/// - A list of RuleResults after applying all rules to `expression`.
758
/// - An empty list if no rules are applicable.
759
fn apply_all_rules<'a>(
760
    expression: &'a Expression,
761
    rules: &'a Vec<&'a Rule<'a>>,
762
) -> Vec<RuleResult<'a>> {
763
    let mut results = Vec::new();
764
    for rule in rules {
765
        match rule.apply(expression, &SymbolTable::new()) {
766
            Ok(red) => {
767
                results.push(RuleResult {
768
                    rule,
769
                    new_expression: red.new_expression,
770
                });
771
            }
772
            Err(_) => continue,
773
        }
774
    }
775
    results
776
}
777

            
778
/// # Returns
779
/// - Some(<new_expression>) after applying the first rule in `results`.
780
/// - None if `results` is empty.
781
fn choose_rewrite(results: &[RuleResult]) -> Option<Expression> {
782
    if results.is_empty() {
783
        return None;
784
    }
785
    // Return the first result for now
786
    // println!("Applying rule: {:?}", results[0].rule);
787
    Some(results[0].new_expression.clone())
788
}
789

            
790
#[test]
791
fn eval_const_int() {
792
    let expr = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
793
    let result = eval_constant(&expr);
794
    assert_eq!(result, Some(Literal::Int(1)));
795
}
796

            
797
#[test]
798
fn eval_const_bool() {
799
    let expr = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
800
    let result = eval_constant(&expr);
801
    assert_eq!(result, Some(Literal::Bool(true)));
802
}
803

            
804
#[test]
805
fn eval_const_and() {
806
    let expr = Expression::And(
807
        Metadata::new(),
808
        Moo::new(matrix_expr![
809
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
810
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
811
        ]),
812
    );
813
    let result = eval_constant(&expr);
814
    assert_eq!(result, Some(Literal::Bool(false)));
815
}
816

            
817
#[test]
818
fn eval_const_ref() {
819
    let expr = Expression::Atomic(
820
        Metadata::new(),
821
        Atom::Reference(Reference::new(DeclarationPtr::new_var(
822
            Name::user("a"),
823
            Domain::int(vec![Range::Bounded(1, 5)]),
824
        ))),
825
    );
826
    let result = eval_constant(&expr);
827
    assert_eq!(result, None);
828
}
829

            
830
#[test]
831
fn eval_const_nested_ref() {
832
    let expr = Expression::Sum(
833
        Metadata::new(),
834
        Moo::new(matrix_expr![
835
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
836
            Expression::And(
837
                Metadata::new(),
838
                Moo::new(matrix_expr![
839
                    Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
840
                    Expression::Atomic(
841
                        Metadata::new(),
842
                        Atom::Reference(Reference::new(DeclarationPtr::new_var(
843
                            Name::user("a"),
844
                            Domain::int(vec![Range::Bounded(1, 5)])
845
                        )))
846
                    ),
847
                ]),
848
            ),
849
        ]),
850
    );
851
    let result = eval_constant(&expr);
852
    assert_eq!(result, None);
853
}
854

            
855
#[test]
856
fn eval_const_eq_int() {
857
    let expr = Expression::Eq(
858
        Metadata::new(),
859
        Moo::new(Expression::Atomic(
860
            Metadata::new(),
861
            Atom::Literal(Literal::Int(1)),
862
        )),
863
        Moo::new(Expression::Atomic(
864
            Metadata::new(),
865
            Atom::Literal(Literal::Int(1)),
866
        )),
867
    );
868
    let result = eval_constant(&expr);
869
    assert_eq!(result, Some(Literal::Bool(true)));
870
}
871

            
872
#[test]
873
fn eval_const_eq_bool() {
874
    let expr = Expression::Eq(
875
        Metadata::new(),
876
        Moo::new(Expression::Atomic(
877
            Metadata::new(),
878
            Atom::Literal(Literal::Bool(true)),
879
        )),
880
        Moo::new(Expression::Atomic(
881
            Metadata::new(),
882
            Atom::Literal(Literal::Bool(true)),
883
        )),
884
    );
885
    let result = eval_constant(&expr);
886
    assert_eq!(result, Some(Literal::Bool(true)));
887
}
888

            
889
#[test]
890
fn eval_const_eq_mixed() {
891
    let expr = Expression::Eq(
892
        Metadata::new(),
893
        Moo::new(Expression::Atomic(
894
            Metadata::new(),
895
            Atom::Literal(Literal::Int(1)),
896
        )),
897
        Moo::new(Expression::Atomic(
898
            Metadata::new(),
899
            Atom::Literal(Literal::Bool(true)),
900
        )),
901
    );
902
    let result = eval_constant(&expr);
903
    assert_eq!(result, None);
904
}
905

            
906
#[test]
907
fn eval_const_sum_mixed() {
908
    let expr = Expression::Sum(
909
        Metadata::new(),
910
        Moo::new(matrix_expr![
911
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
912
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
913
        ]),
914
    );
915
    let result = eval_constant(&expr);
916
    assert_eq!(result, None);
917
}
918

            
919
#[test]
920
fn eval_const_sum_xyz() {
921
    let expr = Expression::And(
922
        Metadata::new(),
923
        Moo::new(matrix_expr![
924
            Expression::Eq(
925
                Metadata::new(),
926
                Moo::new(Expression::Sum(
927
                    Metadata::new(),
928
                    Moo::new(matrix_expr![
929
                        Expression::Atomic(
930
                            Metadata::new(),
931
                            Atom::Reference(Reference::new(DeclarationPtr::new_var(
932
                                Name::user("x"),
933
                                Domain::int(vec![Range::Bounded(1, 5)])
934
                            )))
935
                        ),
936
                        Expression::Atomic(
937
                            Metadata::new(),
938
                            Atom::Reference(Reference::new(DeclarationPtr::new_var(
939
                                Name::user("y"),
940
                                Domain::int(vec![Range::Bounded(1, 5)])
941
                            )))
942
                        ),
943
                        Expression::Atomic(
944
                            Metadata::new(),
945
                            Atom::Reference(Reference::new(DeclarationPtr::new_var(
946
                                Name::user("z"),
947
                                Domain::int(vec![Range::Bounded(1, 5)])
948
                            )))
949
                        ),
950
                    ])
951
                )),
952
                Moo::new(Expression::Atomic(
953
                    Metadata::new(),
954
                    Atom::Literal(Literal::Int(4)),
955
                )),
956
            ),
957
            Expression::Geq(
958
                Metadata::new(),
959
                Moo::new(Expression::Atomic(
960
                    Metadata::new(),
961
                    Atom::Reference(Reference::new(DeclarationPtr::new_var(
962
                        Name::user("x"),
963
                        Domain::int(vec![Range::Bounded(1, 5)])
964
                    )))
965
                )),
966
                Moo::new(Expression::Atomic(
967
                    Metadata::new(),
968
                    Atom::Reference(Reference::new(DeclarationPtr::new_var(
969
                        Name::user("y"),
970
                        Domain::int(vec![Range::Bounded(1, 5)])
971
                    )))
972
                )),
973
            ),
974
        ]),
975
    );
976
    let result = eval_constant(&expr);
977
    assert_eq!(result, None);
978
}
979

            
980
#[test]
981
fn eval_const_or() {
982
    let expr = Expression::Or(
983
        Metadata::new(),
984
        Moo::new(matrix_expr![
985
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
986
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
987
        ]),
988
    );
989
    let result = eval_constant(&expr);
990
    assert_eq!(result, Some(Literal::Bool(false)));
991
}