Lines
0 %
Functions
#![allow(dead_code)]
use crate::ast::{AbstractLiteral, Atom, Expression as Expr, Literal as Lit, Metadata, matrix};
use crate::into_matrix;
use itertools::{Itertools as _, izip};
use std::cmp::Ordering as CmpOrdering;
use std::collections::HashSet;
/// Simplify an expression to a constant if possible
/// Returns:
/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
/// `Some(Const)` if the expression can be simplified to a constant
pub fn eval_constant(expr: &Expr) -> Option<Lit> {
match expr {
Expr::Supset(_, a, b) => match (a.as_ref(), b.as_ref()) {
(
Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
) => {
let a_set: HashSet<Lit> = a.iter().cloned().collect();
let b_set: HashSet<Lit> = b.iter().cloned().collect();
if a_set.difference(&b_set).count() > 0 {
Some(Lit::Bool(a_set.is_superset(&b_set)))
} else {
Some(Lit::Bool(false))
}
_ => None,
},
Expr::SupsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
) => Some(Lit::Bool(
a.iter()
.cloned()
.collect::<HashSet<Lit>>()
.is_superset(&b.iter().cloned().collect::<HashSet<Lit>>()),
)),
Expr::Subset(_, a, b) => match (a.as_ref(), b.as_ref()) {
if b_set.difference(&a_set).count() > 0 {
Some(Lit::Bool(a_set.is_subset(&b_set)))
Expr::SubsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
.is_subset(&b.iter().cloned().collect::<HashSet<Lit>>()),
Expr::Intersect(_, a, b) => match (a.as_ref(), b.as_ref()) {
let mut res: Vec<Lit> = Vec::new();
for lit in a.iter() {
if b.contains(lit) && !res.contains(lit) {
res.push(lit.clone());
Some(Lit::AbstractLiteral(AbstractLiteral::Set(res)))
Expr::Union(_, a, b) => match (a.as_ref(), b.as_ref()) {
for lit in b.iter() {
if !res.contains(lit) {
Expr::In(_, a, b) => {
if let (
Expr::Atomic(_, Atom::Literal(Lit::Int(c))),
Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(d)))),
) = (a.as_ref(), b.as_ref())
{
for lit in d.iter() {
if let Lit::Int(x) = lit
&& c == x
return Some(Lit::Bool(true));
None
Expr::FromSolution(_, _) => None,
Expr::DominanceRelation(_, _) => None,
Expr::InDomain(_, e, domain) => {
let Expr::Atomic(_, Atom::Literal(lit)) = e.as_ref() else {
return None;
};
domain.contains(lit).ok().map(Into::into)
Expr::Atomic(_, Atom::Literal(c)) => Some(c.clone()),
Expr::Atomic(_, Atom::Reference(_)) => None,
Expr::AbstractLiteral(_, a) => {
if let AbstractLiteral::Set(s) = a {
let mut copy = Vec::new();
for expr in s.iter() {
if let Expr::Atomic(_, Atom::Literal(lit)) = expr {
copy.push(lit.clone());
Some(Lit::AbstractLiteral(AbstractLiteral::Set(copy)))
Expr::Comprehension(_, _) => None,
Expr::UnsafeIndex(_, subject, indices) | Expr::SafeIndex(_, subject, indices) => {
let subject: Lit = subject.as_ref().clone().into_literal()?;
let indices: Vec<Lit> = indices
.iter()
.map(|x| x.into_literal())
.collect::<Option<Vec<Lit>>>()?;
match subject {
Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) => {
matrix::flatten_enumerate(subject)
.find(|(i, _)| i == &indices)
.map(|(_, x)| x)
Lit::AbstractLiteral(subject @ AbstractLiteral::Tuple(_)) => {
let AbstractLiteral::Tuple(elems) = subject else {
assert!(indices.len() == 1, "nested tuples not supported yet");
let Lit::Int(index) = indices[0].clone() else {
if elems.len() < index as usize || index < 1 {
// -1 for 0-indexing vs 1-indexing
let item = elems[index as usize - 1].clone();
Some(item)
Lit::AbstractLiteral(subject @ AbstractLiteral::Record(_)) => {
let AbstractLiteral::Record(elems) = subject else {
assert!(indices.len() == 1, "nested record not supported yet");
Some(item.value)
Expr::UnsafeSlice(_, subject, indices) | Expr::SafeSlice(_, subject, indices) => {
let Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) = subject else {
let hole_dim = indices
.position(|x| x.is_none())
.expect("slice expression should have a hole dimension");
let missing_domain = matrix::index_domains(subject.clone())[hole_dim].clone();
let indices: Vec<Option<Lit>> = indices
.map(|x| {
// the outer option represents success of this iterator, the inner the index
// slice.
match x {
Some(x) => x.into_literal().map(Some),
None => Some(None),
})
.collect::<Option<Vec<Option<Lit>>>>()?;
let indices_in_slice: Vec<Vec<Lit>> = missing_domain
.values()
.ok()?
.map(|i| {
let mut indices = indices.clone();
indices[hole_dim] = Some(i);
// These unwraps will only fail if we have multiple holes.
// As this is invalid, panicking is fine.
indices.into_iter().map(|x| x.unwrap()).collect_vec()
.collect_vec();
// Note: indices_in_slice is not necessarily sorted, so this is the best way.
let elems = matrix::flatten_enumerate(subject)
.filter(|(i, _)| indices_in_slice.contains(i))
.map(|(_, elem)| elem)
.collect();
Some(Lit::AbstractLiteral(into_matrix![elems]))
Expr::Abs(_, e) => un_op::<i32, i32>(|a| a.abs(), e).map(Lit::Int),
Expr::Eq(_, a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
.or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
.map(Lit::Bool),
Expr::Neq(_, a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Lit::Bool),
Expr::Lt(_, a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Lit::Bool),
Expr::Gt(_, a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Lit::Bool),
Expr::Leq(_, a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Lit::Bool),
Expr::Geq(_, a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Lit::Bool),
Expr::Not(_, expr) => un_op::<bool, bool>(|e| !e, expr).map(Lit::Bool),
Expr::And(_, e) => {
vec_lit_op::<bool, bool>(|e| e.iter().all(|&e| e), e.as_ref()).map(Lit::Bool)
Expr::Root(_, _) => None,
Expr::Or(_, es) => {
// possibly cheating; definitely should be in partial eval instead
for e in (**es).clone().unwrap_list()? {
if let Expr::Atomic(_, Atom::Literal(Lit::Bool(true))) = e {
vec_lit_op::<bool, bool>(|e| e.iter().any(|&e| e), es.as_ref()).map(Lit::Bool)
Expr::Imply(_, box1, box2) => {
let a: &Atom = (&**box1).try_into().ok()?;
let b: &Atom = (&**box2).try_into().ok()?;
let a: bool = a.try_into().ok()?;
let b: bool = b.try_into().ok()?;
if a {
// true -> b ~> b
Some(Lit::Bool(b))
// false -> b ~> true
Some(Lit::Bool(true))
Expr::Iff(_, box1, box2) => {
Some(Lit::Bool(a == b))
Expr::Sum(_, exprs) => vec_lit_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Lit::Int),
Expr::Product(_, exprs) => {
vec_lit_op::<i32, i32>(|e| e.iter().product(), exprs).map(Lit::Int)
Expr::FlatIneq(_, a, b, c) => {
let a: i32 = a.try_into().ok()?;
let b: i32 = b.try_into().ok()?;
let c: i32 = c.try_into().ok()?;
Some(Lit::Bool(a <= b + c))
Expr::FlatSumGeq(_, exprs, a) => {
let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
let n: i32 = atom.try_into().ok()?;
let acc = acc + n;
Some(acc)
})?;
Some(Lit::Bool(sum >= a.try_into().ok()?))
Expr::FlatSumLeq(_, exprs, a) => {
Expr::Min(_, e) => {
opt_vec_lit_op::<i32, i32>(|e| e.iter().min().copied(), e.as_ref()).map(Lit::Int)
Expr::Max(_, e) => {
opt_vec_lit_op::<i32, i32>(|e| e.iter().max().copied(), e.as_ref()).map(Lit::Int)
Expr::UnsafeDiv(_, a, b) | Expr::SafeDiv(_, a, b) => {
if unwrap_expr::<i32>(b)? == 0 {
bin_op::<i32, i32>(|a, b| ((a as f32) / (b as f32)).floor() as i32, a, b).map(Lit::Int)
Expr::UnsafeMod(_, a, b) | Expr::SafeMod(_, a, b) => {
bin_op::<i32, i32>(|a, b| a - b * (a as f32 / b as f32).floor() as i32, a, b)
.map(Lit::Int)
Expr::MinionDivEqUndefZero(_, a, b, c) => {
// div always rounds down
if b == 0 {
let a = a as f32;
let b = b as f32;
let div: i32 = (a / b).floor() as i32;
Some(Lit::Bool(div == c))
Expr::Bubble(_, a, b) => bin_op::<bool, bool>(|a, b| a && b, a, b).map(Lit::Bool),
Expr::MinionReify(_, a, b) => {
let result = eval_constant(a)?;
let result: bool = result.try_into().ok()?;
Some(Lit::Bool(b == result))
Expr::MinionReifyImply(_, a, b) => {
if b {
Some(Lit::Bool(result))
Expr::MinionModuloEqUndefZero(_, a, b, c) => {
// From Savile Row. Same semantics as division.
//
// a - (b * floor(a/b))
// We don't use % as it has the same semantics as /. We don't use / as we want to round
// down instead, not towards zero.
let modulo = a - b * (a as f32 / b as f32).floor() as i32;
Some(Lit::Bool(modulo == c))
Expr::MinionPow(_, a, b, c) => {
// only available for positive a b c
if a <= 0 {
if b <= 0 {
if c <= 0 {
Some(Lit::Bool(a ^ b == c))
Expr::MinionWInSet(_, _, _) => None,
Expr::MinionWInIntervalSet(_, x, intervals) => {
let x_lit: &Lit = x.try_into().ok()?;
let x_lit = match x_lit.clone() {
Lit::Int(i) => Some(i),
Lit::Bool(true) => Some(1),
Lit::Bool(false) => Some(0),
}?;
let mut intervals = intervals.iter();
loop {
let Some(lower) = intervals.next() else {
break;
let Some(upper) = intervals.next() else {
if &x_lit >= lower && &x_lit <= upper {
Expr::Flatten(_, _, _) => {
// TODO
Expr::AllDiff(_, e) => {
let es = (**e).clone().unwrap_list()?;
let mut lits: HashSet<Lit> = HashSet::new();
for expr in es {
let Expr::Atomic(_, Atom::Literal(x)) = expr else {
Lit::Int(_) | Lit::Bool(_) => {
if lits.contains(&x) {
return Some(Lit::Bool(false));
lits.insert(x.clone());
Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
Expr::FlatAllDiff(_, es) => {
for atom in es {
let Atom::Literal(x) = atom else {
if lits.contains(x) {
Expr::FlatWatchedLiteral(_, _, _) => None,
Expr::AuxDeclaration(_, _, _) => None,
Expr::Neg(_, a) => {
let a: &Atom = a.try_into().ok()?;
Some(Lit::Int(-a))
Expr::Minus(_, a, b) => {
let b: &Atom = b.try_into().ok()?;
Some(Lit::Int(a - b))
Expr::FlatMinusEq(_, a, b) => {
Some(Lit::Bool(a == -b))
Expr::FlatProductEq(_, a, b, c) => {
Some(Lit::Bool(a * b == c))
Expr::FlatWeightedSumLeq(_, cs, vs, total) => {
let cs: Vec<i32> = cs
.map(|x| TryInto::<i32>::try_into(x).ok())
.collect::<Option<Vec<i32>>>()?;
let vs: Vec<i32> = vs
let total: i32 = total.try_into().ok()?;
let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
Some(Lit::Bool(sum <= total))
Expr::FlatWeightedSumGeq(_, cs, vs, total) => {
Some(Lit::Bool(sum >= total))
Expr::FlatAbsEq(_, x, y) => {
let x: i32 = x.try_into().ok()?;
let y: i32 = y.try_into().ok()?;
Some(Lit::Bool(x == y.abs()))
Expr::UnsafePow(_, a, b) | Expr::SafePow(_, a, b) => {
if (a != 0 || b != 0) && b >= 0 {
Some(Lit::Int(a.pow(b as u32)))
Expr::Scope(_, _) => None,
Expr::Metavar(_, _) => None,
Expr::MinionElementOne(_, _, _, _) => None,
Expr::ToInt(_, expression) => {
let lit = (**expression).clone().into_literal()?;
match lit {
Lit::Int(_) => Some(lit),
Lit::Bool(true) => Some(Lit::Int(1)),
Lit::Bool(false) => Some(Lit::Int(0)),
Expr::SATInt(_, _) => None,
Expr::PairwiseSum(_, a, b) => {
match ((**a).clone().into_literal()?, (**b).clone().into_literal()?) {
(Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int + b_int)),
Expr::PairwiseProduct(_, a, b) => {
(Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int * b_int)),
Expr::Defined(_, _) => todo!(),
Expr::Range(_, _) => todo!(),
Expr::Image(_, _, _) => todo!(),
Expr::ImageSet(_, _, _) => todo!(),
Expr::PreImage(_, _, _) => todo!(),
Expr::Inverse(_, _, _) => todo!(),
Expr::Restrict(_, _, _) => todo!(),
Expr::LexLt(_, a, b) => {
let lt = vec_expr_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
pairs
.find_map(|(a, b)| match a.cmp(b) {
CmpOrdering::Less => Some(true), // First difference is <
CmpOrdering::Greater => Some(false), // First difference is >
CmpOrdering::Equal => None, // No difference
.unwrap_or(a_len < b_len) // [1,1] <lex [1,1,x]
Some(lt.into())
Expr::LexLeq(_, a, b) => {
CmpOrdering::Less => Some(true),
CmpOrdering::Greater => Some(false),
CmpOrdering::Equal => None,
.unwrap_or(a_len <= b_len) // [1,1] <=lex [1,1,x]
Expr::LexGt(_, a, b) => eval_constant(&Expr::LexLt(Metadata::new(), b.clone(), a.clone())),
Expr::LexGeq(_, a, b) => {
eval_constant(&Expr::LexLeq(Metadata::new(), b.clone(), a.clone()))
Expr::FlatLexLt(_, a, b) => {
let lt = atoms_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
.unwrap_or(a_len < b_len)
Expr::FlatLexLeq(_, a, b) => {
.unwrap_or(a_len <= b_len)
pub fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
where
T: TryFrom<Lit>,
let a = unwrap_expr::<T>(a)?;
Some(f(a))
pub fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
let b = unwrap_expr::<T>(b)?;
Some(f(a, b))
#[allow(dead_code)]
pub fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
let c = unwrap_expr::<T>(c)?;
Some(f(a, b, c))
pub fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &[Expr]) -> Option<A>
let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
pub fn vec_lit_op<T, A>(f: fn(Vec<T>) -> A, a: &Expr) -> Option<A>
// we don't care about preserving indices here, as we will be getting rid of the vector
// anyways!
let a = a.clone().unwrap_matrix_unchecked()?.0;
type PairsCallback<T, A> = fn(Vec<(T, T)>, (usize, usize)) -> A;
/// Calls the given function on each consecutive pair of elements in the list expressions.
/// Also passes the length of the two lists.
fn vec_expr_pairs_op<T, A>(a: &Expr, b: &Expr, f: PairsCallback<T, A>) -> Option<A>
let a_exprs = a.clone().unwrap_matrix_unchecked()?.0;
let b_exprs = b.clone().unwrap_matrix_unchecked()?.0;
let lens = (a_exprs.len(), b_exprs.len());
let lit_pairs = std::iter::zip(a_exprs, b_exprs)
.map(|(a, b)| Some((unwrap_expr(&a)?, unwrap_expr(&b)?)))
.collect::<Option<Vec<(T, T)>>>()?;
Some(f(lit_pairs, lens))
/// Same as [`vec_expr_pairs_op`], but over slices of atoms.
fn atoms_pairs_op<T, A>(a: &[Atom], b: &[Atom], f: PairsCallback<T, A>) -> Option<A>
T: TryFrom<Atom>,
let lit_pairs = Iterator::zip(a.iter(), b.iter())
.map(|(a, b)| Some((a.clone().try_into().ok()?, b.clone().try_into().ok()?)))
Some(f(lit_pairs, (a.len(), b.len())))
pub fn opt_vec_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &[Expr]) -> Option<A>
f(a)
pub fn opt_vec_lit_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &Expr) -> Option<A>
let a = a.clone().unwrap_list()?;
// FIXME: deal with explicit matrix domains
pub fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &[Expr], b: &Expr) -> Option<A>
pub fn unwrap_expr<T: TryFrom<Lit>>(expr: &Expr) -> Option<T> {
let c = eval_constant(expr)?;
TryInto::<T>::try_into(c).ok()