conjure_core/ast/
matrix.rs

1//! Utility functions for working with matrices.
2
3// TODO: Georgiis essence macro would look really nice in these examples!
4
5use std::{collections::VecDeque, sync::Arc};
6
7use itertools::{izip, Itertools};
8use uniplate::Uniplate as _;
9
10use crate::ast::{Domain, Range};
11
12use super::{AbstractLiteral, Literal};
13
14/// For some index domains, returns a list containing each of the possible indices.
15///
16/// Indices are traversed in row-major ordering.
17///
18/// This is an O(n^dim) operation, where dim is the number of dimensions in the matrix.
19///
20/// # Panics
21///
22/// + If any of the index domains are not finite or enumerable with [`Domain::values`].
23///
24/// # Example
25///
26/// ```
27/// use conjure_core::ast::{Domain,Range,Literal,matrix};
28/// let index_domains = vec![Domain::BoolDomain,Domain::IntDomain(vec![Range::Bounded(1,2)])];
29///
30/// let expected_indices = vec![
31///   vec![Literal::Bool(false),Literal::Int(1)],
32///   vec![Literal::Bool(false),Literal::Int(2)],
33///   vec![Literal::Bool(true),Literal::Int(1)],
34///   vec![Literal::Bool(true),Literal::Int(2)]
35///   ];
36///
37/// let actual_indices: Vec<_> = matrix::enumerate_indices(index_domains).collect();
38///
39/// assert_eq!(actual_indices, expected_indices);
40/// ```
41pub fn enumerate_indices(index_domains: Vec<Domain>) -> impl Iterator<Item = Vec<Literal>> {
42    index_domains
43        .into_iter()
44        .map(|x| {
45            x.values()
46                .expect("index domain should be enumerable with .values()")
47        })
48        .multi_cartesian_product()
49}
50
51/// Flattens a multi-dimensional matrix literal into a one-dimensional slice of its elements.
52///
53/// The elements of the matrix are returned in row-major ordering (see [`enumerate_indices`]).
54///
55/// # Panics
56///
57/// + If the number or type of elements in each dimension is inconsistent.
58///
59/// + If `matrix` is not a matrix.
60pub fn flatten(matrix: AbstractLiteral<Literal>) -> impl Iterator<Item = Literal> {
61    let AbstractLiteral::Matrix(elems, _) = matrix else {
62        panic!("matrix should be a matrix");
63    };
64
65    flatten_1(elems)
66}
67
68fn flatten_1(elems: Vec<Literal>) -> impl Iterator<Item = Literal> {
69    elems.into_iter().flat_map(|elem| {
70        if let Literal::AbstractLiteral(m @ AbstractLiteral::Matrix(_, _)) = elem {
71            Box::new(flatten(m)) as Box<dyn Iterator<Item = Literal>>
72        } else {
73            Box::new(std::iter::once(elem)) as Box<dyn Iterator<Item = Literal>>
74        }
75    })
76}
77/// Flattens a multi-dimensional matrix literal into an iterator over (indices,element).
78///
79/// # Panics
80///
81///   + If the number or type of elements in each dimension is inconsistent.
82///
83///   + If `matrix` is not a matrix.
84///
85///   + If any dimensions in the matrix are not finite or enumerable with [`Domain::values`].
86///     However, index domains in the form `int(i..)` are supported.
87pub fn flatten_enumerate(
88    matrix: AbstractLiteral<Literal>,
89) -> impl Iterator<Item = (Vec<Literal>, Literal)> {
90    let AbstractLiteral::Matrix(elems, _) = matrix.clone() else {
91        panic!("matrix should be a matrix");
92    };
93
94    let index_domains = index_domains(matrix)
95        .into_iter()
96        .map(|mut x| match x {
97            // give unboundedr index domains an end
98            Domain::IntDomain(ref mut ranges) if ranges.len() == 1 && !elems.is_empty() => {
99                if let Range::UnboundedR(start) = ranges[0] {
100                    ranges[0] = Range::Bounded(start, start + (elems.len() as i32 - 1));
101                };
102                x
103            }
104            _ => x,
105        })
106        .collect_vec();
107
108    izip!(enumerate_indices(index_domains), flatten_1(elems))
109}
110
111/// Gets the index domains for a matrix literal.
112///
113/// # Panics
114///
115/// + If `matrix` is not a matrix.
116///
117/// + If the number or type of elements in each dimension is inconsistent.
118pub fn index_domains(matrix: AbstractLiteral<Literal>) -> Vec<Domain> {
119    let AbstractLiteral::Matrix(_, _) = matrix else {
120        panic!("matrix should be a matrix");
121    };
122
123    matrix.cata(Arc::new(
124        move |element: AbstractLiteral<Literal>, child_index_domains: VecDeque<Vec<Domain>>| {
125            assert!(
126                child_index_domains.iter().all_equal(),
127                "each child of a matrix should have the same index domain"
128            );
129
130            let child_index_domains = child_index_domains.front().cloned().unwrap_or(vec![]);
131            match element {
132                AbstractLiteral::Set(_) => vec![],
133                AbstractLiteral::Matrix(_, domain) => {
134                    let mut index_domains = vec![domain];
135                    index_domains.extend(child_index_domains);
136                    index_domains
137                }
138            }
139        },
140    ))
141}