1
use conjure_cp::{
2
    ast::{
3
        Atom, Expression as Expr, GroundDomain, Metadata, Name, SubModel, SymbolTable, serde::HasId,
4
    },
5
    bug,
6
    representation::Representation,
7
    rule_engine::{
8
        ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
9
        register_rule_set,
10
    },
11
    settings::SolverFamily,
12
};
13
use itertools::Itertools;
14
use std::sync::Arc;
15
use std::sync::atomic::AtomicBool;
16
use std::sync::atomic::Ordering;
17
use uniplate::Biplate;
18

            
19
#[cfg(feature = "smt")]
20
use conjure_cp::solver::adaptors::smt::{MatrixTheory, TheoryConfig};
21

            
22
register_rule_set!("Representations", ("Base"), |f: &SolverFamily| {
23
    #[cfg(feature = "smt")]
24
982
    if matches!(
25
135
        f,
26
        SolverFamily::Smt(TheoryConfig {
27
            matrices: MatrixTheory::Atomic,
28
            ..
29
        })
30
    ) {
31
        return true;
32
982
    }
33
982
    matches!(f, SolverFamily::Sat(_) | SolverFamily::Minion)
34
982
});
35

            
36
// special case rule to select representations for matrices in one go.
37
//
38
// we know that they only have one possible representation, so this rule adds a representation for all matrices in the model.
39
#[register_rule(("Representations", 8001))]
40
71846
fn select_representation_matrix(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
41
71846
    let Expr::Root(_, _) = expr else {
42
67483
        return Err(RuleNotApplicable);
43
    };
44

            
45
    // cannot create representations on non-local variables, so use lookup_local.
46
21091
    let matrix_vars = symbols.clone().into_iter_local().filter_map(|(n, decl)| {
47
21091
        let id = decl.id();
48
21091
        let var = decl.as_var()?.clone();
49
19678
        let resolved_domain = var.domain.resolve()?;
50

            
51
19678
        let GroundDomain::Matrix(valdom, indexdoms) = resolved_domain.as_ref() else {
52
18373
            return None;
53
        };
54

            
55
        // TODO: loosen these requirements once we are able to
56
1305
        if !matches!(valdom.as_ref(), GroundDomain::Bool | GroundDomain::Int(_)) {
57
            return None;
58
1305
        }
59

            
60
1305
        if indexdoms
61
1305
            .iter()
62
1581
            .any(|x| !matches!(x.as_ref(), GroundDomain::Bool | GroundDomain::Int(_)))
63
        {
64
            return None;
65
1305
        }
66

            
67
1305
        Some((n, id))
68
21091
    });
69

            
70
4363
    let mut symbols = symbols.clone();
71
4363
    let mut expr = expr.clone();
72
4363
    let has_changed = Arc::new(AtomicBool::new(false));
73
4363
    for (name, id) in matrix_vars {
74
        // Even if we have no references to this matrix, still give it the matrix_to_atom
75
        // representation, as we still currently need to give it to minion even if its unused.
76
        //
77
        // If this var has no represnetation yet, the below call to get_or_add will modify the
78
        // symbol table by adding the representation and represented variable declarations to the
79
        // symbol table.
80
1305
        if symbols.representations_for(&name).unwrap().is_empty() {
81
111
            has_changed.store(true, Ordering::Relaxed);
82
1194
        }
83

            
84
        // (creates the represented variables as a side effect)
85
1305
        let _ = symbols
86
1305
            .get_or_add_representation(&name, &["matrix_to_atom"])
87
1305
            .unwrap();
88

            
89
1305
        let old_name = name.clone();
90
1305
        let new_name =
91
1305
            Name::WithRepresentation(Box::new(old_name.clone()), vec!["matrix_to_atom".into()]);
92
        // give all references to this matrix this representation
93
        // also do this inside subscopes, as long as they dont define their own variable that shadows this
94
        // one.
95

            
96
1305
        let old_name_2 = old_name.clone();
97
1305
        let new_name_2 = new_name.clone();
98
1305
        let has_changed_ptr = Arc::clone(&has_changed);
99
25611
        expr = expr.transform_bi(&move |n: Name| {
100
25611
            if n == old_name_2 {
101
249
                has_changed_ptr.store(true, Ordering::SeqCst);
102
249
                new_name_2.clone()
103
            } else {
104
25362
                n
105
            }
106
25611
        });
107

            
108
1305
        let has_changed_ptr = Arc::clone(&has_changed);
109
1305
        let old_name = old_name.clone();
110
1305
        let new_name = new_name.clone();
111
1305
        expr = expr.transform_bi(&move |mut x: SubModel| {
112
240
            let old_name = old_name.clone();
113
240
            let new_name = new_name.clone();
114
240
            let has_changed_ptr = Arc::clone(&has_changed_ptr);
115

            
116
            // only do things if this inscope and not shadowed..
117
240
            if x.symbols().lookup(&old_name).is_none_or(|x| x.id() == id) {
118
240
                let root = x.root_mut_unchecked();
119
846
                *root = root.transform_bi(&move |n: Name| {
120
846
                    if n == old_name {
121
51
                        has_changed_ptr.store(true, Ordering::SeqCst);
122
51
                        new_name.clone()
123
                    } else {
124
795
                        n
125
                    }
126
846
                });
127
            }
128
240
            x
129
240
        });
130
    }
131

            
132
4363
    if has_changed.load(Ordering::Relaxed) {
133
84
        Ok(Reduction::with_symbols(expr, symbols))
134
    } else {
135
4279
        Err(RuleNotApplicable)
136
    }
137
71846
}
138

            
139
#[register_rule(("Representations", 8000))]
140
71399
fn select_representation(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
141
    // thing we are representing must be a reference
142
30853
    let Expr::Atomic(_, Atom::Reference(decl)) = expr else {
143
51767
        return Err(RuleNotApplicable);
144
    };
145

            
146
19632
    let name: Name = decl.name().clone();
147

            
148
    // thing we are representing must be a variable
149
    {
150
19632
        let guard = decl.ptr().as_var().ok_or(RuleNotApplicable)?;
151
18042
        drop(guard);
152
    }
153

            
154
18042
    if !needs_representation(&name, symbols) {
155
17730
        return Err(RuleNotApplicable);
156
312
    }
157

            
158
312
    let mut symbols = symbols.clone();
159
39
    let representation =
160
312
        get_or_create_representation(&name, &mut symbols).ok_or(RuleNotApplicable)?;
161

            
162
39
    let representation_names = representation
163
39
        .into_iter()
164
39
        .map(|x| x.repr_name().into())
165
39
        .collect_vec();
166

            
167
39
    let new_name = Name::WithRepresentation(Box::new(name.clone()), representation_names);
168

            
169
    // HACK: this is suspicious, but hopefully will work until we clean up representations
170
    // properly...
171
    //
172
    // In general, we should not use names atall anymore, including for representations /
173
    // represented variables.
174
    //
175
    // * instead of storing the link from a variable that has a representation to the variable it
176
    // is representing in the name as WithRepresentation, we should use declaration pointers instead.
177
    //
178
    //
179
    // see: issue #932
180
39
    let mut decl_ptr = decl.clone().into_ptr().detach();
181
39
    decl_ptr.replace_name(new_name);
182

            
183
39
    Ok(Reduction::with_symbols(
184
39
        Expr::Atomic(
185
39
            Metadata::new(),
186
39
            Atom::Reference(conjure_cp::ast::Reference::new(decl_ptr)),
187
39
        ),
188
39
        symbols,
189
39
    ))
190
71399
}
191

            
192
/// Returns whether `name` needs representing.
193
///
194
/// # Panics
195
///
196
///   + If `name` is not in `symbols`.
197
18042
fn needs_representation(name: &Name, symbols: &SymbolTable) -> bool {
198
    // if name already has a representation, false
199
18042
    if let Name::Represented(_) = name {
200
3252
        return false;
201
14790
    }
202
    // might be more logic here in the future?
203
14790
    domain_needs_representation(&symbols.resolve_domain(name).unwrap())
204
18042
}
205

            
206
/// Returns whether `domain` needs representing.
207
15414
fn domain_needs_representation(domain: &GroundDomain) -> bool {
208
    // very simple implementation for nows
209
15414
    match domain {
210
11841
        GroundDomain::Bool | GroundDomain::Int(_) => false,
211
3261
        GroundDomain::Matrix(_, _) => false, // we special case these elsewhere
212
        GroundDomain::Set(_, _)
213
        | GroundDomain::MSet(_, _)
214
        | GroundDomain::Tuple(_)
215
        | GroundDomain::Record(_)
216
312
        | GroundDomain::Function(_, _, _) => true,
217
        GroundDomain::Empty(_) => false,
218
    }
219
15414
}
220

            
221
/// Returns representations for `name`, creating them if they don't exist.
222
///
223
///
224
/// Returns None if there is no valid representation for `name`.
225
///
226
/// # Panics
227
///
228
///   + If `name` is not in `symbols`.
229
312
fn get_or_create_representation(
230
312
    name: &Name,
231
312
    symbols: &mut SymbolTable,
232
312
) -> Option<Vec<Box<dyn Representation>>> {
233
    // TODO: pick representations recursively for nested abstract domains: e.g. sets in sets.
234

            
235
312
    let dom = symbols.resolve_domain(name).unwrap();
236
312
    match dom.as_ref() {
237
        GroundDomain::Set(_, _) => None, // has no representations yet!
238
144
        GroundDomain::Tuple(elem_domains) => {
239
144
            if elem_domains
240
144
                .iter()
241
288
                .any(|d| domain_needs_representation(d.as_ref()))
242
            {
243
                bug!("representing nested abstract domains is not implemented");
244
144
            }
245

            
246
144
            symbols.get_or_add_representation(name, &["tuple_to_atom"])
247
        }
248
168
        GroundDomain::Record(entries) => {
249
168
            if entries
250
168
                .iter()
251
336
                .any(|entry| domain_needs_representation(&entry.domain))
252
            {
253
                bug!("representing nested abstract domains is not implemented");
254
168
            }
255

            
256
168
            symbols.get_or_add_representation(name, &["record_to_atom"])
257
        }
258
        _ => unreachable!("non abstract domains should never need representations"),
259
    }
260
312
}