1use itertools::Itertools;
2use serde::{Deserialize, Serialize};
3use std::fmt::{Display, Formatter};
4use std::hash::Hash;
5use ustr::Ustr;
6
7use super::{
8 Atom, Domain, DomainPtr, Expression, GroundDomain, Metadata, Moo, Range, ReturnType, SetAttr,
9 Typeable, domains::HasDomain, domains::Int, records::RecordValue,
10};
11use crate::ast::pretty::pretty_vec;
12use crate::bug;
13use polyquine::Quine;
14use uniplate::{Biplate, Tree, Uniplate};
15
16#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate, Hash, Quine)]
17#[uniplate(walk_into=[AbstractLiteral<Literal>])]
18#[biplate(to=Atom)]
19#[biplate(to=AbstractLiteral<Literal>)]
20#[biplate(to=AbstractLiteral<Expression>)]
21#[biplate(to=RecordValue<Literal>)]
22#[biplate(to=RecordValue<Expression>)]
23#[biplate(to=Expression)]
24#[path_prefix(conjure_cp::ast)]
25pub enum Literal {
27 Int(i32),
28 Bool(bool),
29 #[allow(clippy::enum_variant_names)]
31 AbstractLiteral(AbstractLiteral<Literal>),
32}
33
34impl HasDomain for Literal {
35 fn domain_of(&self) -> DomainPtr {
36 match self {
37 Literal::Int(i) => Domain::int(vec![Range::Single(*i)]),
38 Literal::Bool(_) => Domain::bool(),
39 Literal::AbstractLiteral(abstract_literal) => abstract_literal.domain_of(),
40 }
41 }
42}
43
44pub trait AbstractLiteralValue:
46 Clone + Eq + PartialEq + Display + Uniplate + Biplate<RecordValue<Self>> + 'static
47{
48 type Dom: Clone + Eq + PartialEq + Display + Quine + From<GroundDomain> + Into<DomainPtr>;
49}
50impl AbstractLiteralValue for Expression {
51 type Dom = DomainPtr;
52}
53impl AbstractLiteralValue for Literal {
54 type Dom = Moo<GroundDomain>;
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Quine)]
58#[path_prefix(conjure_cp::ast)]
59pub enum AbstractLiteral<T: AbstractLiteralValue> {
60 Set(Vec<T>),
61
62 Matrix(Vec<T>, T::Dom),
64
65 Tuple(Vec<T>),
67
68 Record(Vec<RecordValue<T>>),
69
70 Function(Vec<(T, T)>),
71}
72
73impl AbstractLiteral<Expression> {
75 pub fn domain_of(&self) -> Option<DomainPtr> {
76 match self {
77 AbstractLiteral::Set(items) => {
78 let item_domains: Vec<DomainPtr> = items
80 .iter()
81 .map(|x| x.domain_of())
82 .collect::<Option<Vec<DomainPtr>>>()?;
83
84 let mut item_domain_iter = item_domains.iter().cloned();
86 let first_item = item_domain_iter.next()?;
87 let item_domain = item_domains
88 .iter()
89 .try_fold(first_item, |x, y| x.union(y))
90 .expect("taking the union of all item domains of a set literal should succeed");
91
92 Some(Domain::set(SetAttr::<Int>::default(), item_domain))
93 }
94
95 AbstractLiteral::Matrix(items, _) => {
96 let item_domains = items
98 .iter()
99 .map(|x| x.domain_of())
100 .collect::<Option<Vec<DomainPtr>>>()?;
101
102 let mut item_domain_iter = item_domains.iter().cloned();
104
105 let first_item = item_domain_iter.next()?;
106
107 let item_domain = item_domains
108 .iter()
109 .try_fold(first_item, |x, y| x.union(y))
110 .expect(
111 "taking the union of all item domains of a matrix literal should succeed",
112 );
113
114 let mut new_index_domain = vec![];
115
116 let mut e = Expression::AbstractLiteral(Metadata::new(), self.clone());
118 while let Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, idx)) = e {
119 assert!(
120 idx.as_matrix().is_none(),
121 "n-dimensional matrix literals should be represented as a matrix inside a matrix, got {idx}"
122 );
123 new_index_domain.push(idx);
124 e = elems[0].clone();
125 }
126 Some(Domain::matrix(item_domain, new_index_domain))
127 }
128 AbstractLiteral::Tuple(_) => None,
129 AbstractLiteral::Record(_) => None,
130 AbstractLiteral::Function(_) => None,
131 }
132 }
133}
134
135impl HasDomain for AbstractLiteral<Literal> {
136 fn domain_of(&self) -> DomainPtr {
137 Domain::from_literal_vec(&[Literal::AbstractLiteral(self.clone())])
138 .expect("abstract literals should be correctly typed")
139 }
140}
141
142impl Typeable for AbstractLiteral<Expression> {
143 fn return_type(&self) -> ReturnType {
144 match self {
145 AbstractLiteral::Set(items) if items.is_empty() => {
146 ReturnType::Set(Box::new(ReturnType::Unknown))
147 }
148 AbstractLiteral::Set(items) => {
149 let item_type = items[0].return_type();
150
151 let item_types: Vec<ReturnType> = items.iter().map(|x| x.return_type()).collect();
153
154 assert!(
155 item_types.iter().all(|x| x == &item_type),
156 "all items in a set should have the same type"
157 );
158
159 ReturnType::Set(Box::new(item_type))
160 }
161 AbstractLiteral::Matrix(items, _) if items.is_empty() => {
162 ReturnType::Matrix(Box::new(ReturnType::Unknown))
163 }
164 AbstractLiteral::Matrix(items, _) => {
165 let item_type = items[0].return_type();
166
167 let item_types: Vec<ReturnType> = items.iter().map(|x| x.return_type()).collect();
169
170 assert!(
171 item_types.iter().all(|x| x == &item_type),
172 "all items in a matrix should have the same type. items: {items} types: {types:#?}",
173 items = pretty_vec(items),
174 types = items
175 .iter()
176 .map(|x| x.return_type())
177 .collect::<Vec<ReturnType>>()
178 );
179
180 ReturnType::Matrix(Box::new(item_type))
181 }
182 AbstractLiteral::Tuple(items) => {
183 let mut item_types = vec![];
184 for item in items {
185 item_types.push(item.return_type());
186 }
187 ReturnType::Tuple(item_types)
188 }
189 AbstractLiteral::Record(items) => {
190 let mut item_types = vec![];
191 for item in items {
192 item_types.push(item.value.return_type());
193 }
194 ReturnType::Record(item_types)
195 }
196 AbstractLiteral::Function(items) => {
197 if items.is_empty() {
198 return ReturnType::Function(
199 Box::new(ReturnType::Unknown),
200 Box::new(ReturnType::Unknown),
201 );
202 }
203
204 let (x1, y1) = &items[0];
206 let (t1, t2) = (x1.return_type(), y1.return_type());
207 for (x, y) in items {
208 let (tx, ty) = (x.return_type(), y.return_type());
209 if tx != t1 {
210 bug!("Expected {t1}, got {x}: {tx}");
211 }
212 if ty != t2 {
213 bug!("Expected {t2}, got {y}: {ty}");
214 }
215 }
216
217 ReturnType::Function(Box::new(t1), Box::new(t2))
218 }
219 }
220 }
221}
222
223impl<T> AbstractLiteral<T>
224where
225 T: AbstractLiteralValue,
226{
227 pub fn matrix_implied_indices(elems: Vec<T>) -> Self {
231 AbstractLiteral::Matrix(elems, GroundDomain::Int(vec![Range::UnboundedR(1)]).into())
232 }
233
234 pub fn unwrap_list(&self) -> Option<&Vec<T>> {
239 let AbstractLiteral::Matrix(elems, domain) = self else {
240 return None;
241 };
242
243 let domain: DomainPtr = domain.clone().into();
244 let Some(GroundDomain::Int(ranges)) = domain.as_ground() else {
245 return None;
246 };
247
248 let [Range::UnboundedR(1)] = ranges[..] else {
249 return None;
250 };
251
252 Some(elems)
253 }
254}
255
256impl<T> Display for AbstractLiteral<T>
257where
258 T: AbstractLiteralValue,
259{
260 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
261 match self {
262 AbstractLiteral::Set(elems) => {
263 let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
264 write!(f, "{{{elems_str}}}")
265 }
266 AbstractLiteral::Matrix(elems, index_domain) => {
267 let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
268 write!(f, "[{elems_str};{index_domain}]")
269 }
270 AbstractLiteral::Tuple(elems) => {
271 let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
272 write!(f, "({elems_str})")
273 }
274 AbstractLiteral::Record(entries) => {
275 let entries_str: String = entries
276 .iter()
277 .map(|entry| format!("{}: {}", entry.name, entry.value))
278 .join(",");
279 write!(f, "{{{entries_str}}}")
280 }
281 AbstractLiteral::Function(entries) => {
282 let entries_str: String = entries
283 .iter()
284 .map(|entry| format!("{} --> {}", entry.0, entry.1))
285 .join(",");
286 write!(f, "function({entries_str})")
287 }
288 }
289 }
290}
291
292impl<T> Uniplate for AbstractLiteral<T>
293where
294 T: AbstractLiteralValue + Biplate<AbstractLiteral<T>>,
295{
296 fn uniplate(&self) -> (Tree<Self>, Box<dyn Fn(Tree<Self>) -> Self>) {
297 match self {
299 AbstractLiteral::Set(vec) => {
300 let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(vec);
301 (f1_tree, Box::new(move |x| AbstractLiteral::Set(f1_ctx(x))))
302 }
303 AbstractLiteral::Matrix(elems, index_domain) => {
304 let index_domain = index_domain.clone();
305 let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(elems);
306 (
307 f1_tree,
308 Box::new(move |x| AbstractLiteral::Matrix(f1_ctx(x), index_domain.clone())),
309 )
310 }
311 AbstractLiteral::Tuple(elems) => {
312 let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(elems);
313 (
314 f1_tree,
315 Box::new(move |x| AbstractLiteral::Tuple(f1_ctx(x))),
316 )
317 }
318 AbstractLiteral::Record(entries) => {
319 let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(entries);
320 (
321 f1_tree,
322 Box::new(move |x| AbstractLiteral::Record(f1_ctx(x))),
323 )
324 }
325 AbstractLiteral::Function(entries) => {
326 let entry_count = entries.len();
327 let flattened: Vec<T> = entries
328 .iter()
329 .flat_map(|(lhs, rhs)| [lhs.clone(), rhs.clone()])
330 .collect();
331
332 let (f1_tree, f1_ctx) =
333 <Vec<T> as Biplate<AbstractLiteral<T>>>::biplate(&flattened);
334 (
335 f1_tree,
336 Box::new(move |x| {
337 let rebuilt = f1_ctx(x);
338 assert_eq!(
339 rebuilt.len(),
340 entry_count * 2,
341 "number of function literal children should remain unchanged"
342 );
343
344 let mut iter = rebuilt.into_iter();
345 let mut pairs = Vec::with_capacity(entry_count);
346 while let (Some(lhs), Some(rhs)) = (iter.next(), iter.next()) {
347 pairs.push((lhs, rhs));
348 }
349
350 AbstractLiteral::Function(pairs)
351 }),
352 )
353 }
354 }
355 }
356}
357
358impl<U, To> Biplate<To> for AbstractLiteral<U>
359where
360 To: Uniplate,
361 U: AbstractLiteralValue + Biplate<AbstractLiteral<U>> + Biplate<To>,
362 RecordValue<U>: Biplate<AbstractLiteral<U>> + Biplate<To>,
363{
364 fn biplate(&self) -> (Tree<To>, Box<dyn Fn(Tree<To>) -> Self>) {
365 if std::any::TypeId::of::<To>() == std::any::TypeId::of::<AbstractLiteral<U>>() {
366 unsafe {
369 let self_to = std::mem::transmute::<&AbstractLiteral<U>, &To>(self).clone();
371 let tree = Tree::One(self_to);
372 let ctx = Box::new(move |x| {
373 let Tree::One(x) = x else {
374 panic!();
375 };
376
377 std::mem::transmute::<&To, &AbstractLiteral<U>>(&x).clone()
378 });
379
380 (tree, ctx)
381 }
382 } else {
383 match self {
385 AbstractLiteral::Set(vec) => {
386 let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(vec);
387 (f1_tree, Box::new(move |x| AbstractLiteral::Set(f1_ctx(x))))
388 }
389 AbstractLiteral::Matrix(elems, index_domain) => {
390 let index_domain = index_domain.clone();
391 let (f1_tree, f1_ctx) = <Vec<U> as Biplate<To>>::biplate(elems);
392 (
393 f1_tree,
394 Box::new(move |x| AbstractLiteral::Matrix(f1_ctx(x), index_domain.clone())),
395 )
396 }
397 AbstractLiteral::Tuple(elems) => {
398 let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(elems);
399 (
400 f1_tree,
401 Box::new(move |x| AbstractLiteral::Tuple(f1_ctx(x))),
402 )
403 }
404 AbstractLiteral::Record(entries) => {
405 let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(entries);
406 (
407 f1_tree,
408 Box::new(move |x| AbstractLiteral::Record(f1_ctx(x))),
409 )
410 }
411 AbstractLiteral::Function(entries) => {
412 let entry_count = entries.len();
413 let flattened: Vec<U> = entries
414 .iter()
415 .flat_map(|(lhs, rhs)| [lhs.clone(), rhs.clone()])
416 .collect();
417
418 let (f1_tree, f1_ctx) = <Vec<U> as Biplate<To>>::biplate(&flattened);
419 (
420 f1_tree,
421 Box::new(move |x| {
422 let rebuilt = f1_ctx(x);
423 assert_eq!(
424 rebuilt.len(),
425 entry_count * 2,
426 "number of function literal children should remain unchanged"
427 );
428
429 let mut iter = rebuilt.into_iter();
430 let mut pairs = Vec::with_capacity(entry_count);
431 while let (Some(lhs), Some(rhs)) = (iter.next(), iter.next()) {
432 pairs.push((lhs, rhs));
433 }
434
435 AbstractLiteral::Function(pairs)
436 }),
437 )
438 }
439 }
440 }
441 }
442}
443
444impl TryFrom<Literal> for i32 {
445 type Error = &'static str;
446
447 fn try_from(value: Literal) -> Result<Self, Self::Error> {
448 match value {
449 Literal::Int(i) => Ok(i),
450 _ => Err("Cannot convert non-i32 literal to i32"),
451 }
452 }
453}
454
455impl TryFrom<Box<Literal>> for i32 {
456 type Error = &'static str;
457
458 fn try_from(value: Box<Literal>) -> Result<Self, Self::Error> {
459 (*value).try_into()
460 }
461}
462
463impl TryFrom<&Box<Literal>> for i32 {
464 type Error = &'static str;
465
466 fn try_from(value: &Box<Literal>) -> Result<Self, Self::Error> {
467 TryFrom::<&Literal>::try_from(value.as_ref())
468 }
469}
470
471impl TryFrom<&Moo<Literal>> for i32 {
472 type Error = &'static str;
473
474 fn try_from(value: &Moo<Literal>) -> Result<Self, Self::Error> {
475 TryFrom::<&Literal>::try_from(value.as_ref())
476 }
477}
478
479impl TryFrom<&Literal> for i32 {
480 type Error = &'static str;
481
482 fn try_from(value: &Literal) -> Result<Self, Self::Error> {
483 match value {
484 Literal::Int(i) => Ok(*i),
485 _ => Err("Cannot convert non-i32 literal to i32"),
486 }
487 }
488}
489
490impl TryFrom<Literal> for bool {
491 type Error = &'static str;
492
493 fn try_from(value: Literal) -> Result<Self, Self::Error> {
494 match value {
495 Literal::Bool(b) => Ok(b),
496 _ => Err("Cannot convert non-bool literal to bool"),
497 }
498 }
499}
500
501impl TryFrom<&Literal> for bool {
502 type Error = &'static str;
503
504 fn try_from(value: &Literal) -> Result<Self, Self::Error> {
505 match value {
506 Literal::Bool(b) => Ok(*b),
507 _ => Err("Cannot convert non-bool literal to bool"),
508 }
509 }
510}
511
512impl From<i32> for Literal {
513 fn from(i: i32) -> Self {
514 Literal::Int(i)
515 }
516}
517
518impl From<bool> for Literal {
519 fn from(b: bool) -> Self {
520 Literal::Bool(b)
521 }
522}
523
524impl From<Literal> for Ustr {
525 fn from(value: Literal) -> Self {
526 Ustr::from(&format!("{value}"))
528 }
529}
530
531impl AbstractLiteral<Expression> {
532 pub fn into_literals(self) -> Option<AbstractLiteral<Literal>> {
535 match self {
536 AbstractLiteral::Set(_) => todo!(),
537 AbstractLiteral::Matrix(items, domain) => {
538 let mut literals = vec![];
539 for item in items {
540 let literal = match item {
541 Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
542 Expression::AbstractLiteral(_, abslit) => {
543 Some(Literal::AbstractLiteral(abslit.into_literals()?))
544 }
545 _ => None,
546 }?;
547 literals.push(literal);
548 }
549
550 Some(AbstractLiteral::Matrix(literals, domain.resolve()?))
551 }
552 AbstractLiteral::Tuple(items) => {
553 let mut literals = vec![];
554 for item in items {
555 let literal = match item {
556 Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
557 Expression::AbstractLiteral(_, abslit) => {
558 Some(Literal::AbstractLiteral(abslit.into_literals()?))
559 }
560 _ => None,
561 }?;
562 literals.push(literal);
563 }
564
565 Some(AbstractLiteral::Tuple(literals))
566 }
567 AbstractLiteral::Record(entries) => {
568 let mut literals = vec![];
569 for entry in entries {
570 let literal = match entry.value {
571 Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
572 Expression::AbstractLiteral(_, abslit) => {
573 Some(Literal::AbstractLiteral(abslit.into_literals()?))
574 }
575 _ => None,
576 }?;
577
578 literals.push((entry.name, literal));
579 }
580 Some(AbstractLiteral::Record(
581 literals
582 .into_iter()
583 .map(|(name, literal)| RecordValue {
584 name,
585 value: literal,
586 })
587 .collect(),
588 ))
589 }
590 AbstractLiteral::Function(_) => todo!("Implement into_literals for functions"),
591 }
592 }
593}
594
595impl Display for Literal {
597 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
598 match &self {
599 Literal::Int(i) => write!(f, "{i}"),
600 Literal::Bool(b) => write!(f, "{b}"),
601 Literal::AbstractLiteral(l) => write!(f, "{l:?}"),
602 }
603 }
604}
605
606#[cfg(test)]
607mod tests {
608
609 use super::*;
610 use crate::{into_matrix, matrix};
611 use uniplate::Uniplate;
612
613 #[test]
614 fn matrix_uniplate_universe() {
615 let my_matrix: AbstractLiteral<Literal> = into_matrix![
617 vec![Literal::AbstractLiteral(matrix![Literal::Bool(true);Moo::new(GroundDomain::Bool)]); 5];
618 Moo::new(GroundDomain::Bool)
619 ];
620
621 let expected_index_domains = vec![Moo::new(GroundDomain::Bool); 6];
622 let actual_index_domains: Vec<Moo<GroundDomain>> =
623 my_matrix.cata(&move |elem, children| {
624 let mut res = vec![];
625 res.extend(children.into_iter().flatten());
626 if let AbstractLiteral::Matrix(_, index_domain) = elem {
627 res.push(index_domain);
628 }
629
630 res
631 });
632
633 assert_eq!(actual_index_domains, expected_index_domains);
634 }
635}