conjure_cp_core/ast/
literals.rs

1use itertools::Itertools;
2use serde::{Deserialize, Serialize};
3use std::fmt::{Display, Formatter};
4use std::hash::Hash;
5use ustr::Ustr;
6
7use super::{
8    Atom, Domain, DomainPtr, Expression, GroundDomain, Metadata, Moo, Range, ReturnType, SetAttr,
9    Typeable, domains::HasDomain, domains::Int, records::RecordValue,
10};
11use crate::ast::pretty::pretty_vec;
12use crate::bug;
13use polyquine::Quine;
14use uniplate::{Biplate, Tree, Uniplate};
15
16#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate, Hash, Quine)]
17#[uniplate(walk_into=[AbstractLiteral<Literal>])]
18#[biplate(to=Atom)]
19#[biplate(to=AbstractLiteral<Literal>)]
20#[biplate(to=AbstractLiteral<Expression>)]
21#[biplate(to=RecordValue<Literal>)]
22#[biplate(to=RecordValue<Expression>)]
23#[biplate(to=Expression)]
24#[path_prefix(conjure_cp::ast)]
25/// A literal value, equivalent to constants in Conjure.
26pub enum Literal {
27    Int(i32),
28    Bool(bool),
29    //abstract literal variant ends in Literal, but that's ok
30    #[allow(clippy::enum_variant_names)]
31    AbstractLiteral(AbstractLiteral<Literal>),
32}
33
34impl HasDomain for Literal {
35    fn domain_of(&self) -> DomainPtr {
36        match self {
37            Literal::Int(i) => Domain::int(vec![Range::Single(*i)]),
38            Literal::Bool(_) => Domain::bool(),
39            Literal::AbstractLiteral(abstract_literal) => abstract_literal.domain_of(),
40        }
41    }
42}
43
44// make possible values of an AbstractLiteral a closed world to make the trait bounds more sane (particularly in Uniplate instances!!)
45pub trait AbstractLiteralValue:
46    Clone + Eq + PartialEq + Display + Uniplate + Biplate<RecordValue<Self>> + 'static
47{
48    type Dom: Clone + Eq + PartialEq + Display + Quine + From<GroundDomain> + Into<DomainPtr>;
49}
50impl AbstractLiteralValue for Expression {
51    type Dom = DomainPtr;
52}
53impl AbstractLiteralValue for Literal {
54    type Dom = Moo<GroundDomain>;
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Quine)]
58#[path_prefix(conjure_cp::ast)]
59pub enum AbstractLiteral<T: AbstractLiteralValue> {
60    Set(Vec<T>),
61
62    /// A 1 dimensional matrix slice with an index domain.
63    Matrix(Vec<T>, T::Dom),
64
65    // a tuple of literals
66    Tuple(Vec<T>),
67
68    Record(Vec<RecordValue<T>>),
69
70    Function(Vec<(T, T)>),
71}
72
73// TODO: use HasDomain instead once Expression::domain_of returns Domain not Option<Domain>
74impl AbstractLiteral<Expression> {
75    pub fn domain_of(&self) -> Option<DomainPtr> {
76        match self {
77            AbstractLiteral::Set(items) => {
78                // ensure that all items have a domain, or return None
79                let item_domains: Vec<DomainPtr> = items
80                    .iter()
81                    .map(|x| x.domain_of())
82                    .collect::<Option<Vec<DomainPtr>>>()?;
83
84                // union all item domains together
85                let mut item_domain_iter = item_domains.iter().cloned();
86                let first_item = item_domain_iter.next()?;
87                let item_domain = item_domains
88                    .iter()
89                    .try_fold(first_item, |x, y| x.union(y))
90                    .expect("taking the union of all item domains of a set literal should succeed");
91
92                Some(Domain::set(SetAttr::<Int>::default(), item_domain))
93            }
94
95            AbstractLiteral::Matrix(items, _) => {
96                // ensure that all items have a domain, or return None
97                let item_domains = items
98                    .iter()
99                    .map(|x| x.domain_of())
100                    .collect::<Option<Vec<DomainPtr>>>()?;
101
102                // union all item domains together
103                let mut item_domain_iter = item_domains.iter().cloned();
104
105                let first_item = item_domain_iter.next()?;
106
107                let item_domain = item_domains
108                    .iter()
109                    .try_fold(first_item, |x, y| x.union(y))
110                    .expect(
111                        "taking the union of all item domains of a matrix literal should succeed",
112                    );
113
114                let mut new_index_domain = vec![];
115
116                // flatten index domains of n-d matrix into list
117                let mut e = Expression::AbstractLiteral(Metadata::new(), self.clone());
118                while let Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, idx)) = e {
119                    assert!(
120                        idx.as_matrix().is_none(),
121                        "n-dimensional matrix literals should be represented as a matrix inside a matrix, got {idx}"
122                    );
123                    new_index_domain.push(idx);
124                    e = elems[0].clone();
125                }
126                Some(Domain::matrix(item_domain, new_index_domain))
127            }
128            AbstractLiteral::Tuple(_) => None,
129            AbstractLiteral::Record(_) => None,
130            AbstractLiteral::Function(_) => None,
131        }
132    }
133}
134
135impl HasDomain for AbstractLiteral<Literal> {
136    fn domain_of(&self) -> DomainPtr {
137        Domain::from_literal_vec(&[Literal::AbstractLiteral(self.clone())])
138            .expect("abstract literals should be correctly typed")
139    }
140}
141
142impl Typeable for AbstractLiteral<Expression> {
143    fn return_type(&self) -> ReturnType {
144        match self {
145            AbstractLiteral::Set(items) if items.is_empty() => {
146                ReturnType::Set(Box::new(ReturnType::Unknown))
147            }
148            AbstractLiteral::Set(items) => {
149                let item_type = items[0].return_type();
150
151                // if any items do not have a type, return none.
152                let item_types: Vec<ReturnType> = items.iter().map(|x| x.return_type()).collect();
153
154                assert!(
155                    item_types.iter().all(|x| x == &item_type),
156                    "all items in a set should have the same type"
157                );
158
159                ReturnType::Set(Box::new(item_type))
160            }
161            AbstractLiteral::Matrix(items, _) if items.is_empty() => {
162                ReturnType::Matrix(Box::new(ReturnType::Unknown))
163            }
164            AbstractLiteral::Matrix(items, _) => {
165                let item_type = items[0].return_type();
166
167                // if any items do not have a type, return none.
168                let item_types: Vec<ReturnType> = items.iter().map(|x| x.return_type()).collect();
169
170                assert!(
171                    item_types.iter().all(|x| x == &item_type),
172                    "all items in a matrix should have the same type. items: {items} types: {types:#?}",
173                    items = pretty_vec(items),
174                    types = items
175                        .iter()
176                        .map(|x| x.return_type())
177                        .collect::<Vec<ReturnType>>()
178                );
179
180                ReturnType::Matrix(Box::new(item_type))
181            }
182            AbstractLiteral::Tuple(items) => {
183                let mut item_types = vec![];
184                for item in items {
185                    item_types.push(item.return_type());
186                }
187                ReturnType::Tuple(item_types)
188            }
189            AbstractLiteral::Record(items) => {
190                let mut item_types = vec![];
191                for item in items {
192                    item_types.push(item.value.return_type());
193                }
194                ReturnType::Record(item_types)
195            }
196            AbstractLiteral::Function(items) => {
197                if items.is_empty() {
198                    return ReturnType::Function(
199                        Box::new(ReturnType::Unknown),
200                        Box::new(ReturnType::Unknown),
201                    );
202                }
203
204                // Check that all items have the same return type
205                let (x1, y1) = &items[0];
206                let (t1, t2) = (x1.return_type(), y1.return_type());
207                for (x, y) in items {
208                    let (tx, ty) = (x.return_type(), y.return_type());
209                    if tx != t1 {
210                        bug!("Expected {t1}, got {x}: {tx}");
211                    }
212                    if ty != t2 {
213                        bug!("Expected {t2}, got {y}: {ty}");
214                    }
215                }
216
217                ReturnType::Function(Box::new(t1), Box::new(t2))
218            }
219        }
220    }
221}
222
223impl<T> AbstractLiteral<T>
224where
225    T: AbstractLiteralValue,
226{
227    /// Creates a matrix with elements `elems`, with domain `int(1..)`.
228    ///
229    /// This acts as a variable sized list.
230    pub fn matrix_implied_indices(elems: Vec<T>) -> Self {
231        AbstractLiteral::Matrix(elems, GroundDomain::Int(vec![Range::UnboundedR(1)]).into())
232    }
233
234    /// If the AbstractLiteral is a list, returns its elements.
235    ///
236    /// A list is any a matrix with the domain `int(1..)`. This includes matrix literals without
237    /// any explicitly specified domain.
238    pub fn unwrap_list(&self) -> Option<&Vec<T>> {
239        let AbstractLiteral::Matrix(elems, domain) = self else {
240            return None;
241        };
242
243        let domain: DomainPtr = domain.clone().into();
244        let Some(GroundDomain::Int(ranges)) = domain.as_ground() else {
245            return None;
246        };
247
248        let [Range::UnboundedR(1)] = ranges[..] else {
249            return None;
250        };
251
252        Some(elems)
253    }
254}
255
256impl<T> Display for AbstractLiteral<T>
257where
258    T: AbstractLiteralValue,
259{
260    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
261        match self {
262            AbstractLiteral::Set(elems) => {
263                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
264                write!(f, "{{{elems_str}}}")
265            }
266            AbstractLiteral::Matrix(elems, index_domain) => {
267                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
268                write!(f, "[{elems_str};{index_domain}]")
269            }
270            AbstractLiteral::Tuple(elems) => {
271                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
272                write!(f, "({elems_str})")
273            }
274            AbstractLiteral::Record(entries) => {
275                let entries_str: String = entries
276                    .iter()
277                    .map(|entry| format!("{}: {}", entry.name, entry.value))
278                    .join(",");
279                write!(f, "{{{entries_str}}}")
280            }
281            AbstractLiteral::Function(entries) => {
282                let entries_str: String = entries
283                    .iter()
284                    .map(|entry| format!("{} --> {}", entry.0, entry.1))
285                    .join(",");
286                write!(f, "function({entries_str})")
287            }
288        }
289    }
290}
291
292impl<T> Uniplate for AbstractLiteral<T>
293where
294    T: AbstractLiteralValue + Biplate<AbstractLiteral<T>>,
295{
296    fn uniplate(&self) -> (Tree<Self>, Box<dyn Fn(Tree<Self>) -> Self>) {
297        // walking into T
298        match self {
299            AbstractLiteral::Set(vec) => {
300                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(vec);
301                (f1_tree, Box::new(move |x| AbstractLiteral::Set(f1_ctx(x))))
302            }
303            AbstractLiteral::Matrix(elems, index_domain) => {
304                let index_domain = index_domain.clone();
305                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(elems);
306                (
307                    f1_tree,
308                    Box::new(move |x| AbstractLiteral::Matrix(f1_ctx(x), index_domain.clone())),
309                )
310            }
311            AbstractLiteral::Tuple(elems) => {
312                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(elems);
313                (
314                    f1_tree,
315                    Box::new(move |x| AbstractLiteral::Tuple(f1_ctx(x))),
316                )
317            }
318            AbstractLiteral::Record(entries) => {
319                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(entries);
320                (
321                    f1_tree,
322                    Box::new(move |x| AbstractLiteral::Record(f1_ctx(x))),
323                )
324            }
325            AbstractLiteral::Function(entries) => {
326                let entry_count = entries.len();
327                let flattened: Vec<T> = entries
328                    .iter()
329                    .flat_map(|(lhs, rhs)| [lhs.clone(), rhs.clone()])
330                    .collect();
331
332                let (f1_tree, f1_ctx) =
333                    <Vec<T> as Biplate<AbstractLiteral<T>>>::biplate(&flattened);
334                (
335                    f1_tree,
336                    Box::new(move |x| {
337                        let rebuilt = f1_ctx(x);
338                        assert_eq!(
339                            rebuilt.len(),
340                            entry_count * 2,
341                            "number of function literal children should remain unchanged"
342                        );
343
344                        let mut iter = rebuilt.into_iter();
345                        let mut pairs = Vec::with_capacity(entry_count);
346                        while let (Some(lhs), Some(rhs)) = (iter.next(), iter.next()) {
347                            pairs.push((lhs, rhs));
348                        }
349
350                        AbstractLiteral::Function(pairs)
351                    }),
352                )
353            }
354        }
355    }
356}
357
358impl<U, To> Biplate<To> for AbstractLiteral<U>
359where
360    To: Uniplate,
361    U: AbstractLiteralValue + Biplate<AbstractLiteral<U>> + Biplate<To>,
362    RecordValue<U>: Biplate<AbstractLiteral<U>> + Biplate<To>,
363{
364    fn biplate(&self) -> (Tree<To>, Box<dyn Fn(Tree<To>) -> Self>) {
365        if std::any::TypeId::of::<To>() == std::any::TypeId::of::<AbstractLiteral<U>>() {
366            // To ==From => return One(self)
367
368            unsafe {
369                // SAFETY: asserted the type equality above
370                let self_to = std::mem::transmute::<&AbstractLiteral<U>, &To>(self).clone();
371                let tree = Tree::One(self_to);
372                let ctx = Box::new(move |x| {
373                    let Tree::One(x) = x else {
374                        panic!();
375                    };
376
377                    std::mem::transmute::<&To, &AbstractLiteral<U>>(&x).clone()
378                });
379
380                (tree, ctx)
381            }
382        } else {
383            // walking into T
384            match self {
385                AbstractLiteral::Set(vec) => {
386                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(vec);
387                    (f1_tree, Box::new(move |x| AbstractLiteral::Set(f1_ctx(x))))
388                }
389                AbstractLiteral::Matrix(elems, index_domain) => {
390                    let index_domain = index_domain.clone();
391                    let (f1_tree, f1_ctx) = <Vec<U> as Biplate<To>>::biplate(elems);
392                    (
393                        f1_tree,
394                        Box::new(move |x| AbstractLiteral::Matrix(f1_ctx(x), index_domain.clone())),
395                    )
396                }
397                AbstractLiteral::Tuple(elems) => {
398                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(elems);
399                    (
400                        f1_tree,
401                        Box::new(move |x| AbstractLiteral::Tuple(f1_ctx(x))),
402                    )
403                }
404                AbstractLiteral::Record(entries) => {
405                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(entries);
406                    (
407                        f1_tree,
408                        Box::new(move |x| AbstractLiteral::Record(f1_ctx(x))),
409                    )
410                }
411                AbstractLiteral::Function(entries) => {
412                    let entry_count = entries.len();
413                    let flattened: Vec<U> = entries
414                        .iter()
415                        .flat_map(|(lhs, rhs)| [lhs.clone(), rhs.clone()])
416                        .collect();
417
418                    let (f1_tree, f1_ctx) = <Vec<U> as Biplate<To>>::biplate(&flattened);
419                    (
420                        f1_tree,
421                        Box::new(move |x| {
422                            let rebuilt = f1_ctx(x);
423                            assert_eq!(
424                                rebuilt.len(),
425                                entry_count * 2,
426                                "number of function literal children should remain unchanged"
427                            );
428
429                            let mut iter = rebuilt.into_iter();
430                            let mut pairs = Vec::with_capacity(entry_count);
431                            while let (Some(lhs), Some(rhs)) = (iter.next(), iter.next()) {
432                                pairs.push((lhs, rhs));
433                            }
434
435                            AbstractLiteral::Function(pairs)
436                        }),
437                    )
438                }
439            }
440        }
441    }
442}
443
444impl TryFrom<Literal> for i32 {
445    type Error = &'static str;
446
447    fn try_from(value: Literal) -> Result<Self, Self::Error> {
448        match value {
449            Literal::Int(i) => Ok(i),
450            _ => Err("Cannot convert non-i32 literal to i32"),
451        }
452    }
453}
454
455impl TryFrom<Box<Literal>> for i32 {
456    type Error = &'static str;
457
458    fn try_from(value: Box<Literal>) -> Result<Self, Self::Error> {
459        (*value).try_into()
460    }
461}
462
463impl TryFrom<&Box<Literal>> for i32 {
464    type Error = &'static str;
465
466    fn try_from(value: &Box<Literal>) -> Result<Self, Self::Error> {
467        TryFrom::<&Literal>::try_from(value.as_ref())
468    }
469}
470
471impl TryFrom<&Moo<Literal>> for i32 {
472    type Error = &'static str;
473
474    fn try_from(value: &Moo<Literal>) -> Result<Self, Self::Error> {
475        TryFrom::<&Literal>::try_from(value.as_ref())
476    }
477}
478
479impl TryFrom<&Literal> for i32 {
480    type Error = &'static str;
481
482    fn try_from(value: &Literal) -> Result<Self, Self::Error> {
483        match value {
484            Literal::Int(i) => Ok(*i),
485            _ => Err("Cannot convert non-i32 literal to i32"),
486        }
487    }
488}
489
490impl TryFrom<Literal> for bool {
491    type Error = &'static str;
492
493    fn try_from(value: Literal) -> Result<Self, Self::Error> {
494        match value {
495            Literal::Bool(b) => Ok(b),
496            _ => Err("Cannot convert non-bool literal to bool"),
497        }
498    }
499}
500
501impl TryFrom<&Literal> for bool {
502    type Error = &'static str;
503
504    fn try_from(value: &Literal) -> Result<Self, Self::Error> {
505        match value {
506            Literal::Bool(b) => Ok(*b),
507            _ => Err("Cannot convert non-bool literal to bool"),
508        }
509    }
510}
511
512impl From<i32> for Literal {
513    fn from(i: i32) -> Self {
514        Literal::Int(i)
515    }
516}
517
518impl From<bool> for Literal {
519    fn from(b: bool) -> Self {
520        Literal::Bool(b)
521    }
522}
523
524impl From<Literal> for Ustr {
525    fn from(value: Literal) -> Self {
526        // TODO: avoid the temporary-allocation of a string by format! here?
527        Ustr::from(&format!("{value}"))
528    }
529}
530
531impl AbstractLiteral<Expression> {
532    /// If all the elements are literals, returns this as an AbstractLiteral<Literal>.
533    /// Otherwise, returns `None`.
534    pub fn into_literals(self) -> Option<AbstractLiteral<Literal>> {
535        match self {
536            AbstractLiteral::Set(_) => todo!(),
537            AbstractLiteral::Matrix(items, domain) => {
538                let mut literals = vec![];
539                for item in items {
540                    let literal = match item {
541                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
542                        Expression::AbstractLiteral(_, abslit) => {
543                            Some(Literal::AbstractLiteral(abslit.into_literals()?))
544                        }
545                        _ => None,
546                    }?;
547                    literals.push(literal);
548                }
549
550                Some(AbstractLiteral::Matrix(literals, domain.resolve()?))
551            }
552            AbstractLiteral::Tuple(items) => {
553                let mut literals = vec![];
554                for item in items {
555                    let literal = match item {
556                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
557                        Expression::AbstractLiteral(_, abslit) => {
558                            Some(Literal::AbstractLiteral(abslit.into_literals()?))
559                        }
560                        _ => None,
561                    }?;
562                    literals.push(literal);
563                }
564
565                Some(AbstractLiteral::Tuple(literals))
566            }
567            AbstractLiteral::Record(entries) => {
568                let mut literals = vec![];
569                for entry in entries {
570                    let literal = match entry.value {
571                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
572                        Expression::AbstractLiteral(_, abslit) => {
573                            Some(Literal::AbstractLiteral(abslit.into_literals()?))
574                        }
575                        _ => None,
576                    }?;
577
578                    literals.push((entry.name, literal));
579                }
580                Some(AbstractLiteral::Record(
581                    literals
582                        .into_iter()
583                        .map(|(name, literal)| RecordValue {
584                            name,
585                            value: literal,
586                        })
587                        .collect(),
588                ))
589            }
590            AbstractLiteral::Function(_) => todo!("Implement into_literals for functions"),
591        }
592    }
593}
594
595// need display implementations for other types as well
596impl Display for Literal {
597    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
598        match &self {
599            Literal::Int(i) => write!(f, "{i}"),
600            Literal::Bool(b) => write!(f, "{b}"),
601            Literal::AbstractLiteral(l) => write!(f, "{l:?}"),
602        }
603    }
604}
605
606#[cfg(test)]
607mod tests {
608
609    use super::*;
610    use crate::{into_matrix, matrix};
611    use uniplate::Uniplate;
612
613    #[test]
614    fn matrix_uniplate_universe() {
615        // Can we traverse through matrices with uniplate?
616        let my_matrix: AbstractLiteral<Literal> = into_matrix![
617            vec![Literal::AbstractLiteral(matrix![Literal::Bool(true);Moo::new(GroundDomain::Bool)]); 5];
618            Moo::new(GroundDomain::Bool)
619        ];
620
621        let expected_index_domains = vec![Moo::new(GroundDomain::Bool); 6];
622        let actual_index_domains: Vec<Moo<GroundDomain>> =
623            my_matrix.cata(&move |elem, children| {
624                let mut res = vec![];
625                res.extend(children.into_iter().flatten());
626                if let AbstractLiteral::Matrix(_, index_domain) = elem {
627                    res.push(index_domain);
628                }
629
630                res
631            });
632
633        assert_eq!(actual_index_domains, expected_index_domains);
634    }
635}