Skip to main content

conjure_cp_core/ast/
literals.rs

1use itertools::Itertools;
2use serde::{Deserialize, Serialize};
3use std::fmt::{Display, Formatter};
4use std::hash::Hash;
5use ustr::Ustr;
6
7use super::{
8    Atom, Domain, DomainPtr, Expression, GroundDomain, Metadata, Moo, PartitionAttr, Range,
9    ReturnType, SetAttr, Typeable, domains::HasDomain, domains::Int, records::FieldValue,
10};
11use crate::ast::domains::{MSetAttr, SequenceAttr};
12use crate::ast::pretty::pretty_vec;
13use crate::bug;
14use polyquine::Quine;
15use uniplate::{Biplate, Tree, Uniplate};
16
17#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate, Hash, Quine)]
18#[uniplate(walk_into=[AbstractLiteral<Literal>])]
19#[biplate(to=Atom)]
20#[biplate(to=AbstractLiteral<Literal>)]
21#[biplate(to=AbstractLiteral<Expression>)]
22#[biplate(to=FieldValue<Literal>)]
23#[biplate(to=FieldValue<Expression>)]
24#[biplate(to=Expression)]
25#[path_prefix(conjure_cp::ast)]
26/// A literal value, equivalent to constants in Conjure.
27pub enum Literal {
28    Int(i32),
29    Bool(bool),
30    //abstract literal variant ends in Literal, but that's ok
31    #[allow(clippy::enum_variant_names)]
32    AbstractLiteral(AbstractLiteral<Literal>),
33}
34
35impl HasDomain for Literal {
36    fn domain_of(&self) -> DomainPtr {
37        match self {
38            Literal::Int(i) => Domain::int(vec![Range::Single(*i)]),
39            Literal::Bool(_) => Domain::bool(),
40            Literal::AbstractLiteral(abstract_literal) => abstract_literal.domain_of(),
41        }
42    }
43}
44
45// make possible values of an AbstractLiteral a closed world to make the trait bounds more sane (particularly in Uniplate instances!!)
46pub trait AbstractLiteralValue:
47    Clone + Eq + PartialEq + Display + Uniplate + Biplate<FieldValue<Self>> + 'static
48{
49    type Dom: Clone + Eq + PartialEq + Display + Quine + From<GroundDomain> + Into<DomainPtr>;
50}
51impl AbstractLiteralValue for Expression {
52    type Dom = DomainPtr;
53}
54impl AbstractLiteralValue for Literal {
55    type Dom = Moo<GroundDomain>;
56}
57
58#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Quine)]
59#[path_prefix(conjure_cp::ast)]
60pub enum AbstractLiteral<T: AbstractLiteralValue> {
61    Set(Vec<T>),
62
63    MSet(Vec<T>),
64
65    /// A 1 dimensional matrix slice with an index domain.
66    Matrix(Vec<T>, T::Dom),
67
68    // a tuple of literals
69    Tuple(Vec<T>),
70
71    Record(Vec<FieldValue<T>>),
72
73    Sequence(Vec<T>),
74
75    Function(Vec<(T, T)>),
76
77    // Variants only contain one of their name-domain pairs
78    Variant(Moo<FieldValue<T>>),
79
80    // A list of partitions, each part has a set of values
81    Partition(Vec<Vec<T>>),
82    Relation(Vec<Vec<T>>),
83}
84
85// TODO: use HasDomain instead once Expression::domain_of returns Domain not Option<Domain>
86impl AbstractLiteral<Expression> {
87    pub fn domain_of(&self) -> Option<DomainPtr> {
88        match self {
89            AbstractLiteral::Set(items) => {
90                // ensure that all items have a domain, or return None
91                let item_domains: Vec<DomainPtr> = items
92                    .iter()
93                    .map(|x| x.domain_of())
94                    .collect::<Option<Vec<DomainPtr>>>()?;
95
96                // union all item domains together
97                let mut item_domain_iter = item_domains.iter().cloned();
98                let first_item = item_domain_iter.next()?;
99                let item_domain = item_domains
100                    .iter()
101                    .try_fold(first_item, |x, y| x.union(y))
102                    .expect("taking the union of all item domains of a set literal should succeed");
103
104                Some(Domain::set(SetAttr::<Int>::default(), item_domain))
105            }
106
107            AbstractLiteral::MSet(items) => {
108                // ensure that all items have a domain, or return None
109                let item_domains: Vec<DomainPtr> = items
110                    .iter()
111                    .map(|x| x.domain_of())
112                    .collect::<Option<Vec<DomainPtr>>>()?;
113
114                // union all item domains together
115                let mut item_domain_iter = item_domains.iter().cloned();
116                let first_item = item_domain_iter.next()?;
117                let item_domain = item_domains
118                    .iter()
119                    .try_fold(first_item, |x, y| x.union(y))
120                    .expect("taking the union of all item domains of a set literal should succeed");
121
122                Some(Domain::mset(MSetAttr::<Int>::default(), item_domain))
123            }
124
125            AbstractLiteral::Sequence(elems) => {
126                let item_domains: Vec<DomainPtr> = elems
127                    .iter()
128                    .map(|x| x.domain_of())
129                    .collect::<Option<Vec<DomainPtr>>>()?;
130
131                // Get the union of all domains in the sequence.
132                // i.e. if <(1..3), (1..3), (5), (8..9)> then seq dom is (1..3, 5, 8..9)
133                let mut item_domain_iter = item_domains.iter().cloned();
134                let first_item = item_domain_iter.next()?;
135                let item_domain = item_domains
136                    .iter()
137                    .try_fold(first_item, |x, y| x.union(y))
138                    .expect("taking the union of all item domains of a set literal should succeed");
139
140                Some(Domain::sequence(
141                    SequenceAttr::<Int>::default(),
142                    item_domain,
143                ))
144            }
145
146            AbstractLiteral::Partition(items) => {
147                // Flatten the Vec<Vec< into a single vec
148                // ensure that all elemes in each part have a domain, or return None
149
150                let item_domains: Vec<DomainPtr> = items
151                    .iter()
152                    .flatten()
153                    .map(|x| x.domain_of())
154                    .collect::<Option<Vec<DomainPtr>>>()?;
155
156                // union all item domains together
157                let mut item_domain_iter = item_domains.iter().cloned();
158                let first_item = item_domain_iter.next()?;
159                let item_domain = item_domains
160                    .iter()
161                    .try_fold(first_item, |x, y| x.union(y))
162                    .expect("taking the union of all item domains of a partition literal should succeed");
163
164                Some(Domain::partition(
165                    PartitionAttr::<Int>::default(),
166                    item_domain,
167                ))
168            }
169
170            AbstractLiteral::Matrix(items, _) => {
171                // ensure that all items have a domain, or return None
172                let item_domains = items
173                    .iter()
174                    .map(|x| x.domain_of())
175                    .collect::<Option<Vec<DomainPtr>>>()?;
176
177                // union all item domains together
178                let mut item_domain_iter = item_domains.iter().cloned();
179
180                let first_item = item_domain_iter.next()?;
181
182                let item_domain = item_domains
183                    .iter()
184                    .try_fold(first_item, |x, y| x.union(y))
185                    .expect(
186                        "taking the union of all item domains of a matrix literal should succeed",
187                    );
188
189                let mut new_index_domain = vec![];
190
191                // flatten index domains of n-d matrix into list
192                let mut e = Expression::AbstractLiteral(Metadata::new(), self.clone());
193                while let Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, idx)) = e {
194                    assert!(
195                        idx.as_matrix().is_none(),
196                        "n-dimensional matrix literals should be represented as a matrix inside a matrix, got {idx}"
197                    );
198                    new_index_domain.push(idx);
199                    e = elems[0].clone();
200                }
201                Some(Domain::matrix(item_domain, new_index_domain))
202            }
203            AbstractLiteral::Tuple(_) => None,
204            AbstractLiteral::Record(_) => None,
205            AbstractLiteral::Function(_) => None,
206            AbstractLiteral::Variant(_) => None,
207            AbstractLiteral::Relation(_) => None,
208        }
209    }
210}
211
212impl HasDomain for AbstractLiteral<Literal> {
213    fn domain_of(&self) -> DomainPtr {
214        Domain::from_literal_vec(&[Literal::AbstractLiteral(self.clone())])
215            .expect("abstract literals should be correctly typed")
216    }
217}
218
219impl Typeable for AbstractLiteral<Expression> {
220    fn return_type(&self) -> ReturnType {
221        match self {
222            AbstractLiteral::Set(items) if items.is_empty() => {
223                ReturnType::Set(Box::new(ReturnType::Unknown))
224            }
225            AbstractLiteral::Set(items) => {
226                let item_type = items[0].return_type();
227
228                // if any items do not have a type, return none.
229                let item_types: Vec<ReturnType> = items.iter().map(|x| x.return_type()).collect();
230
231                assert!(
232                    item_types.iter().all(|x| x == &item_type),
233                    "all items in a set should have the same type"
234                );
235
236                ReturnType::Set(Box::new(item_type))
237            }
238            AbstractLiteral::MSet(items) if items.is_empty() => {
239                ReturnType::MSet(Box::new(ReturnType::Unknown))
240            }
241            AbstractLiteral::MSet(items) => {
242                let item_type = items[0].return_type();
243
244                // if any items do not have a type, return none.
245                let item_types: Vec<ReturnType> = items.iter().map(|x| x.return_type()).collect();
246
247                assert!(
248                    item_types.iter().all(|x| x == &item_type),
249                    "all items in a set should have the same type"
250                );
251
252                ReturnType::MSet(Box::new(item_type))
253            }
254            AbstractLiteral::Sequence(items) if items.is_empty() => {
255                ReturnType::Sequence(Box::new(ReturnType::Unknown))
256            }
257            AbstractLiteral::Sequence(items) => {
258                let item_type = items[0].return_type();
259
260                // if any items do not have a type, return none.
261                let item_types: Vec<ReturnType> = items.iter().map(|x| x.return_type()).collect();
262
263                assert!(
264                    item_types.iter().all(|x| x == &item_type),
265                    "all items in a sequence should have the same type"
266                );
267
268                ReturnType::Sequence(Box::new(item_type))
269            }
270            AbstractLiteral::Partition(items) if items.is_empty() || items[0].is_empty() => {
271                ReturnType::Partition(Box::new(ReturnType::Unknown))
272            }
273            AbstractLiteral::Partition(items) => {
274                let item_type = items[0][0].return_type();
275
276                // if any items do not have a type, return none.
277                let item_types: Vec<ReturnType> =
278                    items.iter().flatten().map(|x| x.return_type()).collect();
279
280                assert!(
281                    item_types.iter().all(|x| x == &item_type),
282                    "all items in every part of a partition should have the same type"
283                );
284
285                ReturnType::Partition(Box::new(item_type))
286            }
287            AbstractLiteral::Matrix(items, _) if items.is_empty() => {
288                ReturnType::Matrix(Box::new(ReturnType::Unknown))
289            }
290            AbstractLiteral::Matrix(items, _) => {
291                let item_type = items[0].return_type();
292
293                // if any items do not have a type, return none.
294                let item_types: Vec<ReturnType> = items.iter().map(|x| x.return_type()).collect();
295
296                assert!(
297                    item_types.iter().all(|x| x == &item_type),
298                    "all items in a matrix should have the same type. items: {items} types: {types:#?}",
299                    items = pretty_vec(items),
300                    types = items
301                        .iter()
302                        .map(|x| x.return_type())
303                        .collect::<Vec<ReturnType>>()
304                );
305
306                ReturnType::Matrix(Box::new(item_type))
307            }
308            AbstractLiteral::Tuple(items) => {
309                let mut item_types = vec![];
310                for item in items {
311                    item_types.push(item.return_type());
312                }
313                ReturnType::Tuple(item_types)
314            }
315            AbstractLiteral::Record(items) => {
316                let mut item_types = vec![];
317                for item in items {
318                    item_types.push(item.value.return_type());
319                }
320                ReturnType::Record(item_types)
321            }
322            AbstractLiteral::Function(items) => {
323                if items.is_empty() {
324                    return ReturnType::Function(
325                        Box::new(ReturnType::Unknown),
326                        Box::new(ReturnType::Unknown),
327                    );
328                }
329
330                // Check that all items have the same return type
331                let (x1, y1) = &items[0];
332                let (t1, t2) = (x1.return_type(), y1.return_type());
333                for (x, y) in items {
334                    let (tx, ty) = (x.return_type(), y.return_type());
335                    if tx != t1 {
336                        bug!("Expected {t1}, got {x}: {tx}");
337                    }
338                    if ty != t2 {
339                        bug!("Expected {t2}, got {y}: {ty}");
340                    }
341                }
342
343                ReturnType::Function(Box::new(t1), Box::new(t2))
344            }
345            AbstractLiteral::Variant(item) => {
346                // Variants hold multiple possible types. In the case of a literal we know which type it chose
347                ReturnType::Variant(vec![item.value.return_type()])
348            }
349            AbstractLiteral::Relation(items) => {
350                if items.is_empty() {
351                    return ReturnType::Relation(vec![ReturnType::Unknown]);
352                }
353                let mut item_types = vec![];
354                let x1 = &items[0];
355                let size = x1.len();
356                for item in x1 {
357                    item_types.push(item.return_type());
358                }
359                for x in items {
360                    if x.len() != size {
361                        let strs = item_types.iter().map(|x| format!("{}", x)).join(",");
362                        bug!("Expected ({strs}) with length {size}, got size {}", x.len());
363                    }
364                    for i in 1..size {
365                        if let Some(new_type) = x.get(i)
366                            && let Some(old_type) = item_types.get(i)
367                            && new_type.return_type() != *old_type
368                        {
369                            bug!("Expected {old_type}, got {new_type}");
370                        }
371                    }
372                }
373                ReturnType::Relation(item_types)
374            }
375        }
376    }
377}
378
379impl<T> AbstractLiteral<T>
380where
381    T: AbstractLiteralValue,
382{
383    /// Creates a matrix with elements `elems`, with domain `int(1..)`.
384    ///
385    /// This acts as a variable sized list.
386    pub fn matrix_implied_indices(elems: Vec<T>) -> Self {
387        AbstractLiteral::Matrix(elems, GroundDomain::Int(vec![Range::UnboundedR(1)]).into())
388    }
389
390    /// If the AbstractLiteral is a list, returns its elements.
391    ///
392    /// A list is any a matrix with the domain `int(1..)`. This includes matrix literals without
393    /// any explicitly specified domain.
394    pub fn unwrap_list(&self) -> Option<&Vec<T>> {
395        let AbstractLiteral::Matrix(elems, domain) = self else {
396            return None;
397        };
398
399        let domain: DomainPtr = domain.clone().into();
400        let Some(GroundDomain::Int(ranges)) = domain.as_ground() else {
401            return None;
402        };
403
404        let [Range::UnboundedR(1)] = ranges[..] else {
405            return None;
406        };
407
408        Some(elems)
409    }
410}
411
412impl<T> Display for AbstractLiteral<T>
413where
414    T: AbstractLiteralValue,
415{
416    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
417        match self {
418            AbstractLiteral::Set(elems) => {
419                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
420                write!(f, "{{{elems_str}}}")
421            }
422            AbstractLiteral::MSet(elems) => {
423                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
424                write!(f, "mset({elems_str})")
425            }
426            AbstractLiteral::Matrix(elems, index_domain) => {
427                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
428                write!(f, "[{elems_str};{index_domain}]")
429            }
430            AbstractLiteral::Tuple(elems) => {
431                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
432                write!(f, "({elems_str})")
433            }
434            AbstractLiteral::Sequence(elems) => {
435                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
436                write!(f, "sequence({elems_str})")
437            }
438            AbstractLiteral::Partition(parts) => {
439                let elems_str: String = parts
440                    .iter()
441                    .map(|inner| {
442                        let elems_str = inner.iter().map(|x| format!("{x}")).join(",");
443                        format!("{{{}}}", elems_str)
444                    })
445                    .join(", ");
446
447                write!(f, "partition({elems_str})")
448            }
449            AbstractLiteral::Record(entries) => {
450                let entries_str: String = entries
451                    .iter()
452                    .map(|entry| format!("{} = {}", entry.name, entry.value))
453                    .join(",");
454                write!(f, "record {{{entries_str}}}")
455            }
456            AbstractLiteral::Function(entries) => {
457                let entries_str: String = entries
458                    .iter()
459                    .map(|entry| format!("{} --> {}", entry.0, entry.1))
460                    .join(",");
461                write!(f, "function({entries_str})")
462            }
463            AbstractLiteral::Variant(entry) => {
464                write!(f, "variant{{{} = {}}}", entry.name, entry.value)
465            }
466            AbstractLiteral::Relation(elems) => {
467                let elems_str: String = elems
468                    .iter()
469                    .map(|x| format!("({})", x.iter().map(|x| format!("{x}")).join(",")))
470                    .join(",");
471                write!(f, "relation({elems_str})")
472            }
473        }
474    }
475}
476
477impl<T> Uniplate for AbstractLiteral<T>
478where
479    T: AbstractLiteralValue + Biplate<AbstractLiteral<T>>,
480{
481    fn uniplate(&self) -> (Tree<Self>, Box<dyn Fn(Tree<Self>) -> Self>) {
482        // walking into T
483        match self {
484            AbstractLiteral::Set(vec) => {
485                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(vec);
486                (f1_tree, Box::new(move |x| AbstractLiteral::Set(f1_ctx(x))))
487            }
488            AbstractLiteral::MSet(vec) => {
489                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(vec);
490                (f1_tree, Box::new(move |x| AbstractLiteral::MSet(f1_ctx(x))))
491            }
492            AbstractLiteral::Matrix(elems, index_domain) => {
493                let index_domain = index_domain.clone();
494                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(elems);
495                (
496                    f1_tree,
497                    Box::new(move |x| AbstractLiteral::Matrix(f1_ctx(x), index_domain.clone())),
498                )
499            }
500            AbstractLiteral::Sequence(vec) => {
501                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(vec);
502                (
503                    f1_tree,
504                    Box::new(move |x| AbstractLiteral::Sequence(f1_ctx(x))),
505                )
506            }
507            AbstractLiteral::Tuple(elems) => {
508                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(elems);
509                (
510                    f1_tree,
511                    Box::new(move |x| AbstractLiteral::Tuple(f1_ctx(x))),
512                )
513            }
514            AbstractLiteral::Record(entries) => {
515                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(entries);
516                (
517                    f1_tree,
518                    Box::new(move |x| AbstractLiteral::Record(f1_ctx(x))),
519                )
520            }
521            AbstractLiteral::Function(entries) => {
522                let entry_count = entries.len();
523                let flattened: Vec<T> = entries
524                    .iter()
525                    .flat_map(|(lhs, rhs)| [lhs.clone(), rhs.clone()])
526                    .collect();
527
528                let (f1_tree, f1_ctx) =
529                    <Vec<T> as Biplate<AbstractLiteral<T>>>::biplate(&flattened);
530                (
531                    f1_tree,
532                    Box::new(move |x| {
533                        let rebuilt = f1_ctx(x);
534                        assert_eq!(
535                            rebuilt.len(),
536                            entry_count * 2,
537                            "number of function literal children should remain unchanged"
538                        );
539
540                        let mut iter = rebuilt.into_iter();
541                        let mut pairs = Vec::with_capacity(entry_count);
542                        while let (Some(lhs), Some(rhs)) = (iter.next(), iter.next()) {
543                            pairs.push((lhs, rhs));
544                        }
545
546                        AbstractLiteral::Function(pairs)
547                    }),
548                )
549            }
550            AbstractLiteral::Variant(entries) => {
551                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(entries);
552                (
553                    f1_tree,
554                    Box::new(move |x| AbstractLiteral::Variant(f1_ctx(x))),
555                )
556            }
557            AbstractLiteral::Relation(elems) => {
558                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(elems);
559                (
560                    f1_tree,
561                    Box::new(move |x| AbstractLiteral::Relation(f1_ctx(x))),
562                )
563            }
564            AbstractLiteral::Partition(elems) => {
565                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(elems);
566                (
567                    f1_tree,
568                    Box::new(move |x| AbstractLiteral::Partition(f1_ctx(x))),
569                )
570            }
571        }
572    }
573}
574
575impl<U, To> Biplate<To> for AbstractLiteral<U>
576where
577    To: Uniplate,
578    U: AbstractLiteralValue + Biplate<AbstractLiteral<U>> + Biplate<To>,
579    FieldValue<U>: Biplate<AbstractLiteral<U>> + Biplate<To>,
580{
581    fn biplate(&self) -> (Tree<To>, Box<dyn Fn(Tree<To>) -> Self>) {
582        if std::any::TypeId::of::<To>() == std::any::TypeId::of::<AbstractLiteral<U>>() {
583            // To ==From => return One(self)
584
585            unsafe {
586                // SAFETY: asserted the type equality above
587                let self_to = std::mem::transmute::<&AbstractLiteral<U>, &To>(self).clone();
588                let tree = Tree::One(self_to);
589                let ctx = Box::new(move |x| {
590                    let Tree::One(x) = x else {
591                        panic!();
592                    };
593
594                    std::mem::transmute::<&To, &AbstractLiteral<U>>(&x).clone()
595                });
596
597                (tree, ctx)
598            }
599        } else {
600            // walking into T
601            match self {
602                AbstractLiteral::Set(vec) => {
603                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(vec);
604                    (f1_tree, Box::new(move |x| AbstractLiteral::Set(f1_ctx(x))))
605                }
606                AbstractLiteral::MSet(vec) => {
607                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(vec);
608                    (f1_tree, Box::new(move |x| AbstractLiteral::MSet(f1_ctx(x))))
609                }
610                AbstractLiteral::Matrix(elems, index_domain) => {
611                    let index_domain = index_domain.clone();
612                    let (f1_tree, f1_ctx) = <Vec<U> as Biplate<To>>::biplate(elems);
613                    (
614                        f1_tree,
615                        Box::new(move |x| AbstractLiteral::Matrix(f1_ctx(x), index_domain.clone())),
616                    )
617                }
618                AbstractLiteral::Sequence(vec) => {
619                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(vec);
620                    (
621                        f1_tree,
622                        Box::new(move |x| AbstractLiteral::Sequence(f1_ctx(x))),
623                    )
624                }
625                AbstractLiteral::Tuple(elems) => {
626                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(elems);
627                    (
628                        f1_tree,
629                        Box::new(move |x| AbstractLiteral::Tuple(f1_ctx(x))),
630                    )
631                }
632                AbstractLiteral::Record(entries) => {
633                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(entries);
634                    (
635                        f1_tree,
636                        Box::new(move |x| AbstractLiteral::Record(f1_ctx(x))),
637                    )
638                }
639                AbstractLiteral::Function(entries) => {
640                    let entry_count = entries.len();
641                    let flattened: Vec<U> = entries
642                        .iter()
643                        .flat_map(|(lhs, rhs)| [lhs.clone(), rhs.clone()])
644                        .collect();
645
646                    let (f1_tree, f1_ctx) = <Vec<U> as Biplate<To>>::biplate(&flattened);
647                    (
648                        f1_tree,
649                        Box::new(move |x| {
650                            let rebuilt = f1_ctx(x);
651                            assert_eq!(
652                                rebuilt.len(),
653                                entry_count * 2,
654                                "number of function literal children should remain unchanged"
655                            );
656
657                            let mut iter = rebuilt.into_iter();
658                            let mut pairs = Vec::with_capacity(entry_count);
659                            while let (Some(lhs), Some(rhs)) = (iter.next(), iter.next()) {
660                                pairs.push((lhs, rhs));
661                            }
662
663                            AbstractLiteral::Function(pairs)
664                        }),
665                    )
666                }
667                AbstractLiteral::Variant(entries) => {
668                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(entries);
669                    (
670                        f1_tree,
671                        Box::new(move |x| AbstractLiteral::Variant(f1_ctx(x))),
672                    )
673                }
674                AbstractLiteral::Relation(elems) => {
675                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(elems);
676                    (
677                        f1_tree,
678                        Box::new(move |x| AbstractLiteral::Relation(f1_ctx(x))),
679                    )
680                }
681                AbstractLiteral::Partition(elems) => {
682                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(elems);
683                    (
684                        f1_tree,
685                        Box::new(move |x| AbstractLiteral::Partition(f1_ctx(x))),
686                    )
687                }
688            }
689        }
690    }
691}
692
693impl TryFrom<Literal> for i32 {
694    type Error = &'static str;
695
696    fn try_from(value: Literal) -> Result<Self, Self::Error> {
697        match value {
698            Literal::Int(i) => Ok(i),
699            _ => Err("Cannot convert non-i32 literal to i32"),
700        }
701    }
702}
703
704impl TryFrom<Box<Literal>> for i32 {
705    type Error = &'static str;
706
707    fn try_from(value: Box<Literal>) -> Result<Self, Self::Error> {
708        (*value).try_into()
709    }
710}
711
712impl TryFrom<&Box<Literal>> for i32 {
713    type Error = &'static str;
714
715    fn try_from(value: &Box<Literal>) -> Result<Self, Self::Error> {
716        TryFrom::<&Literal>::try_from(value.as_ref())
717    }
718}
719
720impl TryFrom<&Moo<Literal>> for i32 {
721    type Error = &'static str;
722
723    fn try_from(value: &Moo<Literal>) -> Result<Self, Self::Error> {
724        TryFrom::<&Literal>::try_from(value.as_ref())
725    }
726}
727
728impl TryFrom<&Literal> for i32 {
729    type Error = &'static str;
730
731    fn try_from(value: &Literal) -> Result<Self, Self::Error> {
732        match value {
733            Literal::Int(i) => Ok(*i),
734            _ => Err("Cannot convert non-i32 literal to i32"),
735        }
736    }
737}
738
739impl TryFrom<Literal> for bool {
740    type Error = &'static str;
741
742    fn try_from(value: Literal) -> Result<Self, Self::Error> {
743        match value {
744            Literal::Bool(b) => Ok(b),
745            _ => Err("Cannot convert non-bool literal to bool"),
746        }
747    }
748}
749
750impl TryFrom<&Literal> for bool {
751    type Error = &'static str;
752
753    fn try_from(value: &Literal) -> Result<Self, Self::Error> {
754        match value {
755            Literal::Bool(b) => Ok(*b),
756            _ => Err("Cannot convert non-bool literal to bool"),
757        }
758    }
759}
760
761impl From<i32> for Literal {
762    fn from(i: i32) -> Self {
763        Literal::Int(i)
764    }
765}
766
767impl From<bool> for Literal {
768    fn from(b: bool) -> Self {
769        Literal::Bool(b)
770    }
771}
772
773impl From<Literal> for Ustr {
774    fn from(value: Literal) -> Self {
775        // TODO: avoid the temporary-allocation of a string by format! here?
776        Ustr::from(&format!("{value}"))
777    }
778}
779
780impl AbstractLiteral<Expression> {
781    /// If all the elements are literals, returns this as an AbstractLiteral<Literal>.
782    /// Otherwise, returns `None`.
783    pub fn into_literals(self) -> Option<AbstractLiteral<Literal>> {
784        match self {
785            AbstractLiteral::Set(elements) => {
786                let literals = elements
787                    .into_iter()
788                    .map(|expr| match expr {
789                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
790                        Expression::AbstractLiteral(_, abslit) => {
791                            Some(Literal::AbstractLiteral(abslit.into_literals()?))
792                        }
793                        _ => None,
794                    })
795                    .collect::<Option<Vec<_>>>()?;
796                Some(AbstractLiteral::Set(literals))
797            }
798            AbstractLiteral::MSet(elements) => {
799                let literals = elements
800                    .into_iter()
801                    .map(|expr| match expr {
802                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
803                        Expression::AbstractLiteral(_, abslit) => {
804                            Some(Literal::AbstractLiteral(abslit.into_literals()?))
805                        }
806                        _ => None,
807                    })
808                    .collect::<Option<Vec<_>>>()?;
809                Some(AbstractLiteral::MSet(literals))
810            }
811            AbstractLiteral::Partition(elems) => {
812                // want to ascertain if every elem in Vec<Vec<Expr>> is a literal. If any are not, return none
813                // otherwise confirm it is an abslit<lit>
814                let mut partition: Vec<Vec<_>> = Vec::new();
815
816                for part in elems {
817                    let literals = part
818                        .into_iter()
819                        .map(|expr| match expr {
820                            Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
821                            Expression::AbstractLiteral(_, abslit) => {
822                                Some(Literal::AbstractLiteral(abslit.into_literals()?))
823                            }
824                            _ => None,
825                        })
826                        .collect::<Option<Vec<_>>>()?;
827
828                    partition.push(literals);
829                }
830
831                Some(AbstractLiteral::Partition(partition))
832            }
833            AbstractLiteral::Matrix(items, domain) => {
834                let mut literals = vec![];
835                for item in items {
836                    let literal = match item {
837                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
838                        Expression::AbstractLiteral(_, abslit) => {
839                            Some(Literal::AbstractLiteral(abslit.into_literals()?))
840                        }
841                        _ => None,
842                    }?;
843                    literals.push(literal);
844                }
845
846                Some(AbstractLiteral::Matrix(literals, domain.resolve()?))
847            }
848            AbstractLiteral::Sequence(elements) => {
849                let literals = elements
850                    .into_iter()
851                    .map(|expr| match expr {
852                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
853                        Expression::AbstractLiteral(_, abslit) => {
854                            Some(Literal::AbstractLiteral(abslit.into_literals()?))
855                        }
856                        _ => None,
857                    })
858                    .collect::<Option<Vec<_>>>()?;
859                Some(AbstractLiteral::Sequence(literals))
860            }
861            AbstractLiteral::Tuple(items) => {
862                let mut literals = vec![];
863                for item in items {
864                    let literal = match item {
865                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
866                        Expression::AbstractLiteral(_, abslit) => {
867                            Some(Literal::AbstractLiteral(abslit.into_literals()?))
868                        }
869                        _ => None,
870                    }?;
871                    literals.push(literal);
872                }
873
874                Some(AbstractLiteral::Tuple(literals))
875            }
876            AbstractLiteral::Record(entries) => {
877                let mut literals = vec![];
878                for entry in entries {
879                    let literal = match entry.value {
880                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
881                        Expression::AbstractLiteral(_, abslit) => {
882                            Some(Literal::AbstractLiteral(abslit.into_literals()?))
883                        }
884                        _ => None,
885                    }?;
886
887                    literals.push((entry.name, literal));
888                }
889                Some(AbstractLiteral::Record(
890                    literals
891                        .into_iter()
892                        .map(|(name, literal)| FieldValue {
893                            name,
894                            value: literal,
895                        })
896                        .collect(),
897                ))
898            }
899            AbstractLiteral::Function(_) => todo!("Implement into_literals for functions"),
900            AbstractLiteral::Variant(entry) => {
901                let literal = match entry.value.clone() {
902                    Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
903                    Expression::AbstractLiteral(_, abslit) => {
904                        Some(Literal::AbstractLiteral(abslit.into_literals()?))
905                    }
906                    _ => None,
907                }?;
908                Some(AbstractLiteral::Variant(Moo::new(FieldValue {
909                    name: entry.name.clone(),
910                    value: literal,
911                })))
912            }
913            AbstractLiteral::Relation(_) => todo!("Implement into_literals for relations"),
914        }
915    }
916}
917
918// need display implementations for other types as well
919impl Display for Literal {
920    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
921        match &self {
922            Literal::Int(i) => write!(f, "{i}"),
923            Literal::Bool(b) => write!(f, "{b}"),
924            Literal::AbstractLiteral(l) => write!(f, "{l}"),
925        }
926    }
927}
928
929#[cfg(test)]
930mod tests {
931
932    use super::*;
933    use crate::{into_matrix, matrix};
934    use uniplate::Uniplate;
935
936    #[test]
937    fn matrix_uniplate_universe() {
938        // Can we traverse through matrices with uniplate?
939        let my_matrix: AbstractLiteral<Literal> = into_matrix![
940            vec![Literal::AbstractLiteral(matrix![Literal::Bool(true);Moo::new(GroundDomain::Bool)]); 5];
941            Moo::new(GroundDomain::Bool)
942        ];
943
944        let expected_index_domains = vec![Moo::new(GroundDomain::Bool); 6];
945        let actual_index_domains: Vec<Moo<GroundDomain>> =
946            my_matrix.cata(&move |elem, children| {
947                let mut res = vec![];
948                res.extend(children.into_iter().flatten());
949                if let AbstractLiteral::Matrix(_, index_domain) = elem {
950                    res.push(index_domain);
951                }
952
953                res
954            });
955
956        assert_eq!(actual_index_domains, expected_index_domains);
957    }
958}