conjure_oxide/utils/
testing.rs

1use std::collections::{BTreeMap, HashMap, HashSet};
2use std::fmt::Debug;
3use std::vec;
4
5use conjure_core::ast::records::RecordValue;
6use conjure_core::bug;
7use itertools::Itertools as _;
8use std::fs::File;
9use std::fs::{read_to_string, OpenOptions};
10use std::hash::Hash;
11use std::io::Write;
12use std::sync::{Arc, RwLock};
13use uniplate::Uniplate;
14
15use conjure_core::ast::{AbstractLiteral, Domain, SerdeModel};
16use conjure_core::context::Context;
17use serde_json::{json, Error as JsonError, Value as JsonValue};
18
19use conjure_core::error::Error;
20
21use crate::ast::Name::UserName;
22use crate::ast::{Literal, Name};
23use crate::utils::conjure::solutions_to_json;
24use crate::utils::json::sort_json_object;
25use crate::utils::misc::to_set;
26use crate::Model as ConjureModel;
27use crate::SolverFamily;
28
29pub fn assert_eq_any_order<T: Eq + Hash + Debug + Clone>(a: &Vec<Vec<T>>, b: &Vec<Vec<T>>) {
30    assert_eq!(a.len(), b.len());
31
32    let mut a_rows: Vec<HashSet<T>> = Vec::new();
33    for row in a {
34        let hash_row = to_set(row);
35        a_rows.push(hash_row);
36    }
37
38    let mut b_rows: Vec<HashSet<T>> = Vec::new();
39    for row in b {
40        let hash_row = to_set(row);
41        b_rows.push(hash_row);
42    }
43
44    println!("{:?},{:?}", a_rows, b_rows);
45    for row in a_rows {
46        assert!(b_rows.contains(&row));
47    }
48}
49
50pub fn serialise_model(model: &ConjureModel) -> Result<String, JsonError> {
51    // A consistent sorting of the keys of json objects
52    // only required for the generated version
53    // since the expected version will already be sorted
54    let serde_model: SerdeModel = model.clone().into();
55    let generated_json = sort_json_object(&serde_json::to_value(serde_model)?, false);
56
57    // serialise to string
58    let generated_json_str = serde_json::to_string_pretty(&generated_json)?;
59
60    Ok(generated_json_str)
61}
62
63pub fn save_model_json(
64    model: &ConjureModel,
65    path: &str,
66    test_name: &str,
67    test_stage: &str,
68) -> Result<(), std::io::Error> {
69    let generated_json_str = serialise_model(model)?;
70    let filename = format!("{path}/{test_name}.generated-{test_stage}.serialised.json");
71    File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
72    Ok(())
73}
74
75pub fn save_stats_json(
76    context: Arc<RwLock<Context<'static>>>,
77    path: &str,
78    test_name: &str,
79) -> Result<(), std::io::Error> {
80    #[allow(clippy::unwrap_used)]
81    let stats = context.read().unwrap().clone();
82    let generated_json = sort_json_object(&serde_json::to_value(stats)?, false);
83
84    // serialise to string
85    let generated_json_str = serde_json::to_string_pretty(&generated_json)?;
86
87    File::create(format!("{path}/{test_name}-stats.json"))?
88        .write_all(generated_json_str.as_bytes())?;
89
90    Ok(())
91}
92
93pub fn read_model_json(
94    ctx: &Arc<RwLock<Context<'static>>>,
95    path: &str,
96    test_name: &str,
97    prefix: &str,
98    test_stage: &str,
99) -> Result<ConjureModel, std::io::Error> {
100    let expected_json_str = std::fs::read_to_string(format!(
101        "{path}/{test_name}.{prefix}-{test_stage}.serialised.json"
102    ))?;
103    println!("{path}/{test_name}.{prefix}-{test_stage}.serialised.json");
104    let expected_model: SerdeModel = serde_json::from_str(&expected_json_str)?;
105
106    Ok(expected_model.initialise(ctx.clone()).unwrap())
107}
108
109pub fn minion_solutions_from_json(
110    serialized: &str,
111) -> Result<Vec<HashMap<Name, Literal>>, anyhow::Error> {
112    let json: JsonValue = serde_json::from_str(serialized)?;
113
114    let json_array = json
115        .as_array()
116        .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
117
118    let mut solutions = Vec::new();
119
120    for solution in json_array {
121        let mut sol = HashMap::new();
122        let solution = solution
123            .as_object()
124            .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
125
126        for (var_name, constant) in solution {
127            let constant = match constant {
128                JsonValue::Number(n) => {
129                    let n = n
130                        .as_i64()
131                        .ok_or(Error::Parse("Invalid integer".to_owned()))?;
132                    Literal::Int(n as i32)
133                }
134                JsonValue::Bool(b) => Literal::Bool(*b),
135                _ => return Err(Error::Parse("Invalid constant".to_owned()).into()),
136            };
137
138            sol.insert(UserName(var_name.into()), constant);
139        }
140
141        solutions.push(sol);
142    }
143
144    Ok(solutions)
145}
146
147/// Writes the minion solutions to a generated JSON file, and returns the JSON structure.
148pub fn save_solutions_json(
149    solutions: &Vec<BTreeMap<Name, Literal>>,
150    path: &str,
151    test_name: &str,
152    solver: SolverFamily,
153) -> Result<JsonValue, std::io::Error> {
154    let json_solutions = solutions_to_json(solutions);
155    let generated_json_str = serde_json::to_string_pretty(&json_solutions)?;
156
157    let solver_name = match solver {
158        SolverFamily::SAT => "sat",
159        SolverFamily::Minion => "minion",
160    };
161
162    let filename = format!("{path}/{test_name}.generated-{solver_name}.solutions.json");
163    File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
164
165    Ok(json_solutions)
166}
167
168pub fn read_solutions_json(
169    path: &str,
170    test_name: &str,
171    prefix: &str,
172    solver: SolverFamily,
173) -> Result<JsonValue, anyhow::Error> {
174    let solver_name = match solver {
175        SolverFamily::SAT => "sat",
176        SolverFamily::Minion => "minion",
177    };
178
179    let expected_json_str = std::fs::read_to_string(format!(
180        "{path}/{test_name}.{prefix}-{solver_name}.solutions.json"
181    ))?;
182
183    let expected_solutions: JsonValue =
184        sort_json_object(&serde_json::from_str(&expected_json_str)?, true);
185
186    Ok(expected_solutions)
187}
188
189/// Reads a rule trace from a file. For the generated prefix, it appends a count message.
190/// Returns the lines of the file as a vector of strings.
191pub fn read_rule_trace(
192    path: &str,
193    test_name: &str,
194    prefix: &str,
195) -> Result<Vec<String>, std::io::Error> {
196    let filename = format!("{path}/{test_name}-{prefix}-rule-trace.json");
197    let mut rules_trace: Vec<String> = read_to_string(&filename)?
198        .lines()
199        .map(String::from)
200        .collect();
201
202    // If prefix is "generated", append the count message
203    if prefix == "generated" {
204        let rule_count = rules_trace.len();
205        let count_message = json!({
206            "message": "Number of rules applied",
207            "count": rule_count
208        });
209        let count_message_string = serde_json::to_string(&count_message)?;
210        rules_trace.push(count_message_string);
211
212        // Overwrite the file with updated content (including the count message)
213        let mut file = OpenOptions::new()
214            .write(true)
215            .truncate(true)
216            .open(&filename)?;
217        writeln!(file, "{}", rules_trace.join("\n"))?;
218    }
219
220    Ok(rules_trace)
221}
222
223/// Reads a human-readable rule trace text file.
224pub fn read_human_rule_trace(
225    path: &str,
226    test_name: &str,
227    prefix: &str,
228) -> Result<Vec<String>, std::io::Error> {
229    let filename = format!("{path}/{test_name}-{prefix}-rule-trace-human.txt");
230    let rules_trace: Vec<String> = read_to_string(&filename)?
231        .lines()
232        .map(String::from)
233        .collect();
234
235    Ok(rules_trace)
236}
237
238#[doc(hidden)]
239pub fn normalize_solutions_for_comparison(
240    input_solutions: &[BTreeMap<Name, Literal>],
241) -> Vec<BTreeMap<Name, Literal>> {
242    let mut normalized = input_solutions.to_vec();
243
244    for solset in &mut normalized {
245        // remove machine names
246        let keys_to_remove: Vec<Name> = solset
247            .keys()
248            .filter(|k| matches!(k, Name::MachineName(_)))
249            .cloned()
250            .collect();
251        for k in keys_to_remove {
252            solset.remove(&k);
253        }
254
255        let mut updates = vec![];
256        for (k, v) in solset.clone() {
257            if let Name::UserName(_) = k {
258                match v {
259                    Literal::Bool(true) => updates.push((k, Literal::Int(1))),
260                    Literal::Bool(false) => updates.push((k, Literal::Int(0))),
261                    Literal::Int(_) => {}
262                    Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, _)) => {
263                        // make all domains the same (this is just in the tester so the types dont
264                        // actually matter)
265
266                        let mut matrix = AbstractLiteral::Matrix(elems, Domain::IntDomain(vec![]));
267                        matrix =
268                            matrix.transform(Arc::new(
269                                move |x: AbstractLiteral<Literal>| match x {
270                                    AbstractLiteral::Matrix(items, _) => {
271                                        let items = items
272                                            .into_iter()
273                                            .map(|x| match x {
274                                                Literal::Bool(false) => Literal::Int(0),
275                                                Literal::Bool(true) => Literal::Int(1),
276                                                x => x,
277                                            })
278                                            .collect_vec();
279
280                                        AbstractLiteral::Matrix(items, Domain::IntDomain(vec![]))
281                                    }
282                                    x => x,
283                                },
284                            ));
285                        updates.push((k, Literal::AbstractLiteral(matrix)));
286                    }
287                    Literal::AbstractLiteral(AbstractLiteral::Tuple(elems)) => {
288                        // just the same as matrix but with tuples instead
289                        // only conversion needed is to convert bools to ints
290                        let mut tuple = AbstractLiteral::Tuple(elems);
291                        tuple =
292                            tuple.transform(Arc::new(move |x: AbstractLiteral<Literal>| match x {
293                                AbstractLiteral::Tuple(items) => {
294                                    let items = items
295                                        .into_iter()
296                                        .map(|x| match x {
297                                            Literal::Bool(false) => Literal::Int(0),
298                                            Literal::Bool(true) => Literal::Int(1),
299                                            x => x,
300                                        })
301                                        .collect_vec();
302
303                                    AbstractLiteral::Tuple(items)
304                                }
305                                x => x,
306                            }));
307                        updates.push((k, Literal::AbstractLiteral(tuple)));
308                    }
309                    Literal::AbstractLiteral(AbstractLiteral::Record(entries)) => {
310                        // just the same as matrix but with tuples instead
311                        // only conversion needed is to convert bools to ints
312                        let mut record = AbstractLiteral::Record(entries);
313                        record =
314                            record.transform(Arc::new(
315                                move |x: AbstractLiteral<Literal>| match x {
316                                    AbstractLiteral::Record(entries) => {
317                                        let entries = entries
318                                            .into_iter()
319                                            .map(|x| {
320                                                let RecordValue { name, value } = x;
321                                                {
322                                                    let value = match value {
323                                                        Literal::Bool(false) => Literal::Int(0),
324                                                        Literal::Bool(true) => Literal::Int(1),
325                                                        x => x,
326                                                    };
327                                                    RecordValue { name, value }
328                                                }
329                                            })
330                                            .collect_vec();
331
332                                        AbstractLiteral::Record(entries)
333                                    }
334                                    x => x,
335                                },
336                            ));
337                        updates.push((k, Literal::AbstractLiteral(record)));
338                    }
339                    e => bug!("unexpected literal type: {e:?}"),
340                }
341            }
342        }
343
344        for (k, v) in updates {
345            solset.insert(k, v);
346        }
347    }
348
349    // Remove duplicates
350    normalized = normalized.into_iter().unique().collect();
351    normalized
352}