1
//! Rules for variables in domains.
2

            
3
use std::collections::HashMap;
4

            
5
use conjure_cp::{
6
    ast::{
7
        Atom, DecisionVariable, DeclarationKind, Domain, DomainPtr, Expression as Expr, HasDomain,
8
        IntVal, Literal as Lit, Metadata, Moo, Name, Range, Reference, SymbolTable,
9
    },
10
    rule_engine::{ApplicationError, ApplicationResult, Reduction, register_rule},
11
};
12
use uniplate::Biplate;
13

            
14
use ApplicationError::RuleNotApplicable;
15

            
16
type IntBoundsCache = HashMap<Name, (i32, i32)>;
17
type VisitingStack = Vec<Name>;
18

            
19
/// Rewrites variables in domains.
20
///
21
/// Solvers require variable declarations to have ground domains. For integer domains that contain variables in them, we widen to a finite ground superset-domain and add constraints that enforce membership in the original (possibly variable-dependent) domain.
22
#[register_rule("Base", 8990, [Root])]
23
2519548
fn handle_variables_in_domains(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
24
2519548
    let Expr::Root(_, _) = expr else {
25
2456184
        return Err(RuleNotApplicable);
26
    };
27

            
28
63364
    if !symbols_have_decision_variable_references(symbols) {
29
63286
        return Err(RuleNotApplicable);
30
78
    }
31

            
32
78
    let mut known_int_bounds: IntBoundsCache = HashMap::new();
33
78
    let mut domain_guards = Vec::new();
34
78
    let mut changed = false;
35

            
36
    // Collect declarations first to avoid iterator invalidation while mutating declarations.
37
78
    let declarations: Vec<_> = symbols
38
78
        .clone()
39
78
        .into_iter_local()
40
78
        .map(|(_, decl)| decl)
41
78
        .collect();
42

            
43
1404
    for mut decl in declarations {
44
1404
        let Some(domain) = decl.as_find().map(|var| var.domain_of()) else {
45
            continue;
46
        };
47

            
48
228
        if let Some(bounds) =
49
1404
            int_domain_bounds_from_domain(&domain, symbols, &mut known_int_bounds, &mut Vec::new())
50
228
        {
51
228
            known_int_bounds.insert(decl.name().clone(), bounds);
52
1176
        }
53

            
54
1404
        if domain.resolve().is_some() {
55
1314
            continue;
56
90
        }
57

            
58
90
        let Some(widened_domain) =
59
90
            resolve_or_widen_int_domain(&domain, symbols, &mut known_int_bounds, &mut Vec::new())
60
        else {
61
            return Err(RuleNotApplicable);
62
        };
63

            
64
90
        let Some(guards) = domain_consistency_constraints(&decl, &domain) else {
65
            return Err(RuleNotApplicable);
66
        };
67

            
68
90
        domain_guards.extend(guards);
69
90
        let _ = decl.replace_kind(DeclarationKind::Find(DecisionVariable::new(widened_domain)));
70
90
        changed = true;
71
    }
72

            
73
78
    if !changed {
74
        return Err(RuleNotApplicable);
75
78
    }
76

            
77
78
    Ok(Reduction::new(expr.clone(), domain_guards, symbols.clone()))
78
2519548
}
79

            
80
/// Returns true iff at least one local symbol contains a reference to a decision variable.
81
63364
fn symbols_have_decision_variable_references(symbols: &SymbolTable) -> bool {
82
63364
    let is_decision_reference = |reference: &Reference| {
83
20574
        reference.ptr().as_find().is_some()
84
20574
            || symbols
85
20574
                .lookup(&reference.name().clone())
86
20574
                .is_some_and(|decl| decl.as_find().is_some())
87
20574
    };
88

            
89
5627800
    symbols.iter_local().any(|(_, declaration)| {
90
5627800
        declaration.domain().is_some_and(|domain| {
91
5627788
            Biplate::<Reference>::universe_bi(domain.as_ref())
92
5627788
                .iter()
93
5627788
                .any(&is_decision_reference)
94
5627788
                || Biplate::<IntVal>::universe_bi(domain.as_ref())
95
5627788
                    .iter()
96
5627788
                    .any(|int_val| match int_val {
97
                        IntVal::Const(_) => false,
98
                        IntVal::Reference(reference) => is_decision_reference(reference),
99
                        IntVal::Expr(expr) => Biplate::<Atom>::universe_bi(expr.as_ref())
100
                            .iter()
101
                            .any(|atom| {
102
                                matches!(atom, Atom::Reference(reference) if is_decision_reference(reference))
103
                            }),
104
                    })
105
5627800
        }) || declaration.as_find().is_some_and(|find| {
106
5591532
            let domain = find.domain_of();
107
5591532
            domain.resolve().is_none() && domain.as_ref().as_int().is_some()
108
5591532
        })
109
5627800
    })
110
63364
}
111

            
112
/// Resolves an integer domain when possible; otherwise computes a finite widened domain
113
/// by replacing symbolic bounds with conservative numeric bounds.
114
294
fn resolve_or_widen_int_domain(
115
294
    domain: &DomainPtr,
116
294
    symbols: &SymbolTable,
117
294
    known_int_bounds: &mut IntBoundsCache,
118
294
    visiting: &mut VisitingStack,
119
294
) -> Option<DomainPtr> {
120
294
    if let Some(resolved) = domain.resolve() {
121
204
        return Some(resolved.into());
122
90
    }
123

            
124
90
    let ranges = domain.as_ref().as_int()?;
125
90
    let widened_ranges: Vec<Range<i32>> = ranges
126
90
        .iter()
127
102
        .map(|range| int_range_bounds(range, symbols, known_int_bounds, visiting))
128
102
        .map(|bounds| bounds.map(|(lo, hi)| Range::new(Some(lo), Some(hi))))
129
90
        .collect::<Option<Vec<_>>>()?;
130

            
131
90
    Some(Domain::int_ground(widened_ranges))
132
294
}
133

            
134
/// Returns overall numeric bounds for an unresolved integer domain.
135
1608
fn int_domain_bounds_from_domain(
136
1608
    domain: &DomainPtr,
137
1608
    symbols: &SymbolTable,
138
1608
    known_int_bounds: &mut IntBoundsCache,
139
1608
    visiting: &mut VisitingStack,
140
1608
) -> Option<(i32, i32)> {
141
1608
    let ranges = domain.as_ref().as_int()?;
142
432
    let mut lower = i32::MAX;
143
432
    let mut upper = i32::MIN;
144

            
145
468
    for range in ranges {
146
468
        let (lo, hi) = int_range_bounds(&range, symbols, known_int_bounds, visiting)?;
147
468
        lower = lower.min(lo);
148
468
        upper = upper.max(hi);
149
    }
150

            
151
432
    Some((lower, upper))
152
1608
}
153

            
154
/// Computes numeric bounds for a possibly symbolic integer range.
155
570
fn int_range_bounds(
156
570
    range: &Range<IntVal>,
157
570
    symbols: &SymbolTable,
158
570
    known_int_bounds: &mut IntBoundsCache,
159
570
    visiting: &mut VisitingStack,
160
570
) -> Option<(i32, i32)> {
161
570
    match range {
162
12
        Range::Single(v) => int_val_bounds(v, symbols, known_int_bounds, visiting),
163
558
        Range::Bounded(l, r) => {
164
558
            let (ll, lh) = int_val_bounds(l, symbols, known_int_bounds, visiting)?;
165
558
            let (rl, rh) = int_val_bounds(r, symbols, known_int_bounds, visiting)?;
166
558
            Some((ll.min(lh), rl.max(rh)))
167
        }
168
        Range::Unbounded | Range::UnboundedL(_) | Range::UnboundedR(_) => None,
169
    }
170
570
}
171

            
172
/// Computes numeric bounds for an unresolved integer value.
173
1128
fn int_val_bounds(
174
1128
    value: &IntVal,
175
1128
    symbols: &SymbolTable,
176
1128
    known_int_bounds: &mut IntBoundsCache,
177
1128
    visiting: &mut VisitingStack,
178
1128
) -> Option<(i32, i32)> {
179
1128
    if let Some(v) = value.resolve() {
180
828
        return Some((v, v));
181
300
    }
182

            
183
300
    match value {
184
        IntVal::Const(v) => Some((*v, *v)),
185
96
        IntVal::Reference(reference) => {
186
96
            let name = reference.name().clone();
187
96
            int_bounds_for_name(&name, symbols, known_int_bounds, visiting)
188
        }
189
204
        IntVal::Expr(expr) => {
190
204
            let domain = expression_int_bounds(expr, symbols, known_int_bounds, visiting)?;
191
204
            int_domain_bounds_from_domain(&domain, symbols, known_int_bounds, visiting)
192
        }
193
    }
194
1128
}
195

            
196
/// Resolves cached or derived bounds for a declaration by name.
197
96
fn int_bounds_for_name(
198
96
    name: &Name,
199
96
    symbols: &SymbolTable,
200
96
    known_int_bounds: &mut IntBoundsCache,
201
96
    visiting: &mut VisitingStack,
202
96
) -> Option<(i32, i32)> {
203
96
    if let Some(bounds) = known_int_bounds.get(name).copied() {
204
96
        return Some(bounds);
205
    }
206

            
207
    if visiting.contains(name) {
208
        return None;
209
    }
210

            
211
    visiting.push(name.clone());
212
    let maybe_bounds = symbols
213
        .lookup(name)
214
        .and_then(|decl| decl.domain())
215
        .and_then(|domain| {
216
            int_domain_bounds_from_domain(&domain, symbols, known_int_bounds, visiting)
217
        });
218
    visiting.pop();
219
    let bounds = maybe_bounds?;
220
    known_int_bounds.insert(name.clone(), bounds);
221
    Some(bounds)
222
96
}
223

            
224
/// Computes a conservative ground integer domain for an expression.
225
204
fn expression_int_bounds(
226
204
    expr: &Moo<Expr>,
227
204
    symbols: &SymbolTable,
228
204
    known_int_bounds: &mut IntBoundsCache,
229
204
    visiting: &mut VisitingStack,
230
204
) -> Option<DomainPtr> {
231
204
    if let Some(Lit::Int(v)) = conjure_cp::ast::eval_constant(expr) {
232
        return Some(Domain::int_ground(vec![Range::Single(v)]));
233
204
    }
234

            
235
204
    let domain = expr.as_ref().domain_of()?;
236
204
    resolve_or_widen_int_domain(&domain, symbols, known_int_bounds, visiting)
237
204
}
238

            
239
/// Builds guards ensuring widened integer find domains still satisfy original symbolic bounds.
240
90
fn domain_consistency_constraints(
241
90
    declaration: &conjure_cp::ast::DeclarationPtr,
242
90
    original_domain: &DomainPtr,
243
90
) -> Option<Vec<Expr>> {
244
90
    if original_domain.resolve().is_some() {
245
        return Some(Vec::new());
246
90
    }
247

            
248
90
    let ranges = original_domain.as_ref().as_int()?;
249
90
    if ranges.is_empty() {
250
        return None;
251
90
    }
252

            
253
90
    let var_expr = Expr::Atomic(
254
90
        Metadata::new(),
255
90
        Atom::Reference(Reference::new(declaration.clone())),
256
90
    );
257
90
    let mut allowed_intervals = Vec::new();
258

            
259
102
    for range in ranges {
260
102
        let interval = match range {
261
            Range::Single(v) => Expr::Eq(
262
                Metadata::new(),
263
                Moo::new(var_expr.clone()),
264
                Moo::new(Expr::from(v)),
265
            ),
266
102
            Range::Bounded(l, r) => {
267
102
                let geq = Expr::Geq(
268
102
                    Metadata::new(),
269
102
                    Moo::new(var_expr.clone()),
270
102
                    Moo::new(Expr::from(l)),
271
102
                );
272
102
                let leq = Expr::Leq(
273
102
                    Metadata::new(),
274
102
                    Moo::new(var_expr.clone()),
275
102
                    Moo::new(Expr::from(r)),
276
102
                );
277
102
                Expr::And(
278
102
                    Metadata::new(),
279
102
                    Moo::new(conjure_cp::into_matrix_expr!(vec![geq, leq])),
280
102
                )
281
            }
282
            Range::Unbounded | Range::UnboundedL(_) | Range::UnboundedR(_) => return None,
283
        };
284
102
        allowed_intervals.push(interval);
285
    }
286

            
287
90
    let guard = if allowed_intervals.len() == 1 {
288
78
        allowed_intervals.remove(0)
289
    } else {
290
12
        Expr::Or(
291
12
            Metadata::new(),
292
12
            Moo::new(conjure_cp::into_matrix_expr!(allowed_intervals)),
293
12
        )
294
    };
295

            
296
90
    Some(vec![guard])
297
90
}