1
use crate::ast::{Atom, Domain, Literal, Moo, Name, Range};
2
use crate::bug;
3
use crate::solver::{SolverError, SolverResult};
4
use conjure_cp_core::ast::GroundDomain;
5
use z3::{Sort, Symbol, ast::*};
6

            
7
use super::store::SymbolStore;
8
use super::{IntTheory, TheoryConfig};
9

            
10
/// Use 32-bit 2's complement signed bit-vectors
11
pub const BV_SIZE: u32 = 32;
12

            
13
/// A function which encodes a restriction for a specific variable. Given an AST of the correct
14
/// sort, constructs a boolean assertion which will ensure the variable has the correct domain.
15
type RestrictFn = Box<dyn Fn(&Dynamic) -> Bool>;
16

            
17
/// Returns the Oxide domain as a Z3 sort, along with a function to restrict a variable of that sort
18
/// to the original domain's restrictions.
19
pub fn domain_to_sort(
20
    domain: &GroundDomain,
21
    theories: &TheoryConfig,
22
) -> SolverResult<(Sort, RestrictFn)> {
23
    use IntTheory::{Bv, Lia};
24

            
25
    match (theories.ints, domain) {
26
        // Booleans of course have the same domain in SMT, so no restriction required
27
        (_, GroundDomain::Bool) => Ok((Sort::bool(), Box::new(|_| Bool::from_bool(true)))),
28

            
29
        // Return a disjunction of the restrictions each range of the domain enforces
30
        // I.e. `x: int(1, 3..5)` -> `or([x = 1, x >= 3 /\ x <= 5])`
31
        (Lia, GroundDomain::Int(ranges)) => {
32
            let ranges = ranges.clone();
33
            let restrict_fn = move |ast: &Dynamic| {
34
                let int = ast.as_int().unwrap();
35
                let restrictions: Vec<_> = ranges
36
                    .iter()
37
                    .map(|range| int_range_to_int_restriction(&int, range))
38
                    .collect();
39
                Bool::or(restrictions.as_slice())
40
            };
41
            Ok((Sort::int(), Box::new(restrict_fn)))
42
        }
43
        (Bv, GroundDomain::Int(ranges)) => {
44
            let ranges = ranges.clone();
45
            let restrict_fn = move |ast: &Dynamic| {
46
                let bv = ast.as_bv().unwrap();
47
                let restrictions: Vec<_> = ranges
48
                    .iter()
49
                    .map(|range| int_range_to_bv_restriction(&bv, range))
50
                    .collect();
51
                Bool::or(restrictions.as_slice())
52
            };
53
            Ok((Sort::bitvector(BV_SIZE), Box::new(restrict_fn)))
54
        }
55

            
56
        (_, GroundDomain::Matrix(val_domain, idx_domains)) => {
57
            // We constrain the inner values of the domain recursively
58
            // I.e. every way to index the array must give a value in the correct domain
59

            
60
            let (range_sort, restrict_val) = match idx_domains.as_slice() {
61
                [_] => domain_to_sort(val_domain, theories),
62
                [_, tail @ ..] => {
63
                    // Treat as a matrix containing (n-1)-dimensional matrices
64
                    let inner_domain = GroundDomain::Matrix(val_domain.clone(), tail.to_vec());
65
                    domain_to_sort(&inner_domain, theories)
66
                }
67
                [] => Err(SolverError::ModelInvalid(
68
                    "empty matrix index domain".into(),
69
                )),
70
            }?;
71
            let idx_domain = &idx_domains[0];
72
            let (domain_sort, _) = domain_to_sort(idx_domain.as_ref(), theories)?;
73

            
74
            let idx_asts = domain_to_ast_vec(theories, idx_domain.as_ref())?;
75
            let restrict_fn = move |ast: &Dynamic| {
76
                let arr = ast.as_array().unwrap();
77
                let restrictions: Vec<_> = idx_asts
78
                    .iter()
79
                    .map(|idx_ast| (restrict_val)(&arr.select(idx_ast)))
80
                    .collect();
81
                Bool::and(restrictions.as_slice())
82
            };
83
            Ok((
84
                Sort::array(&domain_sort, &range_sort),
85
                Box::new(restrict_fn),
86
            ))
87
        }
88

            
89
        _ => Err(SolverError::ModelFeatureNotImplemented(format!(
90
            "sort for '{domain}' not implemented"
91
        ))),
92
    }
93
}
94

            
95
/// Returns a domain as a vector of Z3 AST literals.
96
pub fn domain_to_ast_vec(
97
    theory_config: &TheoryConfig,
98
    domain: &GroundDomain,
99
) -> SolverResult<Vec<Dynamic>> {
100
    let lits = domain
101
        .values()
102
        .map_err(|err| SolverError::Runtime(err.to_string()))?;
103
    lits.map(|lit| literal_to_ast(theory_config, &lit))
104
        .collect()
105
}
106

            
107
/// Returns a boolean expression restricting the given integer variable to the given range.
108
pub fn int_range_to_int_restriction(var: &Int, range: &Range<i32>) -> Bool {
109
    match range {
110
        Range::Single(n) => var.eq(Int::from(*n)),
111
        Range::UnboundedL(r) => var.le(Int::from(*r)),
112
        Range::UnboundedR(l) => var.ge(Int::from(*l)),
113
        Range::Bounded(l, r) => Bool::and(&[var.ge(Int::from(*l)), var.le(Int::from(*r))]),
114
        _ => bug!("int ranges should not be unbounded"),
115
    }
116
}
117

            
118
/// Returns a boolean expression restricting the given bitvector variable to the given integer range.
119
pub fn int_range_to_bv_restriction(var: &BV, range: &Range<i32>) -> Bool {
120
    match range {
121
        Range::Single(n) => var.eq(BV::from_i64(*n as i64, BV_SIZE)),
122
        Range::UnboundedL(r) => var.bvsle(BV::from_i64(*r as i64, BV_SIZE)),
123
        Range::UnboundedR(l) => var.bvsge(BV::from_i64(*l as i64, BV_SIZE)),
124
        Range::Bounded(l, r) => Bool::and(&[
125
            var.bvsge(BV::from_i64(*l as i64, BV_SIZE)),
126
            var.bvsle(BV::from_i64(*r as i64, BV_SIZE)),
127
        ]),
128
        _ => bug!("int ranges should not be unbounded"),
129
    }
130
}
131

            
132
pub fn name_to_symbol(name: &Name) -> SolverResult<Symbol> {
133
    match name {
134
        Name::User(ustr) => Ok(Symbol::String((*ustr).into())),
135
        Name::Machine(num) => Ok(Symbol::Int(*num as u32)),
136
        Name::Represented(parts) => {
137
            let (name, rule_str, suffix) = parts.as_ref();
138
            Ok(Symbol::String(format!("{name}#{rule_str}_{suffix}")))
139
        }
140
        _ => Err(SolverError::ModelFeatureNotImplemented(format!(
141
            "variable '{name}' name is unsupported"
142
        ))),
143
    }
144
}
145

            
146
/// Converts an atom (literal or reference) into an AST node.
147
pub fn atom_to_ast(
148
    theory_config: &TheoryConfig,
149
    store: &SymbolStore,
150
    atom: &Atom,
151
) -> SolverResult<Dynamic> {
152
    match atom {
153
        Atom::Reference(decl) => store
154
            .get(&decl.name())
155
            .ok_or(SolverError::ModelInvalid(format!(
156
                "variable '{}' does not exist",
157
                decl.name()
158
            )))
159
            .map(|(_, ast, _)| ast)
160
            .cloned(),
161
        Atom::Literal(lit) => literal_to_ast(theory_config, lit),
162
        _ => Err(SolverError::ModelFeatureNotImplemented(format!(
163
            "atom sort not implemented: {atom}"
164
        ))),
165
    }
166
}
167

            
168
/// Converts a CO literal (expression containing no variables) into an AST node.
169
pub fn literal_to_ast(theory_config: &TheoryConfig, lit: &Literal) -> SolverResult<Dynamic> {
170
    match lit {
171
        Literal::Bool(b) => Ok(Bool::from_bool(*b).into()),
172
        Literal::Int(n) => Ok(match theory_config.ints {
173
            IntTheory::Lia => Int::from(*n).into(),
174
            IntTheory::Bv => BV::from_i64(*n as i64, BV_SIZE).into(),
175
        }),
176
        _ => Err(SolverError::ModelFeatureNotImplemented(format!(
177
            "literal type not implemented: {lit}"
178
        ))),
179
    }
180
}