1
use conjure_cp::{
2
    ast::{Atom, Expression as Expr, GroundDomain, Metadata, Name, SymbolTable, serde::HasId},
3
    bug,
4
    representation::Representation,
5
    rule_engine::{
6
        ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
7
        register_rule_set,
8
    },
9
    settings::SolverFamily,
10
};
11
use itertools::Itertools;
12
use std::sync::Arc;
13
use std::sync::atomic::AtomicBool;
14
use std::sync::atomic::Ordering;
15
use uniplate::Biplate;
16

            
17
use conjure_cp::solver::adaptors::smt::{MatrixTheory, TheoryConfig};
18

            
19
register_rule_set!("Representations", ("Base"), |f: &SolverFamily| {
20
5115
    if matches!(
21
477
        f,
22
        SolverFamily::Smt(TheoryConfig {
23
            matrices: MatrixTheory::Atomic,
24
            ..
25
        })
26
    ) {
27
36
        return true;
28
5115
    }
29
5115
    matches!(f, SolverFamily::Sat(_) | SolverFamily::Minion)
30
5151
});
31

            
32
// special case rule to select representations for matrices in one go.
33
//
34
// we know that they only have one possible representation, so this rule adds a representation for all matrices in the model.
35
#[register_rule(("Representations", 8001))]
36
748546
fn select_representation_matrix(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
37
748546
    let Expr::Root(_, _) = expr else {
38
719204
        return Err(RuleNotApplicable);
39
    };
40

            
41
    // cannot create representations on non-local variables, so use lookup_local.
42
835137
    let matrix_vars = symbols.clone().into_iter_local().filter_map(|(n, decl)| {
43
835137
        let id = decl.id();
44
835137
        let var = decl.as_find()?.clone();
45
828117
        let resolved_domain = var.domain.resolve()?;
46

            
47
828117
        let GroundDomain::Matrix(valdom, indexdoms) = resolved_domain.as_ref() else {
48
790298
            return None;
49
        };
50

            
51
        // TODO: loosen these requirements once we are able to
52
37819
        if !matches!(valdom.as_ref(), GroundDomain::Bool | GroundDomain::Int(_)) {
53
            return None;
54
37819
        }
55

            
56
37819
        if indexdoms
57
37819
            .iter()
58
58528
            .any(|x| !matches!(x.as_ref(), GroundDomain::Bool | GroundDomain::Int(_)))
59
        {
60
            return None;
61
37819
        }
62

            
63
37819
        Some((n, id))
64
835137
    });
65

            
66
29342
    let mut symbols = symbols.clone();
67
29342
    let mut expr = expr.clone();
68
29342
    let has_changed = Arc::new(AtomicBool::new(false));
69
37841
    for (name, _id) in matrix_vars {
70
        // Even if we have no references to this matrix, still give it the matrix_to_atom
71
        // representation, as we still currently need to give it to minion even if its unused.
72
        //
73
        // If this var has no represnetation yet, the below call to get_or_add will modify the
74
        // symbol table by adding the representation and represented variable declarations to the
75
        // symbol table.
76
37819
        if symbols.representations_for(&name).unwrap().is_empty() {
77
1177
            has_changed.store(true, Ordering::Relaxed);
78
36642
        }
79

            
80
        // (creates the represented variables as a side effect)
81
37819
        let _ = symbols
82
37819
            .get_or_add_representation(&name, &["matrix_to_atom"])
83
37819
            .unwrap();
84

            
85
37819
        let old_name = name.clone();
86
37819
        let new_name =
87
37819
            Name::WithRepresentation(Box::new(old_name.clone()), vec!["matrix_to_atom".into()]);
88
        // give all references to this matrix this representation
89
        // also do this inside subscopes, as long as they dont define their own variable that shadows this
90
        // one.
91

            
92
37819
        let old_name_2 = old_name.clone();
93
37819
        let new_name_2 = new_name.clone();
94
37819
        let has_changed_ptr = Arc::clone(&has_changed);
95
776715
        expr = expr.transform_bi(&move |n: Name| {
96
776715
            if n == old_name_2 {
97
2928
                has_changed_ptr.store(true, Ordering::SeqCst);
98
2928
                new_name_2.clone()
99
            } else {
100
773787
                n
101
            }
102
776715
        });
103
    }
104

            
105
29342
    if has_changed.load(Ordering::Relaxed) {
106
1257
        Ok(Reduction::with_symbols(expr, symbols))
107
    } else {
108
28085
        Err(RuleNotApplicable)
109
    }
110
748546
}
111

            
112
#[register_rule(("Representations", 8000))]
113
746200
fn select_representation(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
114
    // thing we are representing must be a reference
115
355709
    let Expr::Atomic(_, Atom::Reference(decl)) = expr else {
116
498451
        return Err(RuleNotApplicable);
117
    };
118

            
119
247749
    let name: Name = decl.name().clone();
120

            
121
    // thing we are representing must be a variable
122
    {
123
247749
        let guard = decl.ptr().as_find().ok_or(RuleNotApplicable)?;
124
247749
        drop(guard);
125
    }
126

            
127
247749
    if !needs_representation(&name, symbols) {
128
246813
        return Err(RuleNotApplicable);
129
936
    }
130

            
131
936
    let mut symbols = symbols.clone();
132
117
    let representation =
133
936
        get_or_create_representation(&name, &mut symbols).ok_or(RuleNotApplicable)?;
134

            
135
117
    let representation_names = representation
136
117
        .into_iter()
137
117
        .map(|x| x.repr_name().into())
138
117
        .collect_vec();
139

            
140
117
    let new_name = Name::WithRepresentation(Box::new(name.clone()), representation_names);
141

            
142
    // HACK: this is suspicious, but hopefully will work until we clean up representations
143
    // properly...
144
    //
145
    // In general, we should not use names atall anymore, including for representations /
146
    // represented variables.
147
    //
148
    // * instead of storing the link from a variable that has a representation to the variable it
149
    // is representing in the name as WithRepresentation, we should use declaration pointers instead.
150
    //
151
    //
152
    // see: issue #932
153
117
    let mut decl_ptr = decl.clone().into_ptr().detach();
154
117
    decl_ptr.replace_name(new_name);
155

            
156
117
    Ok(Reduction::with_symbols(
157
117
        Expr::Atomic(
158
117
            Metadata::new(),
159
117
            Atom::Reference(conjure_cp::ast::Reference::new(decl_ptr)),
160
117
        ),
161
117
        symbols,
162
117
    ))
163
746200
}
164

            
165
/// Returns whether `name` needs representing.
166
///
167
/// # Panics
168
///
169
///   + If `name` is not in `symbols`.
170
247749
fn needs_representation(name: &Name, symbols: &SymbolTable) -> bool {
171
    // if name already has a representation, false
172
247749
    if let Name::Represented(_) = name {
173
130557
        return false;
174
117192
    }
175
    // might be more logic here in the future?
176
117192
    domain_needs_representation(&symbols.resolve_domain(name).unwrap())
177
247749
}
178

            
179
/// Returns whether `domain` needs representing.
180
119064
fn domain_needs_representation(domain: &GroundDomain) -> bool {
181
    // very simple implementation for nows
182
119064
    match domain {
183
84105
        GroundDomain::Bool | GroundDomain::Int(_) => false,
184
34023
        GroundDomain::Matrix(_, _) => false, // we special case these elsewhere
185
        GroundDomain::Set(_, _)
186
        | GroundDomain::MSet(_, _)
187
        | GroundDomain::Tuple(_)
188
        | GroundDomain::Record(_)
189
936
        | GroundDomain::Function(_, _, _) => true,
190
        GroundDomain::Empty(_) => false,
191
    }
192
119064
}
193

            
194
/// Returns representations for `name`, creating them if they don't exist.
195
///
196
///
197
/// Returns None if there is no valid representation for `name`.
198
///
199
/// # Panics
200
///
201
///   + If `name` is not in `symbols`.
202
624
fn get_or_create_representation(
203
624
    name: &Name,
204
624
    symbols: &mut SymbolTable,
205
624
) -> Option<Vec<Box<dyn Representation>>> {
206
    // TODO: pick representations recursively for nested abstract domains: e.g. sets in sets.
207

            
208
624
    let dom = symbols.resolve_domain(name).unwrap();
209
624
    match dom.as_ref() {
210
        GroundDomain::Set(_, _) => None, // has no representations yet!
211
288
        GroundDomain::Tuple(elem_domains) => {
212
288
            if elem_domains
213
288
                .iter()
214
576
                .any(|d| domain_needs_representation(d.as_ref()))
215
            {
216
                bug!("representing nested abstract domains is not implemented");
217
288
            }
218

            
219
288
            symbols.get_or_add_representation(name, &["tuple_to_atom"])
220
        }
221
336
        GroundDomain::Record(entries) => {
222
336
            if entries
223
336
                .iter()
224
672
                .any(|entry| domain_needs_representation(&entry.domain))
225
            {
226
                bug!("representing nested abstract domains is not implemented");
227
336
            }
228

            
229
336
            symbols.get_or_add_representation(name, &["record_to_atom"])
230
        }
231
        _ => unreachable!("non abstract domains should never need representations"),
232
    }
233
624
}