1
//! Normalisation rules for comprehensions.
2

            
3
use std::collections::HashSet;
4

            
5
use conjure_cp::{
6
    ast::{
7
        Expression as Expr, Metadata, Moo, Name, SymbolTable, SymbolTablePtr,
8
        ac_operators::ACOperatorKind, comprehension::Comprehension,
9
    },
10
    rule_engine::{
11
        ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
12
    },
13
};
14

            
15
/// Merges nested comprehensions inside the same AC operator into a single comprehension.
16
///
17
/// ```text
18
/// op([ op([ op([ body | qs3 ]) | qs2 ]) | qs1 ]) ~> op([ body | qs1, qs2, qs3 ])
19
/// ```
20
///
21
/// where `op` is one of `and`, `or`, `sum`, or `product`.
22
#[register_rule(("Base", 8900))]
23
1347039
fn merge_nested_ac_comprehensions(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
24
1347039
    let new_expr = merge_nested_ac_comprehensions_impl(expr).ok_or(RuleNotApplicable)?;
25
117
    Ok(Reduction::pure(new_expr))
26
1347039
}
27

            
28
1347039
fn merge_nested_ac_comprehensions_impl(expr: &Expr) -> Option<Expr> {
29
1347039
    let ac_operator_kind = expr.to_ac_operator_kind()?;
30

            
31
98748
    let outer_comprehension = match expr {
32
17586
        Expr::And(_, child)
33
10512
        | Expr::Or(_, child)
34
53262
        | Expr::Sum(_, child)
35
17388
        | Expr::Product(_, child) => {
36
98748
            let Expr::Comprehension(_, comprehension) = child.as_ref() else {
37
93366
                return None;
38
            };
39
5382
            comprehension.as_ref().clone()
40
        }
41
        _ => return None,
42
    };
43

            
44
5382
    let parent_scope = outer_comprehension.symbols().parent().clone()?;
45

            
46
5382
    let mut merged_levels = vec![outer_comprehension.clone()];
47
5382
    let mut merged_names: HashSet<Name> = outer_comprehension
48
5382
        .quantified_vars()
49
5382
        .iter()
50
5382
        .cloned()
51
5382
        .collect();
52

            
53
5382
    let mut current_return_expression = outer_comprehension.return_expression();
54
126
    while let Some(inner_comprehension) =
55
5508
        extract_inner_comprehension(ac_operator_kind, &current_return_expression)
56
    {
57
        // Avoid changing semantics when inner quantifiers shadow outer ones.
58
126
        if inner_comprehension
59
126
            .quantified_vars()
60
126
            .iter()
61
126
            .any(|name| merged_names.contains(name))
62
        {
63
            break;
64
126
        }
65

            
66
126
        merged_names.extend(inner_comprehension.quantified_vars().iter().cloned());
67
126
        current_return_expression = inner_comprehension.clone().return_expression();
68
126
        merged_levels.push(inner_comprehension);
69
    }
70

            
71
5382
    if merged_levels.len() < 2 {
72
5265
        return None;
73
117
    }
74

            
75
117
    let merged_symbols = merge_symbols(parent_scope, &merged_levels);
76
117
    let merged_qualifiers = merged_levels
77
117
        .iter()
78
243
        .flat_map(|level| level.qualifiers.clone())
79
117
        .collect();
80
117
    let mut merged = merged_levels.first()?.clone();
81
117
    merged.return_expression = current_return_expression;
82
117
    merged.qualifiers = merged_qualifiers;
83
117
    merged.symbols = merged_symbols;
84

            
85
117
    let merged_comprehension = Expr::Comprehension(Metadata::new(), Moo::new(merged));
86
117
    let wrapped = match ac_operator_kind {
87
72
        ACOperatorKind::And => Expr::And(Metadata::new(), Moo::new(merged_comprehension)),
88
36
        ACOperatorKind::Or => Expr::Or(Metadata::new(), Moo::new(merged_comprehension)),
89
9
        ACOperatorKind::Sum => Expr::Sum(Metadata::new(), Moo::new(merged_comprehension)),
90
        ACOperatorKind::Product => Expr::Product(Metadata::new(), Moo::new(merged_comprehension)),
91
    };
92

            
93
117
    Some(wrapped)
94
1347039
}
95

            
96
5508
fn extract_inner_comprehension(
97
5508
    ac_operator_kind: ACOperatorKind,
98
5508
    expr: &Expr,
99
5508
) -> Option<Comprehension> {
100
5508
    let wrapped = match (ac_operator_kind, expr) {
101
234
        (ACOperatorKind::And, Expr::And(_, child)) => child.as_ref(),
102
36
        (ACOperatorKind::Or, Expr::Or(_, child)) => child.as_ref(),
103
18
        (ACOperatorKind::Sum, Expr::Sum(_, child)) => child.as_ref(),
104
        (ACOperatorKind::Product, Expr::Product(_, child)) => child.as_ref(),
105
5220
        _ => return None,
106
    };
107

            
108
288
    as_single_comprehension(wrapped)
109
5508
}
110

            
111
288
fn as_single_comprehension(expr: &Expr) -> Option<Comprehension> {
112
288
    if let Expr::Comprehension(_, comprehension) = expr {
113
126
        return Some(comprehension.as_ref().clone());
114
162
    }
115

            
116
162
    let exprs = expr.clone().unwrap_list()?;
117
    let [Expr::Comprehension(_, comprehension)] = exprs.as_slice() else {
118
        return None;
119
    };
120

            
121
    Some(comprehension.as_ref().clone())
122
288
}
123

            
124
117
fn merge_symbols(parent_scope: SymbolTablePtr, levels: &[Comprehension]) -> SymbolTablePtr {
125
117
    let symbols = SymbolTablePtr::with_parent(parent_scope);
126
243
    for level in levels {
127
243
        for (_, decl) in level.symbols().clone().into_iter_local() {
128
243
            symbols.write().update_insert(decl);
129
243
        }
130
    }
131
117
    symbols
132
117
}