use std::fmt::{Display, Formatter};
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use enum_compatability_macro::document_compatibility;
use uniplate::derive::Uniplate;
use uniplate::Biplate;
use crate::ast::literals::Literal;
use crate::ast::symbol_table::{Name, SymbolTable};
use crate::ast::Factor;
use crate::ast::ReturnType;
use crate::metadata::Metadata;
use super::{Domain, Range};
#[document_compatibility]
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate)]
#[uniplate(walk_into=[Factor])]
#[biplate(to=Literal)]
#[biplate(to=Metadata)]
#[biplate(to=Factor)]
#[biplate(to=Name)]
pub enum Expression {
Bubble(Metadata, Box<Expression>, Box<Expression>),
FactorE(Metadata, Factor),
#[compatible(Minion, JsonInput)]
Sum(Metadata, Vec<Expression>),
#[compatible(JsonInput)]
Min(Metadata, Vec<Expression>),
#[compatible(JsonInput)]
Max(Metadata, Vec<Expression>),
#[compatible(JsonInput, SAT)]
Not(Metadata, Box<Expression>),
#[compatible(JsonInput, SAT)]
Or(Metadata, Vec<Expression>),
#[compatible(JsonInput, SAT)]
And(Metadata, Vec<Expression>),
#[compatible(JsonInput)]
Eq(Metadata, Box<Expression>, Box<Expression>),
#[compatible(JsonInput)]
Neq(Metadata, Box<Expression>, Box<Expression>),
#[compatible(JsonInput)]
Geq(Metadata, Box<Expression>, Box<Expression>),
#[compatible(JsonInput)]
Leq(Metadata, Box<Expression>, Box<Expression>),
#[compatible(JsonInput)]
Gt(Metadata, Box<Expression>, Box<Expression>),
#[compatible(JsonInput)]
Lt(Metadata, Box<Expression>, Box<Expression>),
SafeDiv(Metadata, Box<Expression>, Box<Expression>),
#[compatible(JsonInput)]
UnsafeDiv(Metadata, Box<Expression>, Box<Expression>),
SumEq(Metadata, Vec<Expression>, Box<Expression>),
#[compatible(Minion)]
SumGeq(Metadata, Vec<Expression>, Box<Expression>),
#[compatible(Minion)]
SumLeq(Metadata, Vec<Expression>, Box<Expression>),
#[compatible(Minion)]
DivEq(Metadata, Box<Expression>, Box<Expression>, Box<Expression>),
#[compatible(Minion)]
Ineq(Metadata, Box<Expression>, Box<Expression>, Box<Expression>),
#[compatible(Minion)]
AllDiff(Metadata, Vec<Expression>),
#[compatible(Minion)]
WatchedLiteral(Metadata, Name, Literal),
#[compatible(Minion)]
Reify(Metadata, Box<Expression>, Box<Expression>),
#[compatible(Minion)]
AuxDeclaration(Metadata, Name, Box<Expression>),
}
fn expr_vec_to_domain_i32(
exprs: &[Expression],
op: fn(i32, i32) -> Option<i32>,
vars: &SymbolTable,
) -> Option<Domain> {
let domains: Vec<Option<_>> = exprs.iter().map(|e| e.domain_of(vars)).collect();
domains
.into_iter()
.reduce(|a, b| a.and_then(|x| b.and_then(|y| x.apply_i32(op, &y))))
.flatten()
}
fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> (i32, i32) {
let mut min = i32::MAX;
let mut max = i32::MIN;
for r in ranges {
match r {
Range::Single(i) => {
if *i < min {
min = *i;
}
if *i > max {
max = *i;
}
}
Range::Bounded(i, j) => {
if *i < min {
min = *i;
}
if *j > max {
max = *j;
}
}
}
}
(min, max)
}
impl Expression {
pub fn domain_of(&self, vars: &SymbolTable) -> Option<Domain> {
let ret = match self {
Expression::FactorE(_, Factor::Reference(name)) => Some(vars.get(name)?.domain.clone()),
Expression::FactorE(_, Factor::Literal(Literal::Int(n))) => {
Some(Domain::IntDomain(vec![Range::Single(*n)]))
}
Expression::FactorE(_, Factor::Literal(Literal::Bool(_))) => Some(Domain::BoolDomain),
Expression::Sum(_, exprs) => expr_vec_to_domain_i32(exprs, |x, y| Some(x + y), vars),
Expression::Min(_, exprs) => {
expr_vec_to_domain_i32(exprs, |x, y| Some(if x < y { x } else { y }), vars)
}
Expression::Max(_, exprs) => {
expr_vec_to_domain_i32(exprs, |x, y| Some(if x > y { x } else { y }), vars)
}
Expression::UnsafeDiv(_, a, b) | Expression::SafeDiv(_, a, b) => {
a.domain_of(vars)?.apply_i32(
|x, y| if y != 0 { Some(x / y) } else { None },
&b.domain_of(vars)?,
)
}
_ => todo!("Calculate domain of {:?}", self),
};
match ret {
Some(Domain::IntDomain(ranges)) if ranges.len() > 1 => {
let (min, max) = range_vec_bounds_i32(&ranges);
Some(Domain::IntDomain(vec![Range::Bounded(min, max)]))
}
_ => ret,
}
}
pub fn get_meta(&self) -> Metadata {
<Expression as Biplate<Metadata>>::children_bi(self)[0].clone()
}
pub fn set_meta(&self, meta: Metadata) {
<Expression as Biplate<Metadata>>::transform_bi(self, Arc::new(move |_| meta.clone()));
}
pub fn can_be_undefined(&self) -> bool {
match self {
Expression::FactorE(_, _) => false,
_ => true,
}
}
pub fn return_type(&self) -> Option<ReturnType> {
match self {
Expression::FactorE(_, Factor::Literal(Literal::Int(_))) => Some(ReturnType::Int),
Expression::FactorE(_, Factor::Literal(Literal::Bool(_))) => Some(ReturnType::Bool),
Expression::FactorE(_, Factor::Reference(_)) => None,
Expression::Sum(_, _) => Some(ReturnType::Int),
Expression::Min(_, _) => Some(ReturnType::Int),
Expression::Max(_, _) => Some(ReturnType::Int),
Expression::Not(_, _) => Some(ReturnType::Bool),
Expression::Or(_, _) => Some(ReturnType::Bool),
Expression::And(_, _) => Some(ReturnType::Bool),
Expression::Eq(_, _, _) => Some(ReturnType::Bool),
Expression::Neq(_, _, _) => Some(ReturnType::Bool),
Expression::Geq(_, _, _) => Some(ReturnType::Bool),
Expression::Leq(_, _, _) => Some(ReturnType::Bool),
Expression::Gt(_, _, _) => Some(ReturnType::Bool),
Expression::Lt(_, _, _) => Some(ReturnType::Bool),
Expression::SafeDiv(_, _, _) => Some(ReturnType::Int),
Expression::UnsafeDiv(_, _, _) => Some(ReturnType::Int),
Expression::SumEq(_, _, _) => Some(ReturnType::Bool),
Expression::SumGeq(_, _, _) => Some(ReturnType::Bool),
Expression::SumLeq(_, _, _) => Some(ReturnType::Bool),
Expression::DivEq(_, _, _, _) => Some(ReturnType::Bool),
Expression::Ineq(_, _, _, _) => Some(ReturnType::Bool),
Expression::AllDiff(_, _) => Some(ReturnType::Bool),
Expression::Bubble(_, _, _) => None, Expression::WatchedLiteral(_, _, _) => Some(ReturnType::Bool),
Expression::Reify(_, _, _) => Some(ReturnType::Bool),
Expression::AuxDeclaration(_, _, _) => Some(ReturnType::Bool),
}
}
pub fn is_clean(&self) -> bool {
let metadata = self.get_meta();
metadata.clean
}
pub fn set_clean(&mut self, bool_value: bool) {
let mut metadata = self.get_meta();
metadata.clean = bool_value;
self.set_meta(metadata);
}
pub fn as_factor(&self) -> Option<Factor> {
if let Expression::FactorE(_m, f) = self {
Some(f.clone())
} else {
None
}
}
}
fn display_expressions(expressions: &[Expression]) -> String {
format!(
"[{}]",
expressions
.iter()
.map(|e| e.to_string())
.collect::<Vec<String>>()
.join(", ")
)
}
impl From<i32> for Expression {
fn from(i: i32) -> Self {
Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Int(i)))
}
}
impl From<bool> for Expression {
fn from(b: bool) -> Self {
Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Bool(b)))
}
}
impl From<Factor> for Expression {
fn from(value: Factor) -> Self {
Expression::FactorE(Metadata::new(), value)
}
}
impl Display for Expression {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match &self {
Expression::FactorE(_, factor) => factor.fmt(f),
Expression::Sum(_, expressions) => {
write!(f, "Sum({})", display_expressions(expressions))
}
Expression::Min(_, expressions) => {
write!(f, "Min({})", display_expressions(expressions))
}
Expression::Max(_, expressions) => {
write!(f, "Max({})", display_expressions(expressions))
}
Expression::Not(_, expr_box) => {
write!(f, "Not({})", expr_box.clone())
}
Expression::Or(_, expressions) => {
write!(f, "Or({})", display_expressions(expressions))
}
Expression::And(_, expressions) => {
write!(f, "And({})", display_expressions(expressions))
}
Expression::Eq(_, box1, box2) => {
write!(f, "({} = {})", box1.clone(), box2.clone())
}
Expression::Neq(_, box1, box2) => {
write!(f, "({} != {})", box1.clone(), box2.clone())
}
Expression::Geq(_, box1, box2) => {
write!(f, "({} >= {})", box1.clone(), box2.clone())
}
Expression::Leq(_, box1, box2) => {
write!(f, "({} <= {})", box1.clone(), box2.clone())
}
Expression::Gt(_, box1, box2) => {
write!(f, "({} > {})", box1.clone(), box2.clone())
}
Expression::Lt(_, box1, box2) => {
write!(f, "({} < {})", box1.clone(), box2.clone())
}
Expression::SumEq(_, expressions, expr_box) => {
write!(
f,
"SumEq({}, {})",
display_expressions(expressions),
expr_box.clone()
)
}
Expression::SumGeq(_, box1, box2) => {
write!(f, "SumGeq({}, {})", display_expressions(box1), box2.clone())
}
Expression::SumLeq(_, box1, box2) => {
write!(f, "SumLeq({}, {})", display_expressions(box1), box2.clone())
}
Expression::Ineq(_, box1, box2, box3) => write!(
f,
"Ineq({}, {}, {})",
box1.clone(),
box2.clone(),
box3.clone()
),
Expression::AllDiff(_, expressions) => {
write!(f, "AllDiff({})", display_expressions(expressions))
}
Expression::Bubble(_, box1, box2) => {
write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
}
Expression::SafeDiv(_, box1, box2) => {
write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
}
Expression::UnsafeDiv(_, box1, box2) => {
write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
}
Expression::DivEq(_, box1, box2, box3) => {
write!(
f,
"DivEq({}, {}, {})",
box1.clone(),
box2.clone(),
box3.clone()
)
}
Expression::WatchedLiteral(_, x, l) => {
write!(f, "WatchedLiteral({},{})", x, l)
}
Expression::Reify(_, box1, box2) => {
write!(f, "Reify({}, {})", box1.clone(), box2.clone())
}
Expression::AuxDeclaration(_, n, e) => {
write!(f, "{} =aux {}", n, e.clone())
}
}
}
}
#[cfg(test)]
mod tests {
use crate::ast::DecisionVariable;
use super::*;
#[test]
fn test_domain_of_constant_sum() {
let c1 = Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Int(1)));
let c2 = Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Int(2)));
let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
assert_eq!(
sum.domain_of(&SymbolTable::new()),
Some(Domain::IntDomain(vec![Range::Single(3)]))
);
}
#[test]
fn test_domain_of_constant_invalid_type() {
let c1 = Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Int(1)));
let c2 = Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Bool(true)));
let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
assert_eq!(sum.domain_of(&SymbolTable::new()), None);
}
#[test]
fn test_domain_of_empty_sum() {
let sum = Expression::Sum(Metadata::new(), vec![]);
assert_eq!(sum.domain_of(&SymbolTable::new()), None);
}
#[test]
fn test_domain_of_reference() {
let reference =
Expression::FactorE(Metadata::new(), Factor::Reference(Name::MachineName(0)));
let mut vars = SymbolTable::new();
vars.insert(
Name::MachineName(0),
DecisionVariable::new(Domain::IntDomain(vec![Range::Single(1)])),
);
assert_eq!(
reference.domain_of(&vars),
Some(Domain::IntDomain(vec![Range::Single(1)]))
);
}
#[test]
fn test_domain_of_reference_not_found() {
let reference =
Expression::FactorE(Metadata::new(), Factor::Reference(Name::MachineName(0)));
assert_eq!(reference.domain_of(&SymbolTable::new()), None);
}
#[test]
fn test_domain_of_reference_sum_single() {
let reference =
Expression::FactorE(Metadata::new(), Factor::Reference(Name::MachineName(0)));
let mut vars = SymbolTable::new();
vars.insert(
Name::MachineName(0),
DecisionVariable::new(Domain::IntDomain(vec![Range::Single(1)])),
);
let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
assert_eq!(
sum.domain_of(&vars),
Some(Domain::IntDomain(vec![Range::Single(2)]))
);
}
#[test]
fn test_domain_of_reference_sum_bounded() {
let reference =
Expression::FactorE(Metadata::new(), Factor::Reference(Name::MachineName(0)));
let mut vars = SymbolTable::new();
vars.insert(
Name::MachineName(0),
DecisionVariable::new(Domain::IntDomain(vec![Range::Bounded(1, 2)])),
);
let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
assert_eq!(
sum.domain_of(&vars),
Some(Domain::IntDomain(vec![Range::Bounded(2, 4)]))
);
}
}