Skip to main content

conjure_cp_core/ast/
matrix.rs

1//! Utility functions for working with matrices.
2
3// TODO: Georgiis essence macro would look really nice in these examples!
4
5use std::collections::VecDeque;
6
7use itertools::{Itertools, izip};
8use uniplate::Uniplate as _;
9
10use crate::ast::{DomainOpError, Expression as Expr, GroundDomain, Metadata, Moo, Range};
11
12use super::{AbstractLiteral, Literal};
13
14/// For some index domains, returns a list containing each of the possible indices.
15///
16/// Indices are traversed in row-major ordering.
17///
18/// This is an O(n^dim) operation, where dim is the number of dimensions in the matrix.
19///
20/// # Panics
21///
22/// + If any of the index domains are not finite or enumerable with [`Domain::values`].
23///
24/// # Example
25///
26/// ```
27/// use std::collections::HashSet;
28/// use conjure_cp_core::ast::{GroundDomain,Moo,Range,Literal,matrix};
29/// let index_domains = vec![Moo::new(GroundDomain::Bool),Moo::new(GroundDomain::Int(vec![Range::Bounded(1,2)]))];
30///
31/// let expected_indices = HashSet::from([
32///   vec![Literal::Bool(false),Literal::Int(1)],
33///   vec![Literal::Bool(false),Literal::Int(2)],
34///   vec![Literal::Bool(true),Literal::Int(1)],
35///   vec![Literal::Bool(true),Literal::Int(2)]
36///   ]);
37///
38/// let actual_indices: HashSet<_> = matrix::enumerate_indices(index_domains).collect();
39///
40/// assert_eq!(actual_indices, expected_indices);
41/// ```
42pub fn try_enumerate_indices(
43    index_domains: Vec<Moo<GroundDomain>>,
44) -> Result<impl Iterator<Item = Vec<Literal>>, DomainOpError> {
45    let domains = index_domains
46        .into_iter()
47        .map(|x| x.values().map(|values| values.collect_vec()))
48        .collect::<Result<Vec<_>, _>>()?;
49    Ok(domains.into_iter().multi_cartesian_product())
50}
51
52/// For some index domains, returns a list containing each of the possible indices.
53///
54/// See [`try_enumerate_indices`] for the fallible variant.
55pub fn enumerate_indices(
56    index_domains: Vec<Moo<GroundDomain>>,
57) -> impl Iterator<Item = Vec<Literal>> {
58    try_enumerate_indices(index_domains).expect("index domain should be enumerable with .values()")
59}
60
61/// Returns the number of possible elements indexable by the given index domains.
62///
63/// In short, returns the product of the sizes of the given indices.
64pub fn num_elements(index_domains: &[Moo<GroundDomain>]) -> Result<u64, DomainOpError> {
65    let idx_dom_lengths = index_domains
66        .iter()
67        .map(|d| d.length())
68        .collect::<Result<Vec<_>, _>>()?;
69    Ok(idx_dom_lengths.iter().product())
70}
71
72/// Flattens a multi-dimensional matrix literal into a one-dimensional slice of its elements.
73///
74/// The elements of the matrix are returned in row-major ordering (see [`enumerate_indices`]).
75///
76/// # Panics
77///
78/// + If the number or type of elements in each dimension is inconsistent.
79///
80/// + If `matrix` is not a matrix.
81pub fn flatten(matrix: AbstractLiteral<Literal>) -> impl Iterator<Item = Literal> {
82    let AbstractLiteral::Matrix(elems, _) = matrix else {
83        panic!("matrix should be a matrix");
84    };
85
86    flatten_1(elems)
87}
88
89fn flatten_1(elems: Vec<Literal>) -> impl Iterator<Item = Literal> {
90    elems.into_iter().flat_map(|elem| {
91        if let Literal::AbstractLiteral(m @ AbstractLiteral::Matrix(_, _)) = elem {
92            Box::new(flatten(m)) as Box<dyn Iterator<Item = Literal>>
93        } else {
94            Box::new(std::iter::once(elem)) as Box<dyn Iterator<Item = Literal>>
95        }
96    })
97}
98/// Flattens a multi-dimensional matrix literal into an iterator over (indices,element).
99///
100/// # Panics
101///
102///   + If the number or type of elements in each dimension is inconsistent.
103///
104///   + If `matrix` is not a matrix.
105///
106///   + If any dimensions in the matrix are not finite or enumerable with [`Domain::values`].
107///     However, index domains in the form `int(i..)` are supported.
108pub fn flatten_enumerate(
109    matrix: AbstractLiteral<Literal>,
110) -> impl Iterator<Item = (Vec<Literal>, Literal)> {
111    let AbstractLiteral::Matrix(elems, _) = matrix.clone() else {
112        panic!("matrix should be a matrix");
113    };
114
115    let index_domains = index_domains(matrix);
116
117    izip!(enumerate_indices(index_domains), flatten_1(elems))
118}
119
120/// Gets the index domains for a matrix literal.
121///
122/// # Panics
123///
124/// + If `matrix` is not a matrix.
125///
126/// + If the number or type of elements in each dimension is inconsistent.
127pub fn index_domains(matrix: AbstractLiteral<Literal>) -> Vec<Moo<GroundDomain>> {
128    let AbstractLiteral::Matrix(_, _) = matrix else {
129        panic!("matrix should be a matrix");
130    };
131
132    matrix.cata(&move |element: AbstractLiteral<Literal>,
133                       child_index_domains: VecDeque<Vec<Moo<GroundDomain>>>| {
134        assert!(
135            child_index_domains.iter().all_equal(),
136            "each child of a matrix should have the same index domain"
137        );
138
139        let child_index_domains = child_index_domains
140            .front()
141            .unwrap_or(&vec![])
142            .iter()
143            .cloned()
144            .collect_vec();
145        match element {
146            AbstractLiteral::Set(_) => vec![],
147            AbstractLiteral::MSet(_) => vec![],
148            AbstractLiteral::Matrix(elems, domain) => {
149                let mut index_domains = vec![bound_index_domain_from_length(domain, elems.len())];
150                index_domains.extend(child_index_domains);
151                index_domains
152            }
153            AbstractLiteral::Tuple(_) => vec![],
154            AbstractLiteral::Record(_) => vec![],
155            AbstractLiteral::Function(_) => vec![],
156            AbstractLiteral::Variant(_) => vec![],
157            AbstractLiteral::Relation(_) => vec![],
158            AbstractLiteral::Sequence(_) => vec![],
159            AbstractLiteral::Partition(_) => vec![],
160        }
161    })
162}
163
164/// See [`enumerate_indices`]. This function zips the two given lists of index domains, performs a
165/// union on each pair, and returns an enumerating iterator over the new list of domains.
166pub fn enumerate_index_union_indices(
167    a_domains: &[Moo<GroundDomain>],
168    b_domains: &[Moo<GroundDomain>],
169) -> Result<impl Iterator<Item = Vec<Literal>>, DomainOpError> {
170    if a_domains.len() != b_domains.len() {
171        return Err(DomainOpError::WrongType);
172    }
173    let idx_domains: Result<Vec<_>, _> = a_domains
174        .iter()
175        .zip(b_domains.iter())
176        .map(|(a, b)| a.union(b))
177        .collect();
178    let idx_domains = idx_domains?.into_iter().map(Moo::new).collect();
179
180    try_enumerate_indices(idx_domains)
181}
182
183// Given index domains for a multi-dimensional matrix and the nth index in the flattened matrix, find the coordinates in the original matrix
184pub fn flat_index_to_full_index(index_domains: &[Moo<GroundDomain>], index: u64) -> Vec<Literal> {
185    let mut remaining = index;
186    let mut multipliers = vec![1; index_domains.len()];
187
188    for i in (1..index_domains.len()).rev() {
189        multipliers[i - 1] = multipliers[i] * index_domains[i].as_ref().length().unwrap();
190    }
191
192    let mut coords = Vec::new();
193    for m in multipliers.iter() {
194        // adjust for 1-based indexing
195        coords.push(((remaining / m + 1) as i32).into());
196        remaining %= *m;
197    }
198
199    coords
200}
201
202/// Gets concrete index domains for a matrix expression.
203///
204/// For matrix literals, right-unbounded integer index domains like `int(1..)` are bounded using
205/// the literal's realised size in that dimension. For non-literals, this falls back to the
206/// expression's resolved domain.
207pub fn bound_index_domains_of_expr(expr: &Expr) -> Option<Vec<Moo<GroundDomain>>> {
208    let dom = expr.domain_of().and_then(|dom| dom.resolve())?;
209    let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
210        return None;
211    };
212
213    let Some(dimension_lengths) = expr_matrix_dimension_lengths(expr) else {
214        return Some(index_domains.clone());
215    };
216
217    assert_eq!(
218        index_domains.len(),
219        dimension_lengths.len(),
220        "matrix literal domain rank should match its realised rank"
221    );
222
223    Some(
224        index_domains
225            .iter()
226            .cloned()
227            .zip(dimension_lengths)
228            .map(|(domain, len)| bound_index_domain_from_length(domain, len))
229            .collect(),
230    )
231}
232
233/// This is the same as `m[x]` except when `m` is of the forms:
234///
235/// - `n[..]`, then it produces n[x] instead of n[..][x]
236/// - `flatten(n)`, then it produces `n[y]` instead of `flatten(n)[y]`,
237///   where `y` is the full index corresponding to flat index `x`
238///
239/// # Returns
240/// + `Some(expr)` if the safe indexing could be constructed
241/// + `None` if it could not be constructed (e.g. invalid index type)
242pub fn safe_index_optimised(m: Expr, idx: Literal) -> Option<Expr> {
243    match m {
244        Expr::SafeSlice(_, mat, idxs) => {
245            // TODO: support >1 slice index (i.e. multidimensional slices)
246
247            let mut idxs = idxs;
248            let (slice_idx, _) = idxs.iter().find_position(|opt| opt.is_none())?;
249            let _ = idxs[slice_idx].replace(idx.into());
250
251            let Some(idxs) = idxs.into_iter().collect::<Option<Vec<_>>>() else {
252                todo!("slice expression should not contain more than one unspecified index")
253            };
254
255            Some(Expr::SafeIndex(Metadata::new(), mat, idxs))
256        }
257        Expr::Flatten(_, None, inner) => {
258            // Similar to indexed_flatten_matrix rule, but we don't care about out of bounds here
259            let Literal::Int(index) = idx else {
260                return None;
261            };
262
263            let index_domains = bound_index_domains_of_expr(inner.as_ref())?;
264            if index_domains.iter().any(|domain| domain.length().is_err()) {
265                return None;
266            }
267            let flat_index = flat_index_to_full_index(&index_domains, (index - 1) as u64);
268            let flat_index: Vec<Expr> = flat_index.into_iter().map(Into::into).collect();
269
270            Some(Expr::SafeIndex(Metadata::new(), inner, flat_index))
271        }
272        _ => Some(Expr::SafeIndex(
273            Metadata::new(),
274            Moo::new(m),
275            vec![idx.into()],
276        )),
277    }
278}
279
280fn bound_index_domain_from_length(mut domain: Moo<GroundDomain>, len: usize) -> Moo<GroundDomain> {
281    match Moo::make_mut(&mut domain) {
282        GroundDomain::Int(ranges) if ranges.len() == 1 && len > 0 => {
283            if let Range::UnboundedR(start) = ranges[0] {
284                let end = start + (len as i32 - 1);
285                ranges[0] = Range::Bounded(start, end);
286            }
287            domain
288        }
289        _ => domain,
290    }
291}
292
293fn expr_matrix_dimension_lengths(expr: &Expr) -> Option<Vec<usize>> {
294    let (elems, _) = expr.clone().unwrap_matrix_unchecked()?;
295
296    let child_dimensions = elems
297        .iter()
298        .map(|elem| expr_matrix_dimension_lengths(elem).unwrap_or_default())
299        .collect_vec();
300
301    assert!(
302        child_dimensions.iter().all_equal(),
303        "each child of a matrix should have the same shape"
304    );
305
306    let mut dimensions = vec![elems.len()];
307    if let Some(child_dimensions) = child_dimensions.into_iter().next() {
308        dimensions.extend(child_dimensions);
309    }
310    Some(dimensions)
311}