1
//! Normalisation rules for comprehensions.
2

            
3
use std::collections::HashSet;
4

            
5
use conjure_cp::{
6
    ast::{
7
        Expression as Expr, Metadata, Moo, Name, SymbolTable, SymbolTablePtr,
8
        ac_operators::ACOperatorKind, comprehension::Comprehension,
9
    },
10
    rule_engine::{
11
        ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
12
    },
13
};
14

            
15
/// Merges nested comprehensions inside the same AC operator into a single comprehension.
16
///
17
/// ```text
18
/// op([ op([ op([ body | qs3 ]) | qs2 ]) | qs1 ]) ~> op([ body | qs1, qs2, qs3 ])
19
/// ```
20
///
21
/// where `op` is one of `and`, `or`, `sum`, or `product`.
22
#[register_rule(("Base", 8900))]
23
1350318
fn merge_nested_ac_comprehensions(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
24
1350318
    let new_expr = merge_nested_ac_comprehensions_impl(expr).ok_or(RuleNotApplicable)?;
25
54
    Ok(Reduction::pure(new_expr))
26
1350318
}
27

            
28
1350318
fn merge_nested_ac_comprehensions_impl(expr: &Expr) -> Option<Expr> {
29
1350318
    let ac_operator_kind = expr.to_ac_operator_kind()?;
30

            
31
128709
    let outer_comprehension = match expr {
32
10764
        Expr::And(_, child)
33
5832
        | Expr::Or(_, child)
34
70203
        | Expr::Sum(_, child)
35
41910
        | Expr::Product(_, child) => {
36
128709
            let Expr::Comprehension(_, comprehension) = child.as_ref() else {
37
124659
                return None;
38
            };
39
4050
            comprehension.as_ref().clone()
40
        }
41
        _ => return None,
42
    };
43

            
44
4050
    let parent_scope = outer_comprehension.symbols().parent().clone()?;
45

            
46
4050
    let mut merged_levels = vec![outer_comprehension.clone()];
47
4050
    let mut merged_names: HashSet<Name> = outer_comprehension
48
4050
        .quantified_vars()
49
4050
        .iter()
50
4050
        .cloned()
51
4050
        .collect();
52

            
53
4050
    let mut current_return_expression = outer_comprehension.return_expression();
54
60
    while let Some(inner_comprehension) =
55
4110
        extract_inner_comprehension(ac_operator_kind, &current_return_expression)
56
    {
57
        // Avoid changing semantics when inner quantifiers shadow outer ones.
58
60
        if inner_comprehension
59
60
            .quantified_vars()
60
60
            .iter()
61
60
            .any(|name| merged_names.contains(name))
62
        {
63
            break;
64
60
        }
65

            
66
60
        merged_names.extend(inner_comprehension.quantified_vars().iter().cloned());
67
60
        current_return_expression = inner_comprehension.clone().return_expression();
68
60
        merged_levels.push(inner_comprehension);
69
    }
70

            
71
4050
    if merged_levels.len() < 2 {
72
3996
        return None;
73
54
    }
74

            
75
54
    let merged_symbols = merge_symbols(parent_scope, &merged_levels);
76
54
    let merged_qualifiers = merged_levels
77
54
        .iter()
78
114
        .flat_map(|level| level.qualifiers.clone())
79
54
        .collect();
80
54
    let mut merged = merged_levels.first()?.clone();
81
54
    merged.return_expression = current_return_expression;
82
54
    merged.qualifiers = merged_qualifiers;
83
54
    merged.symbols = merged_symbols;
84

            
85
54
    let merged_comprehension = Expr::Comprehension(Metadata::new(), Moo::new(merged));
86
54
    let wrapped = match ac_operator_kind {
87
48
        ACOperatorKind::And => Expr::And(Metadata::new(), Moo::new(merged_comprehension)),
88
        ACOperatorKind::Or => Expr::Or(Metadata::new(), Moo::new(merged_comprehension)),
89
6
        ACOperatorKind::Sum => Expr::Sum(Metadata::new(), Moo::new(merged_comprehension)),
90
        ACOperatorKind::Product => Expr::Product(Metadata::new(), Moo::new(merged_comprehension)),
91
    };
92

            
93
54
    Some(wrapped)
94
1350318
}
95

            
96
4110
fn extract_inner_comprehension(
97
4110
    ac_operator_kind: ACOperatorKind,
98
4110
    expr: &Expr,
99
4110
) -> Option<Comprehension> {
100
4110
    let wrapped = match (ac_operator_kind, expr) {
101
156
        (ACOperatorKind::And, Expr::And(_, child)) => child.as_ref(),
102
        (ACOperatorKind::Or, Expr::Or(_, child)) => child.as_ref(),
103
12
        (ACOperatorKind::Sum, Expr::Sum(_, child)) => child.as_ref(),
104
        (ACOperatorKind::Product, Expr::Product(_, child)) => child.as_ref(),
105
3942
        _ => return None,
106
    };
107

            
108
168
    as_single_comprehension(wrapped)
109
4110
}
110

            
111
168
fn as_single_comprehension(expr: &Expr) -> Option<Comprehension> {
112
168
    if let Expr::Comprehension(_, comprehension) = expr {
113
60
        return Some(comprehension.as_ref().clone());
114
108
    }
115

            
116
108
    let exprs = expr.clone().unwrap_list()?;
117
    let [Expr::Comprehension(_, comprehension)] = exprs.as_slice() else {
118
        return None;
119
    };
120

            
121
    Some(comprehension.as_ref().clone())
122
168
}
123

            
124
54
fn merge_symbols(parent_scope: SymbolTablePtr, levels: &[Comprehension]) -> SymbolTablePtr {
125
54
    let symbols = SymbolTablePtr::with_parent(parent_scope);
126
114
    for level in levels {
127
114
        for (_, decl) in level.symbols().clone().into_iter_local() {
128
114
            symbols.write().update_insert(decl);
129
114
        }
130
    }
131
54
    symbols
132
54
}