1use std::rc::Rc;
23use crate::ast::{Declaration, SymbolTable};
4use tracing::instrument;
5use uniplate::{Biplate, Uniplate};
67use crate::{
8 ast::{Atom, Domain, Expression as Expr, Name},
9 metadata::Metadata,
10};
1112/// True iff `expr` is an `Atom`.
13pub fn is_atom(expr: &Expr) -> bool {
14matches!(expr, Expr::Atomic(_, _))
15}
1617/// True if `expr` is flat; i.e. it only contains atoms.
18pub fn is_flat(expr: &Expr) -> bool {
19for e in expr.children() {
20if !is_atom(&e) {
21return false;
22 }
23 }
24true
25}
2627/// True if the entire AST is constants.
28pub fn is_all_constant(expression: &Expr) -> bool {
29for atom in expression.universe_bi() {
30match atom {
31 Atom::Literal(_) => {}
32_ => {
33return false;
34 }
35 }
36 }
3738true
39}
4041/// Converts a vector of expressions to a vector of atoms.
42///
43/// # Returns
44///
45/// `Some(Vec<Atom>)` if the vectors direct children expressions are all atomic, otherwise `None`.
46#[allow(dead_code)]
47pub fn expressions_to_atoms(exprs: &Vec<Expr>) -> Option<Vec<Atom>> {
48let mut atoms: Vec<Atom> = vec![];
49for expr in exprs {
50let Expr::Atomic(_, atom) = expr else {
51return None;
52 };
53 atoms.push(atom.clone());
54 }
5556Some(atoms)
57}
5859/// Creates a new auxiliary variable using the given expression.
60///
61/// # Returns
62///
63/// * `None` if `Expr` is a `Atom`, or `Expr` does not have a domain (for example, if it is a `Bubble`).
64///
65/// * `Some(ToAuxVarOutput)` if successful, containing:
66///
67/// + A new symbol table, modified to include the auxiliary variable.
68/// + A new top level expression, containing the declaration of the auxiliary variable.
69/// + A reference to the auxiliary variable to replace the existing expression with.
70///
71#[instrument]
72pub fn to_aux_var(expr: &Expr, symbols: &SymbolTable) -> Option<ToAuxVarOutput> {
73let mut symbols = symbols.clone();
7475// No need to put an atom in an aux_var
76if is_atom(expr) {
77return None;
78 }
7980// Anything that should be bubbled, bubble
81if !expr.is_safe() {
82return None;
83 }
8485let name = symbols.gensym();
8687let Some(domain) = expr.domain_of(&symbols) else {
88tracing::trace!("could not find domain of {}", expr);
89return None;
90 };
9192 symbols.insert(Rc::new(Declaration::new_var(name.clone(), domain.clone())))?;
93Some(ToAuxVarOutput {
94 aux_name: name.clone(),
95 aux_decl: Expr::AuxDeclaration(Metadata::new(), name, Box::new(expr.clone())),
96 aux_domain: domain,
97 symbols,
98 _unconstructable: (),
99 })
100}
101102/// Output data of `to_aux_var`.
103pub struct ToAuxVarOutput {
104 aux_name: Name,
105 aux_decl: Expr,
106#[allow(dead_code)] // TODO: aux_domain should be used soon, try removing this pragma
107aux_domain: Domain,
108 symbols: SymbolTable,
109 _unconstructable: (),
110}
111112impl ToAuxVarOutput {
113/// Returns the new auxiliary variable as an `Atom`.
114pub fn as_atom(&self) -> Atom {
115 Atom::Reference(self.aux_name())
116 }
117118/// Returns the new auxiliary variable as an `Expression`.
119 ///
120 /// This expression will have default `Metadata`.
121pub fn as_expr(&self) -> Expr {
122 Expr::Atomic(Metadata::new(), self.as_atom())
123 }
124125/// Returns the top level `Expression` to add to the model.
126pub fn top_level_expr(&self) -> Expr {
127self.aux_decl.clone()
128 }
129130/// Returns the new `SymbolTable`, modified to contain this auxiliary variable in the symbol table.
131pub fn symbols(&self) -> SymbolTable {
132self.symbols.clone()
133 }
134135/// Returns the name of the auxiliary variable.
136pub fn aux_name(&self) -> Name {
137self.aux_name.clone()
138 }
139}