Skip to main content

conjure_cp_cli/utils/
testing.rs

1use std::collections::{BTreeMap, HashMap, HashSet};
2use std::fmt::Debug;
3use std::path::Path;
4use std::{io, mem, vec};
5
6use conjure_cp::ast::records::RecordValue;
7use conjure_cp::ast::serde::ObjId;
8use conjure_cp::bug;
9use itertools::Itertools as _;
10use std::fs::File;
11use std::hash::Hash;
12use std::io::{BufRead, BufReader, Write};
13use std::sync::{Arc, RwLock};
14use uniplate::Uniplate;
15
16use conjure_cp::ast::{AbstractLiteral, GroundDomain, Moo, SerdeModel};
17use conjure_cp::context::Context;
18use serde_json::{Error as JsonError, Value as JsonValue};
19
20use conjure_cp::error::Error;
21
22use crate::utils::conjure::solutions_to_json;
23use crate::utils::json::sort_json_object;
24use crate::utils::misc::to_set;
25use conjure_cp::Model as ConjureModel;
26use conjure_cp::ast::Name::User;
27use conjure_cp::ast::{Literal, Name};
28use conjure_cp::settings::SolverFamily;
29
30/// Limit how many lines of the rewrite serialisation we persist/compare in integration tests.
31pub const REWRITE_SERIALISED_JSON_MAX_LINES: usize = 1000;
32
33/// Converts a SerdeModel to JSON with stable IDs.
34///
35/// This ensures that the same model structure always produces the same IDs,
36/// regardless of the order in which objects were created in memory.
37fn model_to_json_with_stable_ids(model: &SerdeModel) -> Result<JsonValue, JsonError> {
38    // Collect stable ID mapping using uniplate traversal on the SerdeModel
39    let id_map = model.collect_stable_id_mapping();
40
41    // Serialize the model to JSON
42    let mut json = serde_json::to_value(model)?;
43
44    // Replace all IDs in the JSON with their stable counterparts
45    replace_ids(&mut json, &id_map);
46
47    Ok(json)
48}
49
50/// Recursively replaces all IDs in the JSON with their stable counterparts.
51///
52/// This is applied to all fields that are called "id" or "ptr" - be mindful
53/// of potential naming clashes in the future!
54fn replace_ids(value: &mut JsonValue, id_map: &HashMap<ObjId, ObjId>) {
55    match value {
56        JsonValue::Object(map) => {
57            // Replace IDs in three places:
58            // - "id" fields (SymbolTable IDs)
59            // - "parent" fields (SymbolTable nesting)
60            // - "ptr" fields (DeclarationPtr IDs)
61            for (k, v) in map.iter_mut() {
62                if (k == "id" || k == "ptr" || k == "parent")
63                    && let Ok(old_id) = serde_json::from_value::<ObjId>(mem::take(v))
64                {
65                    let new_id = id_map.get(&old_id).expect("all ids to be in the id map");
66                    *v = serde_json::to_value(new_id)
67                        .expect("serialization of an ObjId to always succeed");
68                }
69            }
70
71            // Recursively process all values
72            for val in map.values_mut() {
73                replace_ids(val, id_map);
74            }
75        }
76        JsonValue::Array(arr) => {
77            for item in arr {
78                replace_ids(item, id_map);
79            }
80        }
81        _ => {}
82    }
83}
84
85pub fn assert_eq_any_order<T: Eq + Hash + Debug + Clone>(a: &Vec<Vec<T>>, b: &Vec<Vec<T>>) {
86    assert_eq!(a.len(), b.len());
87
88    let mut a_rows: Vec<HashSet<T>> = Vec::new();
89    for row in a {
90        let hash_row = to_set(row);
91        a_rows.push(hash_row);
92    }
93
94    let mut b_rows: Vec<HashSet<T>> = Vec::new();
95    for row in b {
96        let hash_row = to_set(row);
97        b_rows.push(hash_row);
98    }
99
100    for row in a_rows {
101        assert!(b_rows.contains(&row));
102    }
103}
104
105pub fn serialize_model(model: &ConjureModel) -> Result<String, JsonError> {
106    let serde_model: SerdeModel = model.clone().into();
107
108    // Convert to JSON with stable IDs
109    let json_with_stable_ids = model_to_json_with_stable_ids(&serde_model)?;
110
111    // Sort JSON object keys for consistent output
112    let sorted_json = sort_json_object(&json_with_stable_ids, false);
113
114    // Serialize to pretty-printed string
115    serde_json::to_string_pretty(&sorted_json)
116}
117
118pub fn save_model_json(
119    model: &ConjureModel,
120    path: &str,
121    test_name: &str,
122    test_stage: &str,
123    solver: Option<SolverFamily>,
124) -> Result<(), std::io::Error> {
125    let marker = solver.map_or("agnostic", |s| s.as_str());
126    let generated_json_str = serialize_model(model)?;
127    let generated_json_str = maybe_truncate_serialised_json(generated_json_str, test_stage);
128    let filename = format!("{path}/{test_name}-{marker}.generated-{test_stage}.serialised.json");
129    println!("saving: {}", filename);
130    File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
131    Ok(())
132}
133
134pub fn save_stats_json(
135    context: Arc<RwLock<Context<'static>>>,
136    path: &str,
137    test_name: &str,
138    solver: SolverFamily,
139) -> Result<(), std::io::Error> {
140    #[allow(clippy::unwrap_used)]
141    let solver_name = solver.as_str();
142
143    let stats = context.read().unwrap().clone();
144    let generated_json = sort_json_object(&serde_json::to_value(stats)?, false);
145
146    // serialise to string
147    let generated_json_str = serde_json::to_string_pretty(&generated_json)?;
148
149    File::create(format!("{path}/{test_name}-{solver_name}-stats.json"))?
150        .write_all(generated_json_str.as_bytes())?;
151
152    Ok(())
153}
154
155/// Reads a file into a `String`, providing a clearer error message that includes the file path.
156fn read_with_path(path: String) -> Result<String, std::io::Error> {
157    std::fs::read_to_string(&path)
158        .map_err(|e| io::Error::new(e.kind(), format!("{} (path: {})", e, path)))
159}
160
161pub fn read_model_json(
162    ctx: &Arc<RwLock<Context<'static>>>,
163    path: &str,
164    test_name: &str,
165    prefix: &str,
166    test_stage: &str,
167    solver: Option<SolverFamily>,
168) -> Result<ConjureModel, std::io::Error> {
169    let marker = solver.map_or("agnostic", |s| s.as_str());
170    let new_filepath = format!("{path}/{test_name}-{marker}.{prefix}-{test_stage}.serialised.json");
171    let old_filepath = format!("{path}/{marker}-{test_name}.{prefix}-{test_stage}.serialised.json");
172    let filepath = if Path::new(&new_filepath).exists() {
173        new_filepath
174    } else {
175        old_filepath
176    };
177    let expected_json_str = std::fs::read_to_string(filepath)?;
178    let expected_model: SerdeModel = serde_json::from_str(&expected_json_str)?;
179
180    Ok(expected_model.initialise(ctx.clone()).unwrap())
181}
182
183/// Reads only the first `max_lines` from a serialised model JSON file.
184pub fn read_model_json_prefix(
185    path: &str,
186    test_name: &str,
187    prefix: &str,
188    test_stage: &str,
189    solver: Option<SolverFamily>,
190    max_lines: usize,
191) -> Result<String, std::io::Error> {
192    let marker = solver.map_or("agnostic", |s| s.as_str());
193    let new_filename = format!("{path}/{test_name}-{marker}.{prefix}-{test_stage}.serialised.json");
194    let old_filename = format!("{path}/{marker}-{test_name}.{prefix}-{test_stage}.serialised.json");
195    let filename = if Path::new(&new_filename).exists() {
196        new_filename
197    } else {
198        old_filename
199    };
200    println!("reading: {}", filename);
201    read_first_n_lines(filename, max_lines)
202}
203
204pub fn minion_solutions_from_json(
205    serialized: &str,
206) -> Result<Vec<HashMap<Name, Literal>>, anyhow::Error> {
207    let json: JsonValue = serde_json::from_str(serialized)?;
208
209    let json_array = json
210        .as_array()
211        .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
212
213    let mut solutions = Vec::new();
214
215    for solution in json_array {
216        let mut sol = HashMap::new();
217        let solution = solution
218            .as_object()
219            .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
220
221        for (var_name, constant) in solution {
222            let constant = match constant {
223                JsonValue::Number(n) => {
224                    let n = n
225                        .as_i64()
226                        .ok_or(Error::Parse("Invalid integer".to_owned()))?;
227                    Literal::Int(n as i32)
228                }
229                JsonValue::Bool(b) => Literal::Bool(*b),
230                _ => return Err(Error::Parse("Invalid constant".to_owned()).into()),
231            };
232
233            sol.insert(User(var_name.into()), constant);
234        }
235
236        solutions.push(sol);
237    }
238
239    Ok(solutions)
240}
241
242/// Writes the minion solutions to a generated JSON file, and returns the JSON structure.
243pub fn save_solutions_json(
244    solutions: &Vec<BTreeMap<Name, Literal>>,
245    path: &str,
246    test_name: &str,
247    solver: SolverFamily,
248) -> Result<JsonValue, std::io::Error> {
249    let json_solutions = solutions_to_json(solutions);
250    let generated_json_str = serde_json::to_string_pretty(&json_solutions)?;
251
252    let solver_name = solver.as_str();
253    let filename = format!("{path}/{test_name}-{solver_name}.generated-solutions.json");
254    File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
255
256    Ok(json_solutions)
257}
258
259pub fn read_solutions_json(
260    path: &str,
261    test_name: &str,
262    prefix: &str,
263    solver: SolverFamily,
264) -> Result<JsonValue, anyhow::Error> {
265    let solver_name = match solver {
266        SolverFamily::Sat(_) => "sat",
267        #[cfg(feature = "smt")]
268        SolverFamily::Smt(..) => "smt",
269        SolverFamily::Minion => "minion",
270    };
271    let new_filename = format!("{path}/{test_name}-{solver_name}.{prefix}-solutions.json");
272    let old_filename = format!("{path}/{solver_name}-{test_name}.{prefix}-solutions.json");
273    let expected_json_str = if Path::new(&new_filename).exists() {
274        read_with_path(new_filename)?
275    } else {
276        read_with_path(old_filename)?
277    };
278
279    let expected_solutions: JsonValue =
280        sort_json_object(&serde_json::from_str(&expected_json_str)?, true);
281
282    Ok(expected_solutions)
283}
284
285/// Reads a human-readable rule trace text file.
286pub fn read_human_rule_trace(
287    path: &str,
288    test_name: &str,
289    prefix: &str,
290    solver: &SolverFamily,
291) -> Result<Vec<String>, std::io::Error> {
292    let solver_name = solver.as_str();
293    let new_filename = format!("{path}/{test_name}-{solver_name}-{prefix}-rule-trace.txt");
294    let old_filename = format!("{path}/{solver_name}-{test_name}-{prefix}-rule-trace.txt");
295    let filename = if Path::new(&new_filename).exists() {
296        new_filename
297    } else {
298        old_filename
299    };
300    let rules_trace: Vec<String> = read_with_path(filename)?
301        .lines()
302        .map(String::from)
303        .collect();
304
305    Ok(rules_trace)
306}
307
308#[doc(hidden)]
309pub fn normalize_solutions_for_comparison(
310    input_solutions: &[BTreeMap<Name, Literal>],
311) -> Vec<BTreeMap<Name, Literal>> {
312    let mut normalized = input_solutions.to_vec();
313
314    for solset in &mut normalized {
315        // remove machine names
316        let keys_to_remove: Vec<Name> = solset
317            .keys()
318            .filter(|k| matches!(k, Name::Machine(_)))
319            .cloned()
320            .collect();
321        for k in keys_to_remove {
322            solset.remove(&k);
323        }
324
325        let mut updates = vec![];
326        for (k, v) in solset.clone() {
327            if let Name::User(_) = k {
328                match v {
329                    Literal::Bool(true) => updates.push((k, Literal::Int(1))),
330                    Literal::Bool(false) => updates.push((k, Literal::Int(0))),
331                    Literal::Int(_) => {}
332                    Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, _)) => {
333                        // make all domains the same (this is just in the tester so the types dont
334                        // actually matter)
335
336                        let mut matrix =
337                            AbstractLiteral::Matrix(elems, Moo::new(GroundDomain::Int(vec![])));
338                        matrix = matrix.transform(&move |x: AbstractLiteral<Literal>| match x {
339                            AbstractLiteral::Matrix(items, _) => {
340                                let items = items
341                                    .into_iter()
342                                    .map(|x| match x {
343                                        Literal::Bool(false) => Literal::Int(0),
344                                        Literal::Bool(true) => Literal::Int(1),
345                                        x => x,
346                                    })
347                                    .collect_vec();
348
349                                AbstractLiteral::Matrix(items, Moo::new(GroundDomain::Int(vec![])))
350                            }
351                            x => x,
352                        });
353                        updates.push((k, Literal::AbstractLiteral(matrix)));
354                    }
355                    Literal::AbstractLiteral(AbstractLiteral::Tuple(elems)) => {
356                        // just the same as matrix but with tuples instead
357                        // only conversion needed is to convert bools to ints
358                        let mut tuple = AbstractLiteral::Tuple(elems);
359                        tuple = tuple.transform(
360                            &(move |x: AbstractLiteral<Literal>| match x {
361                                AbstractLiteral::Tuple(items) => {
362                                    let items = items
363                                        .into_iter()
364                                        .map(|x| match x {
365                                            Literal::Bool(false) => Literal::Int(0),
366                                            Literal::Bool(true) => Literal::Int(1),
367                                            x => x,
368                                        })
369                                        .collect_vec();
370
371                                    AbstractLiteral::Tuple(items)
372                                }
373                                x => x,
374                            }),
375                        );
376                        updates.push((k, Literal::AbstractLiteral(tuple)));
377                    }
378                    Literal::AbstractLiteral(AbstractLiteral::Record(entries)) => {
379                        // just the same as matrix but with tuples instead
380                        // only conversion needed is to convert bools to ints
381                        let mut record = AbstractLiteral::Record(entries);
382                        record = record.transform(&move |x: AbstractLiteral<Literal>| match x {
383                            AbstractLiteral::Record(entries) => {
384                                let entries = entries
385                                    .into_iter()
386                                    .map(|x| {
387                                        let RecordValue { name, value } = x;
388                                        {
389                                            let value = match value {
390                                                Literal::Bool(false) => Literal::Int(0),
391                                                Literal::Bool(true) => Literal::Int(1),
392                                                x => x,
393                                            };
394                                            RecordValue { name, value }
395                                        }
396                                    })
397                                    .collect_vec();
398
399                                AbstractLiteral::Record(entries)
400                            }
401                            x => x,
402                        });
403                        updates.push((k, Literal::AbstractLiteral(record)));
404                    }
405                    Literal::AbstractLiteral(AbstractLiteral::Set(members)) => {
406                        let set = AbstractLiteral::Set(members).transform(&move |x| match x {
407                            AbstractLiteral::Set(members) => {
408                                let members = members
409                                    .into_iter()
410                                    .map(|x| match x {
411                                        Literal::Bool(false) => Literal::Int(0),
412                                        Literal::Bool(true) => Literal::Int(1),
413                                        x => x,
414                                    })
415                                    .collect_vec();
416
417                                AbstractLiteral::Set(members)
418                            }
419                            x => x,
420                        });
421                        updates.push((k, Literal::AbstractLiteral(set)));
422                    }
423                    e => bug!("unexpected literal type: {e:?}"),
424                }
425            }
426        }
427
428        for (k, v) in updates {
429            solset.insert(k, v);
430        }
431    }
432
433    // Remove duplicates
434    normalized = normalized.into_iter().unique().collect();
435    normalized
436}
437
438fn maybe_truncate_serialised_json(serialised: String, test_stage: &str) -> String {
439    if test_stage == "rewrite" {
440        truncate_to_first_lines(&serialised, REWRITE_SERIALISED_JSON_MAX_LINES)
441    } else {
442        serialised
443    }
444}
445
446fn truncate_to_first_lines(content: &str, max_lines: usize) -> String {
447    content.lines().take(max_lines).join("\n")
448}
449
450fn read_first_n_lines<P: AsRef<Path>>(filename: P, n: usize) -> io::Result<String> {
451    let reader = BufReader::new(File::open(&filename)?);
452    let lines = reader
453        .lines()
454        .chunks(n)
455        .into_iter()
456        .next()
457        .unwrap()
458        .collect::<Result<Vec<_>, _>>()?;
459    Ok(lines.join("\n"))
460}