1use crate::ast::domains::attrs::MSetAttr;
2use crate::ast::domains::attrs::SetAttr;
3use crate::ast::{
4 DeclarationKind, DomainOpError, Expression, FuncAttr, Literal, Metadata, Moo,
5 RecordEntryGround, Reference, Typeable,
6 domains::{
7 GroundDomain,
8 domain::{DomainPtr, Int},
9 range::Range,
10 },
11};
12use crate::{bug, domain_int, matrix_expr, range};
13use conjure_cp_core::ast::pretty::pretty_vec;
14use conjure_cp_core::ast::{Name, ReturnType, eval_constant};
15use itertools::Itertools;
16use polyquine::Quine;
17use serde::{Deserialize, Serialize};
18use std::fmt::{Display, Formatter};
19use std::iter::zip;
20use std::ops::Deref;
21use uniplate::Uniplate;
22
23#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Quine, Uniplate)]
24#[path_prefix(conjure_cp::ast)]
25#[biplate(to=Expression)]
26#[biplate(to=Reference)]
27pub enum IntVal {
28 Const(Int),
29 #[polyquine_skip]
30 Reference(Reference),
31 Expr(Moo<Expression>),
32}
33
34impl From<Int> for IntVal {
35 fn from(value: Int) -> Self {
36 Self::Const(value)
37 }
38}
39
40impl TryInto<Int> for IntVal {
41 type Error = DomainOpError;
42
43 fn try_into(self) -> Result<Int, Self::Error> {
44 match self {
45 IntVal::Const(val) => Ok(val),
46 _ => Err(DomainOpError::NotGround),
47 }
48 }
49}
50
51impl From<Range<Int>> for Range<IntVal> {
52 fn from(value: Range<Int>) -> Self {
53 match value {
54 Range::Single(x) => Range::Single(x.into()),
55 Range::Bounded(l, r) => Range::Bounded(l.into(), r.into()),
56 Range::UnboundedL(r) => Range::UnboundedL(r.into()),
57 Range::UnboundedR(l) => Range::UnboundedR(l.into()),
58 Range::Unbounded => Range::Unbounded,
59 }
60 }
61}
62
63impl TryInto<Range<Int>> for Range<IntVal> {
64 type Error = DomainOpError;
65
66 fn try_into(self) -> Result<Range<Int>, Self::Error> {
67 match self {
68 Range::Single(x) => Ok(Range::Single(x.try_into()?)),
69 Range::Bounded(l, r) => Ok(Range::Bounded(l.try_into()?, r.try_into()?)),
70 Range::UnboundedL(r) => Ok(Range::UnboundedL(r.try_into()?)),
71 Range::UnboundedR(l) => Ok(Range::UnboundedR(l.try_into()?)),
72 Range::Unbounded => Ok(Range::Unbounded),
73 }
74 }
75}
76
77impl From<SetAttr<Int>> for SetAttr<IntVal> {
78 fn from(value: SetAttr<Int>) -> Self {
79 SetAttr {
80 size: value.size.into(),
81 }
82 }
83}
84
85impl TryInto<SetAttr<Int>> for SetAttr<IntVal> {
86 type Error = DomainOpError;
87
88 fn try_into(self) -> Result<SetAttr<Int>, Self::Error> {
89 let size: Range<Int> = self.size.try_into()?;
90 Ok(SetAttr { size })
91 }
92}
93
94impl From<MSetAttr<Int>> for MSetAttr<IntVal> {
95 fn from(value: MSetAttr<Int>) -> Self {
96 MSetAttr {
97 size: value.size.into(),
98 occurrence: value.occurrence.into(),
99 }
100 }
101}
102
103impl TryInto<MSetAttr<Int>> for MSetAttr<IntVal> {
104 type Error = DomainOpError;
105
106 fn try_into(self) -> Result<MSetAttr<Int>, Self::Error> {
107 let size: Range<Int> = self.size.try_into()?;
108 let occurrence: Range<Int> = self.occurrence.try_into()?;
109 Ok(MSetAttr { size, occurrence })
110 }
111}
112
113impl From<FuncAttr<Int>> for FuncAttr<IntVal> {
114 fn from(value: FuncAttr<Int>) -> Self {
115 FuncAttr {
116 size: value.size.into(),
117 partiality: value.partiality,
118 jectivity: value.jectivity,
119 }
120 }
121}
122
123impl TryInto<FuncAttr<Int>> for FuncAttr<IntVal> {
124 type Error = DomainOpError;
125
126 fn try_into(self) -> Result<FuncAttr<Int>, Self::Error> {
127 let size: Range<Int> = self.size.try_into()?;
128 Ok(FuncAttr {
129 size,
130 jectivity: self.jectivity,
131 partiality: self.partiality,
132 })
133 }
134}
135
136impl Display for IntVal {
137 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
138 match self {
139 IntVal::Const(val) => write!(f, "{val}"),
140 IntVal::Reference(re) => write!(f, "{re}"),
141 IntVal::Expr(expr) => write!(f, "({expr})"),
142 }
143 }
144}
145
146impl IntVal {
147 pub fn new_ref(re: &Reference) -> Option<IntVal> {
148 match re.ptr.kind().deref() {
149 DeclarationKind::ValueLetting(expr, _)
150 | DeclarationKind::TemporaryValueLetting(expr)
151 | DeclarationKind::QuantifiedExpr(expr) => match expr.return_type() {
152 ReturnType::Int => Some(IntVal::Reference(re.clone())),
153 _ => None,
154 },
155 DeclarationKind::Given(dom) => match dom.return_type() {
156 ReturnType::Int => Some(IntVal::Reference(re.clone())),
157 _ => None,
158 },
159 DeclarationKind::Quantified(inner) => match inner.domain().return_type() {
160 ReturnType::Int => Some(IntVal::Reference(re.clone())),
161 _ => None,
162 },
163 DeclarationKind::Find(var) => match var.return_type() {
164 ReturnType::Int => Some(IntVal::Reference(re.clone())),
165 _ => None,
166 },
167 DeclarationKind::DomainLetting(_) | DeclarationKind::RecordField(_) => None,
168 }
169 }
170
171 pub fn new_expr(value: Moo<Expression>) -> Option<IntVal> {
172 if value.return_type() != ReturnType::Int {
173 return None;
174 }
175 Some(IntVal::Expr(value))
176 }
177
178 pub fn resolve(&self) -> Option<Int> {
179 match self {
180 IntVal::Const(value) => Some(*value),
181 IntVal::Expr(expr) => eval_expr_to_int(expr),
182 IntVal::Reference(re) => match re.ptr.kind().deref() {
183 DeclarationKind::ValueLetting(expr, _)
184 | DeclarationKind::TemporaryValueLetting(expr) => eval_expr_to_int(expr),
185 DeclarationKind::Given(_) => None,
187 DeclarationKind::Quantified(inner) => {
188 if let Some(generator) = inner.generator()
189 && let Some(expr) = generator.as_value_letting()
190 {
191 eval_expr_to_int(&expr)
192 } else {
193 None
194 }
195 }
196 DeclarationKind::QuantifiedExpr(_) => None,
198 DeclarationKind::Find(_) => None,
200 DeclarationKind::DomainLetting(_) | DeclarationKind::RecordField(_) => bug!(
201 "Expected integer expression, given, or letting inside int domain; Got: {re}"
202 ),
203 },
204 }
205 }
206}
207
208fn eval_expr_to_int(expr: &Expression) -> Option<Int> {
209 match eval_constant(expr)? {
210 Literal::Int(v) => Some(v),
211 _ => bug!("Expected integer expression, got: {expr}"),
212 }
213}
214
215impl From<IntVal> for Expression {
216 fn from(value: IntVal) -> Self {
217 match value {
218 IntVal::Const(val) => val.into(),
219 IntVal::Reference(re) => re.into(),
220 IntVal::Expr(expr) => expr.as_ref().clone(),
221 }
222 }
223}
224
225impl From<IntVal> for Moo<Expression> {
226 fn from(value: IntVal) -> Self {
227 match value {
228 IntVal::Const(val) => Moo::new(val.into()),
229 IntVal::Reference(re) => Moo::new(re.into()),
230 IntVal::Expr(expr) => expr,
231 }
232 }
233}
234
235impl std::ops::Neg for IntVal {
236 type Output = IntVal;
237
238 fn neg(self) -> Self::Output {
239 match self {
240 IntVal::Const(val) => IntVal::Const(-val),
241 IntVal::Reference(_) | IntVal::Expr(_) => {
242 IntVal::Expr(Moo::new(Expression::Neg(Metadata::new(), self.into())))
243 }
244 }
245 }
246}
247
248impl<T> std::ops::Add<T> for IntVal
249where
250 T: Into<Expression>,
251{
252 type Output = IntVal;
253
254 fn add(self, rhs: T) -> Self::Output {
255 let lhs: Expression = self.into();
256 let rhs: Expression = rhs.into();
257 let sum = matrix_expr!(lhs, rhs; domain_int!(1..));
258 IntVal::Expr(Moo::new(Expression::Sum(Metadata::new(), Moo::new(sum))))
259 }
260}
261
262impl Range<IntVal> {
263 pub fn resolve(&self) -> Option<Range<Int>> {
264 match self {
265 Range::Single(x) => Some(Range::Single(x.resolve()?)),
266 Range::Bounded(l, r) => Some(Range::Bounded(l.resolve()?, r.resolve()?)),
267 Range::UnboundedL(r) => Some(Range::UnboundedL(r.resolve()?)),
268 Range::UnboundedR(l) => Some(Range::UnboundedR(l.resolve()?)),
269 Range::Unbounded => Some(Range::Unbounded),
270 }
271 }
272}
273
274impl SetAttr<IntVal> {
275 pub fn resolve(&self) -> Option<SetAttr<Int>> {
276 Some(SetAttr {
277 size: self.size.resolve()?,
278 })
279 }
280}
281
282impl MSetAttr<IntVal> {
283 pub fn resolve(&self) -> Option<MSetAttr<Int>> {
284 Some(MSetAttr {
285 size: self.size.resolve()?,
286 occurrence: self.occurrence.resolve()?,
287 })
288 }
289}
290
291impl FuncAttr<IntVal> {
292 pub fn resolve(&self) -> Option<FuncAttr<Int>> {
293 Some(FuncAttr {
294 size: self.size.resolve()?,
295 partiality: self.partiality.clone(),
296 jectivity: self.jectivity.clone(),
297 })
298 }
299}
300
301#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Uniplate, Quine)]
302#[path_prefix(conjure_cp::ast)]
303pub struct RecordEntry {
304 pub name: Name,
305 pub domain: DomainPtr,
306}
307
308impl RecordEntry {
309 pub fn resolve(self) -> Option<RecordEntryGround> {
310 Some(RecordEntryGround {
311 name: self.name,
312 domain: self.domain.resolve()?,
313 })
314 }
315}
316
317#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Quine, Uniplate)]
318#[path_prefix(conjure_cp::ast)]
319#[biplate(to=Expression)]
320#[biplate(to=Reference)]
321#[biplate(to=IntVal)]
322#[biplate(to=DomainPtr)]
323#[biplate(to=RecordEntry)]
324pub enum UnresolvedDomain {
325 Int(Vec<Range<IntVal>>),
326 Set(SetAttr<IntVal>, DomainPtr),
328 MSet(MSetAttr<IntVal>, DomainPtr),
329 Matrix(DomainPtr, Vec<DomainPtr>),
331 Tuple(Vec<DomainPtr>),
333 #[polyquine_skip]
335 Reference(Reference),
336 Record(Vec<RecordEntry>),
338 Function(FuncAttr<IntVal>, DomainPtr, DomainPtr),
340}
341
342impl UnresolvedDomain {
343 pub fn resolve(&self) -> Option<GroundDomain> {
344 match self {
345 UnresolvedDomain::Int(rngs) => rngs
346 .iter()
347 .map(Range::<IntVal>::resolve)
348 .collect::<Option<_>>()
349 .map(GroundDomain::Int),
350 UnresolvedDomain::Set(attr, inner) => {
351 Some(GroundDomain::Set(attr.resolve()?, inner.resolve()?))
352 }
353 UnresolvedDomain::MSet(attr, inner) => {
354 Some(GroundDomain::MSet(attr.resolve()?, inner.resolve()?))
355 }
356 UnresolvedDomain::Matrix(inner, idx_doms) => {
357 let inner_gd = inner.resolve()?;
358 idx_doms
359 .iter()
360 .map(DomainPtr::resolve)
361 .collect::<Option<_>>()
362 .map(|idx| GroundDomain::Matrix(inner_gd, idx))
363 }
364 UnresolvedDomain::Tuple(inners) => inners
365 .iter()
366 .map(DomainPtr::resolve)
367 .collect::<Option<_>>()
368 .map(GroundDomain::Tuple),
369 UnresolvedDomain::Record(entries) => entries
370 .iter()
371 .map(|f| {
372 f.domain.resolve().map(|gd| RecordEntryGround {
373 name: f.name.clone(),
374 domain: gd,
375 })
376 })
377 .collect::<Option<_>>()
378 .map(GroundDomain::Record),
379 UnresolvedDomain::Reference(re) => re
380 .ptr
381 .as_domain_letting()
382 .unwrap_or_else(|| {
383 bug!("Reference domain should point to domain letting, but got {re}")
384 })
385 .resolve()
386 .map(Moo::unwrap_or_clone),
387 UnresolvedDomain::Function(attr, dom, cdom) => {
388 if let Some(attr_gd) = attr.resolve()
389 && let Some(dom_gd) = dom.resolve()
390 && let Some(cdom_gd) = cdom.resolve()
391 {
392 return Some(GroundDomain::Function(attr_gd, dom_gd, cdom_gd));
393 }
394 None
395 }
396 }
397 }
398
399 pub(super) fn union_unresolved(
400 &self,
401 other: &UnresolvedDomain,
402 ) -> Result<UnresolvedDomain, DomainOpError> {
403 match (self, other) {
404 (UnresolvedDomain::Int(lhs), UnresolvedDomain::Int(rhs)) => {
405 let merged = lhs.iter().chain(rhs.iter()).cloned().collect_vec();
406 Ok(UnresolvedDomain::Int(merged))
407 }
408 (UnresolvedDomain::Int(_), _) | (_, UnresolvedDomain::Int(_)) => {
409 Err(DomainOpError::WrongType)
410 }
411 (UnresolvedDomain::Set(_, in1), UnresolvedDomain::Set(_, in2)) => {
412 Ok(UnresolvedDomain::Set(SetAttr::default(), in1.union(in2)?))
413 }
414 (UnresolvedDomain::Set(_, _), _) | (_, UnresolvedDomain::Set(_, _)) => {
415 Err(DomainOpError::WrongType)
416 }
417 (UnresolvedDomain::MSet(_, in1), UnresolvedDomain::MSet(_, in2)) => {
418 Ok(UnresolvedDomain::MSet(MSetAttr::default(), in1.union(in2)?))
419 }
420 (UnresolvedDomain::MSet(_, _), _) | (_, UnresolvedDomain::MSet(_, _)) => {
421 Err(DomainOpError::WrongType)
422 }
423 (UnresolvedDomain::Matrix(in1, idx1), UnresolvedDomain::Matrix(in2, idx2))
424 if idx1 == idx2 =>
425 {
426 Ok(UnresolvedDomain::Matrix(in1.union(in2)?, idx1.clone()))
427 }
428 (UnresolvedDomain::Matrix(_, _), _) | (_, UnresolvedDomain::Matrix(_, _)) => {
429 Err(DomainOpError::WrongType)
430 }
431 (UnresolvedDomain::Tuple(lhs), UnresolvedDomain::Tuple(rhs))
432 if lhs.len() == rhs.len() =>
433 {
434 let mut merged = Vec::new();
435 for (l, r) in zip(lhs, rhs) {
436 merged.push(l.union(r)?)
437 }
438 Ok(UnresolvedDomain::Tuple(merged))
439 }
440 (UnresolvedDomain::Tuple(_), _) | (_, UnresolvedDomain::Tuple(_)) => {
441 Err(DomainOpError::WrongType)
442 }
443 (UnresolvedDomain::Reference(_), _) | (_, UnresolvedDomain::Reference(_)) => {
445 Err(DomainOpError::NotGround)
446 }
447 #[allow(unreachable_patterns)] (UnresolvedDomain::Record(_), _) | (_, UnresolvedDomain::Record(_)) => {
450 Err(DomainOpError::WrongType)
451 }
452 #[allow(unreachable_patterns)]
453 (UnresolvedDomain::Function(_, _, _), _) | (_, UnresolvedDomain::Function(_, _, _)) => {
455 Err(DomainOpError::WrongType)
456 }
457 }
458 }
459
460 pub fn element_domain(&self) -> Option<DomainPtr> {
461 match self {
462 UnresolvedDomain::Set(_, inner_dom) => Some(inner_dom.clone()),
463 UnresolvedDomain::Matrix(_, _) => {
464 todo!("Unwrap one dimension of the domain")
465 }
466 _ => None,
467 }
468 }
469}
470
471impl Typeable for UnresolvedDomain {
472 fn return_type(&self) -> ReturnType {
473 match self {
474 UnresolvedDomain::Reference(re) => re.return_type(),
475 UnresolvedDomain::Int(_) => ReturnType::Int,
476 UnresolvedDomain::Set(_attr, inner) => ReturnType::Set(Box::new(inner.return_type())),
477 UnresolvedDomain::MSet(_attr, inner) => ReturnType::MSet(Box::new(inner.return_type())),
478 UnresolvedDomain::Matrix(inner, _idx) => {
479 ReturnType::Matrix(Box::new(inner.return_type()))
480 }
481 UnresolvedDomain::Tuple(inners) => {
482 let mut inner_types = Vec::new();
483 for inner in inners {
484 inner_types.push(inner.return_type());
485 }
486 ReturnType::Tuple(inner_types)
487 }
488 UnresolvedDomain::Record(entries) => {
489 let mut entry_types = Vec::new();
490 for entry in entries {
491 entry_types.push(entry.domain.return_type());
492 }
493 ReturnType::Record(entry_types)
494 }
495 UnresolvedDomain::Function(_, dom, cdom) => {
496 ReturnType::Function(Box::new(dom.return_type()), Box::new(cdom.return_type()))
497 }
498 }
499 }
500}
501
502impl Display for UnresolvedDomain {
503 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
504 match &self {
505 UnresolvedDomain::Reference(re) => write!(f, "{re}"),
506 UnresolvedDomain::Int(ranges) => {
507 if ranges.iter().all(Range::is_lower_or_upper_bounded) {
508 let rngs: String = ranges.iter().map(|r| format!("{r}")).join(", ");
509 write!(f, "int({})", rngs)
510 } else {
511 write!(f, "int")
512 }
513 }
514 UnresolvedDomain::Set(attrs, inner_dom) => write!(f, "set {attrs} of {inner_dom}"),
515 UnresolvedDomain::MSet(attrs, inner_dom) => write!(f, "mset {attrs} of {inner_dom}"),
516 UnresolvedDomain::Matrix(value_domain, index_domains) => {
517 write!(
518 f,
519 "matrix indexed by {} of {value_domain}",
520 pretty_vec(&index_domains.iter().collect_vec())
521 )
522 }
523 UnresolvedDomain::Tuple(domains) => {
524 write!(
525 f,
526 "tuple of ({})",
527 pretty_vec(&domains.iter().collect_vec())
528 )
529 }
530 UnresolvedDomain::Record(entries) => {
531 write!(
532 f,
533 "record of ({})",
534 pretty_vec(
535 &entries
536 .iter()
537 .map(|entry| format!("{}: {}", entry.name, entry.domain))
538 .collect_vec()
539 )
540 )
541 }
542 UnresolvedDomain::Function(attribute, domain, codomain) => {
543 write!(f, "function {} {} --> {} ", attribute, domain, codomain)
544 }
545 }
546 }
547}