1
//! Normalisation rules for comprehensions.
2

            
3
use std::collections::HashSet;
4

            
5
use conjure_cp::{
6
    ast::{
7
        Expression as Expr, Metadata, Moo, Name, SymbolTable, SymbolTablePtr,
8
        ac_operators::ACOperatorKind, comprehension::Comprehension,
9
    },
10
    rule_engine::{
11
        ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
12
    },
13
};
14

            
15
/// Merges nested comprehensions inside the same AC operator into a single comprehension.
16
///
17
/// ```text
18
/// op([ op([ op([ body | qs3 ]) | qs2 ]) | qs1 ]) ~> op([ body | qs1, qs2, qs3 ])
19
/// ```
20
///
21
/// where `op` is one of `and`, `or`, `sum`, or `product`.
22
#[register_rule("Base", 8900)]
23
2436492
fn merge_nested_ac_comprehensions(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
24
2436492
    let new_expr = merge_nested_ac_comprehensions_impl(expr).ok_or(RuleNotApplicable)?;
25
202
    Ok(Reduction::pure(new_expr))
26
2436492
}
27

            
28
2436492
fn merge_nested_ac_comprehensions_impl(expr: &Expr) -> Option<Expr> {
29
2436492
    let ac_operator_kind = expr.to_ac_operator_kind()?;
30

            
31
202687
    let outer_comprehension = match expr {
32
45650
        Expr::And(_, child)
33
41227
        | Expr::Or(_, child)
34
87870
        | Expr::Sum(_, child)
35
27940
        | Expr::Product(_, child) => {
36
202687
            let Expr::Comprehension(_, comprehension) = child.as_ref() else {
37
185135
                return None;
38
            };
39
17552
            comprehension.as_ref().clone()
40
        }
41
        _ => return None,
42
    };
43

            
44
17552
    let parent_scope = outer_comprehension.symbols().parent().clone()?;
45

            
46
17552
    let mut merged_levels = vec![outer_comprehension.clone()];
47
17552
    let mut merged_names: HashSet<Name> = outer_comprehension
48
17552
        .quantified_vars()
49
17552
        .iter()
50
17552
        .cloned()
51
17552
        .collect();
52

            
53
17552
    let mut current_return_expression = outer_comprehension.return_expression();
54
202
    while let Some(inner_comprehension) =
55
17754
        extract_inner_comprehension(ac_operator_kind, &current_return_expression)
56
    {
57
        // Avoid changing semantics when inner quantifiers shadow outer ones.
58
202
        if inner_comprehension
59
202
            .quantified_vars()
60
202
            .iter()
61
202
            .any(|name| merged_names.contains(name))
62
        {
63
            break;
64
202
        }
65

            
66
202
        merged_names.extend(inner_comprehension.quantified_vars().iter().cloned());
67
202
        current_return_expression = inner_comprehension.clone().return_expression();
68
202
        merged_levels.push(inner_comprehension);
69
    }
70

            
71
17552
    if merged_levels.len() < 2 {
72
17350
        return None;
73
202
    }
74

            
75
202
    let merged_symbols = merge_symbols(parent_scope, &merged_levels);
76
202
    let merged_qualifiers = merged_levels
77
202
        .iter()
78
404
        .flat_map(|level| level.qualifiers.clone())
79
202
        .collect();
80
202
    let mut merged = merged_levels.first()?.clone();
81
202
    merged.return_expression = current_return_expression;
82
202
    merged.qualifiers = merged_qualifiers;
83
202
    merged.symbols = merged_symbols;
84

            
85
202
    let merged_comprehension = Expr::Comprehension(Metadata::new(), Moo::new(merged));
86
202
    let wrapped = match ac_operator_kind {
87
86
        ACOperatorKind::And => Expr::And(Metadata::new(), Moo::new(merged_comprehension)),
88
96
        ACOperatorKind::Or => Expr::Or(Metadata::new(), Moo::new(merged_comprehension)),
89
20
        ACOperatorKind::Sum => Expr::Sum(Metadata::new(), Moo::new(merged_comprehension)),
90
        ACOperatorKind::Product => Expr::Product(Metadata::new(), Moo::new(merged_comprehension)),
91
    };
92

            
93
202
    Some(wrapped)
94
2436492
}
95

            
96
17754
fn extract_inner_comprehension(
97
17754
    ac_operator_kind: ACOperatorKind,
98
17754
    expr: &Expr,
99
17754
) -> Option<Comprehension> {
100
17754
    let wrapped = match (ac_operator_kind, expr) {
101
1451
        (ACOperatorKind::And, Expr::And(_, child)) => child.as_ref(),
102
96
        (ACOperatorKind::Or, Expr::Or(_, child)) => child.as_ref(),
103
24
        (ACOperatorKind::Sum, Expr::Sum(_, child)) => child.as_ref(),
104
        (ACOperatorKind::Product, Expr::Product(_, child)) => child.as_ref(),
105
16183
        _ => return None,
106
    };
107

            
108
1571
    as_single_comprehension(wrapped)
109
17754
}
110

            
111
1571
fn as_single_comprehension(expr: &Expr) -> Option<Comprehension> {
112
1571
    if let Expr::Comprehension(_, comprehension) = expr {
113
202
        return Some(comprehension.as_ref().clone());
114
1369
    }
115

            
116
1369
    let exprs = expr.clone().unwrap_list()?;
117
1074
    let [Expr::Comprehension(_, comprehension)] = exprs.as_slice() else {
118
1074
        return None;
119
    };
120

            
121
    Some(comprehension.as_ref().clone())
122
1571
}
123

            
124
202
fn merge_symbols(parent_scope: SymbolTablePtr, levels: &[Comprehension]) -> SymbolTablePtr {
125
202
    let symbols = SymbolTablePtr::with_parent(parent_scope);
126
404
    for level in levels {
127
404
        for (_, decl) in level.symbols().clone().into_iter_local() {
128
404
            symbols.write().update_insert(decl);
129
404
        }
130
    }
131
202
    symbols
132
202
}