1
//! Rules for variables in domains.
2

            
3
use std::collections::HashMap;
4

            
5
use conjure_cp::{
6
    ast::{
7
        Atom, DecisionVariable, DeclarationKind, Domain, DomainPtr, Expression as Expr, HasDomain,
8
        IntVal, Literal as Lit, Metadata, Moo, Name, Range, Reference, SymbolTable,
9
    },
10
    rule_engine::{ApplicationError, ApplicationResult, Reduction, register_rule},
11
};
12
use uniplate::Biplate;
13

            
14
use ApplicationError::RuleNotApplicable;
15

            
16
type IntBoundsCache = HashMap<Name, (i32, i32)>;
17
type VisitingStack = Vec<Name>;
18

            
19
/// Rewrites variables in domains.
20
///
21
/// Solvers require variable declarations to have ground domains. For integer domains that contain variables in them, we widen to a finite ground superset-domain and add constraints that enforce membership in the original (possibly variable-dependent) domain.
22
#[register_rule(("Base", 8990))]
23
1389798
fn handle_variables_in_domains(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
24
1389798
    let Expr::Root(_, _) = expr else {
25
1345677
        return Err(RuleNotApplicable);
26
    };
27

            
28
44121
    if !symbols_have_decision_variable_references(symbols) {
29
44058
        return Err(RuleNotApplicable);
30
63
    }
31

            
32
63
    let mut known_int_bounds: IntBoundsCache = HashMap::new();
33
63
    let mut domain_guards = Vec::new();
34
63
    let mut changed = false;
35

            
36
    // Collect declarations first to avoid iterator invalidation while mutating declarations.
37
63
    let declarations: Vec<_> = symbols
38
63
        .clone()
39
63
        .into_iter_local()
40
63
        .map(|(_, decl)| decl)
41
63
        .collect();
42

            
43
180
    for mut decl in declarations {
44
180
        let Some(domain) = decl.as_find().map(|var| var.domain_of()) else {
45
            continue;
46
        };
47

            
48
180
        if let Some(bounds) =
49
180
            int_domain_bounds_from_domain(&domain, symbols, &mut known_int_bounds, &mut Vec::new())
50
180
        {
51
180
            known_int_bounds.insert(decl.name().clone(), bounds);
52
180
        }
53

            
54
180
        if domain.resolve().is_some() {
55
99
            continue;
56
81
        }
57

            
58
81
        let Some(widened_domain) =
59
81
            resolve_or_widen_int_domain(&domain, symbols, &mut known_int_bounds, &mut Vec::new())
60
        else {
61
            return Err(RuleNotApplicable);
62
        };
63

            
64
81
        let Some(guards) = domain_consistency_constraints(&decl, &domain) else {
65
            return Err(RuleNotApplicable);
66
        };
67

            
68
81
        domain_guards.extend(guards);
69
81
        let _ = decl.replace_kind(DeclarationKind::Find(DecisionVariable::new(widened_domain)));
70
81
        changed = true;
71
    }
72

            
73
63
    if !changed {
74
        return Err(RuleNotApplicable);
75
63
    }
76

            
77
63
    Ok(Reduction::new(expr.clone(), domain_guards, symbols.clone()))
78
1389798
}
79

            
80
/// Returns true iff at least one local symbol contains a reference to a decision variable.
81
44121
fn symbols_have_decision_variable_references(symbols: &SymbolTable) -> bool {
82
44121
    let is_decision_reference = |reference: &Reference| {
83
27387
        reference.ptr().as_find().is_some()
84
27387
            || symbols
85
27387
                .lookup(&reference.name().clone())
86
27387
                .is_some_and(|decl| decl.as_find().is_some())
87
27387
    };
88

            
89
1732323
    symbols.iter_local().any(|(_, declaration)| {
90
1732323
        declaration.domain().is_some_and(|domain| {
91
1732314
            Biplate::<Reference>::universe_bi(domain.as_ref())
92
1732314
                .iter()
93
1732314
                .any(&is_decision_reference)
94
1732314
                || Biplate::<IntVal>::universe_bi(domain.as_ref())
95
1732314
                    .iter()
96
1732314
                    .any(|int_val| match int_val {
97
                        IntVal::Const(_) => false,
98
                        IntVal::Reference(reference) => is_decision_reference(reference),
99
                        IntVal::Expr(expr) => Biplate::<Atom>::universe_bi(expr.as_ref())
100
                            .iter()
101
                            .any(|atom| {
102
                                matches!(atom, Atom::Reference(reference) if is_decision_reference(reference))
103
                            }),
104
                    })
105
1732323
        }) || declaration.as_find().is_some_and(|find| {
106
1718400
            let domain = find.domain_of();
107
1718400
            domain.resolve().is_none() && domain.as_ref().as_int().is_some()
108
1718400
        })
109
1732323
    })
110
44121
}
111

            
112
/// Resolves an integer domain when possible; otherwise computes a finite widened domain
113
/// by replacing symbolic bounds with conservative numeric bounds.
114
171
fn resolve_or_widen_int_domain(
115
171
    domain: &DomainPtr,
116
171
    symbols: &SymbolTable,
117
171
    known_int_bounds: &mut IntBoundsCache,
118
171
    visiting: &mut VisitingStack,
119
171
) -> Option<DomainPtr> {
120
171
    if let Some(resolved) = domain.resolve() {
121
90
        return Some(resolved.into());
122
81
    }
123

            
124
81
    let ranges = domain.as_ref().as_int()?;
125
81
    let widened_ranges: Vec<Range<i32>> = ranges
126
81
        .iter()
127
99
        .map(|range| int_range_bounds(range, symbols, known_int_bounds, visiting))
128
99
        .map(|bounds| bounds.map(|(lo, hi)| Range::new(Some(lo), Some(hi))))
129
81
        .collect::<Option<Vec<_>>>()?;
130

            
131
81
    Some(Domain::int_ground(widened_ranges))
132
171
}
133

            
134
/// Returns overall numeric bounds for an unresolved integer domain.
135
270
fn int_domain_bounds_from_domain(
136
270
    domain: &DomainPtr,
137
270
    symbols: &SymbolTable,
138
270
    known_int_bounds: &mut IntBoundsCache,
139
270
    visiting: &mut VisitingStack,
140
270
) -> Option<(i32, i32)> {
141
270
    let ranges = domain.as_ref().as_int()?;
142
270
    let mut lower = i32::MAX;
143
270
    let mut upper = i32::MIN;
144

            
145
324
    for range in ranges {
146
324
        let (lo, hi) = int_range_bounds(&range, symbols, known_int_bounds, visiting)?;
147
324
        lower = lower.min(lo);
148
324
        upper = upper.max(hi);
149
    }
150

            
151
270
    Some((lower, upper))
152
270
}
153

            
154
/// Computes numeric bounds for a possibly symbolic integer range.
155
423
fn int_range_bounds(
156
423
    range: &Range<IntVal>,
157
423
    symbols: &SymbolTable,
158
423
    known_int_bounds: &mut IntBoundsCache,
159
423
    visiting: &mut VisitingStack,
160
423
) -> Option<(i32, i32)> {
161
423
    match range {
162
18
        Range::Single(v) => int_val_bounds(v, symbols, known_int_bounds, visiting),
163
405
        Range::Bounded(l, r) => {
164
405
            let (ll, lh) = int_val_bounds(l, symbols, known_int_bounds, visiting)?;
165
405
            let (rl, rh) = int_val_bounds(r, symbols, known_int_bounds, visiting)?;
166
405
            Some((ll.min(lh), rl.max(rh)))
167
        }
168
        Range::Unbounded | Range::UnboundedL(_) | Range::UnboundedR(_) => None,
169
    }
170
423
}
171

            
172
/// Computes numeric bounds for an unresolved integer value.
173
828
fn int_val_bounds(
174
828
    value: &IntVal,
175
828
    symbols: &SymbolTable,
176
828
    known_int_bounds: &mut IntBoundsCache,
177
828
    visiting: &mut VisitingStack,
178
828
) -> Option<(i32, i32)> {
179
828
    if let Some(v) = value.resolve() {
180
594
        return Some((v, v));
181
234
    }
182

            
183
234
    match value {
184
        IntVal::Const(v) => Some((*v, *v)),
185
144
        IntVal::Reference(reference) => {
186
144
            let name = reference.name().clone();
187
144
            int_bounds_for_name(&name, symbols, known_int_bounds, visiting)
188
        }
189
90
        IntVal::Expr(expr) => {
190
90
            let domain = expression_int_bounds(expr, symbols, known_int_bounds, visiting)?;
191
90
            int_domain_bounds_from_domain(&domain, symbols, known_int_bounds, visiting)
192
        }
193
    }
194
828
}
195

            
196
/// Resolves cached or derived bounds for a declaration by name.
197
144
fn int_bounds_for_name(
198
144
    name: &Name,
199
144
    symbols: &SymbolTable,
200
144
    known_int_bounds: &mut IntBoundsCache,
201
144
    visiting: &mut VisitingStack,
202
144
) -> Option<(i32, i32)> {
203
144
    if let Some(bounds) = known_int_bounds.get(name).copied() {
204
144
        return Some(bounds);
205
    }
206

            
207
    if visiting.contains(name) {
208
        return None;
209
    }
210

            
211
    visiting.push(name.clone());
212
    let maybe_bounds = symbols
213
        .lookup(name)
214
        .and_then(|decl| decl.domain())
215
        .and_then(|domain| {
216
            int_domain_bounds_from_domain(&domain, symbols, known_int_bounds, visiting)
217
        });
218
    visiting.pop();
219
    let bounds = maybe_bounds?;
220
    known_int_bounds.insert(name.clone(), bounds);
221
    Some(bounds)
222
144
}
223

            
224
/// Computes a conservative ground integer domain for an expression.
225
90
fn expression_int_bounds(
226
90
    expr: &Moo<Expr>,
227
90
    symbols: &SymbolTable,
228
90
    known_int_bounds: &mut IntBoundsCache,
229
90
    visiting: &mut VisitingStack,
230
90
) -> Option<DomainPtr> {
231
90
    if let Some(Lit::Int(v)) = conjure_cp::ast::eval_constant(expr) {
232
        return Some(Domain::int_ground(vec![Range::Single(v)]));
233
90
    }
234

            
235
90
    let domain = expr.as_ref().domain_of()?;
236
90
    resolve_or_widen_int_domain(&domain, symbols, known_int_bounds, visiting)
237
90
}
238

            
239
/// Builds guards ensuring widened integer find domains still satisfy original symbolic bounds.
240
81
fn domain_consistency_constraints(
241
81
    declaration: &conjure_cp::ast::DeclarationPtr,
242
81
    original_domain: &DomainPtr,
243
81
) -> Option<Vec<Expr>> {
244
81
    if original_domain.resolve().is_some() {
245
        return Some(Vec::new());
246
81
    }
247

            
248
81
    let ranges = original_domain.as_ref().as_int()?;
249
81
    if ranges.is_empty() {
250
        return None;
251
81
    }
252

            
253
81
    let var_expr = Expr::Atomic(
254
81
        Metadata::new(),
255
81
        Atom::Reference(Reference::new(declaration.clone())),
256
81
    );
257
81
    let mut allowed_intervals = Vec::new();
258

            
259
99
    for range in ranges {
260
99
        let interval = match range {
261
            Range::Single(v) => Expr::Eq(
262
                Metadata::new(),
263
                Moo::new(var_expr.clone()),
264
                Moo::new(Expr::from(v)),
265
            ),
266
99
            Range::Bounded(l, r) => {
267
99
                let geq = Expr::Geq(
268
99
                    Metadata::new(),
269
99
                    Moo::new(var_expr.clone()),
270
99
                    Moo::new(Expr::from(l)),
271
99
                );
272
99
                let leq = Expr::Leq(
273
99
                    Metadata::new(),
274
99
                    Moo::new(var_expr.clone()),
275
99
                    Moo::new(Expr::from(r)),
276
99
                );
277
99
                Expr::And(
278
99
                    Metadata::new(),
279
99
                    Moo::new(conjure_cp::into_matrix_expr!(vec![geq, leq])),
280
99
                )
281
            }
282
            Range::Unbounded | Range::UnboundedL(_) | Range::UnboundedR(_) => return None,
283
        };
284
99
        allowed_intervals.push(interval);
285
    }
286

            
287
81
    let guard = if allowed_intervals.len() == 1 {
288
63
        allowed_intervals.remove(0)
289
    } else {
290
18
        Expr::Or(
291
18
            Metadata::new(),
292
18
            Moo::new(conjure_cp::into_matrix_expr!(allowed_intervals)),
293
18
        )
294
    };
295

            
296
81
    Some(vec![guard])
297
81
}