1
//! Rules for variables in domains.
2

            
3
use std::collections::HashMap;
4

            
5
use conjure_cp::{
6
    ast::{
7
        Atom, DecisionVariable, DeclarationKind, Domain, DomainPtr, Expression as Expr, HasDomain,
8
        IntVal, Literal as Lit, Metadata, Moo, Name, Range, Reference, SymbolTable,
9
    },
10
    rule_engine::{ApplicationError, ApplicationResult, Reduction, register_rule},
11
};
12
use uniplate::Biplate;
13

            
14
use ApplicationError::RuleNotApplicable;
15

            
16
type IntBoundsCache = HashMap<Name, (i32, i32)>;
17
type VisitingStack = Vec<Name>;
18

            
19
/// Rewrites variables in domains.
20
///
21
/// Solvers require variable declarations to have ground domains. For integer domains that contain variables in them, we widen to a finite ground superset-domain and add constraints that enforce membership in the original (possibly variable-dependent) domain.
22
#[register_rule(("Base", 8990))]
23
1021237
fn handle_variables_in_domains(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
24
1021237
    let Expr::Root(_, _) = expr else {
25
989651
        return Err(RuleNotApplicable);
26
    };
27

            
28
31586
    if !symbols_have_decision_variable_references(symbols) {
29
31544
        return Err(RuleNotApplicable);
30
42
    }
31

            
32
42
    let mut known_int_bounds: IntBoundsCache = HashMap::new();
33
42
    let mut domain_guards = Vec::new();
34
42
    let mut changed = false;
35

            
36
    // Collect declarations first to avoid iterator invalidation while mutating declarations.
37
42
    let declarations: Vec<_> = symbols
38
42
        .clone()
39
42
        .into_iter_local()
40
42
        .map(|(_, decl)| decl)
41
42
        .collect();
42

            
43
120
    for mut decl in declarations {
44
120
        let Some(domain) = decl.as_find().map(|var| var.domain_of()) else {
45
            continue;
46
        };
47

            
48
120
        if let Some(bounds) =
49
120
            int_domain_bounds_from_domain(&domain, symbols, &mut known_int_bounds, &mut Vec::new())
50
120
        {
51
120
            known_int_bounds.insert(decl.name().clone(), bounds);
52
120
        }
53

            
54
120
        if domain.resolve().is_some() {
55
66
            continue;
56
54
        }
57

            
58
54
        let Some(widened_domain) =
59
54
            resolve_or_widen_int_domain(&domain, symbols, &mut known_int_bounds, &mut Vec::new())
60
        else {
61
            return Err(RuleNotApplicable);
62
        };
63

            
64
54
        let Some(guards) = domain_consistency_constraints(&decl, &domain) else {
65
            return Err(RuleNotApplicable);
66
        };
67

            
68
54
        domain_guards.extend(guards);
69
54
        let _ = decl.replace_kind(DeclarationKind::Find(DecisionVariable::new(widened_domain)));
70
54
        changed = true;
71
    }
72

            
73
42
    if !changed {
74
        return Err(RuleNotApplicable);
75
42
    }
76

            
77
42
    Ok(Reduction::new(expr.clone(), domain_guards, symbols.clone()))
78
1021237
}
79

            
80
/// Returns true iff at least one local symbol contains a reference to a decision variable.
81
31586
fn symbols_have_decision_variable_references(symbols: &SymbolTable) -> bool {
82
31586
    let is_decision_reference = |reference: &Reference| {
83
18258
        reference.ptr().as_find().is_some()
84
18258
            || symbols
85
18258
                .lookup(&reference.name().clone())
86
18258
                .is_some_and(|decl| decl.as_find().is_some())
87
18258
    };
88

            
89
1726054
    symbols.iter_local().any(|(_, declaration)| {
90
1726054
        declaration.domain().is_some_and(|domain| {
91
1726048
            Biplate::<Reference>::universe_bi(domain.as_ref())
92
1726048
                .iter()
93
1726048
                .any(&is_decision_reference)
94
1726048
                || Biplate::<IntVal>::universe_bi(domain.as_ref())
95
1726048
                    .iter()
96
1726048
                    .any(|int_val| match int_val {
97
                        IntVal::Const(_) => false,
98
                        IntVal::Reference(reference) => is_decision_reference(reference),
99
                        IntVal::Expr(expr) => Biplate::<Atom>::universe_bi(expr.as_ref())
100
                            .iter()
101
                            .any(|atom| {
102
                                matches!(atom, Atom::Reference(reference) if is_decision_reference(reference))
103
                            }),
104
                    })
105
1726054
        }) || declaration.as_find().is_some_and(|find| {
106
1716648
            let domain = find.domain_of();
107
1716648
            domain.resolve().is_none() && domain.as_ref().as_int().is_some()
108
1716648
        })
109
1726054
    })
110
31586
}
111

            
112
/// Resolves an integer domain when possible; otherwise computes a finite widened domain
113
/// by replacing symbolic bounds with conservative numeric bounds.
114
114
fn resolve_or_widen_int_domain(
115
114
    domain: &DomainPtr,
116
114
    symbols: &SymbolTable,
117
114
    known_int_bounds: &mut IntBoundsCache,
118
114
    visiting: &mut VisitingStack,
119
114
) -> Option<DomainPtr> {
120
114
    if let Some(resolved) = domain.resolve() {
121
60
        return Some(resolved.into());
122
54
    }
123

            
124
54
    let ranges = domain.as_ref().as_int()?;
125
54
    let widened_ranges: Vec<Range<i32>> = ranges
126
54
        .iter()
127
66
        .map(|range| int_range_bounds(range, symbols, known_int_bounds, visiting))
128
66
        .map(|bounds| bounds.map(|(lo, hi)| Range::new(Some(lo), Some(hi))))
129
54
        .collect::<Option<Vec<_>>>()?;
130

            
131
54
    Some(Domain::int_ground(widened_ranges))
132
114
}
133

            
134
/// Returns overall numeric bounds for an unresolved integer domain.
135
180
fn int_domain_bounds_from_domain(
136
180
    domain: &DomainPtr,
137
180
    symbols: &SymbolTable,
138
180
    known_int_bounds: &mut IntBoundsCache,
139
180
    visiting: &mut VisitingStack,
140
180
) -> Option<(i32, i32)> {
141
180
    let ranges = domain.as_ref().as_int()?;
142
180
    let mut lower = i32::MAX;
143
180
    let mut upper = i32::MIN;
144

            
145
216
    for range in ranges {
146
216
        let (lo, hi) = int_range_bounds(&range, symbols, known_int_bounds, visiting)?;
147
216
        lower = lower.min(lo);
148
216
        upper = upper.max(hi);
149
    }
150

            
151
180
    Some((lower, upper))
152
180
}
153

            
154
/// Computes numeric bounds for a possibly symbolic integer range.
155
282
fn int_range_bounds(
156
282
    range: &Range<IntVal>,
157
282
    symbols: &SymbolTable,
158
282
    known_int_bounds: &mut IntBoundsCache,
159
282
    visiting: &mut VisitingStack,
160
282
) -> Option<(i32, i32)> {
161
282
    match range {
162
12
        Range::Single(v) => int_val_bounds(v, symbols, known_int_bounds, visiting),
163
270
        Range::Bounded(l, r) => {
164
270
            let (ll, lh) = int_val_bounds(l, symbols, known_int_bounds, visiting)?;
165
270
            let (rl, rh) = int_val_bounds(r, symbols, known_int_bounds, visiting)?;
166
270
            Some((ll.min(lh), rl.max(rh)))
167
        }
168
        Range::Unbounded | Range::UnboundedL(_) | Range::UnboundedR(_) => None,
169
    }
170
282
}
171

            
172
/// Computes numeric bounds for an unresolved integer value.
173
552
fn int_val_bounds(
174
552
    value: &IntVal,
175
552
    symbols: &SymbolTable,
176
552
    known_int_bounds: &mut IntBoundsCache,
177
552
    visiting: &mut VisitingStack,
178
552
) -> Option<(i32, i32)> {
179
552
    if let Some(v) = value.resolve() {
180
396
        return Some((v, v));
181
156
    }
182

            
183
156
    match value {
184
        IntVal::Const(v) => Some((*v, *v)),
185
96
        IntVal::Reference(reference) => {
186
96
            let name = reference.name().clone();
187
96
            int_bounds_for_name(&name, symbols, known_int_bounds, visiting)
188
        }
189
60
        IntVal::Expr(expr) => {
190
60
            let domain = expression_int_bounds(expr, symbols, known_int_bounds, visiting)?;
191
60
            int_domain_bounds_from_domain(&domain, symbols, known_int_bounds, visiting)
192
        }
193
    }
194
552
}
195

            
196
/// Resolves cached or derived bounds for a declaration by name.
197
96
fn int_bounds_for_name(
198
96
    name: &Name,
199
96
    symbols: &SymbolTable,
200
96
    known_int_bounds: &mut IntBoundsCache,
201
96
    visiting: &mut VisitingStack,
202
96
) -> Option<(i32, i32)> {
203
96
    if let Some(bounds) = known_int_bounds.get(name).copied() {
204
96
        return Some(bounds);
205
    }
206

            
207
    if visiting.contains(name) {
208
        return None;
209
    }
210

            
211
    visiting.push(name.clone());
212
    let maybe_bounds = symbols
213
        .lookup(name)
214
        .and_then(|decl| decl.domain())
215
        .and_then(|domain| {
216
            int_domain_bounds_from_domain(&domain, symbols, known_int_bounds, visiting)
217
        });
218
    visiting.pop();
219
    let bounds = maybe_bounds?;
220
    known_int_bounds.insert(name.clone(), bounds);
221
    Some(bounds)
222
96
}
223

            
224
/// Computes a conservative ground integer domain for an expression.
225
60
fn expression_int_bounds(
226
60
    expr: &Moo<Expr>,
227
60
    symbols: &SymbolTable,
228
60
    known_int_bounds: &mut IntBoundsCache,
229
60
    visiting: &mut VisitingStack,
230
60
) -> Option<DomainPtr> {
231
60
    if let Some(Lit::Int(v)) = conjure_cp::ast::eval_constant(expr) {
232
        return Some(Domain::int_ground(vec![Range::Single(v)]));
233
60
    }
234

            
235
60
    let domain = expr.as_ref().domain_of()?;
236
60
    resolve_or_widen_int_domain(&domain, symbols, known_int_bounds, visiting)
237
60
}
238

            
239
/// Builds guards ensuring widened integer find domains still satisfy original symbolic bounds.
240
54
fn domain_consistency_constraints(
241
54
    declaration: &conjure_cp::ast::DeclarationPtr,
242
54
    original_domain: &DomainPtr,
243
54
) -> Option<Vec<Expr>> {
244
54
    if original_domain.resolve().is_some() {
245
        return Some(Vec::new());
246
54
    }
247

            
248
54
    let ranges = original_domain.as_ref().as_int()?;
249
54
    if ranges.is_empty() {
250
        return None;
251
54
    }
252

            
253
54
    let var_expr = Expr::Atomic(
254
54
        Metadata::new(),
255
54
        Atom::Reference(Reference::new(declaration.clone())),
256
54
    );
257
54
    let mut allowed_intervals = Vec::new();
258

            
259
66
    for range in ranges {
260
66
        let interval = match range {
261
            Range::Single(v) => Expr::Eq(
262
                Metadata::new(),
263
                Moo::new(var_expr.clone()),
264
                Moo::new(Expr::from(v)),
265
            ),
266
66
            Range::Bounded(l, r) => {
267
66
                let geq = Expr::Geq(
268
66
                    Metadata::new(),
269
66
                    Moo::new(var_expr.clone()),
270
66
                    Moo::new(Expr::from(l)),
271
66
                );
272
66
                let leq = Expr::Leq(
273
66
                    Metadata::new(),
274
66
                    Moo::new(var_expr.clone()),
275
66
                    Moo::new(Expr::from(r)),
276
66
                );
277
66
                Expr::And(
278
66
                    Metadata::new(),
279
66
                    Moo::new(conjure_cp::into_matrix_expr!(vec![geq, leq])),
280
66
                )
281
            }
282
            Range::Unbounded | Range::UnboundedL(_) | Range::UnboundedR(_) => return None,
283
        };
284
66
        allowed_intervals.push(interval);
285
    }
286

            
287
54
    let guard = if allowed_intervals.len() == 1 {
288
42
        allowed_intervals.remove(0)
289
    } else {
290
12
        Expr::Or(
291
12
            Metadata::new(),
292
12
            Moo::new(conjure_cp::into_matrix_expr!(allowed_intervals)),
293
12
        )
294
    };
295

            
296
54
    Some(vec![guard])
297
54
}