1use crate::ast::domains::attrs::MSetAttr;
2use crate::ast::domains::attrs::SetAttr;
3use crate::ast::{
4 DeclarationKind, DomainOpError, Expression, FuncAttr, Literal, Metadata, Moo,
5 RecordEntryGround, Reference, Typeable,
6 domains::{
7 GroundDomain,
8 domain::{DomainPtr, Int},
9 range::Range,
10 },
11};
12use crate::{bug, domain_int, matrix_expr, range};
13use conjure_cp_core::ast::pretty::pretty_vec;
14use conjure_cp_core::ast::{Name, ReturnType, eval_constant};
15use itertools::Itertools;
16use polyquine::Quine;
17use serde::{Deserialize, Serialize};
18use std::fmt::{Display, Formatter};
19use std::iter::zip;
20use std::ops::Deref;
21use uniplate::Uniplate;
22
23#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Quine, Uniplate)]
24#[path_prefix(conjure_cp::ast)]
25#[biplate(to=Expression)]
26#[biplate(to=Reference)]
27pub enum IntVal {
28 Const(Int),
29 #[polyquine_skip]
30 Reference(Reference),
31 Expr(Moo<Expression>),
32}
33
34impl From<Int> for IntVal {
35 fn from(value: Int) -> Self {
36 Self::Const(value)
37 }
38}
39
40impl TryInto<Int> for IntVal {
41 type Error = DomainOpError;
42
43 fn try_into(self) -> Result<Int, Self::Error> {
44 match self {
45 IntVal::Const(val) => Ok(val),
46 _ => Err(DomainOpError::NotGround),
47 }
48 }
49}
50
51impl From<Range<Int>> for Range<IntVal> {
52 fn from(value: Range<Int>) -> Self {
53 match value {
54 Range::Single(x) => Range::Single(x.into()),
55 Range::Bounded(l, r) => Range::Bounded(l.into(), r.into()),
56 Range::UnboundedL(r) => Range::UnboundedL(r.into()),
57 Range::UnboundedR(l) => Range::UnboundedR(l.into()),
58 Range::Unbounded => Range::Unbounded,
59 }
60 }
61}
62
63impl TryInto<Range<Int>> for Range<IntVal> {
64 type Error = DomainOpError;
65
66 fn try_into(self) -> Result<Range<Int>, Self::Error> {
67 match self {
68 Range::Single(x) => Ok(Range::Single(x.try_into()?)),
69 Range::Bounded(l, r) => Ok(Range::Bounded(l.try_into()?, r.try_into()?)),
70 Range::UnboundedL(r) => Ok(Range::UnboundedL(r.try_into()?)),
71 Range::UnboundedR(l) => Ok(Range::UnboundedR(l.try_into()?)),
72 Range::Unbounded => Ok(Range::Unbounded),
73 }
74 }
75}
76
77impl From<SetAttr<Int>> for SetAttr<IntVal> {
78 fn from(value: SetAttr<Int>) -> Self {
79 SetAttr {
80 size: value.size.into(),
81 }
82 }
83}
84
85impl TryInto<SetAttr<Int>> for SetAttr<IntVal> {
86 type Error = DomainOpError;
87
88 fn try_into(self) -> Result<SetAttr<Int>, Self::Error> {
89 let size: Range<Int> = self.size.try_into()?;
90 Ok(SetAttr { size })
91 }
92}
93
94impl From<MSetAttr<Int>> for MSetAttr<IntVal> {
95 fn from(value: MSetAttr<Int>) -> Self {
96 MSetAttr {
97 size: value.size.into(),
98 occurrence: value.occurrence.into(),
99 }
100 }
101}
102
103impl TryInto<MSetAttr<Int>> for MSetAttr<IntVal> {
104 type Error = DomainOpError;
105
106 fn try_into(self) -> Result<MSetAttr<Int>, Self::Error> {
107 let size: Range<Int> = self.size.try_into()?;
108 let occurrence: Range<Int> = self.occurrence.try_into()?;
109 Ok(MSetAttr { size, occurrence })
110 }
111}
112
113impl From<FuncAttr<Int>> for FuncAttr<IntVal> {
114 fn from(value: FuncAttr<Int>) -> Self {
115 FuncAttr {
116 size: value.size.into(),
117 partiality: value.partiality,
118 jectivity: value.jectivity,
119 }
120 }
121}
122
123impl TryInto<FuncAttr<Int>> for FuncAttr<IntVal> {
124 type Error = DomainOpError;
125
126 fn try_into(self) -> Result<FuncAttr<Int>, Self::Error> {
127 let size: Range<Int> = self.size.try_into()?;
128 Ok(FuncAttr {
129 size,
130 jectivity: self.jectivity,
131 partiality: self.partiality,
132 })
133 }
134}
135
136impl Display for IntVal {
137 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
138 match self {
139 IntVal::Const(val) => write!(f, "{val}"),
140 IntVal::Reference(re) => write!(f, "{re}"),
141 IntVal::Expr(expr) => write!(f, "({expr})"),
142 }
143 }
144}
145
146impl IntVal {
147 pub fn new_ref(re: &Reference) -> Option<IntVal> {
148 match re.ptr.kind().deref() {
149 DeclarationKind::ValueLetting(expr) => match expr.return_type() {
150 ReturnType::Int => Some(IntVal::Reference(re.clone())),
151 _ => None,
152 },
153 DeclarationKind::Given(dom) => match dom.return_type() {
154 ReturnType::Int => Some(IntVal::Reference(re.clone())),
155 _ => None,
156 },
157 DeclarationKind::Quantified(inner) => match inner.domain().return_type() {
158 ReturnType::Int => Some(IntVal::Reference(re.clone())),
159 _ => None,
160 },
161 DeclarationKind::DomainLetting(_)
162 | DeclarationKind::RecordField(_)
163 | DeclarationKind::Find(_) => None,
164 }
165 }
166
167 pub fn new_expr(value: Moo<Expression>) -> Option<IntVal> {
168 if value.return_type() != ReturnType::Int {
169 return None;
170 }
171 Some(IntVal::Expr(value))
172 }
173
174 pub fn resolve(&self) -> Option<Int> {
175 match self {
176 IntVal::Const(value) => Some(*value),
177 IntVal::Expr(expr) => eval_expr_to_int(expr),
178 IntVal::Reference(re) => match re.ptr.kind().deref() {
179 DeclarationKind::ValueLetting(expr) => eval_expr_to_int(expr),
180 DeclarationKind::Given(_) | DeclarationKind::Quantified(..) => None,
182 DeclarationKind::DomainLetting(_)
183 | DeclarationKind::RecordField(_)
184 | DeclarationKind::Find(_) => bug!(
185 "Expected integer expression, given, or letting inside int domain; Got: {re}"
186 ),
187 },
188 }
189 }
190}
191
192fn eval_expr_to_int(expr: &Expression) -> Option<Int> {
193 match eval_constant(expr)? {
194 Literal::Int(v) => Some(v),
195 _ => bug!("Expected integer expression, got: {expr}"),
196 }
197}
198
199impl From<IntVal> for Expression {
200 fn from(value: IntVal) -> Self {
201 match value {
202 IntVal::Const(val) => val.into(),
203 IntVal::Reference(re) => re.into(),
204 IntVal::Expr(expr) => expr.as_ref().clone(),
205 }
206 }
207}
208
209impl From<IntVal> for Moo<Expression> {
210 fn from(value: IntVal) -> Self {
211 match value {
212 IntVal::Const(val) => Moo::new(val.into()),
213 IntVal::Reference(re) => Moo::new(re.into()),
214 IntVal::Expr(expr) => expr,
215 }
216 }
217}
218
219impl std::ops::Neg for IntVal {
220 type Output = IntVal;
221
222 fn neg(self) -> Self::Output {
223 match self {
224 IntVal::Const(val) => IntVal::Const(-val),
225 IntVal::Reference(_) | IntVal::Expr(_) => {
226 IntVal::Expr(Moo::new(Expression::Neg(Metadata::new(), self.into())))
227 }
228 }
229 }
230}
231
232impl<T> std::ops::Add<T> for IntVal
233where
234 T: Into<Expression>,
235{
236 type Output = IntVal;
237
238 fn add(self, rhs: T) -> Self::Output {
239 let lhs: Expression = self.into();
240 let rhs: Expression = rhs.into();
241 let sum = matrix_expr!(lhs, rhs; domain_int!(1..));
242 IntVal::Expr(Moo::new(Expression::Sum(Metadata::new(), Moo::new(sum))))
243 }
244}
245
246impl Range<IntVal> {
247 pub fn resolve(&self) -> Option<Range<Int>> {
248 match self {
249 Range::Single(x) => Some(Range::Single(x.resolve()?)),
250 Range::Bounded(l, r) => Some(Range::Bounded(l.resolve()?, r.resolve()?)),
251 Range::UnboundedL(r) => Some(Range::UnboundedL(r.resolve()?)),
252 Range::UnboundedR(l) => Some(Range::UnboundedR(l.resolve()?)),
253 Range::Unbounded => Some(Range::Unbounded),
254 }
255 }
256}
257
258impl SetAttr<IntVal> {
259 pub fn resolve(&self) -> Option<SetAttr<Int>> {
260 Some(SetAttr {
261 size: self.size.resolve()?,
262 })
263 }
264}
265
266impl MSetAttr<IntVal> {
267 pub fn resolve(&self) -> Option<MSetAttr<Int>> {
268 Some(MSetAttr {
269 size: self.size.resolve()?,
270 occurrence: self.occurrence.resolve()?,
271 })
272 }
273}
274
275impl FuncAttr<IntVal> {
276 pub fn resolve(&self) -> Option<FuncAttr<Int>> {
277 Some(FuncAttr {
278 size: self.size.resolve()?,
279 partiality: self.partiality.clone(),
280 jectivity: self.jectivity.clone(),
281 })
282 }
283}
284
285#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Uniplate, Quine)]
286#[path_prefix(conjure_cp::ast)]
287pub struct RecordEntry {
288 pub name: Name,
289 pub domain: DomainPtr,
290}
291
292impl RecordEntry {
293 pub fn resolve(self) -> Option<RecordEntryGround> {
294 Some(RecordEntryGround {
295 name: self.name,
296 domain: self.domain.resolve()?,
297 })
298 }
299}
300
301#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Quine, Uniplate)]
302#[path_prefix(conjure_cp::ast)]
303#[biplate(to=Expression)]
304#[biplate(to=Reference)]
305#[biplate(to=IntVal)]
306#[biplate(to=DomainPtr)]
307#[biplate(to=RecordEntry)]
308pub enum UnresolvedDomain {
309 Int(Vec<Range<IntVal>>),
310 Set(SetAttr<IntVal>, DomainPtr),
312 MSet(MSetAttr<IntVal>, DomainPtr),
313 Matrix(DomainPtr, Vec<DomainPtr>),
315 Tuple(Vec<DomainPtr>),
317 #[polyquine_skip]
319 Reference(Reference),
320 Record(Vec<RecordEntry>),
322 Function(FuncAttr<IntVal>, DomainPtr, DomainPtr),
324}
325
326impl UnresolvedDomain {
327 pub fn resolve(&self) -> Option<GroundDomain> {
328 match self {
329 UnresolvedDomain::Int(rngs) => rngs
330 .iter()
331 .map(Range::<IntVal>::resolve)
332 .collect::<Option<_>>()
333 .map(GroundDomain::Int),
334 UnresolvedDomain::Set(attr, inner) => {
335 Some(GroundDomain::Set(attr.resolve()?, inner.resolve()?))
336 }
337 UnresolvedDomain::MSet(attr, inner) => {
338 Some(GroundDomain::MSet(attr.resolve()?, inner.resolve()?))
339 }
340 UnresolvedDomain::Matrix(inner, idx_doms) => {
341 let inner_gd = inner.resolve()?;
342 idx_doms
343 .iter()
344 .map(DomainPtr::resolve)
345 .collect::<Option<_>>()
346 .map(|idx| GroundDomain::Matrix(inner_gd, idx))
347 }
348 UnresolvedDomain::Tuple(inners) => inners
349 .iter()
350 .map(DomainPtr::resolve)
351 .collect::<Option<_>>()
352 .map(GroundDomain::Tuple),
353 UnresolvedDomain::Record(entries) => entries
354 .iter()
355 .map(|f| {
356 f.domain.resolve().map(|gd| RecordEntryGround {
357 name: f.name.clone(),
358 domain: gd,
359 })
360 })
361 .collect::<Option<_>>()
362 .map(GroundDomain::Record),
363 UnresolvedDomain::Reference(re) => re
364 .ptr
365 .as_domain_letting()
366 .unwrap_or_else(|| {
367 bug!("Reference domain should point to domain letting, but got {re}")
368 })
369 .resolve()
370 .map(Moo::unwrap_or_clone),
371 UnresolvedDomain::Function(attr, dom, cdom) => {
372 if let Some(attr_gd) = attr.resolve()
373 && let Some(dom_gd) = dom.resolve()
374 && let Some(cdom_gd) = cdom.resolve()
375 {
376 return Some(GroundDomain::Function(attr_gd, dom_gd, cdom_gd));
377 }
378 None
379 }
380 }
381 }
382
383 pub(super) fn union_unresolved(
384 &self,
385 other: &UnresolvedDomain,
386 ) -> Result<UnresolvedDomain, DomainOpError> {
387 match (self, other) {
388 (UnresolvedDomain::Int(lhs), UnresolvedDomain::Int(rhs)) => {
389 let merged = lhs.iter().chain(rhs.iter()).cloned().collect_vec();
390 Ok(UnresolvedDomain::Int(merged))
391 }
392 (UnresolvedDomain::Int(_), _) | (_, UnresolvedDomain::Int(_)) => {
393 Err(DomainOpError::WrongType)
394 }
395 (UnresolvedDomain::Set(_, in1), UnresolvedDomain::Set(_, in2)) => {
396 Ok(UnresolvedDomain::Set(SetAttr::default(), in1.union(in2)?))
397 }
398 (UnresolvedDomain::Set(_, _), _) | (_, UnresolvedDomain::Set(_, _)) => {
399 Err(DomainOpError::WrongType)
400 }
401 (UnresolvedDomain::MSet(_, in1), UnresolvedDomain::MSet(_, in2)) => {
402 Ok(UnresolvedDomain::MSet(MSetAttr::default(), in1.union(in2)?))
403 }
404 (UnresolvedDomain::MSet(_, _), _) | (_, UnresolvedDomain::MSet(_, _)) => {
405 Err(DomainOpError::WrongType)
406 }
407 (UnresolvedDomain::Matrix(in1, idx1), UnresolvedDomain::Matrix(in2, idx2))
408 if idx1 == idx2 =>
409 {
410 Ok(UnresolvedDomain::Matrix(in1.union(in2)?, idx1.clone()))
411 }
412 (UnresolvedDomain::Matrix(_, _), _) | (_, UnresolvedDomain::Matrix(_, _)) => {
413 Err(DomainOpError::WrongType)
414 }
415 (UnresolvedDomain::Tuple(lhs), UnresolvedDomain::Tuple(rhs))
416 if lhs.len() == rhs.len() =>
417 {
418 let mut merged = Vec::new();
419 for (l, r) in zip(lhs, rhs) {
420 merged.push(l.union(r)?)
421 }
422 Ok(UnresolvedDomain::Tuple(merged))
423 }
424 (UnresolvedDomain::Tuple(_), _) | (_, UnresolvedDomain::Tuple(_)) => {
425 Err(DomainOpError::WrongType)
426 }
427 (UnresolvedDomain::Reference(_), _) | (_, UnresolvedDomain::Reference(_)) => {
429 Err(DomainOpError::NotGround)
430 }
431 #[allow(unreachable_patterns)] (UnresolvedDomain::Record(_), _) | (_, UnresolvedDomain::Record(_)) => {
434 Err(DomainOpError::WrongType)
435 }
436 #[allow(unreachable_patterns)]
437 (UnresolvedDomain::Function(_, _, _), _) | (_, UnresolvedDomain::Function(_, _, _)) => {
439 Err(DomainOpError::WrongType)
440 }
441 }
442 }
443
444 pub fn element_domain(&self) -> Option<DomainPtr> {
445 match self {
446 UnresolvedDomain::Set(_, inner_dom) => Some(inner_dom.clone()),
447 UnresolvedDomain::Matrix(_, _) => {
448 todo!("Unwrap one dimension of the domain")
449 }
450 _ => None,
451 }
452 }
453}
454
455impl Typeable for UnresolvedDomain {
456 fn return_type(&self) -> ReturnType {
457 match self {
458 UnresolvedDomain::Reference(re) => re.return_type(),
459 UnresolvedDomain::Int(_) => ReturnType::Int,
460 UnresolvedDomain::Set(_attr, inner) => ReturnType::Set(Box::new(inner.return_type())),
461 UnresolvedDomain::MSet(_attr, inner) => ReturnType::MSet(Box::new(inner.return_type())),
462 UnresolvedDomain::Matrix(inner, _idx) => {
463 ReturnType::Matrix(Box::new(inner.return_type()))
464 }
465 UnresolvedDomain::Tuple(inners) => {
466 let mut inner_types = Vec::new();
467 for inner in inners {
468 inner_types.push(inner.return_type());
469 }
470 ReturnType::Tuple(inner_types)
471 }
472 UnresolvedDomain::Record(entries) => {
473 let mut entry_types = Vec::new();
474 for entry in entries {
475 entry_types.push(entry.domain.return_type());
476 }
477 ReturnType::Record(entry_types)
478 }
479 UnresolvedDomain::Function(_, dom, cdom) => {
480 ReturnType::Function(Box::new(dom.return_type()), Box::new(cdom.return_type()))
481 }
482 }
483 }
484}
485
486impl Display for UnresolvedDomain {
487 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
488 match &self {
489 UnresolvedDomain::Reference(re) => write!(f, "{re}"),
490 UnresolvedDomain::Int(ranges) => {
491 if ranges.iter().all(Range::is_lower_or_upper_bounded) {
492 let rngs: String = ranges.iter().map(|r| format!("{r}")).join(", ");
493 write!(f, "int({})", rngs)
494 } else {
495 write!(f, "int")
496 }
497 }
498 UnresolvedDomain::Set(attrs, inner_dom) => write!(f, "set {attrs} of {inner_dom}"),
499 UnresolvedDomain::MSet(attrs, inner_dom) => write!(f, "mset {attrs} of {inner_dom}"),
500 UnresolvedDomain::Matrix(value_domain, index_domains) => {
501 write!(
502 f,
503 "matrix indexed by [{}] of {value_domain}",
504 pretty_vec(&index_domains.iter().collect_vec())
505 )
506 }
507 UnresolvedDomain::Tuple(domains) => {
508 write!(
509 f,
510 "tuple of ({})",
511 pretty_vec(&domains.iter().collect_vec())
512 )
513 }
514 UnresolvedDomain::Record(entries) => {
515 write!(
516 f,
517 "record of ({})",
518 pretty_vec(
519 &entries
520 .iter()
521 .map(|entry| format!("{}: {}", entry.name, entry.domain))
522 .collect_vec()
523 )
524 )
525 }
526 UnresolvedDomain::Function(attribute, domain, codomain) => {
527 write!(f, "function {} {} --> {} ", attribute, domain, codomain)
528 }
529 }
530 }
531}