1
use conjure_cp::ast::abstract_comprehension::{Generator, Qualifier};
2
use conjure_cp::ast::{Expression as Expr, *};
3
use conjure_cp::rule_engine::ApplicationError;
4
use conjure_cp::rule_engine::{
5
    ApplicationError::{DomainError, RuleNotApplicable},
6
    ApplicationResult, Reduction, register_rule, register_rule_set,
7
};
8
use conjure_cp::settings::SolverFamily;
9
use conjure_cp::{bug, essence_expr};
10
use uniplate::Uniplate;
11

            
12
// These rules are applicable regardless of what theories are used.
13
register_rule_set!("Smt", ("Base"), |f: &SolverFamily| {
14
4845
    matches!(f, SolverFamily::Smt(..))
15
4845
});
16

            
17
#[register_rule(("Smt", 1000))]
18
5655
fn flatten_indomain(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
19
5655
    let Expr::InDomain(_, inner, domain) = expr else {
20
5652
        return Err(RuleNotApplicable);
21
    };
22

            
23
3
    let dom = domain.resolve().ok_or(RuleNotApplicable)?;
24
3
    let new_expr = match dom.as_ref() {
25
        // Bool values are always in the bool domain
26
        GroundDomain::Bool => Ok(Expr::Atomic(
27
            Metadata::new(),
28
            Atom::Literal(Literal::Bool(true)),
29
        )),
30
        GroundDomain::Empty(_) => Ok(Expr::Atomic(
31
            Metadata::new(),
32
            Atom::Literal(Literal::Bool(false)),
33
        )),
34
3
        GroundDomain::Int(ranges) => {
35
3
            let elements: Vec<_> = ranges
36
3
                .iter()
37
3
                .map(|range| match range {
38
                    Range::Single(n) => {
39
                        let eq = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Int(*n)));
40
                        Expr::Eq(Metadata::new(), inner.clone(), Moo::new(eq))
41
                    }
42
3
                    Range::Bounded(l, r) => {
43
3
                        let l_expr = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Int(*l)));
44
3
                        let r_expr = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Int(*r)));
45
3
                        let lit = AbstractLiteral::matrix_implied_indices(vec![
46
3
                            Expr::Geq(Metadata::new(), inner.clone(), Moo::new(l_expr)),
47
3
                            Expr::Leq(Metadata::new(), inner.clone(), Moo::new(r_expr)),
48
                        ]);
49
3
                        Expr::And(
50
3
                            Metadata::new(),
51
3
                            Moo::new(Expr::AbstractLiteral(Metadata::new(), lit)),
52
3
                        )
53
                    }
54
                    Range::UnboundedL(r) => {
55
                        let bound = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Int(*r)));
56
                        Expr::Leq(Metadata::new(), inner.clone(), Moo::new(bound))
57
                    }
58
                    Range::UnboundedR(l) => {
59
                        let bound = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Int(*l)));
60
                        Expr::Geq(Metadata::new(), inner.clone(), Moo::new(bound))
61
                    }
62
                    Range::Unbounded => bug!("integer domains should not have unbounded ranges"),
63
3
                })
64
3
                .collect();
65
3
            Ok(Expr::Or(
66
3
                Metadata::new(),
67
3
                Moo::new(Expr::AbstractLiteral(
68
3
                    Metadata::new(),
69
3
                    AbstractLiteral::matrix_implied_indices(elements),
70
3
                )),
71
3
            ))
72
        }
73
        _ => Err(RuleNotApplicable),
74
    }?;
75
3
    Ok(Reduction::pure(new_expr))
76
5655
}
77

            
78
/// Matrix a = b iff every index in the union of their indices has the same value.
79
/// E.g. a: matrix indexed by [int(1..2)] of int(1..2), b: matrix indexed by [int(2..3)] of int(1..2)
80
/// a = b ~> a[1] = b[1] /\ a[2] = b[2] /\ a[3] = b[3]
81
// Must run before `matrix_ref_to_atom` ("Base", 2000), otherwise matrix equality can be
82
// rewritten into `int(1..)` indexed literals, losing finite index bounds for this rule.
83
#[register_rule(("Smt", 3000))]
84
73350
fn flatten_matrix_eq_neq(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
85
73350
    let (a, b) = match expr {
86
3924
        Expr::Eq(_, a, b) | Expr::Neq(_, a, b) => (a, b),
87
69426
        _ => return Err(RuleNotApplicable),
88
    };
89

            
90
3924
    let dom_a = a.domain_of().ok_or(RuleNotApplicable)?;
91
3483
    let dom_b = b.domain_of().ok_or(RuleNotApplicable)?;
92

            
93
45
    let (Some((_, a_idx_domains)), Some((_, b_idx_domains))) =
94
3474
        (dom_a.as_matrix_ground(), dom_b.as_matrix_ground())
95
    else {
96
3429
        return Err(RuleNotApplicable);
97
    };
98

            
99
45
    let pairs = matrix::enumerate_index_union_indices(a_idx_domains, b_idx_domains)
100
45
        .map_err(|_| ApplicationError::DomainError)?
101
135
        .map(|idx_lits| {
102
135
            let idx_vec: Vec<_> = idx_lits
103
135
                .into_iter()
104
171
                .map(|lit| Atom::Literal(lit).into())
105
135
                .collect();
106
135
            (
107
135
                Expression::UnsafeIndex(Metadata::new(), a.clone(), idx_vec.clone()),
108
135
                Expression::UnsafeIndex(Metadata::new(), b.clone(), idx_vec),
109
135
            )
110
135
        });
111

            
112
45
    let new_expr = match expr {
113
        Expr::Eq(..) => {
114
99
            let eqs: Vec<_> = pairs.map(|(a, b)| essence_expr!(&a = &b)).collect();
115
27
            Expr::And(
116
27
                Metadata::new(),
117
27
                Moo::new(Expr::AbstractLiteral(
118
27
                    Metadata::new(),
119
27
                    AbstractLiteral::matrix_implied_indices(eqs),
120
27
                )),
121
27
            )
122
        }
123
        Expr::Neq(..) => {
124
36
            let neqs: Vec<_> = pairs.map(|(a, b)| essence_expr!(&a != &b)).collect();
125
18
            Expr::Or(
126
18
                Metadata::new(),
127
18
                Moo::new(Expr::AbstractLiteral(
128
18
                    Metadata::new(),
129
18
                    AbstractLiteral::matrix_implied_indices(neqs),
130
18
                )),
131
18
            )
132
        }
133
        _ => unreachable!(),
134
    };
135

            
136
45
    Ok(Reduction::pure(new_expr))
137
73350
}
138

            
139
/// Turn a matrix slice into a 1-d matrix of the slice elements
140
/// E.g. m[1,..] ~> [m[1,1], m[1,2], m[1,3]]
141
#[register_rule(("Smt", 1000))]
142
16965
fn flatten_matrix_slice(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
143
16965
    let Expr::SafeSlice(_, m, slice_idxs) = expr else {
144
16893
        return Err(RuleNotApplicable);
145
    };
146

            
147
72
    let dom = m.domain_of().ok_or(RuleNotApplicable)?;
148
72
    let Some((_, mat_idxs)) = dom.as_matrix_ground() else {
149
9
        return Err(RuleNotApplicable);
150
    };
151

            
152
63
    if slice_idxs.len() != mat_idxs.len() {
153
        return Err(DomainError);
154
63
    }
155

            
156
    // Find where in the index vector the ".." is
157
63
    let (slice_dim, _) = slice_idxs
158
63
        .iter()
159
63
        .enumerate()
160
90
        .find(|(_, idx)| idx.is_none())
161
63
        .ok_or(RuleNotApplicable)?;
162
63
    let other_idxs = {
163
63
        let opt: Option<Vec<_>> = [&slice_idxs[..slice_dim], &slice_idxs[(slice_dim + 1)..]]
164
63
            .concat()
165
63
            .into_iter()
166
63
            .collect();
167
63
        opt.ok_or(DomainError)?
168
    };
169
63
    let elements: Vec<Expr> = mat_idxs[slice_dim]
170
63
        .values()
171
63
        .map_err(|_| DomainError)?
172
162
        .map(|lit| {
173
162
            let mut new_idx = other_idxs.clone();
174
162
            new_idx.insert(slice_dim, Expr::Atomic(Metadata::new(), Atom::Literal(lit)));
175
162
            Expr::SafeIndex(Metadata::new(), m.clone(), new_idx)
176
162
        })
177
63
        .collect();
178
63
    Ok(Reduction::pure(Expr::AbstractLiteral(
179
63
        Metadata::new(),
180
63
        AbstractLiteral::matrix_implied_indices(elements),
181
63
    )))
182
16965
}
183

            
184
/// Expressions like allDiff and sum support 1-dimensional matrices as inputs, e.g. sum(m) where m is indexed by 1..3.
185
///
186
/// This rule is very similar to `matrix_ref_to_atom`, but turns the matrix reference into a slice rather its atoms.
187
/// Other rules like `flatten_matrix_slice` take care of actually turning the slice into the matrix elements.
188
#[register_rule(("Smt", 999))]
189
13851
fn matrix_ref_to_slice(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
190
    if let Expr::SafeSlice(_, _, _)
191
    | Expr::UnsafeSlice(_, _, _)
192
    | Expr::SafeIndex(_, _, _)
193
13851
    | Expr::UnsafeIndex(_, _, _) = expr
194
    {
195
2421
        return Err(RuleNotApplicable);
196
11430
    };
197

            
198
11430
    for (child, ctx) in expr.holes() {
199
1404
        let Expr::Atomic(_, Atom::Reference(decl)) = &child else {
200
6318
            continue;
201
        };
202

            
203
828
        let dom = decl.resolved_domain().ok_or(RuleNotApplicable)?;
204
828
        let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
205
801
            continue;
206
        };
207

            
208
        // Must be a 1d matrix
209
27
        if index_domains.len() > 1 {
210
18
            continue;
211
9
        }
212

            
213
9
        let new_child = Expr::SafeSlice(Metadata::new(), Moo::new(child.clone()), vec![None]);
214
9
        return Ok(Reduction::pure(ctx(new_child)));
215
    }
216

            
217
11421
    Err(RuleNotApplicable)
218
13851
}
219

            
220
/// This rule is applicable in SMT when atomic representation is not used for matrices.
221
///
222
/// Namely, it unwraps flatten(m) into [m[1, 1], m[1, 2], ...]
223
#[register_rule(("Smt", 999))]
224
13851
fn unwrap_flatten_matrix_nonatomic(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
225
    // TODO: depth not supported yet
226
18
    let Expr::Flatten(_, None, m) = expr else {
227
13833
        return Err(RuleNotApplicable);
228
    };
229

            
230
18
    let dom = m.domain_of().ok_or(RuleNotApplicable)?;
231
18
    let Some(GroundDomain::Matrix(_, index_domains)) = dom.resolve().map(Moo::unwrap_or_clone)
232
    else {
233
        return Err(RuleNotApplicable);
234
    };
235

            
236
18
    let elems: Vec<Expr> = matrix::enumerate_indices(index_domains)
237
117
        .map(|lits| {
238
117
            let idxs: Vec<Expr> = lits.into_iter().map(Into::into).collect();
239
117
            Expr::SafeIndex(Metadata::new(), m.clone(), idxs)
240
117
        })
241
18
        .collect();
242

            
243
18
    let new_dom = GroundDomain::Int(vec![Range::Bounded(
244
18
        1,
245
18
        elems
246
18
            .len()
247
18
            .try_into()
248
18
            .expect("length of matrix should be able to be held in Int type"),
249
18
    )]);
250
18
    let new_expr = Expr::AbstractLiteral(
251
18
        Metadata::new(),
252
18
        AbstractLiteral::Matrix(elems, new_dom.into()),
253
18
    );
254
18
    Ok(Reduction::pure(new_expr))
255
13851
}
256

            
257
/// Expands a sum over an "in set" comprehension to a list.
258
///
259
/// TODO: We currently only support one "in set" generator.
260
/// This rule can be made much more general and nicer.
261
#[register_rule(("Smt", 999))]
262
13851
fn unwrap_abstract_comprehension_sum(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
263
13851
    let Expr::Sum(_, inner) = expr else {
264
13545
        return Err(RuleNotApplicable);
265
    };
266
306
    let Expr::AbstractComprehension(_, comp) = inner.as_ref() else {
267
297
        return Err(RuleNotApplicable);
268
    };
269

            
270
9
    let [Qualifier::Generator(Generator::ExpressionGenerator(generator))] = &comp.qualifiers[..]
271
    else {
272
        return Err(RuleNotApplicable);
273
    };
274

            
275
9
    let set = &generator.expression;
276
9
    let elem_domain = generator.decl.domain().ok_or(DomainError)?;
277
9
    let list: Vec<_> = elem_domain
278
9
        .values()
279
9
        .map_err(|_| DomainError)?
280
45
        .map(|lit| essence_expr!("&lit * toInt(&lit in &set)"))
281
9
        .collect();
282

            
283
9
    let new_expr = Expr::Sum(
284
9
        Metadata::new(),
285
9
        Moo::new(Expr::AbstractLiteral(
286
9
            Metadata::new(),
287
9
            AbstractLiteral::matrix_implied_indices(list),
288
9
        )),
289
9
    );
290
9
    Ok(Reduction::pure(new_expr))
291
13851
}
292

            
293
/// Unwraps a subsetEq expression into checking membership equality.
294
///
295
/// Any elements not in the domain of one set must not be in the other set.
296
#[register_rule(("Smt", 999))]
297
13851
fn unwrap_subseteq(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
298
13851
    let Expr::SubsetEq(_, a, b) = expr else {
299
13851
        return Err(RuleNotApplicable);
300
    };
301

            
302
    let dom_a = a.domain_of().and_then(|d| d.resolve()).ok_or(DomainError)?;
303
    let dom_b = b.domain_of().and_then(|d| d.resolve()).ok_or(DomainError)?;
304

            
305
    let GroundDomain::Set(_, elem_dom_a) = dom_a.as_ref() else {
306
        return Err(RuleNotApplicable);
307
    };
308
    let GroundDomain::Set(_, elem_dom_b) = dom_b.as_ref() else {
309
        return Err(RuleNotApplicable);
310
    };
311

            
312
    let domain_a_iter = elem_dom_a.values().map_err(|_| DomainError)?;
313
    let memberships = domain_a_iter
314
        .map(|lit| {
315
            let b_contains = elem_dom_b.contains(&lit).map_err(|_| DomainError)?;
316
            match b_contains {
317
                true => Ok(essence_expr!("(&lit in &a) -> (&lit in &b)")),
318
                false => Ok(essence_expr!("!(&lit in &a)")),
319
            }
320
        })
321
        .collect::<Result<Vec<_>, _>>()?;
322

            
323
    let new_expr = Expr::And(
324
        Metadata::new(),
325
        Moo::new(Expr::AbstractLiteral(
326
            Metadata::new(),
327
            AbstractLiteral::matrix_implied_indices(memberships),
328
        )),
329
    );
330

            
331
    Ok(Reduction::pure(new_expr))
332
13851
}
333

            
334
/// Unwraps equality between sets into checking membership equality.
335
///
336
/// This is an optimisation over unwrap_subseteq to avoid unnecessary additional -> exprs
337
/// where a single <-> is enough. This must apply before eq_to_subset_eq.
338
#[register_rule(("Smt", 8801))]
339
221256
fn unwrap_set_eq(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
340
221256
    let Expr::Eq(_, a, b) = expr else {
341
209538
        return Err(RuleNotApplicable);
342
    };
343

            
344
11718
    let dom_a = a.domain_of().and_then(|d| d.resolve()).ok_or(DomainError)?;
345
10440
    let dom_b = b.domain_of().and_then(|d| d.resolve()).ok_or(DomainError)?;
346

            
347
10431
    let GroundDomain::Set(_, elem_dom_a) = dom_a.as_ref() else {
348
10413
        return Err(RuleNotApplicable);
349
    };
350
18
    let GroundDomain::Set(_, elem_dom_b) = dom_b.as_ref() else {
351
        return Err(RuleNotApplicable);
352
    };
353

            
354
18
    let union_val_iter = elem_dom_a
355
18
        .union(elem_dom_b)
356
18
        .and_then(|d| d.values())
357
18
        .map_err(|_| DomainError)?;
358
18
    let memberships = union_val_iter
359
45
        .map(|lit| {
360
45
            let a_contains = elem_dom_a.contains(&lit).map_err(|_| DomainError)?;
361
45
            let b_contains = elem_dom_b.contains(&lit).map_err(|_| DomainError)?;
362
45
            match (a_contains, b_contains) {
363
27
                (true, true) => Ok(essence_expr!("(&lit in &a) <-> (&lit in &b)")),
364
9
                (true, false) => Ok(essence_expr!("!(&lit in &a)")),
365
9
                (false, true) => Ok(essence_expr!("!(&lit in &b)")),
366
                (false, false) => unreachable!(),
367
            }
368
45
        })
369
18
        .collect::<Result<Vec<_>, _>>()?;
370

            
371
18
    let new_expr = Expr::And(
372
18
        Metadata::new(),
373
18
        Moo::new(Expr::AbstractLiteral(
374
18
            Metadata::new(),
375
18
            AbstractLiteral::matrix_implied_indices(memberships),
376
18
        )),
377
18
    );
378

            
379
18
    Ok(Reduction::pure(new_expr))
380
221256
}