1
use conjure_cp::{
2
    ast::{Atom, Expression as Expr, GroundDomain, Metadata, Name, SymbolTable, serde::HasId},
3
    bug,
4
    representation::Representation,
5
    rule_engine::{
6
        ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
7
        register_rule_set,
8
    },
9
    settings::SolverFamily,
10
};
11
use itertools::Itertools;
12
use std::sync::Arc;
13
use std::sync::atomic::AtomicBool;
14
use std::sync::atomic::Ordering;
15
use uniplate::Biplate;
16

            
17
use conjure_cp::solver::adaptors::smt::{MatrixTheory, TheoryConfig};
18

            
19
register_rule_set!("Representations", ("Base"), |f: &SolverFamily| {
20
1722
    if matches!(
21
159
        f,
22
        SolverFamily::Smt(TheoryConfig {
23
            matrices: MatrixTheory::Atomic,
24
            ..
25
        })
26
    ) {
27
12
        return true;
28
1722
    }
29
1722
    matches!(f, SolverFamily::Sat(_) | SolverFamily::Minion)
30
1734
});
31

            
32
// special case rule to select representations for matrices in one go.
33
//
34
// we know that they only have one possible representation, so this rule adds a representation for all matrices in the model.
35
#[register_rule(("Representations", 8001))]
36
254978
fn select_representation_matrix(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
37
254978
    let Expr::Root(_, _) = expr else {
38
245113
        return Err(RuleNotApplicable);
39
    };
40

            
41
    // cannot create representations on non-local variables, so use lookup_local.
42
279480
    let matrix_vars = symbols.clone().into_iter_local().filter_map(|(n, decl)| {
43
279480
        let id = decl.id();
44
279480
        let var = decl.as_find()?.clone();
45
277116
        let resolved_domain = var.domain.resolve()?;
46

            
47
277116
        let GroundDomain::Matrix(valdom, indexdoms) = resolved_domain.as_ref() else {
48
264421
            return None;
49
        };
50

            
51
        // TODO: loosen these requirements once we are able to
52
12695
        if !matches!(valdom.as_ref(), GroundDomain::Bool | GroundDomain::Int(_)) {
53
            return None;
54
12695
        }
55

            
56
12695
        if indexdoms
57
12695
            .iter()
58
19598
            .any(|x| !matches!(x.as_ref(), GroundDomain::Bool | GroundDomain::Int(_)))
59
        {
60
            return None;
61
12695
        }
62

            
63
12695
        Some((n, id))
64
279480
    });
65

            
66
9865
    let mut symbols = symbols.clone();
67
9865
    let mut expr = expr.clone();
68
9865
    let has_changed = Arc::new(AtomicBool::new(false));
69
12709
    for (name, _id) in matrix_vars {
70
        // Even if we have no references to this matrix, still give it the matrix_to_atom
71
        // representation, as we still currently need to give it to minion even if its unused.
72
        //
73
        // If this var has no represnetation yet, the below call to get_or_add will modify the
74
        // symbol table by adding the representation and represented variable declarations to the
75
        // symbol table.
76
12695
        if symbols.representations_for(&name).unwrap().is_empty() {
77
395
            has_changed.store(true, Ordering::Relaxed);
78
12300
        }
79

            
80
        // (creates the represented variables as a side effect)
81
12695
        let _ = symbols
82
12695
            .get_or_add_representation(&name, &["matrix_to_atom"])
83
12695
            .unwrap();
84

            
85
12695
        let old_name = name.clone();
86
12695
        let new_name =
87
12695
            Name::WithRepresentation(Box::new(old_name.clone()), vec!["matrix_to_atom".into()]);
88
        // give all references to this matrix this representation
89
        // also do this inside subscopes, as long as they dont define their own variable that shadows this
90
        // one.
91

            
92
12695
        let old_name_2 = old_name.clone();
93
12695
        let new_name_2 = new_name.clone();
94
12695
        let has_changed_ptr = Arc::clone(&has_changed);
95
263841
        expr = expr.transform_bi(&move |n: Name| {
96
263841
            if n == old_name_2 {
97
996
                has_changed_ptr.store(true, Ordering::SeqCst);
98
996
                new_name_2.clone()
99
            } else {
100
262845
                n
101
            }
102
263841
        });
103
    }
104

            
105
9865
    if has_changed.load(Ordering::Relaxed) {
106
423
        Ok(Reduction::with_symbols(expr, symbols))
107
    } else {
108
9442
        Err(RuleNotApplicable)
109
    }
110
254978
}
111

            
112
#[register_rule(("Representations", 8000))]
113
254192
fn select_representation(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
114
    // thing we are representing must be a reference
115
121300
    let Expr::Atomic(_, Atom::Reference(decl)) = expr else {
116
169379
        return Err(RuleNotApplicable);
117
    };
118

            
119
84813
    let name: Name = decl.name().clone();
120

            
121
    // thing we are representing must be a variable
122
    {
123
84813
        let guard = decl.ptr().as_find().ok_or(RuleNotApplicable)?;
124
84813
        drop(guard);
125
    }
126

            
127
84813
    if !needs_representation(&name, symbols) {
128
84501
        return Err(RuleNotApplicable);
129
312
    }
130

            
131
312
    let mut symbols = symbols.clone();
132
39
    let representation =
133
312
        get_or_create_representation(&name, &mut symbols).ok_or(RuleNotApplicable)?;
134

            
135
39
    let representation_names = representation
136
39
        .into_iter()
137
39
        .map(|x| x.repr_name().into())
138
39
        .collect_vec();
139

            
140
39
    let new_name = Name::WithRepresentation(Box::new(name.clone()), representation_names);
141

            
142
    // HACK: this is suspicious, but hopefully will work until we clean up representations
143
    // properly...
144
    //
145
    // In general, we should not use names atall anymore, including for representations /
146
    // represented variables.
147
    //
148
    // * instead of storing the link from a variable that has a representation to the variable it
149
    // is representing in the name as WithRepresentation, we should use declaration pointers instead.
150
    //
151
    //
152
    // see: issue #932
153
39
    let mut decl_ptr = decl.clone().into_ptr().detach();
154
39
    decl_ptr.replace_name(new_name);
155

            
156
39
    Ok(Reduction::with_symbols(
157
39
        Expr::Atomic(
158
39
            Metadata::new(),
159
39
            Atom::Reference(conjure_cp::ast::Reference::new(decl_ptr)),
160
39
        ),
161
39
        symbols,
162
39
    ))
163
254192
}
164

            
165
/// Returns whether `name` needs representing.
166
///
167
/// # Panics
168
///
169
///   + If `name` is not in `symbols`.
170
84813
fn needs_representation(name: &Name, symbols: &SymbolTable) -> bool {
171
    // if name already has a representation, false
172
84813
    if let Name::Represented(_) = name {
173
44631
        return false;
174
40182
    }
175
    // might be more logic here in the future?
176
40182
    domain_needs_representation(&symbols.resolve_domain(name).unwrap())
177
84813
}
178

            
179
/// Returns whether `domain` needs representing.
180
40806
fn domain_needs_representation(domain: &GroundDomain) -> bool {
181
    // very simple implementation for nows
182
40806
    match domain {
183
28941
        GroundDomain::Bool | GroundDomain::Int(_) => false,
184
11553
        GroundDomain::Matrix(_, _) => false, // we special case these elsewhere
185
        GroundDomain::Set(_, _)
186
        | GroundDomain::MSet(_, _)
187
        | GroundDomain::Tuple(_)
188
        | GroundDomain::Record(_)
189
312
        | GroundDomain::Function(_, _, _) => true,
190
        GroundDomain::Empty(_) => false,
191
    }
192
40806
}
193

            
194
/// Returns representations for `name`, creating them if they don't exist.
195
///
196
///
197
/// Returns None if there is no valid representation for `name`.
198
///
199
/// # Panics
200
///
201
///   + If `name` is not in `symbols`.
202
312
fn get_or_create_representation(
203
312
    name: &Name,
204
312
    symbols: &mut SymbolTable,
205
312
) -> Option<Vec<Box<dyn Representation>>> {
206
    // TODO: pick representations recursively for nested abstract domains: e.g. sets in sets.
207

            
208
312
    let dom = symbols.resolve_domain(name).unwrap();
209
312
    match dom.as_ref() {
210
        GroundDomain::Set(_, _) => None, // has no representations yet!
211
144
        GroundDomain::Tuple(elem_domains) => {
212
144
            if elem_domains
213
144
                .iter()
214
288
                .any(|d| domain_needs_representation(d.as_ref()))
215
            {
216
                bug!("representing nested abstract domains is not implemented");
217
144
            }
218

            
219
144
            symbols.get_or_add_representation(name, &["tuple_to_atom"])
220
        }
221
168
        GroundDomain::Record(entries) => {
222
168
            if entries
223
168
                .iter()
224
336
                .any(|entry| domain_needs_representation(&entry.domain))
225
            {
226
                bug!("representing nested abstract domains is not implemented");
227
168
            }
228

            
229
168
            symbols.get_or_add_representation(name, &["record_to_atom"])
230
        }
231
        _ => unreachable!("non abstract domains should never need representations"),
232
    }
233
312
}