conjure_core/ast/
comprehension.rs

1use std::{
2    cell::RefCell,
3    collections::HashSet,
4    fmt::Display,
5    rc::Rc,
6    sync::{Arc, Mutex, RwLock},
7};
8
9use itertools::Itertools as _;
10use serde::{Deserialize, Serialize};
11use uniplate::{derive::Uniplate, Biplate as _};
12
13use crate::{
14    ast::Atom,
15    context::Context,
16    into_matrix_expr, matrix_expr,
17    metadata::Metadata,
18    solver::{Solver, SolverError},
19};
20
21use super::{Declaration, Domain, Expression, Model, Name, Range, SubModel, SymbolTable};
22
23pub enum ComprehensionKind {
24    Sum,
25    And,
26    Or,
27}
28/// A comprehension.
29#[derive(Clone, PartialEq, Eq, Uniplate, Serialize, Deserialize, Debug)]
30#[uniplate(walk_into=[SubModel])]
31#[biplate(to=SubModel,walk_into=[Expression])]
32#[biplate(to=Expression,walk_into=[SubModel])]
33pub struct Comprehension {
34    expression: Expression,
35    submodel: SubModel,
36    induction_vars: Vec<Name>,
37}
38
39impl Comprehension {
40    // Solves this comprehension using Minion, returning the resulting expressions.
41    pub fn solve_with_minion(self) -> Result<Vec<Expression>, SolverError> {
42        let minion = Solver::new(crate::solver::adaptors::Minion::new());
43        // FIXME: weave proper context through
44        let mut model = Model::new(Arc::new(RwLock::new(Context::default())));
45
46        // only branch on the induction variables.
47        model.search_order = Some(self.induction_vars.clone());
48
49        *model.as_submodel_mut() = self.submodel.clone();
50
51        let minion = minion.load_model(model.clone())?;
52
53        let values = Arc::new(Mutex::new(Vec::new()));
54        let values_ptr = Arc::clone(&values);
55
56        tracing::debug!(model=%model.clone(),comprehension=%self.clone(),"Minion solving comprehension");
57        let expression = self.expression;
58        minion.solve(Box::new(move |sols| {
59            // TODO: deal with represented names if induction variables are abslits.
60            let values = &mut *values_ptr.lock().unwrap();
61            values.push(sols);
62            true
63        }))?;
64
65        let values = values.lock().unwrap().clone();
66        Ok(values
67            .clone()
68            .into_iter()
69            .map(|sols| {
70                // substitute in values
71                expression
72                    .clone()
73                    .transform_bi(Arc::new(move |atom: Atom| match atom {
74                        Atom::Reference(name) if sols.contains_key(&name) => {
75                            Atom::Literal(sols.get(&name).unwrap().clone())
76                        }
77                        x => x,
78                    }))
79            })
80            .collect_vec())
81    }
82
83    pub fn domain_of(&self) -> Option<Domain> {
84        self.expression
85            .domain_of(&self.submodel.symbols())
86            .map(|domain| {
87                Domain::DomainMatrix(
88                    Box::new(domain),
89                    vec![Domain::IntDomain(vec![Range::UnboundedR(1)])],
90                )
91            })
92    }
93}
94
95impl Display for Comprehension {
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        let generators: String = self
98            .submodel
99            .symbols()
100            .clone()
101            .into_iter_local()
102            .map(|(name, decl)| (name, decl.domain().unwrap().clone()))
103            .map(|(name, domain)| format!("{name}: {domain}"))
104            .join(",");
105
106        let guards = self
107            .submodel
108            .constraints()
109            .iter()
110            .map(|x| format!("{x}"))
111            .join(",");
112
113        let generators_and_guards = itertools::join([generators, guards], ",");
114
115        let expression = &self.expression;
116        write!(f, "[{expression} | {generators_and_guards}]")
117    }
118}
119
120/// A builder for a comprehension.
121#[derive(Clone, Debug, PartialEq, Eq, Default)]
122pub struct ComprehensionBuilder {
123    guards: Vec<Expression>,
124    generators: Vec<(Name, Domain)>,
125    induction_variables: HashSet<Name>,
126}
127
128impl ComprehensionBuilder {
129    pub fn new() -> Self {
130        Default::default()
131    }
132    pub fn guard(mut self, guard: Expression) -> Self {
133        self.guards.push(guard);
134        self
135    }
136
137    pub fn generator(mut self, name: Name, domain: Domain) -> Self {
138        assert!(!self.induction_variables.contains(&name));
139        self.induction_variables.insert(name.clone());
140        self.generators.push((name, domain));
141        self
142    }
143
144    /// Creates a comprehension with the given return expression.
145    ///
146    /// If a comprehension kind is not given, comprehension guards containing decision variables
147    /// are invalid, and will cause a panic.
148    pub fn with_return_value(
149        self,
150        mut expression: Expression,
151        parent: Rc<RefCell<SymbolTable>>,
152        comprehension_kind: Option<ComprehensionKind>,
153    ) -> Comprehension {
154        let mut submodel = SubModel::new(parent);
155
156        // TODO:also allow guards that reference lettings and givens.
157
158        let induction_variables = self.induction_variables;
159
160        // only guards referencing induction variables can go inside the comprehension
161        let (induction_guards, other_guards): (Vec<_>, Vec<_>) = self
162            .guards
163            .into_iter()
164            .partition(|x| is_induction_guard(&induction_variables, x));
165
166        // handle guards that reference non-induction variables
167        if !other_guards.is_empty() {
168            let comprehension_kind = comprehension_kind.expect(
169                "if any guards reference decision variables, a comprehension kind should be given",
170            );
171
172            let guard_expr = match other_guards.as_slice() {
173                [x] => x.clone(),
174                xs => Expression::And(Metadata::new(), Box::new(into_matrix_expr!(xs.to_vec()))),
175            };
176
177            expression = match comprehension_kind {
178                ComprehensionKind::And => {
179                    Expression::Imply(Metadata::new(), Box::new(guard_expr), Box::new(expression))
180                }
181                ComprehensionKind::Or => Expression::And(
182                    Metadata::new(),
183                    Box::new(Expression::And(
184                        Metadata::new(),
185                        Box::new(matrix_expr![guard_expr, expression]),
186                    )),
187                ),
188
189                ComprehensionKind::Sum => {
190                    panic!("guards that reference decision variables not yet implemented for sum");
191                }
192            }
193        }
194
195        submodel.add_constraints(induction_guards);
196        for (name, domain) in self.generators {
197            submodel
198                .symbols_mut()
199                .insert(Rc::new(Declaration::new_var(name, domain)));
200        }
201
202        Comprehension {
203            expression,
204            submodel,
205            induction_vars: induction_variables.into_iter().collect_vec(),
206        }
207    }
208}
209
210/// True iff the guard only references induction variables.
211fn is_induction_guard(induction_variables: &HashSet<Name>, guard: &Expression) -> bool {
212    guard
213        .universe_bi()
214        .iter()
215        .all(|x| induction_variables.contains(x))
216}