1
use conjure_cp::{
2
    Model,
3
    ast::{
4
        Atom, DeclarationPtr, Domain, Expression, Literal, Metadata, Moo, Name, Range, Reference,
5
        SymbolTable, eval_constant,
6
    },
7
    into_matrix_expr, matrix_expr,
8
    rule_engine::{Rule, get_all_rules, get_rule_by_name, resolve_rule_sets, rewrite_naive},
9
    settings::{QuantifiedExpander, SolverFamily, set_comprehension_expander},
10
    solver::{Solver, adaptors},
11
};
12
#[allow(unused_imports)]
13
#[allow(clippy::single_component_path_imports)] // ensure this is linked so we can lookup rules
14
use conjure_cp_rules;
15
use pretty_assertions::assert_eq;
16
use std::process::exit;
17
use uniplate::Uniplate;
18

            
19
#[test]
20
1
fn rules_present() {
21
1
    let rules = get_all_rules();
22
1
    assert!(!rules.is_empty());
23
1
}
24

            
25
#[test]
26
1
fn sum_of_constants() {
27
1
    let valid_sum_expression = Expression::Sum(
28
1
        Metadata::new(),
29
1
        Moo::new(matrix_expr![
30
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
31
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
32
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(3))),
33
1
        ]),
34
1
    );
35

            
36
1
    let invalid_sum_expression = Expression::Sum(
37
1
        Metadata::new(),
38
1
        Moo::new(matrix_expr![
39
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
40
1
            Expression::Atomic(
41
1
                Metadata::new(),
42
1
                Atom::Reference(Reference::new(DeclarationPtr::new_find(
43
1
                    Name::user("a"),
44
1
                    Domain::bool()
45
1
                ))),
46
1
            ),
47
1
        ]),
48
1
    );
49

            
50
1
    assert_eq!(evaluate_sum_of_constants(&valid_sum_expression), Some(6));
51

            
52
1
    assert_eq!(evaluate_sum_of_constants(&invalid_sum_expression), None);
53
1
}
54

            
55
4
fn evaluate_sum_of_constants(expr: &Expression) -> Option<i32> {
56
4
    match expr {
57
4
        Expression::Sum(_metadata, expressions) => {
58
4
            let expressions = (**expressions).clone().unwrap_list()?;
59
4
            let mut sum = 0;
60
10
            for e in expressions {
61
8
                match e {
62
8
                    Expression::Atomic(_, Atom::Literal(Literal::Int(value))) => {
63
8
                        sum += value;
64
8
                    }
65
2
                    _ => return None,
66
                }
67
            }
68
2
            Some(sum)
69
        }
70
        _ => None,
71
    }
72
4
}
73

            
74
#[test]
75
1
fn recursive_sum_of_constants() {
76
1
    let a = Atom::Reference(Reference::new(DeclarationPtr::new_find(
77
1
        Name::user("a"),
78
1
        Domain::int(vec![Range::Bounded(1, 5)]),
79
1
    )));
80
1
    let complex_expression = Expression::Eq(
81
1
        Metadata::new(),
82
1
        Moo::new(Expression::Sum(
83
1
            Metadata::new(),
84
1
            Moo::new(matrix_expr![
85
1
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
86
1
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
87
1
                Expression::Sum(
88
1
                    Metadata::new(),
89
1
                    Moo::new(matrix_expr![
90
1
                        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
91
1
                        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
92
1
                    ]),
93
1
                ),
94
1
                Expression::Atomic(Metadata::new(), a.clone()),
95
1
            ]),
96
1
        )),
97
1
        Moo::new(Expression::Atomic(
98
1
            Metadata::new(),
99
1
            Atom::Literal(Literal::Int(3)),
100
1
        )),
101
1
    );
102
1
    let correct_simplified_expression = Expression::Eq(
103
1
        Metadata::new(),
104
1
        Moo::new(Expression::Sum(
105
1
            Metadata::new(),
106
1
            Moo::new(matrix_expr![
107
1
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
108
1
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
109
1
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(3))),
110
1
                Expression::Atomic(Metadata::new(), a),
111
1
            ]),
112
1
        )),
113
1
        Moo::new(Expression::Atomic(
114
1
            Metadata::new(),
115
1
            Atom::Literal(Literal::Int(3)),
116
1
        )),
117
1
    );
118

            
119
1
    let simplified_expression = simplify_expression(complex_expression);
120
1
    assert_eq!(simplified_expression, correct_simplified_expression);
121
1
}
122

            
123
7
fn simplify_expression(expr: Expression) -> Expression {
124
7
    match expr {
125
2
        Expression::Sum(_metadata, expressions) => {
126
2
            let expressions = Moo::unwrap_or_clone(expressions).unwrap_list().unwrap();
127
2
            if let Some(result) = evaluate_sum_of_constants(&Expression::Sum(
128
2
                Metadata::new(),
129
2
                Moo::new(into_matrix_expr![expressions.clone()]),
130
2
            )) {
131
1
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(result)))
132
            } else {
133
1
                Expression::Sum(
134
1
                    Metadata::new(),
135
1
                    Moo::new(into_matrix_expr![
136
1
                        expressions.into_iter().map(simplify_expression).collect()
137
1
                    ]),
138
1
                )
139
            }
140
        }
141
1
        Expression::Eq(_metadata, left, right) => Expression::Eq(
142
1
            Metadata::new(),
143
1
            Moo::new(simplify_expression(Moo::unwrap_or_clone(left))),
144
1
            Moo::new(simplify_expression(Moo::unwrap_or_clone(right))),
145
1
        ),
146
        Expression::Geq(_metadata, left, right) => Expression::Geq(
147
            Metadata::new(),
148
            Moo::new(simplify_expression(Moo::unwrap_or_clone(left))),
149
            Moo::new(simplify_expression(Moo::unwrap_or_clone(right))),
150
        ),
151
4
        _ => expr,
152
    }
153
7
}
154

            
155
#[test]
156
1
fn rule_sum_constants() {
157
1
    let sum_constants = get_rule_by_name("partial_evaluator").unwrap();
158
1
    let unwrap_sum = get_rule_by_name("remove_unit_vector_sum").unwrap();
159

            
160
1
    let mut expr = Expression::Sum(
161
1
        Metadata::new(),
162
1
        Moo::new(matrix_expr![
163
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
164
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
165
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(3))),
166
1
        ]),
167
1
    );
168

            
169
1
    expr = sum_constants
170
1
        .apply(&expr, &SymbolTable::new())
171
1
        .unwrap()
172
1
        .new_expression;
173
1
    expr = unwrap_sum
174
1
        .apply(&expr, &SymbolTable::new())
175
1
        .unwrap()
176
1
        .new_expression;
177

            
178
1
    assert_eq!(
179
        expr,
180
1
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(6)))
181
    );
182
1
}
183

            
184
#[test]
185
1
fn rule_sum_geq() {
186
1
    let introduce_sumgeq = get_rule_by_name("introduce_weighted_sumleq_sumgeq").unwrap();
187

            
188
1
    let mut expr = Expression::Geq(
189
1
        Metadata::new(),
190
1
        Moo::new(Expression::Sum(
191
1
            Metadata::new(),
192
1
            Moo::new(matrix_expr![
193
1
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
194
1
                Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
195
1
            ]),
196
1
        )),
197
1
        Moo::new(Expression::Atomic(
198
1
            Metadata::new(),
199
1
            Atom::Literal(Literal::Int(3)),
200
1
        )),
201
1
    );
202

            
203
1
    expr = introduce_sumgeq
204
1
        .apply(&expr, &SymbolTable::new())
205
1
        .unwrap()
206
1
        .new_expression;
207

            
208
1
    assert_eq!(
209
        expr,
210
1
        Expression::FlatSumGeq(
211
1
            Metadata::new(),
212
1
            vec![
213
1
                Atom::Literal(Literal::Int(1)),
214
1
                Atom::Literal(Literal::Int(2)),
215
1
            ],
216
1
            Atom::Literal(Literal::Int(3))
217
1
        )
218
    );
219
1
}
220

            
221
///
222
/// Reduce and solve:
223
/// ```text
224
/// find a,b,c : int(1..3)
225
/// such that a + b + c <= 2 + 3 - 1
226
/// such that a < b
227
/// ```
228
#[test]
229
1
fn reduce_solve_xyz() {
230
1
    println!("Rules: {:?}", get_all_rules());
231
1
    let sum_constants = get_rule_by_name("partial_evaluator").unwrap();
232
1
    let unwrap_sum = get_rule_by_name("remove_unit_vector_sum").unwrap();
233
1
    let lt_to_leq = get_rule_by_name("lt_to_leq").unwrap();
234
1
    let leq_to_ineq = get_rule_by_name("x_leq_y_plus_k_to_ineq").unwrap();
235
1
    let introduce_sumleq = get_rule_by_name("introduce_weighted_sumleq_sumgeq").unwrap();
236

            
237
    // 2 + 3 - 1
238
1
    let mut expr1 = Expression::Sum(
239
1
        Metadata::new(),
240
1
        Moo::new(matrix_expr![
241
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2))),
242
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(3))),
243
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(-1))),
244
1
        ]),
245
1
    );
246

            
247
1
    expr1 = sum_constants
248
1
        .apply(&expr1, &SymbolTable::new())
249
1
        .unwrap()
250
1
        .new_expression;
251
1
    expr1 = unwrap_sum
252
1
        .apply(&expr1, &SymbolTable::new())
253
1
        .unwrap()
254
1
        .new_expression;
255
1
    assert_eq!(
256
        expr1,
257
1
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(4)))
258
    );
259

            
260
1
    let a = Atom::Reference(Reference::new(DeclarationPtr::new_find(
261
1
        Name::user("a"),
262
1
        Domain::int(vec![Range::Bounded(1, 5)]),
263
1
    )));
264
1
    let b = Atom::Reference(Reference::new(DeclarationPtr::new_find(
265
1
        Name::user("b"),
266
1
        Domain::int(vec![Range::Bounded(1, 5)]),
267
1
    )));
268
1
    let c = Atom::Reference(Reference::new(DeclarationPtr::new_find(
269
1
        Name::user("c"),
270
1
        Domain::int(vec![Range::Bounded(1, 5)]),
271
1
    )));
272

            
273
    // a + b + c = 4
274
1
    expr1 = Expression::Leq(
275
1
        Metadata::new(),
276
1
        Moo::new(Expression::Sum(
277
1
            Metadata::new(),
278
1
            Moo::new(matrix_expr![
279
1
                Expression::Atomic(Metadata::new(), a.clone()),
280
1
                Expression::Atomic(Metadata::new(), b.clone()),
281
1
                Expression::Atomic(Metadata::new(), c.clone()),
282
1
            ]),
283
1
        )),
284
1
        Moo::new(expr1),
285
1
    );
286
1
    expr1 = introduce_sumleq
287
1
        .apply(&expr1, &SymbolTable::new())
288
1
        .unwrap()
289
1
        .new_expression;
290
1
    assert_eq!(
291
        expr1,
292
1
        Expression::FlatSumLeq(
293
1
            Metadata::new(),
294
1
            vec![a.clone(), b.clone(), c],
295
1
            Atom::Literal(Literal::Int(4))
296
1
        )
297
    );
298

            
299
    // a < b
300
1
    let mut expr2 = Expression::Lt(
301
1
        Metadata::new(),
302
1
        Moo::new(Expression::Atomic(Metadata::new(), a.clone())),
303
1
        Moo::new(Expression::Atomic(Metadata::new(), b.clone())),
304
1
    );
305
1
    expr2 = lt_to_leq
306
1
        .apply(&expr2, &SymbolTable::new())
307
1
        .unwrap()
308
1
        .new_expression;
309

            
310
1
    expr2 = leq_to_ineq
311
1
        .apply(&expr2, &SymbolTable::new())
312
1
        .unwrap()
313
1
        .new_expression;
314
1
    assert_eq!(
315
        expr2,
316
1
        Expression::FlatIneq(
317
1
            Metadata::new(),
318
1
            Moo::new(a),
319
1
            Moo::new(b),
320
1
            Box::new(Literal::Int(-1)),
321
1
        )
322
    );
323

            
324
1
    let mut model = Model::new(Default::default());
325
1
    *model.constraints_mut() = vec![expr1, expr2];
326

            
327
1
    model
328
1
        .symbols_mut()
329
1
        .insert(DeclarationPtr::new_find(
330
1
            Name::user("a"),
331
1
            Domain::int(vec![Range::Bounded(1, 3)]),
332
1
        ))
333
1
        .unwrap();
334
1
    model
335
1
        .symbols_mut()
336
1
        .insert(DeclarationPtr::new_find(
337
1
            Name::user("b"),
338
1
            Domain::int(vec![Range::Bounded(1, 3)]),
339
1
        ))
340
1
        .unwrap();
341
1
    model
342
1
        .symbols_mut()
343
1
        .insert(DeclarationPtr::new_find(
344
1
            Name::user("c"),
345
1
            Domain::int(vec![Range::Bounded(1, 3)]),
346
1
        ))
347
1
        .unwrap();
348

            
349
1
    let solver: Solver = Solver::new(adaptors::Minion::new());
350
1
    let solver = solver.load_model(model).unwrap();
351
1
    solver.solve(Box::new(|_| true)).unwrap();
352
1
}
353

            
354
#[test]
355
1
fn rule_remove_double_negation() {
356
1
    let remove_double_negation = get_rule_by_name("remove_double_negation").unwrap();
357

            
358
1
    let mut expr = Expression::Not(
359
1
        Metadata::new(),
360
1
        Moo::new(Expression::Not(
361
1
            Metadata::new(),
362
1
            Moo::new(Expression::Atomic(
363
1
                Metadata::new(),
364
1
                Atom::Literal(Literal::Bool(true)),
365
1
            )),
366
1
        )),
367
1
    );
368

            
369
1
    expr = remove_double_negation
370
1
        .apply(&expr, &SymbolTable::new())
371
1
        .unwrap()
372
1
        .new_expression;
373

            
374
1
    assert_eq!(
375
        expr,
376
1
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)))
377
    );
378
1
}
379

            
380
#[test]
381
1
fn remove_trivial_and_or() {
382
1
    let remove_trivial_and = get_rule_by_name("remove_unit_vector_and").unwrap();
383
1
    let remove_trivial_or = get_rule_by_name("remove_unit_vector_or").unwrap();
384

            
385
1
    let mut expr_and = Expression::And(
386
1
        Metadata::new(),
387
1
        Moo::new(matrix_expr![Expression::Atomic(
388
1
            Metadata::new(),
389
1
            Atom::Literal(Literal::Bool(true)),
390
1
        )]),
391
1
    );
392
1
    let mut expr_or = Expression::Or(
393
1
        Metadata::new(),
394
1
        Moo::new(matrix_expr![Expression::Atomic(
395
1
            Metadata::new(),
396
1
            Atom::Literal(Literal::Bool(false)),
397
1
        )]),
398
1
    );
399

            
400
1
    expr_and = remove_trivial_and
401
1
        .apply(&expr_and, &SymbolTable::new())
402
1
        .unwrap()
403
1
        .new_expression;
404
1
    expr_or = remove_trivial_or
405
1
        .apply(&expr_or, &SymbolTable::new())
406
1
        .unwrap()
407
1
        .new_expression;
408

            
409
1
    assert_eq!(
410
        expr_and,
411
1
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)))
412
    );
413
1
    assert_eq!(
414
        expr_or,
415
1
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false)))
416
    );
417
1
}
418

            
419
#[test]
420
1
fn rule_distribute_not_over_and() {
421
1
    let distribute_not_over_and = get_rule_by_name("distribute_not_over_and").unwrap();
422

            
423
1
    let a = Atom::Reference(Reference::new(DeclarationPtr::new_find(
424
1
        Name::user("a"),
425
1
        Domain::bool(),
426
1
    )));
427

            
428
1
    let b = Atom::Reference(Reference::new(DeclarationPtr::new_find(
429
1
        Name::user("b"),
430
1
        Domain::bool(),
431
1
    )));
432

            
433
1
    let mut expr = Expression::Not(
434
1
        Metadata::new(),
435
1
        Moo::new(Expression::And(
436
1
            Metadata::new(),
437
1
            Moo::new(matrix_expr![
438
1
                Expression::Atomic(Metadata::new(), a.clone()),
439
1
                Expression::Atomic(Metadata::new(), b.clone()),
440
1
            ]),
441
1
        )),
442
1
    );
443

            
444
1
    expr = distribute_not_over_and
445
1
        .apply(&expr, &SymbolTable::new())
446
1
        .unwrap()
447
1
        .new_expression;
448

            
449
1
    assert_eq!(
450
        expr,
451
1
        Expression::Or(
452
1
            Metadata::new(),
453
1
            Moo::new(matrix_expr![
454
1
                Expression::Not(
455
1
                    Metadata::new(),
456
1
                    Moo::new(Expression::Atomic(Metadata::new(), a))
457
1
                ),
458
1
                Expression::Not(
459
1
                    Metadata::new(),
460
1
                    Moo::new(Expression::Atomic(Metadata::new(), b))
461
1
                ),
462
1
            ])
463
1
        )
464
    );
465
1
}
466

            
467
#[test]
468
1
fn rule_distribute_not_over_or() {
469
1
    let distribute_not_over_or = get_rule_by_name("distribute_not_over_or").unwrap();
470

            
471
1
    let a = Atom::Reference(Reference::new(DeclarationPtr::new_find(
472
1
        Name::user("a"),
473
1
        Domain::bool(),
474
1
    )));
475

            
476
1
    let b = Atom::Reference(Reference::new(DeclarationPtr::new_find(
477
1
        Name::user("b"),
478
1
        Domain::bool(),
479
1
    )));
480

            
481
1
    let mut expr = Expression::Not(
482
1
        Metadata::new(),
483
1
        Moo::new(Expression::Or(
484
1
            Metadata::new(),
485
1
            Moo::new(matrix_expr![
486
1
                Expression::Atomic(Metadata::new(), a.clone()),
487
1
                Expression::Atomic(Metadata::new(), b.clone()),
488
1
            ]),
489
1
        )),
490
1
    );
491

            
492
1
    expr = distribute_not_over_or
493
1
        .apply(&expr, &SymbolTable::new())
494
1
        .unwrap()
495
1
        .new_expression;
496

            
497
1
    assert_eq!(
498
        expr,
499
1
        Expression::And(
500
1
            Metadata::new(),
501
1
            Moo::new(matrix_expr![
502
1
                Expression::Not(
503
1
                    Metadata::new(),
504
1
                    Moo::new(Expression::Atomic(Metadata::new(), a))
505
1
                ),
506
1
                Expression::Not(
507
1
                    Metadata::new(),
508
1
                    Moo::new(Expression::Atomic(Metadata::new(), b))
509
1
                ),
510
1
            ])
511
1
        )
512
    );
513
1
}
514

            
515
#[test]
516
1
fn rule_distribute_not_over_and_not_changed() {
517
1
    let distribute_not_over_and = get_rule_by_name("distribute_not_over_and").unwrap();
518

            
519
1
    let expr = Expression::Not(
520
1
        Metadata::new(),
521
1
        Moo::new(Expression::Atomic(
522
1
            Metadata::new(),
523
1
            Atom::Reference(Reference::new(DeclarationPtr::new_find(
524
1
                Name::user("a"),
525
1
                Domain::int(vec![Range::Bounded(1, 5)]),
526
1
            ))),
527
1
        )),
528
1
    );
529

            
530
1
    let result = distribute_not_over_and.apply(&expr, &SymbolTable::new());
531

            
532
1
    assert!(result.is_err());
533
1
}
534

            
535
#[test]
536
1
fn rule_distribute_not_over_or_not_changed() {
537
1
    let distribute_not_over_or = get_rule_by_name("distribute_not_over_or").unwrap();
538

            
539
1
    let expr = Expression::Not(
540
1
        Metadata::new(),
541
1
        Moo::new(Expression::Atomic(
542
1
            Metadata::new(),
543
1
            Atom::Reference(Reference::new(DeclarationPtr::new_find(
544
1
                Name::user("a"),
545
1
                Domain::int(vec![Range::Bounded(1, 5)]),
546
1
            ))),
547
1
        )),
548
1
    );
549

            
550
1
    let result = distribute_not_over_or.apply(&expr, &SymbolTable::new());
551

            
552
1
    assert!(result.is_err());
553
1
}
554

            
555
#[test]
556
1
fn rule_distribute_or_over_and() {
557
1
    let distribute_or_over_and = get_rule_by_name("distribute_or_over_and").unwrap();
558

            
559
1
    let d1 = Atom::Reference(Reference::new(DeclarationPtr::new_find(
560
1
        Name::Machine(1),
561
1
        Domain::bool(),
562
1
    )));
563

            
564
1
    let d2 = Atom::Reference(Reference::new(DeclarationPtr::new_find(
565
1
        Name::Machine(2),
566
1
        Domain::bool(),
567
1
    )));
568

            
569
1
    let expr = Expression::Or(
570
1
        Metadata::new(),
571
1
        Moo::new(matrix_expr![
572
1
            Expression::And(
573
1
                Metadata::new(),
574
1
                Moo::new(matrix_expr![
575
1
                    Expression::Atomic(Metadata::new(), d1.clone()),
576
1
                    Expression::Atomic(Metadata::new(), d2.clone()),
577
1
                ]),
578
1
            ),
579
1
            Expression::Atomic(Metadata::new(), d2.clone()),
580
1
        ]),
581
1
    );
582

            
583
1
    let red = distribute_or_over_and
584
1
        .apply(&expr, &SymbolTable::new())
585
1
        .unwrap();
586

            
587
1
    assert_eq!(
588
        red.new_expression,
589
1
        Expression::And(
590
1
            Metadata::new(),
591
1
            Moo::new(matrix_expr![
592
1
                Expression::Or(
593
1
                    Metadata::new(),
594
1
                    Moo::new(matrix_expr![
595
1
                        Expression::Atomic(Metadata::new(), d2.clone()),
596
1
                        Expression::Atomic(Metadata::new(), d1),
597
1
                    ])
598
1
                ),
599
1
                Expression::Or(
600
1
                    Metadata::new(),
601
1
                    Moo::new(matrix_expr![
602
1
                        Expression::Atomic(Metadata::new(), d2.clone()),
603
1
                        Expression::Atomic(Metadata::new(), d2),
604
1
                    ])
605
1
                ),
606
1
            ])
607
1
        ),
608
    );
609
1
}
610

            
611
///
612
/// Reduce and solve:
613
/// ```text
614
/// find a,b,c : int(1..3)
615
/// such that a + b + c = 4
616
/// such that a < b
617
/// ```
618
///
619
/// This test uses the rewrite function to simplify the expression instead
620
/// of applying the rules manually.
621
#[test]
622
1
fn rewrite_solve_xyz() {
623
1
    println!("Rules: {:?}", get_all_rules());
624
1
    set_comprehension_expander(QuantifiedExpander::Native);
625

            
626
1
    let rule_sets = match resolve_rule_sets(SolverFamily::Minion, &["Constant"]) {
627
1
        Ok(rs) => rs,
628
        Err(e) => {
629
            eprintln!("Error resolving rule sets: {e}");
630
            exit(1);
631
        }
632
    };
633
1
    println!("Rule sets: {rule_sets:?}");
634

            
635
    // Create variables and domains
636
1
    let decl_a = DeclarationPtr::new_find(Name::user("a"), Domain::int(vec![Range::Bounded(1, 5)]));
637

            
638
1
    let decl_b = DeclarationPtr::new_find(Name::user("b"), Domain::int(vec![Range::Bounded(1, 5)]));
639

            
640
1
    let decl_c = DeclarationPtr::new_find(Name::user("c"), Domain::int(vec![Range::Bounded(1, 5)]));
641

            
642
1
    let a = Atom::Reference(Reference::new(decl_a.clone()));
643
1
    let b = Atom::Reference(Reference::new(decl_b.clone()));
644
1
    let c = Atom::Reference(Reference::new(decl_c.clone()));
645

            
646
    // Construct nested expression
647
1
    let nested_expr = Expression::And(
648
1
        Metadata::new(),
649
1
        Moo::new(matrix_expr![
650
1
            Expression::Eq(
651
1
                Metadata::new(),
652
1
                Moo::new(Expression::Sum(
653
1
                    Metadata::new(),
654
1
                    Moo::new(matrix_expr![
655
1
                        Expression::Atomic(Metadata::new(), a.clone()),
656
1
                        Expression::Atomic(Metadata::new(), b.clone()),
657
1
                        Expression::Atomic(Metadata::new(), c),
658
1
                    ]),
659
1
                )),
660
1
                Moo::new(Expression::Atomic(
661
1
                    Metadata::new(),
662
1
                    Atom::Literal(Literal::Int(4)),
663
1
                )),
664
1
            ),
665
1
            Expression::Lt(
666
1
                Metadata::new(),
667
1
                Moo::new(Expression::Atomic(Metadata::new(), a)),
668
1
                Moo::new(Expression::Atomic(Metadata::new(), b)),
669
1
            ),
670
1
        ]),
671
1
    );
672

            
673
1
    let rule_sets = match resolve_rule_sets(SolverFamily::Minion, &["Constant"]) {
674
1
        Ok(rs) => rs,
675
        Err(e) => {
676
            eprintln!("Error resolving rule sets: {e}");
677
            exit(1);
678
        }
679
    };
680

            
681
    // Apply rewrite function to the nested expression
682
1
    let mut model = Model::new(Default::default());
683

            
684
    // Insert variables and domains
685
1
    model.symbols_mut().insert(decl_a).unwrap();
686
1
    model.symbols_mut().insert(decl_b).unwrap();
687
1
    model.symbols_mut().insert(decl_c).unwrap();
688

            
689
1
    *model.constraints_mut() = vec![nested_expr];
690

            
691
1
    model = rewrite_naive(&model, &rule_sets, true).unwrap();
692
1
    let rewritten_expr = model.constraints();
693

            
694
    // Check if the expression is in its simplest form
695

            
696
1
    assert!(rewritten_expr.iter().all(is_simple));
697

            
698
1
    let solver: Solver = Solver::new(adaptors::Minion::new());
699
1
    let solver = solver.load_model(model).unwrap();
700
1
    solver.solve(Box::new(|_| true)).unwrap();
701
1
}
702

            
703
struct RuleResult<'a> {
704
    #[allow(dead_code)]
705
    rule: &'a Rule<'a>,
706
    new_expression: Expression,
707
}
708

            
709
/// # Returns
710
/// - True if `expression` is in its simplest form.
711
/// - False otherwise.
712
3
pub fn is_simple(expression: &Expression) -> bool {
713
3
    let rules = get_all_rules();
714
3
    let mut new = expression.clone();
715
3
    while let Some(step) = is_simple_iteration(&new, &rules) {
716
        new = step;
717
    }
718
3
    new == *expression
719
3
}
720

            
721
/// # Returns
722
/// - Some(<new_expression>) after applying the first applicable rule to `expr` or a sub-expression.
723
/// - None if no rule is applicable to the expression or any sub-expression.
724
3
fn is_simple_iteration<'a>(
725
3
    expression: &'a Expression,
726
3
    rules: &'a Vec<&'a Rule<'a>>,
727
3
) -> Option<Expression> {
728
3
    let rule_results = apply_all_rules(expression, rules);
729
3
    if let Some(new) = choose_rewrite(&rule_results) {
730
        return Some(new);
731
    } else {
732
3
        let mut sub = expression.children();
733
3
        for i in 0..sub.len() {
734
            if let Some(new) = is_simple_iteration(&sub[i], rules) {
735
                sub[i] = new;
736
                return Some(expression.with_children(sub.clone()));
737
            }
738
        }
739
    }
740
3
    None // No rules applicable to this branch of the expression
741
3
}
742

            
743
/// # Returns
744
/// - A list of RuleResults after applying all rules to `expression`.
745
/// - An empty list if no rules are applicable.
746
3
fn apply_all_rules<'a>(
747
3
    expression: &'a Expression,
748
3
    rules: &'a Vec<&'a Rule<'a>>,
749
3
) -> Vec<RuleResult<'a>> {
750
3
    let mut results = Vec::new();
751
402
    for rule in rules {
752
402
        match rule.apply(expression, &SymbolTable::new()) {
753
            Ok(red) => {
754
                results.push(RuleResult {
755
                    rule,
756
                    new_expression: red.new_expression,
757
                });
758
            }
759
402
            Err(_) => continue,
760
        }
761
    }
762
3
    results
763
3
}
764

            
765
/// # Returns
766
/// - Some(<new_expression>) after applying the first rule in `results`.
767
/// - None if `results` is empty.
768
3
fn choose_rewrite(results: &[RuleResult]) -> Option<Expression> {
769
3
    if results.is_empty() {
770
3
        return None;
771
    }
772
    // Return the first result for now
773
    // println!("Applying rule: {:?}", results[0].rule);
774
    Some(results[0].new_expression.clone())
775
3
}
776

            
777
#[test]
778
1
fn eval_const_int() {
779
1
    let expr = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
780
1
    let result = eval_constant(&expr);
781
1
    assert_eq!(result, Some(Literal::Int(1)));
782
1
}
783

            
784
#[test]
785
1
fn eval_const_bool() {
786
1
    let expr = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
787
1
    let result = eval_constant(&expr);
788
1
    assert_eq!(result, Some(Literal::Bool(true)));
789
1
}
790

            
791
#[test]
792
1
fn eval_const_and() {
793
1
    let expr = Expression::And(
794
1
        Metadata::new(),
795
1
        Moo::new(matrix_expr![
796
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
797
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
798
1
        ]),
799
1
    );
800
1
    let result = eval_constant(&expr);
801
1
    assert_eq!(result, Some(Literal::Bool(false)));
802
1
}
803

            
804
#[test]
805
1
fn eval_const_ref() {
806
1
    let expr = Expression::Atomic(
807
1
        Metadata::new(),
808
1
        Atom::Reference(Reference::new(DeclarationPtr::new_find(
809
1
            Name::user("a"),
810
1
            Domain::int(vec![Range::Bounded(1, 5)]),
811
1
        ))),
812
1
    );
813
1
    let result = eval_constant(&expr);
814
1
    assert_eq!(result, None);
815
1
}
816

            
817
#[test]
818
1
fn eval_const_nested_ref() {
819
1
    let expr = Expression::Sum(
820
1
        Metadata::new(),
821
1
        Moo::new(matrix_expr![
822
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
823
1
            Expression::And(
824
1
                Metadata::new(),
825
1
                Moo::new(matrix_expr![
826
1
                    Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
827
1
                    Expression::Atomic(
828
1
                        Metadata::new(),
829
1
                        Atom::Reference(Reference::new(DeclarationPtr::new_find(
830
1
                            Name::user("a"),
831
1
                            Domain::int(vec![Range::Bounded(1, 5)])
832
1
                        )))
833
1
                    ),
834
1
                ]),
835
1
            ),
836
1
        ]),
837
1
    );
838
1
    let result = eval_constant(&expr);
839
1
    assert_eq!(result, None);
840
1
}
841

            
842
#[test]
843
1
fn eval_const_eq_int() {
844
1
    let expr = Expression::Eq(
845
1
        Metadata::new(),
846
1
        Moo::new(Expression::Atomic(
847
1
            Metadata::new(),
848
1
            Atom::Literal(Literal::Int(1)),
849
1
        )),
850
1
        Moo::new(Expression::Atomic(
851
1
            Metadata::new(),
852
1
            Atom::Literal(Literal::Int(1)),
853
1
        )),
854
1
    );
855
1
    let result = eval_constant(&expr);
856
1
    assert_eq!(result, Some(Literal::Bool(true)));
857
1
}
858

            
859
#[test]
860
1
fn eval_const_eq_bool() {
861
1
    let expr = Expression::Eq(
862
1
        Metadata::new(),
863
1
        Moo::new(Expression::Atomic(
864
1
            Metadata::new(),
865
1
            Atom::Literal(Literal::Bool(true)),
866
1
        )),
867
1
        Moo::new(Expression::Atomic(
868
1
            Metadata::new(),
869
1
            Atom::Literal(Literal::Bool(true)),
870
1
        )),
871
1
    );
872
1
    let result = eval_constant(&expr);
873
1
    assert_eq!(result, Some(Literal::Bool(true)));
874
1
}
875

            
876
#[test]
877
1
fn eval_const_eq_mixed() {
878
1
    let expr = Expression::Eq(
879
1
        Metadata::new(),
880
1
        Moo::new(Expression::Atomic(
881
1
            Metadata::new(),
882
1
            Atom::Literal(Literal::Int(1)),
883
1
        )),
884
1
        Moo::new(Expression::Atomic(
885
1
            Metadata::new(),
886
1
            Atom::Literal(Literal::Bool(true)),
887
1
        )),
888
1
    );
889
1
    let result = eval_constant(&expr);
890
1
    assert_eq!(result, None);
891
1
}
892

            
893
#[test]
894
1
fn eval_const_sum_mixed() {
895
1
    let expr = Expression::Sum(
896
1
        Metadata::new(),
897
1
        Moo::new(matrix_expr![
898
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1))),
899
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
900
1
        ]),
901
1
    );
902
1
    let result = eval_constant(&expr);
903
1
    assert_eq!(result, None);
904
1
}
905

            
906
#[test]
907
1
fn eval_const_sum_xyz() {
908
1
    let expr = Expression::And(
909
1
        Metadata::new(),
910
1
        Moo::new(matrix_expr![
911
1
            Expression::Eq(
912
1
                Metadata::new(),
913
1
                Moo::new(Expression::Sum(
914
1
                    Metadata::new(),
915
1
                    Moo::new(matrix_expr![
916
1
                        Expression::Atomic(
917
1
                            Metadata::new(),
918
1
                            Atom::Reference(Reference::new(DeclarationPtr::new_find(
919
1
                                Name::user("x"),
920
1
                                Domain::int(vec![Range::Bounded(1, 5)])
921
1
                            )))
922
1
                        ),
923
1
                        Expression::Atomic(
924
1
                            Metadata::new(),
925
1
                            Atom::Reference(Reference::new(DeclarationPtr::new_find(
926
1
                                Name::user("y"),
927
1
                                Domain::int(vec![Range::Bounded(1, 5)])
928
1
                            )))
929
1
                        ),
930
1
                        Expression::Atomic(
931
1
                            Metadata::new(),
932
1
                            Atom::Reference(Reference::new(DeclarationPtr::new_find(
933
1
                                Name::user("z"),
934
1
                                Domain::int(vec![Range::Bounded(1, 5)])
935
1
                            )))
936
1
                        ),
937
1
                    ])
938
1
                )),
939
1
                Moo::new(Expression::Atomic(
940
1
                    Metadata::new(),
941
1
                    Atom::Literal(Literal::Int(4)),
942
1
                )),
943
1
            ),
944
1
            Expression::Geq(
945
1
                Metadata::new(),
946
1
                Moo::new(Expression::Atomic(
947
1
                    Metadata::new(),
948
1
                    Atom::Reference(Reference::new(DeclarationPtr::new_find(
949
1
                        Name::user("x"),
950
1
                        Domain::int(vec![Range::Bounded(1, 5)])
951
1
                    )))
952
1
                )),
953
1
                Moo::new(Expression::Atomic(
954
1
                    Metadata::new(),
955
1
                    Atom::Reference(Reference::new(DeclarationPtr::new_find(
956
1
                        Name::user("y"),
957
1
                        Domain::int(vec![Range::Bounded(1, 5)])
958
1
                    )))
959
1
                )),
960
1
            ),
961
1
        ]),
962
1
    );
963
1
    let result = eval_constant(&expr);
964
1
    assert_eq!(result, None);
965
1
}
966

            
967
#[test]
968
1
fn eval_const_or() {
969
1
    let expr = Expression::Or(
970
1
        Metadata::new(),
971
1
        Moo::new(matrix_expr![
972
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
973
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
974
1
        ]),
975
1
    );
976
1
    let result = eval_constant(&expr);
977
1
    assert_eq!(result, Some(Literal::Bool(false)));
978
1
}