1use crate::ast::domains::attrs::SetAttr;
2use crate::ast::{
3 DeclarationKind, DomainOpError, Expression, FuncAttr, Literal, Metadata, Moo,
4 RecordEntryGround, Reference, Typeable,
5 domains::{
6 GroundDomain,
7 domain::{DomainPtr, Int},
8 range::Range,
9 },
10};
11use crate::{bug, domain_int, matrix_expr, range};
12use conjure_cp_core::ast::pretty::pretty_vec;
13use conjure_cp_core::ast::{Name, ReturnType, eval_constant};
14use itertools::Itertools;
15use polyquine::Quine;
16use serde::{Deserialize, Serialize};
17use std::fmt::{Display, Formatter};
18use std::iter::zip;
19use std::ops::Deref;
20use uniplate::Uniplate;
21
22#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Quine, Uniplate)]
23#[path_prefix(conjure_cp::ast)]
24#[biplate(to=Expression)]
25#[biplate(to=Reference)]
26pub enum IntVal {
27 Const(Int),
28 #[polyquine_skip]
29 Reference(Reference),
30 Expr(Moo<Expression>),
31}
32
33impl From<Int> for IntVal {
34 fn from(value: Int) -> Self {
35 Self::Const(value)
36 }
37}
38
39impl TryInto<Int> for IntVal {
40 type Error = DomainOpError;
41
42 fn try_into(self) -> Result<Int, Self::Error> {
43 match self {
44 IntVal::Const(val) => Ok(val),
45 _ => Err(DomainOpError::NotGround),
46 }
47 }
48}
49
50impl From<Range<Int>> for Range<IntVal> {
51 fn from(value: Range<Int>) -> Self {
52 match value {
53 Range::Single(x) => Range::Single(x.into()),
54 Range::Bounded(l, r) => Range::Bounded(l.into(), r.into()),
55 Range::UnboundedL(r) => Range::UnboundedL(r.into()),
56 Range::UnboundedR(l) => Range::UnboundedR(l.into()),
57 Range::Unbounded => Range::Unbounded,
58 }
59 }
60}
61
62impl TryInto<Range<Int>> for Range<IntVal> {
63 type Error = DomainOpError;
64
65 fn try_into(self) -> Result<Range<Int>, Self::Error> {
66 match self {
67 Range::Single(x) => Ok(Range::Single(x.try_into()?)),
68 Range::Bounded(l, r) => Ok(Range::Bounded(l.try_into()?, r.try_into()?)),
69 Range::UnboundedL(r) => Ok(Range::UnboundedL(r.try_into()?)),
70 Range::UnboundedR(l) => Ok(Range::UnboundedR(l.try_into()?)),
71 Range::Unbounded => Ok(Range::Unbounded),
72 }
73 }
74}
75
76impl From<SetAttr<Int>> for SetAttr<IntVal> {
77 fn from(value: SetAttr<Int>) -> Self {
78 SetAttr {
79 size: value.size.into(),
80 }
81 }
82}
83
84impl TryInto<SetAttr<Int>> for SetAttr<IntVal> {
85 type Error = DomainOpError;
86
87 fn try_into(self) -> Result<SetAttr<Int>, Self::Error> {
88 let size: Range<Int> = self.size.try_into()?;
89 Ok(SetAttr { size })
90 }
91}
92
93impl From<FuncAttr<Int>> for FuncAttr<IntVal> {
94 fn from(value: FuncAttr<Int>) -> Self {
95 FuncAttr {
96 size: value.size.into(),
97 partiality: value.partiality,
98 jectivity: value.jectivity,
99 }
100 }
101}
102
103impl TryInto<FuncAttr<Int>> for FuncAttr<IntVal> {
104 type Error = DomainOpError;
105
106 fn try_into(self) -> Result<FuncAttr<Int>, Self::Error> {
107 let size: Range<Int> = self.size.try_into()?;
108 Ok(FuncAttr {
109 size,
110 jectivity: self.jectivity,
111 partiality: self.partiality,
112 })
113 }
114}
115
116impl Display for IntVal {
117 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
118 match self {
119 IntVal::Const(val) => write!(f, "{val}"),
120 IntVal::Reference(re) => write!(f, "{re}"),
121 IntVal::Expr(expr) => write!(f, "({expr})"),
122 }
123 }
124}
125
126impl IntVal {
127 pub fn new_ref(re: &Reference) -> Option<IntVal> {
128 match re.ptr.kind().deref() {
129 DeclarationKind::ValueLetting(expr) => match expr.return_type() {
130 ReturnType::Int => Some(IntVal::Reference(re.clone())),
131 _ => None,
132 },
133 DeclarationKind::Given(dom) => match dom.return_type() {
134 ReturnType::Int => Some(IntVal::Reference(re.clone())),
135 _ => None,
136 },
137 DeclarationKind::GivenQuantified(inner) => match inner.domain().return_type() {
138 ReturnType::Int => Some(IntVal::Reference(re.clone())),
139 _ => None,
140 },
141 DeclarationKind::DomainLetting(_)
142 | DeclarationKind::RecordField(_)
143 | DeclarationKind::DecisionVariable(_) => None,
144 }
145 }
146
147 pub fn new_expr(value: Moo<Expression>) -> Option<IntVal> {
148 if value.return_type() != ReturnType::Int {
149 return None;
150 }
151 Some(IntVal::Expr(value))
152 }
153
154 pub fn resolve(&self) -> Option<Int> {
155 match self {
156 IntVal::Const(value) => Some(*value),
157 IntVal::Expr(expr) => match eval_constant(expr)? {
158 Literal::Int(v) => Some(v),
159 _ => bug!("Expected integer expression, got: {expr}"),
160 },
161 IntVal::Reference(re) => match re.ptr.kind().deref() {
162 DeclarationKind::ValueLetting(expr) => match eval_constant(expr)? {
163 Literal::Int(v) => Some(v),
164 _ => bug!("Expected integer expression, got: {expr}"),
165 },
166 DeclarationKind::Given(_) | DeclarationKind::GivenQuantified(..) => None,
168 DeclarationKind::DomainLetting(_)
169 | DeclarationKind::RecordField(_)
170 | DeclarationKind::DecisionVariable(_) => bug!(
171 "Expected integer expression, given, or letting inside int domain; Got: {re}"
172 ),
173 },
174 }
175 }
176}
177
178impl From<IntVal> for Expression {
179 fn from(value: IntVal) -> Self {
180 match value {
181 IntVal::Const(val) => val.into(),
182 IntVal::Reference(re) => re.into(),
183 IntVal::Expr(expr) => expr.as_ref().clone(),
184 }
185 }
186}
187
188impl From<IntVal> for Moo<Expression> {
189 fn from(value: IntVal) -> Self {
190 match value {
191 IntVal::Const(val) => Moo::new(val.into()),
192 IntVal::Reference(re) => Moo::new(re.into()),
193 IntVal::Expr(expr) => expr,
194 }
195 }
196}
197
198impl std::ops::Neg for IntVal {
199 type Output = IntVal;
200
201 fn neg(self) -> Self::Output {
202 match self {
203 IntVal::Const(val) => IntVal::Const(-val),
204 IntVal::Reference(_) | IntVal::Expr(_) => {
205 IntVal::Expr(Moo::new(Expression::Neg(Metadata::new(), self.into())))
206 }
207 }
208 }
209}
210
211impl<T> std::ops::Add<T> for IntVal
212where
213 T: Into<Expression>,
214{
215 type Output = IntVal;
216
217 fn add(self, rhs: T) -> Self::Output {
218 let lhs: Expression = self.into();
219 let rhs: Expression = rhs.into();
220 let sum = matrix_expr!(lhs, rhs; domain_int!(1..));
221 IntVal::Expr(Moo::new(Expression::Sum(Metadata::new(), Moo::new(sum))))
222 }
223}
224
225impl Range<IntVal> {
226 pub fn resolve(&self) -> Option<Range<Int>> {
227 match self {
228 Range::Single(x) => Some(Range::Single(x.resolve()?)),
229 Range::Bounded(l, r) => Some(Range::Bounded(l.resolve()?, r.resolve()?)),
230 Range::UnboundedL(r) => Some(Range::UnboundedL(r.resolve()?)),
231 Range::UnboundedR(l) => Some(Range::UnboundedR(l.resolve()?)),
232 Range::Unbounded => Some(Range::Unbounded),
233 }
234 }
235}
236
237impl SetAttr<IntVal> {
238 pub fn resolve(&self) -> Option<SetAttr<Int>> {
239 Some(SetAttr {
240 size: self.size.resolve()?,
241 })
242 }
243}
244
245impl FuncAttr<IntVal> {
246 pub fn resolve(&self) -> Option<FuncAttr<Int>> {
247 Some(FuncAttr {
248 size: self.size.resolve()?,
249 partiality: self.partiality.clone(),
250 jectivity: self.jectivity.clone(),
251 })
252 }
253}
254
255#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Uniplate, Quine)]
256#[path_prefix(conjure_cp::ast)]
257pub struct RecordEntry {
258 pub name: Name,
259 pub domain: DomainPtr,
260}
261
262impl RecordEntry {
263 pub fn resolve(self) -> Option<RecordEntryGround> {
264 Some(RecordEntryGround {
265 name: self.name,
266 domain: self.domain.resolve()?,
267 })
268 }
269}
270
271#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Quine, Uniplate)]
272#[path_prefix(conjure_cp::ast)]
273#[biplate(to=Expression)]
274#[biplate(to=Reference)]
275#[biplate(to=IntVal)]
276#[biplate(to=DomainPtr)]
277#[biplate(to=RecordEntry)]
278pub enum UnresolvedDomain {
279 Int(Vec<Range<IntVal>>),
280 Set(SetAttr<IntVal>, DomainPtr),
282 Matrix(DomainPtr, Vec<DomainPtr>),
284 Tuple(Vec<DomainPtr>),
286 #[polyquine_skip]
288 Reference(Reference),
289 Record(Vec<RecordEntry>),
291 Function(FuncAttr<IntVal>, DomainPtr, DomainPtr),
293}
294
295impl UnresolvedDomain {
296 pub fn resolve(&self) -> Option<GroundDomain> {
297 match self {
298 UnresolvedDomain::Int(rngs) => rngs
299 .iter()
300 .map(Range::<IntVal>::resolve)
301 .collect::<Option<_>>()
302 .map(GroundDomain::Int),
303 UnresolvedDomain::Set(attr, inner) => {
304 Some(GroundDomain::Set(attr.resolve()?, inner.resolve()?))
305 }
306 UnresolvedDomain::Matrix(inner, idx_doms) => {
307 let inner_gd = inner.resolve()?;
308 idx_doms
309 .iter()
310 .map(DomainPtr::resolve)
311 .collect::<Option<_>>()
312 .map(|idx| GroundDomain::Matrix(inner_gd, idx))
313 }
314 UnresolvedDomain::Tuple(inners) => inners
315 .iter()
316 .map(DomainPtr::resolve)
317 .collect::<Option<_>>()
318 .map(GroundDomain::Tuple),
319 UnresolvedDomain::Record(entries) => entries
320 .iter()
321 .map(|f| {
322 f.domain.resolve().map(|gd| RecordEntryGround {
323 name: f.name.clone(),
324 domain: gd,
325 })
326 })
327 .collect::<Option<_>>()
328 .map(GroundDomain::Record),
329 UnresolvedDomain::Reference(re) => re
330 .ptr
331 .as_domain_letting()
332 .unwrap_or_else(|| {
333 bug!("Reference domain should point to domain letting, but got {re}")
334 })
335 .resolve()
336 .map(Moo::unwrap_or_clone),
337 UnresolvedDomain::Function(attr, dom, cdom) => {
338 if let Some(attr_gd) = attr.resolve()
339 && let Some(dom_gd) = dom.resolve()
340 && let Some(cdom_gd) = cdom.resolve()
341 {
342 return Some(GroundDomain::Function(attr_gd, dom_gd, cdom_gd));
343 }
344 None
345 }
346 }
347 }
348
349 pub(super) fn union_unresolved(
350 &self,
351 other: &UnresolvedDomain,
352 ) -> Result<UnresolvedDomain, DomainOpError> {
353 match (self, other) {
354 (UnresolvedDomain::Int(lhs), UnresolvedDomain::Int(rhs)) => {
355 let merged = lhs.iter().chain(rhs.iter()).cloned().collect_vec();
356 Ok(UnresolvedDomain::Int(merged))
357 }
358 (UnresolvedDomain::Int(_), _) | (_, UnresolvedDomain::Int(_)) => {
359 Err(DomainOpError::WrongType)
360 }
361 (UnresolvedDomain::Set(_, in1), UnresolvedDomain::Set(_, in2)) => {
362 Ok(UnresolvedDomain::Set(SetAttr::default(), in1.union(in2)?))
363 }
364 (UnresolvedDomain::Set(_, _), _) | (_, UnresolvedDomain::Set(_, _)) => {
365 Err(DomainOpError::WrongType)
366 }
367 (UnresolvedDomain::Matrix(in1, idx1), UnresolvedDomain::Matrix(in2, idx2))
368 if idx1 == idx2 =>
369 {
370 Ok(UnresolvedDomain::Matrix(in1.union(in2)?, idx1.clone()))
371 }
372 (UnresolvedDomain::Matrix(_, _), _) | (_, UnresolvedDomain::Matrix(_, _)) => {
373 Err(DomainOpError::WrongType)
374 }
375 (UnresolvedDomain::Tuple(lhs), UnresolvedDomain::Tuple(rhs))
376 if lhs.len() == rhs.len() =>
377 {
378 let mut merged = Vec::new();
379 for (l, r) in zip(lhs, rhs) {
380 merged.push(l.union(r)?)
381 }
382 Ok(UnresolvedDomain::Tuple(merged))
383 }
384 (UnresolvedDomain::Tuple(_), _) | (_, UnresolvedDomain::Tuple(_)) => {
385 Err(DomainOpError::WrongType)
386 }
387 (UnresolvedDomain::Reference(_), _) | (_, UnresolvedDomain::Reference(_)) => {
389 Err(DomainOpError::NotGround)
390 }
391 #[allow(unreachable_patterns)] (UnresolvedDomain::Record(_), _) | (_, UnresolvedDomain::Record(_)) => {
394 Err(DomainOpError::WrongType)
395 }
396 #[allow(unreachable_patterns)]
397 (UnresolvedDomain::Function(_, _, _), _) | (_, UnresolvedDomain::Function(_, _, _)) => {
399 Err(DomainOpError::WrongType)
400 }
401 }
402 }
403
404 pub fn element_domain(&self) -> Option<DomainPtr> {
405 match self {
406 UnresolvedDomain::Set(_, inner_dom) => Some(inner_dom.clone()),
407 UnresolvedDomain::Matrix(_, _) => {
408 todo!("Unwrap one dimension of the domain")
409 }
410 _ => None,
411 }
412 }
413}
414
415impl Typeable for UnresolvedDomain {
416 fn return_type(&self) -> ReturnType {
417 match self {
418 UnresolvedDomain::Reference(re) => re.return_type(),
419 UnresolvedDomain::Int(_) => ReturnType::Int,
420 UnresolvedDomain::Set(_attr, inner) => ReturnType::Set(Box::new(inner.return_type())),
421 UnresolvedDomain::Matrix(inner, _idx) => {
422 ReturnType::Matrix(Box::new(inner.return_type()))
423 }
424 UnresolvedDomain::Tuple(inners) => {
425 let mut inner_types = Vec::new();
426 for inner in inners {
427 inner_types.push(inner.return_type());
428 }
429 ReturnType::Tuple(inner_types)
430 }
431 UnresolvedDomain::Record(entries) => {
432 let mut entry_types = Vec::new();
433 for entry in entries {
434 entry_types.push(entry.domain.return_type());
435 }
436 ReturnType::Record(entry_types)
437 }
438 UnresolvedDomain::Function(_, dom, cdom) => {
439 ReturnType::Function(Box::new(dom.return_type()), Box::new(cdom.return_type()))
440 }
441 }
442 }
443}
444
445impl Display for UnresolvedDomain {
446 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
447 match &self {
448 UnresolvedDomain::Reference(re) => write!(f, "{re}"),
449 UnresolvedDomain::Int(ranges) => {
450 if ranges.iter().all(Range::is_lower_or_upper_bounded) {
451 let rngs: String = ranges.iter().map(|r| format!("{r}")).join(", ");
452 write!(f, "int({})", rngs)
453 } else {
454 write!(f, "int")
455 }
456 }
457 UnresolvedDomain::Set(attrs, inner_dom) => write!(f, "set {attrs} of {inner_dom}"),
458 UnresolvedDomain::Matrix(value_domain, index_domains) => {
459 write!(
460 f,
461 "matrix indexed by [{}] of {value_domain}",
462 pretty_vec(&index_domains.iter().collect_vec())
463 )
464 }
465 UnresolvedDomain::Tuple(domains) => {
466 write!(
467 f,
468 "tuple of ({})",
469 pretty_vec(&domains.iter().collect_vec())
470 )
471 }
472 UnresolvedDomain::Record(entries) => {
473 write!(
474 f,
475 "record of ({})",
476 pretty_vec(
477 &entries
478 .iter()
479 .map(|entry| format!("{}: {}", entry.name, entry.domain))
480 .collect_vec()
481 )
482 )
483 }
484 UnresolvedDomain::Function(attribute, domain, codomain) => {
485 write!(f, "function {} {} --> {} ", attribute, domain, codomain)
486 }
487 }
488 }
489}