1
use conjure_cp::{
2
    ast::Metadata,
3
    ast::{
4
        AbstractLiteral, Atom, DeclarationPtr, Expression as Expr, Literal, Moo, Name, SymbolTable,
5
        categories::Category,
6
    },
7
};
8

            
9
use tracing::{instrument, trace};
10
use uniplate::{Biplate, Uniplate};
11

            
12
/// True iff `expr` is an `Atom`.
13
63432
pub fn is_atom(expr: &Expr) -> bool {
14
63432
    matches!(expr, Expr::Atomic(_, _))
15
63432
}
16

            
17
/// True iff `expr` is an `Atom` or `Not(Atom)`.
18
pub fn is_literal(expr: &Expr) -> bool {
19
    match expr {
20
        Expr::Atomic(_, _) => true,
21
        Expr::Not(_, inner) => matches!(**inner, Expr::Atomic(_, _)),
22
        _ => false,
23
    }
24
}
25

            
26
/// True if `expr` is flat; i.e. it only contains atoms.
27
pub fn is_flat(expr: &Expr) -> bool {
28
    for e in expr.children() {
29
        if !is_atom(&e) {
30
            return false;
31
        }
32
    }
33
    true
34
}
35

            
36
/// Returns the arity of a tuple constant expression, if this expression is one.
37
30
pub fn constant_tuple_len(expr: &Expr) -> Option<usize> {
38
    match expr {
39
        Expr::AbstractLiteral(_, AbstractLiteral::Tuple(elems)) => Some(elems.len()),
40
12
        Expr::Atomic(_, Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Tuple(elems)))) => {
41
12
            Some(elems.len())
42
        }
43
18
        _ => None,
44
    }
45
30
}
46

            
47
/// Returns record field names of a record constant expression, if this expression is one.
48
12
pub fn constant_record_names(expr: &Expr) -> Option<Vec<Name>> {
49
    match expr {
50
        Expr::AbstractLiteral(_, AbstractLiteral::Record(entries)) => {
51
            Some(entries.iter().map(|x| x.name.clone()).collect())
52
        }
53
        Expr::Atomic(
54
            _,
55
6
            Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Record(entries))),
56
12
        ) => Some(entries.iter().map(|x| x.name.clone()).collect()),
57
6
        _ => None,
58
    }
59
12
}
60

            
61
/// True if the entire AST is constants.
62
2178738
pub fn is_all_constant(expression: &Expr) -> bool {
63
2451429
    for atom in expression.universe_bi() {
64
2451429
        match atom {
65
614484
            Atom::Literal(_) => {}
66
            _ => {
67
1836945
                return false;
68
            }
69
        }
70
    }
71

            
72
341793
    true
73
2178738
}
74

            
75
/// Converts a vector of expressions to a vector of atoms.
76
///
77
/// # Returns
78
///
79
/// `Some(Vec<Atom>)` if the vectors direct children expressions are all atomic, otherwise `None`.
80
#[allow(dead_code)]
81
pub fn expressions_to_atoms(exprs: &Vec<Expr>) -> Option<Vec<Atom>> {
82
    let mut atoms: Vec<Atom> = vec![];
83
    for expr in exprs {
84
        let Expr::Atomic(_, atom) = expr else {
85
            return None;
86
        };
87
        atoms.push(atom.clone());
88
    }
89

            
90
    Some(atoms)
91
}
92

            
93
/// Creates a new auxiliary variable using the given expression.
94
///
95
/// # Returns
96
///
97
/// * `None` if `Expr` is a `Atom`, or `Expr` does not have a domain (for example, if it is a `Bubble`).
98
///
99
/// * `Some(ToAuxVarOutput)` if successful, containing:
100
///
101
///     + A new symbol table, modified to include the auxiliary variable.
102
///     + A new top level expression, containing the declaration of the auxiliary variable.
103
///     + A reference to the auxiliary variable to replace the existing expression with.
104
///
105
#[instrument(skip_all, fields(expr = %expr))]
106
63432
pub fn to_aux_var(expr: &Expr, symbols: &SymbolTable) -> Option<ToAuxVarOutput> {
107
63432
    let mut symbols = symbols.clone();
108

            
109
    // No need to put an atom in an aux_var
110
63432
    if is_atom(expr) {
111
46242
        if cfg!(debug_assertions) {
112
46242
            trace!(why = "expression is an atom", "to_aux_var() failed");
113
        }
114
46242
        return None;
115
17190
    }
116

            
117
    // Anything that should be bubbled, bubble
118
17190
    if !expr.is_safe() {
119
330
        if cfg!(debug_assertions) {
120
330
            trace!(why = "expression is unsafe", "to_aux_var() failed");
121
        }
122
330
        return None;
123
16860
    }
124

            
125
    // Do not put abstract literals containing expressions into aux vars.
126
    //
127
    // e.g. for `[1,2,3,f/2,e][e]`, the lhs should not be put in an aux var.
128
    //
129
    // instead, we should flatten the elements inside this abstract literal, or wait for it to be
130
    // turned into an atom, or an abstract literal containing only literals - e.g. through an index
131
    // or slice operation.
132
    //
133
16860
    if let Expr::AbstractLiteral(_, _) = expr {
134
5952
        if cfg!(debug_assertions) {
135
5952
            trace!(
136
                why = "expression is an abstract literal",
137
                "to_aux_var() failed"
138
            );
139
        }
140
5952
        return None;
141
10908
    }
142

            
143
    // Only flatten an expression if it contains decision variables or decision variables with some
144
    // constants.
145
    //
146
    // i.e. dont flatten things containing givens, quantified variables, just constants, etc.
147
10908
    let categories = expr.universe_categories();
148

            
149
10908
    assert!(!categories.is_empty());
150

            
151
10908
    if !(categories.len() == 1 && categories.contains(&Category::Decision)
152
7212
        || categories.len() == 2
153
7200
            && categories.contains(&Category::Decision)
154
7200
            && categories.contains(&Category::Constant))
155
    {
156
12
        if cfg!(debug_assertions) {
157
12
            trace!(
158
                why = "expression has sub-expressions that are not in the decision category",
159
                "to_aux_var() failed"
160
            );
161
        }
162
12
        return None;
163
10896
    }
164

            
165
    // FIXME: why does removing this make tests fail!
166
    //
167
    // do not put matrix[e] in auxvar
168
    //
169
    // eventually this will rewrite into an indomain constraint, or a single variable.
170
    //
171
    // To understand why deferring this until a lower level constraint is chosen is good, consider
172
    // the comprehension:
173
    //
174
    // and([m[i] = i + 1  | i: int(1..5)])
175
    //
176
    // Here, if we rewrite inside the comprehension, we will end up making the auxvar
177
    // __0  = m[i].
178
    //
179
    // When we expand the matrix, this will expand to:
180
    //
181
    // __0 = m[1]
182
    // __1 = m[2]
183
    // __2 = m[3]
184
    // __3 = m[4]
185
    // __4 = m[5]
186
    //
187
    //
188
    // These all rewrite to variable references (e.g. m[1] ~> m#matrix_to_atom_1), so these auxvars
189
    // are redundant. However, we don't know this before expanding, as they are just m[i].
190
    //
191
    // In the future, we can do this more fine-grained using categories (e.g. only flatten matrices
192
    // indexed by expressions with the decision variable category) : however, doing this for
193
    // all matrix indexing is fine, as they can be rewritten into a lower-level expression, then
194
    // flattened.
195
10896
    if let Expr::SafeIndex(_, _, _) = expr {
196
3132
        if cfg!(debug_assertions) {
197
3132
            trace!(expr=%expr, why = "expression is an matrix indexing operation", "to_aux_var() failed");
198
        }
199
3132
        return None;
200
7764
    }
201

            
202
7764
    let Some(domain) = expr.domain_of() else {
203
348
        if cfg!(debug_assertions) {
204
348
            trace!(expr=%expr, why = "could not find the domain of the expression", "to_aux_var() failed");
205
        }
206
348
        return None;
207
    };
208

            
209
7416
    let decl = symbols.gensym(&domain);
210

            
211
7416
    if cfg!(debug_assertions) {
212
7416
        trace!(expr=%expr, "to_auxvar() succeeded in putting expr into an auxvar");
213
    }
214

            
215
7416
    Some(ToAuxVarOutput {
216
7416
        aux_declaration: decl.clone(),
217
7416
        aux_expression: Expr::AuxDeclaration(
218
7416
            Metadata::new(),
219
7416
            conjure_cp::ast::Reference::new(decl),
220
7416
            Moo::new(expr.clone()),
221
7416
        ),
222
7416
        symbols,
223
7416
        _unconstructable: (),
224
7416
    })
225
63432
}
226

            
227
/// Output data of `to_aux_var`.
228
pub struct ToAuxVarOutput {
229
    aux_declaration: DeclarationPtr,
230
    aux_expression: Expr,
231
    symbols: SymbolTable,
232
    _unconstructable: (),
233
}
234

            
235
impl ToAuxVarOutput {
236
    /// Returns the new auxiliary variable as an `Atom`.
237
7416
    pub fn as_atom(&self) -> Atom {
238
7416
        Atom::Reference(conjure_cp::ast::Reference::new(
239
7416
            self.aux_declaration.clone(),
240
7416
        ))
241
7416
    }
242

            
243
    /// Returns the new auxiliary variable as an `Expression`.
244
    ///
245
    /// This expression will have default `Metadata`.
246
6054
    pub fn as_expr(&self) -> Expr {
247
6054
        Expr::Atomic(Metadata::new(), self.as_atom())
248
6054
    }
249

            
250
    /// Returns the top level `Expression` to add to the model.
251
7416
    pub fn top_level_expr(&self) -> Expr {
252
7416
        self.aux_expression.clone()
253
7416
    }
254

            
255
    /// Returns the new `SymbolTable`, modified to contain this auxiliary variable in the symbol table.
256
7416
    pub fn symbols(&self) -> SymbolTable {
257
7416
        self.symbols.clone()
258
7416
    }
259
}