1
use conjure_cp::ast::{
2
    AbstractLiteral, Atom, DeclarationPtr, Expression as Expr, Literal, Metadata, Moo, Name,
3
    SymbolTable,
4
    categories::Category,
5
    comprehension::{Comprehension, ComprehensionQualifier},
6
};
7

            
8
use tracing::{instrument, trace};
9
use uniplate::{Biplate, Uniplate};
10

            
11
/// True iff `expr` is an `Atom`.
12
73526
pub fn is_atom(expr: &Expr) -> bool {
13
73526
    matches!(expr, Expr::Atomic(_, _))
14
73526
}
15

            
16
/// True iff `expr` is an `Atom` or `Not(Atom)`.
17
8985
pub fn is_literal(expr: &Expr) -> bool {
18
8985
    match expr {
19
4851
        Expr::Atomic(_, _) => true,
20
        Expr::Not(_, inner) => matches!(**inner, Expr::Atomic(_, _)),
21
4134
        _ => false,
22
    }
23
8985
}
24

            
25
/// True if `expr` is flat; i.e. it only contains atoms.
26
pub fn is_flat(expr: &Expr) -> bool {
27
    for e in expr.children() {
28
        if !is_atom(&e) {
29
            return false;
30
        }
31
    }
32
    true
33
}
34

            
35
/// Returns the arity of a tuple constant expression, if this expression is one.
36
60
pub fn constant_tuple_len(expr: &Expr) -> Option<usize> {
37
    match expr {
38
        Expr::AbstractLiteral(_, AbstractLiteral::Tuple(elems)) => Some(elems.len()),
39
24
        Expr::Atomic(_, Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Tuple(elems)))) => {
40
24
            Some(elems.len())
41
        }
42
36
        _ => None,
43
    }
44
60
}
45

            
46
/// Returns record field names of a record constant expression, if this expression is one.
47
24
pub fn constant_record_names(expr: &Expr) -> Option<Vec<Name>> {
48
    match expr {
49
        Expr::AbstractLiteral(_, AbstractLiteral::Record(entries)) => {
50
            Some(entries.iter().map(|x| x.name.clone()).collect())
51
        }
52
        Expr::Atomic(
53
            _,
54
12
            Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Record(entries))),
55
24
        ) => Some(entries.iter().map(|x| x.name.clone()).collect()),
56
12
        _ => None,
57
    }
58
24
}
59

            
60
/// True if the entire AST is constants.
61
3839655
pub fn is_all_constant(expression: &Expr) -> bool {
62
4180062
    for atom in expression.universe_bi() {
63
4149432
        match atom {
64
1091151
            Atom::Literal(_) => {}
65
            _ => {
66
3058281
                return false;
67
            }
68
        }
69
    }
70

            
71
781374
    true
72
3839655
}
73

            
74
/// Converts a vector of expressions to a vector of atoms.
75
///
76
/// # Returns
77
///
78
/// `Some(Vec<Atom>)` if the vectors direct children expressions are all atomic, otherwise `None`.
79
#[allow(dead_code)]
80
pub fn expressions_to_atoms(exprs: &Vec<Expr>) -> Option<Vec<Atom>> {
81
    let mut atoms: Vec<Atom> = vec![];
82
    for expr in exprs {
83
        let Expr::Atomic(_, atom) = expr else {
84
            return None;
85
        };
86
        atoms.push(atom.clone());
87
    }
88

            
89
    Some(atoms)
90
}
91

            
92
/// Creates a new auxiliary variable using the given expression.
93
///
94
/// # Returns
95
///
96
/// * `None` if `Expr` is a `Atom`, or `Expr` does not have a domain (for example, if it is a `Bubble`).
97
///
98
/// * `Some(ToAuxVarOutput)` if successful, containing:
99
///
100
///     + A new symbol table, modified to include the auxiliary variable.
101
///     + A new top level expression, containing the declaration of the auxiliary variable.
102
///     + A reference to the auxiliary variable to replace the existing expression with.
103
///
104
#[instrument(skip_all, fields(expr = %expr))]
105
73526
pub fn to_aux_var(expr: &Expr, symbols: &SymbolTable) -> Option<ToAuxVarOutput> {
106
73526
    let mut symbols = symbols.clone();
107

            
108
    // No need to put an atom in an aux_var
109
73526
    if is_atom(expr) {
110
56132
        if cfg!(debug_assertions) {
111
56132
            trace!(why = "expression is an atom", "to_aux_var() failed");
112
        }
113
56132
        return None;
114
17394
    }
115

            
116
    // Anything that should be bubbled, bubble
117
17394
    if !expr.is_safe() {
118
536
        if cfg!(debug_assertions) {
119
536
            trace!(why = "expression is unsafe", "to_aux_var() failed");
120
        }
121
536
        return None;
122
16858
    }
123

            
124
    // Do not put abstract literals containing expressions into aux vars.
125
    //
126
    // e.g. for `[1,2,3,f/2,e][e]`, the lhs should not be put in an aux var.
127
    //
128
    // instead, we should flatten the elements inside this abstract literal, or wait for it to be
129
    // turned into an atom, or an abstract literal containing only literals - e.g. through an index
130
    // or slice operation.
131
    //
132
16858
    if let Expr::AbstractLiteral(_, _) = expr {
133
2559
        if cfg!(debug_assertions) {
134
2559
            trace!(
135
                why = "expression is an abstract literal",
136
                "to_aux_var() failed"
137
            );
138
        }
139
2559
        return None;
140
14299
    }
141

            
142
    // Only flatten an expression if it contains decision variables or decision variables with some
143
    // constants.
144
    //
145
    // i.e. dont flatten things containing givens, quantified variables, just constants, etc.
146
14299
    let categories = expr.universe_categories();
147

            
148
14299
    assert!(!categories.is_empty());
149

            
150
14299
    if !(categories.len() == 1 && categories.contains(&Category::Decision)
151
10529
        || categories.len() == 2
152
10259
            && categories.contains(&Category::Decision)
153
10070
            && categories.contains(&Category::Constant))
154
    {
155
1002
        if cfg!(debug_assertions) {
156
1002
            trace!(
157
                why = "expression has sub-expressions that are not in the decision category",
158
                "to_aux_var() failed"
159
            );
160
        }
161
1002
        return None;
162
13297
    }
163

            
164
    // Avoid introducing auxvars for generic matrix indexing (can create many redundant auxvars
165
    // before comprehension expansion). However, keep list indexing eligible so Minion lowering
166
    // can introduce `element` constraints in non-equality contexts.
167
13297
    if let Expr::SafeIndex(_, subject, indices) = expr {
168
6573
        let can_lower_via_element = subject.clone().unwrap_list().is_some()
169
5985
            && indices.iter().all(|i| matches!(i, Expr::Atomic(_, _)));
170

            
171
6573
        if !can_lower_via_element {
172
4728
            if cfg!(debug_assertions) {
173
4728
                trace!(expr=%expr, why = "matrix indexing is not element-lowerable", "to_aux_var() failed");
174
            }
175
4728
            return None;
176
1845
        }
177
6724
    }
178

            
179
8569
    let Some(domain) = expr.domain_of() else {
180
382
        if cfg!(debug_assertions) {
181
382
            trace!(expr=%expr, why = "could not find the domain of the expression", "to_aux_var() failed");
182
        }
183
382
        return None;
184
    };
185

            
186
8187
    let decl = symbols.gen_find(&domain);
187

            
188
8187
    if cfg!(debug_assertions) {
189
8187
        trace!(expr=%expr, "to_auxvar() succeeded in putting expr into an auxvar");
190
    }
191

            
192
8187
    Some(ToAuxVarOutput {
193
8187
        aux_declaration: decl.clone(),
194
8187
        aux_expression: Expr::AuxDeclaration(
195
8187
            Metadata::new(),
196
8187
            conjure_cp::ast::Reference::new(decl),
197
8187
            Moo::new(expr.clone()),
198
8187
        ),
199
8187
        symbols,
200
8187
        _unconstructable: (),
201
8187
    })
202
73526
}
203

            
204
/// Output data of `to_aux_var`.
205
pub struct ToAuxVarOutput {
206
    aux_declaration: DeclarationPtr,
207
    aux_expression: Expr,
208
    symbols: SymbolTable,
209
    _unconstructable: (),
210
}
211

            
212
impl ToAuxVarOutput {
213
    /// Returns the new auxiliary variable as an `Atom`.
214
8187
    pub fn as_atom(&self) -> Atom {
215
8187
        Atom::Reference(conjure_cp::ast::Reference::new(
216
8187
            self.aux_declaration.clone(),
217
8187
        ))
218
8187
    }
219

            
220
    /// Returns the new auxiliary variable as an `Expression`.
221
    ///
222
    /// This expression will have default `Metadata`.
223
6544
    pub fn as_expr(&self) -> Expr {
224
6544
        Expr::Atomic(Metadata::new(), self.as_atom())
225
6544
    }
226

            
227
    /// Returns the top level `Expression` to add to the model.
228
8187
    pub fn top_level_expr(&self) -> Expr {
229
8187
        self.aux_expression.clone()
230
8187
    }
231

            
232
    /// Returns the new `SymbolTable`, modified to contain this auxiliary variable in the symbol table.
233
8187
    pub fn symbols(&self) -> SymbolTable {
234
8187
        self.symbols.clone()
235
8187
    }
236
}
237

            
238
/// Clone comprehension with expression generator into its own detached comprehension scope
239
/// and rewrite all uses of the original quantified declaration to a fresh branch-local
240
/// expression generator.
241
22
pub fn replace_expression_generator_source(
242
22
    comp: &Comprehension,
243
22
    gen_decl: &DeclarationPtr,
244
22
    replacement_expr: Expr,
245
22
) -> (Comprehension, DeclarationPtr) {
246
22
    let replacement_ptr =
247
22
        DeclarationPtr::new_quantified_expr(gen_decl.name().clone(), replacement_expr);
248
22
    let mut comprehension = comp.clone();
249

            
250
    // detach the scope so rewriting this branch does not mutate the original
251
    // comprehension through shared pointers
252
22
    comprehension.symbols = comprehension.symbols.detach();
253

            
254
    // rewrite all uses of the original quantified declaration to the branch-local
255
    // generator declaration
256
22
    comprehension.return_expression =
257
22
        comprehension
258
22
            .return_expression
259
98
            .transform_bi(&|decl: DeclarationPtr| {
260
98
                if decl == *gen_decl {
261
22
                    replacement_ptr.clone()
262
                } else {
263
76
                    decl
264
                }
265
98
            });
266

            
267
22
    comprehension.qualifiers = comprehension
268
22
        .qualifiers
269
22
        .into_iter()
270
42
        .map(|qualifier| {
271
110
            qualifier.transform_bi(&|decl: DeclarationPtr| {
272
110
                if decl == *gen_decl {
273
22
                    replacement_ptr.clone()
274
                } else {
275
88
                    decl
276
                }
277
110
            })
278
42
        })
279
22
        .collect();
280

            
281
    // keep the detached local scope in sync with the rewritten generator
282
    // declarations used by this branch
283
22
    comprehension
284
22
        .symbols
285
22
        .write()
286
22
        .update_insert(replacement_ptr.clone());
287
42
    for qualifier in &comprehension.qualifiers {
288
42
        match qualifier {
289
34
            ComprehensionQualifier::ExpressionGenerator { ptr }
290
38
            | ComprehensionQualifier::Generator { ptr } => {
291
38
                comprehension.symbols.write().update_insert(ptr.clone());
292
38
            }
293
4
            ComprehensionQualifier::Condition(_) => {}
294
        }
295
    }
296

            
297
22
    (comprehension, replacement_ptr)
298
22
}