1
use conjure_cp::{
2
    ast::Metadata,
3
    ast::{
4
        AbstractLiteral, Atom, DeclarationPtr, Expression as Expr, Literal, Moo, Name, SymbolTable,
5
        categories::Category,
6
    },
7
};
8

            
9
use tracing::{instrument, trace};
10
use uniplate::{Biplate, Uniplate};
11

            
12
/// True iff `expr` is an `Atom`.
13
9150
pub fn is_atom(expr: &Expr) -> bool {
14
9150
    matches!(expr, Expr::Atomic(_, _))
15
9150
}
16

            
17
/// True iff `expr` is an `Atom` or `Not(Atom)`.
18
pub fn is_literal(expr: &Expr) -> bool {
19
    match expr {
20
        Expr::Atomic(_, _) => true,
21
        Expr::Not(_, inner) => matches!(**inner, Expr::Atomic(_, _)),
22
        _ => false,
23
    }
24
}
25

            
26
/// True if `expr` is flat; i.e. it only contains atoms.
27
pub fn is_flat(expr: &Expr) -> bool {
28
    for e in expr.children() {
29
        if !is_atom(&e) {
30
            return false;
31
        }
32
    }
33
    true
34
}
35

            
36
/// Returns the arity of a tuple constant expression, if this expression is one.
37
15
pub fn constant_tuple_len(expr: &Expr) -> Option<usize> {
38
    match expr {
39
        Expr::AbstractLiteral(_, AbstractLiteral::Tuple(elems)) => Some(elems.len()),
40
6
        Expr::Atomic(_, Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Tuple(elems)))) => {
41
6
            Some(elems.len())
42
        }
43
9
        _ => None,
44
    }
45
15
}
46

            
47
/// Returns record field names of a record constant expression, if this expression is one.
48
6
pub fn constant_record_names(expr: &Expr) -> Option<Vec<Name>> {
49
    match expr {
50
        Expr::AbstractLiteral(_, AbstractLiteral::Record(entries)) => {
51
            Some(entries.iter().map(|x| x.name.clone()).collect())
52
        }
53
        Expr::Atomic(
54
            _,
55
3
            Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Record(entries))),
56
6
        ) => Some(entries.iter().map(|x| x.name.clone()).collect()),
57
3
        _ => None,
58
    }
59
6
}
60

            
61
/// True if the entire AST is constants.
62
259404
pub fn is_all_constant(expression: &Expr) -> bool {
63
282921
    for atom in expression.universe_bi() {
64
282921
        match atom {
65
77172
            Atom::Literal(_) => {}
66
            _ => {
67
205749
                return false;
68
            }
69
        }
70
    }
71

            
72
53655
    true
73
259404
}
74

            
75
/// Converts a vector of expressions to a vector of atoms.
76
///
77
/// # Returns
78
///
79
/// `Some(Vec<Atom>)` if the vectors direct children expressions are all atomic, otherwise `None`.
80
#[allow(dead_code)]
81
pub fn expressions_to_atoms(exprs: &Vec<Expr>) -> Option<Vec<Atom>> {
82
    let mut atoms: Vec<Atom> = vec![];
83
    for expr in exprs {
84
        let Expr::Atomic(_, atom) = expr else {
85
            return None;
86
        };
87
        atoms.push(atom.clone());
88
    }
89

            
90
    Some(atoms)
91
}
92

            
93
/// Creates a new auxiliary variable using the given expression.
94
///
95
/// # Returns
96
///
97
/// * `None` if `Expr` is a `Atom`, or `Expr` does not have a domain (for example, if it is a `Bubble`).
98
///
99
/// * `Some(ToAuxVarOutput)` if successful, containing:
100
///
101
///     + A new symbol table, modified to include the auxiliary variable.
102
///     + A new top level expression, containing the declaration of the auxiliary variable.
103
///     + A reference to the auxiliary variable to replace the existing expression with.
104
///
105
#[instrument(skip_all, fields(expr = %expr))]
106
9150
pub fn to_aux_var(expr: &Expr, symbols: &SymbolTable) -> Option<ToAuxVarOutput> {
107
9150
    let mut symbols = symbols.clone();
108

            
109
    // No need to put an atom in an aux_var
110
9150
    if is_atom(expr) {
111
7437
        if cfg!(debug_assertions) {
112
7437
            trace!(why = "expression is an atom", "to_aux_var() failed");
113
        }
114
7437
        return None;
115
1713
    }
116

            
117
    // Anything that should be bubbled, bubble
118
1713
    if !expr.is_safe() {
119
6
        if cfg!(debug_assertions) {
120
6
            trace!(why = "expression is unsafe", "to_aux_var() failed");
121
        }
122
6
        return None;
123
1707
    }
124

            
125
    // Do not put abstract literals containing expressions into aux vars.
126
    //
127
    // e.g. for `[1,2,3,f/2,e][e]`, the lhs should not be put in an aux var.
128
    //
129
    // instead, we should flatten the elements inside this abstract literal, or wait for it to be
130
    // turned into an atom, or an abstract literal containing only literals - e.g. through an index
131
    // or slice operation.
132
    //
133
1707
    if let Expr::AbstractLiteral(_, _) = expr {
134
129
        if cfg!(debug_assertions) {
135
129
            trace!(
136
                why = "expression is an abstract literal",
137
                "to_aux_var() failed"
138
            );
139
        }
140
129
        return None;
141
1578
    }
142

            
143
    // Only flatten an expression if it contains decision variables or decision variables with some
144
    // constants.
145
    //
146
    // i.e. dont flatten things containing givens, quantified variables, just constants, etc.
147
1578
    let categories = expr.universe_categories();
148

            
149
1578
    assert!(!categories.is_empty());
150

            
151
1578
    if !(categories.len() == 1 && categories.contains(&Category::Decision)
152
1125
        || categories.len() == 2
153
963
            && categories.contains(&Category::Decision)
154
849
            && categories.contains(&Category::Constant))
155
    {
156
546
        if cfg!(debug_assertions) {
157
546
            trace!(
158
                why = "expression has sub-expressions that are not in the decision category",
159
                "to_aux_var() failed"
160
            );
161
        }
162
546
        return None;
163
1032
    }
164

            
165
    // FIXME: why does removing this make tests fail!
166
    //
167
    // do not put matrix[e] in auxvar
168
    //
169
    // eventually this will rewrite into an indomain constraint, or a single variable.
170
    //
171
    // To understand why deferring this until a lower level constraint is chosen is good, consider
172
    // the comprehension:
173
    //
174
    // and([m[i] = i + 1  | i: int(1..5)])
175
    //
176
    // Here, if we rewrite inside the comprehension, we will end up making the auxvar
177
    // __0  = m[i].
178
    //
179
    // When we expand the matrix, this will expand to:
180
    //
181
    // __0 = m[1]
182
    // __1 = m[2]
183
    // __2 = m[3]
184
    // __3 = m[4]
185
    // __4 = m[5]
186
    //
187
    //
188
    // These all rewrite to variable references (e.g. m[1] ~> m#matrix_to_atom_1), so these auxvars
189
    // are redundant. However, we don't know this before expanding, as they are just m[i].
190
    //
191
    // In the future, we can do this more fine-grained using categories (e.g. only flatten matrices
192
    // indexed by expressions with the decision variable category) : however, doing this for
193
    // all matrix indexing is fine, as they can be rewritten into a lower-level expression, then
194
    // flattened.
195
1032
    if let Expr::SafeIndex(_, _, _) = expr {
196
237
        if cfg!(debug_assertions) {
197
237
            trace!(expr=%expr, why = "expression is an matrix indexing operation", "to_aux_var() failed");
198
        }
199
237
        return None;
200
795
    }
201

            
202
795
    let Some(domain) = expr.domain_of() else {
203
6
        if cfg!(debug_assertions) {
204
6
            trace!(expr=%expr, why = "could not find the domain of the expression", "to_aux_var() failed");
205
        }
206
6
        return None;
207
    };
208

            
209
789
    let decl = symbols.gensym(&domain);
210

            
211
789
    if cfg!(debug_assertions) {
212
789
        trace!(expr=%expr, "to_auxvar() succeeded in putting expr into an auxvar");
213
    }
214

            
215
789
    Some(ToAuxVarOutput {
216
789
        aux_declaration: decl.clone(),
217
789
        aux_expression: Expr::AuxDeclaration(
218
789
            Metadata::new(),
219
789
            conjure_cp::ast::Reference::new(decl),
220
789
            Moo::new(expr.clone()),
221
789
        ),
222
789
        symbols,
223
789
        _unconstructable: (),
224
789
    })
225
9150
}
226

            
227
/// Output data of `to_aux_var`.
228
pub struct ToAuxVarOutput {
229
    aux_declaration: DeclarationPtr,
230
    aux_expression: Expr,
231
    symbols: SymbolTable,
232
    _unconstructable: (),
233
}
234

            
235
impl ToAuxVarOutput {
236
    /// Returns the new auxiliary variable as an `Atom`.
237
789
    pub fn as_atom(&self) -> Atom {
238
789
        Atom::Reference(conjure_cp::ast::Reference::new(
239
789
            self.aux_declaration.clone(),
240
789
        ))
241
789
    }
242

            
243
    /// Returns the new auxiliary variable as an `Expression`.
244
    ///
245
    /// This expression will have default `Metadata`.
246
717
    pub fn as_expr(&self) -> Expr {
247
717
        Expr::Atomic(Metadata::new(), self.as_atom())
248
717
    }
249

            
250
    /// Returns the top level `Expression` to add to the model.
251
789
    pub fn top_level_expr(&self) -> Expr {
252
789
        self.aux_expression.clone()
253
789
    }
254

            
255
    /// Returns the new `SymbolTable`, modified to contain this auxiliary variable in the symbol table.
256
789
    pub fn symbols(&self) -> SymbolTable {
257
789
        self.symbols.clone()
258
789
    }
259
}