Skip to main content

conjure_cp_core/ast/domains/
unresolved.rs

1use crate::ast::domains::attrs::MSetAttr;
2use crate::ast::domains::attrs::SetAttr;
3use crate::ast::{
4    DeclarationKind, DomainOpError, Expression, FuncAttr, Literal, Metadata, Moo,
5    RecordEntryGround, Reference, Typeable,
6    domains::{
7        GroundDomain,
8        domain::{DomainPtr, Int},
9        range::Range,
10    },
11};
12use crate::{bug, domain_int, matrix_expr, range};
13use conjure_cp_core::ast::pretty::pretty_vec;
14use conjure_cp_core::ast::{Name, ReturnType, eval_constant};
15use itertools::Itertools;
16use polyquine::Quine;
17use serde::{Deserialize, Serialize};
18use std::fmt::{Display, Formatter};
19use std::iter::zip;
20use std::ops::Deref;
21use uniplate::Uniplate;
22
23#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Quine, Uniplate)]
24#[path_prefix(conjure_cp::ast)]
25#[biplate(to=Expression)]
26#[biplate(to=Reference)]
27pub enum IntVal {
28    Const(Int),
29    #[polyquine_skip]
30    Reference(Reference),
31    Expr(Moo<Expression>),
32}
33
34impl From<Int> for IntVal {
35    fn from(value: Int) -> Self {
36        Self::Const(value)
37    }
38}
39
40impl TryInto<Int> for IntVal {
41    type Error = DomainOpError;
42
43    fn try_into(self) -> Result<Int, Self::Error> {
44        match self {
45            IntVal::Const(val) => Ok(val),
46            _ => Err(DomainOpError::NotGround),
47        }
48    }
49}
50
51impl From<Range<Int>> for Range<IntVal> {
52    fn from(value: Range<Int>) -> Self {
53        match value {
54            Range::Single(x) => Range::Single(x.into()),
55            Range::Bounded(l, r) => Range::Bounded(l.into(), r.into()),
56            Range::UnboundedL(r) => Range::UnboundedL(r.into()),
57            Range::UnboundedR(l) => Range::UnboundedR(l.into()),
58            Range::Unbounded => Range::Unbounded,
59        }
60    }
61}
62
63impl TryInto<Range<Int>> for Range<IntVal> {
64    type Error = DomainOpError;
65
66    fn try_into(self) -> Result<Range<Int>, Self::Error> {
67        match self {
68            Range::Single(x) => Ok(Range::Single(x.try_into()?)),
69            Range::Bounded(l, r) => Ok(Range::Bounded(l.try_into()?, r.try_into()?)),
70            Range::UnboundedL(r) => Ok(Range::UnboundedL(r.try_into()?)),
71            Range::UnboundedR(l) => Ok(Range::UnboundedR(l.try_into()?)),
72            Range::Unbounded => Ok(Range::Unbounded),
73        }
74    }
75}
76
77impl From<SetAttr<Int>> for SetAttr<IntVal> {
78    fn from(value: SetAttr<Int>) -> Self {
79        SetAttr {
80            size: value.size.into(),
81        }
82    }
83}
84
85impl TryInto<SetAttr<Int>> for SetAttr<IntVal> {
86    type Error = DomainOpError;
87
88    fn try_into(self) -> Result<SetAttr<Int>, Self::Error> {
89        let size: Range<Int> = self.size.try_into()?;
90        Ok(SetAttr { size })
91    }
92}
93
94impl From<MSetAttr<Int>> for MSetAttr<IntVal> {
95    fn from(value: MSetAttr<Int>) -> Self {
96        MSetAttr {
97            size: value.size.into(),
98            occurrence: value.occurrence.into(),
99        }
100    }
101}
102
103impl TryInto<MSetAttr<Int>> for MSetAttr<IntVal> {
104    type Error = DomainOpError;
105
106    fn try_into(self) -> Result<MSetAttr<Int>, Self::Error> {
107        let size: Range<Int> = self.size.try_into()?;
108        let occurrence: Range<Int> = self.occurrence.try_into()?;
109        Ok(MSetAttr { size, occurrence })
110    }
111}
112
113impl From<FuncAttr<Int>> for FuncAttr<IntVal> {
114    fn from(value: FuncAttr<Int>) -> Self {
115        FuncAttr {
116            size: value.size.into(),
117            partiality: value.partiality,
118            jectivity: value.jectivity,
119        }
120    }
121}
122
123impl TryInto<FuncAttr<Int>> for FuncAttr<IntVal> {
124    type Error = DomainOpError;
125
126    fn try_into(self) -> Result<FuncAttr<Int>, Self::Error> {
127        let size: Range<Int> = self.size.try_into()?;
128        Ok(FuncAttr {
129            size,
130            jectivity: self.jectivity,
131            partiality: self.partiality,
132        })
133    }
134}
135
136impl Display for IntVal {
137    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
138        match self {
139            IntVal::Const(val) => write!(f, "{val}"),
140            IntVal::Reference(re) => write!(f, "{re}"),
141            IntVal::Expr(expr) => write!(f, "({expr})"),
142        }
143    }
144}
145
146impl IntVal {
147    pub fn new_ref(re: &Reference) -> Option<IntVal> {
148        match re.ptr.kind().deref() {
149            DeclarationKind::ValueLetting(expr) => match expr.return_type() {
150                ReturnType::Int => Some(IntVal::Reference(re.clone())),
151                _ => None,
152            },
153            DeclarationKind::Given(dom) => match dom.return_type() {
154                ReturnType::Int => Some(IntVal::Reference(re.clone())),
155                _ => None,
156            },
157            DeclarationKind::Quantified(inner) => match inner.domain().return_type() {
158                ReturnType::Int => Some(IntVal::Reference(re.clone())),
159                _ => None,
160            },
161            DeclarationKind::DomainLetting(_)
162            | DeclarationKind::RecordField(_)
163            | DeclarationKind::Find(_) => None,
164        }
165    }
166
167    pub fn new_expr(value: Moo<Expression>) -> Option<IntVal> {
168        if value.return_type() != ReturnType::Int {
169            return None;
170        }
171        Some(IntVal::Expr(value))
172    }
173
174    pub fn resolve(&self) -> Option<Int> {
175        match self {
176            IntVal::Const(value) => Some(*value),
177            IntVal::Expr(expr) => eval_expr_to_int(expr),
178            IntVal::Reference(re) => match re.ptr.kind().deref() {
179                DeclarationKind::ValueLetting(expr) => eval_expr_to_int(expr),
180                // If this is an int given we will be able to resolve it eventually, but not yet
181                DeclarationKind::Given(_) | DeclarationKind::Quantified(..) => None,
182                DeclarationKind::DomainLetting(_)
183                | DeclarationKind::RecordField(_)
184                | DeclarationKind::Find(_) => bug!(
185                    "Expected integer expression, given, or letting inside int domain; Got: {re}"
186                ),
187            },
188        }
189    }
190}
191
192fn eval_expr_to_int(expr: &Expression) -> Option<Int> {
193    match eval_constant(expr)? {
194        Literal::Int(v) => Some(v),
195        _ => bug!("Expected integer expression, got: {expr}"),
196    }
197}
198
199impl From<IntVal> for Expression {
200    fn from(value: IntVal) -> Self {
201        match value {
202            IntVal::Const(val) => val.into(),
203            IntVal::Reference(re) => re.into(),
204            IntVal::Expr(expr) => expr.as_ref().clone(),
205        }
206    }
207}
208
209impl From<IntVal> for Moo<Expression> {
210    fn from(value: IntVal) -> Self {
211        match value {
212            IntVal::Const(val) => Moo::new(val.into()),
213            IntVal::Reference(re) => Moo::new(re.into()),
214            IntVal::Expr(expr) => expr,
215        }
216    }
217}
218
219impl std::ops::Neg for IntVal {
220    type Output = IntVal;
221
222    fn neg(self) -> Self::Output {
223        match self {
224            IntVal::Const(val) => IntVal::Const(-val),
225            IntVal::Reference(_) | IntVal::Expr(_) => {
226                IntVal::Expr(Moo::new(Expression::Neg(Metadata::new(), self.into())))
227            }
228        }
229    }
230}
231
232impl<T> std::ops::Add<T> for IntVal
233where
234    T: Into<Expression>,
235{
236    type Output = IntVal;
237
238    fn add(self, rhs: T) -> Self::Output {
239        let lhs: Expression = self.into();
240        let rhs: Expression = rhs.into();
241        let sum = matrix_expr!(lhs, rhs; domain_int!(1..));
242        IntVal::Expr(Moo::new(Expression::Sum(Metadata::new(), Moo::new(sum))))
243    }
244}
245
246impl Range<IntVal> {
247    pub fn resolve(&self) -> Option<Range<Int>> {
248        match self {
249            Range::Single(x) => Some(Range::Single(x.resolve()?)),
250            Range::Bounded(l, r) => Some(Range::Bounded(l.resolve()?, r.resolve()?)),
251            Range::UnboundedL(r) => Some(Range::UnboundedL(r.resolve()?)),
252            Range::UnboundedR(l) => Some(Range::UnboundedR(l.resolve()?)),
253            Range::Unbounded => Some(Range::Unbounded),
254        }
255    }
256}
257
258impl SetAttr<IntVal> {
259    pub fn resolve(&self) -> Option<SetAttr<Int>> {
260        Some(SetAttr {
261            size: self.size.resolve()?,
262        })
263    }
264}
265
266impl MSetAttr<IntVal> {
267    pub fn resolve(&self) -> Option<MSetAttr<Int>> {
268        Some(MSetAttr {
269            size: self.size.resolve()?,
270            occurrence: self.occurrence.resolve()?,
271        })
272    }
273}
274
275impl FuncAttr<IntVal> {
276    pub fn resolve(&self) -> Option<FuncAttr<Int>> {
277        Some(FuncAttr {
278            size: self.size.resolve()?,
279            partiality: self.partiality.clone(),
280            jectivity: self.jectivity.clone(),
281        })
282    }
283}
284
285#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Uniplate, Quine)]
286#[path_prefix(conjure_cp::ast)]
287pub struct RecordEntry {
288    pub name: Name,
289    pub domain: DomainPtr,
290}
291
292impl RecordEntry {
293    pub fn resolve(self) -> Option<RecordEntryGround> {
294        Some(RecordEntryGround {
295            name: self.name,
296            domain: self.domain.resolve()?,
297        })
298    }
299}
300
301#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Quine, Uniplate)]
302#[path_prefix(conjure_cp::ast)]
303#[biplate(to=Expression)]
304#[biplate(to=Reference)]
305#[biplate(to=IntVal)]
306#[biplate(to=DomainPtr)]
307#[biplate(to=RecordEntry)]
308pub enum UnresolvedDomain {
309    Int(Vec<Range<IntVal>>),
310    /// A set of elements drawn from the inner domain
311    Set(SetAttr<IntVal>, DomainPtr),
312    MSet(MSetAttr<IntVal>, DomainPtr),
313    /// A n-dimensional matrix with a value domain and n-index domains
314    Matrix(DomainPtr, Vec<DomainPtr>),
315    /// A tuple of N elements, each with its own domain
316    Tuple(Vec<DomainPtr>),
317    /// A reference to a domain letting
318    #[polyquine_skip]
319    Reference(Reference),
320    /// A record
321    Record(Vec<RecordEntry>),
322    /// A function with attributes, domain, and range
323    Function(FuncAttr<IntVal>, DomainPtr, DomainPtr),
324}
325
326impl UnresolvedDomain {
327    pub fn resolve(&self) -> Option<GroundDomain> {
328        match self {
329            UnresolvedDomain::Int(rngs) => rngs
330                .iter()
331                .map(Range::<IntVal>::resolve)
332                .collect::<Option<_>>()
333                .map(GroundDomain::Int),
334            UnresolvedDomain::Set(attr, inner) => {
335                Some(GroundDomain::Set(attr.resolve()?, inner.resolve()?))
336            }
337            UnresolvedDomain::MSet(attr, inner) => {
338                Some(GroundDomain::MSet(attr.resolve()?, inner.resolve()?))
339            }
340            UnresolvedDomain::Matrix(inner, idx_doms) => {
341                let inner_gd = inner.resolve()?;
342                idx_doms
343                    .iter()
344                    .map(DomainPtr::resolve)
345                    .collect::<Option<_>>()
346                    .map(|idx| GroundDomain::Matrix(inner_gd, idx))
347            }
348            UnresolvedDomain::Tuple(inners) => inners
349                .iter()
350                .map(DomainPtr::resolve)
351                .collect::<Option<_>>()
352                .map(GroundDomain::Tuple),
353            UnresolvedDomain::Record(entries) => entries
354                .iter()
355                .map(|f| {
356                    f.domain.resolve().map(|gd| RecordEntryGround {
357                        name: f.name.clone(),
358                        domain: gd,
359                    })
360                })
361                .collect::<Option<_>>()
362                .map(GroundDomain::Record),
363            UnresolvedDomain::Reference(re) => re
364                .ptr
365                .as_domain_letting()
366                .unwrap_or_else(|| {
367                    bug!("Reference domain should point to domain letting, but got {re}")
368                })
369                .resolve()
370                .map(Moo::unwrap_or_clone),
371            UnresolvedDomain::Function(attr, dom, cdom) => {
372                if let Some(attr_gd) = attr.resolve()
373                    && let Some(dom_gd) = dom.resolve()
374                    && let Some(cdom_gd) = cdom.resolve()
375                {
376                    return Some(GroundDomain::Function(attr_gd, dom_gd, cdom_gd));
377                }
378                None
379            }
380        }
381    }
382
383    pub(super) fn union_unresolved(
384        &self,
385        other: &UnresolvedDomain,
386    ) -> Result<UnresolvedDomain, DomainOpError> {
387        match (self, other) {
388            (UnresolvedDomain::Int(lhs), UnresolvedDomain::Int(rhs)) => {
389                let merged = lhs.iter().chain(rhs.iter()).cloned().collect_vec();
390                Ok(UnresolvedDomain::Int(merged))
391            }
392            (UnresolvedDomain::Int(_), _) | (_, UnresolvedDomain::Int(_)) => {
393                Err(DomainOpError::WrongType)
394            }
395            (UnresolvedDomain::Set(_, in1), UnresolvedDomain::Set(_, in2)) => {
396                Ok(UnresolvedDomain::Set(SetAttr::default(), in1.union(in2)?))
397            }
398            (UnresolvedDomain::Set(_, _), _) | (_, UnresolvedDomain::Set(_, _)) => {
399                Err(DomainOpError::WrongType)
400            }
401            (UnresolvedDomain::MSet(_, in1), UnresolvedDomain::MSet(_, in2)) => {
402                Ok(UnresolvedDomain::MSet(MSetAttr::default(), in1.union(in2)?))
403            }
404            (UnresolvedDomain::MSet(_, _), _) | (_, UnresolvedDomain::MSet(_, _)) => {
405                Err(DomainOpError::WrongType)
406            }
407            (UnresolvedDomain::Matrix(in1, idx1), UnresolvedDomain::Matrix(in2, idx2))
408                if idx1 == idx2 =>
409            {
410                Ok(UnresolvedDomain::Matrix(in1.union(in2)?, idx1.clone()))
411            }
412            (UnresolvedDomain::Matrix(_, _), _) | (_, UnresolvedDomain::Matrix(_, _)) => {
413                Err(DomainOpError::WrongType)
414            }
415            (UnresolvedDomain::Tuple(lhs), UnresolvedDomain::Tuple(rhs))
416                if lhs.len() == rhs.len() =>
417            {
418                let mut merged = Vec::new();
419                for (l, r) in zip(lhs, rhs) {
420                    merged.push(l.union(r)?)
421                }
422                Ok(UnresolvedDomain::Tuple(merged))
423            }
424            (UnresolvedDomain::Tuple(_), _) | (_, UnresolvedDomain::Tuple(_)) => {
425                Err(DomainOpError::WrongType)
426            }
427            // TODO: Could we support unions of reference domains symbolically?
428            (UnresolvedDomain::Reference(_), _) | (_, UnresolvedDomain::Reference(_)) => {
429                Err(DomainOpError::NotGround)
430            }
431            // TODO: Could we define semantics for merging record domains?
432            #[allow(unreachable_patterns)] // Technically redundant but logically makes sense
433            (UnresolvedDomain::Record(_), _) | (_, UnresolvedDomain::Record(_)) => {
434                Err(DomainOpError::WrongType)
435            }
436            #[allow(unreachable_patterns)]
437            // Technically redundant but logically clearer to have both
438            (UnresolvedDomain::Function(_, _, _), _) | (_, UnresolvedDomain::Function(_, _, _)) => {
439                Err(DomainOpError::WrongType)
440            }
441        }
442    }
443
444    pub fn element_domain(&self) -> Option<DomainPtr> {
445        match self {
446            UnresolvedDomain::Set(_, inner_dom) => Some(inner_dom.clone()),
447            UnresolvedDomain::Matrix(_, _) => {
448                todo!("Unwrap one dimension of the domain")
449            }
450            _ => None,
451        }
452    }
453}
454
455impl Typeable for UnresolvedDomain {
456    fn return_type(&self) -> ReturnType {
457        match self {
458            UnresolvedDomain::Reference(re) => re.return_type(),
459            UnresolvedDomain::Int(_) => ReturnType::Int,
460            UnresolvedDomain::Set(_attr, inner) => ReturnType::Set(Box::new(inner.return_type())),
461            UnresolvedDomain::MSet(_attr, inner) => ReturnType::MSet(Box::new(inner.return_type())),
462            UnresolvedDomain::Matrix(inner, _idx) => {
463                ReturnType::Matrix(Box::new(inner.return_type()))
464            }
465            UnresolvedDomain::Tuple(inners) => {
466                let mut inner_types = Vec::new();
467                for inner in inners {
468                    inner_types.push(inner.return_type());
469                }
470                ReturnType::Tuple(inner_types)
471            }
472            UnresolvedDomain::Record(entries) => {
473                let mut entry_types = Vec::new();
474                for entry in entries {
475                    entry_types.push(entry.domain.return_type());
476                }
477                ReturnType::Record(entry_types)
478            }
479            UnresolvedDomain::Function(_, dom, cdom) => {
480                ReturnType::Function(Box::new(dom.return_type()), Box::new(cdom.return_type()))
481            }
482        }
483    }
484}
485
486impl Display for UnresolvedDomain {
487    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
488        match &self {
489            UnresolvedDomain::Reference(re) => write!(f, "{re}"),
490            UnresolvedDomain::Int(ranges) => {
491                if ranges.iter().all(Range::is_lower_or_upper_bounded) {
492                    let rngs: String = ranges.iter().map(|r| format!("{r}")).join(", ");
493                    write!(f, "int({})", rngs)
494                } else {
495                    write!(f, "int")
496                }
497            }
498            UnresolvedDomain::Set(attrs, inner_dom) => write!(f, "set {attrs} of {inner_dom}"),
499            UnresolvedDomain::MSet(attrs, inner_dom) => write!(f, "mset {attrs} of {inner_dom}"),
500            UnresolvedDomain::Matrix(value_domain, index_domains) => {
501                write!(
502                    f,
503                    "matrix indexed by [{}] of {value_domain}",
504                    pretty_vec(&index_domains.iter().collect_vec())
505                )
506            }
507            UnresolvedDomain::Tuple(domains) => {
508                write!(
509                    f,
510                    "tuple of ({})",
511                    pretty_vec(&domains.iter().collect_vec())
512                )
513            }
514            UnresolvedDomain::Record(entries) => {
515                write!(
516                    f,
517                    "record of ({})",
518                    pretty_vec(
519                        &entries
520                            .iter()
521                            .map(|entry| format!("{}: {}", entry.name, entry.domain))
522                            .collect_vec()
523                    )
524                )
525            }
526            UnresolvedDomain::Function(attribute, domain, codomain) => {
527                write!(f, "function {} {} --> {} ", attribute, domain, codomain)
528            }
529        }
530    }
531}