1
use conjure_cp::{
2
    ast::Metadata,
3
    ast::{
4
        AbstractLiteral, Atom, DeclarationPtr, Expression as Expr, Literal, Moo, Name, SymbolTable,
5
        categories::Category,
6
    },
7
};
8

            
9
use tracing::{instrument, trace};
10
use uniplate::{Biplate, Uniplate};
11

            
12
/// True iff `expr` is an `Atom`.
13
19335
pub fn is_atom(expr: &Expr) -> bool {
14
19335
    matches!(expr, Expr::Atomic(_, _))
15
19335
}
16

            
17
/// True iff `expr` is an `Atom` or `Not(Atom)`.
18
945
pub fn is_literal(expr: &Expr) -> bool {
19
945
    match expr {
20
549
        Expr::Atomic(_, _) => true,
21
        Expr::Not(_, inner) => matches!(**inner, Expr::Atomic(_, _)),
22
396
        _ => false,
23
    }
24
945
}
25

            
26
/// True if `expr` is flat; i.e. it only contains atoms.
27
pub fn is_flat(expr: &Expr) -> bool {
28
    for e in expr.children() {
29
        if !is_atom(&e) {
30
            return false;
31
        }
32
    }
33
    true
34
}
35

            
36
/// Returns the arity of a tuple constant expression, if this expression is one.
37
15
pub fn constant_tuple_len(expr: &Expr) -> Option<usize> {
38
    match expr {
39
        Expr::AbstractLiteral(_, AbstractLiteral::Tuple(elems)) => Some(elems.len()),
40
6
        Expr::Atomic(_, Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Tuple(elems)))) => {
41
6
            Some(elems.len())
42
        }
43
9
        _ => None,
44
    }
45
15
}
46

            
47
/// Returns record field names of a record constant expression, if this expression is one.
48
6
pub fn constant_record_names(expr: &Expr) -> Option<Vec<Name>> {
49
    match expr {
50
        Expr::AbstractLiteral(_, AbstractLiteral::Record(entries)) => {
51
            Some(entries.iter().map(|x| x.name.clone()).collect())
52
        }
53
        Expr::Atomic(
54
            _,
55
3
            Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Record(entries))),
56
6
        ) => Some(entries.iter().map(|x| x.name.clone()).collect()),
57
3
        _ => None,
58
    }
59
6
}
60

            
61
/// True if the entire AST is constants.
62
825888
pub fn is_all_constant(expression: &Expr) -> bool {
63
903813
    for atom in expression.universe_bi() {
64
903813
        match atom {
65
234834
            Atom::Literal(_) => {}
66
            _ => {
67
668979
                return false;
68
            }
69
        }
70
    }
71

            
72
156909
    true
73
825888
}
74

            
75
/// Converts a vector of expressions to a vector of atoms.
76
///
77
/// # Returns
78
///
79
/// `Some(Vec<Atom>)` if the vectors direct children expressions are all atomic, otherwise `None`.
80
#[allow(dead_code)]
81
pub fn expressions_to_atoms(exprs: &Vec<Expr>) -> Option<Vec<Atom>> {
82
    let mut atoms: Vec<Atom> = vec![];
83
    for expr in exprs {
84
        let Expr::Atomic(_, atom) = expr else {
85
            return None;
86
        };
87
        atoms.push(atom.clone());
88
    }
89

            
90
    Some(atoms)
91
}
92

            
93
/// Creates a new auxiliary variable using the given expression.
94
///
95
/// # Returns
96
///
97
/// * `None` if `Expr` is a `Atom`, or `Expr` does not have a domain (for example, if it is a `Bubble`).
98
///
99
/// * `Some(ToAuxVarOutput)` if successful, containing:
100
///
101
///     + A new symbol table, modified to include the auxiliary variable.
102
///     + A new top level expression, containing the declaration of the auxiliary variable.
103
///     + A reference to the auxiliary variable to replace the existing expression with.
104
///
105
#[instrument(skip_all, fields(expr = %expr))]
106
19335
pub fn to_aux_var(expr: &Expr, symbols: &SymbolTable) -> Option<ToAuxVarOutput> {
107
19335
    let mut symbols = symbols.clone();
108

            
109
    // No need to put an atom in an aux_var
110
19335
    if is_atom(expr) {
111
15072
        if cfg!(debug_assertions) {
112
15072
            trace!(why = "expression is an atom", "to_aux_var() failed");
113
        }
114
15072
        return None;
115
4263
    }
116

            
117
    // Anything that should be bubbled, bubble
118
4263
    if !expr.is_safe() {
119
45
        if cfg!(debug_assertions) {
120
45
            trace!(why = "expression is unsafe", "to_aux_var() failed");
121
        }
122
45
        return None;
123
4218
    }
124

            
125
    // Do not put abstract literals containing expressions into aux vars.
126
    //
127
    // e.g. for `[1,2,3,f/2,e][e]`, the lhs should not be put in an aux var.
128
    //
129
    // instead, we should flatten the elements inside this abstract literal, or wait for it to be
130
    // turned into an atom, or an abstract literal containing only literals - e.g. through an index
131
    // or slice operation.
132
    //
133
4218
    if let Expr::AbstractLiteral(_, _) = expr {
134
753
        if cfg!(debug_assertions) {
135
753
            trace!(
136
                why = "expression is an abstract literal",
137
                "to_aux_var() failed"
138
            );
139
        }
140
753
        return None;
141
3465
    }
142

            
143
    // Only flatten an expression if it contains decision variables or decision variables with some
144
    // constants.
145
    //
146
    // i.e. dont flatten things containing givens, quantified variables, just constants, etc.
147
3465
    let categories = expr.universe_categories();
148

            
149
3465
    assert!(!categories.is_empty());
150

            
151
3465
    if !(categories.len() == 1 && categories.contains(&Category::Decision)
152
2358
        || categories.len() == 2
153
2352
            && categories.contains(&Category::Decision)
154
2352
            && categories.contains(&Category::Constant))
155
    {
156
6
        if cfg!(debug_assertions) {
157
6
            trace!(
158
                why = "expression has sub-expressions that are not in the decision category",
159
                "to_aux_var() failed"
160
            );
161
        }
162
6
        return None;
163
3459
    }
164

            
165
    // Avoid introducing auxvars for generic matrix indexing (can create many redundant auxvars
166
    // before comprehension expansion). However, keep list indexing eligible so Minion lowering
167
    // can introduce `element` constraints in non-equality contexts.
168
3459
    if let Expr::SafeIndex(_, subject, indices) = expr {
169
1062
        let can_lower_via_element = subject.clone().unwrap_list().is_some()
170
921
            && indices.iter().all(|i| matches!(i, Expr::Atomic(_, _)));
171

            
172
1062
        if !can_lower_via_element {
173
552
            if cfg!(debug_assertions) {
174
552
                trace!(expr=%expr, why = "matrix indexing is not element-lowerable", "to_aux_var() failed");
175
            }
176
552
            return None;
177
510
        }
178
2397
    }
179

            
180
2907
    let Some(domain) = expr.domain_of() else {
181
105
        if cfg!(debug_assertions) {
182
105
            trace!(expr=%expr, why = "could not find the domain of the expression", "to_aux_var() failed");
183
        }
184
105
        return None;
185
    };
186

            
187
2802
    let decl = symbols.gensym(&domain);
188

            
189
2802
    if cfg!(debug_assertions) {
190
2802
        trace!(expr=%expr, "to_auxvar() succeeded in putting expr into an auxvar");
191
    }
192

            
193
2802
    Some(ToAuxVarOutput {
194
2802
        aux_declaration: decl.clone(),
195
2802
        aux_expression: Expr::AuxDeclaration(
196
2802
            Metadata::new(),
197
2802
            conjure_cp::ast::Reference::new(decl),
198
2802
            Moo::new(expr.clone()),
199
2802
        ),
200
2802
        symbols,
201
2802
        _unconstructable: (),
202
2802
    })
203
19335
}
204

            
205
/// Output data of `to_aux_var`.
206
pub struct ToAuxVarOutput {
207
    aux_declaration: DeclarationPtr,
208
    aux_expression: Expr,
209
    symbols: SymbolTable,
210
    _unconstructable: (),
211
}
212

            
213
impl ToAuxVarOutput {
214
    /// Returns the new auxiliary variable as an `Atom`.
215
2802
    pub fn as_atom(&self) -> Atom {
216
2802
        Atom::Reference(conjure_cp::ast::Reference::new(
217
2802
            self.aux_declaration.clone(),
218
2802
        ))
219
2802
    }
220

            
221
    /// Returns the new auxiliary variable as an `Expression`.
222
    ///
223
    /// This expression will have default `Metadata`.
224
2356
    pub fn as_expr(&self) -> Expr {
225
2356
        Expr::Atomic(Metadata::new(), self.as_atom())
226
2356
    }
227

            
228
    /// Returns the top level `Expression` to add to the model.
229
2802
    pub fn top_level_expr(&self) -> Expr {
230
2802
        self.aux_expression.clone()
231
2802
    }
232

            
233
    /// Returns the new `SymbolTable`, modified to contain this auxiliary variable in the symbol table.
234
2802
    pub fn symbols(&self) -> SymbolTable {
235
2802
        self.symbols.clone()
236
2802
    }
237
}