1
//! Rules for variables in domains.
2

            
3
use std::collections::HashMap;
4

            
5
use conjure_cp::{
6
    ast::{
7
        Atom, DecisionVariable, DeclarationKind, Domain, DomainPtr, Expression as Expr, HasDomain,
8
        IntVal, Literal as Lit, Metadata, Moo, Name, Range, Reference, SymbolTable,
9
    },
10
    rule_engine::{ApplicationError, ApplicationResult, Reduction, register_rule},
11
};
12
use uniplate::Biplate;
13

            
14
use ApplicationError::RuleNotApplicable;
15

            
16
type IntBoundsCache = HashMap<Name, (i32, i32)>;
17
type VisitingStack = Vec<Name>;
18

            
19
/// Rewrites variables in domains.
20
///
21
/// Solvers require variable declarations to have ground domains. For integer domains that contain variables in them, we widen to a finite ground superset-domain and add constraints that enforce membership in the original (possibly variable-dependent) domain.
22
#[register_rule(("Base", 8990))]
23
463250
fn handle_variables_in_domains(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
24
463250
    let Expr::Root(_, _) = expr else {
25
448537
        return Err(RuleNotApplicable);
26
    };
27

            
28
14713
    if !symbols_have_decision_variable_references(symbols) {
29
14692
        return Err(RuleNotApplicable);
30
21
    }
31

            
32
21
    let mut known_int_bounds: IntBoundsCache = HashMap::new();
33
21
    let mut domain_guards = Vec::new();
34
21
    let mut changed = false;
35

            
36
    // Collect declarations first to avoid iterator invalidation while mutating declarations.
37
21
    let declarations: Vec<_> = symbols
38
21
        .clone()
39
21
        .into_iter_local()
40
21
        .map(|(_, decl)| decl)
41
21
        .collect();
42

            
43
60
    for mut decl in declarations {
44
60
        let Some(domain) = decl.as_find().map(|var| var.domain_of()) else {
45
            continue;
46
        };
47

            
48
60
        if let Some(bounds) =
49
60
            int_domain_bounds_from_domain(&domain, symbols, &mut known_int_bounds, &mut Vec::new())
50
60
        {
51
60
            known_int_bounds.insert(decl.name().clone(), bounds);
52
60
        }
53

            
54
60
        if domain.resolve().is_some() {
55
33
            continue;
56
27
        }
57

            
58
27
        let Some(widened_domain) =
59
27
            resolve_or_widen_int_domain(&domain, symbols, &mut known_int_bounds, &mut Vec::new())
60
        else {
61
            return Err(RuleNotApplicable);
62
        };
63

            
64
27
        let Some(guards) = domain_consistency_constraints(&decl, &domain) else {
65
            return Err(RuleNotApplicable);
66
        };
67

            
68
27
        domain_guards.extend(guards);
69
27
        let _ = decl.replace_kind(DeclarationKind::Find(DecisionVariable::new(widened_domain)));
70
27
        changed = true;
71
    }
72

            
73
21
    if !changed {
74
        return Err(RuleNotApplicable);
75
21
    }
76

            
77
21
    Ok(Reduction::new(expr.clone(), domain_guards, symbols.clone()))
78
463250
}
79

            
80
/// Returns true iff at least one local symbol contains a reference to a decision variable.
81
14713
fn symbols_have_decision_variable_references(symbols: &SymbolTable) -> bool {
82
14713
    let is_decision_reference = |reference: &Reference| {
83
9129
        reference.ptr().as_find().is_some()
84
9129
            || symbols
85
9129
                .lookup(&reference.name().clone())
86
9129
                .is_some_and(|decl| decl.as_find().is_some())
87
9129
    };
88

            
89
578416
    symbols.iter_local().any(|(_, declaration)| {
90
578416
        declaration.domain().is_some_and(|domain| {
91
578413
            Biplate::<Reference>::universe_bi(domain.as_ref())
92
578413
                .iter()
93
578413
                .any(&is_decision_reference)
94
578413
                || Biplate::<IntVal>::universe_bi(domain.as_ref())
95
578413
                    .iter()
96
578413
                    .any(|int_val| match int_val {
97
                        IntVal::Const(_) => false,
98
                        IntVal::Reference(reference) => is_decision_reference(reference),
99
                        IntVal::Expr(expr) => Biplate::<Atom>::universe_bi(expr.as_ref())
100
                            .iter()
101
                            .any(|atom| {
102
                                matches!(atom, Atom::Reference(reference) if is_decision_reference(reference))
103
                            }),
104
                    })
105
578416
        }) || declaration.as_find().is_some_and(|find| {
106
573775
            let domain = find.domain_of();
107
573775
            domain.resolve().is_none() && domain.as_ref().as_int().is_some()
108
573775
        })
109
578416
    })
110
14713
}
111

            
112
/// Resolves an integer domain when possible; otherwise computes a finite widened domain
113
/// by replacing symbolic bounds with conservative numeric bounds.
114
57
fn resolve_or_widen_int_domain(
115
57
    domain: &DomainPtr,
116
57
    symbols: &SymbolTable,
117
57
    known_int_bounds: &mut IntBoundsCache,
118
57
    visiting: &mut VisitingStack,
119
57
) -> Option<DomainPtr> {
120
57
    if let Some(resolved) = domain.resolve() {
121
30
        return Some(resolved.into());
122
27
    }
123

            
124
27
    let ranges = domain.as_ref().as_int()?;
125
27
    let widened_ranges: Vec<Range<i32>> = ranges
126
27
        .iter()
127
33
        .map(|range| int_range_bounds(range, symbols, known_int_bounds, visiting))
128
33
        .map(|bounds| bounds.map(|(lo, hi)| Range::new(Some(lo), Some(hi))))
129
27
        .collect::<Option<Vec<_>>>()?;
130

            
131
27
    Some(Domain::int_ground(widened_ranges))
132
57
}
133

            
134
/// Returns overall numeric bounds for an unresolved integer domain.
135
90
fn int_domain_bounds_from_domain(
136
90
    domain: &DomainPtr,
137
90
    symbols: &SymbolTable,
138
90
    known_int_bounds: &mut IntBoundsCache,
139
90
    visiting: &mut VisitingStack,
140
90
) -> Option<(i32, i32)> {
141
90
    let ranges = domain.as_ref().as_int()?;
142
90
    let mut lower = i32::MAX;
143
90
    let mut upper = i32::MIN;
144

            
145
108
    for range in ranges {
146
108
        let (lo, hi) = int_range_bounds(&range, symbols, known_int_bounds, visiting)?;
147
108
        lower = lower.min(lo);
148
108
        upper = upper.max(hi);
149
    }
150

            
151
90
    Some((lower, upper))
152
90
}
153

            
154
/// Computes numeric bounds for a possibly symbolic integer range.
155
141
fn int_range_bounds(
156
141
    range: &Range<IntVal>,
157
141
    symbols: &SymbolTable,
158
141
    known_int_bounds: &mut IntBoundsCache,
159
141
    visiting: &mut VisitingStack,
160
141
) -> Option<(i32, i32)> {
161
141
    match range {
162
6
        Range::Single(v) => int_val_bounds(v, symbols, known_int_bounds, visiting),
163
135
        Range::Bounded(l, r) => {
164
135
            let (ll, lh) = int_val_bounds(l, symbols, known_int_bounds, visiting)?;
165
135
            let (rl, rh) = int_val_bounds(r, symbols, known_int_bounds, visiting)?;
166
135
            Some((ll.min(lh), rl.max(rh)))
167
        }
168
        Range::Unbounded | Range::UnboundedL(_) | Range::UnboundedR(_) => None,
169
    }
170
141
}
171

            
172
/// Computes numeric bounds for an unresolved integer value.
173
276
fn int_val_bounds(
174
276
    value: &IntVal,
175
276
    symbols: &SymbolTable,
176
276
    known_int_bounds: &mut IntBoundsCache,
177
276
    visiting: &mut VisitingStack,
178
276
) -> Option<(i32, i32)> {
179
276
    if let Some(v) = value.resolve() {
180
198
        return Some((v, v));
181
78
    }
182

            
183
78
    match value {
184
        IntVal::Const(v) => Some((*v, *v)),
185
48
        IntVal::Reference(reference) => {
186
48
            let name = reference.name().clone();
187
48
            int_bounds_for_name(&name, symbols, known_int_bounds, visiting)
188
        }
189
30
        IntVal::Expr(expr) => {
190
30
            let domain = expression_int_bounds(expr, symbols, known_int_bounds, visiting)?;
191
30
            int_domain_bounds_from_domain(&domain, symbols, known_int_bounds, visiting)
192
        }
193
    }
194
276
}
195

            
196
/// Resolves cached or derived bounds for a declaration by name.
197
48
fn int_bounds_for_name(
198
48
    name: &Name,
199
48
    symbols: &SymbolTable,
200
48
    known_int_bounds: &mut IntBoundsCache,
201
48
    visiting: &mut VisitingStack,
202
48
) -> Option<(i32, i32)> {
203
48
    if let Some(bounds) = known_int_bounds.get(name).copied() {
204
48
        return Some(bounds);
205
    }
206

            
207
    if visiting.contains(name) {
208
        return None;
209
    }
210

            
211
    visiting.push(name.clone());
212
    let maybe_bounds = symbols
213
        .lookup(name)
214
        .and_then(|decl| decl.domain())
215
        .and_then(|domain| {
216
            int_domain_bounds_from_domain(&domain, symbols, known_int_bounds, visiting)
217
        });
218
    visiting.pop();
219
    let bounds = maybe_bounds?;
220
    known_int_bounds.insert(name.clone(), bounds);
221
    Some(bounds)
222
48
}
223

            
224
/// Computes a conservative ground integer domain for an expression.
225
30
fn expression_int_bounds(
226
30
    expr: &Moo<Expr>,
227
30
    symbols: &SymbolTable,
228
30
    known_int_bounds: &mut IntBoundsCache,
229
30
    visiting: &mut VisitingStack,
230
30
) -> Option<DomainPtr> {
231
30
    if let Some(Lit::Int(v)) = conjure_cp::ast::eval_constant(expr) {
232
        return Some(Domain::int_ground(vec![Range::Single(v)]));
233
30
    }
234

            
235
30
    let domain = expr.as_ref().domain_of()?;
236
30
    resolve_or_widen_int_domain(&domain, symbols, known_int_bounds, visiting)
237
30
}
238

            
239
/// Builds guards ensuring widened integer find domains still satisfy original symbolic bounds.
240
27
fn domain_consistency_constraints(
241
27
    declaration: &conjure_cp::ast::DeclarationPtr,
242
27
    original_domain: &DomainPtr,
243
27
) -> Option<Vec<Expr>> {
244
27
    if original_domain.resolve().is_some() {
245
        return Some(Vec::new());
246
27
    }
247

            
248
27
    let ranges = original_domain.as_ref().as_int()?;
249
27
    if ranges.is_empty() {
250
        return None;
251
27
    }
252

            
253
27
    let var_expr = Expr::Atomic(
254
27
        Metadata::new(),
255
27
        Atom::Reference(Reference::new(declaration.clone())),
256
27
    );
257
27
    let mut allowed_intervals = Vec::new();
258

            
259
33
    for range in ranges {
260
33
        let interval = match range {
261
            Range::Single(v) => Expr::Eq(
262
                Metadata::new(),
263
                Moo::new(var_expr.clone()),
264
                Moo::new(Expr::from(v)),
265
            ),
266
33
            Range::Bounded(l, r) => {
267
33
                let geq = Expr::Geq(
268
33
                    Metadata::new(),
269
33
                    Moo::new(var_expr.clone()),
270
33
                    Moo::new(Expr::from(l)),
271
33
                );
272
33
                let leq = Expr::Leq(
273
33
                    Metadata::new(),
274
33
                    Moo::new(var_expr.clone()),
275
33
                    Moo::new(Expr::from(r)),
276
33
                );
277
33
                Expr::And(
278
33
                    Metadata::new(),
279
33
                    Moo::new(conjure_cp::into_matrix_expr!(vec![geq, leq])),
280
33
                )
281
            }
282
            Range::Unbounded | Range::UnboundedL(_) | Range::UnboundedR(_) => return None,
283
        };
284
33
        allowed_intervals.push(interval);
285
    }
286

            
287
27
    let guard = if allowed_intervals.len() == 1 {
288
21
        allowed_intervals.remove(0)
289
    } else {
290
6
        Expr::Or(
291
6
            Metadata::new(),
292
6
            Moo::new(conjure_cp::into_matrix_expr!(allowed_intervals)),
293
6
        )
294
    };
295

            
296
27
    Some(vec![guard])
297
27
}