1use itertools::Itertools;
2use serde::{Deserialize, Serialize};
3use std::fmt::{Display, Formatter};
4use std::hash::Hash;
5use std::hash::Hasher;
6
7use uniplate::derive::Uniplate;
8use uniplate::{Biplate, Tree, Uniplate};
9
10use super::{records::RecordValue, Atom, Domain, Expression, Range};
11use super::{ReturnType, Typeable};
12
13#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate, Hash)]
14#[uniplate(walk_into=[AbstractLiteral<Literal>])]
15#[biplate(to=Atom)]
16#[biplate(to=AbstractLiteral<Literal>)]
17#[biplate(to=AbstractLiteral<Expression>)]
18#[biplate(to=RecordValue<Literal>,walk_into=[AbstractLiteral<Literal>])]
19#[biplate(to=RecordValue<Expression>)]
20#[biplate(to=Expression)]
21pub enum Literal {
23 Int(i32),
24 Bool(bool),
25 AbstractLiteral(AbstractLiteral<Literal>),
26}
27
28pub trait AbstractLiteralValue:
30 Clone + Eq + PartialEq + Display + Uniplate + Biplate<RecordValue<Self>> + 'static
31{
32}
33impl AbstractLiteralValue for Expression {}
34impl AbstractLiteralValue for Literal {}
35
36#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
37pub enum AbstractLiteral<T: AbstractLiteralValue> {
38 Set(Vec<T>),
39
40 Matrix(Vec<T>, Domain),
42
43 Tuple(Vec<T>),
45
46 Record(Vec<RecordValue<T>>),
47}
48
49impl Typeable for Literal {
50 fn return_type(&self) -> Option<ReturnType> {
51 match self {
52 Literal::Int(_) => Some(ReturnType::Int),
53 Literal::Bool(_) => Some(ReturnType::Bool),
54 Literal::AbstractLiteral(a) => a.return_type(),
55 }
56 }
57}
58
59impl<T: AbstractLiteralValue + Typeable> Typeable for AbstractLiteral<T> {
61 fn return_type(&self) -> Option<ReturnType> {
62 match self {
63 AbstractLiteral::Set(vector) => {
64 Some(ReturnType::Set(Box::new(vector.first()?.return_type()?)))
65 }
66 AbstractLiteral::Matrix(vector, _) => {
67 Some(ReturnType::Matrix(Box::new(vector.first()?.return_type()?)))
68 }
69 _ => None,
70 }
71 }
72}
73
74impl<T> AbstractLiteral<T>
75where
76 T: AbstractLiteralValue,
77{
78 pub fn matrix_implied_indices(elems: Vec<T>) -> Self {
82 AbstractLiteral::Matrix(elems, Domain::IntDomain(vec![Range::UnboundedR(1)]))
83 }
84
85 pub fn unwrap_list(&self) -> Option<&Vec<T>> {
90 let AbstractLiteral::Matrix(elems, Domain::IntDomain(ranges)) = self else {
91 return None;
92 };
93
94 let [Range::UnboundedR(1)] = ranges[..] else {
95 return None;
96 };
97
98 Some(elems)
99 }
100}
101
102impl<T> Display for AbstractLiteral<T>
103where
104 T: AbstractLiteralValue,
105{
106 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
107 match self {
108 AbstractLiteral::Set(elems) => {
109 let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
110 write!(f, "{{{elems_str}}}")
111 }
112 AbstractLiteral::Matrix(elems, index_domain) => {
113 let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
114 write!(f, "[{elems_str};{index_domain}]")
115 }
116 AbstractLiteral::Tuple(elems) => {
117 let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
118 write!(f, "({elems_str})")
119 }
120 AbstractLiteral::Record(entries) => {
121 let entries_str: String = entries
122 .iter()
123 .map(|entry| format!("{}: {}", entry.name, entry.value))
124 .join(",");
125 write!(f, "{{{entries_str}}}")
126 }
127 }
128 }
129}
130
131impl Hash for AbstractLiteral<Literal> {
132 fn hash<H: Hasher>(&self, state: &mut H) {
133 match self {
134 AbstractLiteral::Set(vec) => {
135 0.hash(state);
136 vec.hash(state);
137 }
138 AbstractLiteral::Matrix(elems, index_domain) => {
139 1.hash(state);
140 elems.hash(state);
141 index_domain.hash(state);
142 }
143 AbstractLiteral::Tuple(elems) => {
144 2.hash(state);
145 elems.hash(state);
146 }
147 AbstractLiteral::Record(entries) => {
148 3.hash(state);
149 entries.hash(state);
150 }
151 }
152 }
153}
154
155impl<T> Uniplate for AbstractLiteral<T>
156where
157 T: AbstractLiteralValue + Biplate<AbstractLiteral<T>>,
158{
159 fn uniplate(&self) -> (Tree<Self>, Box<dyn Fn(Tree<Self>) -> Self>) {
160 match self {
162 AbstractLiteral::Set(vec) => {
163 let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(vec);
164 (f1_tree, Box::new(move |x| AbstractLiteral::Set(f1_ctx(x))))
165 }
166 AbstractLiteral::Matrix(elems, index_domain) => {
167 let index_domain = index_domain.clone();
168 let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(elems);
169 (
170 f1_tree,
171 Box::new(move |x| AbstractLiteral::Matrix(f1_ctx(x), index_domain.clone())),
172 )
173 }
174 AbstractLiteral::Tuple(elems) => {
175 let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(elems);
176 (
177 f1_tree,
178 Box::new(move |x| AbstractLiteral::Tuple(f1_ctx(x))),
179 )
180 }
181 AbstractLiteral::Record(entries) => {
182 let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(entries);
183 (
184 f1_tree,
185 Box::new(move |x| AbstractLiteral::Record(f1_ctx(x))),
186 )
187 }
188 }
189 }
190}
191
192impl<U, To> Biplate<To> for AbstractLiteral<U>
193where
194 To: Uniplate,
195 U: AbstractLiteralValue + Biplate<AbstractLiteral<U>> + Biplate<To>,
196 RecordValue<U>: Biplate<AbstractLiteral<U>> + Biplate<To>,
197{
198 fn biplate(&self) -> (Tree<To>, Box<dyn Fn(Tree<To>) -> Self>) {
199 if std::any::TypeId::of::<To>() == std::any::TypeId::of::<AbstractLiteral<U>>() {
200 unsafe {
203 let self_to = std::mem::transmute::<&AbstractLiteral<U>, &To>(self).clone();
205 let tree = Tree::One(self_to.clone());
206 let ctx = Box::new(move |x| {
207 let Tree::One(x) = x else {
208 panic!();
209 };
210
211 std::mem::transmute::<&To, &AbstractLiteral<U>>(&x).clone()
212 });
213
214 (tree, ctx)
215 }
216 } else {
217 match self {
219 AbstractLiteral::Set(vec) => {
220 let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(vec);
221 (f1_tree, Box::new(move |x| AbstractLiteral::Set(f1_ctx(x))))
222 }
223 AbstractLiteral::Matrix(elems, index_domain) => {
224 let index_domain = index_domain.clone();
225 let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(elems);
226 (
227 f1_tree,
228 Box::new(move |x| AbstractLiteral::Matrix(f1_ctx(x), index_domain.clone())),
229 )
230 }
231 AbstractLiteral::Tuple(elems) => {
232 let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(elems);
233 (
234 f1_tree,
235 Box::new(move |x| AbstractLiteral::Tuple(f1_ctx(x))),
236 )
237 }
238 AbstractLiteral::Record(entries) => {
239 let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(entries);
240 (
241 f1_tree,
242 Box::new(move |x| AbstractLiteral::Record(f1_ctx(x))),
243 )
244 }
245 }
246 }
247 }
248}
249
250impl TryFrom<Literal> for i32 {
251 type Error = &'static str;
252
253 fn try_from(value: Literal) -> Result<Self, Self::Error> {
254 match value {
255 Literal::Int(i) => Ok(i),
256 _ => Err("Cannot convert non-i32 literal to i32"),
257 }
258 }
259}
260
261impl TryFrom<&Literal> for i32 {
262 type Error = &'static str;
263
264 fn try_from(value: &Literal) -> Result<Self, Self::Error> {
265 match value {
266 Literal::Int(i) => Ok(*i),
267 _ => Err("Cannot convert non-i32 literal to i32"),
268 }
269 }
270}
271
272impl TryFrom<Literal> for bool {
273 type Error = &'static str;
274
275 fn try_from(value: Literal) -> Result<Self, Self::Error> {
276 match value {
277 Literal::Bool(b) => Ok(b),
278 _ => Err("Cannot convert non-bool literal to bool"),
279 }
280 }
281}
282
283impl TryFrom<&Literal> for bool {
284 type Error = &'static str;
285
286 fn try_from(value: &Literal) -> Result<Self, Self::Error> {
287 match value {
288 Literal::Bool(b) => Ok(*b),
289 _ => Err("Cannot convert non-bool literal to bool"),
290 }
291 }
292}
293
294impl From<i32> for Literal {
295 fn from(i: i32) -> Self {
296 Literal::Int(i)
297 }
298}
299
300impl From<bool> for Literal {
301 fn from(b: bool) -> Self {
302 Literal::Bool(b)
303 }
304}
305
306impl AbstractLiteral<Expression> {
307 pub fn as_literals(self) -> Option<AbstractLiteral<Literal>> {
310 match self {
311 AbstractLiteral::Set(_) => todo!(),
312 AbstractLiteral::Matrix(items, domain) => {
313 let mut literals = vec![];
314 for item in items {
315 let literal = match item {
316 Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
317 Expression::AbstractLiteral(_, abslit) => {
318 Some(Literal::AbstractLiteral(abslit.as_literals()?))
319 }
320 _ => None,
321 }?;
322 literals.push(literal);
323 }
324
325 Some(AbstractLiteral::Matrix(literals, domain))
326 }
327 AbstractLiteral::Tuple(items) => {
328 let mut literals = vec![];
329 for item in items {
330 let literal = match item {
331 Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
332 Expression::AbstractLiteral(_, abslit) => {
333 Some(Literal::AbstractLiteral(abslit.as_literals()?))
334 }
335 _ => None,
336 }?;
337 literals.push(literal);
338 }
339
340 Some(AbstractLiteral::Tuple(literals))
341 }
342 AbstractLiteral::Record(entries) => {
343 let mut literals = vec![];
344 for entry in entries {
345 let literal = match entry.value {
346 Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
347 Expression::AbstractLiteral(_, abslit) => {
348 Some(Literal::AbstractLiteral(abslit.as_literals()?))
349 }
350 _ => None,
351 }?;
352
353 literals.push((entry.name, literal));
354 }
355 Some(AbstractLiteral::Record(
356 literals
357 .into_iter()
358 .map(|(name, literal)| RecordValue {
359 name,
360 value: literal,
361 })
362 .collect(),
363 ))
364 }
365 }
366 }
367}
368
369impl Display for Literal {
371 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
372 match &self {
373 Literal::Int(i) => write!(f, "{}", i),
374 Literal::Bool(b) => write!(f, "{}", b),
375 Literal::AbstractLiteral(l) => write!(f, "{:?}", l),
376 }
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use std::sync::Arc;
383
384 use super::*;
385 use crate::{into_matrix, matrix};
386 use uniplate::Uniplate;
387
388 #[test]
389 fn matrix_uniplate_universe() {
390 let my_matrix: AbstractLiteral<Literal> = into_matrix![
392 vec![Literal::AbstractLiteral(matrix![Literal::Bool(true);Domain::BoolDomain]); 5];
393 Domain::BoolDomain
394 ];
395
396 let expected_index_domains = vec![Domain::BoolDomain; 6];
397 let actual_index_domains: Vec<Domain> = my_matrix.cata(Arc::new(move |elem, children| {
398 let mut res = vec![];
399 res.extend(children.into_iter().flatten());
400 if let AbstractLiteral::Matrix(_, index_domain) = elem {
401 res.push(index_domain);
402 }
403
404 res
405 }));
406
407 assert_eq!(actual_index_domains, expected_index_domains);
408 }
409}