1
use crate::ast::domains::attrs::MSetAttr;
2
use crate::ast::domains::attrs::SetAttr;
3
use crate::ast::{
4
    DeclarationKind, DomainOpError, Expression, FuncAttr, Literal, Metadata, Moo,
5
    RecordEntryGround, Reference, Typeable,
6
    domains::{
7
        GroundDomain,
8
        domain::{DomainPtr, Int},
9
        range::Range,
10
    },
11
};
12
use crate::{bug, domain_int, matrix_expr, range};
13
use conjure_cp_core::ast::pretty::pretty_vec;
14
use conjure_cp_core::ast::{Name, ReturnType, eval_constant};
15
use itertools::Itertools;
16
use polyquine::Quine;
17
use serde::{Deserialize, Serialize};
18
use std::fmt::{Display, Formatter};
19
use std::iter::zip;
20
use std::ops::Deref;
21
use uniplate::Uniplate;
22

            
23
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Quine, Uniplate)]
24
#[path_prefix(conjure_cp::ast)]
25
#[biplate(to=Expression)]
26
#[biplate(to=Reference)]
27
pub enum IntVal {
28
    Const(Int),
29
    #[polyquine_skip]
30
    Reference(Reference),
31
    Expr(Moo<Expression>),
32
}
33

            
34
impl From<Int> for IntVal {
35
42508
    fn from(value: Int) -> Self {
36
42508
        Self::Const(value)
37
42508
    }
38
}
39

            
40
impl TryInto<Int> for IntVal {
41
    type Error = DomainOpError;
42

            
43
121468
    fn try_into(self) -> Result<Int, Self::Error> {
44
121468
        match self {
45
116632
            IntVal::Const(val) => Ok(val),
46
4836
            _ => Err(DomainOpError::NotGround),
47
        }
48
121468
    }
49
}
50

            
51
impl From<Range<Int>> for Range<IntVal> {
52
21454
    fn from(value: Range<Int>) -> Self {
53
21454
        match value {
54
400
            Range::Single(x) => Range::Single(x.into()),
55
21054
            Range::Bounded(l, r) => Range::Bounded(l.into(), r.into()),
56
            Range::UnboundedL(r) => Range::UnboundedL(r.into()),
57
            Range::UnboundedR(l) => Range::UnboundedR(l.into()),
58
            Range::Unbounded => Range::Unbounded,
59
        }
60
21454
    }
61
}
62

            
63
impl TryInto<Range<Int>> for Range<IntVal> {
64
    type Error = DomainOpError;
65

            
66
64780
    fn try_into(self) -> Result<Range<Int>, Self::Error> {
67
64780
        match self {
68
4100
            Range::Single(x) => Ok(Range::Single(x.try_into()?)),
69
59100
            Range::Bounded(l, r) => Ok(Range::Bounded(l.try_into()?, r.try_into()?)),
70
176
            Range::UnboundedL(r) => Ok(Range::UnboundedL(r.try_into()?)),
71
400
            Range::UnboundedR(l) => Ok(Range::UnboundedR(l.try_into()?)),
72
1004
            Range::Unbounded => Ok(Range::Unbounded),
73
        }
74
64780
    }
75
}
76

            
77
impl From<SetAttr<Int>> for SetAttr<IntVal> {
78
    fn from(value: SetAttr<Int>) -> Self {
79
        SetAttr {
80
            size: value.size.into(),
81
        }
82
    }
83
}
84

            
85
impl TryInto<SetAttr<Int>> for SetAttr<IntVal> {
86
    type Error = DomainOpError;
87

            
88
644
    fn try_into(self) -> Result<SetAttr<Int>, Self::Error> {
89
644
        let size: Range<Int> = self.size.try_into()?;
90
644
        Ok(SetAttr { size })
91
644
    }
92
}
93

            
94
impl From<MSetAttr<Int>> for MSetAttr<IntVal> {
95
    fn from(value: MSetAttr<Int>) -> Self {
96
        MSetAttr {
97
            size: value.size.into(),
98
            occurrence: value.occurrence.into(),
99
        }
100
    }
101
}
102

            
103
impl TryInto<MSetAttr<Int>> for MSetAttr<IntVal> {
104
    type Error = DomainOpError;
105

            
106
360
    fn try_into(self) -> Result<MSetAttr<Int>, Self::Error> {
107
360
        let size: Range<Int> = self.size.try_into()?;
108
360
        let occurrence: Range<Int> = self.occurrence.try_into()?;
109
360
        Ok(MSetAttr { size, occurrence })
110
360
    }
111
}
112

            
113
impl From<FuncAttr<Int>> for FuncAttr<IntVal> {
114
    fn from(value: FuncAttr<Int>) -> Self {
115
        FuncAttr {
116
            size: value.size.into(),
117
            partiality: value.partiality,
118
            jectivity: value.jectivity,
119
        }
120
    }
121
}
122

            
123
impl TryInto<FuncAttr<Int>> for FuncAttr<IntVal> {
124
    type Error = DomainOpError;
125

            
126
520
    fn try_into(self) -> Result<FuncAttr<Int>, Self::Error> {
127
520
        let size: Range<Int> = self.size.try_into()?;
128
520
        Ok(FuncAttr {
129
520
            size,
130
520
            jectivity: self.jectivity,
131
520
            partiality: self.partiality,
132
520
        })
133
520
    }
134
}
135

            
136
impl Display for IntVal {
137
6017224
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
138
6017224
        match self {
139
2917554
            IntVal::Const(val) => write!(f, "{val}"),
140
2983476
            IntVal::Reference(re) => write!(f, "{re}"),
141
116194
            IntVal::Expr(expr) => write!(f, "({expr})"),
142
        }
143
6017224
    }
144
}
145

            
146
impl IntVal {
147
1636
    pub fn new_ref(re: &Reference) -> Option<IntVal> {
148
1636
        match re.ptr.kind().deref() {
149
1450
            DeclarationKind::ValueLetting(expr, _)
150
            | DeclarationKind::TemporaryValueLetting(expr)
151
1450
            | DeclarationKind::QuantifiedExpr(expr) => match expr.return_type() {
152
1450
                ReturnType::Int => Some(IntVal::Reference(re.clone())),
153
                _ => None,
154
            },
155
106
            DeclarationKind::Given(dom) => match dom.return_type() {
156
106
                ReturnType::Int => Some(IntVal::Reference(re.clone())),
157
                _ => None,
158
            },
159
80
            DeclarationKind::Quantified(inner) => match inner.domain().return_type() {
160
80
                ReturnType::Int => Some(IntVal::Reference(re.clone())),
161
                _ => None,
162
            },
163
            DeclarationKind::Find(var) => match var.return_type() {
164
                ReturnType::Int => Some(IntVal::Reference(re.clone())),
165
                _ => None,
166
            },
167
            DeclarationKind::DomainLetting(_) | DeclarationKind::RecordField(_) => None,
168
        }
169
1636
    }
170

            
171
1188
    pub fn new_expr(value: Moo<Expression>) -> Option<IntVal> {
172
1188
        if value.return_type() != ReturnType::Int {
173
            return None;
174
1188
        }
175
1188
        Some(IntVal::Expr(value))
176
1188
    }
177

            
178
1148810
    pub fn resolve(&self) -> Option<Int> {
179
1148810
        match self {
180
558838
            IntVal::Const(value) => Some(*value),
181
58116
            IntVal::Expr(expr) => eval_expr_to_int(expr),
182
531856
            IntVal::Reference(re) => match re.ptr.kind().deref() {
183
525264
                DeclarationKind::ValueLetting(expr, _)
184
530064
                | DeclarationKind::TemporaryValueLetting(expr) => eval_expr_to_int(expr),
185
                // If this is an int given we will be able to resolve it eventually, but not yet
186
72
                DeclarationKind::Given(_) => None,
187
200
                DeclarationKind::Quantified(inner) => {
188
200
                    if let Some(generator) = inner.generator()
189
                        && let Some(expr) = generator.as_value_letting()
190
                    {
191
                        eval_expr_to_int(&expr)
192
                    } else {
193
200
                        None
194
                    }
195
                }
196
                // TODO: idk what this whole file does but I very much doubt it affects this
197
                DeclarationKind::QuantifiedExpr(_) => None,
198
                // Decision variables inside domains are unresolved until solving.
199
1520
                DeclarationKind::Find(_) => None,
200
                DeclarationKind::DomainLetting(_) | DeclarationKind::RecordField(_) => bug!(
201
                    "Expected integer expression, given, or letting inside int domain; Got: {re}"
202
                ),
203
            },
204
        }
205
1148810
    }
206
}
207

            
208
588180
fn eval_expr_to_int(expr: &Expression) -> Option<Int> {
209
588180
    match eval_constant(expr)? {
210
581274
        Literal::Int(v) => Some(v),
211
        _ => bug!("Expected integer expression, got: {expr}"),
212
    }
213
588180
}
214

            
215
impl From<IntVal> for Expression {
216
1360
    fn from(value: IntVal) -> Self {
217
1360
        match value {
218
320
            IntVal::Const(val) => val.into(),
219
360
            IntVal::Reference(re) => re.into(),
220
680
            IntVal::Expr(expr) => expr.as_ref().clone(),
221
        }
222
1360
    }
223
}
224

            
225
impl From<IntVal> for Moo<Expression> {
226
88
    fn from(value: IntVal) -> Self {
227
88
        match value {
228
            IntVal::Const(val) => Moo::new(val.into()),
229
88
            IntVal::Reference(re) => Moo::new(re.into()),
230
            IntVal::Expr(expr) => expr,
231
        }
232
88
    }
233
}
234

            
235
impl std::ops::Neg for IntVal {
236
    type Output = IntVal;
237

            
238
6976
    fn neg(self) -> Self::Output {
239
6976
        match self {
240
6888
            IntVal::Const(val) => IntVal::Const(-val),
241
            IntVal::Reference(_) | IntVal::Expr(_) => {
242
88
                IntVal::Expr(Moo::new(Expression::Neg(Metadata::new(), self.into())))
243
            }
244
        }
245
6976
    }
246
}
247

            
248
impl<T> std::ops::Add<T> for IntVal
249
where
250
    T: Into<Expression>,
251
{
252
    type Output = IntVal;
253

            
254
    fn add(self, rhs: T) -> Self::Output {
255
        let lhs: Expression = self.into();
256
        let rhs: Expression = rhs.into();
257
        let sum = matrix_expr!(lhs, rhs; domain_int!(1..));
258
        IntVal::Expr(Moo::new(Expression::Sum(Metadata::new(), Moo::new(sum))))
259
    }
260
}
261

            
262
impl Range<IntVal> {
263
577276
    pub fn resolve(&self) -> Option<Range<Int>> {
264
577276
        match self {
265
7104
            Range::Single(x) => Some(Range::Single(x.resolve()?)),
266
570172
            Range::Bounded(l, r) => Some(Range::Bounded(l.resolve()?, r.resolve()?)),
267
            Range::UnboundedL(r) => Some(Range::UnboundedL(r.resolve()?)),
268
            Range::UnboundedR(l) => Some(Range::UnboundedR(l.resolve()?)),
269
            Range::Unbounded => Some(Range::Unbounded),
270
        }
271
577276
    }
272
}
273

            
274
impl SetAttr<IntVal> {
275
    pub fn resolve(&self) -> Option<SetAttr<Int>> {
276
        Some(SetAttr {
277
            size: self.size.resolve()?,
278
        })
279
    }
280
}
281

            
282
impl MSetAttr<IntVal> {
283
    pub fn resolve(&self) -> Option<MSetAttr<Int>> {
284
        Some(MSetAttr {
285
            size: self.size.resolve()?,
286
            occurrence: self.occurrence.resolve()?,
287
        })
288
    }
289
}
290

            
291
impl FuncAttr<IntVal> {
292
    pub fn resolve(&self) -> Option<FuncAttr<Int>> {
293
        Some(FuncAttr {
294
            size: self.size.resolve()?,
295
            partiality: self.partiality.clone(),
296
            jectivity: self.jectivity.clone(),
297
        })
298
    }
299
}
300

            
301
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Uniplate, Quine)]
302
#[path_prefix(conjure_cp::ast)]
303
pub struct RecordEntry {
304
    pub name: Name,
305
    pub domain: DomainPtr,
306
}
307

            
308
impl RecordEntry {
309
    pub fn resolve(self) -> Option<RecordEntryGround> {
310
        Some(RecordEntryGround {
311
            name: self.name,
312
            domain: self.domain.resolve()?,
313
        })
314
    }
315
}
316

            
317
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Quine, Uniplate)]
318
#[path_prefix(conjure_cp::ast)]
319
#[biplate(to=Expression)]
320
#[biplate(to=Reference)]
321
#[biplate(to=IntVal)]
322
#[biplate(to=DomainPtr)]
323
#[biplate(to=RecordEntry)]
324
pub enum UnresolvedDomain {
325
    Int(Vec<Range<IntVal>>),
326
    /// A set of elements drawn from the inner domain
327
    Set(SetAttr<IntVal>, DomainPtr),
328
    MSet(MSetAttr<IntVal>, DomainPtr),
329
    /// A n-dimensional matrix with a value domain and n-index domains
330
    Matrix(DomainPtr, Vec<DomainPtr>),
331
    /// A tuple of N elements, each with its own domain
332
    Tuple(Vec<DomainPtr>),
333
    /// A reference to a domain letting
334
    #[polyquine_skip]
335
    Reference(Reference),
336
    /// A record
337
    Record(Vec<RecordEntry>),
338
    /// A function with attributes, domain, and range
339
    Function(FuncAttr<IntVal>, DomainPtr, DomainPtr),
340
}
341

            
342
impl UnresolvedDomain {
343
1457736
    pub fn resolve(&self) -> Option<GroundDomain> {
344
1457736
        match self {
345
577276
            UnresolvedDomain::Int(rngs) => rngs
346
577276
                .iter()
347
577276
                .map(Range::<IntVal>::resolve)
348
577276
                .collect::<Option<_>>()
349
577276
                .map(GroundDomain::Int),
350
            UnresolvedDomain::Set(attr, inner) => {
351
                Some(GroundDomain::Set(attr.resolve()?, inner.resolve()?))
352
            }
353
            UnresolvedDomain::MSet(attr, inner) => {
354
                Some(GroundDomain::MSet(attr.resolve()?, inner.resolve()?))
355
            }
356
317560
            UnresolvedDomain::Matrix(inner, idx_doms) => {
357
317560
                let inner_gd = inner.resolve()?;
358
317522
                idx_doms
359
317522
                    .iter()
360
317522
                    .map(DomainPtr::resolve)
361
317522
                    .collect::<Option<_>>()
362
317522
                    .map(|idx| GroundDomain::Matrix(inner_gd, idx))
363
            }
364
            UnresolvedDomain::Tuple(inners) => inners
365
                .iter()
366
                .map(DomainPtr::resolve)
367
                .collect::<Option<_>>()
368
                .map(GroundDomain::Tuple),
369
            UnresolvedDomain::Record(entries) => entries
370
                .iter()
371
                .map(|f| {
372
                    f.domain.resolve().map(|gd| RecordEntryGround {
373
                        name: f.name.clone(),
374
                        domain: gd,
375
                    })
376
                })
377
                .collect::<Option<_>>()
378
                .map(GroundDomain::Record),
379
562900
            UnresolvedDomain::Reference(re) => re
380
562900
                .ptr
381
562900
                .as_domain_letting()
382
562900
                .unwrap_or_else(|| {
383
                    bug!("Reference domain should point to domain letting, but got {re}")
384
                })
385
562900
                .resolve()
386
562900
                .map(Moo::unwrap_or_clone),
387
            UnresolvedDomain::Function(attr, dom, cdom) => {
388
                if let Some(attr_gd) = attr.resolve()
389
                    && let Some(dom_gd) = dom.resolve()
390
                    && let Some(cdom_gd) = cdom.resolve()
391
                {
392
                    return Some(GroundDomain::Function(attr_gd, dom_gd, cdom_gd));
393
                }
394
                None
395
            }
396
        }
397
1457736
    }
398

            
399
    pub(super) fn union_unresolved(
400
        &self,
401
        other: &UnresolvedDomain,
402
    ) -> Result<UnresolvedDomain, DomainOpError> {
403
        match (self, other) {
404
            (UnresolvedDomain::Int(lhs), UnresolvedDomain::Int(rhs)) => {
405
                let merged = lhs.iter().chain(rhs.iter()).cloned().collect_vec();
406
                Ok(UnresolvedDomain::Int(merged))
407
            }
408
            (UnresolvedDomain::Int(_), _) | (_, UnresolvedDomain::Int(_)) => {
409
                Err(DomainOpError::WrongType)
410
            }
411
            (UnresolvedDomain::Set(_, in1), UnresolvedDomain::Set(_, in2)) => {
412
                Ok(UnresolvedDomain::Set(SetAttr::default(), in1.union(in2)?))
413
            }
414
            (UnresolvedDomain::Set(_, _), _) | (_, UnresolvedDomain::Set(_, _)) => {
415
                Err(DomainOpError::WrongType)
416
            }
417
            (UnresolvedDomain::MSet(_, in1), UnresolvedDomain::MSet(_, in2)) => {
418
                Ok(UnresolvedDomain::MSet(MSetAttr::default(), in1.union(in2)?))
419
            }
420
            (UnresolvedDomain::MSet(_, _), _) | (_, UnresolvedDomain::MSet(_, _)) => {
421
                Err(DomainOpError::WrongType)
422
            }
423
            (UnresolvedDomain::Matrix(in1, idx1), UnresolvedDomain::Matrix(in2, idx2))
424
                if idx1 == idx2 =>
425
            {
426
                Ok(UnresolvedDomain::Matrix(in1.union(in2)?, idx1.clone()))
427
            }
428
            (UnresolvedDomain::Matrix(_, _), _) | (_, UnresolvedDomain::Matrix(_, _)) => {
429
                Err(DomainOpError::WrongType)
430
            }
431
            (UnresolvedDomain::Tuple(lhs), UnresolvedDomain::Tuple(rhs))
432
                if lhs.len() == rhs.len() =>
433
            {
434
                let mut merged = Vec::new();
435
                for (l, r) in zip(lhs, rhs) {
436
                    merged.push(l.union(r)?)
437
                }
438
                Ok(UnresolvedDomain::Tuple(merged))
439
            }
440
            (UnresolvedDomain::Tuple(_), _) | (_, UnresolvedDomain::Tuple(_)) => {
441
                Err(DomainOpError::WrongType)
442
            }
443
            // TODO: Could we support unions of reference domains symbolically?
444
            (UnresolvedDomain::Reference(_), _) | (_, UnresolvedDomain::Reference(_)) => {
445
                Err(DomainOpError::NotGround)
446
            }
447
            // TODO: Could we define semantics for merging record domains?
448
            #[allow(unreachable_patterns)] // Technically redundant but logically makes sense
449
            (UnresolvedDomain::Record(_), _) | (_, UnresolvedDomain::Record(_)) => {
450
                Err(DomainOpError::WrongType)
451
            }
452
            #[allow(unreachable_patterns)]
453
            // Technically redundant but logically clearer to have both
454
            (UnresolvedDomain::Function(_, _, _), _) | (_, UnresolvedDomain::Function(_, _, _)) => {
455
                Err(DomainOpError::WrongType)
456
            }
457
        }
458
    }
459

            
460
    pub fn element_domain(&self) -> Option<DomainPtr> {
461
        match self {
462
            UnresolvedDomain::Set(_, inner_dom) => Some(inner_dom.clone()),
463
            UnresolvedDomain::Matrix(_, _) => {
464
                todo!("Unwrap one dimension of the domain")
465
            }
466
            _ => None,
467
        }
468
    }
469
}
470

            
471
impl Typeable for UnresolvedDomain {
472
86564
    fn return_type(&self) -> ReturnType {
473
86564
        match self {
474
18120
            UnresolvedDomain::Reference(re) => re.return_type(),
475
7184
            UnresolvedDomain::Int(_) => ReturnType::Int,
476
            UnresolvedDomain::Set(_attr, inner) => ReturnType::Set(Box::new(inner.return_type())),
477
            UnresolvedDomain::MSet(_attr, inner) => ReturnType::MSet(Box::new(inner.return_type())),
478
61260
            UnresolvedDomain::Matrix(inner, _idx) => {
479
61260
                ReturnType::Matrix(Box::new(inner.return_type()))
480
            }
481
            UnresolvedDomain::Tuple(inners) => {
482
                let mut inner_types = Vec::new();
483
                for inner in inners {
484
                    inner_types.push(inner.return_type());
485
                }
486
                ReturnType::Tuple(inner_types)
487
            }
488
            UnresolvedDomain::Record(entries) => {
489
                let mut entry_types = Vec::new();
490
                for entry in entries {
491
                    entry_types.push(entry.domain.return_type());
492
                }
493
                ReturnType::Record(entry_types)
494
            }
495
            UnresolvedDomain::Function(_, dom, cdom) => {
496
                ReturnType::Function(Box::new(dom.return_type()), Box::new(cdom.return_type()))
497
            }
498
        }
499
86564
    }
500
}
501

            
502
impl Display for UnresolvedDomain {
503
13020278
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
504
13020278
        match &self {
505
5511038
            UnresolvedDomain::Reference(re) => write!(f, "{re}"),
506
3079686
            UnresolvedDomain::Int(ranges) => {
507
3079686
                if ranges.iter().all(Range::is_lower_or_upper_bounded) {
508
3080166
                    let rngs: String = ranges.iter().map(|r| format!("{r}")).join(", ");
509
3079686
                    write!(f, "int({})", rngs)
510
                } else {
511
                    write!(f, "int")
512
                }
513
            }
514
            UnresolvedDomain::Set(attrs, inner_dom) => write!(f, "set {attrs} of {inner_dom}"),
515
            UnresolvedDomain::MSet(attrs, inner_dom) => write!(f, "mset {attrs} of {inner_dom}"),
516
4429554
            UnresolvedDomain::Matrix(value_domain, index_domains) => {
517
4429554
                write!(
518
4429554
                    f,
519
                    "matrix indexed by {} of {value_domain}",
520
4429554
                    pretty_vec(&index_domains.iter().collect_vec())
521
                )
522
            }
523
            UnresolvedDomain::Tuple(domains) => {
524
                write!(
525
                    f,
526
                    "tuple of ({})",
527
                    pretty_vec(&domains.iter().collect_vec())
528
                )
529
            }
530
            UnresolvedDomain::Record(entries) => {
531
                write!(
532
                    f,
533
                    "record of ({})",
534
                    pretty_vec(
535
                        &entries
536
                            .iter()
537
                            .map(|entry| format!("{}: {}", entry.name, entry.domain))
538
                            .collect_vec()
539
                    )
540
                )
541
            }
542
            UnresolvedDomain::Function(attribute, domain, codomain) => {
543
                write!(f, "function {} {} --> {} ", attribute, domain, codomain)
544
            }
545
        }
546
13020278
    }
547
}