Skip to main content

conjure_cp_core/ast/domains/
unresolved.rs

1use crate::ast::domains::attrs::MSetAttr;
2use crate::ast::domains::attrs::SetAttr;
3use crate::ast::{
4    DeclarationKind, DomainOpError, Expression, FuncAttr, Literal, Metadata, Moo,
5    RecordEntryGround, Reference, Typeable,
6    domains::{
7        GroundDomain,
8        domain::{DomainPtr, Int},
9        range::Range,
10    },
11};
12use crate::{bug, domain_int, matrix_expr, range};
13use conjure_cp_core::ast::pretty::pretty_vec;
14use conjure_cp_core::ast::{Name, ReturnType, eval_constant};
15use itertools::Itertools;
16use polyquine::Quine;
17use serde::{Deserialize, Serialize};
18use std::fmt::{Display, Formatter};
19use std::iter::zip;
20use std::ops::Deref;
21use uniplate::Uniplate;
22
23#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Quine, Uniplate)]
24#[path_prefix(conjure_cp::ast)]
25#[biplate(to=Expression)]
26#[biplate(to=Reference)]
27pub enum IntVal {
28    Const(Int),
29    #[polyquine_skip]
30    Reference(Reference),
31    Expr(Moo<Expression>),
32}
33
34impl From<Int> for IntVal {
35    fn from(value: Int) -> Self {
36        Self::Const(value)
37    }
38}
39
40impl TryInto<Int> for IntVal {
41    type Error = DomainOpError;
42
43    fn try_into(self) -> Result<Int, Self::Error> {
44        match self {
45            IntVal::Const(val) => Ok(val),
46            _ => Err(DomainOpError::NotGround),
47        }
48    }
49}
50
51impl From<Range<Int>> for Range<IntVal> {
52    fn from(value: Range<Int>) -> Self {
53        match value {
54            Range::Single(x) => Range::Single(x.into()),
55            Range::Bounded(l, r) => Range::Bounded(l.into(), r.into()),
56            Range::UnboundedL(r) => Range::UnboundedL(r.into()),
57            Range::UnboundedR(l) => Range::UnboundedR(l.into()),
58            Range::Unbounded => Range::Unbounded,
59        }
60    }
61}
62
63impl TryInto<Range<Int>> for Range<IntVal> {
64    type Error = DomainOpError;
65
66    fn try_into(self) -> Result<Range<Int>, Self::Error> {
67        match self {
68            Range::Single(x) => Ok(Range::Single(x.try_into()?)),
69            Range::Bounded(l, r) => Ok(Range::Bounded(l.try_into()?, r.try_into()?)),
70            Range::UnboundedL(r) => Ok(Range::UnboundedL(r.try_into()?)),
71            Range::UnboundedR(l) => Ok(Range::UnboundedR(l.try_into()?)),
72            Range::Unbounded => Ok(Range::Unbounded),
73        }
74    }
75}
76
77impl From<SetAttr<Int>> for SetAttr<IntVal> {
78    fn from(value: SetAttr<Int>) -> Self {
79        SetAttr {
80            size: value.size.into(),
81        }
82    }
83}
84
85impl TryInto<SetAttr<Int>> for SetAttr<IntVal> {
86    type Error = DomainOpError;
87
88    fn try_into(self) -> Result<SetAttr<Int>, Self::Error> {
89        let size: Range<Int> = self.size.try_into()?;
90        Ok(SetAttr { size })
91    }
92}
93
94impl From<MSetAttr<Int>> for MSetAttr<IntVal> {
95    fn from(value: MSetAttr<Int>) -> Self {
96        MSetAttr {
97            size: value.size.into(),
98            occurrence: value.occurrence.into(),
99        }
100    }
101}
102
103impl TryInto<MSetAttr<Int>> for MSetAttr<IntVal> {
104    type Error = DomainOpError;
105
106    fn try_into(self) -> Result<MSetAttr<Int>, Self::Error> {
107        let size: Range<Int> = self.size.try_into()?;
108        let occurrence: Range<Int> = self.occurrence.try_into()?;
109        Ok(MSetAttr { size, occurrence })
110    }
111}
112
113impl From<FuncAttr<Int>> for FuncAttr<IntVal> {
114    fn from(value: FuncAttr<Int>) -> Self {
115        FuncAttr {
116            size: value.size.into(),
117            partiality: value.partiality,
118            jectivity: value.jectivity,
119        }
120    }
121}
122
123impl TryInto<FuncAttr<Int>> for FuncAttr<IntVal> {
124    type Error = DomainOpError;
125
126    fn try_into(self) -> Result<FuncAttr<Int>, Self::Error> {
127        let size: Range<Int> = self.size.try_into()?;
128        Ok(FuncAttr {
129            size,
130            jectivity: self.jectivity,
131            partiality: self.partiality,
132        })
133    }
134}
135
136impl Display for IntVal {
137    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
138        match self {
139            IntVal::Const(val) => write!(f, "{val}"),
140            IntVal::Reference(re) => write!(f, "{re}"),
141            IntVal::Expr(expr) => write!(f, "({expr})"),
142        }
143    }
144}
145
146impl IntVal {
147    pub fn new_ref(re: &Reference) -> Option<IntVal> {
148        match re.ptr.kind().deref() {
149            DeclarationKind::ValueLetting(expr, _)
150            | DeclarationKind::TemporaryValueLetting(expr)
151            | DeclarationKind::QuantifiedExpr(expr) => match expr.return_type() {
152                ReturnType::Int => Some(IntVal::Reference(re.clone())),
153                _ => None,
154            },
155            DeclarationKind::Given(dom) => match dom.return_type() {
156                ReturnType::Int => Some(IntVal::Reference(re.clone())),
157                _ => None,
158            },
159            DeclarationKind::Quantified(inner) => match inner.domain().return_type() {
160                ReturnType::Int => Some(IntVal::Reference(re.clone())),
161                _ => None,
162            },
163            DeclarationKind::Find(var) => match var.return_type() {
164                ReturnType::Int => Some(IntVal::Reference(re.clone())),
165                _ => None,
166            },
167            DeclarationKind::DomainLetting(_) | DeclarationKind::RecordField(_) => None,
168        }
169    }
170
171    pub fn new_expr(value: Moo<Expression>) -> Option<IntVal> {
172        if value.return_type() != ReturnType::Int {
173            return None;
174        }
175        Some(IntVal::Expr(value))
176    }
177
178    pub fn resolve(&self) -> Option<Int> {
179        match self {
180            IntVal::Const(value) => Some(*value),
181            IntVal::Expr(expr) => eval_expr_to_int(expr),
182            IntVal::Reference(re) => match re.ptr.kind().deref() {
183                DeclarationKind::ValueLetting(expr, _)
184                | DeclarationKind::TemporaryValueLetting(expr) => eval_expr_to_int(expr),
185                // If this is an int given we will be able to resolve it eventually, but not yet
186                DeclarationKind::Given(_) => None,
187                DeclarationKind::Quantified(inner) => {
188                    if let Some(generator) = inner.generator()
189                        && let Some(expr) = generator.as_value_letting()
190                    {
191                        eval_expr_to_int(&expr)
192                    } else {
193                        None
194                    }
195                }
196                // TODO: idk what this whole file does but I very much doubt it affects this
197                DeclarationKind::QuantifiedExpr(_) => None,
198                // Decision variables inside domains are unresolved until solving.
199                DeclarationKind::Find(_) => None,
200                DeclarationKind::DomainLetting(_) | DeclarationKind::RecordField(_) => bug!(
201                    "Expected integer expression, given, or letting inside int domain; Got: {re}"
202                ),
203            },
204        }
205    }
206}
207
208fn eval_expr_to_int(expr: &Expression) -> Option<Int> {
209    match eval_constant(expr)? {
210        Literal::Int(v) => Some(v),
211        _ => bug!("Expected integer expression, got: {expr}"),
212    }
213}
214
215impl From<IntVal> for Expression {
216    fn from(value: IntVal) -> Self {
217        match value {
218            IntVal::Const(val) => val.into(),
219            IntVal::Reference(re) => re.into(),
220            IntVal::Expr(expr) => expr.as_ref().clone(),
221        }
222    }
223}
224
225impl From<IntVal> for Moo<Expression> {
226    fn from(value: IntVal) -> Self {
227        match value {
228            IntVal::Const(val) => Moo::new(val.into()),
229            IntVal::Reference(re) => Moo::new(re.into()),
230            IntVal::Expr(expr) => expr,
231        }
232    }
233}
234
235impl std::ops::Neg for IntVal {
236    type Output = IntVal;
237
238    fn neg(self) -> Self::Output {
239        match self {
240            IntVal::Const(val) => IntVal::Const(-val),
241            IntVal::Reference(_) | IntVal::Expr(_) => {
242                IntVal::Expr(Moo::new(Expression::Neg(Metadata::new(), self.into())))
243            }
244        }
245    }
246}
247
248impl<T> std::ops::Add<T> for IntVal
249where
250    T: Into<Expression>,
251{
252    type Output = IntVal;
253
254    fn add(self, rhs: T) -> Self::Output {
255        let lhs: Expression = self.into();
256        let rhs: Expression = rhs.into();
257        let sum = matrix_expr!(lhs, rhs; domain_int!(1..));
258        IntVal::Expr(Moo::new(Expression::Sum(Metadata::new(), Moo::new(sum))))
259    }
260}
261
262impl Range<IntVal> {
263    pub fn resolve(&self) -> Option<Range<Int>> {
264        match self {
265            Range::Single(x) => Some(Range::Single(x.resolve()?)),
266            Range::Bounded(l, r) => Some(Range::Bounded(l.resolve()?, r.resolve()?)),
267            Range::UnboundedL(r) => Some(Range::UnboundedL(r.resolve()?)),
268            Range::UnboundedR(l) => Some(Range::UnboundedR(l.resolve()?)),
269            Range::Unbounded => Some(Range::Unbounded),
270        }
271    }
272}
273
274impl SetAttr<IntVal> {
275    pub fn resolve(&self) -> Option<SetAttr<Int>> {
276        Some(SetAttr {
277            size: self.size.resolve()?,
278        })
279    }
280}
281
282impl MSetAttr<IntVal> {
283    pub fn resolve(&self) -> Option<MSetAttr<Int>> {
284        Some(MSetAttr {
285            size: self.size.resolve()?,
286            occurrence: self.occurrence.resolve()?,
287        })
288    }
289}
290
291impl FuncAttr<IntVal> {
292    pub fn resolve(&self) -> Option<FuncAttr<Int>> {
293        Some(FuncAttr {
294            size: self.size.resolve()?,
295            partiality: self.partiality.clone(),
296            jectivity: self.jectivity.clone(),
297        })
298    }
299}
300
301#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Uniplate, Quine)]
302#[path_prefix(conjure_cp::ast)]
303pub struct RecordEntry {
304    pub name: Name,
305    pub domain: DomainPtr,
306}
307
308impl RecordEntry {
309    pub fn resolve(self) -> Option<RecordEntryGround> {
310        Some(RecordEntryGround {
311            name: self.name,
312            domain: self.domain.resolve()?,
313        })
314    }
315}
316
317#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Quine, Uniplate)]
318#[path_prefix(conjure_cp::ast)]
319#[biplate(to=Expression)]
320#[biplate(to=Reference)]
321#[biplate(to=IntVal)]
322#[biplate(to=DomainPtr)]
323#[biplate(to=RecordEntry)]
324pub enum UnresolvedDomain {
325    Int(Vec<Range<IntVal>>),
326    /// A set of elements drawn from the inner domain
327    Set(SetAttr<IntVal>, DomainPtr),
328    MSet(MSetAttr<IntVal>, DomainPtr),
329    /// A n-dimensional matrix with a value domain and n-index domains
330    Matrix(DomainPtr, Vec<DomainPtr>),
331    /// A tuple of N elements, each with its own domain
332    Tuple(Vec<DomainPtr>),
333    /// A reference to a domain letting
334    #[polyquine_skip]
335    Reference(Reference),
336    /// A record
337    Record(Vec<RecordEntry>),
338    /// A function with attributes, domain, and range
339    Function(FuncAttr<IntVal>, DomainPtr, DomainPtr),
340}
341
342impl UnresolvedDomain {
343    pub fn resolve(&self) -> Option<GroundDomain> {
344        match self {
345            UnresolvedDomain::Int(rngs) => rngs
346                .iter()
347                .map(Range::<IntVal>::resolve)
348                .collect::<Option<_>>()
349                .map(GroundDomain::Int),
350            UnresolvedDomain::Set(attr, inner) => {
351                Some(GroundDomain::Set(attr.resolve()?, inner.resolve()?))
352            }
353            UnresolvedDomain::MSet(attr, inner) => {
354                Some(GroundDomain::MSet(attr.resolve()?, inner.resolve()?))
355            }
356            UnresolvedDomain::Matrix(inner, idx_doms) => {
357                let inner_gd = inner.resolve()?;
358                idx_doms
359                    .iter()
360                    .map(DomainPtr::resolve)
361                    .collect::<Option<_>>()
362                    .map(|idx| GroundDomain::Matrix(inner_gd, idx))
363            }
364            UnresolvedDomain::Tuple(inners) => inners
365                .iter()
366                .map(DomainPtr::resolve)
367                .collect::<Option<_>>()
368                .map(GroundDomain::Tuple),
369            UnresolvedDomain::Record(entries) => entries
370                .iter()
371                .map(|f| {
372                    f.domain.resolve().map(|gd| RecordEntryGround {
373                        name: f.name.clone(),
374                        domain: gd,
375                    })
376                })
377                .collect::<Option<_>>()
378                .map(GroundDomain::Record),
379            UnresolvedDomain::Reference(re) => re
380                .ptr
381                .as_domain_letting()
382                .unwrap_or_else(|| {
383                    bug!("Reference domain should point to domain letting, but got {re}")
384                })
385                .resolve()
386                .map(Moo::unwrap_or_clone),
387            UnresolvedDomain::Function(attr, dom, cdom) => {
388                if let Some(attr_gd) = attr.resolve()
389                    && let Some(dom_gd) = dom.resolve()
390                    && let Some(cdom_gd) = cdom.resolve()
391                {
392                    return Some(GroundDomain::Function(attr_gd, dom_gd, cdom_gd));
393                }
394                None
395            }
396        }
397    }
398
399    pub(super) fn union_unresolved(
400        &self,
401        other: &UnresolvedDomain,
402    ) -> Result<UnresolvedDomain, DomainOpError> {
403        match (self, other) {
404            (UnresolvedDomain::Int(lhs), UnresolvedDomain::Int(rhs)) => {
405                let merged = lhs.iter().chain(rhs.iter()).cloned().collect_vec();
406                Ok(UnresolvedDomain::Int(merged))
407            }
408            (UnresolvedDomain::Int(_), _) | (_, UnresolvedDomain::Int(_)) => {
409                Err(DomainOpError::WrongType)
410            }
411            (UnresolvedDomain::Set(_, in1), UnresolvedDomain::Set(_, in2)) => {
412                Ok(UnresolvedDomain::Set(SetAttr::default(), in1.union(in2)?))
413            }
414            (UnresolvedDomain::Set(_, _), _) | (_, UnresolvedDomain::Set(_, _)) => {
415                Err(DomainOpError::WrongType)
416            }
417            (UnresolvedDomain::MSet(_, in1), UnresolvedDomain::MSet(_, in2)) => {
418                Ok(UnresolvedDomain::MSet(MSetAttr::default(), in1.union(in2)?))
419            }
420            (UnresolvedDomain::MSet(_, _), _) | (_, UnresolvedDomain::MSet(_, _)) => {
421                Err(DomainOpError::WrongType)
422            }
423            (UnresolvedDomain::Matrix(in1, idx1), UnresolvedDomain::Matrix(in2, idx2))
424                if idx1 == idx2 =>
425            {
426                Ok(UnresolvedDomain::Matrix(in1.union(in2)?, idx1.clone()))
427            }
428            (UnresolvedDomain::Matrix(_, _), _) | (_, UnresolvedDomain::Matrix(_, _)) => {
429                Err(DomainOpError::WrongType)
430            }
431            (UnresolvedDomain::Tuple(lhs), UnresolvedDomain::Tuple(rhs))
432                if lhs.len() == rhs.len() =>
433            {
434                let mut merged = Vec::new();
435                for (l, r) in zip(lhs, rhs) {
436                    merged.push(l.union(r)?)
437                }
438                Ok(UnresolvedDomain::Tuple(merged))
439            }
440            (UnresolvedDomain::Tuple(_), _) | (_, UnresolvedDomain::Tuple(_)) => {
441                Err(DomainOpError::WrongType)
442            }
443            // TODO: Could we support unions of reference domains symbolically?
444            (UnresolvedDomain::Reference(_), _) | (_, UnresolvedDomain::Reference(_)) => {
445                Err(DomainOpError::NotGround)
446            }
447            // TODO: Could we define semantics for merging record domains?
448            #[allow(unreachable_patterns)] // Technically redundant but logically makes sense
449            (UnresolvedDomain::Record(_), _) | (_, UnresolvedDomain::Record(_)) => {
450                Err(DomainOpError::WrongType)
451            }
452            #[allow(unreachable_patterns)]
453            // Technically redundant but logically clearer to have both
454            (UnresolvedDomain::Function(_, _, _), _) | (_, UnresolvedDomain::Function(_, _, _)) => {
455                Err(DomainOpError::WrongType)
456            }
457        }
458    }
459
460    pub fn element_domain(&self) -> Option<DomainPtr> {
461        match self {
462            UnresolvedDomain::Set(_, inner_dom) => Some(inner_dom.clone()),
463            UnresolvedDomain::Matrix(_, _) => {
464                todo!("Unwrap one dimension of the domain")
465            }
466            _ => None,
467        }
468    }
469}
470
471impl Typeable for UnresolvedDomain {
472    fn return_type(&self) -> ReturnType {
473        match self {
474            UnresolvedDomain::Reference(re) => re.return_type(),
475            UnresolvedDomain::Int(_) => ReturnType::Int,
476            UnresolvedDomain::Set(_attr, inner) => ReturnType::Set(Box::new(inner.return_type())),
477            UnresolvedDomain::MSet(_attr, inner) => ReturnType::MSet(Box::new(inner.return_type())),
478            UnresolvedDomain::Matrix(inner, _idx) => {
479                ReturnType::Matrix(Box::new(inner.return_type()))
480            }
481            UnresolvedDomain::Tuple(inners) => {
482                let mut inner_types = Vec::new();
483                for inner in inners {
484                    inner_types.push(inner.return_type());
485                }
486                ReturnType::Tuple(inner_types)
487            }
488            UnresolvedDomain::Record(entries) => {
489                let mut entry_types = Vec::new();
490                for entry in entries {
491                    entry_types.push(entry.domain.return_type());
492                }
493                ReturnType::Record(entry_types)
494            }
495            UnresolvedDomain::Function(_, dom, cdom) => {
496                ReturnType::Function(Box::new(dom.return_type()), Box::new(cdom.return_type()))
497            }
498        }
499    }
500}
501
502impl Display for UnresolvedDomain {
503    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
504        match &self {
505            UnresolvedDomain::Reference(re) => write!(f, "{re}"),
506            UnresolvedDomain::Int(ranges) => {
507                if ranges.iter().all(Range::is_lower_or_upper_bounded) {
508                    let rngs: String = ranges.iter().map(|r| format!("{r}")).join(", ");
509                    write!(f, "int({})", rngs)
510                } else {
511                    write!(f, "int")
512                }
513            }
514            UnresolvedDomain::Set(attrs, inner_dom) => write!(f, "set {attrs} of {inner_dom}"),
515            UnresolvedDomain::MSet(attrs, inner_dom) => write!(f, "mset {attrs} of {inner_dom}"),
516            UnresolvedDomain::Matrix(value_domain, index_domains) => {
517                write!(
518                    f,
519                    "matrix indexed by {} of {value_domain}",
520                    pretty_vec(&index_domains.iter().collect_vec())
521                )
522            }
523            UnresolvedDomain::Tuple(domains) => {
524                write!(
525                    f,
526                    "tuple of ({})",
527                    pretty_vec(&domains.iter().collect_vec())
528                )
529            }
530            UnresolvedDomain::Record(entries) => {
531                write!(
532                    f,
533                    "record of ({})",
534                    pretty_vec(
535                        &entries
536                            .iter()
537                            .map(|entry| format!("{}: {}", entry.name, entry.domain))
538                            .collect_vec()
539                    )
540                )
541            }
542            UnresolvedDomain::Function(attribute, domain, codomain) => {
543                write!(f, "function {} {} --> {} ", attribute, domain, codomain)
544            }
545        }
546    }
547}