1
use itertools::Itertools;
2
use serde::{Deserialize, Serialize};
3
use std::fmt::{Display, Formatter};
4
use std::hash::Hash;
5
use ustr::Ustr;
6

            
7
use super::{
8
    Atom, Domain, DomainPtr, Expression, GroundDomain, Metadata, Moo, PartitionAttr, Range,
9
    ReturnType, SetAttr, Typeable, domains::HasDomain, domains::Int, records::FieldValue,
10
};
11
use crate::ast::domains::{MSetAttr, SequenceAttr};
12
use crate::ast::pretty::pretty_vec;
13
use crate::bug;
14
use polyquine::Quine;
15
use uniplate::{Biplate, Tree, Uniplate};
16

            
17
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate, Hash, Quine)]
18
#[uniplate(walk_into=[AbstractLiteral<Literal>])]
19
#[biplate(to=Atom)]
20
#[biplate(to=AbstractLiteral<Literal>)]
21
#[biplate(to=AbstractLiteral<Expression>)]
22
#[biplate(to=FieldValue<Literal>)]
23
#[biplate(to=FieldValue<Expression>)]
24
#[biplate(to=Expression)]
25
#[path_prefix(conjure_cp::ast)]
26
/// A literal value, equivalent to constants in Conjure.
27
pub enum Literal {
28
    Int(i32),
29
    Bool(bool),
30
    //abstract literal variant ends in Literal, but that's ok
31
    #[allow(clippy::enum_variant_names)]
32
    AbstractLiteral(AbstractLiteral<Literal>),
33
}
34

            
35
impl HasDomain for Literal {
36
2432795
    fn domain_of(&self) -> DomainPtr {
37
2432795
        match self {
38
2362944
            Literal::Int(i) => Domain::int(vec![Range::Single(*i)]),
39
52633
            Literal::Bool(_) => Domain::bool(),
40
17218
            Literal::AbstractLiteral(abstract_literal) => abstract_literal.domain_of(),
41
        }
42
2432795
    }
43
}
44

            
45
// make possible values of an AbstractLiteral a closed world to make the trait bounds more sane (particularly in Uniplate instances!!)
46
pub trait AbstractLiteralValue:
47
    Clone + Eq + PartialEq + Display + Uniplate + Biplate<FieldValue<Self>> + 'static
48
{
49
    type Dom: Clone + Eq + PartialEq + Display + Quine + From<GroundDomain> + Into<DomainPtr>;
50
}
51
impl AbstractLiteralValue for Expression {
52
    type Dom = DomainPtr;
53
}
54
impl AbstractLiteralValue for Literal {
55
    type Dom = Moo<GroundDomain>;
56
}
57

            
58
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Quine)]
59
#[path_prefix(conjure_cp::ast)]
60
pub enum AbstractLiteral<T: AbstractLiteralValue> {
61
    Set(Vec<T>),
62

            
63
    MSet(Vec<T>),
64

            
65
    /// A 1 dimensional matrix slice with an index domain.
66
    Matrix(Vec<T>, T::Dom),
67

            
68
    // a tuple of literals
69
    Tuple(Vec<T>),
70

            
71
    Record(Vec<FieldValue<T>>),
72

            
73
    Sequence(Vec<T>),
74

            
75
    Function(Vec<(T, T)>),
76

            
77
    // Variants only contain one of their name-domain pairs
78
    Variant(Moo<FieldValue<T>>),
79

            
80
    // A list of partitions, each part has a set of values
81
    Partition(Vec<Vec<T>>),
82
    Relation(Vec<Vec<T>>),
83
}
84

            
85
// TODO: use HasDomain instead once Expression::domain_of returns Domain not Option<Domain>
86
impl AbstractLiteral<Expression> {
87
198020
    pub fn domain_of(&self) -> Option<DomainPtr> {
88
198020
        match self {
89
8200
            AbstractLiteral::Set(items) => {
90
                // ensure that all items have a domain, or return None
91
8200
                let item_domains: Vec<DomainPtr> = items
92
8200
                    .iter()
93
20280
                    .map(|x| x.domain_of())
94
8200
                    .collect::<Option<Vec<DomainPtr>>>()?;
95

            
96
                // union all item domains together
97
8200
                let mut item_domain_iter = item_domains.iter().cloned();
98
8200
                let first_item = item_domain_iter.next()?;
99
8200
                let item_domain = item_domains
100
8200
                    .iter()
101
20280
                    .try_fold(first_item, |x, y| x.union(y))
102
8200
                    .expect("taking the union of all item domains of a set literal should succeed");
103

            
104
8200
                Some(Domain::set(SetAttr::<Int>::default(), item_domain))
105
            }
106

            
107
40
            AbstractLiteral::MSet(items) => {
108
                // ensure that all items have a domain, or return None
109
40
                let item_domains: Vec<DomainPtr> = items
110
40
                    .iter()
111
120
                    .map(|x| x.domain_of())
112
40
                    .collect::<Option<Vec<DomainPtr>>>()?;
113

            
114
                // union all item domains together
115
40
                let mut item_domain_iter = item_domains.iter().cloned();
116
40
                let first_item = item_domain_iter.next()?;
117
40
                let item_domain = item_domains
118
40
                    .iter()
119
120
                    .try_fold(first_item, |x, y| x.union(y))
120
40
                    .expect("taking the union of all item domains of a set literal should succeed");
121

            
122
40
                Some(Domain::mset(MSetAttr::<Int>::default(), item_domain))
123
            }
124

            
125
120
            AbstractLiteral::Sequence(elems) => {
126
120
                let item_domains: Vec<DomainPtr> = elems
127
120
                    .iter()
128
400
                    .map(|x| x.domain_of())
129
120
                    .collect::<Option<Vec<DomainPtr>>>()?;
130

            
131
                // Get the union of all domains in the sequence.
132
                // i.e. if <(1..3), (1..3), (5), (8..9)> then seq dom is (1..3, 5, 8..9)
133
120
                let mut item_domain_iter = item_domains.iter().cloned();
134
120
                let first_item = item_domain_iter.next()?;
135
120
                let item_domain = item_domains
136
120
                    .iter()
137
400
                    .try_fold(first_item, |x, y| x.union(y))
138
120
                    .expect("taking the union of all item domains of a set literal should succeed");
139

            
140
120
                Some(Domain::sequence(
141
120
                    SequenceAttr::<Int>::default(),
142
120
                    item_domain,
143
120
                ))
144
            }
145

            
146
40
            AbstractLiteral::Partition(items) => {
147
                // Flatten the Vec<Vec< into a single vec
148
                // ensure that all elemes in each part have a domain, or return None
149

            
150
40
                let item_domains: Vec<DomainPtr> = items
151
40
                    .iter()
152
40
                    .flatten()
153
160
                    .map(|x| x.domain_of())
154
40
                    .collect::<Option<Vec<DomainPtr>>>()?;
155

            
156
                // union all item domains together
157
40
                let mut item_domain_iter = item_domains.iter().cloned();
158
40
                let first_item = item_domain_iter.next()?;
159
40
                let item_domain = item_domains
160
40
                    .iter()
161
160
                    .try_fold(first_item, |x, y| x.union(y))
162
40
                    .expect("taking the union of all item domains of a partition literal should succeed");
163

            
164
40
                Some(Domain::partition(
165
40
                    PartitionAttr::<Int>::default(),
166
40
                    item_domain,
167
40
                ))
168
            }
169

            
170
189100
            AbstractLiteral::Matrix(items, _) => {
171
                // ensure that all items have a domain, or return None
172
189100
                let item_domains = items
173
189100
                    .iter()
174
1297738
                    .map(|x| x.domain_of())
175
189100
                    .collect::<Option<Vec<DomainPtr>>>()?;
176

            
177
                // union all item domains together
178
189096
                let mut item_domain_iter = item_domains.iter().cloned();
179

            
180
189096
                let first_item = item_domain_iter.next()?;
181

            
182
189096
                let item_domain = item_domains
183
189096
                    .iter()
184
1297734
                    .try_fold(first_item, |x, y| x.union(y))
185
189096
                    .expect(
186
189096
                        "taking the union of all item domains of a matrix literal should succeed",
187
                    );
188

            
189
189096
                let mut new_index_domain = vec![];
190

            
191
                // flatten index domains of n-d matrix into list
192
189096
                let mut e = Expression::AbstractLiteral(Metadata::new(), self.clone());
193
197090
                while let Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, idx)) = e {
194
197090
                    assert!(
195
197090
                        idx.as_matrix().is_none(),
196
                        "n-dimensional matrix literals should be represented as a matrix inside a matrix, got {idx}"
197
                    );
198
197090
                    new_index_domain.push(idx);
199
197090
                    e = elems[0].clone();
200
                }
201
189096
                Some(Domain::matrix(item_domain, new_index_domain))
202
            }
203
280
            AbstractLiteral::Tuple(_) => None,
204
40
            AbstractLiteral::Record(_) => None,
205
80
            AbstractLiteral::Function(_) => None,
206
40
            AbstractLiteral::Variant(_) => None,
207
80
            AbstractLiteral::Relation(_) => None,
208
        }
209
198020
    }
210
}
211

            
212
impl HasDomain for AbstractLiteral<Literal> {
213
17218
    fn domain_of(&self) -> DomainPtr {
214
17218
        Domain::from_literal_vec(&[Literal::AbstractLiteral(self.clone())])
215
17218
            .expect("abstract literals should be correctly typed")
216
17218
    }
217
}
218

            
219
impl Typeable for AbstractLiteral<Expression> {
220
133590
    fn return_type(&self) -> ReturnType {
221
        match self {
222
            AbstractLiteral::Set(items) if items.is_empty() => {
223
                ReturnType::Set(Box::new(ReturnType::Unknown))
224
            }
225
            AbstractLiteral::Set(items) => {
226
                let item_type = items[0].return_type();
227

            
228
                // if any items do not have a type, return none.
229
                let item_types: Vec<ReturnType> = items.iter().map(|x| x.return_type()).collect();
230

            
231
                assert!(
232
                    item_types.iter().all(|x| x == &item_type),
233
                    "all items in a set should have the same type"
234
                );
235

            
236
                ReturnType::Set(Box::new(item_type))
237
            }
238
            AbstractLiteral::MSet(items) if items.is_empty() => {
239
                ReturnType::MSet(Box::new(ReturnType::Unknown))
240
            }
241
            AbstractLiteral::MSet(items) => {
242
                let item_type = items[0].return_type();
243

            
244
                // if any items do not have a type, return none.
245
                let item_types: Vec<ReturnType> = items.iter().map(|x| x.return_type()).collect();
246

            
247
                assert!(
248
                    item_types.iter().all(|x| x == &item_type),
249
                    "all items in a set should have the same type"
250
                );
251

            
252
                ReturnType::MSet(Box::new(item_type))
253
            }
254
            AbstractLiteral::Sequence(items) if items.is_empty() => {
255
                ReturnType::Sequence(Box::new(ReturnType::Unknown))
256
            }
257
            AbstractLiteral::Sequence(items) => {
258
                let item_type = items[0].return_type();
259

            
260
                // if any items do not have a type, return none.
261
                let item_types: Vec<ReturnType> = items.iter().map(|x| x.return_type()).collect();
262

            
263
                assert!(
264
                    item_types.iter().all(|x| x == &item_type),
265
                    "all items in a sequence should have the same type"
266
                );
267

            
268
                ReturnType::Sequence(Box::new(item_type))
269
            }
270
            AbstractLiteral::Partition(items) if items.is_empty() || items[0].is_empty() => {
271
                ReturnType::Partition(Box::new(ReturnType::Unknown))
272
            }
273
            AbstractLiteral::Partition(items) => {
274
                let item_type = items[0][0].return_type();
275

            
276
                // if any items do not have a type, return none.
277
                let item_types: Vec<ReturnType> =
278
                    items.iter().flatten().map(|x| x.return_type()).collect();
279

            
280
                assert!(
281
                    item_types.iter().all(|x| x == &item_type),
282
                    "all items in every part of a partition should have the same type"
283
                );
284

            
285
                ReturnType::Partition(Box::new(item_type))
286
            }
287
133590
            AbstractLiteral::Matrix(items, _) if items.is_empty() => {
288
                ReturnType::Matrix(Box::new(ReturnType::Unknown))
289
            }
290
133590
            AbstractLiteral::Matrix(items, _) => {
291
133590
                let item_type = items[0].return_type();
292

            
293
                // if any items do not have a type, return none.
294
963672
                let item_types: Vec<ReturnType> = items.iter().map(|x| x.return_type()).collect();
295

            
296
133590
                assert!(
297
963672
                    item_types.iter().all(|x| x == &item_type),
298
                    "all items in a matrix should have the same type. items: {items} types: {types:#?}",
299
                    items = pretty_vec(items),
300
                    types = items
301
                        .iter()
302
                        .map(|x| x.return_type())
303
                        .collect::<Vec<ReturnType>>()
304
                );
305

            
306
133590
                ReturnType::Matrix(Box::new(item_type))
307
            }
308
            AbstractLiteral::Tuple(items) => {
309
                let mut item_types = vec![];
310
                for item in items {
311
                    item_types.push(item.return_type());
312
                }
313
                ReturnType::Tuple(item_types)
314
            }
315
            AbstractLiteral::Record(items) => {
316
                let mut item_types = vec![];
317
                for item in items {
318
                    item_types.push(item.value.return_type());
319
                }
320
                ReturnType::Record(item_types)
321
            }
322
            AbstractLiteral::Function(items) => {
323
                if items.is_empty() {
324
                    return ReturnType::Function(
325
                        Box::new(ReturnType::Unknown),
326
                        Box::new(ReturnType::Unknown),
327
                    );
328
                }
329

            
330
                // Check that all items have the same return type
331
                let (x1, y1) = &items[0];
332
                let (t1, t2) = (x1.return_type(), y1.return_type());
333
                for (x, y) in items {
334
                    let (tx, ty) = (x.return_type(), y.return_type());
335
                    if tx != t1 {
336
                        bug!("Expected {t1}, got {x}: {tx}");
337
                    }
338
                    if ty != t2 {
339
                        bug!("Expected {t2}, got {y}: {ty}");
340
                    }
341
                }
342

            
343
                ReturnType::Function(Box::new(t1), Box::new(t2))
344
            }
345
            AbstractLiteral::Variant(item) => {
346
                // Variants hold multiple possible types. In the case of a literal we know which type it chose
347
                ReturnType::Variant(vec![item.value.return_type()])
348
            }
349
            AbstractLiteral::Relation(items) => {
350
                if items.is_empty() {
351
                    return ReturnType::Relation(vec![ReturnType::Unknown]);
352
                }
353
                let mut item_types = vec![];
354
                let x1 = &items[0];
355
                let size = x1.len();
356
                for item in x1 {
357
                    item_types.push(item.return_type());
358
                }
359
                for x in items {
360
                    if x.len() != size {
361
                        let strs = item_types.iter().map(|x| format!("{}", x)).join(",");
362
                        bug!("Expected ({strs}) with length {size}, got size {}", x.len());
363
                    }
364
                    for i in 1..size {
365
                        if let Some(new_type) = x.get(i)
366
                            && let Some(old_type) = item_types.get(i)
367
                            && new_type.return_type() != *old_type
368
                        {
369
                            bug!("Expected {old_type}, got {new_type}");
370
                        }
371
                    }
372
                }
373
                ReturnType::Relation(item_types)
374
            }
375
        }
376
133590
    }
377
}
378

            
379
impl<T> AbstractLiteral<T>
380
where
381
    T: AbstractLiteralValue,
382
{
383
    /// Creates a matrix with elements `elems`, with domain `int(1..)`.
384
    ///
385
    /// This acts as a variable sized list.
386
949723
    pub fn matrix_implied_indices(elems: Vec<T>) -> Self {
387
949723
        AbstractLiteral::Matrix(elems, GroundDomain::Int(vec![Range::UnboundedR(1)]).into())
388
949723
    }
389

            
390
    /// If the AbstractLiteral is a list, returns its elements.
391
    ///
392
    /// A list is any a matrix with the domain `int(1..)`. This includes matrix literals without
393
    /// any explicitly specified domain.
394
7481322
    pub fn unwrap_list(&self) -> Option<&Vec<T>> {
395
7481322
        let AbstractLiteral::Matrix(elems, domain) = self else {
396
            return None;
397
        };
398

            
399
7481322
        let domain: DomainPtr = domain.clone().into();
400
7481322
        let Some(GroundDomain::Int(ranges)) = domain.as_ground() else {
401
            return None;
402
        };
403

            
404
7481322
        let [Range::UnboundedR(1)] = ranges[..] else {
405
1410148
            return None;
406
        };
407

            
408
6071174
        Some(elems)
409
7481322
    }
410
}
411

            
412
impl<T> Display for AbstractLiteral<T>
413
where
414
    T: AbstractLiteralValue,
415
{
416
2451690
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
417
2451690
        match self {
418
9228
            AbstractLiteral::Set(elems) => {
419
24002
                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
420
9228
                write!(f, "{{{elems_str}}}")
421
            }
422
40
            AbstractLiteral::MSet(elems) => {
423
120
                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
424
40
                write!(f, "mset({elems_str})")
425
            }
426
2440382
            AbstractLiteral::Matrix(elems, index_domain) => {
427
9137992
                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
428
2440382
                write!(f, "[{elems_str};{index_domain}]")
429
            }
430
1080
            AbstractLiteral::Tuple(elems) => {
431
2480
                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
432
1080
                write!(f, "({elems_str})")
433
            }
434
120
            AbstractLiteral::Sequence(elems) => {
435
400
                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
436
120
                write!(f, "sequence({elems_str})")
437
            }
438
40
            AbstractLiteral::Partition(parts) => {
439
40
                let elems_str: String = parts
440
40
                    .iter()
441
120
                    .map(|inner| {
442
160
                        let elems_str = inner.iter().map(|x| format!("{x}")).join(",");
443
120
                        format!("{{{}}}", elems_str)
444
120
                    })
445
40
                    .join(", ");
446

            
447
40
                write!(f, "partition({elems_str})")
448
            }
449
560
            AbstractLiteral::Record(entries) => {
450
560
                let entries_str: String = entries
451
560
                    .iter()
452
1120
                    .map(|entry| format!("{} = {}", entry.name, entry.value))
453
560
                    .join(",");
454
560
                write!(f, "record {{{entries_str}}}")
455
            }
456
120
            AbstractLiteral::Function(entries) => {
457
120
                let entries_str: String = entries
458
120
                    .iter()
459
240
                    .map(|entry| format!("{} --> {}", entry.0, entry.1))
460
120
                    .join(",");
461
120
                write!(f, "function({entries_str})")
462
            }
463
40
            AbstractLiteral::Variant(entry) => {
464
40
                write!(f, "variant{{{} = {}}}", entry.name, entry.value)
465
            }
466
80
            AbstractLiteral::Relation(elems) => {
467
80
                let elems_str: String = elems
468
80
                    .iter()
469
480
                    .map(|x| format!("({})", x.iter().map(|x| format!("{x}")).join(",")))
470
80
                    .join(",");
471
80
                write!(f, "relation({elems_str})")
472
            }
473
        }
474
2451690
    }
475
}
476

            
477
impl<T> Uniplate for AbstractLiteral<T>
478
where
479
    T: AbstractLiteralValue + Biplate<AbstractLiteral<T>>,
480
{
481
34546
    fn uniplate(&self) -> (Tree<Self>, Box<dyn Fn(Tree<Self>) -> Self>) {
482
        // walking into T
483
34546
        match self {
484
            AbstractLiteral::Set(vec) => {
485
                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(vec);
486
                (f1_tree, Box::new(move |x| AbstractLiteral::Set(f1_ctx(x))))
487
            }
488
            AbstractLiteral::MSet(vec) => {
489
                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(vec);
490
                (f1_tree, Box::new(move |x| AbstractLiteral::MSet(f1_ctx(x))))
491
            }
492
34546
            AbstractLiteral::Matrix(elems, index_domain) => {
493
34546
                let index_domain = index_domain.clone();
494
34546
                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(elems);
495
                (
496
34546
                    f1_tree,
497
34546
                    Box::new(move |x| AbstractLiteral::Matrix(f1_ctx(x), index_domain.clone())),
498
                )
499
            }
500
            AbstractLiteral::Sequence(vec) => {
501
                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(vec);
502
                (
503
                    f1_tree,
504
                    Box::new(move |x| AbstractLiteral::Sequence(f1_ctx(x))),
505
                )
506
            }
507
            AbstractLiteral::Tuple(elems) => {
508
                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(elems);
509
                (
510
                    f1_tree,
511
                    Box::new(move |x| AbstractLiteral::Tuple(f1_ctx(x))),
512
                )
513
            }
514
            AbstractLiteral::Record(entries) => {
515
                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(entries);
516
                (
517
                    f1_tree,
518
                    Box::new(move |x| AbstractLiteral::Record(f1_ctx(x))),
519
                )
520
            }
521
            AbstractLiteral::Function(entries) => {
522
                let entry_count = entries.len();
523
                let flattened: Vec<T> = entries
524
                    .iter()
525
                    .flat_map(|(lhs, rhs)| [lhs.clone(), rhs.clone()])
526
                    .collect();
527

            
528
                let (f1_tree, f1_ctx) =
529
                    <Vec<T> as Biplate<AbstractLiteral<T>>>::biplate(&flattened);
530
                (
531
                    f1_tree,
532
                    Box::new(move |x| {
533
                        let rebuilt = f1_ctx(x);
534
                        assert_eq!(
535
                            rebuilt.len(),
536
                            entry_count * 2,
537
                            "number of function literal children should remain unchanged"
538
                        );
539

            
540
                        let mut iter = rebuilt.into_iter();
541
                        let mut pairs = Vec::with_capacity(entry_count);
542
                        while let (Some(lhs), Some(rhs)) = (iter.next(), iter.next()) {
543
                            pairs.push((lhs, rhs));
544
                        }
545

            
546
                        AbstractLiteral::Function(pairs)
547
                    }),
548
                )
549
            }
550
            AbstractLiteral::Variant(entries) => {
551
                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(entries);
552
                (
553
                    f1_tree,
554
                    Box::new(move |x| AbstractLiteral::Variant(f1_ctx(x))),
555
                )
556
            }
557
            AbstractLiteral::Relation(elems) => {
558
                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(elems);
559
                (
560
                    f1_tree,
561
                    Box::new(move |x| AbstractLiteral::Relation(f1_ctx(x))),
562
                )
563
            }
564
            AbstractLiteral::Partition(elems) => {
565
                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(elems);
566
                (
567
                    f1_tree,
568
                    Box::new(move |x| AbstractLiteral::Partition(f1_ctx(x))),
569
                )
570
            }
571
        }
572
34546
    }
573
}
574

            
575
impl<U, To> Biplate<To> for AbstractLiteral<U>
576
where
577
    To: Uniplate,
578
    U: AbstractLiteralValue + Biplate<AbstractLiteral<U>> + Biplate<To>,
579
    FieldValue<U>: Biplate<AbstractLiteral<U>> + Biplate<To>,
580
{
581
87535847
    fn biplate(&self) -> (Tree<To>, Box<dyn Fn(Tree<To>) -> Self>) {
582
87535847
        if std::any::TypeId::of::<To>() == std::any::TypeId::of::<AbstractLiteral<U>>() {
583
            // To ==From => return One(self)
584

            
585
            unsafe {
586
                // SAFETY: asserted the type equality above
587
17845
                let self_to = std::mem::transmute::<&AbstractLiteral<U>, &To>(self).clone();
588
17845
                let tree = Tree::One(self_to);
589
17845
                let ctx = Box::new(move |x| {
590
                    let Tree::One(x) = x else {
591
                        panic!();
592
                    };
593

            
594
                    std::mem::transmute::<&To, &AbstractLiteral<U>>(&x).clone()
595
                });
596

            
597
17845
                (tree, ctx)
598
            }
599
        } else {
600
            // walking into T
601
87518002
            match self {
602
204434
                AbstractLiteral::Set(vec) => {
603
204434
                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(vec);
604
204434
                    (f1_tree, Box::new(move |x| AbstractLiteral::Set(f1_ctx(x))))
605
                }
606
240
                AbstractLiteral::MSet(vec) => {
607
240
                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(vec);
608
240
                    (f1_tree, Box::new(move |x| AbstractLiteral::MSet(f1_ctx(x))))
609
                }
610
87268708
                AbstractLiteral::Matrix(elems, index_domain) => {
611
87268708
                    let index_domain = index_domain.clone();
612
87268708
                    let (f1_tree, f1_ctx) = <Vec<U> as Biplate<To>>::biplate(elems);
613
                    (
614
87268708
                        f1_tree,
615
87268708
                        Box::new(move |x| AbstractLiteral::Matrix(f1_ctx(x), index_domain.clone())),
616
                    )
617
                }
618
620
                AbstractLiteral::Sequence(vec) => {
619
620
                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(vec);
620
                    (
621
620
                        f1_tree,
622
620
                        Box::new(move |x| AbstractLiteral::Sequence(f1_ctx(x))),
623
                    )
624
                }
625
17200
                AbstractLiteral::Tuple(elems) => {
626
17200
                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(elems);
627
                    (
628
17200
                        f1_tree,
629
17200
                        Box::new(move |x| AbstractLiteral::Tuple(f1_ctx(x))),
630
                    )
631
                }
632
25300
                AbstractLiteral::Record(entries) => {
633
25300
                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(entries);
634
                    (
635
25300
                        f1_tree,
636
25300
                        Box::new(move |x| AbstractLiteral::Record(f1_ctx(x))),
637
                    )
638
                }
639
940
                AbstractLiteral::Function(entries) => {
640
940
                    let entry_count = entries.len();
641
940
                    let flattened: Vec<U> = entries
642
940
                        .iter()
643
1880
                        .flat_map(|(lhs, rhs)| [lhs.clone(), rhs.clone()])
644
940
                        .collect();
645

            
646
940
                    let (f1_tree, f1_ctx) = <Vec<U> as Biplate<To>>::biplate(&flattened);
647
                    (
648
940
                        f1_tree,
649
940
                        Box::new(move |x| {
650
                            let rebuilt = f1_ctx(x);
651
                            assert_eq!(
652
                                rebuilt.len(),
653
                                entry_count * 2,
654
                                "number of function literal children should remain unchanged"
655
                            );
656

            
657
                            let mut iter = rebuilt.into_iter();
658
                            let mut pairs = Vec::with_capacity(entry_count);
659
                            while let (Some(lhs), Some(rhs)) = (iter.next(), iter.next()) {
660
                                pairs.push((lhs, rhs));
661
                            }
662

            
663
                            AbstractLiteral::Function(pairs)
664
                        }),
665
                    )
666
                }
667
140
                AbstractLiteral::Variant(entries) => {
668
140
                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(entries);
669
                    (
670
140
                        f1_tree,
671
140
                        Box::new(move |x| AbstractLiteral::Variant(f1_ctx(x))),
672
                    )
673
                }
674
280
                AbstractLiteral::Relation(elems) => {
675
280
                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(elems);
676
                    (
677
280
                        f1_tree,
678
280
                        Box::new(move |x| AbstractLiteral::Relation(f1_ctx(x))),
679
                    )
680
                }
681
140
                AbstractLiteral::Partition(elems) => {
682
140
                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(elems);
683
                    (
684
140
                        f1_tree,
685
140
                        Box::new(move |x| AbstractLiteral::Partition(f1_ctx(x))),
686
                    )
687
                }
688
            }
689
        }
690
87535847
    }
691
}
692

            
693
impl TryFrom<Literal> for i32 {
694
    type Error = &'static str;
695

            
696
2722488
    fn try_from(value: Literal) -> Result<Self, Self::Error> {
697
2722488
        match value {
698
2560840
            Literal::Int(i) => Ok(i),
699
161648
            _ => Err("Cannot convert non-i32 literal to i32"),
700
        }
701
2722488
    }
702
}
703

            
704
impl TryFrom<Box<Literal>> for i32 {
705
    type Error = &'static str;
706

            
707
    fn try_from(value: Box<Literal>) -> Result<Self, Self::Error> {
708
        (*value).try_into()
709
    }
710
}
711

            
712
impl TryFrom<&Box<Literal>> for i32 {
713
    type Error = &'static str;
714

            
715
162800
    fn try_from(value: &Box<Literal>) -> Result<Self, Self::Error> {
716
162800
        TryFrom::<&Literal>::try_from(value.as_ref())
717
162800
    }
718
}
719

            
720
impl TryFrom<&Moo<Literal>> for i32 {
721
    type Error = &'static str;
722

            
723
    fn try_from(value: &Moo<Literal>) -> Result<Self, Self::Error> {
724
        TryFrom::<&Literal>::try_from(value.as_ref())
725
    }
726
}
727

            
728
impl TryFrom<&Literal> for i32 {
729
    type Error = &'static str;
730

            
731
3075116
    fn try_from(value: &Literal) -> Result<Self, Self::Error> {
732
3075116
        match value {
733
3075116
            Literal::Int(i) => Ok(*i),
734
            _ => Err("Cannot convert non-i32 literal to i32"),
735
        }
736
3075116
    }
737
}
738

            
739
impl TryFrom<Literal> for bool {
740
    type Error = &'static str;
741

            
742
1703256
    fn try_from(value: Literal) -> Result<Self, Self::Error> {
743
1703256
        match value {
744
1679300
            Literal::Bool(b) => Ok(b),
745
23956
            _ => Err("Cannot convert non-bool literal to bool"),
746
        }
747
1703256
    }
748
}
749

            
750
impl TryFrom<&Literal> for bool {
751
    type Error = &'static str;
752

            
753
246896
    fn try_from(value: &Literal) -> Result<Self, Self::Error> {
754
246896
        match value {
755
246896
            Literal::Bool(b) => Ok(*b),
756
            _ => Err("Cannot convert non-bool literal to bool"),
757
        }
758
246896
    }
759
}
760

            
761
impl From<i32> for Literal {
762
3795542
    fn from(i: i32) -> Self {
763
3795542
        Literal::Int(i)
764
3795542
    }
765
}
766

            
767
impl From<bool> for Literal {
768
166938
    fn from(b: bool) -> Self {
769
166938
        Literal::Bool(b)
770
166938
    }
771
}
772

            
773
impl From<Literal> for Ustr {
774
1760
    fn from(value: Literal) -> Self {
775
        // TODO: avoid the temporary-allocation of a string by format! here?
776
1760
        Ustr::from(&format!("{value}"))
777
1760
    }
778
}
779

            
780
impl AbstractLiteral<Expression> {
781
    /// If all the elements are literals, returns this as an AbstractLiteral<Literal>.
782
    /// Otherwise, returns `None`.
783
5011884
    pub fn into_literals(self) -> Option<AbstractLiteral<Literal>> {
784
5011884
        match self {
785
4724
            AbstractLiteral::Set(elements) => {
786
4724
                let literals = elements
787
4724
                    .into_iter()
788
11846
                    .map(|expr| match expr {
789
11846
                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
790
                        Expression::AbstractLiteral(_, abslit) => {
791
                            Some(Literal::AbstractLiteral(abslit.into_literals()?))
792
                        }
793
                        _ => None,
794
11846
                    })
795
4724
                    .collect::<Option<Vec<_>>>()?;
796
4724
                Some(AbstractLiteral::Set(literals))
797
            }
798
            AbstractLiteral::MSet(elements) => {
799
                let literals = elements
800
                    .into_iter()
801
                    .map(|expr| match expr {
802
                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
803
                        Expression::AbstractLiteral(_, abslit) => {
804
                            Some(Literal::AbstractLiteral(abslit.into_literals()?))
805
                        }
806
                        _ => None,
807
                    })
808
                    .collect::<Option<Vec<_>>>()?;
809
                Some(AbstractLiteral::MSet(literals))
810
            }
811
            AbstractLiteral::Partition(elems) => {
812
                // want to ascertain if every elem in Vec<Vec<Expr>> is a literal. If any are not, return none
813
                // otherwise confirm it is an abslit<lit>
814
                let mut partition: Vec<Vec<_>> = Vec::new();
815

            
816
                for part in elems {
817
                    let literals = part
818
                        .into_iter()
819
                        .map(|expr| match expr {
820
                            Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
821
                            Expression::AbstractLiteral(_, abslit) => {
822
                                Some(Literal::AbstractLiteral(abslit.into_literals()?))
823
                            }
824
                            _ => None,
825
                        })
826
                        .collect::<Option<Vec<_>>>()?;
827

            
828
                    partition.push(literals);
829
                }
830

            
831
                Some(AbstractLiteral::Partition(partition))
832
            }
833
5006840
            AbstractLiteral::Matrix(items, domain) => {
834
5006840
                let mut literals = vec![];
835
7052988
                for item in items {
836
5662932
                    let literal = match item {
837
2535952
                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
838
257320
                        Expression::AbstractLiteral(_, abslit) => {
839
257320
                            Some(Literal::AbstractLiteral(abslit.into_literals()?))
840
                        }
841
4259716
                        _ => None,
842
4259716
                    }?;
843
2792152
                    literals.push(literal);
844
                }
845

            
846
746004
                Some(AbstractLiteral::Matrix(literals, domain.resolve()?))
847
            }
848
            AbstractLiteral::Sequence(elements) => {
849
                let literals = elements
850
                    .into_iter()
851
                    .map(|expr| match expr {
852
                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
853
                        Expression::AbstractLiteral(_, abslit) => {
854
                            Some(Literal::AbstractLiteral(abslit.into_literals()?))
855
                        }
856
                        _ => None,
857
                    })
858
                    .collect::<Option<Vec<_>>>()?;
859
                Some(AbstractLiteral::Sequence(literals))
860
            }
861
240
            AbstractLiteral::Tuple(items) => {
862
240
                let mut literals = vec![];
863
480
                for item in items {
864
480
                    let literal = match item {
865
480
                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
866
                        Expression::AbstractLiteral(_, abslit) => {
867
                            Some(Literal::AbstractLiteral(abslit.into_literals()?))
868
                        }
869
                        _ => None,
870
                    }?;
871
480
                    literals.push(literal);
872
                }
873

            
874
240
                Some(AbstractLiteral::Tuple(literals))
875
            }
876
80
            AbstractLiteral::Record(entries) => {
877
80
                let mut literals = vec![];
878
160
                for entry in entries {
879
160
                    let literal = match entry.value {
880
160
                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
881
                        Expression::AbstractLiteral(_, abslit) => {
882
                            Some(Literal::AbstractLiteral(abslit.into_literals()?))
883
                        }
884
                        _ => None,
885
                    }?;
886

            
887
160
                    literals.push((entry.name, literal));
888
                }
889
                Some(AbstractLiteral::Record(
890
80
                    literals
891
80
                        .into_iter()
892
80
                        .map(|(name, literal)| FieldValue {
893
160
                            name,
894
160
                            value: literal,
895
160
                        })
896
80
                        .collect(),
897
                ))
898
            }
899
            AbstractLiteral::Function(_) => todo!("Implement into_literals for functions"),
900
            AbstractLiteral::Variant(entry) => {
901
                let literal = match entry.value.clone() {
902
                    Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
903
                    Expression::AbstractLiteral(_, abslit) => {
904
                        Some(Literal::AbstractLiteral(abslit.into_literals()?))
905
                    }
906
                    _ => None,
907
                }?;
908
                Some(AbstractLiteral::Variant(Moo::new(FieldValue {
909
                    name: entry.name.clone(),
910
                    value: literal,
911
                })))
912
            }
913
            AbstractLiteral::Relation(_) => todo!("Implement into_literals for relations"),
914
        }
915
5011884
    }
916
}
917

            
918
// need display implementations for other types as well
919
impl Display for Literal {
920
5957760
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
921
5957760
        match &self {
922
4921772
            Literal::Int(i) => write!(f, "{i}"),
923
857148
            Literal::Bool(b) => write!(f, "{b}"),
924
178840
            Literal::AbstractLiteral(l) => write!(f, "{l}"),
925
        }
926
5957760
    }
927
}
928

            
929
#[cfg(test)]
930
mod tests {
931

            
932
    use super::*;
933
    use crate::{into_matrix, matrix};
934
    use uniplate::Uniplate;
935

            
936
    #[test]
937
1
    fn matrix_uniplate_universe() {
938
        // Can we traverse through matrices with uniplate?
939
1
        let my_matrix: AbstractLiteral<Literal> = into_matrix![
940
1
            vec![Literal::AbstractLiteral(matrix![Literal::Bool(true);Moo::new(GroundDomain::Bool)]); 5];
941
1
            Moo::new(GroundDomain::Bool)
942
        ];
943

            
944
1
        let expected_index_domains = vec![Moo::new(GroundDomain::Bool); 6];
945
1
        let actual_index_domains: Vec<Moo<GroundDomain>> =
946
6
            my_matrix.cata(&move |elem, children| {
947
6
                let mut res = vec![];
948
6
                res.extend(children.into_iter().flatten());
949
6
                if let AbstractLiteral::Matrix(_, index_domain) = elem {
950
6
                    res.push(index_domain);
951
6
                }
952

            
953
6
                res
954
6
            });
955

            
956
1
        assert_eq!(actual_index_domains, expected_index_domains);
957
1
    }
958
}