1
use crate::ast::domains::attrs::MSetAttr;
2
use crate::ast::domains::attrs::SetAttr;
3
use crate::ast::{
4
    DeclarationKind, DomainOpError, Expression, FuncAttr, Literal, Metadata, Moo,
5
    RecordEntryGround, Reference, Typeable,
6
    domains::{
7
        GroundDomain,
8
        domain::{DomainPtr, Int},
9
        range::Range,
10
    },
11
};
12
use crate::{bug, domain_int, matrix_expr, range};
13
use conjure_cp_core::ast::pretty::pretty_vec;
14
use conjure_cp_core::ast::{Name, ReturnType, eval_constant};
15
use itertools::Itertools;
16
use polyquine::Quine;
17
use serde::{Deserialize, Serialize};
18
use std::fmt::{Display, Formatter};
19
use std::iter::zip;
20
use std::ops::Deref;
21
use uniplate::Uniplate;
22

            
23
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Quine, Uniplate)]
24
#[path_prefix(conjure_cp::ast)]
25
#[biplate(to=Expression)]
26
#[biplate(to=Reference)]
27
pub enum IntVal {
28
    Const(Int),
29
    #[polyquine_skip]
30
    Reference(Reference),
31
    Expr(Moo<Expression>),
32
}
33

            
34
impl From<Int> for IntVal {
35
8160
    fn from(value: Int) -> Self {
36
8160
        Self::Const(value)
37
8160
    }
38
}
39

            
40
impl TryInto<Int> for IntVal {
41
    type Error = DomainOpError;
42

            
43
35280
    fn try_into(self) -> Result<Int, Self::Error> {
44
35280
        match self {
45
34920
            IntVal::Const(val) => Ok(val),
46
360
            _ => Err(DomainOpError::NotGround),
47
        }
48
35280
    }
49
}
50

            
51
impl From<Range<Int>> for Range<IntVal> {
52
4120
    fn from(value: Range<Int>) -> Self {
53
4120
        match value {
54
80
            Range::Single(x) => Range::Single(x.into()),
55
4040
            Range::Bounded(l, r) => Range::Bounded(l.into(), r.into()),
56
            Range::UnboundedL(r) => Range::UnboundedL(r.into()),
57
            Range::UnboundedR(l) => Range::UnboundedR(l.into()),
58
            Range::Unbounded => Range::Unbounded,
59
        }
60
4120
    }
61
}
62

            
63
impl TryInto<Range<Int>> for Range<IntVal> {
64
    type Error = DomainOpError;
65

            
66
19326
    fn try_into(self) -> Result<Range<Int>, Self::Error> {
67
19326
        match self {
68
1400
            Range::Single(x) => Ok(Range::Single(x.try_into()?)),
69
16740
            Range::Bounded(l, r) => Ok(Range::Bounded(l.try_into()?, r.try_into()?)),
70
160
            Range::UnboundedL(r) => Ok(Range::UnboundedL(r.try_into()?)),
71
300
            Range::UnboundedR(l) => Ok(Range::UnboundedR(l.try_into()?)),
72
726
            Range::Unbounded => Ok(Range::Unbounded),
73
        }
74
19326
    }
75
}
76

            
77
impl From<SetAttr<Int>> for SetAttr<IntVal> {
78
    fn from(value: SetAttr<Int>) -> Self {
79
        SetAttr {
80
            size: value.size.into(),
81
        }
82
    }
83
}
84

            
85
impl TryInto<SetAttr<Int>> for SetAttr<IntVal> {
86
    type Error = DomainOpError;
87

            
88
226
    fn try_into(self) -> Result<SetAttr<Int>, Self::Error> {
89
226
        let size: Range<Int> = self.size.try_into()?;
90
226
        Ok(SetAttr { size })
91
226
    }
92
}
93

            
94
impl From<MSetAttr<Int>> for MSetAttr<IntVal> {
95
    fn from(value: MSetAttr<Int>) -> Self {
96
        MSetAttr {
97
            size: value.size.into(),
98
            occurrence: value.occurrence.into(),
99
        }
100
    }
101
}
102

            
103
impl TryInto<MSetAttr<Int>> for MSetAttr<IntVal> {
104
    type Error = DomainOpError;
105

            
106
360
    fn try_into(self) -> Result<MSetAttr<Int>, Self::Error> {
107
360
        let size: Range<Int> = self.size.try_into()?;
108
360
        let occurrence: Range<Int> = self.occurrence.try_into()?;
109
360
        Ok(MSetAttr { size, occurrence })
110
360
    }
111
}
112

            
113
impl From<FuncAttr<Int>> for FuncAttr<IntVal> {
114
    fn from(value: FuncAttr<Int>) -> Self {
115
        FuncAttr {
116
            size: value.size.into(),
117
            partiality: value.partiality,
118
            jectivity: value.jectivity,
119
        }
120
    }
121
}
122

            
123
impl TryInto<FuncAttr<Int>> for FuncAttr<IntVal> {
124
    type Error = DomainOpError;
125

            
126
520
    fn try_into(self) -> Result<FuncAttr<Int>, Self::Error> {
127
520
        let size: Range<Int> = self.size.try_into()?;
128
520
        Ok(FuncAttr {
129
520
            size,
130
520
            jectivity: self.jectivity,
131
520
            partiality: self.partiality,
132
520
        })
133
520
    }
134
}
135

            
136
impl Display for IntVal {
137
4080
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
138
4080
        match self {
139
2000
            IntVal::Const(val) => write!(f, "{val}"),
140
1680
            IntVal::Reference(re) => write!(f, "{re}"),
141
400
            IntVal::Expr(expr) => write!(f, "({expr})"),
142
        }
143
4080
    }
144
}
145

            
146
impl IntVal {
147
300
    pub fn new_ref(re: &Reference) -> Option<IntVal> {
148
300
        match re.ptr.kind().deref() {
149
300
            DeclarationKind::ValueLetting(expr) => match expr.return_type() {
150
300
                ReturnType::Int => Some(IntVal::Reference(re.clone())),
151
                _ => None,
152
            },
153
            DeclarationKind::Given(dom) => match dom.return_type() {
154
                ReturnType::Int => Some(IntVal::Reference(re.clone())),
155
                _ => None,
156
            },
157
            DeclarationKind::Quantified(inner) => match inner.domain().return_type() {
158
                ReturnType::Int => Some(IntVal::Reference(re.clone())),
159
                _ => None,
160
            },
161
            DeclarationKind::DomainLetting(_)
162
            | DeclarationKind::RecordField(_)
163
            | DeclarationKind::Find(_) => None,
164
        }
165
300
    }
166

            
167
80
    pub fn new_expr(value: Moo<Expression>) -> Option<IntVal> {
168
80
        if value.return_type() != ReturnType::Int {
169
            return None;
170
80
        }
171
80
        Some(IntVal::Expr(value))
172
80
    }
173

            
174
29000
    pub fn resolve(&self) -> Option<Int> {
175
29000
        match self {
176
14460
            IntVal::Const(value) => Some(*value),
177
160
            IntVal::Expr(expr) => eval_expr_to_int(expr),
178
14380
            IntVal::Reference(re) => match re.ptr.kind().deref() {
179
14380
                DeclarationKind::ValueLetting(expr) => eval_expr_to_int(expr),
180
                // If this is an int given we will be able to resolve it eventually, but not yet
181
                DeclarationKind::Given(_) | DeclarationKind::Quantified(..) => None,
182
                DeclarationKind::DomainLetting(_)
183
                | DeclarationKind::RecordField(_)
184
                | DeclarationKind::Find(_) => bug!(
185
                    "Expected integer expression, given, or letting inside int domain; Got: {re}"
186
                ),
187
            },
188
        }
189
29000
    }
190
}
191

            
192
14540
fn eval_expr_to_int(expr: &Expression) -> Option<Int> {
193
14540
    match eval_constant(expr)? {
194
14540
        Literal::Int(v) => Some(v),
195
        _ => bug!("Expected integer expression, got: {expr}"),
196
    }
197
14540
}
198

            
199
impl From<IntVal> for Expression {
200
    fn from(value: IntVal) -> Self {
201
        match value {
202
            IntVal::Const(val) => val.into(),
203
            IntVal::Reference(re) => re.into(),
204
            IntVal::Expr(expr) => expr.as_ref().clone(),
205
        }
206
    }
207
}
208

            
209
impl From<IntVal> for Moo<Expression> {
210
    fn from(value: IntVal) -> Self {
211
        match value {
212
            IntVal::Const(val) => Moo::new(val.into()),
213
            IntVal::Reference(re) => Moo::new(re.into()),
214
            IntVal::Expr(expr) => expr,
215
        }
216
    }
217
}
218

            
219
impl std::ops::Neg for IntVal {
220
    type Output = IntVal;
221

            
222
800
    fn neg(self) -> Self::Output {
223
800
        match self {
224
800
            IntVal::Const(val) => IntVal::Const(-val),
225
            IntVal::Reference(_) | IntVal::Expr(_) => {
226
                IntVal::Expr(Moo::new(Expression::Neg(Metadata::new(), self.into())))
227
            }
228
        }
229
800
    }
230
}
231

            
232
impl<T> std::ops::Add<T> for IntVal
233
where
234
    T: Into<Expression>,
235
{
236
    type Output = IntVal;
237

            
238
    fn add(self, rhs: T) -> Self::Output {
239
        let lhs: Expression = self.into();
240
        let rhs: Expression = rhs.into();
241
        let sum = matrix_expr!(lhs, rhs; domain_int!(1..));
242
        IntVal::Expr(Moo::new(Expression::Sum(Metadata::new(), Moo::new(sum))))
243
    }
244
}
245

            
246
impl Range<IntVal> {
247
14500
    pub fn resolve(&self) -> Option<Range<Int>> {
248
14500
        match self {
249
            Range::Single(x) => Some(Range::Single(x.resolve()?)),
250
14500
            Range::Bounded(l, r) => Some(Range::Bounded(l.resolve()?, r.resolve()?)),
251
            Range::UnboundedL(r) => Some(Range::UnboundedL(r.resolve()?)),
252
            Range::UnboundedR(l) => Some(Range::UnboundedR(l.resolve()?)),
253
            Range::Unbounded => Some(Range::Unbounded),
254
        }
255
14500
    }
256
}
257

            
258
impl SetAttr<IntVal> {
259
    pub fn resolve(&self) -> Option<SetAttr<Int>> {
260
        Some(SetAttr {
261
            size: self.size.resolve()?,
262
        })
263
    }
264
}
265

            
266
impl MSetAttr<IntVal> {
267
    pub fn resolve(&self) -> Option<MSetAttr<Int>> {
268
        Some(MSetAttr {
269
            size: self.size.resolve()?,
270
            occurrence: self.occurrence.resolve()?,
271
        })
272
    }
273
}
274

            
275
impl FuncAttr<IntVal> {
276
    pub fn resolve(&self) -> Option<FuncAttr<Int>> {
277
        Some(FuncAttr {
278
            size: self.size.resolve()?,
279
            partiality: self.partiality.clone(),
280
            jectivity: self.jectivity.clone(),
281
        })
282
    }
283
}
284

            
285
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Uniplate, Quine)]
286
#[path_prefix(conjure_cp::ast)]
287
pub struct RecordEntry {
288
    pub name: Name,
289
    pub domain: DomainPtr,
290
}
291

            
292
impl RecordEntry {
293
    pub fn resolve(self) -> Option<RecordEntryGround> {
294
        Some(RecordEntryGround {
295
            name: self.name,
296
            domain: self.domain.resolve()?,
297
        })
298
    }
299
}
300

            
301
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Quine, Uniplate)]
302
#[path_prefix(conjure_cp::ast)]
303
#[biplate(to=Expression)]
304
#[biplate(to=Reference)]
305
#[biplate(to=IntVal)]
306
#[biplate(to=DomainPtr)]
307
#[biplate(to=RecordEntry)]
308
pub enum UnresolvedDomain {
309
    Int(Vec<Range<IntVal>>),
310
    /// A set of elements drawn from the inner domain
311
    Set(SetAttr<IntVal>, DomainPtr),
312
    MSet(MSetAttr<IntVal>, DomainPtr),
313
    /// A n-dimensional matrix with a value domain and n-index domains
314
    Matrix(DomainPtr, Vec<DomainPtr>),
315
    /// A tuple of N elements, each with its own domain
316
    Tuple(Vec<DomainPtr>),
317
    /// A reference to a domain letting
318
    #[polyquine_skip]
319
    Reference(Reference),
320
    /// A record
321
    Record(Vec<RecordEntry>),
322
    /// A function with attributes, domain, and range
323
    Function(FuncAttr<IntVal>, DomainPtr, DomainPtr),
324
}
325

            
326
impl UnresolvedDomain {
327
45780
    pub fn resolve(&self) -> Option<GroundDomain> {
328
45780
        match self {
329
14500
            UnresolvedDomain::Int(rngs) => rngs
330
14500
                .iter()
331
14500
                .map(Range::<IntVal>::resolve)
332
14500
                .collect::<Option<_>>()
333
14500
                .map(GroundDomain::Int),
334
            UnresolvedDomain::Set(attr, inner) => {
335
                Some(GroundDomain::Set(attr.resolve()?, inner.resolve()?))
336
            }
337
            UnresolvedDomain::MSet(attr, inner) => {
338
                Some(GroundDomain::MSet(attr.resolve()?, inner.resolve()?))
339
            }
340
12300
            UnresolvedDomain::Matrix(inner, idx_doms) => {
341
12300
                let inner_gd = inner.resolve()?;
342
12300
                idx_doms
343
12300
                    .iter()
344
12300
                    .map(DomainPtr::resolve)
345
12300
                    .collect::<Option<_>>()
346
12300
                    .map(|idx| GroundDomain::Matrix(inner_gd, idx))
347
            }
348
            UnresolvedDomain::Tuple(inners) => inners
349
                .iter()
350
                .map(DomainPtr::resolve)
351
                .collect::<Option<_>>()
352
                .map(GroundDomain::Tuple),
353
            UnresolvedDomain::Record(entries) => entries
354
                .iter()
355
                .map(|f| {
356
                    f.domain.resolve().map(|gd| RecordEntryGround {
357
                        name: f.name.clone(),
358
                        domain: gd,
359
                    })
360
                })
361
                .collect::<Option<_>>()
362
                .map(GroundDomain::Record),
363
18980
            UnresolvedDomain::Reference(re) => re
364
18980
                .ptr
365
18980
                .as_domain_letting()
366
18980
                .unwrap_or_else(|| {
367
                    bug!("Reference domain should point to domain letting, but got {re}")
368
                })
369
18980
                .resolve()
370
18980
                .map(Moo::unwrap_or_clone),
371
            UnresolvedDomain::Function(attr, dom, cdom) => {
372
                if let Some(attr_gd) = attr.resolve()
373
                    && let Some(dom_gd) = dom.resolve()
374
                    && let Some(cdom_gd) = cdom.resolve()
375
                {
376
                    return Some(GroundDomain::Function(attr_gd, dom_gd, cdom_gd));
377
                }
378
                None
379
            }
380
        }
381
45780
    }
382

            
383
    pub(super) fn union_unresolved(
384
        &self,
385
        other: &UnresolvedDomain,
386
    ) -> Result<UnresolvedDomain, DomainOpError> {
387
        match (self, other) {
388
            (UnresolvedDomain::Int(lhs), UnresolvedDomain::Int(rhs)) => {
389
                let merged = lhs.iter().chain(rhs.iter()).cloned().collect_vec();
390
                Ok(UnresolvedDomain::Int(merged))
391
            }
392
            (UnresolvedDomain::Int(_), _) | (_, UnresolvedDomain::Int(_)) => {
393
                Err(DomainOpError::WrongType)
394
            }
395
            (UnresolvedDomain::Set(_, in1), UnresolvedDomain::Set(_, in2)) => {
396
                Ok(UnresolvedDomain::Set(SetAttr::default(), in1.union(in2)?))
397
            }
398
            (UnresolvedDomain::Set(_, _), _) | (_, UnresolvedDomain::Set(_, _)) => {
399
                Err(DomainOpError::WrongType)
400
            }
401
            (UnresolvedDomain::MSet(_, in1), UnresolvedDomain::MSet(_, in2)) => {
402
                Ok(UnresolvedDomain::MSet(MSetAttr::default(), in1.union(in2)?))
403
            }
404
            (UnresolvedDomain::MSet(_, _), _) | (_, UnresolvedDomain::MSet(_, _)) => {
405
                Err(DomainOpError::WrongType)
406
            }
407
            (UnresolvedDomain::Matrix(in1, idx1), UnresolvedDomain::Matrix(in2, idx2))
408
                if idx1 == idx2 =>
409
            {
410
                Ok(UnresolvedDomain::Matrix(in1.union(in2)?, idx1.clone()))
411
            }
412
            (UnresolvedDomain::Matrix(_, _), _) | (_, UnresolvedDomain::Matrix(_, _)) => {
413
                Err(DomainOpError::WrongType)
414
            }
415
            (UnresolvedDomain::Tuple(lhs), UnresolvedDomain::Tuple(rhs))
416
                if lhs.len() == rhs.len() =>
417
            {
418
                let mut merged = Vec::new();
419
                for (l, r) in zip(lhs, rhs) {
420
                    merged.push(l.union(r)?)
421
                }
422
                Ok(UnresolvedDomain::Tuple(merged))
423
            }
424
            (UnresolvedDomain::Tuple(_), _) | (_, UnresolvedDomain::Tuple(_)) => {
425
                Err(DomainOpError::WrongType)
426
            }
427
            // TODO: Could we support unions of reference domains symbolically?
428
            (UnresolvedDomain::Reference(_), _) | (_, UnresolvedDomain::Reference(_)) => {
429
                Err(DomainOpError::NotGround)
430
            }
431
            // TODO: Could we define semantics for merging record domains?
432
            #[allow(unreachable_patterns)] // Technically redundant but logically makes sense
433
            (UnresolvedDomain::Record(_), _) | (_, UnresolvedDomain::Record(_)) => {
434
                Err(DomainOpError::WrongType)
435
            }
436
            #[allow(unreachable_patterns)]
437
            // Technically redundant but logically clearer to have both
438
            (UnresolvedDomain::Function(_, _, _), _) | (_, UnresolvedDomain::Function(_, _, _)) => {
439
                Err(DomainOpError::WrongType)
440
            }
441
        }
442
    }
443

            
444
    pub fn element_domain(&self) -> Option<DomainPtr> {
445
        match self {
446
            UnresolvedDomain::Set(_, inner_dom) => Some(inner_dom.clone()),
447
            UnresolvedDomain::Matrix(_, _) => {
448
                todo!("Unwrap one dimension of the domain")
449
            }
450
            _ => None,
451
        }
452
    }
453
}
454

            
455
impl Typeable for UnresolvedDomain {
456
10480
    fn return_type(&self) -> ReturnType {
457
10480
        match self {
458
4460
            UnresolvedDomain::Reference(re) => re.return_type(),
459
1680
            UnresolvedDomain::Int(_) => ReturnType::Int,
460
            UnresolvedDomain::Set(_attr, inner) => ReturnType::Set(Box::new(inner.return_type())),
461
            UnresolvedDomain::MSet(_attr, inner) => ReturnType::MSet(Box::new(inner.return_type())),
462
4340
            UnresolvedDomain::Matrix(inner, _idx) => {
463
4340
                ReturnType::Matrix(Box::new(inner.return_type()))
464
            }
465
            UnresolvedDomain::Tuple(inners) => {
466
                let mut inner_types = Vec::new();
467
                for inner in inners {
468
                    inner_types.push(inner.return_type());
469
                }
470
                ReturnType::Tuple(inner_types)
471
            }
472
            UnresolvedDomain::Record(entries) => {
473
                let mut entry_types = Vec::new();
474
                for entry in entries {
475
                    entry_types.push(entry.domain.return_type());
476
                }
477
                ReturnType::Record(entry_types)
478
            }
479
            UnresolvedDomain::Function(_, dom, cdom) => {
480
                ReturnType::Function(Box::new(dom.return_type()), Box::new(cdom.return_type()))
481
            }
482
        }
483
10480
    }
484
}
485

            
486
impl Display for UnresolvedDomain {
487
5080
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
488
5080
        match &self {
489
2640
            UnresolvedDomain::Reference(re) => write!(f, "{re}"),
490
2040
            UnresolvedDomain::Int(ranges) => {
491
2040
                if ranges.iter().all(Range::is_lower_or_upper_bounded) {
492
2040
                    let rngs: String = ranges.iter().map(|r| format!("{r}")).join(", ");
493
2040
                    write!(f, "int({})", rngs)
494
                } else {
495
                    write!(f, "int")
496
                }
497
            }
498
            UnresolvedDomain::Set(attrs, inner_dom) => write!(f, "set {attrs} of {inner_dom}"),
499
            UnresolvedDomain::MSet(attrs, inner_dom) => write!(f, "mset {attrs} of {inner_dom}"),
500
400
            UnresolvedDomain::Matrix(value_domain, index_domains) => {
501
400
                write!(
502
400
                    f,
503
                    "matrix indexed by [{}] of {value_domain}",
504
400
                    pretty_vec(&index_domains.iter().collect_vec())
505
                )
506
            }
507
            UnresolvedDomain::Tuple(domains) => {
508
                write!(
509
                    f,
510
                    "tuple of ({})",
511
                    pretty_vec(&domains.iter().collect_vec())
512
                )
513
            }
514
            UnresolvedDomain::Record(entries) => {
515
                write!(
516
                    f,
517
                    "record of ({})",
518
                    pretty_vec(
519
                        &entries
520
                            .iter()
521
                            .map(|entry| format!("{}: {}", entry.name, entry.domain))
522
                            .collect_vec()
523
                    )
524
                )
525
            }
526
            UnresolvedDomain::Function(attribute, domain, codomain) => {
527
                write!(f, "function {} {} --> {} ", attribute, domain, codomain)
528
            }
529
        }
530
5080
    }
531
}