1
use conjure_cp::{
2
    ast::{
3
        Atom, Expression as Expr, GroundDomain, Metadata, Name, SubModel, SymbolTable, serde::HasId,
4
    },
5
    bug,
6
    representation::Representation,
7
    rule_engine::{
8
        ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
9
        register_rule_set,
10
    },
11
    solver::{
12
        SolverFamily,
13
        adaptors::smt::{MatrixTheory, TheoryConfig},
14
    },
15
};
16
use itertools::Itertools;
17
use std::sync::Arc;
18
use std::sync::atomic::AtomicBool;
19
use std::sync::atomic::Ordering;
20
use uniplate::Biplate;
21

            
22
register_rule_set!("Representations", ("Base"), |f: &SolverFamily| matches!(
23
    f,
24
    SolverFamily::Sat
25
        | SolverFamily::Minion
26
        | SolverFamily::Smt(TheoryConfig {
27
            matrices: MatrixTheory::Atomic,
28
            ..
29
        })
30
));
31

            
32
// special case rule to select representations for matrices in one go.
33
//
34
// we know that they only have one possible representation, so this rule adds a representation for all matrices in the model.
35
#[register_rule(("Representations", 8001))]
36
fn select_representation_matrix(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
37
    let Expr::Root(_, _) = expr else {
38
        return Err(RuleNotApplicable);
39
    };
40

            
41
    // cannot create representations on non-local variables, so use lookup_local.
42
    let matrix_vars = symbols
43
        .clone()
44
        .into_iter_local()
45
        .filter_map(|(n, decl)| {
46
            let id = decl.id();
47
            decl.as_var().map(|x| (n, id, x.clone()))
48
        })
49
        .filter(|(_, _, var)| {
50
            let Some((valdom, indexdoms)) = var.domain.as_matrix_ground() else {
51
                return false;
52
            };
53

            
54
            // TODO: loosen these requirements once we are able to
55
            if !matches!(valdom.as_ref(), GroundDomain::Bool | GroundDomain::Int(_)) {
56
                return false;
57
            }
58

            
59
            if indexdoms
60
                .iter()
61
                .any(|x| !matches!(x.as_ref(), GroundDomain::Bool | GroundDomain::Int(_)))
62
            {
63
                return false;
64
            }
65

            
66
            true
67
        });
68

            
69
    let mut symbols = symbols.clone();
70
    let mut expr = expr.clone();
71
    let has_changed = Arc::new(AtomicBool::new(false));
72
    for (name, id, _) in matrix_vars {
73
        // Even if we have no references to this matrix, still give it the matrix_to_atom
74
        // representation, as we still currently need to give it to minion even if its unused.
75
        //
76
        // If this var has no represnetation yet, the below call to get_or_add will modify the
77
        // symbol table by adding the representation and represented variable declarations to the
78
        // symbol table.
79
        if symbols.representations_for(&name).unwrap().is_empty() {
80
            has_changed.store(true, Ordering::Relaxed);
81
        }
82

            
83
        // (creates the represented variables as a side effect)
84
        let _ = symbols
85
            .get_or_add_representation(&name, &["matrix_to_atom"])
86
            .unwrap();
87

            
88
        let old_name = name.clone();
89
        let new_name =
90
            Name::WithRepresentation(Box::new(old_name.clone()), vec!["matrix_to_atom".into()]);
91
        // give all references to this matrix this representation
92
        // also do this inside subscopes, as long as they dont define their own variable that shadows this
93
        // one.
94

            
95
        let old_name_2 = old_name.clone();
96
        let new_name_2 = new_name.clone();
97
        let has_changed_ptr = Arc::clone(&has_changed);
98
        expr = expr.transform_bi(&move |n: Name| {
99
            if n == old_name_2 {
100
                has_changed_ptr.store(true, Ordering::SeqCst);
101
                new_name_2.clone()
102
            } else {
103
                n
104
            }
105
        });
106

            
107
        let has_changed_ptr = Arc::clone(&has_changed);
108
        let old_name = old_name.clone();
109
        let new_name = new_name.clone();
110
        expr = expr.transform_bi(&move |mut x: SubModel| {
111
            let old_name = old_name.clone();
112
            let new_name = new_name.clone();
113
            let has_changed_ptr = Arc::clone(&has_changed_ptr);
114

            
115
            // only do things if this inscope and not shadowed..
116
            if x.symbols().lookup(&old_name).is_none_or(|x| x.id() == id) {
117
                let root = x.root_mut_unchecked();
118
                *root = root.transform_bi(&move |n: Name| {
119
                    if n == old_name {
120
                        has_changed_ptr.store(true, Ordering::SeqCst);
121
                        new_name.clone()
122
                    } else {
123
                        n
124
                    }
125
                });
126
            }
127
            x
128
        });
129
    }
130

            
131
    if has_changed.load(Ordering::Relaxed) {
132
        Ok(Reduction::with_symbols(expr, symbols))
133
    } else {
134
        Err(RuleNotApplicable)
135
    }
136
}
137

            
138
#[register_rule(("Representations", 8000))]
139
fn select_representation(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
140
    // thing we are representing must be a reference
141
    let Expr::Atomic(_, Atom::Reference(decl)) = expr else {
142
        return Err(RuleNotApplicable);
143
    };
144

            
145
    let name: Name = decl.name().clone();
146

            
147
    // thing we are representing must be a variable
148
    {
149
        decl.ptr().as_var().ok_or(RuleNotApplicable)?;
150
    }
151

            
152
    if !needs_representation(&name, symbols) {
153
        return Err(RuleNotApplicable);
154
    }
155

            
156
    let mut symbols = symbols.clone();
157
    let representation =
158
        get_or_create_representation(&name, &mut symbols).ok_or(RuleNotApplicable)?;
159

            
160
    let representation_names = representation
161
        .into_iter()
162
        .map(|x| x.repr_name().into())
163
        .collect_vec();
164

            
165
    let new_name = Name::WithRepresentation(Box::new(name.clone()), representation_names);
166

            
167
    // HACK: this is suspicious, but hopefully will work until we clean up representations
168
    // properly...
169
    //
170
    // In general, we should not use names atall anymore, including for representations /
171
    // represented variables.
172
    //
173
    // * instead of storing the link from a variable that has a representation to the variable it
174
    // is representing in the name as WithRepresentation, we should use declaration pointers instead.
175
    //
176
    //
177
    // see: issue #932
178
    let mut decl_ptr = decl.clone().into_ptr().detach();
179
    decl_ptr.replace_name(new_name);
180

            
181
    Ok(Reduction::with_symbols(
182
        Expr::Atomic(
183
            Metadata::new(),
184
            Atom::Reference(conjure_cp::ast::Reference::new(decl_ptr)),
185
        ),
186
        symbols,
187
    ))
188
}
189

            
190
/// Returns whether `name` needs representing.
191
///
192
/// # Panics
193
///
194
///   + If `name` is not in `symbols`.
195
fn needs_representation(name: &Name, symbols: &SymbolTable) -> bool {
196
    // if name already has a representation, false
197
    if let Name::Represented(_) = name {
198
        return false;
199
    }
200
    // might be more logic here in the future?
201
    domain_needs_representation(&symbols.resolve_domain(name).unwrap())
202
}
203

            
204
/// Returns whether `domain` needs representing.
205
fn domain_needs_representation(domain: &GroundDomain) -> bool {
206
    // very simple implementation for nows
207
    match domain {
208
        GroundDomain::Bool | GroundDomain::Int(_) => false,
209
        GroundDomain::Matrix(_, _) => false, // we special case these elsewhere
210
        GroundDomain::Set(_, _)
211
        | GroundDomain::Tuple(_)
212
        | GroundDomain::Record(_)
213
        | GroundDomain::Function(_, _, _) => true,
214
        GroundDomain::Empty(_) => false,
215
    }
216
}
217

            
218
/// Returns representations for `name`, creating them if they don't exist.
219
///
220
///
221
/// Returns None if there is no valid representation for `name`.
222
///
223
/// # Panics
224
///
225
///   + If `name` is not in `symbols`.
226
fn get_or_create_representation(
227
    name: &Name,
228
    symbols: &mut SymbolTable,
229
) -> Option<Vec<Box<dyn Representation>>> {
230
    // TODO: pick representations recursively for nested abstract domains: e.g. sets in sets.
231

            
232
    let dom = symbols.resolve_domain(name).unwrap();
233
    match dom.as_ref() {
234
        GroundDomain::Set(_, _) => None, // has no representations yet!
235
        GroundDomain::Tuple(elem_domains) => {
236
            if elem_domains
237
                .iter()
238
                .any(|d| domain_needs_representation(d.as_ref()))
239
            {
240
                bug!("representing nested abstract domains is not implemented");
241
            }
242

            
243
            symbols.get_or_add_representation(name, &["tuple_to_atom"])
244
        }
245
        GroundDomain::Record(entries) => {
246
            if entries
247
                .iter()
248
                .any(|entry| domain_needs_representation(&entry.domain))
249
            {
250
                bug!("representing nested abstract domains is not implemented");
251
            }
252

            
253
            symbols.get_or_add_representation(name, &["record_to_atom"])
254
        }
255
        _ => unreachable!("non abstract domains should never need representations"),
256
    }
257
}