1
use crate::errors::{FatalParseError, RecoverableParseError};
2
use crate::expression::parse_expression;
3
use crate::field;
4
use crate::parser::ParseContext;
5
use crate::util::{TypecheckingContext, named_children};
6
use conjure_cp_core::ast::{
7
    Atom, DeclarationKind, Expression, Metadata, Moo, ReturnType, Typeable,
8
};
9
use conjure_cp_core::into_matrix_expr;
10
use tree_sitter::Node;
11
use uniplate::Uniplate;
12

            
13
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
14
enum ParetoDirection {
15
    Minimising,
16
    Maximising,
17
}
18

            
19
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
20
enum ReferenceRewriteAction {
21
    LeaveAsIs,
22
    ExpandValueLetting,
23
    WrapInFromSolution,
24
}
25

            
26
5697
pub fn parse_dominance_relation(
27
5697
    ctx: &mut ParseContext,
28
5697
    node: &Node,
29
5697
) -> Result<Option<Expression>, FatalParseError> {
30
5697
    if ctx.root.kind() == "dominance_relation" {
31
        ctx.record_error(RecoverableParseError::new(
32
            "Nested dominance relations are not allowed".to_string(),
33
            Some(node.range()),
34
        ));
35
        return Ok(None);
36
5697
    }
37

            
38
5697
    let Some(inner_node) = field!(recover, ctx, node, "expression") else {
39
        return Ok(None);
40
    };
41

            
42
    // Create a nested context so downstream parsing knows it is inside a dominance relation.
43
5697
    let mut inner_ctx = ParseContext {
44
5697
        source_code: ctx.source_code,
45
5697
        root: node,
46
5697
        symbols: ctx.symbols.clone(),
47
5697
        errors: ctx.errors,
48
5697
        source_map: &mut *ctx.source_map,
49
5697
        decl_spans: ctx.decl_spans,
50
5697
        typechecking_context: TypecheckingContext::Unknown,
51
5697
        inner_typechecking_context: TypecheckingContext::Unknown,
52
5697
    };
53

            
54
5697
    let Some(inner) = parse_expression(&mut inner_ctx, inner_node)? else {
55
        return Ok(None);
56
    };
57

            
58
5697
    Ok(Some(Expression::DominanceRelation(
59
5697
        Metadata::new(),
60
5697
        Moo::new(inner),
61
5697
    )))
62
5697
}
63

            
64
1095
pub fn parse_pareto_expression(
65
1095
    ctx: &mut ParseContext,
66
1095
    node: &Node,
67
1095
) -> Result<Option<Expression>, FatalParseError> {
68
1095
    if ctx.root.kind() != "dominance_relation" {
69
        ctx.record_error(RecoverableParseError::new(
70
            "pareto(...) only allowed inside dominance relations".to_string(),
71
            Some(node.range()),
72
        ));
73
        return Ok(None);
74
1095
    }
75

            
76
1095
    let mut non_worsening = Vec::new();
77
1095
    let mut strict_improvements = Vec::new();
78
1095
    let components = field!(node, "components");
79

            
80
1095
    if components.kind() != "pareto_items" {
81
        return Err(FatalParseError::internal_error(
82
            format!("Unexpected pareto component list: '{}'", components.kind()),
83
            Some(components.range()),
84
        ));
85
1095
    }
86

            
87
2188
    for item_node in named_children(&components) {
88
2188
        let direction_node = field!(item_node, "direction");
89
2188
        let direction_str =
90
2188
            &ctx.source_code[direction_node.start_byte()..direction_node.end_byte()];
91
2188
        let direction = match direction_str {
92
2188
            "minimising" => ParetoDirection::Minimising,
93
1
            "maximising" => ParetoDirection::Maximising,
94
            _ => {
95
                return Err(FatalParseError::internal_error(
96
                    format!("Unexpected pareto direction: '{direction_str}'"),
97
                    Some(direction_node.range()),
98
                ));
99
            }
100
        };
101

            
102
2188
        let component_node = field!(item_node, "expression");
103
2188
        let Some(component_expr) = parse_pareto_component(ctx, &component_node)? else {
104
            return Ok(None);
105
        };
106
2188
        let Some((non_worse, strict)) =
107
2188
            build_pareto_constraints(ctx, &component_node, component_expr, direction)
108
        else {
109
            return Ok(None);
110
        };
111
2188
        non_worsening.push(non_worse);
112
2188
        strict_improvements.push(strict);
113
    }
114

            
115
1095
    let mut conjuncts = non_worsening;
116
1095
    conjuncts.push(combine_with_and_or(strict_improvements, true));
117

            
118
1095
    Ok(Some(combine_with_and_or(conjuncts, false)))
119
1095
}
120

            
121
2188
fn parse_pareto_component(
122
2188
    ctx: &mut ParseContext,
123
2188
    node: &Node,
124
2188
) -> Result<Option<Expression>, FatalParseError> {
125
2188
    let saved_context = ctx.typechecking_context;
126
2188
    let saved_inner_context = ctx.inner_typechecking_context;
127
2188
    ctx.typechecking_context = TypecheckingContext::Unknown;
128
2188
    ctx.inner_typechecking_context = TypecheckingContext::Unknown;
129

            
130
2188
    let parsed = parse_expression(ctx, *node);
131

            
132
2188
    ctx.typechecking_context = saved_context;
133
2188
    ctx.inner_typechecking_context = saved_inner_context;
134
2188
    parsed
135
2188
}
136

            
137
2188
fn build_pareto_constraints(
138
2188
    ctx: &mut ParseContext,
139
2188
    node: &Node,
140
2188
    component: Expression,
141
2188
    direction: ParetoDirection,
142
2188
) -> Option<(Expression, Expression)> {
143
2188
    if component
144
2188
        .universe()
145
2188
        .iter()
146
3829
        .any(|expr| matches!(expr, Expression::FromSolution(_, _)))
147
    {
148
        ctx.record_error(RecoverableParseError::new(
149
            "pareto(...) components cannot contain fromSolution(...) explicitly".to_string(),
150
            Some(node.range()),
151
        ));
152
        return None;
153
2188
    }
154

            
155
2188
    let current = expand_value_lettings(&component);
156
2188
    let previous = lift_to_previous_solution(&current);
157

            
158
2188
    match current.return_type() {
159
2188
        ReturnType::Int => Some(match direction {
160
2187
            ParetoDirection::Minimising => (
161
2187
                Expression::Leq(
162
2187
                    Metadata::new(),
163
2187
                    Moo::new(current.clone()),
164
2187
                    Moo::new(previous.clone()),
165
2187
                ),
166
2187
                Expression::Lt(Metadata::new(), Moo::new(current), Moo::new(previous)),
167
2187
            ),
168
1
            ParetoDirection::Maximising => (
169
1
                Expression::Geq(
170
1
                    Metadata::new(),
171
1
                    Moo::new(current.clone()),
172
1
                    Moo::new(previous.clone()),
173
1
                ),
174
1
                Expression::Gt(Metadata::new(), Moo::new(current), Moo::new(previous)),
175
1
            ),
176
        }),
177
        ReturnType::Bool => Some(match direction {
178
            ParetoDirection::Minimising => (
179
                Expression::Imply(
180
                    Metadata::new(),
181
                    Moo::new(current.clone()),
182
                    Moo::new(previous.clone()),
183
                ),
184
                combine_with_and_or(
185
                    vec![
186
                        Expression::Not(Metadata::new(), Moo::new(current)),
187
                        previous,
188
                    ],
189
                    false,
190
                ),
191
            ),
192
            ParetoDirection::Maximising => (
193
                Expression::Imply(
194
                    Metadata::new(),
195
                    Moo::new(previous.clone()),
196
                    Moo::new(current.clone()),
197
                ),
198
                combine_with_and_or(
199
                    vec![
200
                        current,
201
                        Expression::Not(Metadata::new(), Moo::new(previous)),
202
                    ],
203
                    false,
204
                ),
205
            ),
206
        }),
207
        found => {
208
            ctx.record_error(RecoverableParseError::new(
209
                format!(
210
                    "pareto(...) only supports int or bool components, found '{}'",
211
                    found
212
                ),
213
                Some(node.range()),
214
            ));
215
            None
216
        }
217
    }
218
2188
}
219

            
220
2188
fn expand_value_lettings(expr: &Expression) -> Expression {
221
2188
    rewrite_references(expr, false)
222
2188
}
223

            
224
2188
fn lift_to_previous_solution(expr: &Expression) -> Expression {
225
2188
    rewrite_references(expr, true)
226
2188
}
227

            
228
4376
fn rewrite_references(expr: &Expression, to_previous_solution: bool) -> Expression {
229
4376
    let mut lifted = expr.clone();
230

            
231
    loop {
232
6564
        let next = lifted.rewrite(&|subexpr| match subexpr {
233
4376
            Expression::Atomic(_, Atom::Reference(ref reference)) => {
234
4376
                let action = {
235
4376
                    let kind = reference.ptr.kind();
236
4376
                    match &*kind {
237
2188
                        DeclarationKind::Find(_) if to_previous_solution => {
238
2188
                            ReferenceRewriteAction::WrapInFromSolution
239
                        }
240
2188
                        DeclarationKind::Find(_) => ReferenceRewriteAction::LeaveAsIs,
241
                        DeclarationKind::ValueLetting(_, _)
242
                        | DeclarationKind::TemporaryValueLetting(_) => {
243
                            ReferenceRewriteAction::ExpandValueLetting
244
                        }
245
                        DeclarationKind::Given(_)
246
                        | DeclarationKind::Quantified(_)
247
                        | DeclarationKind::QuantifiedExpr(_)
248
                        | DeclarationKind::DomainLetting(_)
249
                        | DeclarationKind::Field(_)
250
                        | _ => ReferenceRewriteAction::LeaveAsIs,
251
                    }
252
                };
253

            
254
4376
                match action {
255
2188
                    ReferenceRewriteAction::LeaveAsIs => Some(subexpr),
256
                    ReferenceRewriteAction::ExpandValueLetting => reference.resolve_expression(),
257
2188
                    ReferenceRewriteAction::WrapInFromSolution => Some(Expression::FromSolution(
258
2188
                        Metadata::new(),
259
2188
                        Moo::new(Atom::Reference(reference.clone())),
260
2188
                    )),
261
                }
262
            }
263
7111
            _ => Some(subexpr),
264
11487
        });
265

            
266
6564
        if next == lifted {
267
4376
            return lifted;
268
2188
        }
269

            
270
2188
        lifted = next;
271
    }
272
4376
}
273

            
274
2190
fn combine_with_and_or(exprs: Vec<Expression>, is_or: bool) -> Expression {
275
2190
    match exprs.len() {
276
        0 => {
277
            if is_or {
278
                Expression::Or(Metadata::new(), Moo::new(into_matrix_expr![exprs]))
279
            } else {
280
                Expression::And(Metadata::new(), Moo::new(into_matrix_expr![exprs]))
281
            }
282
        }
283
2
        1 => match exprs.into_iter().next() {
284
2
            Some(expr) => expr,
285
            None => unreachable!("vector length already checked"),
286
        },
287
        _ => {
288
2188
            if is_or {
289
1093
                Expression::Or(Metadata::new(), Moo::new(into_matrix_expr![exprs]))
290
            } else {
291
1095
                Expression::And(Metadata::new(), Moo::new(into_matrix_expr![exprs]))
292
            }
293
        }
294
    }
295
2190
}