1
use conjure_cp::ast::{Expression as Expr, *};
2
use conjure_cp::rule_engine::ApplicationError;
3
use conjure_cp::rule_engine::{
4
    ApplicationError::{DomainError, RuleNotApplicable},
5
    ApplicationResult, Reduction, register_rule, register_rule_set,
6
};
7
use conjure_cp::solver::SolverFamily;
8
use conjure_cp::{bug, essence_expr};
9
use uniplate::Uniplate;
10

            
11
// These rules are applicable regardless of what theories are used.
12
register_rule_set!("Smt", ("Base"), |f: &SolverFamily| {
13
    matches!(f, SolverFamily::Smt(..))
14
});
15

            
16
#[register_rule(("Smt", 1000))]
17
fn flatten_indomain(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
18
    let Expr::InDomain(_, inner, domain) = expr else {
19
        return Err(RuleNotApplicable);
20
    };
21

            
22
    let dom = domain.resolve().ok_or(RuleNotApplicable)?;
23
    let new_expr = match dom.as_ref() {
24
        // Bool values are always in the bool domain
25
        GroundDomain::Bool => Ok(Expr::Atomic(
26
            Metadata::new(),
27
            Atom::Literal(Literal::Bool(true)),
28
        )),
29
        GroundDomain::Empty(_) => Ok(Expr::Atomic(
30
            Metadata::new(),
31
            Atom::Literal(Literal::Bool(false)),
32
        )),
33
        GroundDomain::Int(ranges) => {
34
            let elements: Vec<_> = ranges
35
                .iter()
36
                .map(|range| match range {
37
                    Range::Single(n) => {
38
                        let eq = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Int(*n)));
39
                        Expr::Eq(Metadata::new(), inner.clone(), Moo::new(eq))
40
                    }
41
                    Range::Bounded(l, r) => {
42
                        let l_expr = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Int(*l)));
43
                        let r_expr = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Int(*r)));
44
                        let lit = AbstractLiteral::list(vec![
45
                            Expr::Geq(Metadata::new(), inner.clone(), Moo::new(l_expr)),
46
                            Expr::Leq(Metadata::new(), inner.clone(), Moo::new(r_expr)),
47
                        ]);
48
                        Expr::And(
49
                            Metadata::new(),
50
                            Moo::new(Expr::AbstractLiteral(Metadata::new(), lit)),
51
                        )
52
                    }
53
                    Range::UnboundedL(r) => {
54
                        let bound = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Int(*r)));
55
                        Expr::Leq(Metadata::new(), inner.clone(), Moo::new(bound))
56
                    }
57
                    Range::UnboundedR(l) => {
58
                        let bound = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Int(*l)));
59
                        Expr::Geq(Metadata::new(), inner.clone(), Moo::new(bound))
60
                    }
61
                    Range::Unbounded => bug!("integer domains should not have unbounded ranges"),
62
                })
63
                .collect();
64
            Ok(Expr::Or(
65
                Metadata::new(),
66
                Moo::new(Expr::AbstractLiteral(
67
                    Metadata::new(),
68
                    AbstractLiteral::list(elements),
69
                )),
70
            ))
71
        }
72
        _ => Err(RuleNotApplicable),
73
    }?;
74
    Ok(Reduction::pure(new_expr))
75
}
76

            
77
/// Matrix a = b iff every index in the union of their indices has the same value.
78
/// E.g. a: matrix indexed by [int(1..2)] of int(1..2), b: matrix indexed by [int(2..3)] of int(1..2)
79
/// a = b ~> a[1] = b[1] /\ a[2] = b[2] /\ a[3] = b[3]
80
#[register_rule(("Smt", 1000))]
81
fn flatten_matrix_eq_neq(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
82
    let (a, b) = match expr {
83
        Expr::Eq(_, a, b) | Expr::Neq(_, a, b) => (a, b),
84
        _ => return Err(RuleNotApplicable),
85
    };
86

            
87
    let dom_a = a.domain_of().ok_or(RuleNotApplicable)?;
88
    let dom_b = b.domain_of().ok_or(RuleNotApplicable)?;
89

            
90
    let (Some((_, a_idx_domains)), Some((_, b_idx_domains))) =
91
        (dom_a.as_matrix_ground(), dom_b.as_matrix_ground())
92
    else {
93
        return Err(RuleNotApplicable);
94
    };
95

            
96
    let pairs = matrix::enumerate_index_union_indices(a_idx_domains, b_idx_domains)
97
        .map_err(|_| ApplicationError::DomainError)?
98
        .map(|idx_lits| {
99
            let idx_vec: Vec<_> = idx_lits
100
                .into_iter()
101
                .map(|lit| Atom::Literal(lit).into())
102
                .collect();
103
            (
104
                Expression::UnsafeIndex(Metadata::new(), a.clone(), idx_vec.clone()),
105
                Expression::UnsafeIndex(Metadata::new(), b.clone(), idx_vec),
106
            )
107
        });
108

            
109
    let new_expr = match expr {
110
        Expr::Eq(..) => {
111
            let eqs: Vec<_> = pairs.map(|(a, b)| essence_expr!(&a = &b)).collect();
112
            Expr::And(
113
                Metadata::new(),
114
                Moo::new(Expr::AbstractLiteral(
115
                    Metadata::new(),
116
                    AbstractLiteral::list(eqs),
117
                )),
118
            )
119
        }
120
        Expr::Neq(..) => {
121
            let neqs: Vec<_> = pairs.map(|(a, b)| essence_expr!(&a != &b)).collect();
122
            Expr::Or(
123
                Metadata::new(),
124
                Moo::new(Expr::AbstractLiteral(
125
                    Metadata::new(),
126
                    AbstractLiteral::list(neqs),
127
                )),
128
            )
129
        }
130
        _ => unreachable!(),
131
    };
132

            
133
    Ok(Reduction::pure(new_expr))
134
}
135

            
136
/// Turn a matrix slice into a 1-d matrix of the slice elements
137
/// E.g. m[1,..] ~> [m[1,1], m[1,2], m[1,3]]
138
#[register_rule(("Smt", 1000))]
139
fn flatten_matrix_slice(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
140
    let Expr::SafeSlice(_, m, slice_idxs) = expr else {
141
        return Err(RuleNotApplicable);
142
    };
143

            
144
    let dom = m.domain_of().ok_or(RuleNotApplicable)?;
145
    let Some((_, mat_idxs)) = dom.as_matrix_ground() else {
146
        return Err(RuleNotApplicable);
147
    };
148

            
149
    if slice_idxs.len() != mat_idxs.len() {
150
        return Err(DomainError);
151
    }
152

            
153
    // Find where in the index vector the ".." is
154
    let (slice_dim, _) = slice_idxs
155
        .iter()
156
        .enumerate()
157
        .find(|(_, idx)| idx.is_none())
158
        .ok_or(RuleNotApplicable)?;
159
    let other_idxs = {
160
        let opt: Option<Vec<_>> = [&slice_idxs[..slice_dim], &slice_idxs[(slice_dim + 1)..]]
161
            .concat()
162
            .into_iter()
163
            .collect();
164
        opt.ok_or(DomainError)?
165
    };
166
    let elements: Vec<Expr> = mat_idxs[slice_dim]
167
        .values()
168
        .map_err(|_| DomainError)?
169
        .map(|lit| {
170
            let mut new_idx = other_idxs.clone();
171
            new_idx.insert(slice_dim, Expr::Atomic(Metadata::new(), Atom::Literal(lit)));
172
            Expr::SafeIndex(Metadata::new(), m.clone(), new_idx)
173
        })
174
        .collect();
175
    Ok(Reduction::pure(Expr::AbstractLiteral(
176
        Metadata::new(),
177
        AbstractLiteral::list(elements),
178
    )))
179
}
180

            
181
/// Expressions like allDiff and sum support 1-dimensional matrices as inputs, e.g. sum(m) where m is indexed by 1..3.
182
///
183
/// This rule is very similar to `matrix_ref_to_atom`, but turns the matrix reference into a slice rather its atoms.
184
/// Other rules like `flatten_matrix_slice` take care of actually turning the slice into the matrix elements.
185
#[register_rule(("Smt", 999))]
186
fn matrix_ref_to_slice(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
187
    if let Expr::SafeSlice(_, _, _)
188
    | Expr::UnsafeSlice(_, _, _)
189
    | Expr::SafeIndex(_, _, _)
190
    | Expr::UnsafeIndex(_, _, _) = expr
191
    {
192
        return Err(RuleNotApplicable);
193
    };
194

            
195
    for (child, ctx) in expr.holes() {
196
        let Expr::Atomic(_, Atom::Reference(decl)) = &child else {
197
            continue;
198
        };
199

            
200
        let dom = decl.resolved_domain().ok_or(RuleNotApplicable)?;
201
        let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
202
            continue;
203
        };
204

            
205
        // Must be a 1d matrix
206
        if index_domains.len() > 1 {
207
            continue;
208
        }
209

            
210
        let new_child = Expr::SafeSlice(Metadata::new(), Moo::new(child.clone()), vec![None]);
211
        return Ok(Reduction::pure(ctx(new_child)));
212
    }
213

            
214
    Err(RuleNotApplicable)
215
}