1#![allow(dead_code)]
2use std::collections::HashSet;
3
4use conjure_core::ast::{Atom, Expression as Expr, Literal as Lit};
5use conjure_core::metadata::Metadata;
6use conjure_core::rule_engine::{
7 register_rule, register_rule_set, ApplicationError, ApplicationError::RuleNotApplicable,
8 ApplicationResult, Reduction,
9};
10use itertools::izip;
11
12use crate::ast::SymbolTable;
13
14register_rule_set!("Constant", ());
15
16#[register_rule(("Constant", 9001))]
17fn apply_eval_constant(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
18 if let Expr::Atomic(_, Atom::Literal(_)) = expr {
19 return Err(ApplicationError::RuleNotApplicable);
20 }
21 eval_constant(expr)
22 .map(|c| Reduction::pure(Expr::Atomic(Metadata::new(), Atom::Literal(c))))
23 .ok_or(ApplicationError::RuleNotApplicable)
24}
25
26pub fn eval_constant(expr: &Expr) -> Option<Lit> {
31 match expr {
32 Expr::AbstractLiteral(_, _) => None,
33 Expr::FromSolution(_, _) => None,
35 Expr::DominanceRelation(_, _) => None,
37 Expr::UnsafeIndex(_, _, _) => None,
38 Expr::SafeIndex(_, _, _) => None,
40 Expr::UnsafeSlice(_, _, _) => None,
41 Expr::SafeSlice(_, _, _) => None,
43 Expr::InDomain(_, e, domain) => {
44 let Expr::Atomic(_, Atom::Literal(lit)) = e.as_ref() else {
45 return None;
46 };
47
48 domain.contains(lit).map(Into::into)
49 }
50 Expr::Atomic(_, Atom::Literal(c)) => Some(c.clone()),
51 Expr::Atomic(_, Atom::Reference(_c)) => None,
52 Expr::Abs(_, e) => un_op::<i32, i32>(|a| a.abs(), e).map(Lit::Int),
53 Expr::Eq(_, a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
54 .or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
55 .map(Lit::Bool),
56 Expr::Neq(_, a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Lit::Bool),
57 Expr::Lt(_, a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Lit::Bool),
58 Expr::Gt(_, a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Lit::Bool),
59 Expr::Leq(_, a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Lit::Bool),
60 Expr::Geq(_, a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Lit::Bool),
61
62 Expr::Not(_, expr) => un_op::<bool, bool>(|e| !e, expr).map(Lit::Bool),
63
64 Expr::And(_, e) => {
65 vec_lit_op::<bool, bool>(|e| e.iter().all(|&e| e), e.as_ref()).map(Lit::Bool)
66 }
67 Expr::Root(_, _) => None,
70 Expr::Or(_, e) => {
71 vec_lit_op::<bool, bool>(|e| e.iter().any(|&e| e), e.as_ref()).map(Lit::Bool)
72 }
73 Expr::Imply(_, box1, box2) => {
74 let a: &Atom = (&**box1).try_into().ok()?;
75 let b: &Atom = (&**box2).try_into().ok()?;
76
77 let a: bool = a.try_into().ok()?;
78 let b: bool = b.try_into().ok()?;
79
80 if a {
81 Some(Lit::Bool(b))
83 } else {
84 Some(Lit::Bool(true))
86 }
87 }
88
89 Expr::Sum(_, exprs) => vec_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Lit::Int),
90 Expr::Product(_, exprs) => vec_op::<i32, i32>(|e| e.iter().product(), exprs).map(Lit::Int),
91
92 Expr::FlatIneq(_, a, b, c) => {
93 let a: i32 = a.try_into().ok()?;
94 let b: i32 = b.try_into().ok()?;
95 let c: i32 = c.try_into().ok()?;
96
97 Some(Lit::Bool(a <= b + c))
98 }
99
100 Expr::FlatSumGeq(_, exprs, a) => {
101 let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
102 let n: i32 = atom.try_into().ok()?;
103 let acc = acc + n;
104 Some(acc)
105 })?;
106
107 Some(Lit::Bool(sum >= a.try_into().ok()?))
108 }
109 Expr::FlatSumLeq(_, exprs, a) => {
110 let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
111 let n: i32 = atom.try_into().ok()?;
112 let acc = acc + n;
113 Some(acc)
114 })?;
115
116 Some(Lit::Bool(sum >= a.try_into().ok()?))
117 }
118 Expr::Min(_, e) => {
119 opt_vec_lit_op::<i32, i32>(|e| e.iter().min().copied(), e.as_ref()).map(Lit::Int)
120 }
121 Expr::Max(_, e) => {
122 opt_vec_lit_op::<i32, i32>(|e| e.iter().max().copied(), e.as_ref()).map(Lit::Int)
123 }
124 Expr::UnsafeDiv(_, a, b) | Expr::SafeDiv(_, a, b) => {
125 if unwrap_expr::<i32>(b)? == 0 {
126 return None;
127 }
128 bin_op::<i32, i32>(|a, b| ((a as f32) / (b as f32)).floor() as i32, a, b).map(Lit::Int)
129 }
130 Expr::UnsafeMod(_, a, b) | Expr::SafeMod(_, a, b) => {
131 if unwrap_expr::<i32>(b)? == 0 {
132 return None;
133 }
134 bin_op::<i32, i32>(|a, b| a - b * (a as f32 / b as f32).floor() as i32, a, b)
135 .map(Lit::Int)
136 }
137 Expr::MinionDivEqUndefZero(_, a, b, c) => {
138 let a: i32 = a.try_into().ok()?;
140 let b: i32 = b.try_into().ok()?;
141 let c: i32 = c.try_into().ok()?;
142
143 if b == 0 {
144 return None;
145 }
146
147 let a = a as f32;
148 let b = b as f32;
149 let div: i32 = (a / b).floor() as i32;
150 Some(Lit::Bool(div == c))
151 }
152 Expr::Bubble(_, a, b) => bin_op::<bool, bool>(|a, b| a && b, a, b).map(Lit::Bool),
153
154 Expr::MinionReify(_, a, b) => {
155 let result = eval_constant(a)?;
156
157 let result: bool = result.try_into().ok()?;
158 let b: bool = b.try_into().ok()?;
159
160 Some(Lit::Bool(b == result))
161 }
162
163 Expr::MinionReifyImply(_, a, b) => {
164 let result = eval_constant(a)?;
165
166 let result: bool = result.try_into().ok()?;
167 let b: bool = b.try_into().ok()?;
168
169 if b {
170 Some(Lit::Bool(result))
171 } else {
172 Some(Lit::Bool(true))
173 }
174 }
175 Expr::MinionModuloEqUndefZero(_, a, b, c) => {
176 let a: i32 = a.try_into().ok()?;
184 let b: i32 = b.try_into().ok()?;
185 let c: i32 = c.try_into().ok()?;
186
187 if b == 0 {
188 return None;
189 }
190
191 let modulo = a - b * (a as f32 / b as f32).floor() as i32;
192 Some(Lit::Bool(modulo == c))
193 }
194
195 Expr::MinionPow(_, a, b, c) => {
196 let a: i32 = a.try_into().ok()?;
199 let b: i32 = b.try_into().ok()?;
200 let c: i32 = c.try_into().ok()?;
201
202 if a <= 0 {
203 return None;
204 }
205
206 if b <= 0 {
207 return None;
208 }
209
210 if c <= 0 {
211 return None;
212 }
213
214 Some(Lit::Bool(a ^ b == c))
215 }
216
217 Expr::AllDiff(_, e) => {
218 let es = e.clone().unwrap_list()?;
219 let mut lits: HashSet<Lit> = HashSet::new();
220 for expr in es {
221 let Expr::Atomic(_, Atom::Literal(x)) = expr else {
222 return None;
223 };
224 match x {
225 Lit::Int(_) | Lit::Bool(_) => {
226 if lits.contains(&x) {
227 return Some(Lit::Bool(false));
228 } else {
229 lits.insert(x.clone());
230 }
231 }
232 Lit::AbstractLiteral(_) => return None, }
234 }
235 Some(Lit::Bool(true))
236 }
237 Expr::FlatAllDiff(_, es) => {
238 let mut lits: HashSet<Lit> = HashSet::new();
239 for atom in es {
240 let Atom::Literal(x) = atom else {
241 return None;
242 };
243
244 match x {
245 Lit::Int(_) | Lit::Bool(_) => {
246 if lits.contains(x) {
247 return Some(Lit::Bool(false));
248 } else {
249 lits.insert(x.clone());
250 }
251 }
252 Lit::AbstractLiteral(_) => return None, }
254 }
255 Some(Lit::Bool(true))
256 }
257 Expr::FlatWatchedLiteral(_, _, _) => None,
258 Expr::AuxDeclaration(_, _, _) => None,
259 Expr::Neg(_, a) => {
260 let a: &Atom = a.try_into().ok()?;
261 let a: i32 = a.try_into().ok()?;
262 Some(Lit::Int(-a))
263 }
264 Expr::Minus(_, a, b) => {
265 let a: &Atom = a.try_into().ok()?;
266 let a: i32 = a.try_into().ok()?;
267
268 let b: &Atom = b.try_into().ok()?;
269 let b: i32 = b.try_into().ok()?;
270
271 Some(Lit::Int(a - b))
272 }
273 Expr::FlatMinusEq(_, a, b) => {
274 let a: i32 = a.try_into().ok()?;
275 let b: i32 = b.try_into().ok()?;
276 Some(Lit::Bool(a == -b))
277 }
278 Expr::FlatProductEq(_, a, b, c) => {
279 let a: i32 = a.try_into().ok()?;
280 let b: i32 = b.try_into().ok()?;
281 let c: i32 = c.try_into().ok()?;
282 Some(Lit::Bool(a * b == c))
283 }
284 Expr::FlatWeightedSumLeq(_, cs, vs, total) => {
285 let cs: Vec<i32> = cs
286 .iter()
287 .map(|x| TryInto::<i32>::try_into(x).ok())
288 .collect::<Option<Vec<i32>>>()?;
289 let vs: Vec<i32> = vs
290 .iter()
291 .map(|x| TryInto::<i32>::try_into(x).ok())
292 .collect::<Option<Vec<i32>>>()?;
293 let total: i32 = total.try_into().ok()?;
294
295 let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
296
297 Some(Lit::Bool(sum <= total))
298 }
299
300 Expr::FlatWeightedSumGeq(_, cs, vs, total) => {
301 let cs: Vec<i32> = cs
302 .iter()
303 .map(|x| TryInto::<i32>::try_into(x).ok())
304 .collect::<Option<Vec<i32>>>()?;
305 let vs: Vec<i32> = vs
306 .iter()
307 .map(|x| TryInto::<i32>::try_into(x).ok())
308 .collect::<Option<Vec<i32>>>()?;
309 let total: i32 = total.try_into().ok()?;
310
311 let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
312
313 Some(Lit::Bool(sum >= total))
314 }
315 Expr::FlatAbsEq(_, x, y) => {
316 let x: i32 = x.try_into().ok()?;
317 let y: i32 = y.try_into().ok()?;
318
319 Some(Lit::Bool(x == y.abs()))
320 }
321
322 Expr::UnsafePow(_, a, b) | Expr::SafePow(_, a, b) => {
323 let a: &Atom = a.try_into().ok()?;
324 let a: i32 = a.try_into().ok()?;
325
326 let b: &Atom = b.try_into().ok()?;
327 let b: i32 = b.try_into().ok()?;
328
329 if (a != 0 || b != 0) && b >= 0 {
330 Some(Lit::Int(a ^ b))
331 } else {
332 None
333 }
334 }
335 Expr::Scope(_, _) => None,
336 }
337}
338
339#[register_rule(("Constant", 9001))]
343fn eval_root(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
344 let Expr::Root(_, exprs) = expr else {
348 return Err(RuleNotApplicable);
349 };
350
351 match exprs.len() {
352 0 => Ok(Reduction::pure(Expr::Root(
353 Metadata::new(),
354 vec![true.into()],
355 ))),
356 1 => Err(RuleNotApplicable),
357 _ => {
358 let lit =
359 vec_op::<bool, bool>(|e| e.iter().all(|&e| e), exprs).ok_or(RuleNotApplicable)?;
360
361 Ok(Reduction::pure(Expr::Root(
362 Metadata::new(),
363 vec![lit.into()],
364 )))
365 }
366 }
367}
368
369fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
370where
371 T: TryFrom<Lit>,
372{
373 let a = unwrap_expr::<T>(a)?;
374 Some(f(a))
375}
376
377fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
378where
379 T: TryFrom<Lit>,
380{
381 let a = unwrap_expr::<T>(a)?;
382 let b = unwrap_expr::<T>(b)?;
383 Some(f(a, b))
384}
385
386#[allow(dead_code)]
387fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
388where
389 T: TryFrom<Lit>,
390{
391 let a = unwrap_expr::<T>(a)?;
392 let b = unwrap_expr::<T>(b)?;
393 let c = unwrap_expr::<T>(c)?;
394 Some(f(a, b, c))
395}
396
397fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &[Expr]) -> Option<A>
398where
399 T: TryFrom<Lit>,
400{
401 let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
402 Some(f(a))
403}
404
405fn vec_lit_op<T, A>(f: fn(Vec<T>) -> A, a: &Expr) -> Option<A>
406where
407 T: TryFrom<Lit>,
408{
409 let a = a.clone().unwrap_list()?;
410 let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
411 Some(f(a))
412}
413
414fn opt_vec_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &[Expr]) -> Option<A>
415where
416 T: TryFrom<Lit>,
417{
418 let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
419 f(a)
420}
421
422fn opt_vec_lit_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &Expr) -> Option<A>
423where
424 T: TryFrom<Lit>,
425{
426 let a = a.clone().unwrap_list()?;
427 let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
429 f(a)
430}
431
432#[allow(dead_code)]
433fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &[Expr], b: &Expr) -> Option<A>
434where
435 T: TryFrom<Lit>,
436{
437 let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
438 let b = unwrap_expr::<T>(b)?;
439 Some(f(a, b))
440}
441
442fn unwrap_expr<T: TryFrom<Lit>>(expr: &Expr) -> Option<T> {
443 let c = eval_constant(expr)?;
444 TryInto::<T>::try_into(c).ok()
445}
446
447#[cfg(test)]
448mod tests {
449 use crate::rules::eval_constant;
450 use conjure_core::ast::{Atom, Expression, Literal};
451
452 #[test]
453 fn div_by_zero() {
454 let expr = Expression::UnsafeDiv(
455 Default::default(),
456 Box::new(Expression::Atomic(
457 Default::default(),
458 Atom::Literal(Literal::Int(1)),
459 )),
460 Box::new(Expression::Atomic(
461 Default::default(),
462 Atom::Literal(Literal::Int(0)),
463 )),
464 );
465 assert_eq!(eval_constant(&expr), None);
466 }
467
468 #[test]
469 fn safediv_by_zero() {
470 let expr = Expression::SafeDiv(
471 Default::default(),
472 Box::new(Expression::Atomic(
473 Default::default(),
474 Atom::Literal(Literal::Int(1)),
475 )),
476 Box::new(Expression::Atomic(
477 Default::default(),
478 Atom::Literal(Literal::Int(0)),
479 )),
480 );
481 assert_eq!(eval_constant(&expr), None);
482 }
483}