1
//! Utility functions for working with matrices.
2

            
3
// TODO: Georgiis essence macro would look really nice in these examples!
4

            
5
use std::collections::VecDeque;
6

            
7
use itertools::{Itertools, izip};
8
use uniplate::Uniplate as _;
9

            
10
use crate::ast::{DomainOpError, Expression as Expr, GroundDomain, Metadata, Moo, Range};
11

            
12
use super::{AbstractLiteral, Literal};
13

            
14
/// For some index domains, returns a list containing each of the possible indices.
15
///
16
/// Indices are traversed in row-major ordering.
17
///
18
/// This is an O(n^dim) operation, where dim is the number of dimensions in the matrix.
19
///
20
/// # Panics
21
///
22
/// + If any of the index domains are not finite or enumerable with [`Domain::values`].
23
///
24
/// # Example
25
///
26
/// ```
27
/// use std::collections::HashSet;
28
/// use conjure_cp_core::ast::{GroundDomain,Moo,Range,Literal,matrix};
29
/// let index_domains = vec![Moo::new(GroundDomain::Bool),Moo::new(GroundDomain::Int(vec![Range::Bounded(1,2)]))];
30
///
31
/// let expected_indices = HashSet::from([
32
///   vec![Literal::Bool(false),Literal::Int(1)],
33
///   vec![Literal::Bool(false),Literal::Int(2)],
34
///   vec![Literal::Bool(true),Literal::Int(1)],
35
///   vec![Literal::Bool(true),Literal::Int(2)]
36
///   ]);
37
///
38
/// let actual_indices: HashSet<_> = matrix::enumerate_indices(index_domains).collect();
39
///
40
/// assert_eq!(actual_indices, expected_indices);
41
/// ```
42
34766
pub fn try_enumerate_indices(
43
34766
    index_domains: Vec<Moo<GroundDomain>>,
44
34766
) -> Result<impl Iterator<Item = Vec<Literal>>, DomainOpError> {
45
34766
    let domains = index_domains
46
34766
        .into_iter()
47
49408
        .map(|x| x.values().map(|values| values.collect_vec()))
48
34766
        .collect::<Result<Vec<_>, _>>()?;
49
34766
    Ok(domains.into_iter().multi_cartesian_product())
50
34766
}
51

            
52
/// For some index domains, returns a list containing each of the possible indices.
53
///
54
/// See [`try_enumerate_indices`] for the fallible variant.
55
34246
pub fn enumerate_indices(
56
34246
    index_domains: Vec<Moo<GroundDomain>>,
57
34246
) -> impl Iterator<Item = Vec<Literal>> {
58
34246
    try_enumerate_indices(index_domains).expect("index domain should be enumerable with .values()")
59
34246
}
60

            
61
/// Returns the number of possible elements indexable by the given index domains.
62
///
63
/// In short, returns the product of the sizes of the given indices.
64
800
pub fn num_elements(index_domains: &[Moo<GroundDomain>]) -> Result<u64, DomainOpError> {
65
800
    let idx_dom_lengths = index_domains
66
800
        .iter()
67
1600
        .map(|d| d.length())
68
800
        .collect::<Result<Vec<_>, _>>()?;
69
800
    Ok(idx_dom_lengths.iter().product())
70
800
}
71

            
72
/// Flattens a multi-dimensional matrix literal into a one-dimensional slice of its elements.
73
///
74
/// The elements of the matrix are returned in row-major ordering (see [`enumerate_indices`]).
75
///
76
/// # Panics
77
///
78
/// + If the number or type of elements in each dimension is inconsistent.
79
///
80
/// + If `matrix` is not a matrix.
81
11520
pub fn flatten(matrix: AbstractLiteral<Literal>) -> impl Iterator<Item = Literal> {
82
11520
    let AbstractLiteral::Matrix(elems, _) = matrix else {
83
        panic!("matrix should be a matrix");
84
    };
85

            
86
11520
    flatten_1(elems)
87
11520
}
88

            
89
28140
fn flatten_1(elems: Vec<Literal>) -> impl Iterator<Item = Literal> {
90
56094
    elems.into_iter().flat_map(|elem| {
91
11520
        if let Literal::AbstractLiteral(m @ AbstractLiteral::Matrix(_, _)) = elem {
92
11520
            Box::new(flatten(m)) as Box<dyn Iterator<Item = Literal>>
93
        } else {
94
44574
            Box::new(std::iter::once(elem)) as Box<dyn Iterator<Item = Literal>>
95
        }
96
56094
    })
97
28140
}
98
/// Flattens a multi-dimensional matrix literal into an iterator over (indices,element).
99
///
100
/// # Panics
101
///
102
///   + If the number or type of elements in each dimension is inconsistent.
103
///
104
///   + If `matrix` is not a matrix.
105
///
106
///   + If any dimensions in the matrix are not finite or enumerable with [`Domain::values`].
107
///     However, index domains in the form `int(i..)` are supported.
108
16620
pub fn flatten_enumerate(
109
16620
    matrix: AbstractLiteral<Literal>,
110
16620
) -> impl Iterator<Item = (Vec<Literal>, Literal)> {
111
16620
    let AbstractLiteral::Matrix(elems, _) = matrix.clone() else {
112
        panic!("matrix should be a matrix");
113
    };
114

            
115
16620
    let index_domains = index_domains(matrix);
116

            
117
16620
    izip!(enumerate_indices(index_domains), flatten_1(elems))
118
16620
}
119

            
120
/// Gets the index domains for a matrix literal.
121
///
122
/// # Panics
123
///
124
/// + If `matrix` is not a matrix.
125
///
126
/// + If the number or type of elements in each dimension is inconsistent.
127
16700
pub fn index_domains(matrix: AbstractLiteral<Literal>) -> Vec<Moo<GroundDomain>> {
128
16700
    let AbstractLiteral::Matrix(_, _) = matrix else {
129
        panic!("matrix should be a matrix");
130
    };
131

            
132
16700
    matrix.cata(&move |element: AbstractLiteral<Literal>,
133
34540
                       child_index_domains: VecDeque<Vec<Moo<GroundDomain>>>| {
134
34540
        assert!(
135
34540
            child_index_domains.iter().all_equal(),
136
            "each child of a matrix should have the same index domain"
137
        );
138

            
139
34540
        let child_index_domains = child_index_domains
140
34540
            .front()
141
34540
            .unwrap_or(&vec![])
142
34540
            .iter()
143
34540
            .cloned()
144
34540
            .collect_vec();
145
34540
        match element {
146
            AbstractLiteral::Set(_) => vec![],
147
            AbstractLiteral::MSet(_) => vec![],
148
34540
            AbstractLiteral::Matrix(elems, domain) => {
149
34540
                let mut index_domains = vec![bound_index_domain_from_length(domain, elems.len())];
150
34540
                index_domains.extend(child_index_domains);
151
34540
                index_domains
152
            }
153
            AbstractLiteral::Tuple(_) => vec![],
154
            AbstractLiteral::Record(_) => vec![],
155
            AbstractLiteral::Function(_) => vec![],
156
            AbstractLiteral::Variant(_) => vec![],
157
            AbstractLiteral::Relation(_) => vec![],
158
            AbstractLiteral::Sequence(_) => vec![],
159
            AbstractLiteral::Partition(_) => vec![],
160
        }
161
34540
    })
162
16700
}
163

            
164
/// See [`enumerate_indices`]. This function zips the two given lists of index domains, performs a
165
/// union on each pair, and returns an enumerating iterator over the new list of domains.
166
400
pub fn enumerate_index_union_indices(
167
400
    a_domains: &[Moo<GroundDomain>],
168
400
    b_domains: &[Moo<GroundDomain>],
169
400
) -> Result<impl Iterator<Item = Vec<Literal>>, DomainOpError> {
170
400
    if a_domains.len() != b_domains.len() {
171
        return Err(DomainOpError::WrongType);
172
400
    }
173
400
    let idx_domains: Result<Vec<_>, _> = a_domains
174
400
        .iter()
175
400
        .zip(b_domains.iter())
176
480
        .map(|(a, b)| a.union(b))
177
400
        .collect();
178
400
    let idx_domains = idx_domains?.into_iter().map(Moo::new).collect();
179

            
180
400
    try_enumerate_indices(idx_domains)
181
400
}
182

            
183
// Given index domains for a multi-dimensional matrix and the nth index in the flattened matrix, find the coordinates in the original matrix
184
640
pub fn flat_index_to_full_index(index_domains: &[Moo<GroundDomain>], index: u64) -> Vec<Literal> {
185
640
    let mut remaining = index;
186
640
    let mut multipliers = vec![1; index_domains.len()];
187

            
188
640
    for i in (1..index_domains.len()).rev() {
189
640
        multipliers[i - 1] = multipliers[i] * index_domains[i].as_ref().length().unwrap();
190
640
    }
191

            
192
640
    let mut coords = Vec::new();
193
1280
    for m in multipliers.iter() {
194
        // adjust for 1-based indexing
195
1280
        coords.push(((remaining / m + 1) as i32).into());
196
1280
        remaining %= *m;
197
1280
    }
198

            
199
640
    coords
200
640
}
201

            
202
/// Gets concrete index domains for a matrix expression.
203
///
204
/// For matrix literals, right-unbounded integer index domains like `int(1..)` are bounded using
205
/// the literal's realised size in that dimension. For non-literals, this falls back to the
206
/// expression's resolved domain.
207
31626
pub fn bound_index_domains_of_expr(expr: &Expr) -> Option<Vec<Moo<GroundDomain>>> {
208
31626
    let dom = expr.domain_of().and_then(|dom| dom.resolve())?;
209
28346
    let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
210
26138
        return None;
211
    };
212

            
213
2208
    let Some(dimension_lengths) = expr_matrix_dimension_lengths(expr) else {
214
2208
        return Some(index_domains.clone());
215
    };
216

            
217
    assert_eq!(
218
        index_domains.len(),
219
        dimension_lengths.len(),
220
        "matrix literal domain rank should match its realised rank"
221
    );
222

            
223
    Some(
224
        index_domains
225
            .iter()
226
            .cloned()
227
            .zip(dimension_lengths)
228
            .map(|(domain, len)| bound_index_domain_from_length(domain, len))
229
            .collect(),
230
    )
231
31626
}
232

            
233
/// This is the same as `m[x]` except when `m` is of the forms:
234
///
235
/// - `n[..]`, then it produces n[x] instead of n[..][x]
236
/// - `flatten(n)`, then it produces `n[y]` instead of `flatten(n)[y]`,
237
///   where `y` is the full index corresponding to flat index `x`
238
///
239
/// # Returns
240
/// + `Some(expr)` if the safe indexing could be constructed
241
/// + `None` if it could not be constructed (e.g. invalid index type)
242
3360
pub fn safe_index_optimised(m: Expr, idx: Literal) -> Option<Expr> {
243
    match m {
244
1440
        Expr::SafeSlice(_, mat, idxs) => {
245
            // TODO: support >1 slice index (i.e. multidimensional slices)
246

            
247
1440
            let mut idxs = idxs;
248
2160
            let (slice_idx, _) = idxs.iter().find_position(|opt| opt.is_none())?;
249
1440
            let _ = idxs[slice_idx].replace(idx.into());
250

            
251
1440
            let Some(idxs) = idxs.into_iter().collect::<Option<Vec<_>>>() else {
252
                todo!("slice expression should not contain more than one unspecified index")
253
            };
254

            
255
1440
            Some(Expr::SafeIndex(Metadata::new(), mat, idxs))
256
        }
257
        Expr::Flatten(_, None, inner) => {
258
            // Similar to indexed_flatten_matrix rule, but we don't care about out of bounds here
259
            let Literal::Int(index) = idx else {
260
                return None;
261
            };
262

            
263
            let index_domains = bound_index_domains_of_expr(inner.as_ref())?;
264
            if index_domains.iter().any(|domain| domain.length().is_err()) {
265
                return None;
266
            }
267
            let flat_index = flat_index_to_full_index(&index_domains, (index - 1) as u64);
268
            let flat_index: Vec<Expr> = flat_index.into_iter().map(Into::into).collect();
269

            
270
            Some(Expr::SafeIndex(Metadata::new(), inner, flat_index))
271
        }
272
1920
        _ => Some(Expr::SafeIndex(
273
1920
            Metadata::new(),
274
1920
            Moo::new(m),
275
1920
            vec![idx.into()],
276
1920
        )),
277
    }
278
3360
}
279

            
280
34540
fn bound_index_domain_from_length(mut domain: Moo<GroundDomain>, len: usize) -> Moo<GroundDomain> {
281
34540
    match Moo::make_mut(&mut domain) {
282
34540
        GroundDomain::Int(ranges) if ranges.len() == 1 && len > 0 => {
283
33820
            if let Range::UnboundedR(start) = ranges[0] {
284
                let end = start + (len as i32 - 1);
285
                ranges[0] = Range::Bounded(start, end);
286
33820
            }
287
33820
            domain
288
        }
289
720
        _ => domain,
290
    }
291
34540
}
292

            
293
2208
fn expr_matrix_dimension_lengths(expr: &Expr) -> Option<Vec<usize>> {
294
2208
    let (elems, _) = expr.clone().unwrap_matrix_unchecked()?;
295

            
296
    let child_dimensions = elems
297
        .iter()
298
        .map(|elem| expr_matrix_dimension_lengths(elem).unwrap_or_default())
299
        .collect_vec();
300

            
301
    assert!(
302
        child_dimensions.iter().all_equal(),
303
        "each child of a matrix should have the same shape"
304
    );
305

            
306
    let mut dimensions = vec![elems.len()];
307
    if let Some(child_dimensions) = child_dimensions.into_iter().next() {
308
        dimensions.extend(child_dimensions);
309
    }
310
    Some(dimensions)
311
2208
}