1
use conjure_cp::{
2
    Model,
3
    ast::{
4
        Atom, DeclarationPtr, Domain, Expression, Literal, Metadata, Moo, Name, Range, Reference,
5
        SymbolTable, eval_constant,
6
    },
7
    into_matrix_expr, matrix_expr,
8
    rule_engine::{Rule, get_all_rules, get_rule_by_name, resolve_rule_sets, rewrite_naive},
9
    settings::SolverFamily,
10
    solver::{Solver, adaptors},
11
};
12
#[allow(unused_imports)]
13
#[allow(clippy::single_component_path_imports)] // ensure this is linked so we can lookup rules
14
use conjure_cp_rules;
15
use pretty_assertions::assert_eq;
16
use std::process::exit;
17
use uniplate::Uniplate;
18

            
19
#[test]
20
1
fn rules_present() {
21
1
    let rules = get_all_rules();
22
1
    assert!(!rules.is_empty());
23
1
}
24

            
25
#[test]
26
1
fn sum_of_constants() {
27
1
    let valid_sum_expression = Expression::Sum(
28
1
        Metadata::new(),
29
1
        Moo::new(matrix_expr![
30
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
31
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
32
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(3))),
33
1
        ]),
34
1
    );
35

            
36
1
    let invalid_sum_expression = Expression::Sum(
37
1
        Metadata::new(),
38
1
        Moo::new(matrix_expr![
39
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
40
1
            Expression::Atomic(
41
1
                Metadata::new(),
42
1
                Atom::Reference(Reference::new(DeclarationPtr::new_find(
43
1
                    Name::user("a"),
44
1
                    Domain::bool()
45
1
                ))),
46
1
            ),
47
1
        ]),
48
1
    );
49

            
50
1
    assert_eq!(evaluate_sum_of_constants(&valid_sum_expression), Some(6));
51

            
52
1
    assert_eq!(evaluate_sum_of_constants(&invalid_sum_expression), None);
53
1
}
54

            
55
4
fn evaluate_sum_of_constants(expr: &Expression) -> Option<i32> {
56
4
    match expr {
57
4
        Expression::Sum(_metadata, expressions) => {
58
4
            let expressions = (**expressions).clone().unwrap_list()?;
59
4
            let mut sum = 0;
60
10
            for e in expressions {
61
8
                match e {
62
8
                    Expression::Atomic(_, Atom::Literal(Literal::Int(value))) => {
63
8
                        sum += value;
64
8
                    }
65
2
                    _ => return None,
66
                }
67
            }
68
2
            Some(sum)
69
        }
70
        _ => None,
71
    }
72
4
}
73

            
74
#[test]
75
1
fn recursive_sum_of_constants() {
76
1
    let a = Atom::Reference(Reference::new(DeclarationPtr::new_find(
77
1
        Name::user("a"),
78
1
        Domain::int(vec![Range::Bounded(1, 5)]),
79
1
    )));
80
1
    let complex_expression = Expression::Eq(
81
1
        Metadata::new(),
82
1
        Moo::new(Expression::Sum(
83
1
            Metadata::new(),
84
1
            Moo::new(matrix_expr![
85
1
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
86
1
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
87
1
                Expression::Sum(
88
1
                    Metadata::new(),
89
1
                    Moo::new(matrix_expr![
90
1
                        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
91
1
                        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
92
1
                    ]),
93
1
                ),
94
1
                Expression::Atomic(Metadata::new(), a.clone()),
95
1
            ]),
96
1
        )),
97
1
        Moo::new(Expression::Atomic(
98
1
            Metadata::new(),
99
1
            Atom::Literal(Literal::Int(3)),
100
1
        )),
101
1
    );
102
1
    let correct_simplified_expression = Expression::Eq(
103
1
        Metadata::new(),
104
1
        Moo::new(Expression::Sum(
105
1
            Metadata::new(),
106
1
            Moo::new(matrix_expr![
107
1
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
108
1
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
109
1
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(3))),
110
1
                Expression::Atomic(Metadata::new(), a),
111
1
            ]),
112
1
        )),
113
1
        Moo::new(Expression::Atomic(
114
1
            Metadata::new(),
115
1
            Atom::Literal(Literal::Int(3)),
116
1
        )),
117
1
    );
118

            
119
1
    let simplified_expression = simplify_expression(complex_expression);
120
1
    assert_eq!(simplified_expression, correct_simplified_expression);
121
1
}
122

            
123
7
fn simplify_expression(expr: Expression) -> Expression {
124
7
    match expr {
125
2
        Expression::Sum(_metadata, expressions) => {
126
2
            let expressions = Moo::unwrap_or_clone(expressions).unwrap_list().unwrap();
127
2
            if let Some(result) = evaluate_sum_of_constants(&Expression::Sum(
128
2
                Metadata::new(),
129
2
                Moo::new(into_matrix_expr![expressions.clone()]),
130
2
            )) {
131
1
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(result)))
132
            } else {
133
1
                Expression::Sum(
134
1
                    Metadata::new(),
135
1
                    Moo::new(into_matrix_expr![
136
1
                        expressions.into_iter().map(simplify_expression).collect()
137
1
                    ]),
138
1
                )
139
            }
140
        }
141
1
        Expression::Eq(_metadata, left, right) => Expression::Eq(
142
1
            Metadata::new(),
143
1
            Moo::new(simplify_expression(Moo::unwrap_or_clone(left))),
144
1
            Moo::new(simplify_expression(Moo::unwrap_or_clone(right))),
145
1
        ),
146
        Expression::Geq(_metadata, left, right) => Expression::Geq(
147
            Metadata::new(),
148
            Moo::new(simplify_expression(Moo::unwrap_or_clone(left))),
149
            Moo::new(simplify_expression(Moo::unwrap_or_clone(right))),
150
        ),
151
4
        _ => expr,
152
    }
153
7
}
154

            
155
#[test]
156
1
fn rule_sum_constants() {
157
1
    let sum_constants = get_rule_by_name("partial_evaluator").unwrap();
158
1
    let unwrap_sum = get_rule_by_name("remove_unit_vector_sum").unwrap();
159

            
160
1
    let mut expr = Expression::Sum(
161
1
        Metadata::new(),
162
1
        Moo::new(matrix_expr![
163
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
164
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
165
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(3))),
166
1
        ]),
167
1
    );
168

            
169
1
    expr = sum_constants
170
1
        .apply(&expr, &SymbolTable::new())
171
1
        .unwrap()
172
1
        .new_expression;
173
1
    expr = unwrap_sum
174
1
        .apply(&expr, &SymbolTable::new())
175
1
        .unwrap()
176
1
        .new_expression;
177

            
178
1
    assert_eq!(
179
        expr,
180
1
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(6)))
181
    );
182
1
}
183

            
184
#[test]
185
1
fn rule_sum_geq() {
186
1
    let introduce_sumgeq = get_rule_by_name("introduce_weighted_sumleq_sumgeq").unwrap();
187

            
188
1
    let mut expr = Expression::Geq(
189
1
        Metadata::new(),
190
1
        Moo::new(Expression::Sum(
191
1
            Metadata::new(),
192
1
            Moo::new(matrix_expr![
193
1
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
194
1
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
195
1
            ]),
196
1
        )),
197
1
        Moo::new(Expression::Atomic(
198
1
            Metadata::new(),
199
1
            Atom::Literal(Literal::Int(3)),
200
1
        )),
201
1
    );
202

            
203
1
    expr = introduce_sumgeq
204
1
        .apply(&expr, &SymbolTable::new())
205
1
        .unwrap()
206
1
        .new_expression;
207

            
208
1
    assert_eq!(
209
        expr,
210
1
        Expression::FlatSumGeq(
211
1
            Metadata::new(),
212
1
            vec![
213
1
                Atom::Literal(Literal::Int(1)),
214
1
                Atom::Literal(Literal::Int(2)),
215
1
            ],
216
1
            Atom::Literal(Literal::Int(3))
217
1
        )
218
    );
219
1
}
220

            
221
///
222
/// Reduce and solve:
223
/// ```text
224
/// find a,b,c : int(1..3)
225
/// such that a + b + c <= 2 + 3 - 1
226
/// such that a < b
227
/// ```
228
#[test]
229
1
fn reduce_solve_xyz() {
230
1
    println!("Rules: {:?}", get_all_rules());
231
1
    let sum_constants = get_rule_by_name("partial_evaluator").unwrap();
232
1
    let unwrap_sum = get_rule_by_name("remove_unit_vector_sum").unwrap();
233
1
    let lt_to_leq = get_rule_by_name("lt_to_leq").unwrap();
234
1
    let leq_to_ineq = get_rule_by_name("x_leq_y_plus_k_to_ineq").unwrap();
235
1
    let introduce_sumleq = get_rule_by_name("introduce_weighted_sumleq_sumgeq").unwrap();
236

            
237
    // 2 + 3 - 1
238
1
    let mut expr1 = Expression::Sum(
239
1
        Metadata::new(),
240
1
        Moo::new(matrix_expr![
241
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
242
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(3))),
243
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(-1))),
244
1
        ]),
245
1
    );
246

            
247
1
    expr1 = sum_constants
248
1
        .apply(&expr1, &SymbolTable::new())
249
1
        .unwrap()
250
1
        .new_expression;
251
1
    expr1 = unwrap_sum
252
1
        .apply(&expr1, &SymbolTable::new())
253
1
        .unwrap()
254
1
        .new_expression;
255
1
    assert_eq!(
256
        expr1,
257
1
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(4)))
258
    );
259

            
260
1
    let a = Atom::Reference(Reference::new(DeclarationPtr::new_find(
261
1
        Name::user("a"),
262
1
        Domain::int(vec![Range::Bounded(1, 5)]),
263
1
    )));
264
1
    let b = Atom::Reference(Reference::new(DeclarationPtr::new_find(
265
1
        Name::user("b"),
266
1
        Domain::int(vec![Range::Bounded(1, 5)]),
267
1
    )));
268
1
    let c = Atom::Reference(Reference::new(DeclarationPtr::new_find(
269
1
        Name::user("c"),
270
1
        Domain::int(vec![Range::Bounded(1, 5)]),
271
1
    )));
272

            
273
    // a + b + c = 4
274
1
    expr1 = Expression::Leq(
275
1
        Metadata::new(),
276
1
        Moo::new(Expression::Sum(
277
1
            Metadata::new(),
278
1
            Moo::new(matrix_expr![
279
1
                Expression::Atomic(Metadata::new(), a.clone()),
280
1
                Expression::Atomic(Metadata::new(), b.clone()),
281
1
                Expression::Atomic(Metadata::new(), c.clone()),
282
1
            ]),
283
1
        )),
284
1
        Moo::new(expr1),
285
1
    );
286
1
    expr1 = introduce_sumleq
287
1
        .apply(&expr1, &SymbolTable::new())
288
1
        .unwrap()
289
1
        .new_expression;
290
1
    assert_eq!(
291
        expr1,
292
1
        Expression::FlatSumLeq(
293
1
            Metadata::new(),
294
1
            vec![a.clone(), b.clone(), c],
295
1
            Atom::Literal(Literal::Int(4))
296
1
        )
297
    );
298

            
299
    // a < b
300
1
    let mut expr2 = Expression::Lt(
301
1
        Metadata::new(),
302
1
        Moo::new(Expression::Atomic(Metadata::new(), a.clone())),
303
1
        Moo::new(Expression::Atomic(Metadata::new(), b.clone())),
304
1
    );
305
1
    expr2 = lt_to_leq
306
1
        .apply(&expr2, &SymbolTable::new())
307
1
        .unwrap()
308
1
        .new_expression;
309

            
310
1
    expr2 = leq_to_ineq
311
1
        .apply(&expr2, &SymbolTable::new())
312
1
        .unwrap()
313
1
        .new_expression;
314
1
    assert_eq!(
315
        expr2,
316
1
        Expression::FlatIneq(
317
1
            Metadata::new(),
318
1
            Moo::new(a),
319
1
            Moo::new(b),
320
1
            Box::new(Literal::Int(-1)),
321
1
        )
322
    );
323

            
324
1
    let mut model = Model::new(Default::default());
325
1
    *model.as_submodel_mut().constraints_mut() = vec![expr1, expr2];
326

            
327
1
    model
328
1
        .as_submodel_mut()
329
1
        .symbols_mut()
330
1
        .insert(DeclarationPtr::new_find(
331
1
            Name::user("a"),
332
1
            Domain::int(vec![Range::Bounded(1, 3)]),
333
1
        ))
334
1
        .unwrap();
335
1
    model
336
1
        .as_submodel_mut()
337
1
        .symbols_mut()
338
1
        .insert(DeclarationPtr::new_find(
339
1
            Name::user("b"),
340
1
            Domain::int(vec![Range::Bounded(1, 3)]),
341
1
        ))
342
1
        .unwrap();
343
1
    model
344
1
        .as_submodel_mut()
345
1
        .symbols_mut()
346
1
        .insert(DeclarationPtr::new_find(
347
1
            Name::user("c"),
348
1
            Domain::int(vec![Range::Bounded(1, 3)]),
349
1
        ))
350
1
        .unwrap();
351

            
352
1
    let solver: Solver = Solver::new(adaptors::Minion::new());
353
1
    let solver = solver.load_model(model).unwrap();
354
1
    solver.solve(Box::new(|_| true)).unwrap();
355
1
}
356

            
357
#[test]
358
1
fn rule_remove_double_negation() {
359
1
    let remove_double_negation = get_rule_by_name("remove_double_negation").unwrap();
360

            
361
1
    let mut expr = Expression::Not(
362
1
        Metadata::new(),
363
1
        Moo::new(Expression::Not(
364
1
            Metadata::new(),
365
1
            Moo::new(Expression::Atomic(
366
1
                Metadata::new(),
367
1
                Atom::Literal(Literal::Bool(true)),
368
1
            )),
369
1
        )),
370
1
    );
371

            
372
1
    expr = remove_double_negation
373
1
        .apply(&expr, &SymbolTable::new())
374
1
        .unwrap()
375
1
        .new_expression;
376

            
377
1
    assert_eq!(
378
        expr,
379
1
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)))
380
    );
381
1
}
382

            
383
#[test]
384
1
fn remove_trivial_and_or() {
385
1
    let remove_trivial_and = get_rule_by_name("remove_unit_vector_and").unwrap();
386
1
    let remove_trivial_or = get_rule_by_name("remove_unit_vector_or").unwrap();
387

            
388
1
    let mut expr_and = Expression::And(
389
1
        Metadata::new(),
390
1
        Moo::new(matrix_expr![Expression::Atomic(
391
1
            Metadata::new(),
392
1
            Atom::Literal(Literal::Bool(true)),
393
1
        )]),
394
1
    );
395
1
    let mut expr_or = Expression::Or(
396
1
        Metadata::new(),
397
1
        Moo::new(matrix_expr![Expression::Atomic(
398
1
            Metadata::new(),
399
1
            Atom::Literal(Literal::Bool(false)),
400
1
        )]),
401
1
    );
402

            
403
1
    expr_and = remove_trivial_and
404
1
        .apply(&expr_and, &SymbolTable::new())
405
1
        .unwrap()
406
1
        .new_expression;
407
1
    expr_or = remove_trivial_or
408
1
        .apply(&expr_or, &SymbolTable::new())
409
1
        .unwrap()
410
1
        .new_expression;
411

            
412
1
    assert_eq!(
413
        expr_and,
414
1
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)))
415
    );
416
1
    assert_eq!(
417
        expr_or,
418
1
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false)))
419
    );
420
1
}
421

            
422
#[test]
423
1
fn rule_distribute_not_over_and() {
424
1
    let distribute_not_over_and = get_rule_by_name("distribute_not_over_and").unwrap();
425

            
426
1
    let a = Atom::Reference(Reference::new(DeclarationPtr::new_find(
427
1
        Name::user("a"),
428
1
        Domain::bool(),
429
1
    )));
430

            
431
1
    let b = Atom::Reference(Reference::new(DeclarationPtr::new_find(
432
1
        Name::user("b"),
433
1
        Domain::bool(),
434
1
    )));
435

            
436
1
    let mut expr = Expression::Not(
437
1
        Metadata::new(),
438
1
        Moo::new(Expression::And(
439
1
            Metadata::new(),
440
1
            Moo::new(matrix_expr![
441
1
                Expression::Atomic(Metadata::new(), a.clone()),
442
1
                Expression::Atomic(Metadata::new(), b.clone()),
443
1
            ]),
444
1
        )),
445
1
    );
446

            
447
1
    expr = distribute_not_over_and
448
1
        .apply(&expr, &SymbolTable::new())
449
1
        .unwrap()
450
1
        .new_expression;
451

            
452
1
    assert_eq!(
453
        expr,
454
1
        Expression::Or(
455
1
            Metadata::new(),
456
1
            Moo::new(matrix_expr![
457
1
                Expression::Not(
458
1
                    Metadata::new(),
459
1
                    Moo::new(Expression::Atomic(Metadata::new(), a))
460
1
                ),
461
1
                Expression::Not(
462
1
                    Metadata::new(),
463
1
                    Moo::new(Expression::Atomic(Metadata::new(), b))
464
1
                ),
465
1
            ])
466
1
        )
467
    );
468
1
}
469

            
470
#[test]
471
1
fn rule_distribute_not_over_or() {
472
1
    let distribute_not_over_or = get_rule_by_name("distribute_not_over_or").unwrap();
473

            
474
1
    let a = Atom::Reference(Reference::new(DeclarationPtr::new_find(
475
1
        Name::user("a"),
476
1
        Domain::bool(),
477
1
    )));
478

            
479
1
    let b = Atom::Reference(Reference::new(DeclarationPtr::new_find(
480
1
        Name::user("b"),
481
1
        Domain::bool(),
482
1
    )));
483

            
484
1
    let mut expr = Expression::Not(
485
1
        Metadata::new(),
486
1
        Moo::new(Expression::Or(
487
1
            Metadata::new(),
488
1
            Moo::new(matrix_expr![
489
1
                Expression::Atomic(Metadata::new(), a.clone()),
490
1
                Expression::Atomic(Metadata::new(), b.clone()),
491
1
            ]),
492
1
        )),
493
1
    );
494

            
495
1
    expr = distribute_not_over_or
496
1
        .apply(&expr, &SymbolTable::new())
497
1
        .unwrap()
498
1
        .new_expression;
499

            
500
1
    assert_eq!(
501
        expr,
502
1
        Expression::And(
503
1
            Metadata::new(),
504
1
            Moo::new(matrix_expr![
505
1
                Expression::Not(
506
1
                    Metadata::new(),
507
1
                    Moo::new(Expression::Atomic(Metadata::new(), a))
508
1
                ),
509
1
                Expression::Not(
510
1
                    Metadata::new(),
511
1
                    Moo::new(Expression::Atomic(Metadata::new(), b))
512
1
                ),
513
1
            ])
514
1
        )
515
    );
516
1
}
517

            
518
#[test]
519
1
fn rule_distribute_not_over_and_not_changed() {
520
1
    let distribute_not_over_and = get_rule_by_name("distribute_not_over_and").unwrap();
521

            
522
1
    let expr = Expression::Not(
523
1
        Metadata::new(),
524
1
        Moo::new(Expression::Atomic(
525
1
            Metadata::new(),
526
1
            Atom::Reference(Reference::new(DeclarationPtr::new_find(
527
1
                Name::user("a"),
528
1
                Domain::int(vec![Range::Bounded(1, 5)]),
529
1
            ))),
530
1
        )),
531
1
    );
532

            
533
1
    let result = distribute_not_over_and.apply(&expr, &SymbolTable::new());
534

            
535
1
    assert!(result.is_err());
536
1
}
537

            
538
#[test]
539
1
fn rule_distribute_not_over_or_not_changed() {
540
1
    let distribute_not_over_or = get_rule_by_name("distribute_not_over_or").unwrap();
541

            
542
1
    let expr = Expression::Not(
543
1
        Metadata::new(),
544
1
        Moo::new(Expression::Atomic(
545
1
            Metadata::new(),
546
1
            Atom::Reference(Reference::new(DeclarationPtr::new_find(
547
1
                Name::user("a"),
548
1
                Domain::int(vec![Range::Bounded(1, 5)]),
549
1
            ))),
550
1
        )),
551
1
    );
552

            
553
1
    let result = distribute_not_over_or.apply(&expr, &SymbolTable::new());
554

            
555
1
    assert!(result.is_err());
556
1
}
557

            
558
#[test]
559
1
fn rule_distribute_or_over_and() {
560
1
    let distribute_or_over_and = get_rule_by_name("distribute_or_over_and").unwrap();
561

            
562
1
    let d1 = Atom::Reference(Reference::new(DeclarationPtr::new_find(
563
1
        Name::Machine(1),
564
1
        Domain::bool(),
565
1
    )));
566

            
567
1
    let d2 = Atom::Reference(Reference::new(DeclarationPtr::new_find(
568
1
        Name::Machine(2),
569
1
        Domain::bool(),
570
1
    )));
571

            
572
1
    let expr = Expression::Or(
573
1
        Metadata::new(),
574
1
        Moo::new(matrix_expr![
575
1
            Expression::And(
576
1
                Metadata::new(),
577
1
                Moo::new(matrix_expr![
578
1
                    Expression::Atomic(Metadata::new(), d1.clone()),
579
1
                    Expression::Atomic(Metadata::new(), d2.clone()),
580
1
                ]),
581
1
            ),
582
1
            Expression::Atomic(Metadata::new(), d2.clone()),
583
1
        ]),
584
1
    );
585

            
586
1
    let red = distribute_or_over_and
587
1
        .apply(&expr, &SymbolTable::new())
588
1
        .unwrap();
589

            
590
1
    assert_eq!(
591
        red.new_expression,
592
1
        Expression::And(
593
1
            Metadata::new(),
594
1
            Moo::new(matrix_expr![
595
1
                Expression::Or(
596
1
                    Metadata::new(),
597
1
                    Moo::new(matrix_expr![
598
1
                        Expression::Atomic(Metadata::new(), d2.clone()),
599
1
                        Expression::Atomic(Metadata::new(), d1),
600
1
                    ])
601
1
                ),
602
1
                Expression::Or(
603
1
                    Metadata::new(),
604
1
                    Moo::new(matrix_expr![
605
1
                        Expression::Atomic(Metadata::new(), d2.clone()),
606
1
                        Expression::Atomic(Metadata::new(), d2),
607
1
                    ])
608
1
                ),
609
1
            ])
610
1
        ),
611
    );
612
1
}
613

            
614
///
615
/// Reduce and solve:
616
/// ```text
617
/// find a,b,c : int(1..3)
618
/// such that a + b + c = 4
619
/// such that a < b
620
/// ```
621
///
622
/// This test uses the rewrite function to simplify the expression instead
623
/// of applying the rules manually.
624
#[test]
625
1
fn rewrite_solve_xyz() {
626
1
    println!("Rules: {:?}", get_all_rules());
627

            
628
1
    let rule_sets = match resolve_rule_sets(SolverFamily::Minion, &["Constant"]) {
629
1
        Ok(rs) => rs,
630
        Err(e) => {
631
            eprintln!("Error resolving rule sets: {e}");
632
            exit(1);
633
        }
634
    };
635
1
    println!("Rule sets: {rule_sets:?}");
636

            
637
    // Create variables and domains
638
1
    let decl_a = DeclarationPtr::new_find(Name::user("a"), Domain::int(vec![Range::Bounded(1, 5)]));
639

            
640
1
    let decl_b = DeclarationPtr::new_find(Name::user("b"), Domain::int(vec![Range::Bounded(1, 5)]));
641

            
642
1
    let decl_c = DeclarationPtr::new_find(Name::user("c"), Domain::int(vec![Range::Bounded(1, 5)]));
643

            
644
1
    let a = Atom::Reference(Reference::new(decl_a.clone()));
645
1
    let b = Atom::Reference(Reference::new(decl_b.clone()));
646
1
    let c = Atom::Reference(Reference::new(decl_c.clone()));
647

            
648
    // Construct nested expression
649
1
    let nested_expr = Expression::And(
650
1
        Metadata::new(),
651
1
        Moo::new(matrix_expr![
652
1
            Expression::Eq(
653
1
                Metadata::new(),
654
1
                Moo::new(Expression::Sum(
655
1
                    Metadata::new(),
656
1
                    Moo::new(matrix_expr![
657
1
                        Expression::Atomic(Metadata::new(), a.clone()),
658
1
                        Expression::Atomic(Metadata::new(), b.clone()),
659
1
                        Expression::Atomic(Metadata::new(), c),
660
1
                    ]),
661
1
                )),
662
1
                Moo::new(Expression::Atomic(
663
1
                    Metadata::new(),
664
1
                    Atom::Literal(Literal::Int(4)),
665
1
                )),
666
1
            ),
667
1
            Expression::Lt(
668
1
                Metadata::new(),
669
1
                Moo::new(Expression::Atomic(Metadata::new(), a)),
670
1
                Moo::new(Expression::Atomic(Metadata::new(), b)),
671
1
            ),
672
1
        ]),
673
1
    );
674

            
675
1
    let rule_sets = match resolve_rule_sets(SolverFamily::Minion, &["Constant"]) {
676
1
        Ok(rs) => rs,
677
        Err(e) => {
678
            eprintln!("Error resolving rule sets: {e}");
679
            exit(1);
680
        }
681
    };
682

            
683
    // Apply rewrite function to the nested expression
684
1
    let mut model = Model::new(Default::default());
685

            
686
    // Insert variables and domains
687
1
    model
688
1
        .as_submodel_mut()
689
1
        .symbols_mut()
690
1
        .insert(decl_a)
691
1
        .unwrap();
692
1
    model
693
1
        .as_submodel_mut()
694
1
        .symbols_mut()
695
1
        .insert(decl_b)
696
1
        .unwrap();
697
1
    model
698
1
        .as_submodel_mut()
699
1
        .symbols_mut()
700
1
        .insert(decl_c)
701
1
        .unwrap();
702

            
703
1
    *model.as_submodel_mut().constraints_mut() = vec![nested_expr];
704

            
705
1
    model = rewrite_naive(&model, &rule_sets, true, false).unwrap();
706
1
    let rewritten_expr = model.as_submodel().constraints();
707

            
708
    // Check if the expression is in its simplest form
709

            
710
1
    assert!(rewritten_expr.iter().all(is_simple));
711

            
712
1
    let solver: Solver = Solver::new(adaptors::Minion::new());
713
1
    let solver = solver.load_model(model).unwrap();
714
1
    solver.solve(Box::new(|_| true)).unwrap();
715
1
}
716

            
717
struct RuleResult<'a> {
718
    #[allow(dead_code)]
719
    rule: &'a Rule<'a>,
720
    new_expression: Expression,
721
}
722

            
723
/// # Returns
724
/// - True if `expression` is in its simplest form.
725
/// - False otherwise.
726
3
pub fn is_simple(expression: &Expression) -> bool {
727
3
    let rules = get_all_rules();
728
3
    let mut new = expression.clone();
729
3
    while let Some(step) = is_simple_iteration(&new, &rules) {
730
        new = step;
731
    }
732
3
    new == *expression
733
3
}
734

            
735
/// # Returns
736
/// - Some(<new_expression>) after applying the first applicable rule to `expr` or a sub-expression.
737
/// - None if no rule is applicable to the expression or any sub-expression.
738
3
fn is_simple_iteration<'a>(
739
3
    expression: &'a Expression,
740
3
    rules: &'a Vec<&'a Rule<'a>>,
741
3
) -> Option<Expression> {
742
3
    let rule_results = apply_all_rules(expression, rules);
743
3
    if let Some(new) = choose_rewrite(&rule_results) {
744
        return Some(new);
745
    } else {
746
3
        let mut sub = expression.children();
747
3
        for i in 0..sub.len() {
748
            if let Some(new) = is_simple_iteration(&sub[i], rules) {
749
                sub[i] = new;
750
                return Some(expression.with_children(sub.clone()));
751
            }
752
        }
753
    }
754
3
    None // No rules applicable to this branch of the expression
755
3
}
756

            
757
/// # Returns
758
/// - A list of RuleResults after applying all rules to `expression`.
759
/// - An empty list if no rules are applicable.
760
3
fn apply_all_rules<'a>(
761
3
    expression: &'a Expression,
762
3
    rules: &'a Vec<&'a Rule<'a>>,
763
3
) -> Vec<RuleResult<'a>> {
764
3
    let mut results = Vec::new();
765
393
    for rule in rules {
766
393
        match rule.apply(expression, &SymbolTable::new()) {
767
            Ok(red) => {
768
                results.push(RuleResult {
769
                    rule,
770
                    new_expression: red.new_expression,
771
                });
772
            }
773
393
            Err(_) => continue,
774
        }
775
    }
776
3
    results
777
3
}
778

            
779
/// # Returns
780
/// - Some(<new_expression>) after applying the first rule in `results`.
781
/// - None if `results` is empty.
782
3
fn choose_rewrite(results: &[RuleResult]) -> Option<Expression> {
783
3
    if results.is_empty() {
784
3
        return None;
785
    }
786
    // Return the first result for now
787
    // println!("Applying rule: {:?}", results[0].rule);
788
    Some(results[0].new_expression.clone())
789
3
}
790

            
791
#[test]
792
1
fn eval_const_int() {
793
1
    let expr = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
794
1
    let result = eval_constant(&expr);
795
1
    assert_eq!(result, Some(Literal::Int(1)));
796
1
}
797

            
798
#[test]
799
1
fn eval_const_bool() {
800
1
    let expr = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
801
1
    let result = eval_constant(&expr);
802
1
    assert_eq!(result, Some(Literal::Bool(true)));
803
1
}
804

            
805
#[test]
806
1
fn eval_const_and() {
807
1
    let expr = Expression::And(
808
1
        Metadata::new(),
809
1
        Moo::new(matrix_expr![
810
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
811
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
812
1
        ]),
813
1
    );
814
1
    let result = eval_constant(&expr);
815
1
    assert_eq!(result, Some(Literal::Bool(false)));
816
1
}
817

            
818
#[test]
819
1
fn eval_const_ref() {
820
1
    let expr = Expression::Atomic(
821
1
        Metadata::new(),
822
1
        Atom::Reference(Reference::new(DeclarationPtr::new_find(
823
1
            Name::user("a"),
824
1
            Domain::int(vec![Range::Bounded(1, 5)]),
825
1
        ))),
826
1
    );
827
1
    let result = eval_constant(&expr);
828
1
    assert_eq!(result, None);
829
1
}
830

            
831
#[test]
832
1
fn eval_const_nested_ref() {
833
1
    let expr = Expression::Sum(
834
1
        Metadata::new(),
835
1
        Moo::new(matrix_expr![
836
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
837
1
            Expression::And(
838
1
                Metadata::new(),
839
1
                Moo::new(matrix_expr![
840
1
                    Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
841
1
                    Expression::Atomic(
842
1
                        Metadata::new(),
843
1
                        Atom::Reference(Reference::new(DeclarationPtr::new_find(
844
1
                            Name::user("a"),
845
1
                            Domain::int(vec![Range::Bounded(1, 5)])
846
1
                        )))
847
1
                    ),
848
1
                ]),
849
1
            ),
850
1
        ]),
851
1
    );
852
1
    let result = eval_constant(&expr);
853
1
    assert_eq!(result, None);
854
1
}
855

            
856
#[test]
857
1
fn eval_const_eq_int() {
858
1
    let expr = Expression::Eq(
859
1
        Metadata::new(),
860
1
        Moo::new(Expression::Atomic(
861
1
            Metadata::new(),
862
1
            Atom::Literal(Literal::Int(1)),
863
1
        )),
864
1
        Moo::new(Expression::Atomic(
865
1
            Metadata::new(),
866
1
            Atom::Literal(Literal::Int(1)),
867
1
        )),
868
1
    );
869
1
    let result = eval_constant(&expr);
870
1
    assert_eq!(result, Some(Literal::Bool(true)));
871
1
}
872

            
873
#[test]
874
1
fn eval_const_eq_bool() {
875
1
    let expr = Expression::Eq(
876
1
        Metadata::new(),
877
1
        Moo::new(Expression::Atomic(
878
1
            Metadata::new(),
879
1
            Atom::Literal(Literal::Bool(true)),
880
1
        )),
881
1
        Moo::new(Expression::Atomic(
882
1
            Metadata::new(),
883
1
            Atom::Literal(Literal::Bool(true)),
884
1
        )),
885
1
    );
886
1
    let result = eval_constant(&expr);
887
1
    assert_eq!(result, Some(Literal::Bool(true)));
888
1
}
889

            
890
#[test]
891
1
fn eval_const_eq_mixed() {
892
1
    let expr = Expression::Eq(
893
1
        Metadata::new(),
894
1
        Moo::new(Expression::Atomic(
895
1
            Metadata::new(),
896
1
            Atom::Literal(Literal::Int(1)),
897
1
        )),
898
1
        Moo::new(Expression::Atomic(
899
1
            Metadata::new(),
900
1
            Atom::Literal(Literal::Bool(true)),
901
1
        )),
902
1
    );
903
1
    let result = eval_constant(&expr);
904
1
    assert_eq!(result, None);
905
1
}
906

            
907
#[test]
908
1
fn eval_const_sum_mixed() {
909
1
    let expr = Expression::Sum(
910
1
        Metadata::new(),
911
1
        Moo::new(matrix_expr![
912
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
913
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
914
1
        ]),
915
1
    );
916
1
    let result = eval_constant(&expr);
917
1
    assert_eq!(result, None);
918
1
}
919

            
920
#[test]
921
1
fn eval_const_sum_xyz() {
922
1
    let expr = Expression::And(
923
1
        Metadata::new(),
924
1
        Moo::new(matrix_expr![
925
1
            Expression::Eq(
926
1
                Metadata::new(),
927
1
                Moo::new(Expression::Sum(
928
1
                    Metadata::new(),
929
1
                    Moo::new(matrix_expr![
930
1
                        Expression::Atomic(
931
1
                            Metadata::new(),
932
1
                            Atom::Reference(Reference::new(DeclarationPtr::new_find(
933
1
                                Name::user("x"),
934
1
                                Domain::int(vec![Range::Bounded(1, 5)])
935
1
                            )))
936
1
                        ),
937
1
                        Expression::Atomic(
938
1
                            Metadata::new(),
939
1
                            Atom::Reference(Reference::new(DeclarationPtr::new_find(
940
1
                                Name::user("y"),
941
1
                                Domain::int(vec![Range::Bounded(1, 5)])
942
1
                            )))
943
1
                        ),
944
1
                        Expression::Atomic(
945
1
                            Metadata::new(),
946
1
                            Atom::Reference(Reference::new(DeclarationPtr::new_find(
947
1
                                Name::user("z"),
948
1
                                Domain::int(vec![Range::Bounded(1, 5)])
949
1
                            )))
950
1
                        ),
951
1
                    ])
952
1
                )),
953
1
                Moo::new(Expression::Atomic(
954
1
                    Metadata::new(),
955
1
                    Atom::Literal(Literal::Int(4)),
956
1
                )),
957
1
            ),
958
1
            Expression::Geq(
959
1
                Metadata::new(),
960
1
                Moo::new(Expression::Atomic(
961
1
                    Metadata::new(),
962
1
                    Atom::Reference(Reference::new(DeclarationPtr::new_find(
963
1
                        Name::user("x"),
964
1
                        Domain::int(vec![Range::Bounded(1, 5)])
965
1
                    )))
966
1
                )),
967
1
                Moo::new(Expression::Atomic(
968
1
                    Metadata::new(),
969
1
                    Atom::Reference(Reference::new(DeclarationPtr::new_find(
970
1
                        Name::user("y"),
971
1
                        Domain::int(vec![Range::Bounded(1, 5)])
972
1
                    )))
973
1
                )),
974
1
            ),
975
1
        ]),
976
1
    );
977
1
    let result = eval_constant(&expr);
978
1
    assert_eq!(result, None);
979
1
}
980

            
981
#[test]
982
1
fn eval_const_or() {
983
1
    let expr = Expression::Or(
984
1
        Metadata::new(),
985
1
        Moo::new(matrix_expr![
986
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
987
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
988
1
        ]),
989
1
    );
990
1
    let result = eval_constant(&expr);
991
1
    assert_eq!(result, Some(Literal::Bool(false)));
992
1
}