1
//! Normalisation rules for comprehensions.
2

            
3
use std::collections::HashSet;
4

            
5
use conjure_cp::{
6
    ast::{
7
        Expression as Expr, Metadata, Moo, Name, SymbolTable, SymbolTablePtr,
8
        ac_operators::ACOperatorKind, comprehension::Comprehension,
9
    },
10
    rule_engine::{
11
        ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
12
    },
13
};
14

            
15
/// Merges nested comprehensions inside the same AC operator into a single comprehension.
16
///
17
/// ```text
18
/// op([ op([ op([ body | qs3 ]) | qs2 ]) | qs1 ]) ~> op([ body | qs1, qs2, qs3 ])
19
/// ```
20
///
21
/// where `op` is one of `and`, `or`, `sum`, or `product`.
22
#[register_rule(("Base", 8900))]
23
448997
fn merge_nested_ac_comprehensions(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
24
448997
    let new_expr = merge_nested_ac_comprehensions_impl(expr).ok_or(RuleNotApplicable)?;
25
39
    Ok(Reduction::pure(new_expr))
26
448997
}
27

            
28
448997
fn merge_nested_ac_comprehensions_impl(expr: &Expr) -> Option<Expr> {
29
448997
    let ac_operator_kind = expr.to_ac_operator_kind()?;
30

            
31
32910
    let outer_comprehension = match expr {
32
5859
        Expr::And(_, child)
33
3501
        | Expr::Or(_, child)
34
17754
        | Expr::Sum(_, child)
35
5796
        | Expr::Product(_, child) => {
36
32910
            let Expr::Comprehension(_, comprehension) = child.as_ref() else {
37
31116
                return None;
38
            };
39
1794
            comprehension.as_ref().clone()
40
        }
41
        _ => return None,
42
    };
43

            
44
1794
    let parent_scope = outer_comprehension.symbols().parent().clone()?;
45

            
46
1794
    let mut merged_levels = vec![outer_comprehension.clone()];
47
1794
    let mut merged_names: HashSet<Name> = outer_comprehension
48
1794
        .quantified_vars()
49
1794
        .iter()
50
1794
        .cloned()
51
1794
        .collect();
52

            
53
1794
    let mut current_return_expression = outer_comprehension.return_expression();
54
42
    while let Some(inner_comprehension) =
55
1836
        extract_inner_comprehension(ac_operator_kind, &current_return_expression)
56
    {
57
        // Avoid changing semantics when inner quantifiers shadow outer ones.
58
42
        if inner_comprehension
59
42
            .quantified_vars()
60
42
            .iter()
61
42
            .any(|name| merged_names.contains(name))
62
        {
63
            break;
64
42
        }
65

            
66
42
        merged_names.extend(inner_comprehension.quantified_vars().iter().cloned());
67
42
        current_return_expression = inner_comprehension.clone().return_expression();
68
42
        merged_levels.push(inner_comprehension);
69
    }
70

            
71
1794
    if merged_levels.len() < 2 {
72
1755
        return None;
73
39
    }
74

            
75
39
    let merged_symbols = merge_symbols(parent_scope, &merged_levels);
76
39
    let merged_qualifiers = merged_levels
77
39
        .iter()
78
81
        .flat_map(|level| level.qualifiers.clone())
79
39
        .collect();
80
39
    let mut merged = merged_levels.first()?.clone();
81
39
    merged.return_expression = current_return_expression;
82
39
    merged.qualifiers = merged_qualifiers;
83
39
    merged.symbols = merged_symbols;
84

            
85
39
    let merged_comprehension = Expr::Comprehension(Metadata::new(), Moo::new(merged));
86
39
    let wrapped = match ac_operator_kind {
87
24
        ACOperatorKind::And => Expr::And(Metadata::new(), Moo::new(merged_comprehension)),
88
12
        ACOperatorKind::Or => Expr::Or(Metadata::new(), Moo::new(merged_comprehension)),
89
3
        ACOperatorKind::Sum => Expr::Sum(Metadata::new(), Moo::new(merged_comprehension)),
90
        ACOperatorKind::Product => Expr::Product(Metadata::new(), Moo::new(merged_comprehension)),
91
    };
92

            
93
39
    Some(wrapped)
94
448997
}
95

            
96
1836
fn extract_inner_comprehension(
97
1836
    ac_operator_kind: ACOperatorKind,
98
1836
    expr: &Expr,
99
1836
) -> Option<Comprehension> {
100
1836
    let wrapped = match (ac_operator_kind, expr) {
101
78
        (ACOperatorKind::And, Expr::And(_, child)) => child.as_ref(),
102
12
        (ACOperatorKind::Or, Expr::Or(_, child)) => child.as_ref(),
103
6
        (ACOperatorKind::Sum, Expr::Sum(_, child)) => child.as_ref(),
104
        (ACOperatorKind::Product, Expr::Product(_, child)) => child.as_ref(),
105
1740
        _ => return None,
106
    };
107

            
108
96
    as_single_comprehension(wrapped)
109
1836
}
110

            
111
96
fn as_single_comprehension(expr: &Expr) -> Option<Comprehension> {
112
96
    if let Expr::Comprehension(_, comprehension) = expr {
113
42
        return Some(comprehension.as_ref().clone());
114
54
    }
115

            
116
54
    let exprs = expr.clone().unwrap_list()?;
117
    let [Expr::Comprehension(_, comprehension)] = exprs.as_slice() else {
118
        return None;
119
    };
120

            
121
    Some(comprehension.as_ref().clone())
122
96
}
123

            
124
39
fn merge_symbols(parent_scope: SymbolTablePtr, levels: &[Comprehension]) -> SymbolTablePtr {
125
39
    let symbols = SymbolTablePtr::with_parent(parent_scope);
126
81
    for level in levels {
127
81
        for (_, decl) in level.symbols().clone().into_iter_local() {
128
81
            symbols.write().update_insert(decl);
129
81
        }
130
    }
131
39
    symbols
132
39
}