1
use conjure_cp::{
2
    Model,
3
    ast::{
4
        Atom, DeclarationPtr, Domain, Expression, Literal, Metadata, Moo, Name, Range, Reference,
5
        SymbolTable, eval_constant,
6
    },
7
    into_matrix_expr, matrix_expr,
8
    rule_engine::{Rule, get_all_rules, get_rule_by_name, resolve_rule_sets, rewrite_naive},
9
    settings::{QuantifiedExpander, SolverFamily, set_comprehension_expander},
10
    solver::{Solver, adaptors},
11
};
12
#[allow(unused_imports)]
13
#[allow(clippy::single_component_path_imports)] // ensure this is linked so we can lookup rules
14
use conjure_cp_rules;
15
use pretty_assertions::assert_eq;
16
use std::process::exit;
17
use uniplate::Uniplate;
18

            
19
#[test]
20
3
fn rules_present() {
21
3
    let rules = get_all_rules();
22
3
    assert!(!rules.is_empty());
23
3
}
24

            
25
#[test]
26
3
fn sum_of_constants() {
27
3
    let valid_sum_expression = Expression::Sum(
28
3
        Metadata::new(),
29
3
        Moo::new(matrix_expr![
30
3
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
31
3
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
32
3
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(3))),
33
3
        ]),
34
3
    );
35

            
36
3
    let invalid_sum_expression = Expression::Sum(
37
3
        Metadata::new(),
38
3
        Moo::new(matrix_expr![
39
3
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
40
3
            Expression::Atomic(
41
3
                Metadata::new(),
42
3
                Atom::Reference(Reference::new(DeclarationPtr::new_find(
43
3
                    Name::user("a"),
44
3
                    Domain::bool()
45
3
                ))),
46
3
            ),
47
3
        ]),
48
3
    );
49

            
50
3
    assert_eq!(evaluate_sum_of_constants(&valid_sum_expression), Some(6));
51

            
52
3
    assert_eq!(evaluate_sum_of_constants(&invalid_sum_expression), None);
53
3
}
54

            
55
12
fn evaluate_sum_of_constants(expr: &Expression) -> Option<i32> {
56
12
    match expr {
57
12
        Expression::Sum(_metadata, expressions) => {
58
12
            let expressions = (**expressions).clone().unwrap_list()?;
59
12
            let mut sum = 0;
60
30
            for e in expressions {
61
24
                match e {
62
24
                    Expression::Atomic(_, Atom::Literal(Literal::Int(value))) => {
63
24
                        sum += value;
64
24
                    }
65
6
                    _ => return None,
66
                }
67
            }
68
6
            Some(sum)
69
        }
70
        _ => None,
71
    }
72
12
}
73

            
74
#[test]
75
3
fn recursive_sum_of_constants() {
76
3
    let a = Atom::Reference(Reference::new(DeclarationPtr::new_find(
77
3
        Name::user("a"),
78
3
        Domain::int(vec![Range::Bounded(1, 5)]),
79
3
    )));
80
3
    let complex_expression = Expression::Eq(
81
3
        Metadata::new(),
82
3
        Moo::new(Expression::Sum(
83
3
            Metadata::new(),
84
3
            Moo::new(matrix_expr![
85
3
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
86
3
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
87
3
                Expression::Sum(
88
3
                    Metadata::new(),
89
3
                    Moo::new(matrix_expr![
90
3
                        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
91
3
                        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
92
3
                    ]),
93
3
                ),
94
3
                Expression::Atomic(Metadata::new(), a.clone()),
95
3
            ]),
96
3
        )),
97
3
        Moo::new(Expression::Atomic(
98
3
            Metadata::new(),
99
3
            Atom::Literal(Literal::Int(3)),
100
3
        )),
101
3
    );
102
3
    let correct_simplified_expression = Expression::Eq(
103
3
        Metadata::new(),
104
3
        Moo::new(Expression::Sum(
105
3
            Metadata::new(),
106
3
            Moo::new(matrix_expr![
107
3
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
108
3
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
109
3
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(3))),
110
3
                Expression::Atomic(Metadata::new(), a),
111
3
            ]),
112
3
        )),
113
3
        Moo::new(Expression::Atomic(
114
3
            Metadata::new(),
115
3
            Atom::Literal(Literal::Int(3)),
116
3
        )),
117
3
    );
118

            
119
3
    let simplified_expression = simplify_expression(complex_expression);
120
3
    assert_eq!(simplified_expression, correct_simplified_expression);
121
3
}
122

            
123
21
fn simplify_expression(expr: Expression) -> Expression {
124
21
    match expr {
125
6
        Expression::Sum(_metadata, expressions) => {
126
6
            let expressions = Moo::unwrap_or_clone(expressions).unwrap_list().unwrap();
127
6
            if let Some(result) = evaluate_sum_of_constants(&Expression::Sum(
128
6
                Metadata::new(),
129
6
                Moo::new(into_matrix_expr![expressions.clone()]),
130
6
            )) {
131
3
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(result)))
132
            } else {
133
3
                Expression::Sum(
134
3
                    Metadata::new(),
135
3
                    Moo::new(into_matrix_expr![
136
3
                        expressions.into_iter().map(simplify_expression).collect()
137
3
                    ]),
138
3
                )
139
            }
140
        }
141
3
        Expression::Eq(_metadata, left, right) => Expression::Eq(
142
3
            Metadata::new(),
143
3
            Moo::new(simplify_expression(Moo::unwrap_or_clone(left))),
144
3
            Moo::new(simplify_expression(Moo::unwrap_or_clone(right))),
145
3
        ),
146
        Expression::Geq(_metadata, left, right) => Expression::Geq(
147
            Metadata::new(),
148
            Moo::new(simplify_expression(Moo::unwrap_or_clone(left))),
149
            Moo::new(simplify_expression(Moo::unwrap_or_clone(right))),
150
        ),
151
12
        _ => expr,
152
    }
153
21
}
154

            
155
#[test]
156
3
fn rule_sum_constants() {
157
3
    let sum_constants = get_rule_by_name("partial_evaluator").unwrap();
158
3
    let unwrap_sum = get_rule_by_name("remove_unit_vector_sum").unwrap();
159

            
160
3
    let mut expr = Expression::Sum(
161
3
        Metadata::new(),
162
3
        Moo::new(matrix_expr![
163
3
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
164
3
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
165
3
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(3))),
166
3
        ]),
167
3
    );
168

            
169
3
    expr = sum_constants
170
3
        .apply(&expr, &SymbolTable::new())
171
3
        .unwrap()
172
3
        .new_expression;
173
3
    expr = unwrap_sum
174
3
        .apply(&expr, &SymbolTable::new())
175
3
        .unwrap()
176
3
        .new_expression;
177

            
178
3
    assert_eq!(
179
        expr,
180
3
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(6)))
181
    );
182
3
}
183

            
184
#[test]
185
3
fn rule_sum_geq() {
186
3
    let introduce_sumgeq = get_rule_by_name("introduce_weighted_sumleq_sumgeq").unwrap();
187

            
188
3
    let mut expr = Expression::Geq(
189
3
        Metadata::new(),
190
3
        Moo::new(Expression::Sum(
191
3
            Metadata::new(),
192
3
            Moo::new(matrix_expr![
193
3
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
194
3
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
195
3
            ]),
196
3
        )),
197
3
        Moo::new(Expression::Atomic(
198
3
            Metadata::new(),
199
3
            Atom::Literal(Literal::Int(3)),
200
3
        )),
201
3
    );
202

            
203
3
    expr = introduce_sumgeq
204
3
        .apply(&expr, &SymbolTable::new())
205
3
        .unwrap()
206
3
        .new_expression;
207

            
208
3
    assert_eq!(
209
        expr,
210
3
        Expression::FlatSumGeq(
211
3
            Metadata::new(),
212
3
            vec![
213
3
                Atom::Literal(Literal::Int(1)),
214
3
                Atom::Literal(Literal::Int(2)),
215
3
            ],
216
3
            Atom::Literal(Literal::Int(3))
217
3
        )
218
    );
219
3
}
220

            
221
///
222
/// Reduce and solve:
223
/// ```text
224
/// find a,b,c : int(1..3)
225
/// such that a + b + c <= 2 + 3 - 1
226
/// such that a < b
227
/// ```
228
#[test]
229
3
fn reduce_solve_xyz() {
230
3
    println!("Rules: {:?}", get_all_rules());
231
3
    let sum_constants = get_rule_by_name("partial_evaluator").unwrap();
232
3
    let unwrap_sum = get_rule_by_name("remove_unit_vector_sum").unwrap();
233
3
    let lt_to_leq = get_rule_by_name("lt_to_leq").unwrap();
234
3
    let leq_to_ineq = get_rule_by_name("x_leq_y_plus_k_to_ineq").unwrap();
235
3
    let introduce_sumleq = get_rule_by_name("introduce_weighted_sumleq_sumgeq").unwrap();
236

            
237
    // 2 + 3 - 1
238
3
    let mut expr1 = Expression::Sum(
239
3
        Metadata::new(),
240
3
        Moo::new(matrix_expr![
241
3
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
242
3
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(3))),
243
3
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(-1))),
244
3
        ]),
245
3
    );
246

            
247
3
    expr1 = sum_constants
248
3
        .apply(&expr1, &SymbolTable::new())
249
3
        .unwrap()
250
3
        .new_expression;
251
3
    expr1 = unwrap_sum
252
3
        .apply(&expr1, &SymbolTable::new())
253
3
        .unwrap()
254
3
        .new_expression;
255
3
    assert_eq!(
256
        expr1,
257
3
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(4)))
258
    );
259

            
260
3
    let a = Atom::Reference(Reference::new(DeclarationPtr::new_find(
261
3
        Name::user("a"),
262
3
        Domain::int(vec![Range::Bounded(1, 5)]),
263
3
    )));
264
3
    let b = Atom::Reference(Reference::new(DeclarationPtr::new_find(
265
3
        Name::user("b"),
266
3
        Domain::int(vec![Range::Bounded(1, 5)]),
267
3
    )));
268
3
    let c = Atom::Reference(Reference::new(DeclarationPtr::new_find(
269
3
        Name::user("c"),
270
3
        Domain::int(vec![Range::Bounded(1, 5)]),
271
3
    )));
272

            
273
    // a + b + c = 4
274
3
    expr1 = Expression::Leq(
275
3
        Metadata::new(),
276
3
        Moo::new(Expression::Sum(
277
3
            Metadata::new(),
278
3
            Moo::new(matrix_expr![
279
3
                Expression::Atomic(Metadata::new(), a.clone()),
280
3
                Expression::Atomic(Metadata::new(), b.clone()),
281
3
                Expression::Atomic(Metadata::new(), c.clone()),
282
3
            ]),
283
3
        )),
284
3
        Moo::new(expr1),
285
3
    );
286
3
    expr1 = introduce_sumleq
287
3
        .apply(&expr1, &SymbolTable::new())
288
3
        .unwrap()
289
3
        .new_expression;
290
3
    assert_eq!(
291
        expr1,
292
3
        Expression::FlatSumLeq(
293
3
            Metadata::new(),
294
3
            vec![a.clone(), b.clone(), c],
295
3
            Atom::Literal(Literal::Int(4))
296
3
        )
297
    );
298

            
299
    // a < b
300
3
    let mut expr2 = Expression::Lt(
301
3
        Metadata::new(),
302
3
        Moo::new(Expression::Atomic(Metadata::new(), a.clone())),
303
3
        Moo::new(Expression::Atomic(Metadata::new(), b.clone())),
304
3
    );
305
3
    expr2 = lt_to_leq
306
3
        .apply(&expr2, &SymbolTable::new())
307
3
        .unwrap()
308
3
        .new_expression;
309

            
310
3
    expr2 = leq_to_ineq
311
3
        .apply(&expr2, &SymbolTable::new())
312
3
        .unwrap()
313
3
        .new_expression;
314
3
    assert_eq!(
315
        expr2,
316
3
        Expression::FlatIneq(
317
3
            Metadata::new(),
318
3
            Moo::new(a),
319
3
            Moo::new(b),
320
3
            Box::new(Literal::Int(-1)),
321
3
        )
322
    );
323

            
324
3
    let mut model = Model::new(Default::default());
325
3
    *model.constraints_mut() = vec![expr1, expr2];
326

            
327
3
    model
328
3
        .symbols_mut()
329
3
        .insert(DeclarationPtr::new_find(
330
3
            Name::user("a"),
331
3
            Domain::int(vec![Range::Bounded(1, 3)]),
332
3
        ))
333
3
        .unwrap();
334
3
    model
335
3
        .symbols_mut()
336
3
        .insert(DeclarationPtr::new_find(
337
3
            Name::user("b"),
338
3
            Domain::int(vec![Range::Bounded(1, 3)]),
339
3
        ))
340
3
        .unwrap();
341
3
    model
342
3
        .symbols_mut()
343
3
        .insert(DeclarationPtr::new_find(
344
3
            Name::user("c"),
345
3
            Domain::int(vec![Range::Bounded(1, 3)]),
346
3
        ))
347
3
        .unwrap();
348

            
349
3
    let solver: Solver = Solver::new(adaptors::Minion::new());
350
3
    let solver = solver.load_model(model).unwrap();
351
3
    solver.solve(Box::new(|_| true)).unwrap();
352
3
}
353

            
354
#[test]
355
3
fn rule_remove_double_negation() {
356
3
    let remove_double_negation = get_rule_by_name("remove_double_negation").unwrap();
357

            
358
3
    let mut expr = Expression::Not(
359
3
        Metadata::new(),
360
3
        Moo::new(Expression::Not(
361
3
            Metadata::new(),
362
3
            Moo::new(Expression::Atomic(
363
3
                Metadata::new(),
364
3
                Atom::Literal(Literal::Bool(true)),
365
3
            )),
366
3
        )),
367
3
    );
368

            
369
3
    expr = remove_double_negation
370
3
        .apply(&expr, &SymbolTable::new())
371
3
        .unwrap()
372
3
        .new_expression;
373

            
374
3
    assert_eq!(
375
        expr,
376
3
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)))
377
    );
378
3
}
379

            
380
#[test]
381
3
fn remove_trivial_and_or() {
382
3
    let remove_trivial_and = get_rule_by_name("remove_unit_vector_and").unwrap();
383
3
    let remove_trivial_or = get_rule_by_name("remove_unit_vector_or").unwrap();
384

            
385
3
    let mut expr_and = Expression::And(
386
3
        Metadata::new(),
387
3
        Moo::new(matrix_expr![Expression::Atomic(
388
3
            Metadata::new(),
389
3
            Atom::Literal(Literal::Bool(true)),
390
3
        )]),
391
3
    );
392
3
    let mut expr_or = Expression::Or(
393
3
        Metadata::new(),
394
3
        Moo::new(matrix_expr![Expression::Atomic(
395
3
            Metadata::new(),
396
3
            Atom::Literal(Literal::Bool(false)),
397
3
        )]),
398
3
    );
399

            
400
3
    expr_and = remove_trivial_and
401
3
        .apply(&expr_and, &SymbolTable::new())
402
3
        .unwrap()
403
3
        .new_expression;
404
3
    expr_or = remove_trivial_or
405
3
        .apply(&expr_or, &SymbolTable::new())
406
3
        .unwrap()
407
3
        .new_expression;
408

            
409
3
    assert_eq!(
410
        expr_and,
411
3
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)))
412
    );
413
3
    assert_eq!(
414
        expr_or,
415
3
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false)))
416
    );
417
3
}
418

            
419
#[test]
420
3
fn rule_distribute_not_over_and() {
421
3
    let distribute_not_over_and = get_rule_by_name("distribute_not_over_and").unwrap();
422

            
423
3
    let a = Atom::Reference(Reference::new(DeclarationPtr::new_find(
424
3
        Name::user("a"),
425
3
        Domain::bool(),
426
3
    )));
427

            
428
3
    let b = Atom::Reference(Reference::new(DeclarationPtr::new_find(
429
3
        Name::user("b"),
430
3
        Domain::bool(),
431
3
    )));
432

            
433
3
    let mut expr = Expression::Not(
434
3
        Metadata::new(),
435
3
        Moo::new(Expression::And(
436
3
            Metadata::new(),
437
3
            Moo::new(matrix_expr![
438
3
                Expression::Atomic(Metadata::new(), a.clone()),
439
3
                Expression::Atomic(Metadata::new(), b.clone()),
440
3
            ]),
441
3
        )),
442
3
    );
443

            
444
3
    expr = distribute_not_over_and
445
3
        .apply(&expr, &SymbolTable::new())
446
3
        .unwrap()
447
3
        .new_expression;
448

            
449
3
    assert_eq!(
450
        expr,
451
3
        Expression::Or(
452
3
            Metadata::new(),
453
3
            Moo::new(matrix_expr![
454
3
                Expression::Not(
455
3
                    Metadata::new(),
456
3
                    Moo::new(Expression::Atomic(Metadata::new(), a))
457
3
                ),
458
3
                Expression::Not(
459
3
                    Metadata::new(),
460
3
                    Moo::new(Expression::Atomic(Metadata::new(), b))
461
3
                ),
462
3
            ])
463
3
        )
464
    );
465
3
}
466

            
467
#[test]
468
3
fn rule_distribute_not_over_or() {
469
3
    let distribute_not_over_or = get_rule_by_name("distribute_not_over_or").unwrap();
470

            
471
3
    let a = Atom::Reference(Reference::new(DeclarationPtr::new_find(
472
3
        Name::user("a"),
473
3
        Domain::bool(),
474
3
    )));
475

            
476
3
    let b = Atom::Reference(Reference::new(DeclarationPtr::new_find(
477
3
        Name::user("b"),
478
3
        Domain::bool(),
479
3
    )));
480

            
481
3
    let mut expr = Expression::Not(
482
3
        Metadata::new(),
483
3
        Moo::new(Expression::Or(
484
3
            Metadata::new(),
485
3
            Moo::new(matrix_expr![
486
3
                Expression::Atomic(Metadata::new(), a.clone()),
487
3
                Expression::Atomic(Metadata::new(), b.clone()),
488
3
            ]),
489
3
        )),
490
3
    );
491

            
492
3
    expr = distribute_not_over_or
493
3
        .apply(&expr, &SymbolTable::new())
494
3
        .unwrap()
495
3
        .new_expression;
496

            
497
3
    assert_eq!(
498
        expr,
499
3
        Expression::And(
500
3
            Metadata::new(),
501
3
            Moo::new(matrix_expr![
502
3
                Expression::Not(
503
3
                    Metadata::new(),
504
3
                    Moo::new(Expression::Atomic(Metadata::new(), a))
505
3
                ),
506
3
                Expression::Not(
507
3
                    Metadata::new(),
508
3
                    Moo::new(Expression::Atomic(Metadata::new(), b))
509
3
                ),
510
3
            ])
511
3
        )
512
    );
513
3
}
514

            
515
#[test]
516
3
fn rule_distribute_not_over_and_not_changed() {
517
3
    let distribute_not_over_and = get_rule_by_name("distribute_not_over_and").unwrap();
518

            
519
3
    let expr = Expression::Not(
520
3
        Metadata::new(),
521
3
        Moo::new(Expression::Atomic(
522
3
            Metadata::new(),
523
3
            Atom::Reference(Reference::new(DeclarationPtr::new_find(
524
3
                Name::user("a"),
525
3
                Domain::int(vec![Range::Bounded(1, 5)]),
526
3
            ))),
527
3
        )),
528
3
    );
529

            
530
3
    let result = distribute_not_over_and.apply(&expr, &SymbolTable::new());
531

            
532
3
    assert!(result.is_err());
533
3
}
534

            
535
#[test]
536
3
fn rule_distribute_not_over_or_not_changed() {
537
3
    let distribute_not_over_or = get_rule_by_name("distribute_not_over_or").unwrap();
538

            
539
3
    let expr = Expression::Not(
540
3
        Metadata::new(),
541
3
        Moo::new(Expression::Atomic(
542
3
            Metadata::new(),
543
3
            Atom::Reference(Reference::new(DeclarationPtr::new_find(
544
3
                Name::user("a"),
545
3
                Domain::int(vec![Range::Bounded(1, 5)]),
546
3
            ))),
547
3
        )),
548
3
    );
549

            
550
3
    let result = distribute_not_over_or.apply(&expr, &SymbolTable::new());
551

            
552
3
    assert!(result.is_err());
553
3
}
554

            
555
#[test]
556
3
fn rule_distribute_or_over_and() {
557
3
    let distribute_or_over_and = get_rule_by_name("distribute_or_over_and").unwrap();
558

            
559
3
    let d1 = Atom::Reference(Reference::new(DeclarationPtr::new_find(
560
3
        Name::Machine(1),
561
3
        Domain::bool(),
562
3
    )));
563

            
564
3
    let d2 = Atom::Reference(Reference::new(DeclarationPtr::new_find(
565
3
        Name::Machine(2),
566
3
        Domain::bool(),
567
3
    )));
568

            
569
3
    let expr = Expression::Or(
570
3
        Metadata::new(),
571
3
        Moo::new(matrix_expr![
572
3
            Expression::And(
573
3
                Metadata::new(),
574
3
                Moo::new(matrix_expr![
575
3
                    Expression::Atomic(Metadata::new(), d1.clone()),
576
3
                    Expression::Atomic(Metadata::new(), d2.clone()),
577
3
                ]),
578
3
            ),
579
3
            Expression::Atomic(Metadata::new(), d2.clone()),
580
3
        ]),
581
3
    );
582

            
583
3
    let red = distribute_or_over_and
584
3
        .apply(&expr, &SymbolTable::new())
585
3
        .unwrap();
586

            
587
3
    assert_eq!(
588
        red.new_expression,
589
3
        Expression::And(
590
3
            Metadata::new(),
591
3
            Moo::new(matrix_expr![
592
3
                Expression::Or(
593
3
                    Metadata::new(),
594
3
                    Moo::new(matrix_expr![
595
3
                        Expression::Atomic(Metadata::new(), d2.clone()),
596
3
                        Expression::Atomic(Metadata::new(), d1),
597
3
                    ])
598
3
                ),
599
3
                Expression::Or(
600
3
                    Metadata::new(),
601
3
                    Moo::new(matrix_expr![
602
3
                        Expression::Atomic(Metadata::new(), d2.clone()),
603
3
                        Expression::Atomic(Metadata::new(), d2),
604
3
                    ])
605
3
                ),
606
3
            ])
607
3
        ),
608
    );
609
3
}
610

            
611
///
612
/// Reduce and solve:
613
/// ```text
614
/// find a,b,c : int(1..3)
615
/// such that a + b + c = 4
616
/// such that a < b
617
/// ```
618
///
619
/// This test uses the rewrite function to simplify the expression instead
620
/// of applying the rules manually.
621
#[test]
622
3
fn rewrite_solve_xyz() {
623
3
    println!("Rules: {:?}", get_all_rules());
624
3
    set_comprehension_expander(QuantifiedExpander::Native);
625

            
626
3
    let rule_sets = match resolve_rule_sets(SolverFamily::Minion, &["Constant"]) {
627
3
        Ok(rs) => rs,
628
        Err(e) => {
629
            eprintln!("Error resolving rule sets: {e}");
630
            exit(1);
631
        }
632
    };
633
3
    println!("Rule sets: {rule_sets:?}");
634

            
635
    // Create variables and domains
636
3
    let decl_a = DeclarationPtr::new_find(Name::user("a"), Domain::int(vec![Range::Bounded(1, 5)]));
637

            
638
3
    let decl_b = DeclarationPtr::new_find(Name::user("b"), Domain::int(vec![Range::Bounded(1, 5)]));
639

            
640
3
    let decl_c = DeclarationPtr::new_find(Name::user("c"), Domain::int(vec![Range::Bounded(1, 5)]));
641

            
642
3
    let a = Atom::Reference(Reference::new(decl_a.clone()));
643
3
    let b = Atom::Reference(Reference::new(decl_b.clone()));
644
3
    let c = Atom::Reference(Reference::new(decl_c.clone()));
645

            
646
    // Construct nested expression
647
3
    let nested_expr = Expression::And(
648
3
        Metadata::new(),
649
3
        Moo::new(matrix_expr![
650
3
            Expression::Eq(
651
3
                Metadata::new(),
652
3
                Moo::new(Expression::Sum(
653
3
                    Metadata::new(),
654
3
                    Moo::new(matrix_expr![
655
3
                        Expression::Atomic(Metadata::new(), a.clone()),
656
3
                        Expression::Atomic(Metadata::new(), b.clone()),
657
3
                        Expression::Atomic(Metadata::new(), c),
658
3
                    ]),
659
3
                )),
660
3
                Moo::new(Expression::Atomic(
661
3
                    Metadata::new(),
662
3
                    Atom::Literal(Literal::Int(4)),
663
3
                )),
664
3
            ),
665
3
            Expression::Lt(
666
3
                Metadata::new(),
667
3
                Moo::new(Expression::Atomic(Metadata::new(), a)),
668
3
                Moo::new(Expression::Atomic(Metadata::new(), b)),
669
3
            ),
670
3
        ]),
671
3
    );
672

            
673
3
    let rule_sets = match resolve_rule_sets(SolverFamily::Minion, &["Constant"]) {
674
3
        Ok(rs) => rs,
675
        Err(e) => {
676
            eprintln!("Error resolving rule sets: {e}");
677
            exit(1);
678
        }
679
    };
680

            
681
    // Apply rewrite function to the nested expression
682
3
    let mut model = Model::new(Default::default());
683

            
684
    // Insert variables and domains
685
3
    model.symbols_mut().insert(decl_a).unwrap();
686
3
    model.symbols_mut().insert(decl_b).unwrap();
687
3
    model.symbols_mut().insert(decl_c).unwrap();
688

            
689
3
    *model.constraints_mut() = vec![nested_expr];
690

            
691
3
    model = rewrite_naive(&model, &rule_sets, true).unwrap();
692
3
    let rewritten_expr = model.constraints();
693

            
694
    // Check if the expression is in its simplest form
695

            
696
3
    assert!(rewritten_expr.iter().all(is_simple));
697

            
698
3
    let solver: Solver = Solver::new(adaptors::Minion::new());
699
3
    let solver = solver.load_model(model).unwrap();
700
3
    solver.solve(Box::new(|_| true)).unwrap();
701
3
}
702

            
703
struct RuleResult<'a> {
704
    #[allow(dead_code)]
705
    rule: &'a Rule<'a>,
706
    new_expression: Expression,
707
}
708

            
709
/// # Returns
710
/// - True if `expression` is in its simplest form.
711
/// - False otherwise.
712
9
pub fn is_simple(expression: &Expression) -> bool {
713
9
    let rules = get_all_rules();
714
9
    let mut new = expression.clone();
715
9
    while let Some(step) = is_simple_iteration(&new, &rules) {
716
        new = step;
717
    }
718
9
    new == *expression
719
9
}
720

            
721
/// # Returns
722
/// - Some(<new_expression>) after applying the first applicable rule to `expr` or a sub-expression.
723
/// - None if no rule is applicable to the expression or any sub-expression.
724
9
fn is_simple_iteration<'a>(
725
9
    expression: &'a Expression,
726
9
    rules: &'a Vec<&'a Rule<'a>>,
727
9
) -> Option<Expression> {
728
9
    let rule_results = apply_all_rules(expression, rules);
729
9
    if let Some(new) = choose_rewrite(&rule_results) {
730
        return Some(new);
731
    } else {
732
9
        let mut sub = expression.children();
733
9
        for i in 0..sub.len() {
734
            if let Some(new) = is_simple_iteration(&sub[i], rules) {
735
                sub[i] = new;
736
                return Some(expression.with_children(sub.clone()));
737
            }
738
        }
739
    }
740
9
    None // No rules applicable to this branch of the expression
741
9
}
742

            
743
/// # Returns
744
/// - A list of RuleResults after applying all rules to `expression`.
745
/// - An empty list if no rules are applicable.
746
9
fn apply_all_rules<'a>(
747
9
    expression: &'a Expression,
748
9
    rules: &'a Vec<&'a Rule<'a>>,
749
9
) -> Vec<RuleResult<'a>> {
750
9
    let mut results = Vec::new();
751
1209
    for rule in rules {
752
1209
        match rule.apply(expression, &SymbolTable::new()) {
753
            Ok(red) => {
754
                results.push(RuleResult {
755
                    rule,
756
                    new_expression: red.new_expression,
757
                });
758
            }
759
1209
            Err(_) => continue,
760
        }
761
    }
762
9
    results
763
9
}
764

            
765
/// # Returns
766
/// - Some(<new_expression>) after applying the first rule in `results`.
767
/// - None if `results` is empty.
768
9
fn choose_rewrite(results: &[RuleResult]) -> Option<Expression> {
769
9
    if results.is_empty() {
770
9
        return None;
771
    }
772
    // Return the first result for now
773
    // println!("Applying rule: {:?}", results[0].rule);
774
    Some(results[0].new_expression.clone())
775
9
}
776

            
777
#[test]
778
3
fn eval_const_int() {
779
3
    let expr = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
780
3
    let result = eval_constant(&expr);
781
3
    assert_eq!(result, Some(Literal::Int(1)));
782
3
}
783

            
784
#[test]
785
3
fn eval_const_bool() {
786
3
    let expr = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
787
3
    let result = eval_constant(&expr);
788
3
    assert_eq!(result, Some(Literal::Bool(true)));
789
3
}
790

            
791
#[test]
792
3
fn eval_const_and() {
793
3
    let expr = Expression::And(
794
3
        Metadata::new(),
795
3
        Moo::new(matrix_expr![
796
3
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
797
3
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
798
3
        ]),
799
3
    );
800
3
    let result = eval_constant(&expr);
801
3
    assert_eq!(result, Some(Literal::Bool(false)));
802
3
}
803

            
804
#[test]
805
3
fn eval_const_ref() {
806
3
    let expr = Expression::Atomic(
807
3
        Metadata::new(),
808
3
        Atom::Reference(Reference::new(DeclarationPtr::new_find(
809
3
            Name::user("a"),
810
3
            Domain::int(vec![Range::Bounded(1, 5)]),
811
3
        ))),
812
3
    );
813
3
    let result = eval_constant(&expr);
814
3
    assert_eq!(result, None);
815
3
}
816

            
817
#[test]
818
3
fn eval_const_nested_ref() {
819
3
    let expr = Expression::Sum(
820
3
        Metadata::new(),
821
3
        Moo::new(matrix_expr![
822
3
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
823
3
            Expression::And(
824
3
                Metadata::new(),
825
3
                Moo::new(matrix_expr![
826
3
                    Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
827
3
                    Expression::Atomic(
828
3
                        Metadata::new(),
829
3
                        Atom::Reference(Reference::new(DeclarationPtr::new_find(
830
3
                            Name::user("a"),
831
3
                            Domain::int(vec![Range::Bounded(1, 5)])
832
3
                        )))
833
3
                    ),
834
3
                ]),
835
3
            ),
836
3
        ]),
837
3
    );
838
3
    let result = eval_constant(&expr);
839
3
    assert_eq!(result, None);
840
3
}
841

            
842
#[test]
843
3
fn eval_const_eq_int() {
844
3
    let expr = Expression::Eq(
845
3
        Metadata::new(),
846
3
        Moo::new(Expression::Atomic(
847
3
            Metadata::new(),
848
3
            Atom::Literal(Literal::Int(1)),
849
3
        )),
850
3
        Moo::new(Expression::Atomic(
851
3
            Metadata::new(),
852
3
            Atom::Literal(Literal::Int(1)),
853
3
        )),
854
3
    );
855
3
    let result = eval_constant(&expr);
856
3
    assert_eq!(result, Some(Literal::Bool(true)));
857
3
}
858

            
859
#[test]
860
3
fn eval_const_eq_bool() {
861
3
    let expr = Expression::Eq(
862
3
        Metadata::new(),
863
3
        Moo::new(Expression::Atomic(
864
3
            Metadata::new(),
865
3
            Atom::Literal(Literal::Bool(true)),
866
3
        )),
867
3
        Moo::new(Expression::Atomic(
868
3
            Metadata::new(),
869
3
            Atom::Literal(Literal::Bool(true)),
870
3
        )),
871
3
    );
872
3
    let result = eval_constant(&expr);
873
3
    assert_eq!(result, Some(Literal::Bool(true)));
874
3
}
875

            
876
#[test]
877
3
fn eval_const_eq_mixed() {
878
3
    let expr = Expression::Eq(
879
3
        Metadata::new(),
880
3
        Moo::new(Expression::Atomic(
881
3
            Metadata::new(),
882
3
            Atom::Literal(Literal::Int(1)),
883
3
        )),
884
3
        Moo::new(Expression::Atomic(
885
3
            Metadata::new(),
886
3
            Atom::Literal(Literal::Bool(true)),
887
3
        )),
888
3
    );
889
3
    let result = eval_constant(&expr);
890
3
    assert_eq!(result, None);
891
3
}
892

            
893
#[test]
894
3
fn eval_const_sum_mixed() {
895
3
    let expr = Expression::Sum(
896
3
        Metadata::new(),
897
3
        Moo::new(matrix_expr![
898
3
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
899
3
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
900
3
        ]),
901
3
    );
902
3
    let result = eval_constant(&expr);
903
3
    assert_eq!(result, None);
904
3
}
905

            
906
#[test]
907
3
fn eval_const_sum_xyz() {
908
3
    let expr = Expression::And(
909
3
        Metadata::new(),
910
3
        Moo::new(matrix_expr![
911
3
            Expression::Eq(
912
3
                Metadata::new(),
913
3
                Moo::new(Expression::Sum(
914
3
                    Metadata::new(),
915
3
                    Moo::new(matrix_expr![
916
3
                        Expression::Atomic(
917
3
                            Metadata::new(),
918
3
                            Atom::Reference(Reference::new(DeclarationPtr::new_find(
919
3
                                Name::user("x"),
920
3
                                Domain::int(vec![Range::Bounded(1, 5)])
921
3
                            )))
922
3
                        ),
923
3
                        Expression::Atomic(
924
3
                            Metadata::new(),
925
3
                            Atom::Reference(Reference::new(DeclarationPtr::new_find(
926
3
                                Name::user("y"),
927
3
                                Domain::int(vec![Range::Bounded(1, 5)])
928
3
                            )))
929
3
                        ),
930
3
                        Expression::Atomic(
931
3
                            Metadata::new(),
932
3
                            Atom::Reference(Reference::new(DeclarationPtr::new_find(
933
3
                                Name::user("z"),
934
3
                                Domain::int(vec![Range::Bounded(1, 5)])
935
3
                            )))
936
3
                        ),
937
3
                    ])
938
3
                )),
939
3
                Moo::new(Expression::Atomic(
940
3
                    Metadata::new(),
941
3
                    Atom::Literal(Literal::Int(4)),
942
3
                )),
943
3
            ),
944
3
            Expression::Geq(
945
3
                Metadata::new(),
946
3
                Moo::new(Expression::Atomic(
947
3
                    Metadata::new(),
948
3
                    Atom::Reference(Reference::new(DeclarationPtr::new_find(
949
3
                        Name::user("x"),
950
3
                        Domain::int(vec![Range::Bounded(1, 5)])
951
3
                    )))
952
3
                )),
953
3
                Moo::new(Expression::Atomic(
954
3
                    Metadata::new(),
955
3
                    Atom::Reference(Reference::new(DeclarationPtr::new_find(
956
3
                        Name::user("y"),
957
3
                        Domain::int(vec![Range::Bounded(1, 5)])
958
3
                    )))
959
3
                )),
960
3
            ),
961
3
        ]),
962
3
    );
963
3
    let result = eval_constant(&expr);
964
3
    assert_eq!(result, None);
965
3
}
966

            
967
#[test]
968
3
fn eval_const_or() {
969
3
    let expr = Expression::Or(
970
3
        Metadata::new(),
971
3
        Moo::new(matrix_expr![
972
3
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
973
3
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
974
3
        ]),
975
3
    );
976
3
    let result = eval_constant(&expr);
977
3
    assert_eq!(result, Some(Literal::Bool(false)));
978
3
}