1
//! Utility functions for working with matrices.
2

            
3
// TODO: Georgiis essence macro would look really nice in these examples!
4

            
5
use std::collections::VecDeque;
6

            
7
use itertools::{Itertools, izip};
8
use uniplate::Uniplate as _;
9

            
10
use crate::ast::{DomainOpError, GroundDomain, Moo, Range};
11

            
12
use super::{AbstractLiteral, Literal};
13

            
14
/// For some index domains, returns a list containing each of the possible indices.
15
///
16
/// Indices are traversed in row-major ordering.
17
///
18
/// This is an O(n^dim) operation, where dim is the number of dimensions in the matrix.
19
///
20
/// # Panics
21
///
22
/// + If any of the index domains are not finite or enumerable with [`Domain::values`].
23
///
24
/// # Example
25
///
26
/// ```
27
/// use std::collections::HashSet;
28
/// use conjure_cp_core::ast::{GroundDomain,Moo,Range,Literal,matrix};
29
/// let index_domains = vec![Moo::new(GroundDomain::Bool),Moo::new(GroundDomain::Int(vec![Range::Bounded(1,2)]))];
30
///
31
/// let expected_indices = HashSet::from([
32
///   vec![Literal::Bool(false),Literal::Int(1)],
33
///   vec![Literal::Bool(false),Literal::Int(2)],
34
///   vec![Literal::Bool(true),Literal::Int(1)],
35
///   vec![Literal::Bool(true),Literal::Int(2)]
36
///   ]);
37
///
38
/// let actual_indices: HashSet<_> = matrix::enumerate_indices(index_domains).collect();
39
///
40
/// assert_eq!(actual_indices, expected_indices);
41
/// ```
42
pub fn enumerate_indices(
43
    index_domains: Vec<Moo<GroundDomain>>,
44
) -> impl Iterator<Item = Vec<Literal>> {
45
    index_domains
46
        .into_iter()
47
        .map(|x| {
48
            x.values()
49
                .expect("index domain should be enumerable with .values()")
50
                .collect_vec()
51
        })
52
        .multi_cartesian_product()
53
}
54

            
55
/// Flattens a multi-dimensional matrix literal into a one-dimensional slice of its elements.
56
///
57
/// The elements of the matrix are returned in row-major ordering (see [`enumerate_indices`]).
58
///
59
/// # Panics
60
///
61
/// + If the number or type of elements in each dimension is inconsistent.
62
///
63
/// + If `matrix` is not a matrix.
64
pub fn flatten(matrix: AbstractLiteral<Literal>) -> impl Iterator<Item = Literal> {
65
    let AbstractLiteral::Matrix(elems, _) = matrix else {
66
        panic!("matrix should be a matrix");
67
    };
68

            
69
    flatten_1(elems)
70
}
71

            
72
fn flatten_1(elems: Vec<Literal>) -> impl Iterator<Item = Literal> {
73
    elems.into_iter().flat_map(|elem| {
74
        if let Literal::AbstractLiteral(m @ AbstractLiteral::Matrix(_, _)) = elem {
75
            Box::new(flatten(m)) as Box<dyn Iterator<Item = Literal>>
76
        } else {
77
            Box::new(std::iter::once(elem)) as Box<dyn Iterator<Item = Literal>>
78
        }
79
    })
80
}
81
/// Flattens a multi-dimensional matrix literal into an iterator over (indices,element).
82
///
83
/// # Panics
84
///
85
///   + If the number or type of elements in each dimension is inconsistent.
86
///
87
///   + If `matrix` is not a matrix.
88
///
89
///   + If any dimensions in the matrix are not finite or enumerable with [`Domain::values`].
90
///     However, index domains in the form `int(i..)` are supported.
91
pub fn flatten_enumerate(
92
    matrix: AbstractLiteral<Literal>,
93
) -> impl Iterator<Item = (Vec<Literal>, Literal)> {
94
    let AbstractLiteral::Matrix(elems, _) = matrix.clone() else {
95
        panic!("matrix should be a matrix");
96
    };
97

            
98
    let index_domains = index_domains(matrix)
99
        .into_iter()
100
        .map(|mut x| match Moo::make_mut(&mut x) {
101
            // give unboundedr index domains an end
102
            GroundDomain::Int(ranges) if ranges.len() == 1 && !elems.is_empty() => {
103
                if let Range::UnboundedR(start) = ranges[0] {
104
                    ranges[0] = Range::Bounded(start, start + (elems.len() as i32 - 1));
105
                };
106
                x
107
            }
108
            _ => x,
109
        })
110
        .collect_vec();
111

            
112
    izip!(enumerate_indices(index_domains), flatten_1(elems))
113
}
114

            
115
/// Gets the index domains for a matrix literal.
116
///
117
/// # Panics
118
///
119
/// + If `matrix` is not a matrix.
120
///
121
/// + If the number or type of elements in each dimension is inconsistent.
122
pub fn index_domains(matrix: AbstractLiteral<Literal>) -> Vec<Moo<GroundDomain>> {
123
    let AbstractLiteral::Matrix(_, _) = matrix else {
124
        panic!("matrix should be a matrix");
125
    };
126

            
127
    matrix.cata(&move |element: AbstractLiteral<Literal>,
128
                       child_index_domains: VecDeque<Vec<Moo<GroundDomain>>>| {
129
        assert!(
130
            child_index_domains.iter().all_equal(),
131
            "each child of a matrix should have the same index domain"
132
        );
133

            
134
        let child_index_domains = child_index_domains
135
            .front()
136
            .unwrap_or(&vec![])
137
            .iter()
138
            .cloned()
139
            .collect_vec();
140
        match element {
141
            AbstractLiteral::Set(_) => vec![],
142
            AbstractLiteral::Matrix(_, domain) => {
143
                let mut index_domains = vec![domain];
144
                index_domains.extend(child_index_domains);
145
                index_domains
146
            }
147
            AbstractLiteral::Tuple(_) => vec![],
148
            AbstractLiteral::Record(_) => vec![],
149
            AbstractLiteral::Function(_) => vec![],
150
        }
151
    })
152
}
153

            
154
/// See [`enumerate_indices`]. This function zips the two given lists of index domains, performs a
155
/// union on each pair, and returns an enumerating iterator over the new list of domains.
156
pub fn enumerate_index_union_indices(
157
    a_domains: &[Moo<GroundDomain>],
158
    b_domains: &[Moo<GroundDomain>],
159
) -> Result<impl Iterator<Item = Vec<Literal>>, DomainOpError> {
160
    if a_domains.len() != b_domains.len() {
161
        return Err(DomainOpError::WrongType);
162
    }
163
    let idx_domains: Result<Vec<_>, _> = a_domains
164
        .iter()
165
        .zip(b_domains.iter())
166
        .map(|(a, b)| a.union(b))
167
        .collect();
168
    let idx_domains = idx_domains?.into_iter().map(Moo::new).collect();
169

            
170
    Ok(enumerate_indices(idx_domains))
171
}