conjure_core/ast/matrix.rs
1//! Utility functions for working with matrices.
2
3// TODO: Georgiis essence macro would look really nice in these examples!
4
5use std::{collections::VecDeque, sync::Arc};
6
7use itertools::{izip, Itertools};
8use uniplate::Uniplate as _;
9
10use crate::ast::{Domain, Range};
11
12use super::{AbstractLiteral, Literal};
13
14/// For some index domains, returns a list containing each of the possible indices.
15///
16/// Indices are traversed in row-major ordering.
17///
18/// This is an O(n^dim) operation, where dim is the number of dimensions in the matrix.
19///
20/// # Panics
21///
22/// + If any of the index domains are not finite or enumerable with [`Domain::values`].
23///
24/// # Example
25///
26/// ```
27/// use conjure_core::ast::{Domain,Range,Literal,matrix};
28/// let index_domains = vec![Domain::BoolDomain,Domain::IntDomain(vec![Range::Bounded(1,2)])];
29///
30/// let expected_indices = vec![
31/// vec![Literal::Bool(false),Literal::Int(1)],
32/// vec![Literal::Bool(false),Literal::Int(2)],
33/// vec![Literal::Bool(true),Literal::Int(1)],
34/// vec![Literal::Bool(true),Literal::Int(2)]
35/// ];
36///
37/// let actual_indices: Vec<_> = matrix::enumerate_indices(index_domains).collect();
38///
39/// assert_eq!(actual_indices, expected_indices);
40/// ```
41pub fn enumerate_indices(index_domains: Vec<Domain>) -> impl Iterator<Item = Vec<Literal>> {
42 index_domains
43 .into_iter()
44 .map(|x| {
45 x.values()
46 .expect("index domain should be enumerable with .values()")
47 })
48 .multi_cartesian_product()
49}
50
51/// Flattens a multi-dimensional matrix literal into a one-dimensional slice of its elements.
52///
53/// The elements of the matrix are returned in row-major ordering (see [`enumerate_indices`]).
54///
55/// # Panics
56///
57/// + If the number or type of elements in each dimension is inconsistent.
58///
59/// + If `matrix` is not a matrix.
60pub fn flatten(matrix: AbstractLiteral<Literal>) -> impl Iterator<Item = Literal> {
61 let AbstractLiteral::Matrix(elems, _) = matrix else {
62 panic!("matrix should be a matrix");
63 };
64
65 flatten_1(elems)
66}
67
68fn flatten_1(elems: Vec<Literal>) -> impl Iterator<Item = Literal> {
69 elems.into_iter().flat_map(|elem| {
70 if let Literal::AbstractLiteral(m @ AbstractLiteral::Matrix(_, _)) = elem {
71 Box::new(flatten(m)) as Box<dyn Iterator<Item = Literal>>
72 } else {
73 Box::new(std::iter::once(elem)) as Box<dyn Iterator<Item = Literal>>
74 }
75 })
76}
77/// Flattens a multi-dimensional matrix literal into an iterator over (indices,element).
78///
79/// # Panics
80///
81/// + If the number or type of elements in each dimension is inconsistent.
82///
83/// + If `matrix` is not a matrix.
84///
85/// + If any dimensions in the matrix are not finite or enumerable with [`Domain::values`].
86/// However, index domains in the form `int(i..)` are supported.
87pub fn flatten_enumerate(
88 matrix: AbstractLiteral<Literal>,
89) -> impl Iterator<Item = (Vec<Literal>, Literal)> {
90 let AbstractLiteral::Matrix(elems, _) = matrix.clone() else {
91 panic!("matrix should be a matrix");
92 };
93
94 let index_domains = index_domains(matrix)
95 .into_iter()
96 .map(|mut x| match x {
97 // give unboundedr index domains an end
98 Domain::IntDomain(ref mut ranges) if ranges.len() == 1 && !elems.is_empty() => {
99 if let Range::UnboundedR(start) = ranges[0] {
100 ranges[0] = Range::Bounded(start, start + (elems.len() as i32 - 1));
101 };
102 x
103 }
104 _ => x,
105 })
106 .collect_vec();
107
108 izip!(enumerate_indices(index_domains), flatten_1(elems))
109}
110
111/// Gets the index domains for a matrix literal.
112///
113/// # Panics
114///
115/// + If `matrix` is not a matrix.
116///
117/// + If the number or type of elements in each dimension is inconsistent.
118pub fn index_domains(matrix: AbstractLiteral<Literal>) -> Vec<Domain> {
119 let AbstractLiteral::Matrix(_, _) = matrix else {
120 panic!("matrix should be a matrix");
121 };
122
123 matrix.cata(Arc::new(
124 move |element: AbstractLiteral<Literal>, child_index_domains: VecDeque<Vec<Domain>>| {
125 assert!(
126 child_index_domains.iter().all_equal(),
127 "each child of a matrix should have the same index domain"
128 );
129
130 let child_index_domains = child_index_domains.front().cloned().unwrap_or(vec![]);
131 match element {
132 AbstractLiteral::Set(_) => vec![],
133 AbstractLiteral::Matrix(_, domain) => {
134 let mut index_domains = vec![domain];
135 index_domains.extend(child_index_domains);
136 index_domains
137 }
138 }
139 },
140 ))
141}