1
use conjure_cp::{
2
    ast::Metadata,
3
    ast::{Atom, DeclarationPtr, Expression as Expr, Moo, SymbolTable, categories::Category},
4
};
5

            
6
use tracing::{instrument, trace};
7
use uniplate::{Biplate, Uniplate};
8

            
9
/// True iff `expr` is an `Atom`.
10
pub fn is_atom(expr: &Expr) -> bool {
11
    matches!(expr, Expr::Atomic(_, _))
12
}
13

            
14
/// True iff `expr` is an `Atom` or `Not(Atom)`.
15
pub fn is_literal(expr: &Expr) -> bool {
16
    match expr {
17
        Expr::Atomic(_, _) => true,
18
        Expr::Not(_, inner) => matches!(**inner, Expr::Atomic(_, _)),
19
        _ => false,
20
    }
21
}
22

            
23
/// True if `expr` is flat; i.e. it only contains atoms.
24
pub fn is_flat(expr: &Expr) -> bool {
25
    for e in expr.children() {
26
        if !is_atom(&e) {
27
            return false;
28
        }
29
    }
30
    true
31
}
32

            
33
/// True if the entire AST is constants.
34
pub fn is_all_constant(expression: &Expr) -> bool {
35
    for atom in expression.universe_bi() {
36
        match atom {
37
            Atom::Literal(_) => {}
38
            _ => {
39
                return false;
40
            }
41
        }
42
    }
43

            
44
    true
45
}
46

            
47
/// Converts a vector of expressions to a vector of atoms.
48
///
49
/// # Returns
50
///
51
/// `Some(Vec<Atom>)` if the vectors direct children expressions are all atomic, otherwise `None`.
52
#[allow(dead_code)]
53
pub fn expressions_to_atoms(exprs: &Vec<Expr>) -> Option<Vec<Atom>> {
54
    let mut atoms: Vec<Atom> = vec![];
55
    for expr in exprs {
56
        let Expr::Atomic(_, atom) = expr else {
57
            return None;
58
        };
59
        atoms.push(atom.clone());
60
    }
61

            
62
    Some(atoms)
63
}
64

            
65
/// Creates a new auxiliary variable using the given expression.
66
///
67
/// # Returns
68
///
69
/// * `None` if `Expr` is a `Atom`, or `Expr` does not have a domain (for example, if it is a `Bubble`).
70
///
71
/// * `Some(ToAuxVarOutput)` if successful, containing:
72
///
73
///     + A new symbol table, modified to include the auxiliary variable.
74
///     + A new top level expression, containing the declaration of the auxiliary variable.
75
///     + A reference to the auxiliary variable to replace the existing expression with.
76
///
77
#[instrument(skip_all, fields(expr = %expr))]
78
pub fn to_aux_var(expr: &Expr, symbols: &SymbolTable) -> Option<ToAuxVarOutput> {
79
    let mut symbols = symbols.clone();
80

            
81
    // No need to put an atom in an aux_var
82
    if is_atom(expr) {
83
        if cfg!(debug_assertions) {
84
            trace!(why = "expression is an atom", "to_aux_var() failed");
85
        }
86
        return None;
87
    }
88

            
89
    // Anything that should be bubbled, bubble
90
    if !expr.is_safe() {
91
        if cfg!(debug_assertions) {
92
            trace!(why = "expression is unsafe", "to_aux_var() failed");
93
        }
94
        return None;
95
    }
96

            
97
    // Do not put abstract literals containing expressions into aux vars.
98
    //
99
    // e.g. for `[1,2,3,f/2,e][e]`, the lhs should not be put in an aux var.
100
    //
101
    // instead, we should flatten the elements inside this abstract literal, or wait for it to be
102
    // turned into an atom, or an abstract literal containing only literals - e.g. through an index
103
    // or slice operation.
104
    //
105
    if let Expr::AbstractLiteral(_, _) = expr {
106
        if cfg!(debug_assertions) {
107
            trace!(
108
                why = "expression is an abstract literal",
109
                "to_aux_var() failed"
110
            );
111
        }
112
        return None;
113
    }
114

            
115
    // Only flatten an expression if it contains decision variables or decision variables with some
116
    // constants.
117
    //
118
    // i.e. dont flatten things containing givens, quantified variables, just constants, etc.
119
    let categories = expr.universe_categories();
120

            
121
    assert!(!categories.is_empty());
122

            
123
    if !(categories.len() == 1 && categories.contains(&Category::Decision)
124
        || categories.len() == 2
125
            && categories.contains(&Category::Decision)
126
            && categories.contains(&Category::Constant))
127
    {
128
        if cfg!(debug_assertions) {
129
            trace!(
130
                why = "expression has sub-expressions that are not in the decision category",
131
                "to_aux_var() failed"
132
            );
133
        }
134
        return None;
135
    }
136

            
137
    // FIXME: why does removing this make tests fail!
138
    //
139
    // do not put matrix[e] in auxvar
140
    //
141
    // eventually this will rewrite into an indomain constraint, or a single variable.
142
    //
143
    // To understand why deferring this until a lower level constraint is chosen is good, consider
144
    // the comprehension:
145
    //
146
    // and([m[i] = i + 1  | i: int(1..5)])
147
    //
148
    // Here, if we rewrite inside the comprehension, we will end up making the auxvar
149
    // __0  = m[i].
150
    //
151
    // When we expand the matrix, this will expand to:
152
    //
153
    // __0 = m[1]
154
    // __1 = m[2]
155
    // __2 = m[3]
156
    // __3 = m[4]
157
    // __4 = m[5]
158
    //
159
    //
160
    // These all rewrite to variable references (e.g. m[1] ~> m#matrix_to_atom_1), so these auxvars
161
    // are redundant. However, we don't know this before expanding, as they are just m[i].
162
    //
163
    // In the future, we can do this more fine-grained using categories (e.g. only flatten matrices
164
    // indexed by expressions with the decision variable category) : however, doing this for
165
    // all matrix indexing is fine, as they can be rewritten into a lower-level expression, then
166
    // flattened.
167
    if let Expr::SafeIndex(_, _, _) = expr {
168
        if cfg!(debug_assertions) {
169
            trace!(expr=%expr, why = "expression is an matrix indexing operation", "to_aux_var() failed");
170
        }
171
        return None;
172
    }
173

            
174
    let Some(domain) = expr.domain_of() else {
175
        if cfg!(debug_assertions) {
176
            trace!(expr=%expr, why = "could not find the domain of the expression", "to_aux_var() failed");
177
        }
178
        return None;
179
    };
180

            
181
    let decl = symbols.gensym(&domain);
182

            
183
    if cfg!(debug_assertions) {
184
        trace!(expr=%expr, "to_auxvar() succeeded in putting expr into an auxvar");
185
    }
186

            
187
    Some(ToAuxVarOutput {
188
        aux_declaration: decl.clone(),
189
        aux_expression: Expr::AuxDeclaration(
190
            Metadata::new(),
191
            conjure_cp::ast::Reference::new(decl),
192
            Moo::new(expr.clone()),
193
        ),
194
        symbols,
195
        _unconstructable: (),
196
    })
197
}
198

            
199
/// Output data of `to_aux_var`.
200
pub struct ToAuxVarOutput {
201
    aux_declaration: DeclarationPtr,
202
    aux_expression: Expr,
203
    symbols: SymbolTable,
204
    _unconstructable: (),
205
}
206

            
207
impl ToAuxVarOutput {
208
    /// Returns the new auxiliary variable as an `Atom`.
209
    pub fn as_atom(&self) -> Atom {
210
        Atom::Reference(conjure_cp::ast::Reference::new(
211
            self.aux_declaration.clone(),
212
        ))
213
    }
214

            
215
    /// Returns the new auxiliary variable as an `Expression`.
216
    ///
217
    /// This expression will have default `Metadata`.
218
    pub fn as_expr(&self) -> Expr {
219
        Expr::Atomic(Metadata::new(), self.as_atom())
220
    }
221

            
222
    /// Returns the top level `Expression` to add to the model.
223
    pub fn top_level_expr(&self) -> Expr {
224
        self.aux_expression.clone()
225
    }
226

            
227
    /// Returns the new `SymbolTable`, modified to contain this auxiliary variable in the symbol table.
228
    pub fn symbols(&self) -> SymbolTable {
229
        self.symbols.clone()
230
    }
231
}