1
use conjure_cp::ast::{DeclarationPtr, DomainPtr, GroundDomain, Moo};
2
use conjure_cp::parse::tree_sitter::parse_literal;
3
use itertools::{Itertools, izip};
4
use std::collections::BTreeMap;
5

            
6
use super::prelude::*;
7

            
8
register_representation!(MatrixToAtom, "matrix_to_atom");
9

            
10
#[derive(Clone, Debug)]
11
pub struct MatrixToAtom {
12
    src_var: Name,
13

            
14
    // all the possible indices in this matrix, in order.
15
    indices: Vec<Vec<Literal>>,
16

            
17
    // the element domain for the matrix.
18
    elem_domain: Moo<GroundDomain>,
19

            
20
    // the index domains for the matrix.
21
    index_domains: Vec<Moo<GroundDomain>>,
22
}
23

            
24
impl MatrixToAtom {
25
    /// Returns the names of the representation variables, in the same order as the indices.
26
    fn names(&self) -> impl Iterator<Item = Name> + '_ {
27
        self.indices.iter().map(|x| self.indices_to_name(x))
28
    }
29

            
30
    /// Gets the representation variable name for a specific set of indices.
31
    fn indices_to_name(&self, indices: &[Literal]) -> Name {
32
        Name::Represented(Box::new((
33
            self.src_var.clone(),
34
            self.repr_name().into(),
35
            indices.iter().join("_").into(),
36
        )))
37
    }
38

            
39
    /// Panics if name is invalid.
40
    #[allow(dead_code)]
41
    fn name_to_indices(&self, name: &Name) -> Vec<Literal> {
42
        let Name::Represented(fields) = name else {
43
            bug!("representation name should be Name::RepresentationOf");
44
        };
45

            
46
        let (src_var, rule_string, suffix) = fields.as_ref();
47

            
48
        assert_eq!(
49
            src_var,
50
            self.variable_name(),
51
            "name should have the same source var as self"
52
        );
53
        assert_eq!(
54
            rule_string,
55
            self.repr_name(),
56
            "name should have the same repr_name as self"
57
        );
58

            
59
        // FIXME: call the parser here to parse the literals properly; support more literal kinds
60
        // ~niklasdewally
61
        let indices = suffix.split("_").collect_vec();
62
        assert_eq!(
63
            indices.len(),
64
            self.indices[0].len(),
65
            "name should have same number of indices as self"
66
        );
67

            
68
        let parsed_indices = indices
69
            .into_iter()
70
            .map(|x| match parse_literal(x) {
71
                Ok(literal) => literal,
72
                Err(_) => bug!("{x} should be a string that can parse into a valid Literal"),
73
            })
74
            .collect_vec();
75

            
76
        assert!(
77
            self.indices.contains(&parsed_indices),
78
            "indices parsed from the representation name should be valid indices for this variable"
79
        );
80

            
81
        parsed_indices
82
    }
83
}
84

            
85
impl Representation for MatrixToAtom {
86
    fn init(name: &Name, symtab: &SymbolTable) -> Option<Self> {
87
        let domain = symtab.resolve_domain(name)?;
88

            
89
        if !domain.is_finite() {
90
            return None;
91
        }
92

            
93
        let GroundDomain::Matrix(elem_domain, index_domains) = domain.as_ref() else {
94
            return None;
95
        };
96

            
97
        let indices = matrix::enumerate_indices(index_domains.clone()).collect_vec();
98

            
99
        Some(MatrixToAtom {
100
            src_var: name.clone(),
101
            indices,
102
            elem_domain: elem_domain.clone(),
103
            index_domains: index_domains.clone(),
104
        })
105
    }
106

            
107
    fn variable_name(&self) -> &Name {
108
        &self.src_var
109
    }
110

            
111
    fn value_down(&self, value: Literal) -> Result<BTreeMap<Name, Literal>, ApplicationError> {
112
        let Literal::AbstractLiteral(matrix) = value else {
113
            return Err(RuleNotApplicable);
114
        };
115

            
116
        let AbstractLiteral::Matrix(_, ref index_domain) = matrix else {
117
            return Err(RuleNotApplicable);
118
        };
119

            
120
        if index_domain != &self.index_domains[0] {
121
            return Err(RuleNotApplicable);
122
        }
123

            
124
        Ok(izip!(self.names(), matrix::flatten(matrix)).collect())
125
    }
126

            
127
    fn value_up(&self, values: &BTreeMap<Name, Literal>) -> Result<Literal, ApplicationError> {
128
        // TODO: this has no error checking or failures that don't panic...
129

            
130
        let n_dims = self.index_domains.len();
131
        fn inner(
132
            current_index: Vec<Literal>,
133
            current_dim: usize,
134
            self1: &MatrixToAtom,
135
            values: &BTreeMap<Name, Literal>,
136
            n_dims: usize,
137
        ) -> Literal {
138
            if current_dim < n_dims {
139
                Literal::AbstractLiteral(into_matrix![
140
                    self1.index_domains[current_dim]
141
                        .values()
142
                        .unwrap()
143
                        .map(|i| {
144
                            let mut current_index_1 = current_index.clone();
145
                            current_index_1.push(i);
146
                            inner(current_index_1, current_dim + 1, self1, values, n_dims)
147
                        })
148
                        .collect_vec()
149
                ])
150
            } else {
151
                values
152
                    .get(&self1.indices_to_name(&current_index))
153
                    .unwrap()
154
                    .clone()
155
            }
156
        }
157

            
158
        Ok(inner(vec![], 0, self, values, n_dims))
159
    }
160

            
161
    fn expression_down(
162
        &self,
163
        symtab: &SymbolTable,
164
    ) -> Result<BTreeMap<Name, Expression>, ApplicationError> {
165
        Ok(self
166
            .names()
167
            .map(|name| {
168
                let declaration = symtab.lookup(&name).expect("declarations of the representation variables should exist in the symbol table before expression_down is called");
169
                (name, declaration)
170
            })
171
            .map(|(name, decl)| (name, Expression::Atomic(Metadata::new(), Atom::Reference(conjure_cp::ast::Reference::new(decl)))))
172
            .collect())
173
    }
174

            
175
    fn declaration_down(&self) -> Result<Vec<DeclarationPtr>, ApplicationError> {
176
        let dom: DomainPtr = self.elem_domain.clone().into();
177
        Ok(self
178
            .names()
179
            .map(|name| DeclarationPtr::new_var(name, dom.clone()))
180
            .collect_vec())
181
    }
182

            
183
    fn repr_name(&self) -> &str {
184
        "matrix_to_atom"
185
    }
186

            
187
    fn box_clone(&self) -> Box<dyn Representation> {
188
        Box::new(self.clone()) as _
189
    }
190
}