1
use itertools::Itertools;
2
use serde::{Deserialize, Serialize};
3
use std::fmt::{Display, Formatter};
4
use std::hash::Hash;
5
use std::hash::Hasher;
6
use uniplate::derive::Uniplate;
7
use uniplate::{Biplate, Tree, Uniplate};
8

            
9
use super::{Atom, Domain, Expression, Range};
10

            
11
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate, Hash)]
12
#[uniplate(walk_into=[AbstractLiteral<Literal>])]
13
#[biplate(to=Atom)]
14
#[biplate(to=AbstractLiteral<Literal>)]
15
#[biplate(to=AbstractLiteral<Expression>)]
16
#[biplate(to=Expression)]
17
/// A literal value, equivalent to constants in Conjure.
18
pub enum Literal {
19
    Int(i32),
20
    Bool(bool),
21
    AbstractLiteral(AbstractLiteral<Literal>),
22
}
23

            
24
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
25
pub enum AbstractLiteral<T: Uniplate + Biplate<AbstractLiteral<T>> + Biplate<T>> {
26
    Set(Vec<T>),
27

            
28
    /// A 1 dimensional matrix slice with an index domain.
29
    Matrix(Vec<T>, Domain),
30
}
31

            
32
impl<T> AbstractLiteral<T>
33
where
34
    T: Uniplate + Biplate<AbstractLiteral<T>> + Biplate<T>,
35
{
36
    /// Creates a matrix with elements `elems`, with domain `int(1..)`.
37
    ///
38
    /// This acts as a variable sized list.
39
7164
    pub fn matrix_implied_indices(elems: Vec<T>) -> Self {
40
7164
        AbstractLiteral::Matrix(elems, Domain::IntDomain(vec![Range::UnboundedR(1)]))
41
7164
    }
42

            
43
    /// If the AbstractLiteral is a list, returns its elements.
44
    ///
45
    /// A list is any a matrix with the domain `int(1..)`. This includes matrix literals without
46
    /// any explicitly specified domain.
47
134406
    pub fn unwrap_list(&self) -> Option<&Vec<T>> {
48
134406
        let AbstractLiteral::Matrix(elems, Domain::IntDomain(ranges)) = self else {
49
            return None;
50
        };
51

            
52
134406
        let [Range::UnboundedR(1)] = ranges[..] else {
53
84708
            return None;
54
        };
55

            
56
49698
        Some(elems)
57
134406
    }
58
}
59

            
60
impl<T> Display for AbstractLiteral<T>
61
where
62
    T: Uniplate + Biplate<AbstractLiteral<T>> + Biplate<T> + Display,
63
{
64
20790
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
65
20790
        match self {
66
            AbstractLiteral::Set(elems) => {
67
                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
68
                write!(f, "{{{elems_str}}}")
69
            }
70
20790
            AbstractLiteral::Matrix(elems, index_domain) => {
71
34920
                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
72
20790
                write!(f, "[{elems_str};{index_domain}]")
73
            }
74
        }
75
20790
    }
76
}
77

            
78
impl Hash for AbstractLiteral<Literal> {
79
    fn hash<H: Hasher>(&self, state: &mut H) {
80
        match self {
81
            AbstractLiteral::Set(vec) => {
82
                0.hash(state);
83
                vec.hash(state);
84
            }
85
            AbstractLiteral::Matrix(elems, index_domain) => {
86
                1.hash(state);
87
                elems.hash(state);
88
                index_domain.hash(state);
89
            }
90
        }
91
    }
92
}
93

            
94
impl<T> Uniplate for AbstractLiteral<T>
95
where
96
    T: Uniplate + Biplate<AbstractLiteral<T>> + Biplate<T>,
97
{
98
    fn uniplate(&self) -> (Tree<Self>, Box<dyn Fn(Tree<Self>) -> Self>) {
99
        // walking into T
100
        match self {
101
            AbstractLiteral::Set(vec) => {
102
                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(vec);
103
                (f1_tree, Box::new(move |x| AbstractLiteral::Set(f1_ctx(x))))
104
            }
105
            AbstractLiteral::Matrix(elems, index_domain) => {
106
                let index_domain = index_domain.clone();
107
                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(elems);
108
                (
109
                    f1_tree,
110
                    Box::new(move |x| AbstractLiteral::Matrix(f1_ctx(x), index_domain.clone())),
111
                )
112
            }
113
        }
114
    }
115
}
116

            
117
impl<U, To> Biplate<To> for AbstractLiteral<U>
118
where
119
    To: Uniplate,
120
    U: Biplate<To> + Biplate<U> + Biplate<AbstractLiteral<U>>,
121
{
122
701838
    fn biplate(&self) -> (Tree<To>, Box<dyn Fn(Tree<To>) -> Self>) {
123
701838
        // walking into T
124
701838
        match self {
125
            AbstractLiteral::Set(vec) => {
126
                let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(vec);
127
                (f1_tree, Box::new(move |x| AbstractLiteral::Set(f1_ctx(x))))
128
            }
129
701838
            AbstractLiteral::Matrix(elems, index_domain) => {
130
701838
                let index_domain = index_domain.clone();
131
701838
                let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(elems);
132
701838
                (
133
701838
                    f1_tree,
134
701838
                    Box::new(move |x| AbstractLiteral::Matrix(f1_ctx(x), index_domain.clone())),
135
701838
                )
136
            }
137
        }
138
701838
    }
139
}
140

            
141
impl TryFrom<Literal> for i32 {
142
    type Error = &'static str;
143

            
144
99218
    fn try_from(value: Literal) -> Result<Self, Self::Error> {
145
99218
        match value {
146
99110
            Literal::Int(i) => Ok(i),
147
108
            _ => Err("Cannot convert non-i32 literal to i32"),
148
        }
149
99218
    }
150
}
151

            
152
impl TryFrom<&Literal> for i32 {
153
    type Error = &'static str;
154

            
155
5328
    fn try_from(value: &Literal) -> Result<Self, Self::Error> {
156
5328
        match value {
157
5328
            Literal::Int(i) => Ok(*i),
158
            _ => Err("Cannot convert non-i32 literal to i32"),
159
        }
160
5328
    }
161
}
162

            
163
impl TryFrom<Literal> for bool {
164
    type Error = &'static str;
165

            
166
2484
    fn try_from(value: Literal) -> Result<Self, Self::Error> {
167
2484
        match value {
168
2160
            Literal::Bool(b) => Ok(b),
169
324
            _ => Err("Cannot convert non-bool literal to bool"),
170
        }
171
2484
    }
172
}
173

            
174
impl TryFrom<&Literal> for bool {
175
    type Error = &'static str;
176

            
177
    fn try_from(value: &Literal) -> Result<Self, Self::Error> {
178
        match value {
179
            Literal::Bool(b) => Ok(*b),
180
            _ => Err("Cannot convert non-bool literal to bool"),
181
        }
182
    }
183
}
184

            
185
impl From<i32> for Literal {
186
216
    fn from(i: i32) -> Self {
187
216
        Literal::Int(i)
188
216
    }
189
}
190

            
191
impl From<bool> for Literal {
192
630
    fn from(b: bool) -> Self {
193
630
        Literal::Bool(b)
194
630
    }
195
}
196

            
197
// need display implementations for other types as well
198
impl Display for Literal {
199
69984
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
200
69984
        match &self {
201
60732
            Literal::Int(i) => write!(f, "{}", i),
202
9252
            Literal::Bool(b) => write!(f, "{}", b),
203
            Literal::AbstractLiteral(l) => write!(f, "{:?}", l),
204
        }
205
69984
    }
206
}