1
//! Utility functions for working with matrices.
2

            
3
// TODO: Georgiis essence macro would look really nice in these examples!
4

            
5
use std::collections::VecDeque;
6

            
7
use itertools::{Itertools, izip};
8
use uniplate::Uniplate as _;
9

            
10
use crate::ast::{
11
    DomainOpError, Expression as Expr, GroundDomain, Metadata, Moo, Range, domains::Int,
12
};
13

            
14
use super::{AbstractLiteral, Literal};
15

            
16
/// For some index domains, returns a list containing each of the possible indices.
17
///
18
/// Indices are traversed in row-major ordering.
19
///
20
/// This is an O(n^dim) operation, where dim is the number of dimensions in the matrix.
21
///
22
/// # Panics
23
///
24
/// + If any of the index domains are not finite or enumerable with [`Domain::values`].
25
///
26
/// # Example
27
///
28
/// ```
29
/// use std::collections::HashSet;
30
/// use conjure_cp_core::ast::{GroundDomain,Moo,Range,Literal,matrix};
31
/// let index_domains = vec![Moo::new(GroundDomain::Bool),Moo::new(GroundDomain::Int(vec![Range::Bounded(1,2)]))];
32
///
33
/// let expected_indices = HashSet::from([
34
///   vec![Literal::Bool(false),Literal::Int(1)],
35
///   vec![Literal::Bool(false),Literal::Int(2)],
36
///   vec![Literal::Bool(true),Literal::Int(1)],
37
///   vec![Literal::Bool(true),Literal::Int(2)]
38
///   ]);
39
///
40
/// let actual_indices: HashSet<_> = matrix::enumerate_indices(index_domains).collect();
41
///
42
/// assert_eq!(actual_indices, expected_indices);
43
/// ```
44
19559
pub fn enumerate_indices(
45
19559
    index_domains: Vec<Moo<GroundDomain>>,
46
19559
) -> impl Iterator<Item = Vec<Literal>> {
47
26342
    index_domains
48
26342
        .into_iter()
49
33502
        .map(|x| {
50
33502
            x.values()
51
33502
                .expect("index domain should be enumerable with .values()")
52
26719
                .collect_vec()
53
26719
        })
54
10040
        .multi_cartesian_product()
55
10040
}
56

            
57
/// Returns the number of possible elements indexable by the given index domains.
58
///
59
/// In short, returns the product of the sizes of the given indices.
60
390
pub fn num_elements(index_domains: &[Moo<GroundDomain>]) -> Result<Int, DomainOpError> {
61
580
    let idx_dom_lengths = index_domains
62
390
        .iter()
63
590
        .map(|d| d.length())
64
390
        .collect::<Result<Vec<_>, _>>()?;
65
200
    Ok(idx_dom_lengths.iter().product())
66
200
}
67

            
68
/// Flattens a multi-dimensional matrix literal into a one-dimensional slice of its elements.
69
///
70
/// The elements of the matrix are returned in row-major ordering (see [`enumerate_indices`]).
71
///
72
/// # Panics
73
///
74
/// + If the number or type of elements in each dimension is inconsistent.
75
///
76
/// + If `matrix` is not a matrix.
77
140
pub fn flatten(matrix: AbstractLiteral<Literal>) -> impl Iterator<Item = Literal> {
78
140
    let AbstractLiteral::Matrix(elems, _) = matrix else {
79
        panic!("matrix should be a matrix");
80
133
    };
81
133

            
82
140
    flatten_1(elems)
83
520
}
84
1102

            
85
533
fn flatten_1(elems: Vec<Literal>) -> impl Iterator<Item = Literal> {
86
1293
    elems.into_iter().flat_map(|elem| {
87
140
        if let Literal::AbstractLiteral(m @ AbstractLiteral::Matrix(_, _)) = elem {
88
1109
            Box::new(flatten(m)) as Box<dyn Iterator<Item = Literal>>
89
        } else {
90
2122
            Box::new(std::iter::once(elem)) as Box<dyn Iterator<Item = Literal>>
91
380
        }
92
1160
    })
93
400
}
94
/// Flattens a multi-dimensional matrix literal into an iterator over (indices,element).
95
///
96
/// # Panics
97
///
98
///   + If the number or type of elements in each dimension is inconsistent.
99
///
100
///   + If `matrix` is not a matrix.
101
///
102
///   + If any dimensions in the matrix are not finite or enumerable with [`Domain::values`].
103
///     However, index domains in the form `int(i..)` are supported.
104
507
pub fn flatten_enumerate(
105
507
    matrix: AbstractLiteral<Literal>,
106
260
) -> impl Iterator<Item = (Vec<Literal>, Literal)> {
107
260
    let AbstractLiteral::Matrix(elems, _) = matrix.clone() else {
108
        panic!("matrix should be a matrix");
109
247
    };
110
247

            
111
564
    let index_domains = index_domains(matrix)
112
260
        .into_iter()
113
624
        .map(|mut x| match Moo::make_mut(&mut x) {
114
            // give unboundedr index domains an end
115
320
            GroundDomain::Int(ranges) if ranges.len() == 1 && !elems.is_empty() => {
116
546
                if let Range::UnboundedR(start) = ranges[0] {
117
266
                    ranges[0] = Range::Bounded(start, start + (elems.len() as i32 - 1));
118
280
                };
119
318
                x
120
304
            }
121
287
            _ => x,
122
320
        })
123
507
        .collect_vec();
124
247

            
125
260
    izip!(enumerate_indices(index_domains), flatten_1(elems))
126
260
}
127

            
128
/// Gets the index domains for a matrix literal.
129
///
130
/// # Panics
131
///
132
/// + If `matrix` is not a matrix.
133
///
134
/// + If the number or type of elements in each dimension is inconsistent.
135
280
pub fn index_domains(matrix: AbstractLiteral<Literal>) -> Vec<Moo<GroundDomain>> {
136
280
    let AbstractLiteral::Matrix(_, _) = matrix else {
137
        panic!("matrix should be a matrix");
138
266
    };
139
494

            
140
774
    matrix.cata(&move |element: AbstractLiteral<Literal>,
141
1014
                       child_index_domains: VecDeque<Vec<Moo<GroundDomain>>>| {
142
520
        assert!(
143
520
            child_index_domains.iter().all_equal(),
144
            "each child of a matrix should have the same index domain"
145
494
        );
146
494

            
147
1014
        let child_index_domains = child_index_domains
148
1014
            .front()
149
1014
            .unwrap_or(&vec![])
150
1014
            .iter()
151
1014
            .cloned()
152
520
            .collect_vec();
153
520
        match element {
154
494
            AbstractLiteral::Set(_) => vec![],
155
494
            AbstractLiteral::MSet(_) => vec![],
156
1014
            AbstractLiteral::Matrix(_, domain) => {
157
1014
                let mut index_domains = vec![domain];
158
520
                index_domains.extend(child_index_domains);
159
520
                index_domains
160
            }
161
            AbstractLiteral::Tuple(_) => vec![],
162
            AbstractLiteral::Record(_) => vec![],
163
494
            AbstractLiteral::Function(_) => vec![],
164
266
        }
165
520
    })
166
280
}
167

            
168
/// See [`enumerate_indices`]. This function zips the two given lists of index domains, performs a
169
/// union on each pair, and returns an enumerating iterator over the new list of domains.
170
195
pub fn enumerate_index_union_indices(
171
195
    a_domains: &[Moo<GroundDomain>],
172
195
    b_domains: &[Moo<GroundDomain>],
173
100
) -> Result<impl Iterator<Item = Vec<Literal>>, DomainOpError> {
174
195
    if a_domains.len() != b_domains.len() {
175
95
        return Err(DomainOpError::WrongType);
176
195
    }
177
195
    let idx_domains: Result<Vec<_>, _> = a_domains
178
214
        .iter()
179
195
        .zip(b_domains.iter())
180
215
        .map(|(a, b)| a.union(b))
181
100
        .collect();
182
195
    let idx_domains = idx_domains?.into_iter().map(Moo::new).collect();
183
95

            
184
100
    Ok(enumerate_indices(idx_domains))
185
100
}
186
152

            
187
// Given index domains for a multi-dimensional matrix and the nth index in the flattened matrix, find the coordinates in the original matrix
188
312
pub fn flat_index_to_full_index(index_domains: &[Moo<GroundDomain>], index: Int) -> Vec<Literal> {
189
160
    let mut remaining = index;
190
312
    let mut multipliers = vec![1; index_domains.len()];
191
152

            
192
312
    for i in (1..index_domains.len()).rev() {
193
160
        multipliers[i - 1] = multipliers[i] * index_domains[i].as_ref().length().unwrap();
194
312
    }
195
304

            
196
464
    let mut coords = Vec::new();
197
624
    for m in multipliers.iter() {
198
        // adjust for 1-based indexing
199
624
        coords.push((remaining / m + 1).into());
200
320
        remaining %= *m;
201
472
    }
202
152

            
203
160
    coords
204
160
}
205

            
206
/// This is the same as `m[x]` except when `m` is of the forms:
207
///
208
/// - `n[..]`, then it produces n[x] instead of n[..][x]
209
/// - `flatten(n)`, then it produces `n[y]` instead of `flatten(n)[y]`,
210
///   where `y` is the full index corresponding to flat index `x`
211
///
212
/// # Returns
213
/// + `Some(expr)` if the safe indexing could be constructed
214
/// + `None` if it could not be constructed (e.g. invalid index type)
215
1416
pub fn safe_index_optimised(m: Expr, idx: Literal) -> Option<Expr> {
216
    match m {
217
480
        Expr::SafeSlice(_, mat, idxs) => {
218
            // TODO: support >1 slice index (i.e. multidimensional slices)
219
684

            
220
936
            let mut idxs = idxs;
221
720
            let (slice_idx, _) = idxs.iter().find_position(|opt| opt.is_none())?;
222
936
            let _ = idxs[slice_idx].replace(idx.into());
223

            
224
480
            let Some(idxs) = idxs.into_iter().collect::<Option<Vec<_>>>() else {
225
                todo!("slice expression should not contain more than one unspecified index")
226
456
            };
227

            
228
480
            Some(Expr::SafeIndex(Metadata::new(), mat, idxs))
229
        }
230
        Expr::Flatten(_, None, inner) => {
231
            // Similar to indexed_flatten_matrix rule, but we don't care about out of bounds here
232
            let Literal::Int(index) = idx else {
233
                return None;
234
            };
235

            
236
            let dom = inner.domain_of().and_then(|dom| dom.resolve())?;
237
            let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
238
                return None;
239
            };
240
            let flat_index = flat_index_to_full_index(index_domains, index - 1);
241
            let flat_index: Vec<Expr> = flat_index.into_iter().map(Into::into).collect();
242

            
243
456
            Some(Expr::SafeIndex(Metadata::new(), inner, flat_index))
244
456
        }
245
936
        _ => Some(Expr::SafeIndex(
246
936
            Metadata::new(),
247
936
            Moo::new(m),
248
480
            vec![idx.into()],
249
1392
        )),
250
    }
251
960
}