conjure_cp_core/ast/
matrix.rs1use std::collections::VecDeque;
6
7use itertools::{Itertools, izip};
8use uniplate::Uniplate as _;
9
10use crate::ast::{DomainOpError, Expression as Expr, GroundDomain, Metadata, Moo, Range};
11
12use super::{AbstractLiteral, Literal};
13
14pub fn try_enumerate_indices(
43 index_domains: Vec<Moo<GroundDomain>>,
44) -> Result<impl Iterator<Item = Vec<Literal>>, DomainOpError> {
45 let domains = index_domains
46 .into_iter()
47 .map(|x| x.values().map(|values| values.collect_vec()))
48 .collect::<Result<Vec<_>, _>>()?;
49 Ok(domains.into_iter().multi_cartesian_product())
50}
51
52pub fn enumerate_indices(
56 index_domains: Vec<Moo<GroundDomain>>,
57) -> impl Iterator<Item = Vec<Literal>> {
58 try_enumerate_indices(index_domains).expect("index domain should be enumerable with .values()")
59}
60
61pub fn num_elements(index_domains: &[Moo<GroundDomain>]) -> Result<u64, DomainOpError> {
65 let idx_dom_lengths = index_domains
66 .iter()
67 .map(|d| d.length())
68 .collect::<Result<Vec<_>, _>>()?;
69 Ok(idx_dom_lengths.iter().product())
70}
71
72pub fn flatten(matrix: AbstractLiteral<Literal>) -> impl Iterator<Item = Literal> {
82 let AbstractLiteral::Matrix(elems, _) = matrix else {
83 panic!("matrix should be a matrix");
84 };
85
86 flatten_1(elems)
87}
88
89fn flatten_1(elems: Vec<Literal>) -> impl Iterator<Item = Literal> {
90 elems.into_iter().flat_map(|elem| {
91 if let Literal::AbstractLiteral(m @ AbstractLiteral::Matrix(_, _)) = elem {
92 Box::new(flatten(m)) as Box<dyn Iterator<Item = Literal>>
93 } else {
94 Box::new(std::iter::once(elem)) as Box<dyn Iterator<Item = Literal>>
95 }
96 })
97}
98pub fn flatten_enumerate(
109 matrix: AbstractLiteral<Literal>,
110) -> impl Iterator<Item = (Vec<Literal>, Literal)> {
111 let AbstractLiteral::Matrix(elems, _) = matrix.clone() else {
112 panic!("matrix should be a matrix");
113 };
114
115 let index_domains = index_domains(matrix);
116
117 izip!(enumerate_indices(index_domains), flatten_1(elems))
118}
119
120pub fn index_domains(matrix: AbstractLiteral<Literal>) -> Vec<Moo<GroundDomain>> {
128 let AbstractLiteral::Matrix(_, _) = matrix else {
129 panic!("matrix should be a matrix");
130 };
131
132 matrix.cata(&move |element: AbstractLiteral<Literal>,
133 child_index_domains: VecDeque<Vec<Moo<GroundDomain>>>| {
134 assert!(
135 child_index_domains.iter().all_equal(),
136 "each child of a matrix should have the same index domain"
137 );
138
139 let child_index_domains = child_index_domains
140 .front()
141 .unwrap_or(&vec![])
142 .iter()
143 .cloned()
144 .collect_vec();
145 match element {
146 AbstractLiteral::Set(_) => vec![],
147 AbstractLiteral::MSet(_) => vec![],
148 AbstractLiteral::Matrix(elems, domain) => {
149 let mut index_domains = vec![bound_index_domain_from_length(domain, elems.len())];
150 index_domains.extend(child_index_domains);
151 index_domains
152 }
153 AbstractLiteral::Tuple(_) => vec![],
154 AbstractLiteral::Record(_) => vec![],
155 AbstractLiteral::Function(_) => vec![],
156 AbstractLiteral::Variant(_) => vec![],
157 AbstractLiteral::Relation(_) => vec![],
158 AbstractLiteral::Sequence(_) => vec![],
159 AbstractLiteral::Partition(_) => vec![],
160 }
161 })
162}
163
164pub fn enumerate_index_union_indices(
167 a_domains: &[Moo<GroundDomain>],
168 b_domains: &[Moo<GroundDomain>],
169) -> Result<impl Iterator<Item = Vec<Literal>>, DomainOpError> {
170 if a_domains.len() != b_domains.len() {
171 return Err(DomainOpError::WrongType);
172 }
173 let idx_domains: Result<Vec<_>, _> = a_domains
174 .iter()
175 .zip(b_domains.iter())
176 .map(|(a, b)| a.union(b))
177 .collect();
178 let idx_domains = idx_domains?.into_iter().map(Moo::new).collect();
179
180 try_enumerate_indices(idx_domains)
181}
182
183pub fn flat_index_to_full_index(index_domains: &[Moo<GroundDomain>], index: u64) -> Vec<Literal> {
185 let mut remaining = index;
186 let mut multipliers = vec![1; index_domains.len()];
187
188 for i in (1..index_domains.len()).rev() {
189 multipliers[i - 1] = multipliers[i] * index_domains[i].as_ref().length().unwrap();
190 }
191
192 let mut coords = Vec::new();
193 for m in multipliers.iter() {
194 coords.push(((remaining / m + 1) as i32).into());
196 remaining %= *m;
197 }
198
199 coords
200}
201
202pub fn bound_index_domains_of_expr(expr: &Expr) -> Option<Vec<Moo<GroundDomain>>> {
208 let dom = expr.domain_of().and_then(|dom| dom.resolve())?;
209 let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
210 return None;
211 };
212
213 let Some(dimension_lengths) = expr_matrix_dimension_lengths(expr) else {
214 return Some(index_domains.clone());
215 };
216
217 assert_eq!(
218 index_domains.len(),
219 dimension_lengths.len(),
220 "matrix literal domain rank should match its realised rank"
221 );
222
223 Some(
224 index_domains
225 .iter()
226 .cloned()
227 .zip(dimension_lengths)
228 .map(|(domain, len)| bound_index_domain_from_length(domain, len))
229 .collect(),
230 )
231}
232
233pub fn safe_index_optimised(m: Expr, idx: Literal) -> Option<Expr> {
243 match m {
244 Expr::SafeSlice(_, mat, idxs) => {
245 let mut idxs = idxs;
248 let (slice_idx, _) = idxs.iter().find_position(|opt| opt.is_none())?;
249 let _ = idxs[slice_idx].replace(idx.into());
250
251 let Some(idxs) = idxs.into_iter().collect::<Option<Vec<_>>>() else {
252 todo!("slice expression should not contain more than one unspecified index")
253 };
254
255 Some(Expr::SafeIndex(Metadata::new(), mat, idxs))
256 }
257 Expr::Flatten(_, None, inner) => {
258 let Literal::Int(index) = idx else {
260 return None;
261 };
262
263 let index_domains = bound_index_domains_of_expr(inner.as_ref())?;
264 if index_domains.iter().any(|domain| domain.length().is_err()) {
265 return None;
266 }
267 let flat_index = flat_index_to_full_index(&index_domains, (index - 1) as u64);
268 let flat_index: Vec<Expr> = flat_index.into_iter().map(Into::into).collect();
269
270 Some(Expr::SafeIndex(Metadata::new(), inner, flat_index))
271 }
272 _ => Some(Expr::SafeIndex(
273 Metadata::new(),
274 Moo::new(m),
275 vec![idx.into()],
276 )),
277 }
278}
279
280fn bound_index_domain_from_length(mut domain: Moo<GroundDomain>, len: usize) -> Moo<GroundDomain> {
281 match Moo::make_mut(&mut domain) {
282 GroundDomain::Int(ranges) if ranges.len() == 1 && len > 0 => {
283 if let Range::UnboundedR(start) = ranges[0] {
284 let end = start + (len as i32 - 1);
285 ranges[0] = Range::Bounded(start, end);
286 }
287 domain
288 }
289 _ => domain,
290 }
291}
292
293fn expr_matrix_dimension_lengths(expr: &Expr) -> Option<Vec<usize>> {
294 let (elems, _) = expr.clone().unwrap_matrix_unchecked()?;
295
296 let child_dimensions = elems
297 .iter()
298 .map(|elem| expr_matrix_dimension_lengths(elem).unwrap_or_default())
299 .collect_vec();
300
301 assert!(
302 child_dimensions.iter().all_equal(),
303 "each child of a matrix should have the same shape"
304 );
305
306 let mut dimensions = vec![elems.len()];
307 if let Some(child_dimensions) = child_dimensions.into_iter().next() {
308 dimensions.extend(child_dimensions);
309 }
310 Some(dimensions)
311}