1
use conjure_cp::ast::abstract_comprehension::{Generator, Qualifier};
2
use conjure_cp::ast::{Expression as Expr, *};
3
use conjure_cp::rule_engine::ApplicationError;
4
use conjure_cp::rule_engine::{
5
    ApplicationError::{DomainError, RuleNotApplicable},
6
    ApplicationResult, Reduction, register_rule, register_rule_set,
7
};
8
use conjure_cp::settings::SolverFamily;
9
use conjure_cp::{bug, essence_expr};
10
use uniplate::Uniplate;
11

            
12
// These rules are applicable regardless of what theories are used.
13
register_rule_set!("Smt", ("Base"), |f: &SolverFamily| {
14
982
    matches!(f, SolverFamily::Smt(..))
15
982
});
16

            
17
#[register_rule(("Smt", 1000))]
18
10926
fn flatten_indomain(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
19
10926
    let Expr::InDomain(_, inner, domain) = expr else {
20
10923
        return Err(RuleNotApplicable);
21
    };
22

            
23
3
    let dom = domain.resolve().ok_or(RuleNotApplicable)?;
24
3
    let new_expr = match dom.as_ref() {
25
        // Bool values are always in the bool domain
26
        GroundDomain::Bool => Ok(Expr::Atomic(
27
            Metadata::new(),
28
            Atom::Literal(Literal::Bool(true)),
29
        )),
30
        GroundDomain::Empty(_) => Ok(Expr::Atomic(
31
            Metadata::new(),
32
            Atom::Literal(Literal::Bool(false)),
33
        )),
34
3
        GroundDomain::Int(ranges) => {
35
3
            let elements: Vec<_> = ranges
36
3
                .iter()
37
3
                .map(|range| match range {
38
                    Range::Single(n) => {
39
                        let eq = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Int(*n)));
40
                        Expr::Eq(Metadata::new(), inner.clone(), Moo::new(eq))
41
                    }
42
3
                    Range::Bounded(l, r) => {
43
3
                        let l_expr = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Int(*l)));
44
3
                        let r_expr = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Int(*r)));
45
3
                        let lit = AbstractLiteral::matrix_implied_indices(vec![
46
3
                            Expr::Geq(Metadata::new(), inner.clone(), Moo::new(l_expr)),
47
3
                            Expr::Leq(Metadata::new(), inner.clone(), Moo::new(r_expr)),
48
                        ]);
49
3
                        Expr::And(
50
3
                            Metadata::new(),
51
3
                            Moo::new(Expr::AbstractLiteral(Metadata::new(), lit)),
52
3
                        )
53
                    }
54
                    Range::UnboundedL(r) => {
55
                        let bound = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Int(*r)));
56
                        Expr::Leq(Metadata::new(), inner.clone(), Moo::new(bound))
57
                    }
58
                    Range::UnboundedR(l) => {
59
                        let bound = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Int(*l)));
60
                        Expr::Geq(Metadata::new(), inner.clone(), Moo::new(bound))
61
                    }
62
                    Range::Unbounded => bug!("integer domains should not have unbounded ranges"),
63
3
                })
64
3
                .collect();
65
3
            Ok(Expr::Or(
66
3
                Metadata::new(),
67
3
                Moo::new(Expr::AbstractLiteral(
68
3
                    Metadata::new(),
69
3
                    AbstractLiteral::matrix_implied_indices(elements),
70
3
                )),
71
3
            ))
72
        }
73
        _ => Err(RuleNotApplicable),
74
    }?;
75
3
    Ok(Reduction::pure(new_expr))
76
10926
}
77

            
78
/// Matrix a = b iff every index in the union of their indices has the same value.
79
/// E.g. a: matrix indexed by [int(1..2)] of int(1..2), b: matrix indexed by [int(2..3)] of int(1..2)
80
/// a = b ~> a[1] = b[1] /\ a[2] = b[2] /\ a[3] = b[3]
81
// Must run before `matrix_ref_to_atom` ("Base", 2000), otherwise matrix equality can be
82
// rewritten into `int(1..)` indexed literals, losing finite index bounds for this rule.
83
#[register_rule(("Smt", 3000))]
84
19317
fn flatten_matrix_eq_neq(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
85
19317
    let (a, b) = match expr {
86
867
        Expr::Eq(_, a, b) | Expr::Neq(_, a, b) => (a, b),
87
18450
        _ => return Err(RuleNotApplicable),
88
    };
89

            
90
867
    let dom_a = a.domain_of().ok_or(RuleNotApplicable)?;
91
762
    let dom_b = b.domain_of().ok_or(RuleNotApplicable)?;
92

            
93
15
    let (Some((_, a_idx_domains)), Some((_, b_idx_domains))) =
94
759
        (dom_a.as_matrix_ground(), dom_b.as_matrix_ground())
95
    else {
96
744
        return Err(RuleNotApplicable);
97
    };
98

            
99
15
    let pairs = matrix::enumerate_index_union_indices(a_idx_domains, b_idx_domains)
100
15
        .map_err(|_| ApplicationError::DomainError)?
101
45
        .map(|idx_lits| {
102
45
            let idx_vec: Vec<_> = idx_lits
103
45
                .into_iter()
104
57
                .map(|lit| Atom::Literal(lit).into())
105
45
                .collect();
106
45
            (
107
45
                Expression::UnsafeIndex(Metadata::new(), a.clone(), idx_vec.clone()),
108
45
                Expression::UnsafeIndex(Metadata::new(), b.clone(), idx_vec),
109
45
            )
110
45
        });
111

            
112
15
    let new_expr = match expr {
113
        Expr::Eq(..) => {
114
33
            let eqs: Vec<_> = pairs.map(|(a, b)| essence_expr!(&a = &b)).collect();
115
9
            Expr::And(
116
9
                Metadata::new(),
117
9
                Moo::new(Expr::AbstractLiteral(
118
9
                    Metadata::new(),
119
9
                    AbstractLiteral::matrix_implied_indices(eqs),
120
9
                )),
121
9
            )
122
        }
123
        Expr::Neq(..) => {
124
12
            let neqs: Vec<_> = pairs.map(|(a, b)| essence_expr!(&a != &b)).collect();
125
6
            Expr::Or(
126
6
                Metadata::new(),
127
6
                Moo::new(Expr::AbstractLiteral(
128
6
                    Metadata::new(),
129
6
                    AbstractLiteral::matrix_implied_indices(neqs),
130
6
                )),
131
6
            )
132
        }
133
        _ => unreachable!(),
134
    };
135

            
136
15
    Ok(Reduction::pure(new_expr))
137
19317
}
138

            
139
/// Turn a matrix slice into a 1-d matrix of the slice elements
140
/// E.g. m[1,..] ~> [m[1,1], m[1,2], m[1,3]]
141
#[register_rule(("Smt", 1000))]
142
10926
fn flatten_matrix_slice(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
143
10926
    let Expr::SafeSlice(_, m, slice_idxs) = expr else {
144
10902
        return Err(RuleNotApplicable);
145
    };
146

            
147
24
    let dom = m.domain_of().ok_or(RuleNotApplicable)?;
148
24
    let Some((_, mat_idxs)) = dom.as_matrix_ground() else {
149
3
        return Err(RuleNotApplicable);
150
    };
151

            
152
21
    if slice_idxs.len() != mat_idxs.len() {
153
        return Err(DomainError);
154
21
    }
155

            
156
    // Find where in the index vector the ".." is
157
21
    let (slice_dim, _) = slice_idxs
158
21
        .iter()
159
21
        .enumerate()
160
30
        .find(|(_, idx)| idx.is_none())
161
21
        .ok_or(RuleNotApplicable)?;
162
21
    let other_idxs = {
163
21
        let opt: Option<Vec<_>> = [&slice_idxs[..slice_dim], &slice_idxs[(slice_dim + 1)..]]
164
21
            .concat()
165
21
            .into_iter()
166
21
            .collect();
167
21
        opt.ok_or(DomainError)?
168
    };
169
21
    let elements: Vec<Expr> = mat_idxs[slice_dim]
170
21
        .values()
171
21
        .map_err(|_| DomainError)?
172
54
        .map(|lit| {
173
54
            let mut new_idx = other_idxs.clone();
174
54
            new_idx.insert(slice_dim, Expr::Atomic(Metadata::new(), Atom::Literal(lit)));
175
54
            Expr::SafeIndex(Metadata::new(), m.clone(), new_idx)
176
54
        })
177
21
        .collect();
178
21
    Ok(Reduction::pure(Expr::AbstractLiteral(
179
21
        Metadata::new(),
180
21
        AbstractLiteral::matrix_implied_indices(elements),
181
21
    )))
182
10926
}
183

            
184
/// Expressions like allDiff and sum support 1-dimensional matrices as inputs, e.g. sum(m) where m is indexed by 1..3.
185
///
186
/// This rule is very similar to `matrix_ref_to_atom`, but turns the matrix reference into a slice rather its atoms.
187
/// Other rules like `flatten_matrix_slice` take care of actually turning the slice into the matrix elements.
188
#[register_rule(("Smt", 999))]
189
10164
fn matrix_ref_to_slice(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
190
    if let Expr::SafeSlice(_, _, _)
191
    | Expr::UnsafeSlice(_, _, _)
192
    | Expr::SafeIndex(_, _, _)
193
10164
    | Expr::UnsafeIndex(_, _, _) = expr
194
    {
195
1923
        return Err(RuleNotApplicable);
196
8241
    };
197

            
198
8241
    for (child, ctx) in expr.holes() {
199
534
        let Expr::Atomic(_, Atom::Reference(decl)) = &child else {
200
4317
            continue;
201
        };
202

            
203
297
        let dom = decl.resolved_domain().ok_or(RuleNotApplicable)?;
204
297
        let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
205
288
            continue;
206
        };
207

            
208
        // Must be a 1d matrix
209
9
        if index_domains.len() > 1 {
210
6
            continue;
211
3
        }
212

            
213
3
        let new_child = Expr::SafeSlice(Metadata::new(), Moo::new(child.clone()), vec![None]);
214
3
        return Ok(Reduction::pure(ctx(new_child)));
215
    }
216

            
217
8238
    Err(RuleNotApplicable)
218
10164
}
219

            
220
/// This rule is applicable in SMT when atomic representation is not used for matrices.
221
///
222
/// Namely, it unwraps flatten(m) into [m[1, 1], m[1, 2], ...]
223
#[register_rule(("Smt", 999))]
224
10164
fn unwrap_flatten_matrix_nonatomic(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
225
    // TODO: depth not supported yet
226
6
    let Expr::Flatten(_, None, m) = expr else {
227
10158
        return Err(RuleNotApplicable);
228
    };
229

            
230
6
    let dom = m.domain_of().ok_or(RuleNotApplicable)?;
231
6
    let Some(GroundDomain::Matrix(_, index_domains)) = dom.resolve().map(Moo::unwrap_or_clone)
232
    else {
233
        return Err(RuleNotApplicable);
234
    };
235

            
236
6
    let elems: Vec<Expr> = matrix::enumerate_indices(index_domains)
237
39
        .map(|lits| {
238
39
            let idxs: Vec<Expr> = lits.into_iter().map(Into::into).collect();
239
39
            Expr::SafeIndex(Metadata::new(), m.clone(), idxs)
240
39
        })
241
6
        .collect();
242

            
243
6
    let new_dom = GroundDomain::Int(vec![Range::Bounded(
244
6
        1,
245
6
        elems
246
6
            .len()
247
6
            .try_into()
248
6
            .expect("length of matrix should be able to be held in Int type"),
249
6
    )]);
250
6
    let new_expr = Expr::AbstractLiteral(
251
6
        Metadata::new(),
252
6
        AbstractLiteral::Matrix(elems, new_dom.into()),
253
6
    );
254
6
    Ok(Reduction::pure(new_expr))
255
10164
}
256

            
257
/// Expands a sum over an "in set" comprehension to a list.
258
///
259
/// TODO: We currently only support one "in set" generator.
260
/// This rule can be made much more general and nicer.
261
#[register_rule(("Smt", 999))]
262
10164
fn unwrap_abstract_comprehension_sum(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
263
10164
    let Expr::Sum(_, inner) = expr else {
264
9996
        return Err(RuleNotApplicable);
265
    };
266
168
    let Expr::AbstractComprehension(_, comp) = inner.as_ref() else {
267
165
        return Err(RuleNotApplicable);
268
    };
269

            
270
3
    let [Qualifier::Generator(Generator::ExpressionGenerator(generator))] = &comp.qualifiers[..]
271
    else {
272
        return Err(RuleNotApplicable);
273
    };
274

            
275
3
    let set = &generator.expression;
276
3
    let elem_domain = generator.decl.domain().ok_or(DomainError)?;
277
3
    let list: Vec<_> = elem_domain
278
3
        .values()
279
3
        .map_err(|_| DomainError)?
280
15
        .map(|lit| essence_expr!("&lit * toInt(&lit in &set)"))
281
3
        .collect();
282

            
283
3
    let new_expr = Expr::Sum(
284
3
        Metadata::new(),
285
3
        Moo::new(Expr::AbstractLiteral(
286
3
            Metadata::new(),
287
3
            AbstractLiteral::matrix_implied_indices(list),
288
3
        )),
289
3
    );
290
3
    Ok(Reduction::pure(new_expr))
291
10164
}
292

            
293
/// Unwraps a subsetEq expression into checking membership equality.
294
///
295
/// Any elements not in the domain of one set must not be in the other set.
296
#[register_rule(("Smt", 999))]
297
10164
fn unwrap_subseteq(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
298
10164
    let Expr::SubsetEq(_, a, b) = expr else {
299
10164
        return Err(RuleNotApplicable);
300
    };
301

            
302
    let dom_a = a.domain_of().and_then(|d| d.resolve()).ok_or(DomainError)?;
303
    let dom_b = b.domain_of().and_then(|d| d.resolve()).ok_or(DomainError)?;
304

            
305
    let GroundDomain::Set(_, elem_dom_a) = dom_a.as_ref() else {
306
        return Err(RuleNotApplicable);
307
    };
308
    let GroundDomain::Set(_, elem_dom_b) = dom_b.as_ref() else {
309
        return Err(RuleNotApplicable);
310
    };
311

            
312
    let domain_a_iter = elem_dom_a.values().map_err(|_| DomainError)?;
313
    let memberships = domain_a_iter
314
        .map(|lit| {
315
            let b_contains = elem_dom_b.contains(&lit).map_err(|_| DomainError)?;
316
            match b_contains {
317
                true => Ok(essence_expr!("(&lit in &a) -> (&lit in &b)")),
318
                false => Ok(essence_expr!("!(&lit in &a)")),
319
            }
320
        })
321
        .collect::<Result<Vec<_>, _>>()?;
322

            
323
    let new_expr = Expr::And(
324
        Metadata::new(),
325
        Moo::new(Expr::AbstractLiteral(
326
            Metadata::new(),
327
            AbstractLiteral::matrix_implied_indices(memberships),
328
        )),
329
    );
330

            
331
    Ok(Reduction::pure(new_expr))
332
10164
}
333

            
334
/// Unwraps equality between sets into checking membership equality.
335
///
336
/// This is an optimisation over unwrap_subseteq to avoid unnecessary additional -> exprs
337
/// where a single <-> is enough. This must apply before eq_to_subset_eq.
338
#[register_rule(("Smt", 8801))]
339
33777
fn unwrap_set_eq(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
340
33777
    let Expr::Eq(_, a, b) = expr else {
341
32148
        return Err(RuleNotApplicable);
342
    };
343

            
344
1629
    let dom_a = a.domain_of().and_then(|d| d.resolve()).ok_or(DomainError)?;
345
1524
    let dom_b = b.domain_of().and_then(|d| d.resolve()).ok_or(DomainError)?;
346

            
347
1521
    let GroundDomain::Set(_, elem_dom_a) = dom_a.as_ref() else {
348
1515
        return Err(RuleNotApplicable);
349
    };
350
6
    let GroundDomain::Set(_, elem_dom_b) = dom_b.as_ref() else {
351
        return Err(RuleNotApplicable);
352
    };
353

            
354
6
    let union_val_iter = elem_dom_a
355
6
        .union(elem_dom_b)
356
6
        .and_then(|d| d.values())
357
6
        .map_err(|_| DomainError)?;
358
6
    let memberships = union_val_iter
359
15
        .map(|lit| {
360
15
            let a_contains = elem_dom_a.contains(&lit).map_err(|_| DomainError)?;
361
15
            let b_contains = elem_dom_b.contains(&lit).map_err(|_| DomainError)?;
362
15
            match (a_contains, b_contains) {
363
9
                (true, true) => Ok(essence_expr!("(&lit in &a) <-> (&lit in &b)")),
364
3
                (true, false) => Ok(essence_expr!("!(&lit in &a)")),
365
3
                (false, true) => Ok(essence_expr!("!(&lit in &b)")),
366
                (false, false) => unreachable!(),
367
            }
368
15
        })
369
6
        .collect::<Result<Vec<_>, _>>()?;
370

            
371
6
    let new_expr = Expr::And(
372
6
        Metadata::new(),
373
6
        Moo::new(Expr::AbstractLiteral(
374
6
            Metadata::new(),
375
6
            AbstractLiteral::matrix_implied_indices(memberships),
376
6
        )),
377
6
    );
378

            
379
6
    Ok(Reduction::pure(new_expr))
380
33777
}