conjure_core/rules/matrix/
repr_matrix.rs
1use conjure_core::ast::Expression as Expr;
2use conjure_core::ast::{matrix, SymbolTable};
3use conjure_core::rule_engine::{
4 register_rule, ApplicationError::RuleNotApplicable, ApplicationResult, Reduction,
5};
6use itertools::{chain, izip, Itertools};
7use uniplate::Uniplate;
8
9use crate::ast::Domain;
10use crate::ast::Literal;
11use crate::ast::Name;
12use crate::ast::{Atom, Range};
13use crate::into_matrix_expr;
14use crate::metadata::Metadata;
15
16#[register_rule(("Base", 2000))]
18fn index_matrix_to_atom(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
19 let Expr::SafeIndex(_, subject, indices) = expr else {
21 return Err(RuleNotApplicable);
22 };
23
24 let Expr::Atomic(_, Atom::Reference(Name::WithRepresentation(name, reprs))) = &**subject else {
27 return Err(RuleNotApplicable);
28 };
29
30 if reprs.first().is_none_or(|x| x.as_str() != "matrix_to_atom") {
31 return Err(RuleNotApplicable);
32 }
33
34 let repr = symbols
35 .get_representation(name, &["matrix_to_atom"])
36 .unwrap()[0]
37 .clone();
38
39 let decl = symbols.lookup(name).unwrap();
41
42 let Some(Domain::DomainMatrix(_, index_domains)) =
44 decl.domain().cloned().map(|x| x.resolve(symbols))
45 else {
46 return Err(RuleNotApplicable);
47 };
48
49 let mut indices_are_const = true;
56 let mut indices_as_lits: Vec<Literal> = vec![];
57
58 for index in indices {
59 let Some(index) = index.clone().to_literal() else {
60 indices_are_const = false;
61 break;
62 };
63 indices_as_lits.push(index);
64 }
65
66 if indices_are_const {
67 let indices_as_name = Name::RepresentedName(
70 name.clone(),
71 "matrix_to_atom".into(),
72 indices_as_lits.iter().join("_"),
73 );
74
75 let subject = repr.expression_down(symbols)?[&indices_as_name].clone();
76
77 Ok(Reduction::pure(subject))
78 } else {
79 let n_dims = index_domains.len();
96 if n_dims <= 1 {
97 return Err(RuleNotApplicable);
98 };
99
100 let bounds = index_domains
104 .iter()
105 .map(|dom| {
106 let Domain::IntDomain(ranges) = dom else {
107 return Err(RuleNotApplicable);
108 };
109
110 let &[Range::Bounded(from, to)] = &ranges[..] else {
111 return Err(RuleNotApplicable);
112 };
113
114 Ok((from, to))
115 })
116 .process_results(|it| it.collect_vec())?;
117
118 let sizes = bounds
120 .iter()
121 .map(|(from, to)| (to - from) + 1)
122 .collect_vec();
123
124 let lower_bounds = bounds.iter().map(|(from, _)| from).collect_vec();
126
127 let mut coeffs: Vec<Expr> = chain!(std::iter::once(&1), sizes.iter().skip(1).rev())
158 .scan(1, |state, &x| {
159 *state *= x;
160 Some(*state)
161 })
162 .map(|x| Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Int(x))))
163 .collect_vec();
164
165 coeffs.reverse();
166
167 let terms: Vec<Expr> = izip!(indices, lower_bounds)
169 .map(|(i, lbi)| {
170 Expr::Minus(
171 Metadata::new(),
172 Box::new(i.clone()),
173 Box::new(Expr::Atomic(
174 Metadata::new(),
175 Atom::Literal(Literal::Int(*lbi)),
176 )),
177 )
178 })
179 .collect_vec();
180
181 let mut sum_terms: Vec<Expr> = izip!(coeffs, terms)
183 .map(|(coeff, term)| Expr::Product(Metadata::new(), vec![coeff, term]))
184 .collect_vec();
185
186 sum_terms.push(Expr::Atomic(
188 Metadata::new(),
189 Atom::Literal(Literal::Int(1)),
190 ));
191
192 let flat_index = Expr::Sum(Metadata::new(), Box::new(into_matrix_expr![sum_terms]));
193
194 let repr_exprs = repr.expression_down(symbols)?;
197 let flat_elems = matrix::enumerate_indices(index_domains.clone())
198 .map(|xs| {
199 Name::RepresentedName(
200 name.clone(),
201 "matrix_to_atom".into(),
202 xs.into_iter().join("_"),
203 )
204 })
205 .map(|x| repr_exprs[&x].clone())
206 .collect_vec();
207
208 let flat_matrix = into_matrix_expr![flat_elems];
209
210 Ok(Reduction::pure(Expr::SafeIndex(
211 Metadata::new(),
212 Box::new(flat_matrix),
213 vec![flat_index],
214 )))
215 }
216}
217
218#[register_rule(("Base", 2000))]
220fn slice_matrix_to_atom(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
221 let Expr::SafeSlice(_, subject, indices) = expr else {
222 return Err(RuleNotApplicable);
223 };
224
225 let Expr::Atomic(_, Atom::Reference(Name::WithRepresentation(name, reprs))) = &**subject else {
226 return Err(RuleNotApplicable);
227 };
228
229 if reprs.first().is_none_or(|x| x.as_str() != "matrix_to_atom") {
230 return Err(RuleNotApplicable);
231 }
232
233 let decl = symbols.lookup(name).unwrap();
234 let repr = symbols
235 .get_representation(name, &["matrix_to_atom"])
236 .unwrap()[0]
237 .clone();
238
239 let Some(Domain::DomainMatrix(_, index_domains)) =
241 decl.domain().cloned().map(|x| x.resolve(symbols))
242 else {
243 return Err(RuleNotApplicable);
244 };
245
246 let mut indices_as_lits: Vec<Option<Literal>> = vec![];
247 let mut hole_dim: i32 = -1;
248 for (i, index) in indices.iter().enumerate() {
249 match index {
250 Some(e) => {
251 let lit = e.clone().to_literal().ok_or(RuleNotApplicable)?;
252 indices_as_lits.push(Some(lit.clone()));
253 }
254 None => {
255 indices_as_lits.push(None);
256 assert_eq!(hole_dim, -1);
257 hole_dim = i as _;
258 }
259 }
260 }
261
262 assert_ne!(hole_dim, -1);
263
264 let repr_values = repr.expression_down(symbols)?;
265
266 let slice = index_domains[hole_dim as usize]
267 .values()
268 .expect("index domain should be finite and enumerable")
269 .into_iter()
270 .map(|i| {
271 let mut indices_as_lits = indices_as_lits.clone();
272 indices_as_lits[hole_dim as usize] = Some(i);
273 let name = Name::RepresentedName(
274 name.clone(),
275 "matrix_to_atom".into(),
276 indices_as_lits.into_iter().map(|x| x.unwrap()).join("_"),
277 );
278 repr_values[&name].clone()
279 })
280 .collect_vec();
281
282 let new_expr = into_matrix_expr!(slice);
283
284 Ok(Reduction::pure(new_expr))
285}
286
287#[register_rule(("Base", 2000))]
289fn matrix_ref_to_atom(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
290 if let Expr::SafeSlice(_, _, _)
291 | Expr::UnsafeSlice(_, _, _)
292 | Expr::SafeIndex(_, _, _)
293 | Expr::UnsafeIndex(_, _, _) = expr
294 {
295 return Err(RuleNotApplicable);
296 };
297
298 for (child, ctx) in expr.holes() {
299 let Expr::Atomic(_, Atom::Reference(Name::WithRepresentation(name, reprs))) = child else {
300 continue;
301 };
302
303 if reprs.first().is_none_or(|x| x.as_str() != "matrix_to_atom") {
304 continue;
305 }
306
307 let decl = symbols.lookup(name.as_ref()).unwrap();
308 let repr = symbols
309 .get_representation(name.as_ref(), &["matrix_to_atom"])
310 .unwrap()[0]
311 .clone();
312
313 let Some(Domain::DomainMatrix(_, index_domains)) =
315 decl.domain().cloned().map(|x| x.resolve(symbols))
316 else {
317 continue;
318 };
319
320 if index_domains.len() > 1 {
321 continue;
322 }
323
324 let Ok(matrix_values) = repr.expression_down(symbols) else {
325 continue;
326 };
327
328 let flat_values = matrix::enumerate_indices(index_domains)
329 .map(|i| {
330 matrix_values[&Name::RepresentedName(
331 name.clone(),
332 "matrix_to_atom".into(),
333 i.iter().join("_"),
334 )]
335 .clone()
336 })
337 .collect_vec();
338 return Ok(Reduction::pure(ctx(into_matrix_expr![flat_values])));
339 }
340
341 Err(RuleNotApplicable)
342}