1
use conjure_cp::ast::comprehension::ComprehensionQualifier;
2
use conjure_cp::ast::{Expression as Expr, *};
3
use conjure_cp::rule_engine::ApplicationError;
4
use conjure_cp::rule_engine::{
5
    ApplicationError::{DomainError, RuleNotApplicable},
6
    ApplicationResult, Reduction, register_rule, register_rule_set,
7
};
8
use conjure_cp::settings::SolverFamily;
9
use conjure_cp::{bug, essence_expr};
10
use uniplate::Uniplate;
11

            
12
// These rules are applicable regardless of what theories are used.
13
register_rule_set!("Smt", ("Base"), |f: &SolverFamily| {
14
7638
    matches!(f, SolverFamily::Smt(..))
15
7638
});
16

            
17
#[register_rule("Smt", 1000, [InDomain])]
18
41806
fn flatten_indomain(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
19
41806
    let Expr::InDomain(_, inner, domain) = expr else {
20
41794
        return Err(RuleNotApplicable);
21
    };
22

            
23
12
    let dom = domain.resolve().ok_or(RuleNotApplicable)?;
24
12
    let new_expr = match dom.as_ref() {
25
        // Bool values are always in the bool domain
26
        GroundDomain::Bool => Ok(Expr::Atomic(
27
            Metadata::new(),
28
            Atom::Literal(Literal::Bool(true)),
29
        )),
30
        GroundDomain::Empty(_) => Ok(Expr::Atomic(
31
            Metadata::new(),
32
            Atom::Literal(Literal::Bool(false)),
33
        )),
34
12
        GroundDomain::Int(ranges) => {
35
12
            let elements: Vec<_> = ranges
36
12
                .iter()
37
12
                .map(|range| match range {
38
                    Range::Single(n) => {
39
                        let eq = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Int(*n)));
40
                        Expr::Eq(Metadata::new(), inner.clone(), Moo::new(eq))
41
                    }
42
12
                    Range::Bounded(l, r) => {
43
12
                        let l_expr = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Int(*l)));
44
12
                        let r_expr = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Int(*r)));
45
12
                        let lit = AbstractLiteral::matrix_implied_indices(vec![
46
12
                            Expr::Geq(Metadata::new(), inner.clone(), Moo::new(l_expr)),
47
12
                            Expr::Leq(Metadata::new(), inner.clone(), Moo::new(r_expr)),
48
                        ]);
49
12
                        Expr::And(
50
12
                            Metadata::new(),
51
12
                            Moo::new(Expr::AbstractLiteral(Metadata::new(), lit)),
52
12
                        )
53
                    }
54
                    Range::UnboundedL(r) => {
55
                        let bound = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Int(*r)));
56
                        Expr::Leq(Metadata::new(), inner.clone(), Moo::new(bound))
57
                    }
58
                    Range::UnboundedR(l) => {
59
                        let bound = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Int(*l)));
60
                        Expr::Geq(Metadata::new(), inner.clone(), Moo::new(bound))
61
                    }
62
                    Range::Unbounded => bug!("integer domains should not have unbounded ranges"),
63
12
                })
64
12
                .collect();
65
12
            Ok(Expr::Or(
66
12
                Metadata::new(),
67
12
                Moo::new(Expr::AbstractLiteral(
68
12
                    Metadata::new(),
69
12
                    AbstractLiteral::matrix_implied_indices(elements),
70
12
                )),
71
12
            ))
72
        }
73
        _ => Err(RuleNotApplicable),
74
    }?;
75
12
    Ok(Reduction::pure(new_expr))
76
41806
}
77

            
78
/// Matrix a = b iff every index in the union of their indices has the same value.
79
/// E.g. a: matrix indexed by [int(1..2)] of int(1..2), b: matrix indexed by [int(2..3)] of int(1..2)
80
/// a = b ~> a[1] = b[1] /\ a[2] = b[2] /\ a[3] = b[3]
81
// Must run before `matrix_ref_to_atom` ("Base", 2000), otherwise matrix equality can be
82
// rewritten into `int(1..)` indexed literals, losing finite index bounds for this rule.
83
#[register_rule("Smt", 3000, [Eq, Neq])]
84
167556
fn flatten_matrix_eq_neq(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
85
167556
    let (a, b) = match expr {
86
5086
        Expr::Eq(_, a, b) | Expr::Neq(_, a, b) => (a, b),
87
162470
        _ => return Err(RuleNotApplicable),
88
    };
89

            
90
5086
    let a_idx_domains = matrix::bound_index_domains_of_expr(a.as_ref()).ok_or(RuleNotApplicable)?;
91
68
    let b_idx_domains = matrix::bound_index_domains_of_expr(b.as_ref()).ok_or(RuleNotApplicable)?;
92

            
93
60
    let pairs = matrix::enumerate_index_union_indices(&a_idx_domains, &b_idx_domains)
94
60
        .map_err(|_| ApplicationError::DomainError)?
95
180
        .map(|idx_lits| {
96
180
            let idx_vec: Vec<_> = idx_lits
97
180
                .into_iter()
98
228
                .map(|lit| Atom::Literal(lit).into())
99
180
                .collect();
100
180
            (
101
180
                Expression::UnsafeIndex(Metadata::new(), a.clone(), idx_vec.clone()),
102
180
                Expression::UnsafeIndex(Metadata::new(), b.clone(), idx_vec),
103
180
            )
104
180
        });
105

            
106
60
    let new_expr = match expr {
107
        Expr::Eq(..) => {
108
132
            let eqs: Vec<_> = pairs.map(|(a, b)| essence_expr!(&a = &b)).collect();
109
36
            Expr::And(
110
36
                Metadata::new(),
111
36
                Moo::new(Expr::AbstractLiteral(
112
36
                    Metadata::new(),
113
36
                    AbstractLiteral::matrix_implied_indices(eqs),
114
36
                )),
115
36
            )
116
        }
117
        Expr::Neq(..) => {
118
48
            let neqs: Vec<_> = pairs.map(|(a, b)| essence_expr!(&a != &b)).collect();
119
24
            Expr::Or(
120
24
                Metadata::new(),
121
24
                Moo::new(Expr::AbstractLiteral(
122
24
                    Metadata::new(),
123
24
                    AbstractLiteral::matrix_implied_indices(neqs),
124
24
                )),
125
24
            )
126
        }
127
        _ => unreachable!(),
128
    };
129

            
130
60
    Ok(Reduction::pure(new_expr))
131
167556
}
132

            
133
/// Turn a matrix slice into a 1-d matrix of the slice elements
134
/// E.g. m[1,..] ~> [m[1,1], m[1,2], m[1,3]]
135
#[register_rule("Smt", 1000, [SafeSlice])]
136
41806
fn flatten_matrix_slice(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
137
41806
    let Expr::SafeSlice(_, m, slice_idxs) = expr else {
138
41710
        return Err(RuleNotApplicable);
139
    };
140

            
141
96
    let mat_idxs = matrix::bound_index_domains_of_expr(m.as_ref()).ok_or(RuleNotApplicable)?;
142

            
143
96
    if slice_idxs.len() != mat_idxs.len() {
144
        return Err(DomainError);
145
96
    }
146

            
147
    // Find where in the index vector the ".." is
148
96
    let (slice_dim, _) = slice_idxs
149
96
        .iter()
150
96
        .enumerate()
151
132
        .find(|(_, idx)| idx.is_none())
152
96
        .ok_or(RuleNotApplicable)?;
153
96
    let other_idxs = {
154
96
        let opt: Option<Vec<_>> = [&slice_idxs[..slice_dim], &slice_idxs[(slice_dim + 1)..]]
155
96
            .concat()
156
96
            .into_iter()
157
96
            .collect();
158
96
        opt.ok_or(DomainError)?
159
    };
160
96
    let elements: Vec<Expr> = mat_idxs[slice_dim]
161
96
        .values()
162
96
        .map_err(|_| DomainError)?
163
264
        .map(|lit| {
164
264
            let mut new_idx = other_idxs.clone();
165
264
            new_idx.insert(slice_dim, Expr::Atomic(Metadata::new(), Atom::Literal(lit)));
166
264
            Expr::SafeIndex(Metadata::new(), m.clone(), new_idx)
167
264
        })
168
96
        .collect();
169
96
    Ok(Reduction::pure(Expr::AbstractLiteral(
170
96
        Metadata::new(),
171
96
        AbstractLiteral::matrix_implied_indices(elements),
172
96
    )))
173
41806
}
174

            
175
/// Expressions like allDiff and sum support 1-dimensional matrices as inputs, e.g. sum(m) where m is indexed by 1..3.
176
///
177
/// This rule is very similar to `matrix_ref_to_atom`, but turns the matrix reference into a slice rather its atoms.
178
/// Other rules like `flatten_matrix_slice` take care of actually turning the slice into the matrix elements.
179
#[register_rule("Smt", 999)]
180
38401
fn matrix_ref_to_slice(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
181
    if let Expr::SafeSlice(_, _, _)
182
    | Expr::UnsafeSlice(_, _, _)
183
    | Expr::SafeIndex(_, _, _)
184
38401
    | Expr::UnsafeIndex(_, _, _) = expr
185
    {
186
3578
        return Err(RuleNotApplicable);
187
34823
    };
188

            
189
34823
    for (child, ctx) in expr.holes() {
190
8720
        let Expr::Atomic(_, Atom::Reference(decl)) = &child else {
191
17298
            continue;
192
        };
193

            
194
6374
        let dom = decl.resolved_domain().ok_or(RuleNotApplicable)?;
195
6374
        let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
196
6344
            continue;
197
        };
198

            
199
        // Must be a 1d matrix
200
30
        if index_domains.len() > 1 {
201
18
            continue;
202
12
        }
203

            
204
12
        let new_child = Expr::SafeSlice(Metadata::new(), Moo::new(child.clone()), vec![None]);
205
12
        return Ok(Reduction::pure(ctx(new_child)));
206
    }
207

            
208
34811
    Err(RuleNotApplicable)
209
38401
}
210

            
211
/// This rule is applicable in SMT when atomic representation is not used for matrices.
212
///
213
/// Namely, it unwraps flatten(m) into [m[1, 1], m[1, 2], ...]
214
#[register_rule("Smt", 999, [Flatten])]
215
38401
fn unwrap_flatten_matrix_nonatomic(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
216
    // TODO: depth not supported yet
217
18
    let Expr::Flatten(_, None, m) = expr else {
218
38383
        return Err(RuleNotApplicable);
219
    };
220

            
221
18
    let index_domains = matrix::bound_index_domains_of_expr(m.as_ref()).ok_or(RuleNotApplicable)?;
222

            
223
18
    let elems: Vec<Expr> = matrix::try_enumerate_indices(index_domains)
224
18
        .map_err(|_| DomainError)?
225
102
        .map(|lits| {
226
102
            let idxs: Vec<Expr> = lits.into_iter().map(Into::into).collect();
227
102
            Expr::SafeIndex(Metadata::new(), m.clone(), idxs)
228
102
        })
229
18
        .collect();
230

            
231
18
    let new_dom = GroundDomain::Int(vec![Range::Bounded(
232
18
        1,
233
18
        elems
234
18
            .len()
235
18
            .try_into()
236
18
            .expect("length of matrix should be able to be held in Int type"),
237
18
    )]);
238
18
    let new_expr = Expr::AbstractLiteral(
239
18
        Metadata::new(),
240
18
        AbstractLiteral::Matrix(elems, new_dom.into()),
241
18
    );
242
18
    Ok(Reduction::pure(new_expr))
243
38401
}
244

            
245
/// Expands a sum over an "in set" comprehension to a list.
246
///
247
/// TODO: We currently only support one "in set" generator.
248
/// This rule can be made much more general and nicer.
249
#[register_rule("Smt", 999, [Sum])]
250
38401
fn unwrap_abstract_comprehension_sum(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
251
38401
    let Expr::Sum(_, inner) = expr else {
252
37823
        return Err(RuleNotApplicable);
253
    };
254
578
    let Expr::Comprehension(_, comp) = inner.as_ref() else {
255
572
        return Err(RuleNotApplicable);
256
    };
257

            
258
6
    let [ComprehensionQualifier::ExpressionGenerator { ptr }] = &comp.qualifiers[..] else {
259
        return Err(RuleNotApplicable);
260
    };
261

            
262
6
    let Some(set) = ptr
263
6
        .as_quantified_expr()
264
6
        .map(|expr_guard| expr_guard.clone())
265
    else {
266
        return Err(RuleNotApplicable);
267
    };
268

            
269
6
    let elem_domain = set
270
6
        .domain_of()
271
6
        .expect("Expression must have a domain")
272
6
        .element_domain()
273
6
        .expect("Expression must contain elements with uniform domain");
274
6
    let list: Vec<_> = elem_domain
275
6
        .values()
276
6
        .map_err(|_| DomainError)?
277
30
        .map(|lit| essence_expr!("&lit * toInt(&lit in &set)"))
278
6
        .collect();
279

            
280
6
    let new_expr = Expr::Sum(
281
6
        Metadata::new(),
282
6
        Moo::new(Expr::AbstractLiteral(
283
6
            Metadata::new(),
284
6
            AbstractLiteral::matrix_implied_indices(list),
285
6
        )),
286
6
    );
287
6
    Ok(Reduction::pure(new_expr))
288
38401
}
289

            
290
/// Unwraps a subsetEq expression into checking membership equality.
291
///
292
/// Any elements not in the domain of one set must not be in the other set.
293
#[register_rule("Smt", 999, [SubsetEq])]
294
38401
fn unwrap_subseteq(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
295
38401
    let Expr::SubsetEq(_, a, b) = expr else {
296
38401
        return Err(RuleNotApplicable);
297
    };
298

            
299
    let dom_a = a.domain_of().and_then(|d| d.resolve()).ok_or(DomainError)?;
300
    let dom_b = b.domain_of().and_then(|d| d.resolve()).ok_or(DomainError)?;
301

            
302
    let GroundDomain::Set(_, elem_dom_a) = dom_a.as_ref() else {
303
        return Err(RuleNotApplicable);
304
    };
305
    let GroundDomain::Set(_, elem_dom_b) = dom_b.as_ref() else {
306
        return Err(RuleNotApplicable);
307
    };
308

            
309
    let domain_a_iter = elem_dom_a.values().map_err(|_| DomainError)?;
310
    let memberships = domain_a_iter
311
        .map(|lit| {
312
            let b_contains = elem_dom_b.contains(&lit).map_err(|_| DomainError)?;
313
            match b_contains {
314
                true => Ok(essence_expr!("(&lit in &a) -> (&lit in &b)")),
315
                false => Ok(essence_expr!("!(&lit in &a)")),
316
            }
317
        })
318
        .collect::<Result<Vec<_>, _>>()?;
319

            
320
    let new_expr = Expr::And(
321
        Metadata::new(),
322
        Moo::new(Expr::AbstractLiteral(
323
            Metadata::new(),
324
            AbstractLiteral::matrix_implied_indices(memberships),
325
        )),
326
    );
327

            
328
    Ok(Reduction::pure(new_expr))
329
38401
}
330

            
331
/// Unwraps equality between sets into checking membership equality.
332
///
333
/// This is an optimisation over unwrap_subseteq to avoid unnecessary additional -> exprs
334
/// where a single <-> is enough. This must apply before eq_to_subset_eq.
335
#[register_rule("Smt", 8801, [Eq])]
336
344880
fn unwrap_set_eq(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
337
344880
    let Expr::Eq(_, a, b) = expr else {
338
332562
        return Err(RuleNotApplicable);
339
    };
340

            
341
12318
    let dom_a = a.domain_of().and_then(|d| d.resolve()).ok_or(DomainError)?;
342
11124
    let dom_b = b.domain_of().and_then(|d| d.resolve()).ok_or(DomainError)?;
343

            
344
11032
    let GroundDomain::Set(_, elem_dom_a) = dom_a.as_ref() else {
345
11008
        return Err(RuleNotApplicable);
346
    };
347
24
    let GroundDomain::Set(_, elem_dom_b) = dom_b.as_ref() else {
348
        return Err(RuleNotApplicable);
349
    };
350

            
351
24
    let union_val_iter = elem_dom_a
352
24
        .union(elem_dom_b)
353
24
        .and_then(|d| d.values())
354
24
        .map_err(|_| DomainError)?;
355
24
    let memberships = union_val_iter
356
60
        .map(|lit| {
357
60
            let a_contains = elem_dom_a.contains(&lit).map_err(|_| DomainError)?;
358
60
            let b_contains = elem_dom_b.contains(&lit).map_err(|_| DomainError)?;
359
60
            match (a_contains, b_contains) {
360
36
                (true, true) => Ok(essence_expr!("(&lit in &a) <-> (&lit in &b)")),
361
12
                (true, false) => Ok(essence_expr!("!(&lit in &a)")),
362
12
                (false, true) => Ok(essence_expr!("!(&lit in &b)")),
363
                (false, false) => unreachable!(),
364
            }
365
60
        })
366
24
        .collect::<Result<Vec<_>, _>>()?;
367

            
368
24
    let new_expr = Expr::And(
369
24
        Metadata::new(),
370
24
        Moo::new(Expr::AbstractLiteral(
371
24
            Metadata::new(),
372
24
            AbstractLiteral::matrix_implied_indices(memberships),
373
24
        )),
374
24
    );
375

            
376
24
    Ok(Reduction::pure(new_expr))
377
344880
}