Skip to main content

conjure_cp_core/ast/domains/
unresolved.rs

1use crate::ast::domains::attrs::SetAttr;
2use crate::ast::{
3    DeclarationKind, DomainOpError, Expression, FuncAttr, Literal, Metadata, Moo,
4    RecordEntryGround, Reference, Typeable,
5    domains::{
6        GroundDomain,
7        domain::{DomainPtr, Int},
8        range::Range,
9    },
10};
11use crate::{bug, domain_int, matrix_expr, range};
12use conjure_cp_core::ast::pretty::pretty_vec;
13use conjure_cp_core::ast::{Name, ReturnType, eval_constant};
14use itertools::Itertools;
15use polyquine::Quine;
16use serde::{Deserialize, Serialize};
17use std::fmt::{Display, Formatter};
18use std::iter::zip;
19use std::ops::Deref;
20use uniplate::Uniplate;
21
22#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Quine, Uniplate)]
23#[path_prefix(conjure_cp::ast)]
24#[biplate(to=Expression)]
25#[biplate(to=Reference)]
26pub enum IntVal {
27    Const(Int),
28    #[polyquine_skip]
29    Reference(Reference),
30    Expr(Moo<Expression>),
31}
32
33impl From<Int> for IntVal {
34    fn from(value: Int) -> Self {
35        Self::Const(value)
36    }
37}
38
39impl TryInto<Int> for IntVal {
40    type Error = DomainOpError;
41
42    fn try_into(self) -> Result<Int, Self::Error> {
43        match self {
44            IntVal::Const(val) => Ok(val),
45            _ => Err(DomainOpError::NotGround),
46        }
47    }
48}
49
50impl From<Range<Int>> for Range<IntVal> {
51    fn from(value: Range<Int>) -> Self {
52        match value {
53            Range::Single(x) => Range::Single(x.into()),
54            Range::Bounded(l, r) => Range::Bounded(l.into(), r.into()),
55            Range::UnboundedL(r) => Range::UnboundedL(r.into()),
56            Range::UnboundedR(l) => Range::UnboundedR(l.into()),
57            Range::Unbounded => Range::Unbounded,
58        }
59    }
60}
61
62impl TryInto<Range<Int>> for Range<IntVal> {
63    type Error = DomainOpError;
64
65    fn try_into(self) -> Result<Range<Int>, Self::Error> {
66        match self {
67            Range::Single(x) => Ok(Range::Single(x.try_into()?)),
68            Range::Bounded(l, r) => Ok(Range::Bounded(l.try_into()?, r.try_into()?)),
69            Range::UnboundedL(r) => Ok(Range::UnboundedL(r.try_into()?)),
70            Range::UnboundedR(l) => Ok(Range::UnboundedR(l.try_into()?)),
71            Range::Unbounded => Ok(Range::Unbounded),
72        }
73    }
74}
75
76impl From<SetAttr<Int>> for SetAttr<IntVal> {
77    fn from(value: SetAttr<Int>) -> Self {
78        SetAttr {
79            size: value.size.into(),
80        }
81    }
82}
83
84impl TryInto<SetAttr<Int>> for SetAttr<IntVal> {
85    type Error = DomainOpError;
86
87    fn try_into(self) -> Result<SetAttr<Int>, Self::Error> {
88        let size: Range<Int> = self.size.try_into()?;
89        Ok(SetAttr { size })
90    }
91}
92
93impl From<FuncAttr<Int>> for FuncAttr<IntVal> {
94    fn from(value: FuncAttr<Int>) -> Self {
95        FuncAttr {
96            size: value.size.into(),
97            partiality: value.partiality,
98            jectivity: value.jectivity,
99        }
100    }
101}
102
103impl TryInto<FuncAttr<Int>> for FuncAttr<IntVal> {
104    type Error = DomainOpError;
105
106    fn try_into(self) -> Result<FuncAttr<Int>, Self::Error> {
107        let size: Range<Int> = self.size.try_into()?;
108        Ok(FuncAttr {
109            size,
110            jectivity: self.jectivity,
111            partiality: self.partiality,
112        })
113    }
114}
115
116impl Display for IntVal {
117    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
118        match self {
119            IntVal::Const(val) => write!(f, "{val}"),
120            IntVal::Reference(re) => write!(f, "{re}"),
121            IntVal::Expr(expr) => write!(f, "({expr})"),
122        }
123    }
124}
125
126impl IntVal {
127    pub fn new_ref(re: &Reference) -> Option<IntVal> {
128        match re.ptr.kind().deref() {
129            DeclarationKind::ValueLetting(expr) => match expr.return_type() {
130                ReturnType::Int => Some(IntVal::Reference(re.clone())),
131                _ => None,
132            },
133            DeclarationKind::Given(dom) => match dom.return_type() {
134                ReturnType::Int => Some(IntVal::Reference(re.clone())),
135                _ => None,
136            },
137            DeclarationKind::GivenQuantified(inner) => match inner.domain().return_type() {
138                ReturnType::Int => Some(IntVal::Reference(re.clone())),
139                _ => None,
140            },
141            DeclarationKind::DomainLetting(_)
142            | DeclarationKind::RecordField(_)
143            | DeclarationKind::DecisionVariable(_) => None,
144        }
145    }
146
147    pub fn new_expr(value: Moo<Expression>) -> Option<IntVal> {
148        if value.return_type() != ReturnType::Int {
149            return None;
150        }
151        Some(IntVal::Expr(value))
152    }
153
154    pub fn resolve(&self) -> Option<Int> {
155        match self {
156            IntVal::Const(value) => Some(*value),
157            IntVal::Expr(expr) => match eval_constant(expr)? {
158                Literal::Int(v) => Some(v),
159                _ => bug!("Expected integer expression, got: {expr}"),
160            },
161            IntVal::Reference(re) => match re.ptr.kind().deref() {
162                DeclarationKind::ValueLetting(expr) => match eval_constant(expr)? {
163                    Literal::Int(v) => Some(v),
164                    _ => bug!("Expected integer expression, got: {expr}"),
165                },
166                // If this is an int given we will be able to resolve it eventually, but not yet
167                DeclarationKind::Given(_) | DeclarationKind::GivenQuantified(..) => None,
168                DeclarationKind::DomainLetting(_)
169                | DeclarationKind::RecordField(_)
170                | DeclarationKind::DecisionVariable(_) => bug!(
171                    "Expected integer expression, given, or letting inside int domain; Got: {re}"
172                ),
173            },
174        }
175    }
176}
177
178impl From<IntVal> for Expression {
179    fn from(value: IntVal) -> Self {
180        match value {
181            IntVal::Const(val) => val.into(),
182            IntVal::Reference(re) => re.into(),
183            IntVal::Expr(expr) => expr.as_ref().clone(),
184        }
185    }
186}
187
188impl From<IntVal> for Moo<Expression> {
189    fn from(value: IntVal) -> Self {
190        match value {
191            IntVal::Const(val) => Moo::new(val.into()),
192            IntVal::Reference(re) => Moo::new(re.into()),
193            IntVal::Expr(expr) => expr,
194        }
195    }
196}
197
198impl std::ops::Neg for IntVal {
199    type Output = IntVal;
200
201    fn neg(self) -> Self::Output {
202        match self {
203            IntVal::Const(val) => IntVal::Const(-val),
204            IntVal::Reference(_) | IntVal::Expr(_) => {
205                IntVal::Expr(Moo::new(Expression::Neg(Metadata::new(), self.into())))
206            }
207        }
208    }
209}
210
211impl<T> std::ops::Add<T> for IntVal
212where
213    T: Into<Expression>,
214{
215    type Output = IntVal;
216
217    fn add(self, rhs: T) -> Self::Output {
218        let lhs: Expression = self.into();
219        let rhs: Expression = rhs.into();
220        let sum = matrix_expr!(lhs, rhs; domain_int!(1..));
221        IntVal::Expr(Moo::new(Expression::Sum(Metadata::new(), Moo::new(sum))))
222    }
223}
224
225impl Range<IntVal> {
226    pub fn resolve(&self) -> Option<Range<Int>> {
227        match self {
228            Range::Single(x) => Some(Range::Single(x.resolve()?)),
229            Range::Bounded(l, r) => Some(Range::Bounded(l.resolve()?, r.resolve()?)),
230            Range::UnboundedL(r) => Some(Range::UnboundedL(r.resolve()?)),
231            Range::UnboundedR(l) => Some(Range::UnboundedR(l.resolve()?)),
232            Range::Unbounded => Some(Range::Unbounded),
233        }
234    }
235}
236
237impl SetAttr<IntVal> {
238    pub fn resolve(&self) -> Option<SetAttr<Int>> {
239        Some(SetAttr {
240            size: self.size.resolve()?,
241        })
242    }
243}
244
245impl FuncAttr<IntVal> {
246    pub fn resolve(&self) -> Option<FuncAttr<Int>> {
247        Some(FuncAttr {
248            size: self.size.resolve()?,
249            partiality: self.partiality.clone(),
250            jectivity: self.jectivity.clone(),
251        })
252    }
253}
254
255#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Uniplate, Quine)]
256#[path_prefix(conjure_cp::ast)]
257pub struct RecordEntry {
258    pub name: Name,
259    pub domain: DomainPtr,
260}
261
262impl RecordEntry {
263    pub fn resolve(self) -> Option<RecordEntryGround> {
264        Some(RecordEntryGround {
265            name: self.name,
266            domain: self.domain.resolve()?,
267        })
268    }
269}
270
271#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Quine, Uniplate)]
272#[path_prefix(conjure_cp::ast)]
273#[biplate(to=Expression)]
274#[biplate(to=Reference)]
275#[biplate(to=IntVal)]
276#[biplate(to=DomainPtr)]
277#[biplate(to=RecordEntry)]
278pub enum UnresolvedDomain {
279    Int(Vec<Range<IntVal>>),
280    /// A set of elements drawn from the inner domain
281    Set(SetAttr<IntVal>, DomainPtr),
282    /// A n-dimensional matrix with a value domain and n-index domains
283    Matrix(DomainPtr, Vec<DomainPtr>),
284    /// A tuple of N elements, each with its own domain
285    Tuple(Vec<DomainPtr>),
286    /// A reference to a domain letting
287    #[polyquine_skip]
288    Reference(Reference),
289    /// A record
290    Record(Vec<RecordEntry>),
291    /// A function with attributes, domain, and range
292    Function(FuncAttr<IntVal>, DomainPtr, DomainPtr),
293}
294
295impl UnresolvedDomain {
296    pub fn resolve(&self) -> Option<GroundDomain> {
297        match self {
298            UnresolvedDomain::Int(rngs) => rngs
299                .iter()
300                .map(Range::<IntVal>::resolve)
301                .collect::<Option<_>>()
302                .map(GroundDomain::Int),
303            UnresolvedDomain::Set(attr, inner) => {
304                Some(GroundDomain::Set(attr.resolve()?, inner.resolve()?))
305            }
306            UnresolvedDomain::Matrix(inner, idx_doms) => {
307                let inner_gd = inner.resolve()?;
308                idx_doms
309                    .iter()
310                    .map(DomainPtr::resolve)
311                    .collect::<Option<_>>()
312                    .map(|idx| GroundDomain::Matrix(inner_gd, idx))
313            }
314            UnresolvedDomain::Tuple(inners) => inners
315                .iter()
316                .map(DomainPtr::resolve)
317                .collect::<Option<_>>()
318                .map(GroundDomain::Tuple),
319            UnresolvedDomain::Record(entries) => entries
320                .iter()
321                .map(|f| {
322                    f.domain.resolve().map(|gd| RecordEntryGround {
323                        name: f.name.clone(),
324                        domain: gd,
325                    })
326                })
327                .collect::<Option<_>>()
328                .map(GroundDomain::Record),
329            UnresolvedDomain::Reference(re) => re
330                .ptr
331                .as_domain_letting()
332                .unwrap_or_else(|| {
333                    bug!("Reference domain should point to domain letting, but got {re}")
334                })
335                .resolve()
336                .map(Moo::unwrap_or_clone),
337            UnresolvedDomain::Function(attr, dom, cdom) => {
338                if let Some(attr_gd) = attr.resolve()
339                    && let Some(dom_gd) = dom.resolve()
340                    && let Some(cdom_gd) = cdom.resolve()
341                {
342                    return Some(GroundDomain::Function(attr_gd, dom_gd, cdom_gd));
343                }
344                None
345            }
346        }
347    }
348
349    pub(super) fn union_unresolved(
350        &self,
351        other: &UnresolvedDomain,
352    ) -> Result<UnresolvedDomain, DomainOpError> {
353        match (self, other) {
354            (UnresolvedDomain::Int(lhs), UnresolvedDomain::Int(rhs)) => {
355                let merged = lhs.iter().chain(rhs.iter()).cloned().collect_vec();
356                Ok(UnresolvedDomain::Int(merged))
357            }
358            (UnresolvedDomain::Int(_), _) | (_, UnresolvedDomain::Int(_)) => {
359                Err(DomainOpError::WrongType)
360            }
361            (UnresolvedDomain::Set(_, in1), UnresolvedDomain::Set(_, in2)) => {
362                Ok(UnresolvedDomain::Set(SetAttr::default(), in1.union(in2)?))
363            }
364            (UnresolvedDomain::Set(_, _), _) | (_, UnresolvedDomain::Set(_, _)) => {
365                Err(DomainOpError::WrongType)
366            }
367            (UnresolvedDomain::Matrix(in1, idx1), UnresolvedDomain::Matrix(in2, idx2))
368                if idx1 == idx2 =>
369            {
370                Ok(UnresolvedDomain::Matrix(in1.union(in2)?, idx1.clone()))
371            }
372            (UnresolvedDomain::Matrix(_, _), _) | (_, UnresolvedDomain::Matrix(_, _)) => {
373                Err(DomainOpError::WrongType)
374            }
375            (UnresolvedDomain::Tuple(lhs), UnresolvedDomain::Tuple(rhs))
376                if lhs.len() == rhs.len() =>
377            {
378                let mut merged = Vec::new();
379                for (l, r) in zip(lhs, rhs) {
380                    merged.push(l.union(r)?)
381                }
382                Ok(UnresolvedDomain::Tuple(merged))
383            }
384            (UnresolvedDomain::Tuple(_), _) | (_, UnresolvedDomain::Tuple(_)) => {
385                Err(DomainOpError::WrongType)
386            }
387            // TODO: Could we support unions of reference domains symbolically?
388            (UnresolvedDomain::Reference(_), _) | (_, UnresolvedDomain::Reference(_)) => {
389                Err(DomainOpError::NotGround)
390            }
391            // TODO: Could we define semantics for merging record domains?
392            #[allow(unreachable_patterns)] // Technically redundant but logically makes sense
393            (UnresolvedDomain::Record(_), _) | (_, UnresolvedDomain::Record(_)) => {
394                Err(DomainOpError::WrongType)
395            }
396            #[allow(unreachable_patterns)]
397            // Technically redundant but logically clearer to have both
398            (UnresolvedDomain::Function(_, _, _), _) | (_, UnresolvedDomain::Function(_, _, _)) => {
399                Err(DomainOpError::WrongType)
400            }
401        }
402    }
403
404    pub fn element_domain(&self) -> Option<DomainPtr> {
405        match self {
406            UnresolvedDomain::Set(_, inner_dom) => Some(inner_dom.clone()),
407            UnresolvedDomain::Matrix(_, _) => {
408                todo!("Unwrap one dimension of the domain")
409            }
410            _ => None,
411        }
412    }
413}
414
415impl Typeable for UnresolvedDomain {
416    fn return_type(&self) -> ReturnType {
417        match self {
418            UnresolvedDomain::Reference(re) => re.return_type(),
419            UnresolvedDomain::Int(_) => ReturnType::Int,
420            UnresolvedDomain::Set(_attr, inner) => ReturnType::Set(Box::new(inner.return_type())),
421            UnresolvedDomain::Matrix(inner, _idx) => {
422                ReturnType::Matrix(Box::new(inner.return_type()))
423            }
424            UnresolvedDomain::Tuple(inners) => {
425                let mut inner_types = Vec::new();
426                for inner in inners {
427                    inner_types.push(inner.return_type());
428                }
429                ReturnType::Tuple(inner_types)
430            }
431            UnresolvedDomain::Record(entries) => {
432                let mut entry_types = Vec::new();
433                for entry in entries {
434                    entry_types.push(entry.domain.return_type());
435                }
436                ReturnType::Record(entry_types)
437            }
438            UnresolvedDomain::Function(_, dom, cdom) => {
439                ReturnType::Function(Box::new(dom.return_type()), Box::new(cdom.return_type()))
440            }
441        }
442    }
443}
444
445impl Display for UnresolvedDomain {
446    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
447        match &self {
448            UnresolvedDomain::Reference(re) => write!(f, "{re}"),
449            UnresolvedDomain::Int(ranges) => {
450                if ranges.iter().all(Range::is_lower_or_upper_bounded) {
451                    let rngs: String = ranges.iter().map(|r| format!("{r}")).join(", ");
452                    write!(f, "int({})", rngs)
453                } else {
454                    write!(f, "int")
455                }
456            }
457            UnresolvedDomain::Set(attrs, inner_dom) => write!(f, "set {attrs} of {inner_dom}"),
458            UnresolvedDomain::Matrix(value_domain, index_domains) => {
459                write!(
460                    f,
461                    "matrix indexed by [{}] of {value_domain}",
462                    pretty_vec(&index_domains.iter().collect_vec())
463                )
464            }
465            UnresolvedDomain::Tuple(domains) => {
466                write!(
467                    f,
468                    "tuple of ({})",
469                    pretty_vec(&domains.iter().collect_vec())
470                )
471            }
472            UnresolvedDomain::Record(entries) => {
473                write!(
474                    f,
475                    "record of ({})",
476                    pretty_vec(
477                        &entries
478                            .iter()
479                            .map(|entry| format!("{}: {}", entry.name, entry.domain))
480                            .collect_vec()
481                    )
482                )
483            }
484            UnresolvedDomain::Function(attribute, domain, codomain) => {
485                write!(f, "function {} {} --> {} ", attribute, domain, codomain)
486            }
487        }
488    }
489}