1
//! Utility functions for working with matrices.
2

            
3
// TODO: Georgiis essence macro would look really nice in these examples!
4

            
5
use std::collections::VecDeque;
6

            
7
use itertools::{Itertools, izip};
8
use uniplate::Uniplate as _;
9

            
10
use crate::ast::{DomainOpError, Expression as Expr, GroundDomain, Metadata, Moo, Range};
11

            
12
use super::{AbstractLiteral, Literal};
13

            
14
/// For some index domains, returns a list containing each of the possible indices.
15
///
16
/// Indices are traversed in row-major ordering.
17
///
18
/// This is an O(n^dim) operation, where dim is the number of dimensions in the matrix.
19
///
20
/// # Panics
21
///
22
/// + If any of the index domains are not finite or enumerable with [`Domain::values`].
23
///
24
/// # Example
25
///
26
/// ```
27
/// use std::collections::HashSet;
28
/// use conjure_cp_core::ast::{GroundDomain,Moo,Range,Literal,matrix};
29
/// let index_domains = vec![Moo::new(GroundDomain::Bool),Moo::new(GroundDomain::Int(vec![Range::Bounded(1,2)]))];
30
///
31
/// let expected_indices = HashSet::from([
32
///   vec![Literal::Bool(false),Literal::Int(1)],
33
///   vec![Literal::Bool(false),Literal::Int(2)],
34
///   vec![Literal::Bool(true),Literal::Int(1)],
35
///   vec![Literal::Bool(true),Literal::Int(2)]
36
///   ]);
37
///
38
/// let actual_indices: HashSet<_> = matrix::enumerate_indices(index_domains).collect();
39
///
40
/// assert_eq!(actual_indices, expected_indices);
41
/// ```
42
15454
pub fn enumerate_indices(
43
15454
    index_domains: Vec<Moo<GroundDomain>>,
44
15454
) -> impl Iterator<Item = Vec<Literal>> {
45
15454
    index_domains
46
15454
        .into_iter()
47
22724
        .map(|x| {
48
22724
            x.values()
49
22724
                .expect("index domain should be enumerable with .values()")
50
22724
                .collect_vec()
51
22724
        })
52
15454
        .multi_cartesian_product()
53
15454
}
54

            
55
/// Returns the number of possible elements indexable by the given index domains.
56
///
57
/// In short, returns the product of the sizes of the given indices.
58
580
pub fn num_elements(index_domains: &[Moo<GroundDomain>]) -> Result<u64, DomainOpError> {
59
580
    let idx_dom_lengths = index_domains
60
580
        .iter()
61
1160
        .map(|d| d.length())
62
580
        .collect::<Result<Vec<_>, _>>()?;
63
580
    Ok(idx_dom_lengths.iter().product())
64
580
}
65

            
66
/// Flattens a multi-dimensional matrix literal into a one-dimensional slice of its elements.
67
///
68
/// The elements of the matrix are returned in row-major ordering (see [`enumerate_indices`]).
69
///
70
/// # Panics
71
///
72
/// + If the number or type of elements in each dimension is inconsistent.
73
///
74
/// + If `matrix` is not a matrix.
75
406
pub fn flatten(matrix: AbstractLiteral<Literal>) -> impl Iterator<Item = Literal> {
76
406
    let AbstractLiteral::Matrix(elems, _) = matrix else {
77
        panic!("matrix should be a matrix");
78
    };
79

            
80
406
    flatten_1(elems)
81
406
}
82

            
83
1160
fn flatten_1(elems: Vec<Literal>) -> impl Iterator<Item = Literal> {
84
3364
    elems.into_iter().flat_map(|elem| {
85
406
        if let Literal::AbstractLiteral(m @ AbstractLiteral::Matrix(_, _)) = elem {
86
406
            Box::new(flatten(m)) as Box<dyn Iterator<Item = Literal>>
87
        } else {
88
2958
            Box::new(std::iter::once(elem)) as Box<dyn Iterator<Item = Literal>>
89
        }
90
3364
    })
91
1160
}
92
/// Flattens a multi-dimensional matrix literal into an iterator over (indices,element).
93
///
94
/// # Panics
95
///
96
///   + If the number or type of elements in each dimension is inconsistent.
97
///
98
///   + If `matrix` is not a matrix.
99
///
100
///   + If any dimensions in the matrix are not finite or enumerable with [`Domain::values`].
101
///     However, index domains in the form `int(i..)` are supported.
102
754
pub fn flatten_enumerate(
103
754
    matrix: AbstractLiteral<Literal>,
104
754
) -> impl Iterator<Item = (Vec<Literal>, Literal)> {
105
754
    let AbstractLiteral::Matrix(elems, _) = matrix.clone() else {
106
        panic!("matrix should be a matrix");
107
    };
108

            
109
754
    let index_domains = index_domains(matrix)
110
754
        .into_iter()
111
928
        .map(|mut x| match Moo::make_mut(&mut x) {
112
            // give unboundedr index domains an end
113
928
            GroundDomain::Int(ranges) if ranges.len() == 1 && !elems.is_empty() => {
114
812
                if let Range::UnboundedR(start) = ranges[0] {
115
                    ranges[0] = Range::Bounded(start, start + (elems.len() as i32 - 1));
116
812
                };
117
812
                x
118
            }
119
116
            _ => x,
120
928
        })
121
754
        .collect_vec();
122

            
123
754
    izip!(enumerate_indices(index_domains), flatten_1(elems))
124
754
}
125

            
126
/// Gets the index domains for a matrix literal.
127
///
128
/// # Panics
129
///
130
/// + If `matrix` is not a matrix.
131
///
132
/// + If the number or type of elements in each dimension is inconsistent.
133
812
pub fn index_domains(matrix: AbstractLiteral<Literal>) -> Vec<Moo<GroundDomain>> {
134
812
    let AbstractLiteral::Matrix(_, _) = matrix else {
135
        panic!("matrix should be a matrix");
136
    };
137

            
138
812
    matrix.cata(&move |element: AbstractLiteral<Literal>,
139
1508
                       child_index_domains: VecDeque<Vec<Moo<GroundDomain>>>| {
140
1508
        assert!(
141
1508
            child_index_domains.iter().all_equal(),
142
            "each child of a matrix should have the same index domain"
143
        );
144

            
145
1508
        let child_index_domains = child_index_domains
146
1508
            .front()
147
1508
            .unwrap_or(&vec![])
148
1508
            .iter()
149
1508
            .cloned()
150
1508
            .collect_vec();
151
1508
        match element {
152
            AbstractLiteral::Set(_) => vec![],
153
            AbstractLiteral::MSet(_) => vec![],
154
1508
            AbstractLiteral::Matrix(_, domain) => {
155
1508
                let mut index_domains = vec![domain];
156
1508
                index_domains.extend(child_index_domains);
157
1508
                index_domains
158
            }
159
            AbstractLiteral::Tuple(_) => vec![],
160
            AbstractLiteral::Record(_) => vec![],
161
            AbstractLiteral::Function(_) => vec![],
162
        }
163
1508
    })
164
812
}
165

            
166
/// See [`enumerate_indices`]. This function zips the two given lists of index domains, performs a
167
/// union on each pair, and returns an enumerating iterator over the new list of domains.
168
290
pub fn enumerate_index_union_indices(
169
290
    a_domains: &[Moo<GroundDomain>],
170
290
    b_domains: &[Moo<GroundDomain>],
171
290
) -> Result<impl Iterator<Item = Vec<Literal>>, DomainOpError> {
172
290
    if a_domains.len() != b_domains.len() {
173
        return Err(DomainOpError::WrongType);
174
290
    }
175
290
    let idx_domains: Result<Vec<_>, _> = a_domains
176
290
        .iter()
177
290
        .zip(b_domains.iter())
178
348
        .map(|(a, b)| a.union(b))
179
290
        .collect();
180
290
    let idx_domains = idx_domains?.into_iter().map(Moo::new).collect();
181

            
182
290
    Ok(enumerate_indices(idx_domains))
183
290
}
184

            
185
// Given index domains for a multi-dimensional matrix and the nth index in the flattened matrix, find the coordinates in the original matrix
186
464
pub fn flat_index_to_full_index(index_domains: &[Moo<GroundDomain>], index: u64) -> Vec<Literal> {
187
464
    let mut remaining = index;
188
464
    let mut multipliers = vec![1; index_domains.len()];
189

            
190
464
    for i in (1..index_domains.len()).rev() {
191
464
        multipliers[i - 1] = multipliers[i] * index_domains[i].as_ref().length().unwrap();
192
464
    }
193

            
194
464
    let mut coords = Vec::new();
195
928
    for m in multipliers.iter() {
196
        // adjust for 1-based indexing
197
928
        coords.push(((remaining / m + 1) as i32).into());
198
928
        remaining %= *m;
199
928
    }
200

            
201
464
    coords
202
464
}
203

            
204
/// This is the same as `m[x]` except when `m` is of the forms:
205
///
206
/// - `n[..]`, then it produces n[x] instead of n[..][x]
207
/// - `flatten(n)`, then it produces `n[y]` instead of `flatten(n)[y]`,
208
///   where `y` is the full index corresponding to flat index `x`
209
///
210
/// # Returns
211
/// + `Some(expr)` if the safe indexing could be constructed
212
/// + `None` if it could not be constructed (e.g. invalid index type)
213
2784
pub fn safe_index_optimised(m: Expr, idx: Literal) -> Option<Expr> {
214
    match m {
215
1392
        Expr::SafeSlice(_, mat, idxs) => {
216
            // TODO: support >1 slice index (i.e. multidimensional slices)
217

            
218
1392
            let mut idxs = idxs;
219
2088
            let (slice_idx, _) = idxs.iter().find_position(|opt| opt.is_none())?;
220
1392
            let _ = idxs[slice_idx].replace(idx.into());
221

            
222
1392
            let Some(idxs) = idxs.into_iter().collect::<Option<Vec<_>>>() else {
223
                todo!("slice expression should not contain more than one unspecified index")
224
            };
225

            
226
1392
            Some(Expr::SafeIndex(Metadata::new(), mat, idxs))
227
        }
228
        Expr::Flatten(_, None, inner) => {
229
            // Similar to indexed_flatten_matrix rule, but we don't care about out of bounds here
230
            let Literal::Int(index) = idx else {
231
                return None;
232
            };
233

            
234
            let dom = inner.domain_of().and_then(|dom| dom.resolve())?;
235
            let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
236
                return None;
237
            };
238
            let flat_index = flat_index_to_full_index(index_domains, (index - 1) as u64);
239
            let flat_index: Vec<Expr> = flat_index.into_iter().map(Into::into).collect();
240

            
241
            Some(Expr::SafeIndex(Metadata::new(), inner, flat_index))
242
        }
243
1392
        _ => Some(Expr::SafeIndex(
244
1392
            Metadata::new(),
245
1392
            Moo::new(m),
246
1392
            vec![idx.into()],
247
1392
        )),
248
    }
249
2784
}