1
use conjure_cp::ast::{DeclarationPtr, DomainPtr, GroundDomain, Moo, records::RecordValue};
2
use itertools::Itertools;
3

            
4
use super::prelude::*;
5

            
6
register_representation!(RecordToAtom, "record_to_atom");
7

            
8
#[derive(Clone, Debug)]
9
pub struct RecordToAtom {
10
    src_var: Name,
11

            
12
    entry_names: Vec<Name>,
13

            
14
    // all the possible indices in this record, in order.
15
    indices: Vec<Literal>,
16

            
17
    // the element domains for each item in the tuple.
18
    elem_domain: Vec<Moo<GroundDomain>>,
19
}
20

            
21
impl RecordToAtom {
22
    /// Returns the names of the representation variable (there must be a much easier way to do this but oh well)
23
    fn names(&self) -> impl Iterator<Item = Name> + '_ {
24
        self.indices
25
            .iter()
26
            .map(move |index| self.indices_to_name(std::slice::from_ref(index)))
27
    }
28

            
29
    /// Gets the representation variable name for a specific set of indices.
30
    fn indices_to_name(&self, indices: &[Literal]) -> Name {
31
        Name::Represented(Box::new((
32
            self.src_var.clone(),
33
            self.repr_name().into(),
34
            indices.iter().join("_").into(),
35
        )))
36
    }
37
}
38

            
39
impl Representation for RecordToAtom {
40
    fn init(name: &Name, symtab: &SymbolTable) -> Option<Self> {
41
        let domain = symtab.resolve_domain(name)?;
42

            
43
        if !domain.is_finite() {
44
            return None;
45
        }
46

            
47
        let GroundDomain::Record(entries) = domain.as_ref() else {
48
            return None;
49
        };
50

            
51
        //indices may not be needed as a field as we can always use the length of the record
52
        let indices = (1..(entries.len() + 1) as i32).map(Literal::Int).collect();
53

            
54
        Some(RecordToAtom {
55
            src_var: name.clone(),
56
            entry_names: entries.iter().map(|entry| entry.name.clone()).collect(),
57
            indices,
58
            elem_domain: entries.iter().map(|entry| entry.domain.clone()).collect(),
59
        })
60
    }
61

            
62
    fn variable_name(&self) -> &Name {
63
        &self.src_var
64
    }
65

            
66
    fn value_down(
67
        &self,
68
        value: Literal,
69
    ) -> Result<std::collections::BTreeMap<Name, Literal>, ApplicationError> {
70
        let Literal::AbstractLiteral(record) = value else {
71
            return Err(ApplicationError::RuleNotApplicable);
72
        };
73

            
74
        let AbstractLiteral::Tuple(entries) = record else {
75
            return Err(ApplicationError::RuleNotApplicable);
76
        };
77

            
78
        let mut result = std::collections::BTreeMap::new();
79

            
80
        for (i, elem) in entries.into_iter().enumerate() {
81
            let name = format!("{}_{}", self.src_var, i + 1);
82
            result.insert(Name::user(&name), elem);
83
        }
84

            
85
        Ok(result)
86
    }
87

            
88
    fn value_up(
89
        &self,
90
        values: &std::collections::BTreeMap<Name, Literal>,
91
    ) -> Result<Literal, ApplicationError> {
92
        let mut record = Vec::new();
93

            
94
        for name in self.names() {
95
            let value = values
96
                .get(&name)
97
                .ok_or(ApplicationError::RuleNotApplicable)?;
98
            let Name::Represented(fields) = name.clone() else {
99
                return Err(ApplicationError::RuleNotApplicable);
100
            };
101

            
102
            let (_, _, idx) = *fields;
103

            
104
            let idx: usize = idx
105
                .parse()
106
                .map_err(|_| ApplicationError::RuleNotApplicable)?;
107
            if idx == 0 {
108
                return Err(ApplicationError::RuleNotApplicable);
109
            }
110
            let idx = idx - 1;
111
            record.push(RecordValue {
112
                name: self.entry_names[idx].clone(),
113
                value: value.clone(),
114
            });
115
        }
116

            
117
        Ok(Literal::AbstractLiteral(AbstractLiteral::Record(record)))
118
    }
119

            
120
    fn expression_down(
121
        &self,
122
        st: &SymbolTable,
123
    ) -> Result<std::collections::BTreeMap<Name, Expression>, ApplicationError> {
124
        Ok(self
125
            .names()
126
            .map(|name| {
127
                let decl = st.lookup(&name).unwrap();
128
                (
129
                    name,
130
                    Expression::Atomic(
131
                        Metadata::new(),
132
                        Atom::Reference(conjure_cp::ast::Reference::new(decl)),
133
                    ),
134
                )
135
            })
136
            .collect())
137
    }
138

            
139
    fn declaration_down(&self) -> Result<Vec<DeclarationPtr>, ApplicationError> {
140
        Ok(self
141
            .names()
142
            .zip(self.elem_domain.iter().cloned())
143
            .map(|(name, domain)| DeclarationPtr::new_var(name, DomainPtr::from(domain)))
144
            .collect())
145
    }
146

            
147
    fn repr_name(&self) -> &str {
148
        "record_to_atom"
149
    }
150

            
151
    fn box_clone(&self) -> Box<dyn Representation> {
152
        Box::new(self.clone()) as _
153
    }
154
}