Skip to main content

conjure_cp_cli/utils/
testing.rs

1use std::collections::{BTreeMap, HashMap, HashSet};
2use std::fmt::Debug;
3use std::path::Path;
4use std::{io, mem, vec};
5
6use conjure_cp::ast::records::RecordValue;
7use conjure_cp::ast::serde::ObjId;
8use conjure_cp::bug;
9use itertools::Itertools as _;
10use std::fs::File;
11use std::hash::Hash;
12use std::io::{BufRead, BufReader, Write};
13use std::sync::{Arc, RwLock};
14use uniplate::Uniplate;
15
16use conjure_cp::ast::{AbstractLiteral, Expression, GroundDomain, Moo, SerdeModel};
17use conjure_cp::context::Context;
18use serde_json::{Error as JsonError, Value as JsonValue};
19
20use conjure_cp::error::Error;
21
22use crate::utils::conjure::solutions_to_json;
23use crate::utils::json::sort_json_object;
24use crate::utils::misc::to_set;
25use conjure_cp::Model as ConjureModel;
26use conjure_cp::ast::Name::User;
27use conjure_cp::ast::{Literal, Name};
28use conjure_cp::settings::SolverFamily;
29
30/// Limit how many lines of the rewrite serialisation we persist/compare in integration tests.
31pub const REWRITE_SERIALISED_JSON_MAX_LINES: usize = 1000;
32
33/// Converts a SerdeModel to JSON with stable IDs.
34///
35/// This ensures that the same model structure always produces the same IDs,
36/// regardless of the order in which objects were created in memory.
37fn model_to_json_with_stable_ids(model: &SerdeModel) -> Result<JsonValue, JsonError> {
38    // Collect stable ID mapping using uniplate traversal on the SerdeModel
39    let id_map = model.collect_stable_id_mapping();
40
41    // Serialize the model to JSON
42    let mut json = serde_json::to_value(model)?;
43
44    // Replace all IDs in the JSON with their stable counterparts
45    replace_ids(&mut json, &id_map);
46
47    Ok(json)
48}
49
50/// Recursively replaces all IDs in the JSON with their stable counterparts.
51///
52/// This is applied to all fields that are called "id" or "ptr" - be mindful
53/// of potential naming clashes in the future!
54fn replace_ids(value: &mut JsonValue, id_map: &HashMap<ObjId, ObjId>) {
55    match value {
56        JsonValue::Object(map) => {
57            // Replace IDs in three places:
58            // - "id" fields (SymbolTable IDs)
59            // - "parent" fields (SymbolTable nesting)
60            // - "ptr" fields (DeclarationPtr IDs)
61            for (k, v) in map.iter_mut() {
62                if (k == "id" || k == "ptr" || k == "parent")
63                    && let Ok(old_id) = serde_json::from_value::<ObjId>(mem::take(v))
64                {
65                    let new_id = id_map.get(&old_id).expect("all ids to be in the id map");
66                    *v = serde_json::to_value(new_id)
67                        .expect("serialization of an ObjId to always succeed");
68                }
69            }
70
71            // Recursively process all values
72            for val in map.values_mut() {
73                replace_ids(val, id_map);
74            }
75        }
76        JsonValue::Array(arr) => {
77            for item in arr {
78                replace_ids(item, id_map);
79            }
80        }
81        _ => {}
82    }
83}
84
85pub fn assert_eq_any_order<T: Eq + Hash + Debug + Clone>(a: &Vec<Vec<T>>, b: &Vec<Vec<T>>) {
86    assert_eq!(a.len(), b.len());
87
88    let mut a_rows: Vec<HashSet<T>> = Vec::new();
89    for row in a {
90        let hash_row = to_set(row);
91        a_rows.push(hash_row);
92    }
93
94    let mut b_rows: Vec<HashSet<T>> = Vec::new();
95    for row in b {
96        let hash_row = to_set(row);
97        b_rows.push(hash_row);
98    }
99
100    for row in a_rows {
101        assert!(b_rows.contains(&row));
102    }
103}
104
105pub fn serialize_model(model: &ConjureModel) -> Result<String, JsonError> {
106    let serde_model: SerdeModel = model.clone().into();
107
108    // Convert to JSON with stable IDs
109    let json_with_stable_ids = model_to_json_with_stable_ids(&serde_model)?;
110
111    // Sort JSON object keys for consistent output
112    let sorted_json = sort_json_object(&json_with_stable_ids, false);
113
114    // Serialize to pretty-printed string
115    serde_json::to_string_pretty(&sorted_json)
116}
117
118pub fn serialize_domains(model: &ConjureModel) -> Result<String, JsonError> {
119    let mut output = String::new();
120    for constraint in model.constraints() {
121        serialize_domains_expr(constraint, 0, &mut output);
122    }
123    Ok(output)
124}
125
126fn serialize_domains_expr(expr: &Expression, depth: usize, output: &mut String) {
127    let domain = expr
128        .domain_of()
129        .map(|domain| domain.to_string())
130        .unwrap_or_else(|| "<unknown>".to_owned());
131    output.push_str(&" ".repeat(depth));
132    output.push_str(&format!("{expr} :: {domain}\n"));
133
134    for child in expr.children() {
135        serialize_domains_expr(&child, depth + 1, output);
136    }
137}
138
139pub fn save_model_json(
140    model: &ConjureModel,
141    path: &str,
142    test_name: &str,
143    test_stage: &str,
144    solver: SolverFamily,
145) -> Result<(), std::io::Error> {
146    let marker = solver.as_str();
147    let generated_json_str = serialize_model(model)?;
148    let generated_json_str = maybe_truncate_serialised_json(generated_json_str, test_stage);
149    let filename = format!("{path}/{test_name}-{marker}.generated-{test_stage}.serialised.json");
150    println!("saving: {}", filename);
151    File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
152    Ok(())
153}
154
155pub fn save_stats_json(
156    context: Arc<RwLock<Context<'static>>>,
157    path: &str,
158    test_name: &str,
159    solver: SolverFamily,
160) -> Result<(), std::io::Error> {
161    #[allow(clippy::unwrap_used)]
162    let solver_name = solver.as_str();
163
164    let stats = context.read().unwrap().clone();
165    let generated_json = sort_json_object(&serde_json::to_value(stats)?, false);
166
167    // serialise to string
168    let generated_json_str = serde_json::to_string_pretty(&generated_json)?;
169
170    File::create(format!("{path}/{test_name}-{solver_name}-stats.json"))?
171        .write_all(generated_json_str.as_bytes())?;
172
173    Ok(())
174}
175
176/// Reads a file into a `String`, providing a clearer error message that includes the file path.
177fn read_with_path(path: String) -> Result<String, std::io::Error> {
178    std::fs::read_to_string(&path)
179        .map_err(|e| io::Error::new(e.kind(), format!("{} (path: {})", e, path)))
180}
181
182pub fn read_model_json(
183    ctx: &Arc<RwLock<Context<'static>>>,
184    path: &str,
185    test_name: &str,
186    prefix: &str,
187    test_stage: &str,
188    solver: SolverFamily,
189) -> Result<ConjureModel, std::io::Error> {
190    let marker = solver.as_str();
191    let filepath = format!("{path}/{test_name}-{marker}.{prefix}-{test_stage}.serialised.json");
192    let expected_json_str = std::fs::read_to_string(filepath)?;
193    let expected_model: SerdeModel = serde_json::from_str(&expected_json_str)?;
194
195    Ok(expected_model.initialise(ctx.clone()).unwrap())
196}
197
198/// Reads only the first `max_lines` from a serialised model JSON file.
199pub fn read_model_json_prefix(
200    path: &str,
201    test_name: &str,
202    prefix: &str,
203    test_stage: &str,
204    solver: SolverFamily,
205    max_lines: usize,
206) -> Result<String, std::io::Error> {
207    let marker = solver.as_str();
208    let filename = format!("{path}/{test_name}-{marker}.{prefix}-{test_stage}.serialised.json");
209    println!("reading: {}", filename);
210    read_first_n_lines(filename, max_lines)
211}
212
213pub fn minion_solutions_from_json(
214    serialized: &str,
215) -> Result<Vec<HashMap<Name, Literal>>, anyhow::Error> {
216    let json: JsonValue = serde_json::from_str(serialized)?;
217
218    let json_array = json
219        .as_array()
220        .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
221
222    let mut solutions = Vec::new();
223
224    for solution in json_array {
225        let mut sol = HashMap::new();
226        let solution = solution
227            .as_object()
228            .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
229
230        for (var_name, constant) in solution {
231            let constant = match constant {
232                JsonValue::Number(n) => {
233                    let n = n
234                        .as_i64()
235                        .ok_or(Error::Parse("Invalid integer".to_owned()))?;
236                    Literal::Int(n as i32)
237                }
238                JsonValue::Bool(b) => Literal::Bool(*b),
239                _ => return Err(Error::Parse("Invalid constant".to_owned()).into()),
240            };
241
242            sol.insert(User(var_name.into()), constant);
243        }
244
245        solutions.push(sol);
246    }
247
248    Ok(solutions)
249}
250
251/// Writes the minion solutions to a generated JSON file, and returns the JSON structure.
252pub fn save_solutions_json(
253    solutions: &Vec<BTreeMap<Name, Literal>>,
254    path: &str,
255    test_name: &str,
256    solver: SolverFamily,
257) -> Result<JsonValue, std::io::Error> {
258    let json_solutions = solutions_to_json(solutions);
259    let generated_json_str = serde_json::to_string_pretty(&json_solutions)?;
260
261    let solver_name = solver.as_str();
262    let filename = format!("{path}/{test_name}-{solver_name}.generated-solutions.json");
263    File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
264
265    Ok(json_solutions)
266}
267
268pub fn read_solutions_json(
269    path: &str,
270    test_name: &str,
271    prefix: &str,
272    solver: SolverFamily,
273) -> Result<JsonValue, anyhow::Error> {
274    let solver_name = solver.as_str();
275    let filename = format!("{path}/{test_name}-{solver_name}.{prefix}-solutions.json");
276    let expected_json_str = read_with_path(filename)?;
277
278    let expected_solutions: JsonValue =
279        sort_json_object(&serde_json::from_str(&expected_json_str)?, true);
280
281    Ok(expected_solutions)
282}
283
284/// Reads a human-readable rule trace text file.
285pub fn read_human_rule_trace(
286    path: &str,
287    test_name: &str,
288    prefix: &str,
289    solver: &SolverFamily,
290) -> Result<Vec<String>, std::io::Error> {
291    let solver_name = solver.as_str();
292    let filename = format!("{path}/{test_name}-{solver_name}-{prefix}-rule-trace.txt");
293    let rules_trace: Vec<String> = read_with_path(filename)?
294        .lines()
295        .map(String::from)
296        .collect();
297
298    Ok(rules_trace)
299}
300
301#[doc(hidden)]
302pub fn normalize_solutions_for_comparison(
303    input_solutions: &[BTreeMap<Name, Literal>],
304) -> Vec<BTreeMap<Name, Literal>> {
305    let mut normalized = input_solutions.to_vec();
306
307    for solset in &mut normalized {
308        // remove machine names
309        let keys_to_remove: Vec<Name> = solset
310            .keys()
311            .filter(|k| matches!(k, Name::Machine(_)))
312            .cloned()
313            .collect();
314        for k in keys_to_remove {
315            solset.remove(&k);
316        }
317
318        let mut updates = vec![];
319        for (k, v) in solset.clone() {
320            if let Name::User(_) = k {
321                match v {
322                    Literal::Bool(true) => updates.push((k, Literal::Int(1))),
323                    Literal::Bool(false) => updates.push((k, Literal::Int(0))),
324                    Literal::Int(_) => {}
325                    Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, _)) => {
326                        // make all domains the same (this is just in the tester so the types dont
327                        // actually matter)
328
329                        let mut matrix =
330                            AbstractLiteral::Matrix(elems, Moo::new(GroundDomain::Int(vec![])));
331                        matrix = matrix.transform(&move |x: AbstractLiteral<Literal>| match x {
332                            AbstractLiteral::Matrix(items, _) => {
333                                let items = items
334                                    .into_iter()
335                                    .map(|x| match x {
336                                        Literal::Bool(false) => Literal::Int(0),
337                                        Literal::Bool(true) => Literal::Int(1),
338                                        x => x,
339                                    })
340                                    .collect_vec();
341
342                                AbstractLiteral::Matrix(items, Moo::new(GroundDomain::Int(vec![])))
343                            }
344                            x => x,
345                        });
346                        updates.push((k, Literal::AbstractLiteral(matrix)));
347                    }
348                    Literal::AbstractLiteral(AbstractLiteral::Tuple(elems)) => {
349                        // just the same as matrix but with tuples instead
350                        // only conversion needed is to convert bools to ints
351                        let mut tuple = AbstractLiteral::Tuple(elems);
352                        tuple = tuple.transform(
353                            &(move |x: AbstractLiteral<Literal>| match x {
354                                AbstractLiteral::Tuple(items) => {
355                                    let items = items
356                                        .into_iter()
357                                        .map(|x| match x {
358                                            Literal::Bool(false) => Literal::Int(0),
359                                            Literal::Bool(true) => Literal::Int(1),
360                                            x => x,
361                                        })
362                                        .collect_vec();
363
364                                    AbstractLiteral::Tuple(items)
365                                }
366                                x => x,
367                            }),
368                        );
369                        updates.push((k, Literal::AbstractLiteral(tuple)));
370                    }
371                    Literal::AbstractLiteral(AbstractLiteral::Record(entries)) => {
372                        // just the same as matrix but with tuples instead
373                        // only conversion needed is to convert bools to ints
374                        let mut record = AbstractLiteral::Record(entries);
375                        record = record.transform(&move |x: AbstractLiteral<Literal>| match x {
376                            AbstractLiteral::Record(entries) => {
377                                let entries = entries
378                                    .into_iter()
379                                    .map(|x| {
380                                        let RecordValue { name, value } = x;
381                                        {
382                                            let value = match value {
383                                                Literal::Bool(false) => Literal::Int(0),
384                                                Literal::Bool(true) => Literal::Int(1),
385                                                x => x,
386                                            };
387                                            RecordValue { name, value }
388                                        }
389                                    })
390                                    .collect_vec();
391
392                                AbstractLiteral::Record(entries)
393                            }
394                            x => x,
395                        });
396                        updates.push((k, Literal::AbstractLiteral(record)));
397                    }
398                    Literal::AbstractLiteral(AbstractLiteral::Set(members)) => {
399                        let set = AbstractLiteral::Set(members).transform(&move |x| match x {
400                            AbstractLiteral::Set(members) => {
401                                let members = members
402                                    .into_iter()
403                                    .map(|x| match x {
404                                        Literal::Bool(false) => Literal::Int(0),
405                                        Literal::Bool(true) => Literal::Int(1),
406                                        x => x,
407                                    })
408                                    .collect_vec();
409
410                                AbstractLiteral::Set(members)
411                            }
412                            x => x,
413                        });
414                        updates.push((k, Literal::AbstractLiteral(set)));
415                    }
416                    e => bug!("unexpected literal type: {e:?}"),
417                }
418            }
419        }
420
421        for (k, v) in updates {
422            solset.insert(k, v);
423        }
424    }
425
426    // Remove duplicates
427    normalized = normalized.into_iter().unique().collect();
428    normalized
429}
430
431fn maybe_truncate_serialised_json(serialised: String, test_stage: &str) -> String {
432    if test_stage == "rewrite" {
433        truncate_to_first_lines(&serialised, REWRITE_SERIALISED_JSON_MAX_LINES)
434    } else {
435        serialised
436    }
437}
438
439fn truncate_to_first_lines(content: &str, max_lines: usize) -> String {
440    content.lines().take(max_lines).join("\n")
441}
442
443fn read_first_n_lines<P: AsRef<Path>>(filename: P, n: usize) -> io::Result<String> {
444    let reader = BufReader::new(File::open(&filename)?);
445    let lines = reader
446        .lines()
447        .chunks(n)
448        .into_iter()
449        .next()
450        .unwrap()
451        .collect::<Result<Vec<_>, _>>()?;
452    Ok(lines.join("\n"))
453}