conjure_cp_core/ast/
matrix.rs

1//! Utility functions for working with matrices.
2
3// TODO: Georgiis essence macro would look really nice in these examples!
4
5use std::collections::VecDeque;
6
7use itertools::{Itertools, izip};
8use uniplate::Uniplate as _;
9
10use crate::ast::{DomainOpError, GroundDomain, Moo, Range};
11
12use super::{AbstractLiteral, Literal};
13
14/// For some index domains, returns a list containing each of the possible indices.
15///
16/// Indices are traversed in row-major ordering.
17///
18/// This is an O(n^dim) operation, where dim is the number of dimensions in the matrix.
19///
20/// # Panics
21///
22/// + If any of the index domains are not finite or enumerable with [`Domain::values`].
23///
24/// # Example
25///
26/// ```
27/// use std::collections::HashSet;
28/// use conjure_cp_core::ast::{GroundDomain,Moo,Range,Literal,matrix};
29/// let index_domains = vec![Moo::new(GroundDomain::Bool),Moo::new(GroundDomain::Int(vec![Range::Bounded(1,2)]))];
30///
31/// let expected_indices = HashSet::from([
32///   vec![Literal::Bool(false),Literal::Int(1)],
33///   vec![Literal::Bool(false),Literal::Int(2)],
34///   vec![Literal::Bool(true),Literal::Int(1)],
35///   vec![Literal::Bool(true),Literal::Int(2)]
36///   ]);
37///
38/// let actual_indices: HashSet<_> = matrix::enumerate_indices(index_domains).collect();
39///
40/// assert_eq!(actual_indices, expected_indices);
41/// ```
42pub fn enumerate_indices(
43    index_domains: Vec<Moo<GroundDomain>>,
44) -> impl Iterator<Item = Vec<Literal>> {
45    index_domains
46        .into_iter()
47        .map(|x| {
48            x.values()
49                .expect("index domain should be enumerable with .values()")
50                .collect_vec()
51        })
52        .multi_cartesian_product()
53}
54
55/// Returns the number of possible elements indexable by the given index domains.
56///
57/// In short, returns the product of the sizes of the given indices.
58pub fn num_elements(index_domains: &[Moo<GroundDomain>]) -> Result<u64, DomainOpError> {
59    let idx_dom_lengths = index_domains
60        .iter()
61        .map(|d| d.length())
62        .collect::<Result<Vec<_>, _>>()?;
63    Ok(idx_dom_lengths.iter().product())
64}
65
66/// Flattens a multi-dimensional matrix literal into a one-dimensional slice of its elements.
67///
68/// The elements of the matrix are returned in row-major ordering (see [`enumerate_indices`]).
69///
70/// # Panics
71///
72/// + If the number or type of elements in each dimension is inconsistent.
73///
74/// + If `matrix` is not a matrix.
75pub fn flatten(matrix: AbstractLiteral<Literal>) -> impl Iterator<Item = Literal> {
76    let AbstractLiteral::Matrix(elems, _) = matrix else {
77        panic!("matrix should be a matrix");
78    };
79
80    flatten_1(elems)
81}
82
83fn flatten_1(elems: Vec<Literal>) -> impl Iterator<Item = Literal> {
84    elems.into_iter().flat_map(|elem| {
85        if let Literal::AbstractLiteral(m @ AbstractLiteral::Matrix(_, _)) = elem {
86            Box::new(flatten(m)) as Box<dyn Iterator<Item = Literal>>
87        } else {
88            Box::new(std::iter::once(elem)) as Box<dyn Iterator<Item = Literal>>
89        }
90    })
91}
92/// Flattens a multi-dimensional matrix literal into an iterator over (indices,element).
93///
94/// # Panics
95///
96///   + If the number or type of elements in each dimension is inconsistent.
97///
98///   + If `matrix` is not a matrix.
99///
100///   + If any dimensions in the matrix are not finite or enumerable with [`Domain::values`].
101///     However, index domains in the form `int(i..)` are supported.
102pub fn flatten_enumerate(
103    matrix: AbstractLiteral<Literal>,
104) -> impl Iterator<Item = (Vec<Literal>, Literal)> {
105    let AbstractLiteral::Matrix(elems, _) = matrix.clone() else {
106        panic!("matrix should be a matrix");
107    };
108
109    let index_domains = index_domains(matrix)
110        .into_iter()
111        .map(|mut x| match Moo::make_mut(&mut x) {
112            // give unboundedr index domains an end
113            GroundDomain::Int(ranges) if ranges.len() == 1 && !elems.is_empty() => {
114                if let Range::UnboundedR(start) = ranges[0] {
115                    ranges[0] = Range::Bounded(start, start + (elems.len() as i32 - 1));
116                };
117                x
118            }
119            _ => x,
120        })
121        .collect_vec();
122
123    izip!(enumerate_indices(index_domains), flatten_1(elems))
124}
125
126/// Gets the index domains for a matrix literal.
127///
128/// # Panics
129///
130/// + If `matrix` is not a matrix.
131///
132/// + If the number or type of elements in each dimension is inconsistent.
133pub fn index_domains(matrix: AbstractLiteral<Literal>) -> Vec<Moo<GroundDomain>> {
134    let AbstractLiteral::Matrix(_, _) = matrix else {
135        panic!("matrix should be a matrix");
136    };
137
138    matrix.cata(&move |element: AbstractLiteral<Literal>,
139                       child_index_domains: VecDeque<Vec<Moo<GroundDomain>>>| {
140        assert!(
141            child_index_domains.iter().all_equal(),
142            "each child of a matrix should have the same index domain"
143        );
144
145        let child_index_domains = child_index_domains
146            .front()
147            .unwrap_or(&vec![])
148            .iter()
149            .cloned()
150            .collect_vec();
151        match element {
152            AbstractLiteral::Set(_) => vec![],
153            AbstractLiteral::Matrix(_, domain) => {
154                let mut index_domains = vec![domain];
155                index_domains.extend(child_index_domains);
156                index_domains
157            }
158            AbstractLiteral::Tuple(_) => vec![],
159            AbstractLiteral::Record(_) => vec![],
160            AbstractLiteral::Function(_) => vec![],
161        }
162    })
163}
164
165/// See [`enumerate_indices`]. This function zips the two given lists of index domains, performs a
166/// union on each pair, and returns an enumerating iterator over the new list of domains.
167pub fn enumerate_index_union_indices(
168    a_domains: &[Moo<GroundDomain>],
169    b_domains: &[Moo<GroundDomain>],
170) -> Result<impl Iterator<Item = Vec<Literal>>, DomainOpError> {
171    if a_domains.len() != b_domains.len() {
172        return Err(DomainOpError::WrongType);
173    }
174    let idx_domains: Result<Vec<_>, _> = a_domains
175        .iter()
176        .zip(b_domains.iter())
177        .map(|(a, b)| a.union(b))
178        .collect();
179    let idx_domains = idx_domains?.into_iter().map(Moo::new).collect();
180
181    Ok(enumerate_indices(idx_domains))
182}
183
184// Given index domains for a multi-dimensional matrix and the nth index in the flattened matrix, find the coordinates in the original matrix
185pub fn flat_index_to_full_index(index_domains: &[Moo<GroundDomain>], index: u64) -> Vec<Literal> {
186    let mut remaining = index;
187    let mut multipliers = vec![1; index_domains.len()];
188
189    for i in (1..index_domains.len()).rev() {
190        multipliers[i - 1] = multipliers[i] * index_domains[i].as_ref().length().unwrap();
191    }
192
193    let mut coords = Vec::new();
194    for m in multipliers.iter() {
195        // adjust for 1-based indexing
196        coords.push(((remaining / m + 1) as i32).into());
197        remaining %= *m;
198    }
199
200    coords
201}