1
use conjure_cp::{
2
    ast::{Atom, Expression as Expr, GroundDomain, Metadata, Name, SymbolTable, serde::HasId},
3
    bug,
4
    representation::Representation,
5
    rule_engine::{
6
        ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
7
        register_rule_set,
8
    },
9
    settings::SolverFamily,
10
};
11
use itertools::Itertools;
12
use std::sync::Arc;
13
use std::sync::atomic::AtomicBool;
14
use std::sync::atomic::Ordering;
15
use uniplate::Biplate;
16

            
17
use conjure_cp::solver::adaptors::smt::{MatrixTheory, TheoryConfig};
18

            
19
register_rule_set!("Representations", ("Base"), |f: &SolverFamily| {
20
7188
    if matches!(
21
1438
        f,
22
        SolverFamily::Smt(TheoryConfig {
23
            matrices: MatrixTheory::Atomic,
24
            ..
25
        })
26
    ) {
27
450
        return true;
28
7188
    }
29
7188
    matches!(f, SolverFamily::Sat(_) | SolverFamily::Minion)
30
7638
});
31

            
32
// special case rule to select representations for matrices in one go.
33
//
34
// we know that they only have one possible representation, so this rule adds a representation for all matrices in the model.
35
#[register_rule("Representations", 8001, [Root])]
36
1154070
fn select_representation_matrix(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
37
1154070
    let Expr::Root(_, _) = expr else {
38
1120416
        return Err(RuleNotApplicable);
39
    };
40

            
41
    // cannot create representations on non-local variables, so use lookup_local.
42
1364804
    let matrix_vars = symbols.clone().into_iter_local().filter_map(|(n, decl)| {
43
1364804
        let id = decl.id();
44
1364804
        let var = decl.as_find()?.clone();
45
1348314
        let resolved_domain = var.domain.resolve()?;
46

            
47
1348314
        let GroundDomain::Matrix(valdom, indexdoms) = resolved_domain.as_ref() else {
48
1311192
            return None;
49
        };
50

            
51
        // TODO: loosen these requirements once we are able to
52
37122
        if !matches!(valdom.as_ref(), GroundDomain::Bool | GroundDomain::Int(_)) {
53
4
            return None;
54
37118
        }
55

            
56
37118
        if indexdoms
57
37118
            .iter()
58
54252
            .any(|x| !matches!(x.as_ref(), GroundDomain::Bool | GroundDomain::Int(_)))
59
        {
60
            return None;
61
37118
        }
62

            
63
37118
        Some((n, id))
64
1364804
    });
65

            
66
33654
    let mut symbols = symbols.clone();
67
33654
    let mut expr = expr.clone();
68
33654
    let has_changed = Arc::new(AtomicBool::new(false));
69
38132
    for (name, _id) in matrix_vars {
70
        // Even if we have no references to this matrix, still give it the matrix_to_atom
71
        // representation, as we still currently need to give it to minion even if its unused.
72
        //
73
        // If this var has no represnetation yet, the below call to get_or_add will modify the
74
        // symbol table by adding the representation and represented variable declarations to the
75
        // symbol table.
76
37118
        if symbols.representations_for(&name).unwrap().is_empty() {
77
1302
            has_changed.store(true, Ordering::Relaxed);
78
35816
        }
79

            
80
        // (creates the represented variables as a side effect)
81
37118
        let _ = symbols
82
37118
            .get_or_add_representation(&name, &["matrix_to_atom"])
83
37118
            .unwrap();
84

            
85
37118
        let old_name = name.clone();
86
37118
        let new_name =
87
37118
            Name::WithRepresentation(Box::new(old_name.clone()), vec!["matrix_to_atom".into()]);
88
        // give all references to this matrix this representation
89
        // also do this inside subscopes, as long as they dont define their own variable that shadows this
90
        // one.
91

            
92
37118
        let old_name_2 = old_name.clone();
93
37118
        let new_name_2 = new_name.clone();
94
37118
        let has_changed_ptr = Arc::clone(&has_changed);
95
1756056
        expr = expr.transform_bi(&move |n: Name| {
96
1756056
            if n == old_name_2 {
97
3440
                has_changed_ptr.store(true, Ordering::SeqCst);
98
3440
                new_name_2.clone()
99
            } else {
100
1752616
                n
101
            }
102
1756056
        });
103
    }
104

            
105
33654
    if has_changed.load(Ordering::Relaxed) {
106
1334
        Ok(Reduction::with_symbols(expr, symbols))
107
    } else {
108
32320
        Err(RuleNotApplicable)
109
    }
110
1154070
}
111

            
112
#[register_rule("Representations", 8000, [Atomic])]
113
1151284
fn select_representation(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
114
    // thing we are representing must be a reference
115
555104
    let Expr::Atomic(_, Atom::Reference(decl)) = expr else {
116
753300
        return Err(RuleNotApplicable);
117
    };
118

            
119
397984
    let name: Name = decl.name().clone();
120

            
121
    // thing we are representing must be a variable
122
    {
123
397984
        let guard = decl.ptr().as_find().ok_or(RuleNotApplicable)?;
124
394297
        drop(guard);
125
    }
126

            
127
394297
    if !needs_representation(&name, symbols) {
128
393049
        return Err(RuleNotApplicable);
129
1248
    }
130

            
131
1248
    let mut symbols = symbols.clone();
132
156
    let representation =
133
1248
        get_or_create_representation(&name, &mut symbols).ok_or(RuleNotApplicable)?;
134

            
135
156
    let representation_names = representation
136
156
        .into_iter()
137
156
        .map(|x| x.repr_name().into())
138
156
        .collect_vec();
139

            
140
156
    let new_name = Name::WithRepresentation(Box::new(name.clone()), representation_names);
141

            
142
    // HACK: this is suspicious, but hopefully will work until we clean up representations
143
    // properly...
144
    //
145
    // In general, we should not use names atall anymore, including for representations /
146
    // represented variables.
147
    //
148
    // * instead of storing the link from a variable that has a representation to the variable it
149
    // is representing in the name as WithRepresentation, we should use declaration pointers instead.
150
    //
151
    //
152
    // see: issue #932
153
156
    let mut decl_ptr = decl.clone().into_ptr().detach();
154
156
    decl_ptr.replace_name(new_name);
155

            
156
156
    Ok(Reduction::with_symbols(
157
156
        Expr::Atomic(
158
156
            Metadata::new(),
159
156
            Atom::Reference(conjure_cp::ast::Reference::new(decl_ptr)),
160
156
        ),
161
156
        symbols,
162
156
    ))
163
1151284
}
164

            
165
/// Returns whether `name` needs representing.
166
///
167
394297
fn needs_representation(name: &Name, symbols: &SymbolTable) -> bool {
168
    // if name already has a representation, false
169
394297
    if let Name::Represented(_) = name {
170
259252
        return false;
171
135045
    }
172
    // might be more logic here in the future?
173
135045
    symbols
174
135045
        .resolve_domain(name)
175
135045
        .is_some_and(|domain| domain_needs_representation(domain.as_ref()))
176
394297
}
177

            
178
/// Returns whether `domain` needs representing.
179
137541
fn domain_needs_representation(domain: &GroundDomain) -> bool {
180
    // very simple implementation for nows
181
137541
    match domain {
182
103465
        GroundDomain::Bool | GroundDomain::Int(_) => false,
183
32828
        GroundDomain::Matrix(_, _) => false, // we special case these elsewhere
184
        GroundDomain::Set(_, _)
185
        | GroundDomain::MSet(_, _)
186
        | GroundDomain::Tuple(_)
187
        | GroundDomain::Record(_)
188
1248
        | GroundDomain::Function(_, _, _) => true,
189
        GroundDomain::Empty(_) => false,
190
    }
191
137541
}
192

            
193
/// Returns representations for `name`, creating them if they don't exist.
194
///
195
///
196
/// Returns None if there is no valid representation for `name`.
197
///
198
1248
fn get_or_create_representation(
199
1248
    name: &Name,
200
1248
    symbols: &mut SymbolTable,
201
1248
) -> Option<Vec<Box<dyn Representation>>> {
202
    // TODO: pick representations recursively for nested abstract domains: e.g. sets in sets.
203

            
204
1248
    let dom = symbols.resolve_domain(name)?;
205
1248
    match dom.as_ref() {
206
        GroundDomain::Set(_, _) => None, // has no representations yet!
207
576
        GroundDomain::Tuple(elem_domains) => {
208
576
            if elem_domains
209
576
                .iter()
210
1152
                .any(|d| domain_needs_representation(d.as_ref()))
211
            {
212
                bug!("representing nested abstract domains is not implemented");
213
576
            }
214

            
215
576
            symbols.get_or_add_representation(name, &["tuple_to_atom"])
216
        }
217
672
        GroundDomain::Record(entries) => {
218
672
            if entries
219
672
                .iter()
220
1344
                .any(|entry| domain_needs_representation(&entry.domain))
221
            {
222
                bug!("representing nested abstract domains is not implemented");
223
672
            }
224

            
225
672
            symbols.get_or_add_representation(name, &["record_to_atom"])
226
        }
227
        _ => unreachable!("non abstract domains should never need representations"),
228
    }
229
1248
}