conjure_cp_core/ast/
matrix.rs1use std::collections::VecDeque;
6
7use itertools::{Itertools, izip};
8use uniplate::Uniplate as _;
9
10use crate::ast::{DomainOpError, Expression as Expr, GroundDomain, Metadata, Moo, Range};
11
12use super::{AbstractLiteral, Literal};
13
14pub fn try_enumerate_indices(
43 index_domains: Vec<Moo<GroundDomain>>,
44) -> Result<impl Iterator<Item = Vec<Literal>>, DomainOpError> {
45 let domains = index_domains
46 .into_iter()
47 .map(|x| x.values().map(|values| values.collect_vec()))
48 .collect::<Result<Vec<_>, _>>()?;
49 Ok(domains.into_iter().multi_cartesian_product())
50}
51
52pub fn enumerate_indices(
56 index_domains: Vec<Moo<GroundDomain>>,
57) -> impl Iterator<Item = Vec<Literal>> {
58 try_enumerate_indices(index_domains).expect("index domain should be enumerable with .values()")
59}
60
61pub fn num_elements(index_domains: &[Moo<GroundDomain>]) -> Result<u64, DomainOpError> {
65 let idx_dom_lengths = index_domains
66 .iter()
67 .map(|d| d.length())
68 .collect::<Result<Vec<_>, _>>()?;
69 Ok(idx_dom_lengths.iter().product())
70}
71
72pub fn flatten(matrix: AbstractLiteral<Literal>) -> impl Iterator<Item = Literal> {
82 let AbstractLiteral::Matrix(elems, _) = matrix else {
83 panic!("matrix should be a matrix");
84 };
85
86 flatten_1(elems)
87}
88
89fn flatten_1(elems: Vec<Literal>) -> impl Iterator<Item = Literal> {
90 elems.into_iter().flat_map(|elem| {
91 if let Literal::AbstractLiteral(m @ AbstractLiteral::Matrix(_, _)) = elem {
92 Box::new(flatten(m)) as Box<dyn Iterator<Item = Literal>>
93 } else {
94 Box::new(std::iter::once(elem)) as Box<dyn Iterator<Item = Literal>>
95 }
96 })
97}
98pub fn flatten_enumerate(
109 matrix: AbstractLiteral<Literal>,
110) -> impl Iterator<Item = (Vec<Literal>, Literal)> {
111 let AbstractLiteral::Matrix(elems, _) = matrix.clone() else {
112 panic!("matrix should be a matrix");
113 };
114
115 let index_domains = index_domains(matrix);
116
117 izip!(enumerate_indices(index_domains), flatten_1(elems))
118}
119
120pub fn index_domains(matrix: AbstractLiteral<Literal>) -> Vec<Moo<GroundDomain>> {
128 let AbstractLiteral::Matrix(_, _) = matrix else {
129 panic!("matrix should be a matrix");
130 };
131
132 matrix.cata(&move |element: AbstractLiteral<Literal>,
133 child_index_domains: VecDeque<Vec<Moo<GroundDomain>>>| {
134 assert!(
135 child_index_domains.iter().all_equal(),
136 "each child of a matrix should have the same index domain"
137 );
138
139 let child_index_domains = child_index_domains
140 .front()
141 .unwrap_or(&vec![])
142 .iter()
143 .cloned()
144 .collect_vec();
145 match element {
146 AbstractLiteral::Set(_) => vec![],
147 AbstractLiteral::MSet(_) => vec![],
148 AbstractLiteral::Matrix(elems, domain) => {
149 let mut index_domains = vec![bound_index_domain_from_length(domain, elems.len())];
150 index_domains.extend(child_index_domains);
151 index_domains
152 }
153 AbstractLiteral::Tuple(_) => vec![],
154 AbstractLiteral::Record(_) => vec![],
155 AbstractLiteral::Function(_) => vec![],
156 }
157 })
158}
159
160pub fn enumerate_index_union_indices(
163 a_domains: &[Moo<GroundDomain>],
164 b_domains: &[Moo<GroundDomain>],
165) -> Result<impl Iterator<Item = Vec<Literal>>, DomainOpError> {
166 if a_domains.len() != b_domains.len() {
167 return Err(DomainOpError::WrongType);
168 }
169 let idx_domains: Result<Vec<_>, _> = a_domains
170 .iter()
171 .zip(b_domains.iter())
172 .map(|(a, b)| a.union(b))
173 .collect();
174 let idx_domains = idx_domains?.into_iter().map(Moo::new).collect();
175
176 try_enumerate_indices(idx_domains)
177}
178
179pub fn flat_index_to_full_index(index_domains: &[Moo<GroundDomain>], index: u64) -> Vec<Literal> {
181 let mut remaining = index;
182 let mut multipliers = vec![1; index_domains.len()];
183
184 for i in (1..index_domains.len()).rev() {
185 multipliers[i - 1] = multipliers[i] * index_domains[i].as_ref().length().unwrap();
186 }
187
188 let mut coords = Vec::new();
189 for m in multipliers.iter() {
190 coords.push(((remaining / m + 1) as i32).into());
192 remaining %= *m;
193 }
194
195 coords
196}
197
198pub fn bound_index_domains_of_expr(expr: &Expr) -> Option<Vec<Moo<GroundDomain>>> {
204 let dom = expr.domain_of().and_then(|dom| dom.resolve())?;
205 let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
206 return None;
207 };
208
209 let Some(dimension_lengths) = expr_matrix_dimension_lengths(expr) else {
210 return Some(index_domains.clone());
211 };
212
213 assert_eq!(
214 index_domains.len(),
215 dimension_lengths.len(),
216 "matrix literal domain rank should match its realised rank"
217 );
218
219 Some(
220 index_domains
221 .iter()
222 .cloned()
223 .zip(dimension_lengths)
224 .map(|(domain, len)| bound_index_domain_from_length(domain, len))
225 .collect(),
226 )
227}
228
229pub fn safe_index_optimised(m: Expr, idx: Literal) -> Option<Expr> {
239 match m {
240 Expr::SafeSlice(_, mat, idxs) => {
241 let mut idxs = idxs;
244 let (slice_idx, _) = idxs.iter().find_position(|opt| opt.is_none())?;
245 let _ = idxs[slice_idx].replace(idx.into());
246
247 let Some(idxs) = idxs.into_iter().collect::<Option<Vec<_>>>() else {
248 todo!("slice expression should not contain more than one unspecified index")
249 };
250
251 Some(Expr::SafeIndex(Metadata::new(), mat, idxs))
252 }
253 Expr::Flatten(_, None, inner) => {
254 let Literal::Int(index) = idx else {
256 return None;
257 };
258
259 let index_domains = bound_index_domains_of_expr(inner.as_ref())?;
260 if index_domains.iter().any(|domain| domain.length().is_err()) {
261 return None;
262 }
263 let flat_index = flat_index_to_full_index(&index_domains, (index - 1) as u64);
264 let flat_index: Vec<Expr> = flat_index.into_iter().map(Into::into).collect();
265
266 Some(Expr::SafeIndex(Metadata::new(), inner, flat_index))
267 }
268 _ => Some(Expr::SafeIndex(
269 Metadata::new(),
270 Moo::new(m),
271 vec![idx.into()],
272 )),
273 }
274}
275
276fn bound_index_domain_from_length(mut domain: Moo<GroundDomain>, len: usize) -> Moo<GroundDomain> {
277 match Moo::make_mut(&mut domain) {
278 GroundDomain::Int(ranges) if ranges.len() == 1 && len > 0 => {
279 if let Range::UnboundedR(start) = ranges[0] {
280 let end = start + (len as i32 - 1);
281 ranges[0] = Range::Bounded(start, end);
282 }
283 domain
284 }
285 _ => domain,
286 }
287}
288
289fn expr_matrix_dimension_lengths(expr: &Expr) -> Option<Vec<usize>> {
290 let (elems, _) = expr.clone().unwrap_matrix_unchecked()?;
291
292 let child_dimensions = elems
293 .iter()
294 .map(|elem| expr_matrix_dimension_lengths(elem).unwrap_or_default())
295 .collect_vec();
296
297 assert!(
298 child_dimensions.iter().all_equal(),
299 "each child of a matrix should have the same shape"
300 );
301
302 let mut dimensions = vec![elems.len()];
303 if let Some(child_dimensions) = child_dimensions.into_iter().next() {
304 dimensions.extend(child_dimensions);
305 }
306 Some(dimensions)
307}