1
use conjure_cp::{
2
    ast::Metadata,
3
    ast::{
4
        AbstractLiteral, Atom, DeclarationPtr, Expression as Expr, Literal, Moo, Name, SymbolTable,
5
        categories::Category,
6
    },
7
};
8

            
9
use tracing::{instrument, trace};
10
use uniplate::{Biplate, Uniplate};
11

            
12
/// True iff `expr` is an `Atom`.
13
54981
pub fn is_atom(expr: &Expr) -> bool {
14
54981
    matches!(expr, Expr::Atomic(_, _))
15
54981
}
16

            
17
/// True iff `expr` is an `Atom` or `Not(Atom)`.
18
2835
pub fn is_literal(expr: &Expr) -> bool {
19
2835
    match expr {
20
1647
        Expr::Atomic(_, _) => true,
21
        Expr::Not(_, inner) => matches!(**inner, Expr::Atomic(_, _)),
22
1188
        _ => false,
23
    }
24
2835
}
25

            
26
/// True if `expr` is flat; i.e. it only contains atoms.
27
pub fn is_flat(expr: &Expr) -> bool {
28
    for e in expr.children() {
29
        if !is_atom(&e) {
30
            return false;
31
        }
32
    }
33
    true
34
}
35

            
36
/// Returns the arity of a tuple constant expression, if this expression is one.
37
45
pub fn constant_tuple_len(expr: &Expr) -> Option<usize> {
38
    match expr {
39
        Expr::AbstractLiteral(_, AbstractLiteral::Tuple(elems)) => Some(elems.len()),
40
18
        Expr::Atomic(_, Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Tuple(elems)))) => {
41
18
            Some(elems.len())
42
        }
43
27
        _ => None,
44
    }
45
45
}
46

            
47
/// Returns record field names of a record constant expression, if this expression is one.
48
18
pub fn constant_record_names(expr: &Expr) -> Option<Vec<Name>> {
49
    match expr {
50
        Expr::AbstractLiteral(_, AbstractLiteral::Record(entries)) => {
51
            Some(entries.iter().map(|x| x.name.clone()).collect())
52
        }
53
        Expr::Atomic(
54
            _,
55
9
            Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Record(entries))),
56
18
        ) => Some(entries.iter().map(|x| x.name.clone()).collect()),
57
9
        _ => None,
58
    }
59
18
}
60

            
61
/// True if the entire AST is constants.
62
2337471
pub fn is_all_constant(expression: &Expr) -> bool {
63
2557836
    for atom in expression.universe_bi() {
64
2557836
        match atom {
65
678942
            Atom::Literal(_) => {}
66
            _ => {
67
1878894
                return false;
68
            }
69
        }
70
    }
71

            
72
458577
    true
73
2337471
}
74

            
75
/// Converts a vector of expressions to a vector of atoms.
76
///
77
/// # Returns
78
///
79
/// `Some(Vec<Atom>)` if the vectors direct children expressions are all atomic, otherwise `None`.
80
#[allow(dead_code)]
81
pub fn expressions_to_atoms(exprs: &Vec<Expr>) -> Option<Vec<Atom>> {
82
    let mut atoms: Vec<Atom> = vec![];
83
    for expr in exprs {
84
        let Expr::Atomic(_, atom) = expr else {
85
            return None;
86
        };
87
        atoms.push(atom.clone());
88
    }
89

            
90
    Some(atoms)
91
}
92

            
93
/// Creates a new auxiliary variable using the given expression.
94
///
95
/// # Returns
96
///
97
/// * `None` if `Expr` is a `Atom`, or `Expr` does not have a domain (for example, if it is a `Bubble`).
98
///
99
/// * `Some(ToAuxVarOutput)` if successful, containing:
100
///
101
///     + A new symbol table, modified to include the auxiliary variable.
102
///     + A new top level expression, containing the declaration of the auxiliary variable.
103
///     + A reference to the auxiliary variable to replace the existing expression with.
104
///
105
#[instrument(skip_all, fields(expr = %expr))]
106
54981
pub fn to_aux_var(expr: &Expr, symbols: &SymbolTable) -> Option<ToAuxVarOutput> {
107
54981
    let mut symbols = symbols.clone();
108

            
109
    // No need to put an atom in an aux_var
110
54981
    if is_atom(expr) {
111
42624
        if cfg!(debug_assertions) {
112
42624
            trace!(why = "expression is an atom", "to_aux_var() failed");
113
        }
114
42624
        return None;
115
12357
    }
116

            
117
    // Anything that should be bubbled, bubble
118
12357
    if !expr.is_safe() {
119
135
        if cfg!(debug_assertions) {
120
135
            trace!(why = "expression is unsafe", "to_aux_var() failed");
121
        }
122
135
        return None;
123
12222
    }
124

            
125
    // Do not put abstract literals containing expressions into aux vars.
126
    //
127
    // e.g. for `[1,2,3,f/2,e][e]`, the lhs should not be put in an aux var.
128
    //
129
    // instead, we should flatten the elements inside this abstract literal, or wait for it to be
130
    // turned into an atom, or an abstract literal containing only literals - e.g. through an index
131
    // or slice operation.
132
    //
133
12222
    if let Expr::AbstractLiteral(_, _) = expr {
134
2187
        if cfg!(debug_assertions) {
135
2187
            trace!(
136
                why = "expression is an abstract literal",
137
                "to_aux_var() failed"
138
            );
139
        }
140
2187
        return None;
141
10035
    }
142

            
143
    // Only flatten an expression if it contains decision variables or decision variables with some
144
    // constants.
145
    //
146
    // i.e. dont flatten things containing givens, quantified variables, just constants, etc.
147
10035
    let categories = expr.universe_categories();
148

            
149
10035
    assert!(!categories.is_empty());
150

            
151
10035
    if !(categories.len() == 1 && categories.contains(&Category::Decision)
152
6786
        || categories.len() == 2
153
6768
            && categories.contains(&Category::Decision)
154
6768
            && categories.contains(&Category::Constant))
155
    {
156
18
        if cfg!(debug_assertions) {
157
18
            trace!(
158
                why = "expression has sub-expressions that are not in the decision category",
159
                "to_aux_var() failed"
160
            );
161
        }
162
18
        return None;
163
10017
    }
164

            
165
    // Avoid introducing auxvars for generic matrix indexing (can create many redundant auxvars
166
    // before comprehension expansion). However, keep list indexing eligible so Minion lowering
167
    // can introduce `element` constraints in non-equality contexts.
168
10017
    if let Expr::SafeIndex(_, subject, indices) = expr {
169
2970
        let can_lower_via_element = subject.clone().unwrap_list().is_some()
170
2547
            && indices.iter().all(|i| matches!(i, Expr::Atomic(_, _)));
171

            
172
2970
        if !can_lower_via_element {
173
1512
            if cfg!(debug_assertions) {
174
1512
                trace!(expr=%expr, why = "matrix indexing is not element-lowerable", "to_aux_var() failed");
175
            }
176
1512
            return None;
177
1458
        }
178
7047
    }
179

            
180
8505
    let Some(domain) = expr.domain_of() else {
181
315
        if cfg!(debug_assertions) {
182
315
            trace!(expr=%expr, why = "could not find the domain of the expression", "to_aux_var() failed");
183
        }
184
315
        return None;
185
    };
186

            
187
8190
    let decl = symbols.gensym(&domain);
188

            
189
8190
    if cfg!(debug_assertions) {
190
8190
        trace!(expr=%expr, "to_auxvar() succeeded in putting expr into an auxvar");
191
    }
192

            
193
8190
    Some(ToAuxVarOutput {
194
8190
        aux_declaration: decl.clone(),
195
8190
        aux_expression: Expr::AuxDeclaration(
196
8190
            Metadata::new(),
197
8190
            conjure_cp::ast::Reference::new(decl),
198
8190
            Moo::new(expr.clone()),
199
8190
        ),
200
8190
        symbols,
201
8190
        _unconstructable: (),
202
8190
    })
203
54981
}
204

            
205
/// Output data of `to_aux_var`.
206
pub struct ToAuxVarOutput {
207
    aux_declaration: DeclarationPtr,
208
    aux_expression: Expr,
209
    symbols: SymbolTable,
210
    _unconstructable: (),
211
}
212

            
213
impl ToAuxVarOutput {
214
    /// Returns the new auxiliary variable as an `Atom`.
215
8190
    pub fn as_atom(&self) -> Atom {
216
8190
        Atom::Reference(conjure_cp::ast::Reference::new(
217
8190
            self.aux_declaration.clone(),
218
8190
        ))
219
8190
    }
220

            
221
    /// Returns the new auxiliary variable as an `Expression`.
222
    ///
223
    /// This expression will have default `Metadata`.
224
6876
    pub fn as_expr(&self) -> Expr {
225
6876
        Expr::Atomic(Metadata::new(), self.as_atom())
226
6876
    }
227

            
228
    /// Returns the top level `Expression` to add to the model.
229
8190
    pub fn top_level_expr(&self) -> Expr {
230
8190
        self.aux_expression.clone()
231
8190
    }
232

            
233
    /// Returns the new `SymbolTable`, modified to contain this auxiliary variable in the symbol table.
234
8190
    pub fn symbols(&self) -> SymbolTable {
235
8190
        self.symbols.clone()
236
8190
    }
237
}