1use crate::ast::domains::attrs::MSetAttr;
2use crate::ast::domains::attrs::SetAttr;
3use crate::ast::{
4 DeclarationKind, DomainOpError, Expression, FuncAttr, Literal, Metadata, Moo,
5 RecordEntryGround, Reference, Typeable,
6 domains::{
7 GroundDomain,
8 domain::{DomainPtr, Int},
9 range::Range,
10 },
11};
12use crate::{bug, domain_int, matrix_expr, range};
13use conjure_cp_core::ast::pretty::pretty_vec;
14use conjure_cp_core::ast::{Name, ReturnType, eval_constant};
15use itertools::Itertools;
16use polyquine::Quine;
17use serde::{Deserialize, Serialize};
18use std::fmt::{Display, Formatter};
19use std::iter::zip;
20use std::ops::Deref;
21use uniplate::Uniplate;
22
23#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Quine, Uniplate)]
24#[path_prefix(conjure_cp::ast)]
25#[biplate(to=Expression)]
26#[biplate(to=Reference)]
27pub enum IntVal {
28 Const(Int),
29 #[polyquine_skip]
30 Reference(Reference),
31 Expr(Moo<Expression>),
32}
33
34impl From<Int> for IntVal {
35 fn from(value: Int) -> Self {
36 Self::Const(value)
37 }
38}
39
40impl TryInto<Int> for IntVal {
41 type Error = DomainOpError;
42
43 fn try_into(self) -> Result<Int, Self::Error> {
44 match self {
45 IntVal::Const(val) => Ok(val),
46 _ => Err(DomainOpError::NotGround),
47 }
48 }
49}
50
51impl From<Range<Int>> for Range<IntVal> {
52 fn from(value: Range<Int>) -> Self {
53 match value {
54 Range::Single(x) => Range::Single(x.into()),
55 Range::Bounded(l, r) => Range::Bounded(l.into(), r.into()),
56 Range::UnboundedL(r) => Range::UnboundedL(r.into()),
57 Range::UnboundedR(l) => Range::UnboundedR(l.into()),
58 Range::Unbounded => Range::Unbounded,
59 }
60 }
61}
62
63impl TryInto<Range<Int>> for Range<IntVal> {
64 type Error = DomainOpError;
65
66 fn try_into(self) -> Result<Range<Int>, Self::Error> {
67 match self {
68 Range::Single(x) => Ok(Range::Single(x.try_into()?)),
69 Range::Bounded(l, r) => Ok(Range::Bounded(l.try_into()?, r.try_into()?)),
70 Range::UnboundedL(r) => Ok(Range::UnboundedL(r.try_into()?)),
71 Range::UnboundedR(l) => Ok(Range::UnboundedR(l.try_into()?)),
72 Range::Unbounded => Ok(Range::Unbounded),
73 }
74 }
75}
76
77impl From<SetAttr<Int>> for SetAttr<IntVal> {
78 fn from(value: SetAttr<Int>) -> Self {
79 SetAttr {
80 size: value.size.into(),
81 }
82 }
83}
84
85impl TryInto<SetAttr<Int>> for SetAttr<IntVal> {
86 type Error = DomainOpError;
87
88 fn try_into(self) -> Result<SetAttr<Int>, Self::Error> {
89 let size: Range<Int> = self.size.try_into()?;
90 Ok(SetAttr { size })
91 }
92}
93
94impl From<MSetAttr<Int>> for MSetAttr<IntVal> {
95 fn from(value: MSetAttr<Int>) -> Self {
96 MSetAttr {
97 size: value.size.into(),
98 occurrence: value.occurrence.into(),
99 }
100 }
101}
102
103impl TryInto<MSetAttr<Int>> for MSetAttr<IntVal> {
104 type Error = DomainOpError;
105
106 fn try_into(self) -> Result<MSetAttr<Int>, Self::Error> {
107 let size: Range<Int> = self.size.try_into()?;
108 let occurrence: Range<Int> = self.occurrence.try_into()?;
109 Ok(MSetAttr { size, occurrence })
110 }
111}
112
113impl From<FuncAttr<Int>> for FuncAttr<IntVal> {
114 fn from(value: FuncAttr<Int>) -> Self {
115 FuncAttr {
116 size: value.size.into(),
117 partiality: value.partiality,
118 jectivity: value.jectivity,
119 }
120 }
121}
122
123impl TryInto<FuncAttr<Int>> for FuncAttr<IntVal> {
124 type Error = DomainOpError;
125
126 fn try_into(self) -> Result<FuncAttr<Int>, Self::Error> {
127 let size: Range<Int> = self.size.try_into()?;
128 Ok(FuncAttr {
129 size,
130 jectivity: self.jectivity,
131 partiality: self.partiality,
132 })
133 }
134}
135
136impl Display for IntVal {
137 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
138 match self {
139 IntVal::Const(val) => write!(f, "{val}"),
140 IntVal::Reference(re) => write!(f, "{re}"),
141 IntVal::Expr(expr) => write!(f, "({expr})"),
142 }
143 }
144}
145
146impl IntVal {
147 pub fn new_ref(re: &Reference) -> Option<IntVal> {
148 match re.ptr.kind().deref() {
149 DeclarationKind::ValueLetting(expr, _)
150 | DeclarationKind::TemporaryValueLetting(expr) => match expr.return_type() {
151 ReturnType::Int => Some(IntVal::Reference(re.clone())),
152 _ => None,
153 },
154 DeclarationKind::Given(dom) => match dom.return_type() {
155 ReturnType::Int => Some(IntVal::Reference(re.clone())),
156 _ => None,
157 },
158 DeclarationKind::Quantified(inner) => match inner.domain().return_type() {
159 ReturnType::Int => Some(IntVal::Reference(re.clone())),
160 _ => None,
161 },
162 DeclarationKind::Find(var) => match var.return_type() {
163 ReturnType::Int => Some(IntVal::Reference(re.clone())),
164 _ => None,
165 },
166 DeclarationKind::DomainLetting(_) | DeclarationKind::RecordField(_) => None,
167 }
168 }
169
170 pub fn new_expr(value: Moo<Expression>) -> Option<IntVal> {
171 if value.return_type() != ReturnType::Int {
172 return None;
173 }
174 Some(IntVal::Expr(value))
175 }
176
177 pub fn resolve(&self) -> Option<Int> {
178 match self {
179 IntVal::Const(value) => Some(*value),
180 IntVal::Expr(expr) => eval_expr_to_int(expr),
181 IntVal::Reference(re) => match re.ptr.kind().deref() {
182 DeclarationKind::ValueLetting(expr, _)
183 | DeclarationKind::TemporaryValueLetting(expr) => eval_expr_to_int(expr),
184 DeclarationKind::Given(_) => None,
186 DeclarationKind::Quantified(inner) => {
187 if let Some(generator) = inner.generator()
188 && let Some(expr) = generator.as_value_letting()
189 {
190 eval_expr_to_int(&expr)
191 } else {
192 None
193 }
194 }
195 DeclarationKind::Find(_) => None,
197 DeclarationKind::DomainLetting(_) | DeclarationKind::RecordField(_) => bug!(
198 "Expected integer expression, given, or letting inside int domain; Got: {re}"
199 ),
200 },
201 }
202 }
203}
204
205fn eval_expr_to_int(expr: &Expression) -> Option<Int> {
206 match eval_constant(expr)? {
207 Literal::Int(v) => Some(v),
208 _ => bug!("Expected integer expression, got: {expr}"),
209 }
210}
211
212impl From<IntVal> for Expression {
213 fn from(value: IntVal) -> Self {
214 match value {
215 IntVal::Const(val) => val.into(),
216 IntVal::Reference(re) => re.into(),
217 IntVal::Expr(expr) => expr.as_ref().clone(),
218 }
219 }
220}
221
222impl From<IntVal> for Moo<Expression> {
223 fn from(value: IntVal) -> Self {
224 match value {
225 IntVal::Const(val) => Moo::new(val.into()),
226 IntVal::Reference(re) => Moo::new(re.into()),
227 IntVal::Expr(expr) => expr,
228 }
229 }
230}
231
232impl std::ops::Neg for IntVal {
233 type Output = IntVal;
234
235 fn neg(self) -> Self::Output {
236 match self {
237 IntVal::Const(val) => IntVal::Const(-val),
238 IntVal::Reference(_) | IntVal::Expr(_) => {
239 IntVal::Expr(Moo::new(Expression::Neg(Metadata::new(), self.into())))
240 }
241 }
242 }
243}
244
245impl<T> std::ops::Add<T> for IntVal
246where
247 T: Into<Expression>,
248{
249 type Output = IntVal;
250
251 fn add(self, rhs: T) -> Self::Output {
252 let lhs: Expression = self.into();
253 let rhs: Expression = rhs.into();
254 let sum = matrix_expr!(lhs, rhs; domain_int!(1..));
255 IntVal::Expr(Moo::new(Expression::Sum(Metadata::new(), Moo::new(sum))))
256 }
257}
258
259impl Range<IntVal> {
260 pub fn resolve(&self) -> Option<Range<Int>> {
261 match self {
262 Range::Single(x) => Some(Range::Single(x.resolve()?)),
263 Range::Bounded(l, r) => Some(Range::Bounded(l.resolve()?, r.resolve()?)),
264 Range::UnboundedL(r) => Some(Range::UnboundedL(r.resolve()?)),
265 Range::UnboundedR(l) => Some(Range::UnboundedR(l.resolve()?)),
266 Range::Unbounded => Some(Range::Unbounded),
267 }
268 }
269}
270
271impl SetAttr<IntVal> {
272 pub fn resolve(&self) -> Option<SetAttr<Int>> {
273 Some(SetAttr {
274 size: self.size.resolve()?,
275 })
276 }
277}
278
279impl MSetAttr<IntVal> {
280 pub fn resolve(&self) -> Option<MSetAttr<Int>> {
281 Some(MSetAttr {
282 size: self.size.resolve()?,
283 occurrence: self.occurrence.resolve()?,
284 })
285 }
286}
287
288impl FuncAttr<IntVal> {
289 pub fn resolve(&self) -> Option<FuncAttr<Int>> {
290 Some(FuncAttr {
291 size: self.size.resolve()?,
292 partiality: self.partiality.clone(),
293 jectivity: self.jectivity.clone(),
294 })
295 }
296}
297
298#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Uniplate, Quine)]
299#[path_prefix(conjure_cp::ast)]
300pub struct RecordEntry {
301 pub name: Name,
302 pub domain: DomainPtr,
303}
304
305impl RecordEntry {
306 pub fn resolve(self) -> Option<RecordEntryGround> {
307 Some(RecordEntryGround {
308 name: self.name,
309 domain: self.domain.resolve()?,
310 })
311 }
312}
313
314#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Quine, Uniplate)]
315#[path_prefix(conjure_cp::ast)]
316#[biplate(to=Expression)]
317#[biplate(to=Reference)]
318#[biplate(to=IntVal)]
319#[biplate(to=DomainPtr)]
320#[biplate(to=RecordEntry)]
321pub enum UnresolvedDomain {
322 Int(Vec<Range<IntVal>>),
323 Set(SetAttr<IntVal>, DomainPtr),
325 MSet(MSetAttr<IntVal>, DomainPtr),
326 Matrix(DomainPtr, Vec<DomainPtr>),
328 Tuple(Vec<DomainPtr>),
330 #[polyquine_skip]
332 Reference(Reference),
333 Record(Vec<RecordEntry>),
335 Function(FuncAttr<IntVal>, DomainPtr, DomainPtr),
337}
338
339impl UnresolvedDomain {
340 pub fn resolve(&self) -> Option<GroundDomain> {
341 match self {
342 UnresolvedDomain::Int(rngs) => rngs
343 .iter()
344 .map(Range::<IntVal>::resolve)
345 .collect::<Option<_>>()
346 .map(GroundDomain::Int),
347 UnresolvedDomain::Set(attr, inner) => {
348 Some(GroundDomain::Set(attr.resolve()?, inner.resolve()?))
349 }
350 UnresolvedDomain::MSet(attr, inner) => {
351 Some(GroundDomain::MSet(attr.resolve()?, inner.resolve()?))
352 }
353 UnresolvedDomain::Matrix(inner, idx_doms) => {
354 let inner_gd = inner.resolve()?;
355 idx_doms
356 .iter()
357 .map(DomainPtr::resolve)
358 .collect::<Option<_>>()
359 .map(|idx| GroundDomain::Matrix(inner_gd, idx))
360 }
361 UnresolvedDomain::Tuple(inners) => inners
362 .iter()
363 .map(DomainPtr::resolve)
364 .collect::<Option<_>>()
365 .map(GroundDomain::Tuple),
366 UnresolvedDomain::Record(entries) => entries
367 .iter()
368 .map(|f| {
369 f.domain.resolve().map(|gd| RecordEntryGround {
370 name: f.name.clone(),
371 domain: gd,
372 })
373 })
374 .collect::<Option<_>>()
375 .map(GroundDomain::Record),
376 UnresolvedDomain::Reference(re) => re
377 .ptr
378 .as_domain_letting()
379 .unwrap_or_else(|| {
380 bug!("Reference domain should point to domain letting, but got {re}")
381 })
382 .resolve()
383 .map(Moo::unwrap_or_clone),
384 UnresolvedDomain::Function(attr, dom, cdom) => {
385 if let Some(attr_gd) = attr.resolve()
386 && let Some(dom_gd) = dom.resolve()
387 && let Some(cdom_gd) = cdom.resolve()
388 {
389 return Some(GroundDomain::Function(attr_gd, dom_gd, cdom_gd));
390 }
391 None
392 }
393 }
394 }
395
396 pub(super) fn union_unresolved(
397 &self,
398 other: &UnresolvedDomain,
399 ) -> Result<UnresolvedDomain, DomainOpError> {
400 match (self, other) {
401 (UnresolvedDomain::Int(lhs), UnresolvedDomain::Int(rhs)) => {
402 let merged = lhs.iter().chain(rhs.iter()).cloned().collect_vec();
403 Ok(UnresolvedDomain::Int(merged))
404 }
405 (UnresolvedDomain::Int(_), _) | (_, UnresolvedDomain::Int(_)) => {
406 Err(DomainOpError::WrongType)
407 }
408 (UnresolvedDomain::Set(_, in1), UnresolvedDomain::Set(_, in2)) => {
409 Ok(UnresolvedDomain::Set(SetAttr::default(), in1.union(in2)?))
410 }
411 (UnresolvedDomain::Set(_, _), _) | (_, UnresolvedDomain::Set(_, _)) => {
412 Err(DomainOpError::WrongType)
413 }
414 (UnresolvedDomain::MSet(_, in1), UnresolvedDomain::MSet(_, in2)) => {
415 Ok(UnresolvedDomain::MSet(MSetAttr::default(), in1.union(in2)?))
416 }
417 (UnresolvedDomain::MSet(_, _), _) | (_, UnresolvedDomain::MSet(_, _)) => {
418 Err(DomainOpError::WrongType)
419 }
420 (UnresolvedDomain::Matrix(in1, idx1), UnresolvedDomain::Matrix(in2, idx2))
421 if idx1 == idx2 =>
422 {
423 Ok(UnresolvedDomain::Matrix(in1.union(in2)?, idx1.clone()))
424 }
425 (UnresolvedDomain::Matrix(_, _), _) | (_, UnresolvedDomain::Matrix(_, _)) => {
426 Err(DomainOpError::WrongType)
427 }
428 (UnresolvedDomain::Tuple(lhs), UnresolvedDomain::Tuple(rhs))
429 if lhs.len() == rhs.len() =>
430 {
431 let mut merged = Vec::new();
432 for (l, r) in zip(lhs, rhs) {
433 merged.push(l.union(r)?)
434 }
435 Ok(UnresolvedDomain::Tuple(merged))
436 }
437 (UnresolvedDomain::Tuple(_), _) | (_, UnresolvedDomain::Tuple(_)) => {
438 Err(DomainOpError::WrongType)
439 }
440 (UnresolvedDomain::Reference(_), _) | (_, UnresolvedDomain::Reference(_)) => {
442 Err(DomainOpError::NotGround)
443 }
444 #[allow(unreachable_patterns)] (UnresolvedDomain::Record(_), _) | (_, UnresolvedDomain::Record(_)) => {
447 Err(DomainOpError::WrongType)
448 }
449 #[allow(unreachable_patterns)]
450 (UnresolvedDomain::Function(_, _, _), _) | (_, UnresolvedDomain::Function(_, _, _)) => {
452 Err(DomainOpError::WrongType)
453 }
454 }
455 }
456
457 pub fn element_domain(&self) -> Option<DomainPtr> {
458 match self {
459 UnresolvedDomain::Set(_, inner_dom) => Some(inner_dom.clone()),
460 UnresolvedDomain::Matrix(_, _) => {
461 todo!("Unwrap one dimension of the domain")
462 }
463 _ => None,
464 }
465 }
466}
467
468impl Typeable for UnresolvedDomain {
469 fn return_type(&self) -> ReturnType {
470 match self {
471 UnresolvedDomain::Reference(re) => re.return_type(),
472 UnresolvedDomain::Int(_) => ReturnType::Int,
473 UnresolvedDomain::Set(_attr, inner) => ReturnType::Set(Box::new(inner.return_type())),
474 UnresolvedDomain::MSet(_attr, inner) => ReturnType::MSet(Box::new(inner.return_type())),
475 UnresolvedDomain::Matrix(inner, _idx) => {
476 ReturnType::Matrix(Box::new(inner.return_type()))
477 }
478 UnresolvedDomain::Tuple(inners) => {
479 let mut inner_types = Vec::new();
480 for inner in inners {
481 inner_types.push(inner.return_type());
482 }
483 ReturnType::Tuple(inner_types)
484 }
485 UnresolvedDomain::Record(entries) => {
486 let mut entry_types = Vec::new();
487 for entry in entries {
488 entry_types.push(entry.domain.return_type());
489 }
490 ReturnType::Record(entry_types)
491 }
492 UnresolvedDomain::Function(_, dom, cdom) => {
493 ReturnType::Function(Box::new(dom.return_type()), Box::new(cdom.return_type()))
494 }
495 }
496 }
497}
498
499impl Display for UnresolvedDomain {
500 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
501 match &self {
502 UnresolvedDomain::Reference(re) => write!(f, "{re}"),
503 UnresolvedDomain::Int(ranges) => {
504 if ranges.iter().all(Range::is_lower_or_upper_bounded) {
505 let rngs: String = ranges.iter().map(|r| format!("{r}")).join(", ");
506 write!(f, "int({})", rngs)
507 } else {
508 write!(f, "int")
509 }
510 }
511 UnresolvedDomain::Set(attrs, inner_dom) => write!(f, "set {attrs} of {inner_dom}"),
512 UnresolvedDomain::MSet(attrs, inner_dom) => write!(f, "mset {attrs} of {inner_dom}"),
513 UnresolvedDomain::Matrix(value_domain, index_domains) => {
514 write!(
515 f,
516 "matrix indexed by {} of {value_domain}",
517 pretty_vec(&index_domains.iter().collect_vec())
518 )
519 }
520 UnresolvedDomain::Tuple(domains) => {
521 write!(
522 f,
523 "tuple of ({})",
524 pretty_vec(&domains.iter().collect_vec())
525 )
526 }
527 UnresolvedDomain::Record(entries) => {
528 write!(
529 f,
530 "record of ({})",
531 pretty_vec(
532 &entries
533 .iter()
534 .map(|entry| format!("{}: {}", entry.name, entry.domain))
535 .collect_vec()
536 )
537 )
538 }
539 UnresolvedDomain::Function(attribute, domain, codomain) => {
540 write!(f, "function {} {} --> {} ", attribute, domain, codomain)
541 }
542 }
543 }
544}