1
use crate::ast::domains::attrs::SetAttr;
2
use crate::ast::{
3
    DeclarationKind, DomainOpError, Expression, FuncAttr, Literal, Metadata, Moo,
4
    RecordEntryGround, Reference, Typeable,
5
    domains::{
6
        GroundDomain,
7
        domain::{DomainPtr, Int},
8
        range::Range,
9
    },
10
};
11
use crate::{bug, domain_int, matrix_expr, range};
12
use conjure_cp_core::ast::pretty::pretty_vec;
13
use conjure_cp_core::ast::{Name, ReturnType, eval_constant};
14
use itertools::Itertools;
15
use polyquine::Quine;
16
use serde::{Deserialize, Serialize};
17
use std::fmt::{Display, Formatter};
18
use std::iter::zip;
19
use std::ops::Deref;
20
use uniplate::Uniplate;
21

            
22
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Quine, Uniplate)]
23
#[path_prefix(conjure_cp::ast)]
24
#[biplate(to=Expression)]
25
#[biplate(to=Reference)]
26
pub enum IntVal {
27
    Const(Int),
28
    #[polyquine_skip]
29
    Reference(Reference),
30
    Expr(Moo<Expression>),
31
}
32

            
33
impl From<Int> for IntVal {
34
    fn from(value: Int) -> Self {
35
        Self::Const(value)
36
    }
37
}
38

            
39
impl TryInto<Int> for IntVal {
40
    type Error = DomainOpError;
41

            
42
    fn try_into(self) -> Result<Int, Self::Error> {
43
        match self {
44
            IntVal::Const(val) => Ok(val),
45
            _ => Err(DomainOpError::NotGround),
46
        }
47
    }
48
}
49

            
50
impl From<Range<Int>> for Range<IntVal> {
51
    fn from(value: Range<Int>) -> Self {
52
        match value {
53
            Range::Single(x) => Range::Single(x.into()),
54
            Range::Bounded(l, r) => Range::Bounded(l.into(), r.into()),
55
            Range::UnboundedL(r) => Range::UnboundedL(r.into()),
56
            Range::UnboundedR(l) => Range::UnboundedR(l.into()),
57
            Range::Unbounded => Range::Unbounded,
58
        }
59
    }
60
}
61

            
62
impl TryInto<Range<Int>> for Range<IntVal> {
63
    type Error = DomainOpError;
64

            
65
    fn try_into(self) -> Result<Range<Int>, Self::Error> {
66
        match self {
67
            Range::Single(x) => Ok(Range::Single(x.try_into()?)),
68
            Range::Bounded(l, r) => Ok(Range::Bounded(l.try_into()?, r.try_into()?)),
69
            Range::UnboundedL(r) => Ok(Range::UnboundedL(r.try_into()?)),
70
            Range::UnboundedR(l) => Ok(Range::UnboundedR(l.try_into()?)),
71
            Range::Unbounded => Ok(Range::Unbounded),
72
        }
73
    }
74
}
75

            
76
impl From<SetAttr<Int>> for SetAttr<IntVal> {
77
    fn from(value: SetAttr<Int>) -> Self {
78
        SetAttr {
79
            size: value.size.into(),
80
        }
81
    }
82
}
83

            
84
impl TryInto<SetAttr<Int>> for SetAttr<IntVal> {
85
    type Error = DomainOpError;
86

            
87
    fn try_into(self) -> Result<SetAttr<Int>, Self::Error> {
88
        let size: Range<Int> = self.size.try_into()?;
89
        Ok(SetAttr { size })
90
    }
91
}
92

            
93
impl From<FuncAttr<Int>> for FuncAttr<IntVal> {
94
    fn from(value: FuncAttr<Int>) -> Self {
95
        FuncAttr {
96
            size: value.size.into(),
97
            partiality: value.partiality,
98
            jectivity: value.jectivity,
99
        }
100
    }
101
}
102

            
103
impl TryInto<FuncAttr<Int>> for FuncAttr<IntVal> {
104
    type Error = DomainOpError;
105

            
106
    fn try_into(self) -> Result<FuncAttr<Int>, Self::Error> {
107
        let size: Range<Int> = self.size.try_into()?;
108
        Ok(FuncAttr {
109
            size,
110
            jectivity: self.jectivity,
111
            partiality: self.partiality,
112
        })
113
    }
114
}
115

            
116
impl Display for IntVal {
117
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
118
        match self {
119
            IntVal::Const(val) => write!(f, "{val}"),
120
            IntVal::Reference(re) => write!(f, "{re}"),
121
            IntVal::Expr(expr) => write!(f, "({expr})"),
122
        }
123
    }
124
}
125

            
126
impl IntVal {
127
    pub fn new_ref(re: &Reference) -> Option<IntVal> {
128
        match re.ptr.kind().deref() {
129
            DeclarationKind::ValueLetting(expr) => match expr.return_type() {
130
                ReturnType::Int => Some(IntVal::Reference(re.clone())),
131
                _ => None,
132
            },
133
            DeclarationKind::Given(dom) => match dom.return_type() {
134
                ReturnType::Int => Some(IntVal::Reference(re.clone())),
135
                _ => None,
136
            },
137
            DeclarationKind::DomainLetting(_)
138
            | DeclarationKind::RecordField(_)
139
            | DeclarationKind::DecisionVariable(_) => None,
140
        }
141
    }
142

            
143
    pub fn new_expr(value: Moo<Expression>) -> Option<IntVal> {
144
        if value.return_type() != ReturnType::Int {
145
            return None;
146
        }
147
        Some(IntVal::Expr(value))
148
    }
149

            
150
    pub fn resolve(&self) -> Option<Int> {
151
        match self {
152
            IntVal::Const(value) => Some(*value),
153
            IntVal::Expr(expr) => match eval_constant(expr)? {
154
                Literal::Int(v) => Some(v),
155
                _ => bug!("Expected integer expression, got: {expr}"),
156
            },
157
            IntVal::Reference(re) => match re.ptr.kind().deref() {
158
                DeclarationKind::ValueLetting(expr) => match eval_constant(expr)? {
159
                    Literal::Int(v) => Some(v),
160
                    _ => bug!("Expected integer expression, got: {expr}"),
161
                },
162
                // If this is an int given we will be able to resolve it eventually, but not yet
163
                DeclarationKind::Given(_) => None,
164
                DeclarationKind::DomainLetting(_)
165
                | DeclarationKind::RecordField(_)
166
                | DeclarationKind::DecisionVariable(_) => bug!(
167
                    "Expected integer expression, given, or letting inside int domain; Got: {re}"
168
                ),
169
            },
170
        }
171
    }
172
}
173

            
174
impl From<IntVal> for Expression {
175
    fn from(value: IntVal) -> Self {
176
        match value {
177
            IntVal::Const(val) => val.into(),
178
            IntVal::Reference(re) => re.into(),
179
            IntVal::Expr(expr) => expr.as_ref().clone(),
180
        }
181
    }
182
}
183

            
184
impl From<IntVal> for Moo<Expression> {
185
    fn from(value: IntVal) -> Self {
186
        match value {
187
            IntVal::Const(val) => Moo::new(val.into()),
188
            IntVal::Reference(re) => Moo::new(re.into()),
189
            IntVal::Expr(expr) => expr,
190
        }
191
    }
192
}
193

            
194
impl std::ops::Neg for IntVal {
195
    type Output = IntVal;
196

            
197
    fn neg(self) -> Self::Output {
198
        match self {
199
            IntVal::Const(val) => IntVal::Const(-val),
200
            IntVal::Reference(_) | IntVal::Expr(_) => {
201
                IntVal::Expr(Moo::new(Expression::Neg(Metadata::new(), self.into())))
202
            }
203
        }
204
    }
205
}
206

            
207
impl<T> std::ops::Add<T> for IntVal
208
where
209
    T: Into<Expression>,
210
{
211
    type Output = IntVal;
212

            
213
    fn add(self, rhs: T) -> Self::Output {
214
        let lhs: Expression = self.into();
215
        let rhs: Expression = rhs.into();
216
        let sum = matrix_expr!(lhs, rhs; domain_int!(1..));
217
        IntVal::Expr(Moo::new(Expression::Sum(Metadata::new(), Moo::new(sum))))
218
    }
219
}
220

            
221
impl Range<IntVal> {
222
    pub fn resolve(&self) -> Option<Range<Int>> {
223
        match self {
224
            Range::Single(x) => Some(Range::Single(x.resolve()?)),
225
            Range::Bounded(l, r) => Some(Range::Bounded(l.resolve()?, r.resolve()?)),
226
            Range::UnboundedL(r) => Some(Range::UnboundedL(r.resolve()?)),
227
            Range::UnboundedR(l) => Some(Range::UnboundedR(l.resolve()?)),
228
            Range::Unbounded => Some(Range::Unbounded),
229
        }
230
    }
231
}
232

            
233
impl SetAttr<IntVal> {
234
    pub fn resolve(&self) -> Option<SetAttr<Int>> {
235
        Some(SetAttr {
236
            size: self.size.resolve()?,
237
        })
238
    }
239
}
240

            
241
impl FuncAttr<IntVal> {
242
    pub fn resolve(&self) -> Option<FuncAttr<Int>> {
243
        Some(FuncAttr {
244
            size: self.size.resolve()?,
245
            partiality: self.partiality.clone(),
246
            jectivity: self.jectivity.clone(),
247
        })
248
    }
249
}
250

            
251
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Uniplate, Quine)]
252
#[path_prefix(conjure_cp::ast)]
253
pub struct RecordEntry {
254
    pub name: Name,
255
    pub domain: DomainPtr,
256
}
257

            
258
impl RecordEntry {
259
    pub fn resolve(self) -> Option<RecordEntryGround> {
260
        Some(RecordEntryGround {
261
            name: self.name,
262
            domain: self.domain.resolve()?,
263
        })
264
    }
265
}
266

            
267
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Quine, Uniplate)]
268
#[path_prefix(conjure_cp::ast)]
269
#[biplate(to=Expression)]
270
#[biplate(to=Reference)]
271
#[biplate(to=IntVal)]
272
#[biplate(to=DomainPtr)]
273
#[biplate(to=RecordEntry)]
274
pub enum UnresolvedDomain {
275
    Int(Vec<Range<IntVal>>),
276
    /// A set of elements drawn from the inner domain
277
    Set(SetAttr<IntVal>, DomainPtr),
278
    /// A n-dimensional matrix with a value domain and n-index domains
279
    Matrix(DomainPtr, Vec<DomainPtr>),
280
    /// A tuple of N elements, each with its own domain
281
    Tuple(Vec<DomainPtr>),
282
    /// A reference to a domain letting
283
    #[polyquine_skip]
284
    Reference(Reference),
285
    /// A record
286
    Record(Vec<RecordEntry>),
287
    /// A function with attributes, domain, and range
288
    Function(FuncAttr<IntVal>, DomainPtr, DomainPtr),
289
}
290

            
291
impl UnresolvedDomain {
292
    pub fn resolve(&self) -> Option<GroundDomain> {
293
        match self {
294
            UnresolvedDomain::Int(rngs) => rngs
295
                .iter()
296
                .map(Range::<IntVal>::resolve)
297
                .collect::<Option<_>>()
298
                .map(GroundDomain::Int),
299
            UnresolvedDomain::Set(attr, inner) => {
300
                Some(GroundDomain::Set(attr.resolve()?, inner.resolve()?))
301
            }
302
            UnresolvedDomain::Matrix(inner, idx_doms) => {
303
                let inner_gd = inner.resolve()?;
304
                idx_doms
305
                    .iter()
306
                    .map(DomainPtr::resolve)
307
                    .collect::<Option<_>>()
308
                    .map(|idx| GroundDomain::Matrix(inner_gd, idx))
309
            }
310
            UnresolvedDomain::Tuple(inners) => inners
311
                .iter()
312
                .map(DomainPtr::resolve)
313
                .collect::<Option<_>>()
314
                .map(GroundDomain::Tuple),
315
            UnresolvedDomain::Record(entries) => entries
316
                .iter()
317
                .map(|f| {
318
                    f.domain.resolve().map(|gd| RecordEntryGround {
319
                        name: f.name.clone(),
320
                        domain: gd,
321
                    })
322
                })
323
                .collect::<Option<_>>()
324
                .map(GroundDomain::Record),
325
            UnresolvedDomain::Reference(re) => re
326
                .ptr
327
                .as_domain_letting()
328
                .unwrap_or_else(|| {
329
                    bug!("Reference domain should point to domain letting, but got {re}")
330
                })
331
                .resolve()
332
                .map(Moo::unwrap_or_clone),
333
            UnresolvedDomain::Function(attr, dom, cdom) => {
334
                if let Some(attr_gd) = attr.resolve()
335
                    && let Some(dom_gd) = dom.resolve()
336
                    && let Some(cdom_gd) = cdom.resolve()
337
                {
338
                    return Some(GroundDomain::Function(attr_gd, dom_gd, cdom_gd));
339
                }
340
                None
341
            }
342
        }
343
    }
344

            
345
    pub(super) fn union_unresolved(
346
        &self,
347
        other: &UnresolvedDomain,
348
    ) -> Result<UnresolvedDomain, DomainOpError> {
349
        match (self, other) {
350
            (UnresolvedDomain::Int(lhs), UnresolvedDomain::Int(rhs)) => {
351
                let merged = lhs.iter().chain(rhs.iter()).cloned().collect_vec();
352
                Ok(UnresolvedDomain::Int(merged))
353
            }
354
            (UnresolvedDomain::Int(_), _) | (_, UnresolvedDomain::Int(_)) => {
355
                Err(DomainOpError::WrongType)
356
            }
357
            (UnresolvedDomain::Set(_, in1), UnresolvedDomain::Set(_, in2)) => {
358
                Ok(UnresolvedDomain::Set(SetAttr::default(), in1.union(in2)?))
359
            }
360
            (UnresolvedDomain::Set(_, _), _) | (_, UnresolvedDomain::Set(_, _)) => {
361
                Err(DomainOpError::WrongType)
362
            }
363
            (UnresolvedDomain::Matrix(in1, idx1), UnresolvedDomain::Matrix(in2, idx2))
364
                if idx1 == idx2 =>
365
            {
366
                Ok(UnresolvedDomain::Matrix(in1.union(in2)?, idx1.clone()))
367
            }
368
            (UnresolvedDomain::Matrix(_, _), _) | (_, UnresolvedDomain::Matrix(_, _)) => {
369
                Err(DomainOpError::WrongType)
370
            }
371
            (UnresolvedDomain::Tuple(lhs), UnresolvedDomain::Tuple(rhs))
372
                if lhs.len() == rhs.len() =>
373
            {
374
                let mut merged = Vec::new();
375
                for (l, r) in zip(lhs, rhs) {
376
                    merged.push(l.union(r)?)
377
                }
378
                Ok(UnresolvedDomain::Tuple(merged))
379
            }
380
            (UnresolvedDomain::Tuple(_), _) | (_, UnresolvedDomain::Tuple(_)) => {
381
                Err(DomainOpError::WrongType)
382
            }
383
            // TODO: Could we support unions of reference domains symbolically?
384
            (UnresolvedDomain::Reference(_), _) | (_, UnresolvedDomain::Reference(_)) => {
385
                Err(DomainOpError::NotGround)
386
            }
387
            // TODO: Could we define semantics for merging record domains?
388
            #[allow(unreachable_patterns)] // Technically redundant but logically makes sense
389
            (UnresolvedDomain::Record(_), _) | (_, UnresolvedDomain::Record(_)) => {
390
                Err(DomainOpError::WrongType)
391
            }
392
            #[allow(unreachable_patterns)]
393
            // Technically redundant but logically clearer to have both
394
            (UnresolvedDomain::Function(_, _, _), _) | (_, UnresolvedDomain::Function(_, _, _)) => {
395
                Err(DomainOpError::WrongType)
396
            }
397
        }
398
    }
399
}
400

            
401
impl Typeable for UnresolvedDomain {
402
    fn return_type(&self) -> ReturnType {
403
        match self {
404
            UnresolvedDomain::Reference(re) => re.return_type(),
405
            UnresolvedDomain::Int(_) => ReturnType::Int,
406
            UnresolvedDomain::Set(_attr, inner) => ReturnType::Set(Box::new(inner.return_type())),
407
            UnresolvedDomain::Matrix(inner, _idx) => {
408
                ReturnType::Matrix(Box::new(inner.return_type()))
409
            }
410
            UnresolvedDomain::Tuple(inners) => {
411
                let mut inner_types = Vec::new();
412
                for inner in inners {
413
                    inner_types.push(inner.return_type());
414
                }
415
                ReturnType::Tuple(inner_types)
416
            }
417
            UnresolvedDomain::Record(entries) => {
418
                let mut entry_types = Vec::new();
419
                for entry in entries {
420
                    entry_types.push(entry.domain.return_type());
421
                }
422
                ReturnType::Record(entry_types)
423
            }
424
            UnresolvedDomain::Function(_, dom, cdom) => {
425
                ReturnType::Function(Box::new(dom.return_type()), Box::new(cdom.return_type()))
426
            }
427
        }
428
    }
429
}
430

            
431
impl Display for UnresolvedDomain {
432
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
433
        match &self {
434
            UnresolvedDomain::Reference(re) => write!(f, "{re}"),
435
            UnresolvedDomain::Int(ranges) => {
436
                if ranges.iter().all(Range::is_lower_or_upper_bounded) {
437
                    let rngs: String = ranges.iter().map(|r| format!("{r}")).join(", ");
438
                    write!(f, "int({})", rngs)
439
                } else {
440
                    write!(f, "int")
441
                }
442
            }
443
            UnresolvedDomain::Set(attrs, inner_dom) => write!(f, "set {attrs} of {inner_dom}"),
444
            UnresolvedDomain::Matrix(value_domain, index_domains) => {
445
                write!(
446
                    f,
447
                    "matrix indexed by [{}] of {value_domain}",
448
                    pretty_vec(&index_domains.iter().collect_vec())
449
                )
450
            }
451
            UnresolvedDomain::Tuple(domains) => {
452
                write!(
453
                    f,
454
                    "tuple of ({})",
455
                    pretty_vec(&domains.iter().collect_vec())
456
                )
457
            }
458
            UnresolvedDomain::Record(entries) => {
459
                write!(
460
                    f,
461
                    "record of ({})",
462
                    pretty_vec(
463
                        &entries
464
                            .iter()
465
                            .map(|entry| format!("{}: {}", entry.name, entry.domain))
466
                            .collect_vec()
467
                    )
468
                )
469
            }
470
            UnresolvedDomain::Function(attribute, domain, codomain) => {
471
                write!(f, "function {} {} --> {} ", attribute, domain, codomain)
472
            }
473
        }
474
    }
475
}