1
use conjure_cp::ast::{DeclarationPtr, DomainPtr, GroundDomain, Moo, records::RecordValue};
2
use itertools::Itertools;
3

            
4
use super::prelude::*;
5

            
6
register_representation!(RecordToAtom, "record_to_atom");
7

            
8
#[derive(Clone, Debug)]
9
pub struct RecordToAtom {
10
    src_var: Name,
11

            
12
    entry_names: Vec<Name>,
13

            
14
    // all the possible indices in this record, in order.
15
    indices: Vec<Literal>,
16

            
17
    // the element domains for each item in the tuple.
18
    elem_domain: Vec<Moo<GroundDomain>>,
19
}
20

            
21
impl RecordToAtom {
22
    /// Returns the names of the representation variable (there must be a much easier way to do this but oh well)
23
168
    fn names(&self) -> impl Iterator<Item = Name> + '_ {
24
168
        self.indices
25
168
            .iter()
26
336
            .map(move |index| self.indices_to_name(std::slice::from_ref(index)))
27
168
    }
28

            
29
    /// Gets the representation variable name for a specific set of indices.
30
336
    fn indices_to_name(&self, indices: &[Literal]) -> Name {
31
336
        Name::Represented(Box::new((
32
336
            self.src_var.clone(),
33
336
            self.repr_name().into(),
34
336
            indices.iter().join("_").into(),
35
336
        )))
36
336
    }
37
}
38

            
39
impl Representation for RecordToAtom {
40
36
    fn init(name: &Name, symtab: &SymbolTable) -> Option<Self> {
41
36
        let domain = symtab.resolve_domain(name)?;
42

            
43
36
        if !domain.is_finite() {
44
            return None;
45
36
        }
46

            
47
36
        let GroundDomain::Record(entries) = domain.as_ref() else {
48
            return None;
49
        };
50

            
51
        //indices may not be needed as a field as we can always use the length of the record
52
36
        let indices = (1..(entries.len() + 1) as i32).map(Literal::Int).collect();
53

            
54
        Some(RecordToAtom {
55
36
            src_var: name.clone(),
56
72
            entry_names: entries.iter().map(|entry| entry.name.clone()).collect(),
57
36
            indices,
58
72
            elem_domain: entries.iter().map(|entry| entry.domain.clone()).collect(),
59
        })
60
36
    }
61

            
62
    fn variable_name(&self) -> &Name {
63
        &self.src_var
64
    }
65

            
66
    fn value_down(
67
        &self,
68
        value: Literal,
69
    ) -> Result<std::collections::BTreeMap<Name, Literal>, ApplicationError> {
70
        let Literal::AbstractLiteral(record) = value else {
71
            return Err(ApplicationError::RuleNotApplicable);
72
        };
73

            
74
        let AbstractLiteral::Tuple(entries) = record else {
75
            return Err(ApplicationError::RuleNotApplicable);
76
        };
77

            
78
        let mut result = std::collections::BTreeMap::new();
79

            
80
        for (i, elem) in entries.into_iter().enumerate() {
81
            let name = format!("{}_{}", self.src_var, i + 1);
82
            result.insert(Name::user(&name), elem);
83
        }
84

            
85
        Ok(result)
86
    }
87

            
88
36
    fn value_up(
89
36
        &self,
90
36
        values: &std::collections::BTreeMap<Name, Literal>,
91
36
    ) -> Result<Literal, ApplicationError> {
92
36
        let mut record = Vec::new();
93

            
94
72
        for name in self.names() {
95
72
            let value = values
96
72
                .get(&name)
97
72
                .ok_or(ApplicationError::RuleNotApplicable)?;
98
72
            let Name::Represented(fields) = name.clone() else {
99
                return Err(ApplicationError::RuleNotApplicable);
100
            };
101

            
102
72
            let (_, _, idx) = *fields;
103

            
104
72
            let idx: usize = idx
105
72
                .parse()
106
72
                .map_err(|_| ApplicationError::RuleNotApplicable)?;
107
72
            if idx == 0 {
108
                return Err(ApplicationError::RuleNotApplicable);
109
72
            }
110
72
            let idx = idx - 1;
111
72
            record.push(RecordValue {
112
72
                name: self.entry_names[idx].clone(),
113
72
                value: value.clone(),
114
72
            });
115
        }
116

            
117
36
        Ok(Literal::AbstractLiteral(AbstractLiteral::Record(record)))
118
36
    }
119

            
120
96
    fn expression_down(
121
96
        &self,
122
96
        st: &SymbolTable,
123
96
    ) -> Result<std::collections::BTreeMap<Name, Expression>, ApplicationError> {
124
96
        Ok(self
125
96
            .names()
126
192
            .map(|name| {
127
192
                let decl = st.lookup(&name).unwrap();
128
192
                (
129
192
                    name,
130
192
                    Expression::Atomic(
131
192
                        Metadata::new(),
132
192
                        Atom::Reference(conjure_cp::ast::Reference::new(decl)),
133
192
                    ),
134
192
                )
135
192
            })
136
96
            .collect())
137
96
    }
138

            
139
36
    fn declaration_down(&self) -> Result<Vec<DeclarationPtr>, ApplicationError> {
140
36
        Ok(self
141
36
            .names()
142
36
            .zip(self.elem_domain.iter().cloned())
143
72
            .map(|(name, domain)| DeclarationPtr::new_find(name, DomainPtr::from(domain)))
144
36
            .collect())
145
36
    }
146

            
147
516
    fn repr_name(&self) -> &str {
148
516
        "record_to_atom"
149
516
    }
150

            
151
7080
    fn box_clone(&self) -> Box<dyn Representation> {
152
7080
        Box::new(self.clone()) as _
153
7080
    }
154
}