conjure_core/ast/
literals.rs

1use itertools::Itertools;
2use serde::{Deserialize, Serialize};
3use std::fmt::{Display, Formatter};
4use std::hash::Hash;
5use std::hash::Hasher;
6
7use uniplate::derive::Uniplate;
8use uniplate::{Biplate, Tree, Uniplate};
9
10use super::{records::RecordValue, Atom, Domain, Expression, Range};
11use super::{ReturnType, Typeable};
12
13#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate, Hash)]
14#[uniplate(walk_into=[AbstractLiteral<Literal>])]
15#[biplate(to=Atom)]
16#[biplate(to=AbstractLiteral<Literal>)]
17#[biplate(to=AbstractLiteral<Expression>)]
18#[biplate(to=RecordValue<Literal>,walk_into=[AbstractLiteral<Literal>])]
19#[biplate(to=RecordValue<Expression>)]
20#[biplate(to=Expression)]
21/// A literal value, equivalent to constants in Conjure.
22pub enum Literal {
23    Int(i32),
24    Bool(bool),
25    AbstractLiteral(AbstractLiteral<Literal>),
26}
27
28// make possible values of an AbstractLiteral a closed world to make the trait bounds more sane (particularly in Uniplate instances!!)
29pub trait AbstractLiteralValue:
30    Clone + Eq + PartialEq + Display + Uniplate + Biplate<RecordValue<Self>> + 'static
31{
32}
33impl AbstractLiteralValue for Expression {}
34impl AbstractLiteralValue for Literal {}
35
36#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
37pub enum AbstractLiteral<T: AbstractLiteralValue> {
38    Set(Vec<T>),
39
40    /// A 1 dimensional matrix slice with an index domain.
41    Matrix(Vec<T>, Domain),
42
43    // a tuple of literals
44    Tuple(Vec<T>),
45
46    Record(Vec<RecordValue<T>>),
47}
48
49impl Typeable for Literal {
50    fn return_type(&self) -> Option<ReturnType> {
51        match self {
52            Literal::Int(_) => Some(ReturnType::Int),
53            Literal::Bool(_) => Some(ReturnType::Bool),
54            Literal::AbstractLiteral(a) => a.return_type(),
55        }
56    }
57}
58
59// TODO: handle tuples and records
60impl<T: AbstractLiteralValue + Typeable> Typeable for AbstractLiteral<T> {
61    fn return_type(&self) -> Option<ReturnType> {
62        match self {
63            AbstractLiteral::Set(vector) => {
64                Some(ReturnType::Set(Box::new(vector.first()?.return_type()?)))
65            }
66            AbstractLiteral::Matrix(vector, _) => {
67                Some(ReturnType::Matrix(Box::new(vector.first()?.return_type()?)))
68            }
69            _ => None,
70        }
71    }
72}
73
74impl<T> AbstractLiteral<T>
75where
76    T: AbstractLiteralValue,
77{
78    /// Creates a matrix with elements `elems`, with domain `int(1..)`.
79    ///
80    /// This acts as a variable sized list.
81    pub fn matrix_implied_indices(elems: Vec<T>) -> Self {
82        AbstractLiteral::Matrix(elems, Domain::IntDomain(vec![Range::UnboundedR(1)]))
83    }
84
85    /// If the AbstractLiteral is a list, returns its elements.
86    ///
87    /// A list is any a matrix with the domain `int(1..)`. This includes matrix literals without
88    /// any explicitly specified domain.
89    pub fn unwrap_list(&self) -> Option<&Vec<T>> {
90        let AbstractLiteral::Matrix(elems, Domain::IntDomain(ranges)) = self else {
91            return None;
92        };
93
94        let [Range::UnboundedR(1)] = ranges[..] else {
95            return None;
96        };
97
98        Some(elems)
99    }
100}
101
102impl<T> Display for AbstractLiteral<T>
103where
104    T: AbstractLiteralValue,
105{
106    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
107        match self {
108            AbstractLiteral::Set(elems) => {
109                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
110                write!(f, "{{{elems_str}}}")
111            }
112            AbstractLiteral::Matrix(elems, index_domain) => {
113                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
114                write!(f, "[{elems_str};{index_domain}]")
115            }
116            AbstractLiteral::Tuple(elems) => {
117                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
118                write!(f, "({elems_str})")
119            }
120            AbstractLiteral::Record(entries) => {
121                let entries_str: String = entries
122                    .iter()
123                    .map(|entry| format!("{}: {}", entry.name, entry.value))
124                    .join(",");
125                write!(f, "{{{entries_str}}}")
126            }
127        }
128    }
129}
130
131impl Hash for AbstractLiteral<Literal> {
132    fn hash<H: Hasher>(&self, state: &mut H) {
133        match self {
134            AbstractLiteral::Set(vec) => {
135                0.hash(state);
136                vec.hash(state);
137            }
138            AbstractLiteral::Matrix(elems, index_domain) => {
139                1.hash(state);
140                elems.hash(state);
141                index_domain.hash(state);
142            }
143            AbstractLiteral::Tuple(elems) => {
144                2.hash(state);
145                elems.hash(state);
146            }
147            AbstractLiteral::Record(entries) => {
148                3.hash(state);
149                entries.hash(state);
150            }
151        }
152    }
153}
154
155impl<T> Uniplate for AbstractLiteral<T>
156where
157    T: AbstractLiteralValue + Biplate<AbstractLiteral<T>>,
158{
159    fn uniplate(&self) -> (Tree<Self>, Box<dyn Fn(Tree<Self>) -> Self>) {
160        // walking into T
161        match self {
162            AbstractLiteral::Set(vec) => {
163                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(vec);
164                (f1_tree, Box::new(move |x| AbstractLiteral::Set(f1_ctx(x))))
165            }
166            AbstractLiteral::Matrix(elems, index_domain) => {
167                let index_domain = index_domain.clone();
168                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(elems);
169                (
170                    f1_tree,
171                    Box::new(move |x| AbstractLiteral::Matrix(f1_ctx(x), index_domain.clone())),
172                )
173            }
174            AbstractLiteral::Tuple(elems) => {
175                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(elems);
176                (
177                    f1_tree,
178                    Box::new(move |x| AbstractLiteral::Tuple(f1_ctx(x))),
179                )
180            }
181            AbstractLiteral::Record(entries) => {
182                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(entries);
183                (
184                    f1_tree,
185                    Box::new(move |x| AbstractLiteral::Record(f1_ctx(x))),
186                )
187            }
188        }
189    }
190}
191
192impl<U, To> Biplate<To> for AbstractLiteral<U>
193where
194    To: Uniplate,
195    U: AbstractLiteralValue + Biplate<AbstractLiteral<U>> + Biplate<To>,
196    RecordValue<U>: Biplate<AbstractLiteral<U>> + Biplate<To>,
197{
198    fn biplate(&self) -> (Tree<To>, Box<dyn Fn(Tree<To>) -> Self>) {
199        if std::any::TypeId::of::<To>() == std::any::TypeId::of::<AbstractLiteral<U>>() {
200            // To ==From => return One(self)
201
202            unsafe {
203                // SAFETY: asserted the type equality above
204                let self_to = std::mem::transmute::<&AbstractLiteral<U>, &To>(self).clone();
205                let tree = Tree::One(self_to.clone());
206                let ctx = Box::new(move |x| {
207                    let Tree::One(x) = x else {
208                        panic!();
209                    };
210
211                    std::mem::transmute::<&To, &AbstractLiteral<U>>(&x).clone()
212                });
213
214                (tree, ctx)
215            }
216        } else {
217            // walking into T
218            match self {
219                AbstractLiteral::Set(vec) => {
220                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(vec);
221                    (f1_tree, Box::new(move |x| AbstractLiteral::Set(f1_ctx(x))))
222                }
223                AbstractLiteral::Matrix(elems, index_domain) => {
224                    let index_domain = index_domain.clone();
225                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(elems);
226                    (
227                        f1_tree,
228                        Box::new(move |x| AbstractLiteral::Matrix(f1_ctx(x), index_domain.clone())),
229                    )
230                }
231                AbstractLiteral::Tuple(elems) => {
232                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(elems);
233                    (
234                        f1_tree,
235                        Box::new(move |x| AbstractLiteral::Tuple(f1_ctx(x))),
236                    )
237                }
238                AbstractLiteral::Record(entries) => {
239                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(entries);
240                    (
241                        f1_tree,
242                        Box::new(move |x| AbstractLiteral::Record(f1_ctx(x))),
243                    )
244                }
245            }
246        }
247    }
248}
249
250impl TryFrom<Literal> for i32 {
251    type Error = &'static str;
252
253    fn try_from(value: Literal) -> Result<Self, Self::Error> {
254        match value {
255            Literal::Int(i) => Ok(i),
256            _ => Err("Cannot convert non-i32 literal to i32"),
257        }
258    }
259}
260
261impl TryFrom<&Literal> for i32 {
262    type Error = &'static str;
263
264    fn try_from(value: &Literal) -> Result<Self, Self::Error> {
265        match value {
266            Literal::Int(i) => Ok(*i),
267            _ => Err("Cannot convert non-i32 literal to i32"),
268        }
269    }
270}
271
272impl TryFrom<Literal> for bool {
273    type Error = &'static str;
274
275    fn try_from(value: Literal) -> Result<Self, Self::Error> {
276        match value {
277            Literal::Bool(b) => Ok(b),
278            _ => Err("Cannot convert non-bool literal to bool"),
279        }
280    }
281}
282
283impl TryFrom<&Literal> for bool {
284    type Error = &'static str;
285
286    fn try_from(value: &Literal) -> Result<Self, Self::Error> {
287        match value {
288            Literal::Bool(b) => Ok(*b),
289            _ => Err("Cannot convert non-bool literal to bool"),
290        }
291    }
292}
293
294impl From<i32> for Literal {
295    fn from(i: i32) -> Self {
296        Literal::Int(i)
297    }
298}
299
300impl From<bool> for Literal {
301    fn from(b: bool) -> Self {
302        Literal::Bool(b)
303    }
304}
305
306impl AbstractLiteral<Expression> {
307    /// If all the elements are literals, returns this as an AbstractLiteral<Literal>.
308    /// Otherwise, returns `None`.
309    pub fn as_literals(self) -> Option<AbstractLiteral<Literal>> {
310        match self {
311            AbstractLiteral::Set(_) => todo!(),
312            AbstractLiteral::Matrix(items, domain) => {
313                let mut literals = vec![];
314                for item in items {
315                    let literal = match item {
316                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
317                        Expression::AbstractLiteral(_, abslit) => {
318                            Some(Literal::AbstractLiteral(abslit.as_literals()?))
319                        }
320                        _ => None,
321                    }?;
322                    literals.push(literal);
323                }
324
325                Some(AbstractLiteral::Matrix(literals, domain))
326            }
327            AbstractLiteral::Tuple(items) => {
328                let mut literals = vec![];
329                for item in items {
330                    let literal = match item {
331                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
332                        Expression::AbstractLiteral(_, abslit) => {
333                            Some(Literal::AbstractLiteral(abslit.as_literals()?))
334                        }
335                        _ => None,
336                    }?;
337                    literals.push(literal);
338                }
339
340                Some(AbstractLiteral::Tuple(literals))
341            }
342            AbstractLiteral::Record(entries) => {
343                let mut literals = vec![];
344                for entry in entries {
345                    let literal = match entry.value {
346                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
347                        Expression::AbstractLiteral(_, abslit) => {
348                            Some(Literal::AbstractLiteral(abslit.as_literals()?))
349                        }
350                        _ => None,
351                    }?;
352
353                    literals.push((entry.name, literal));
354                }
355                Some(AbstractLiteral::Record(
356                    literals
357                        .into_iter()
358                        .map(|(name, literal)| RecordValue {
359                            name,
360                            value: literal,
361                        })
362                        .collect(),
363                ))
364            }
365        }
366    }
367}
368
369// need display implementations for other types as well
370impl Display for Literal {
371    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
372        match &self {
373            Literal::Int(i) => write!(f, "{}", i),
374            Literal::Bool(b) => write!(f, "{}", b),
375            Literal::AbstractLiteral(l) => write!(f, "{:?}", l),
376        }
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use std::sync::Arc;
383
384    use super::*;
385    use crate::{into_matrix, matrix};
386    use uniplate::Uniplate;
387
388    #[test]
389    fn matrix_uniplate_universe() {
390        // Can we traverse through matrices with uniplate?
391        let my_matrix: AbstractLiteral<Literal> = into_matrix![
392            vec![Literal::AbstractLiteral(matrix![Literal::Bool(true);Domain::BoolDomain]); 5];
393            Domain::BoolDomain
394        ];
395
396        let expected_index_domains = vec![Domain::BoolDomain; 6];
397        let actual_index_domains: Vec<Domain> = my_matrix.cata(Arc::new(move |elem, children| {
398            let mut res = vec![];
399            res.extend(children.into_iter().flatten());
400            if let AbstractLiteral::Matrix(_, index_domain) = elem {
401                res.push(index_domain);
402            }
403
404            res
405        }));
406
407        assert_eq!(actual_index_domains, expected_index_domains);
408    }
409}