1
use std::rc::Rc;
2

            
3
use crate::ast::{Declaration, SymbolTable};
4
use tracing::instrument;
5
use uniplate::{Biplate, Uniplate};
6

            
7
use crate::{
8
    ast::{Atom, Domain, Expression as Expr, Name},
9
    metadata::Metadata,
10
};
11

            
12
/// True iff `expr` is an `Atom`.
13
31590
pub fn is_atom(expr: &Expr) -> bool {
14
31590
    matches!(expr, Expr::Atomic(_, _))
15
31590
}
16

            
17
/// True if `expr` is flat; i.e. it only contains atoms.
18
pub fn is_flat(expr: &Expr) -> bool {
19
    for e in expr.children() {
20
        if !is_atom(&e) {
21
            return false;
22
        }
23
    }
24
    true
25
}
26

            
27
/// True if the entire AST is constants.
28
948294
pub fn is_all_constant(expression: &Expr) -> bool {
29
963198
    for atom in expression.universe_bi() {
30
963198
        match atom {
31
169992
            Atom::Literal(_) => {}
32
            _ => {
33
793206
                return false;
34
            }
35
        }
36
    }
37

            
38
155088
    true
39
948294
}
40

            
41
/// Converts a vector of expressions to a vector of atoms.
42
///
43
/// # Returns
44
///
45
/// `Some(Vec<Atom>)` if the vectors direct children expressions are all atomic, otherwise `None`.
46
#[allow(dead_code)]
47
pub fn expressions_to_atoms(exprs: &Vec<Expr>) -> Option<Vec<Atom>> {
48
    let mut atoms: Vec<Atom> = vec![];
49
    for expr in exprs {
50
        let Expr::Atomic(_, atom) = expr else {
51
            return None;
52
        };
53
        atoms.push(atom.clone());
54
    }
55

            
56
    Some(atoms)
57
}
58

            
59
/// Creates a new auxiliary variable using the given expression.
60
///
61
/// # Returns
62
///
63
/// * `None` if `Expr` is a `Atom`, or `Expr` does not have a domain (for example, if it is a `Bubble`).
64
///
65
/// * `Some(ToAuxVarOutput)` if successful, containing:
66
///     
67
///     + A new symbol table, modified to include the auxiliary variable.
68
///     + A new top level expression, containing the declaration of the auxiliary variable.
69
///     + A reference to the auxiliary variable to replace the existing expression with.
70
///
71
#[instrument]
72
pub fn to_aux_var(expr: &Expr, symbols: &SymbolTable) -> Option<ToAuxVarOutput> {
73
    let mut symbols = symbols.clone();
74

            
75
    // No need to put an atom in an aux_var
76
    if is_atom(expr) {
77
        return None;
78
    }
79

            
80
    // Anything that should be bubbled, bubble
81
    if !expr.is_safe() {
82
        return None;
83
    }
84

            
85
    let name = symbols.gensym();
86

            
87
    let Some(domain) = expr.domain_of(&symbols) else {
88
        tracing::trace!("could not find domain of {}", expr);
89
        return None;
90
    };
91

            
92
    symbols.insert(Rc::new(Declaration::new_var(name.clone(), domain.clone())))?;
93
    Some(ToAuxVarOutput {
94
        aux_name: name.clone(),
95
        aux_decl: Expr::AuxDeclaration(Metadata::new(), name, Box::new(expr.clone())),
96
        aux_domain: domain,
97
        symbols,
98
        _unconstructable: (),
99
    })
100
}
101

            
102
/// Output data of `to_aux_var`.
103
pub struct ToAuxVarOutput {
104
    aux_name: Name,
105
    aux_decl: Expr,
106
    #[allow(dead_code)] // TODO: aux_domain should be used soon, try removing this pragma
107
    aux_domain: Domain,
108
    symbols: SymbolTable,
109
    _unconstructable: (),
110
}
111

            
112
impl ToAuxVarOutput {
113
    /// Returns the new auxiliary variable as an `Atom`.
114
2340
    pub fn as_atom(&self) -> Atom {
115
2340
        Atom::Reference(self.aux_name())
116
2340
    }
117

            
118
    /// Returns the new auxiliary variable as an `Expression`.
119
    ///
120
    /// This expression will have default `Metadata`.
121
2142
    pub fn as_expr(&self) -> Expr {
122
2142
        Expr::Atomic(Metadata::new(), self.as_atom())
123
2142
    }
124

            
125
    /// Returns the top level `Expression` to add to the model.
126
2340
    pub fn top_level_expr(&self) -> Expr {
127
2340
        self.aux_decl.clone()
128
2340
    }
129

            
130
    /// Returns the new `SymbolTable`, modified to contain this auxiliary variable in the symbol table.
131
2340
    pub fn symbols(&self) -> SymbolTable {
132
2340
        self.symbols.clone()
133
2340
    }
134

            
135
    /// Returns the name of the auxiliary variable.
136
2340
    pub fn aux_name(&self) -> Name {
137
2340
        self.aux_name.clone()
138
2340
    }
139
}