1
use itertools::Itertools;
2
use serde::{Deserialize, Serialize};
3
use std::fmt::{Display, Formatter};
4
use std::hash::Hash;
5
use ustr::Ustr;
6

            
7
use super::{
8
    Atom, Domain, DomainPtr, Expression, GroundDomain, Metadata, Moo, Range, ReturnType, SetAttr,
9
    Typeable, domains::HasDomain, domains::Int, records::RecordValue,
10
};
11
use crate::ast::pretty::pretty_vec;
12
use crate::bug;
13
use polyquine::Quine;
14
use uniplate::{Biplate, Tree, Uniplate};
15

            
16
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate, Hash, Quine)]
17
#[uniplate(walk_into=[AbstractLiteral<Literal>])]
18
#[biplate(to=Atom)]
19
#[biplate(to=AbstractLiteral<Literal>)]
20
#[biplate(to=AbstractLiteral<Expression>)]
21
#[biplate(to=RecordValue<Literal>)]
22
#[biplate(to=RecordValue<Expression>)]
23
#[biplate(to=Expression)]
24
#[path_prefix(conjure_cp::ast)]
25
/// A literal value, equivalent to constants in Conjure.
26
pub enum Literal {
27
    Int(i32),
28
    Bool(bool),
29
    //abstract literal variant ends in Literal, but that's ok
30
    #[allow(clippy::enum_variant_names)]
31
    AbstractLiteral(AbstractLiteral<Literal>),
32
}
33

            
34
impl HasDomain for Literal {
35
    fn domain_of(&self) -> DomainPtr {
36
        match self {
37
            Literal::Int(i) => Domain::int(vec![Range::Single(*i)]),
38
            Literal::Bool(_) => Domain::bool(),
39
            Literal::AbstractLiteral(abstract_literal) => abstract_literal.domain_of(),
40
        }
41
    }
42
}
43

            
44
// make possible values of an AbstractLiteral a closed world to make the trait bounds more sane (particularly in Uniplate instances!!)
45
pub trait AbstractLiteralValue:
46
    Clone + Eq + PartialEq + Display + Uniplate + Biplate<RecordValue<Self>> + 'static
47
{
48
    type Dom: Clone + Eq + PartialEq + Display + Quine + From<GroundDomain> + Into<DomainPtr>;
49
}
50
impl AbstractLiteralValue for Expression {
51
    type Dom = DomainPtr;
52
}
53
impl AbstractLiteralValue for Literal {
54
    type Dom = Moo<GroundDomain>;
55
}
56

            
57
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Quine)]
58
#[path_prefix(conjure_cp::ast)]
59
pub enum AbstractLiteral<T: AbstractLiteralValue> {
60
    Set(Vec<T>),
61

            
62
    /// A 1 dimensional matrix slice with an index domain.
63
    Matrix(Vec<T>, T::Dom),
64

            
65
    // a tuple of literals
66
    Tuple(Vec<T>),
67

            
68
    Record(Vec<RecordValue<T>>),
69

            
70
    Function(Vec<(T, T)>),
71
}
72

            
73
// TODO: use HasDomain instead once Expression::domain_of returns Domain not Option<Domain>
74
impl AbstractLiteral<Expression> {
75
    pub fn domain_of(&self) -> Option<DomainPtr> {
76
        match self {
77
            AbstractLiteral::Set(items) => {
78
                // ensure that all items have a domain, or return None
79
                let item_domains: Vec<DomainPtr> = items
80
                    .iter()
81
                    .map(|x| x.domain_of())
82
                    .collect::<Option<Vec<DomainPtr>>>()?;
83

            
84
                // union all item domains together
85
                let mut item_domain_iter = item_domains.iter().cloned();
86
                let first_item = item_domain_iter.next()?;
87
                let item_domain = item_domains
88
                    .iter()
89
                    .try_fold(first_item, |x, y| x.union(y))
90
                    .expect("taking the union of all item domains of a set literal should succeed");
91

            
92
                Some(Domain::set(SetAttr::<Int>::default(), item_domain))
93
            }
94

            
95
            AbstractLiteral::Matrix(items, _) => {
96
                // ensure that all items have a domain, or return None
97
                let item_domains = items
98
                    .iter()
99
                    .map(|x| x.domain_of())
100
                    .collect::<Option<Vec<DomainPtr>>>()?;
101

            
102
                // union all item domains together
103
                let mut item_domain_iter = item_domains.iter().cloned();
104

            
105
                let first_item = item_domain_iter.next()?;
106

            
107
                let item_domain = item_domains
108
                    .iter()
109
                    .try_fold(first_item, |x, y| x.union(y))
110
                    .expect(
111
                        "taking the union of all item domains of a matrix literal should succeed",
112
                    );
113

            
114
                let mut new_index_domain = vec![];
115

            
116
                // flatten index domains of n-d matrix into list
117
                let mut e = Expression::AbstractLiteral(Metadata::new(), self.clone());
118
                while let Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, idx)) = e {
119
                    assert!(
120
                        idx.as_matrix().is_none(),
121
                        "n-dimensional matrix literals should be represented as a matrix inside a matrix, got {idx}"
122
                    );
123
                    new_index_domain.push(idx);
124
                    e = elems[0].clone();
125
                }
126
                Some(Domain::matrix(item_domain, new_index_domain))
127
            }
128
            AbstractLiteral::Tuple(_) => None,
129
            AbstractLiteral::Record(_) => None,
130
            AbstractLiteral::Function(_) => None,
131
        }
132
    }
133

            
134
    pub fn list(exprs: Vec<Expression>) -> Self {
135
        let domain = Domain::int_ground(vec![Range::UnboundedR(1)]);
136
        AbstractLiteral::Matrix(exprs, domain)
137
    }
138
}
139

            
140
impl HasDomain for AbstractLiteral<Literal> {
141
    fn domain_of(&self) -> DomainPtr {
142
        Domain::from_literal_vec(&[Literal::AbstractLiteral(self.clone())])
143
            .expect("abstract literals should be correctly typed")
144
    }
145
}
146

            
147
impl Typeable for AbstractLiteral<Expression> {
148
    fn return_type(&self) -> ReturnType {
149
        match self {
150
            AbstractLiteral::Set(items) if items.is_empty() => {
151
                ReturnType::Set(Box::new(ReturnType::Unknown))
152
            }
153
            AbstractLiteral::Set(items) => {
154
                let item_type = items[0].return_type();
155

            
156
                // if any items do not have a type, return none.
157
                let item_types: Vec<ReturnType> = items.iter().map(|x| x.return_type()).collect();
158

            
159
                assert!(
160
                    item_types.iter().all(|x| x == &item_type),
161
                    "all items in a set should have the same type"
162
                );
163

            
164
                ReturnType::Set(Box::new(item_type))
165
            }
166
            AbstractLiteral::Matrix(items, _) if items.is_empty() => {
167
                ReturnType::Matrix(Box::new(ReturnType::Unknown))
168
            }
169
            AbstractLiteral::Matrix(items, _) => {
170
                let item_type = items[0].return_type();
171

            
172
                // if any items do not have a type, return none.
173
                let item_types: Vec<ReturnType> = items.iter().map(|x| x.return_type()).collect();
174

            
175
                assert!(
176
                    item_types.iter().all(|x| x == &item_type),
177
                    "all items in a matrix should have the same type. items: {items} types: {types:#?}",
178
                    items = pretty_vec(items),
179
                    types = items
180
                        .iter()
181
                        .map(|x| x.return_type())
182
                        .collect::<Vec<ReturnType>>()
183
                );
184

            
185
                ReturnType::Matrix(Box::new(item_type))
186
            }
187
            AbstractLiteral::Tuple(items) => {
188
                let mut item_types = vec![];
189
                for item in items {
190
                    item_types.push(item.return_type());
191
                }
192
                ReturnType::Tuple(item_types)
193
            }
194
            AbstractLiteral::Record(items) => {
195
                let mut item_types = vec![];
196
                for item in items {
197
                    item_types.push(item.value.return_type());
198
                }
199
                ReturnType::Record(item_types)
200
            }
201
            AbstractLiteral::Function(items) => {
202
                if items.is_empty() {
203
                    return ReturnType::Function(
204
                        Box::new(ReturnType::Unknown),
205
                        Box::new(ReturnType::Unknown),
206
                    );
207
                }
208

            
209
                // Check that all items have the same return type
210
                let (x1, y1) = &items[0];
211
                let (t1, t2) = (x1.return_type(), y1.return_type());
212
                for (x, y) in items {
213
                    let (tx, ty) = (x.return_type(), y.return_type());
214
                    if tx != t1 {
215
                        bug!("Expected {t1}, got {x}: {tx}");
216
                    }
217
                    if ty != t2 {
218
                        bug!("Expected {t2}, got {y}: {ty}");
219
                    }
220
                }
221

            
222
                ReturnType::Function(Box::new(t1), Box::new(t2))
223
            }
224
        }
225
    }
226
}
227

            
228
impl<T> AbstractLiteral<T>
229
where
230
    T: AbstractLiteralValue,
231
{
232
    /// Creates a matrix with elements `elems`, with domain `int(1..)`.
233
    ///
234
    /// This acts as a variable sized list.
235
    pub fn matrix_implied_indices(elems: Vec<T>) -> Self {
236
        AbstractLiteral::Matrix(elems, GroundDomain::Int(vec![Range::UnboundedR(1)]).into())
237
    }
238

            
239
    /// If the AbstractLiteral is a list, returns its elements.
240
    ///
241
    /// A list is any a matrix with the domain `int(1..)`. This includes matrix literals without
242
    /// any explicitly specified domain.
243
    pub fn unwrap_list(&self) -> Option<&Vec<T>> {
244
        let AbstractLiteral::Matrix(elems, domain) = self else {
245
            return None;
246
        };
247

            
248
        let domain: DomainPtr = domain.clone().into();
249
        let Some(GroundDomain::Int(ranges)) = domain.as_ground() else {
250
            return None;
251
        };
252

            
253
        let [Range::UnboundedR(1)] = ranges[..] else {
254
            return None;
255
        };
256

            
257
        Some(elems)
258
    }
259
}
260

            
261
impl<T> Display for AbstractLiteral<T>
262
where
263
    T: AbstractLiteralValue,
264
{
265
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
266
        match self {
267
            AbstractLiteral::Set(elems) => {
268
                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
269
                write!(f, "{{{elems_str}}}")
270
            }
271
            AbstractLiteral::Matrix(elems, index_domain) => {
272
                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
273
                write!(f, "[{elems_str};{index_domain}]")
274
            }
275
            AbstractLiteral::Tuple(elems) => {
276
                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
277
                write!(f, "({elems_str})")
278
            }
279
            AbstractLiteral::Record(entries) => {
280
                let entries_str: String = entries
281
                    .iter()
282
                    .map(|entry| format!("{}: {}", entry.name, entry.value))
283
                    .join(",");
284
                write!(f, "{{{entries_str}}}")
285
            }
286
            AbstractLiteral::Function(entries) => {
287
                let entries_str: String = entries
288
                    .iter()
289
                    .map(|entry| format!("{} --> {}", entry.0, entry.1))
290
                    .join(",");
291
                write!(f, "function({entries_str})")
292
            }
293
        }
294
    }
295
}
296

            
297
impl<T> Uniplate for AbstractLiteral<T>
298
where
299
    T: AbstractLiteralValue + Biplate<AbstractLiteral<T>>,
300
{
301
    fn uniplate(&self) -> (Tree<Self>, Box<dyn Fn(Tree<Self>) -> Self>) {
302
        // walking into T
303
        match self {
304
            AbstractLiteral::Set(vec) => {
305
                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(vec);
306
                (f1_tree, Box::new(move |x| AbstractLiteral::Set(f1_ctx(x))))
307
            }
308
            AbstractLiteral::Matrix(elems, index_domain) => {
309
                let index_domain = index_domain.clone();
310
                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(elems);
311
                (
312
                    f1_tree,
313
                    Box::new(move |x| AbstractLiteral::Matrix(f1_ctx(x), index_domain.clone())),
314
                )
315
            }
316
            AbstractLiteral::Tuple(elems) => {
317
                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(elems);
318
                (
319
                    f1_tree,
320
                    Box::new(move |x| AbstractLiteral::Tuple(f1_ctx(x))),
321
                )
322
            }
323
            AbstractLiteral::Record(entries) => {
324
                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(entries);
325
                (
326
                    f1_tree,
327
                    Box::new(move |x| AbstractLiteral::Record(f1_ctx(x))),
328
                )
329
            }
330
            AbstractLiteral::Function(entries) => {
331
                let entry_count = entries.len();
332
                let flattened: Vec<T> = entries
333
                    .iter()
334
                    .flat_map(|(lhs, rhs)| [lhs.clone(), rhs.clone()])
335
                    .collect();
336

            
337
                let (f1_tree, f1_ctx) =
338
                    <Vec<T> as Biplate<AbstractLiteral<T>>>::biplate(&flattened);
339
                (
340
                    f1_tree,
341
                    Box::new(move |x| {
342
                        let rebuilt = f1_ctx(x);
343
                        assert_eq!(
344
                            rebuilt.len(),
345
                            entry_count * 2,
346
                            "number of function literal children should remain unchanged"
347
                        );
348

            
349
                        let mut iter = rebuilt.into_iter();
350
                        let mut pairs = Vec::with_capacity(entry_count);
351
                        while let (Some(lhs), Some(rhs)) = (iter.next(), iter.next()) {
352
                            pairs.push((lhs, rhs));
353
                        }
354

            
355
                        AbstractLiteral::Function(pairs)
356
                    }),
357
                )
358
            }
359
        }
360
    }
361
}
362

            
363
impl<U, To> Biplate<To> for AbstractLiteral<U>
364
where
365
    To: Uniplate,
366
    U: AbstractLiteralValue + Biplate<AbstractLiteral<U>> + Biplate<To>,
367
    RecordValue<U>: Biplate<AbstractLiteral<U>> + Biplate<To>,
368
{
369
    fn biplate(&self) -> (Tree<To>, Box<dyn Fn(Tree<To>) -> Self>) {
370
        if std::any::TypeId::of::<To>() == std::any::TypeId::of::<AbstractLiteral<U>>() {
371
            // To ==From => return One(self)
372

            
373
            unsafe {
374
                // SAFETY: asserted the type equality above
375
                let self_to = std::mem::transmute::<&AbstractLiteral<U>, &To>(self).clone();
376
                let tree = Tree::One(self_to);
377
                let ctx = Box::new(move |x| {
378
                    let Tree::One(x) = x else {
379
                        panic!();
380
                    };
381

            
382
                    std::mem::transmute::<&To, &AbstractLiteral<U>>(&x).clone()
383
                });
384

            
385
                (tree, ctx)
386
            }
387
        } else {
388
            // walking into T
389
            match self {
390
                AbstractLiteral::Set(vec) => {
391
                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(vec);
392
                    (f1_tree, Box::new(move |x| AbstractLiteral::Set(f1_ctx(x))))
393
                }
394
                AbstractLiteral::Matrix(elems, index_domain) => {
395
                    let index_domain = index_domain.clone();
396
                    let (f1_tree, f1_ctx) = <Vec<U> as Biplate<To>>::biplate(elems);
397
                    (
398
                        f1_tree,
399
                        Box::new(move |x| AbstractLiteral::Matrix(f1_ctx(x), index_domain.clone())),
400
                    )
401
                }
402
                AbstractLiteral::Tuple(elems) => {
403
                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(elems);
404
                    (
405
                        f1_tree,
406
                        Box::new(move |x| AbstractLiteral::Tuple(f1_ctx(x))),
407
                    )
408
                }
409
                AbstractLiteral::Record(entries) => {
410
                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(entries);
411
                    (
412
                        f1_tree,
413
                        Box::new(move |x| AbstractLiteral::Record(f1_ctx(x))),
414
                    )
415
                }
416
                AbstractLiteral::Function(entries) => {
417
                    let entry_count = entries.len();
418
                    let flattened: Vec<U> = entries
419
                        .iter()
420
                        .flat_map(|(lhs, rhs)| [lhs.clone(), rhs.clone()])
421
                        .collect();
422

            
423
                    let (f1_tree, f1_ctx) = <Vec<U> as Biplate<To>>::biplate(&flattened);
424
                    (
425
                        f1_tree,
426
                        Box::new(move |x| {
427
                            let rebuilt = f1_ctx(x);
428
                            assert_eq!(
429
                                rebuilt.len(),
430
                                entry_count * 2,
431
                                "number of function literal children should remain unchanged"
432
                            );
433

            
434
                            let mut iter = rebuilt.into_iter();
435
                            let mut pairs = Vec::with_capacity(entry_count);
436
                            while let (Some(lhs), Some(rhs)) = (iter.next(), iter.next()) {
437
                                pairs.push((lhs, rhs));
438
                            }
439

            
440
                            AbstractLiteral::Function(pairs)
441
                        }),
442
                    )
443
                }
444
            }
445
        }
446
    }
447
}
448

            
449
impl TryFrom<Literal> for i32 {
450
    type Error = &'static str;
451

            
452
    fn try_from(value: Literal) -> Result<Self, Self::Error> {
453
        match value {
454
            Literal::Int(i) => Ok(i),
455
            _ => Err("Cannot convert non-i32 literal to i32"),
456
        }
457
    }
458
}
459

            
460
impl TryFrom<Box<Literal>> for i32 {
461
    type Error = &'static str;
462

            
463
    fn try_from(value: Box<Literal>) -> Result<Self, Self::Error> {
464
        (*value).try_into()
465
    }
466
}
467

            
468
impl TryFrom<&Box<Literal>> for i32 {
469
    type Error = &'static str;
470

            
471
    fn try_from(value: &Box<Literal>) -> Result<Self, Self::Error> {
472
        TryFrom::<&Literal>::try_from(value.as_ref())
473
    }
474
}
475

            
476
impl TryFrom<&Moo<Literal>> for i32 {
477
    type Error = &'static str;
478

            
479
    fn try_from(value: &Moo<Literal>) -> Result<Self, Self::Error> {
480
        TryFrom::<&Literal>::try_from(value.as_ref())
481
    }
482
}
483

            
484
impl TryFrom<&Literal> for i32 {
485
    type Error = &'static str;
486

            
487
    fn try_from(value: &Literal) -> Result<Self, Self::Error> {
488
        match value {
489
            Literal::Int(i) => Ok(*i),
490
            _ => Err("Cannot convert non-i32 literal to i32"),
491
        }
492
    }
493
}
494

            
495
impl TryFrom<Literal> for bool {
496
    type Error = &'static str;
497

            
498
    fn try_from(value: Literal) -> Result<Self, Self::Error> {
499
        match value {
500
            Literal::Bool(b) => Ok(b),
501
            _ => Err("Cannot convert non-bool literal to bool"),
502
        }
503
    }
504
}
505

            
506
impl TryFrom<&Literal> for bool {
507
    type Error = &'static str;
508

            
509
    fn try_from(value: &Literal) -> Result<Self, Self::Error> {
510
        match value {
511
            Literal::Bool(b) => Ok(*b),
512
            _ => Err("Cannot convert non-bool literal to bool"),
513
        }
514
    }
515
}
516

            
517
impl From<i32> for Literal {
518
    fn from(i: i32) -> Self {
519
        Literal::Int(i)
520
    }
521
}
522

            
523
impl From<bool> for Literal {
524
    fn from(b: bool) -> Self {
525
        Literal::Bool(b)
526
    }
527
}
528

            
529
impl From<Literal> for Ustr {
530
    fn from(value: Literal) -> Self {
531
        // TODO: avoid the temporary-allocation of a string by format! here?
532
        Ustr::from(&format!("{value}"))
533
    }
534
}
535

            
536
impl AbstractLiteral<Expression> {
537
    /// If all the elements are literals, returns this as an AbstractLiteral<Literal>.
538
    /// Otherwise, returns `None`.
539
    pub fn into_literals(self) -> Option<AbstractLiteral<Literal>> {
540
        match self {
541
            AbstractLiteral::Set(_) => todo!(),
542
            AbstractLiteral::Matrix(items, domain) => {
543
                let mut literals = vec![];
544
                for item in items {
545
                    let literal = match item {
546
                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
547
                        Expression::AbstractLiteral(_, abslit) => {
548
                            Some(Literal::AbstractLiteral(abslit.into_literals()?))
549
                        }
550
                        _ => None,
551
                    }?;
552
                    literals.push(literal);
553
                }
554

            
555
                Some(AbstractLiteral::Matrix(literals, domain.resolve()?))
556
            }
557
            AbstractLiteral::Tuple(items) => {
558
                let mut literals = vec![];
559
                for item in items {
560
                    let literal = match item {
561
                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
562
                        Expression::AbstractLiteral(_, abslit) => {
563
                            Some(Literal::AbstractLiteral(abslit.into_literals()?))
564
                        }
565
                        _ => None,
566
                    }?;
567
                    literals.push(literal);
568
                }
569

            
570
                Some(AbstractLiteral::Tuple(literals))
571
            }
572
            AbstractLiteral::Record(entries) => {
573
                let mut literals = vec![];
574
                for entry in entries {
575
                    let literal = match entry.value {
576
                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
577
                        Expression::AbstractLiteral(_, abslit) => {
578
                            Some(Literal::AbstractLiteral(abslit.into_literals()?))
579
                        }
580
                        _ => None,
581
                    }?;
582

            
583
                    literals.push((entry.name, literal));
584
                }
585
                Some(AbstractLiteral::Record(
586
                    literals
587
                        .into_iter()
588
                        .map(|(name, literal)| RecordValue {
589
                            name,
590
                            value: literal,
591
                        })
592
                        .collect(),
593
                ))
594
            }
595
            AbstractLiteral::Function(_) => todo!("Implement into_literals for functions"),
596
        }
597
    }
598
}
599

            
600
// need display implementations for other types as well
601
impl Display for Literal {
602
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
603
        match &self {
604
            Literal::Int(i) => write!(f, "{i}"),
605
            Literal::Bool(b) => write!(f, "{b}"),
606
            Literal::AbstractLiteral(l) => write!(f, "{l:?}"),
607
        }
608
    }
609
}
610

            
611
#[cfg(test)]
612
mod tests {
613

            
614
    use super::*;
615
    use crate::{into_matrix, matrix};
616
    use uniplate::Uniplate;
617

            
618
    #[test]
619
    fn matrix_uniplate_universe() {
620
        // Can we traverse through matrices with uniplate?
621
        let my_matrix: AbstractLiteral<Literal> = into_matrix![
622
            vec![Literal::AbstractLiteral(matrix![Literal::Bool(true);Moo::new(GroundDomain::Bool)]); 5];
623
            Moo::new(GroundDomain::Bool)
624
        ];
625

            
626
        let expected_index_domains = vec![Moo::new(GroundDomain::Bool); 6];
627
        let actual_index_domains: Vec<Moo<GroundDomain>> =
628
            my_matrix.cata(&move |elem, children| {
629
                let mut res = vec![];
630
                res.extend(children.into_iter().flatten());
631
                if let AbstractLiteral::Matrix(_, index_domain) = elem {
632
                    res.push(index_domain);
633
                }
634

            
635
                res
636
            });
637

            
638
        assert_eq!(actual_index_domains, expected_index_domains);
639
    }
640
}