1
use conjure_cp::ast::comprehension::{Comprehension, ComprehensionQualifier};
2
use conjure_cp::ast::{Atom, DeclarationPtr, Metadata};
3
use conjure_cp::ast::{Expression as Expr, Moo, SymbolTable};
4
use conjure_cp::into_matrix_expr;
5
use conjure_cp::rule_engine::Reduction;
6
use conjure_cp::rule_engine::{
7
    ApplicationError::RuleNotApplicable, ApplicationResult, register_rule,
8
};
9
use uniplate::Biplate;
10

            
11
// [ return_expr | i <- A union B, qualifiers...] -> flatten([[ return_expr | i <- A, qualifiers...], [ return_expr | i <- B, !(i in A), qualifiers...]; int(1..2)])
12
#[register_rule("Base", 8700, [Comprehension])]
13
1958068
fn union_set(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
14
1958068
    match expr {
15
16342
        Expr::Comprehension(_, comp) => {
16
            // find if any of the generators are generating from expressions
17
27285
            for qualifier in &comp.qualifiers {
18
27285
                if let ComprehensionQualifier::ExpressionGenerator { ptr } = qualifier {
19
70
                    let gen_decl = ptr.clone();
20

            
21
                    // match on expression being of form A union B
22
70
                    let Some((a, b)) = (match ptr.as_quantified_expr() {
23
70
                        Some(expr_guard) => match &*expr_guard {
24
10
                            Expr::Union(_, a, b) => Some((a.clone(), b.clone())),
25
60
                            _ => None,
26
                        },
27
                        None => None,
28
                    }) else {
29
60
                        continue;
30
                    };
31

            
32
                    // [ return_expr | i <- A, guards...] part
33
10
                    let (comprehension1, _) =
34
10
                        rewrite_union_branch(comp.as_ref(), &gen_decl, a.clone().into());
35

            
36
                    // [ return_expr | i <- B, !(i in A), guards...] part
37
10
                    let (mut comprehension2, b_ptr) =
38
10
                        rewrite_union_branch(comp.as_ref(), &gen_decl, b.into());
39

            
40
                    // add the condition !(i in A)
41
10
                    comprehension2
42
10
                        .qualifiers
43
10
                        .push(ComprehensionQualifier::Condition(Expr::Not(
44
10
                            Metadata::new(),
45
10
                            Moo::new(Expr::In(
46
10
                                Metadata::new(),
47
10
                                Moo::new(Expr::Atomic(Metadata::new(), Atom::new_ref(b_ptr))),
48
10
                                a,
49
10
                            )),
50
10
                        )));
51

            
52
10
                    return Ok(Reduction::pure(Expr::Flatten(
53
10
                        Metadata::new(),
54
10
                        None,
55
10
                        Moo::new(into_matrix_expr!(vec![
56
10
                            Expr::Comprehension(Metadata::new(), comprehension1.into()),
57
10
                            Expr::Comprehension(Metadata::new(), comprehension2.into())
58
10
                        ])),
59
10
                    )));
60
27215
                }
61
            }
62

            
63
16332
            Err(RuleNotApplicable)
64
        }
65
1941726
        _ => Err(RuleNotApplicable),
66
    }
67
1958068
}
68

            
69
/// Clone one union branch into its own detached comprehension scope and rewrite all uses of the
70
/// original quantified declaration to a fresh branch-local expression generator.
71
20
fn rewrite_union_branch(
72
20
    comp: &Comprehension,
73
20
    gen_decl: &DeclarationPtr,
74
20
    replacement_expr: Expr,
75
20
) -> (Comprehension, DeclarationPtr) {
76
20
    let replacement_ptr =
77
20
        DeclarationPtr::new_quantified_expr(gen_decl.name().clone(), replacement_expr);
78
20
    let mut comprehension = comp.clone();
79

            
80
    // detach the scope so rewriting this branch does not mutate the original
81
    // comprehension through shared pointers
82
20
    comprehension.symbols = comprehension.symbols.detach();
83

            
84
    // rewrite all uses of the original quantified declaration to the branch-local
85
    // generator declaration
86
20
    comprehension.return_expression =
87
20
        comprehension
88
20
            .return_expression
89
92
            .transform_bi(&|decl: DeclarationPtr| {
90
92
                if decl == *gen_decl {
91
20
                    replacement_ptr.clone()
92
                } else {
93
72
                    decl
94
                }
95
92
            });
96

            
97
20
    comprehension.qualifiers = comprehension
98
20
        .qualifiers
99
20
        .into_iter()
100
40
        .map(|qualifier| {
101
104
            qualifier.transform_bi(&|decl: DeclarationPtr| {
102
104
                if decl == *gen_decl {
103
20
                    replacement_ptr.clone()
104
                } else {
105
84
                    decl
106
                }
107
104
            })
108
40
        })
109
20
        .collect();
110

            
111
    // keep the detached local scope in sync with the rewritten generator
112
    // declarations used by this branch
113
20
    comprehension
114
20
        .symbols
115
20
        .write()
116
20
        .update_insert(replacement_ptr.clone());
117
40
    for qualifier in &comprehension.qualifiers {
118
40
        match qualifier {
119
32
            ComprehensionQualifier::ExpressionGenerator { ptr }
120
36
            | ComprehensionQualifier::Generator { ptr } => {
121
36
                comprehension.symbols.write().update_insert(ptr.clone());
122
36
            }
123
4
            ComprehensionQualifier::Condition(_) => {}
124
        }
125
    }
126

            
127
20
    (comprehension, replacement_ptr)
128
20
}