Skip to main content

conjure_cp_cli/utils/
testing.rs

1use std::collections::{BTreeMap, HashMap, HashSet};
2use std::fmt::Debug;
3use std::path::Path;
4use std::{io, mem, vec};
5
6use conjure_cp::ast::records::RecordValue;
7use conjure_cp::ast::serde::ObjId;
8use conjure_cp::bug;
9use itertools::Itertools as _;
10use std::fs::File;
11use std::hash::Hash;
12use std::io::{BufRead, BufReader, Write};
13use std::sync::{Arc, RwLock};
14use uniplate::Uniplate;
15
16use conjure_cp::ast::{AbstractLiteral, GroundDomain, Moo, SerdeModel};
17use conjure_cp::context::Context;
18use serde_json::{Error as JsonError, Value as JsonValue};
19
20use conjure_cp::error::Error;
21
22use crate::utils::conjure::solutions_to_json;
23use crate::utils::json::sort_json_object;
24use crate::utils::misc::to_set;
25use conjure_cp::Model as ConjureModel;
26use conjure_cp::ast::Name::User;
27use conjure_cp::ast::{Literal, Name};
28use conjure_cp::solver::SolverFamily;
29
30/// Limit how many lines of the rewrite serialisation we persist/compare in integration tests.
31pub const REWRITE_SERIALISED_JSON_MAX_LINES: usize = 1000;
32
33/// Converts a SerdeModel to JSON with stable IDs.
34///
35/// This ensures that the same model structure always produces the same IDs,
36/// regardless of the order in which objects were created in memory.
37fn model_to_json_with_stable_ids(model: &SerdeModel) -> Result<JsonValue, JsonError> {
38    // Collect stable ID mapping using uniplate traversal on the SerdeModel
39    let id_map = model.collect_stable_id_mapping();
40
41    // Serialize the model to JSON
42    let mut json = serde_json::to_value(model)?;
43
44    // Replace all IDs in the JSON with their stable counterparts
45    replace_ids(&mut json, &id_map);
46
47    Ok(json)
48}
49
50/// Recursively replaces all IDs in the JSON with their stable counterparts.
51///
52/// This is applied to all fields that are called "id" or "ptr" - be mindful
53/// of potential naming clashes in the future!
54fn replace_ids(value: &mut JsonValue, id_map: &HashMap<ObjId, ObjId>) {
55    match value {
56        JsonValue::Object(map) => {
57            // Replace IDs in three places:
58            // - "id" fields (SymbolTable IDs)
59            // - "parent" fields (SymbolTable nesting)
60            // - "ptr" fields (DeclarationPtr IDs)
61            for (k, v) in map.iter_mut() {
62                if (k == "id" || k == "ptr" || k == "parent")
63                    && let Ok(old_id) = serde_json::from_value::<ObjId>(mem::take(v))
64                {
65                    let new_id = id_map.get(&old_id).expect("all ids to be in the id map");
66                    *v = serde_json::to_value(new_id)
67                        .expect("serialization of an ObjId to always succeed");
68                }
69            }
70
71            // Recursively process all values
72            for val in map.values_mut() {
73                replace_ids(val, id_map);
74            }
75        }
76        JsonValue::Array(arr) => {
77            for item in arr {
78                replace_ids(item, id_map);
79            }
80        }
81        _ => {}
82    }
83}
84
85pub fn assert_eq_any_order<T: Eq + Hash + Debug + Clone>(a: &Vec<Vec<T>>, b: &Vec<Vec<T>>) {
86    assert_eq!(a.len(), b.len());
87
88    let mut a_rows: Vec<HashSet<T>> = Vec::new();
89    for row in a {
90        let hash_row = to_set(row);
91        a_rows.push(hash_row);
92    }
93
94    let mut b_rows: Vec<HashSet<T>> = Vec::new();
95    for row in b {
96        let hash_row = to_set(row);
97        b_rows.push(hash_row);
98    }
99
100    println!("{a_rows:?},{b_rows:?}");
101    for row in a_rows {
102        assert!(b_rows.contains(&row));
103    }
104}
105
106pub fn serialize_model(model: &ConjureModel) -> Result<String, JsonError> {
107    let serde_model: SerdeModel = model.clone().into();
108
109    // Convert to JSON with stable IDs
110    let json_with_stable_ids = model_to_json_with_stable_ids(&serde_model)?;
111
112    // Sort JSON object keys for consistent output
113    let sorted_json = sort_json_object(&json_with_stable_ids, false);
114
115    // Serialize to pretty-printed string
116    serde_json::to_string_pretty(&sorted_json)
117}
118
119pub fn save_model_json(
120    model: &ConjureModel,
121    path: &str,
122    test_name: &str,
123    test_stage: &str,
124) -> Result<(), std::io::Error> {
125    let generated_json_str = serialize_model(model)?;
126    let generated_json_str = maybe_truncate_serialised_json(generated_json_str, test_stage);
127    let filename = format!("{path}/{test_name}.generated-{test_stage}.serialised.json");
128    File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
129    Ok(())
130}
131
132pub fn save_stats_json(
133    context: Arc<RwLock<Context<'static>>>,
134    path: &str,
135    test_name: &str,
136) -> Result<(), std::io::Error> {
137    #[allow(clippy::unwrap_used)]
138    let stats = context.read().unwrap().stats.clone();
139    let generated_json = sort_json_object(&serde_json::to_value(stats)?, false);
140
141    // serialise to string
142    let generated_json_str = serde_json::to_string_pretty(&generated_json)?;
143
144    File::create(format!("{path}/{test_name}-stats.json"))?
145        .write_all(generated_json_str.as_bytes())?;
146
147    Ok(())
148}
149
150/// Reads a file into a `String`, providing a clearer error message that includes the file path.
151fn read_with_path(path: String) -> Result<String, std::io::Error> {
152    std::fs::read_to_string(&path)
153        .map_err(|e| io::Error::new(e.kind(), format!("{} (path: {})", e, path)))
154}
155
156pub fn read_model_json(
157    ctx: &Arc<RwLock<Context<'static>>>,
158    path: &str,
159    test_name: &str,
160    prefix: &str,
161    test_stage: &str,
162) -> Result<ConjureModel, std::io::Error> {
163    let expected_json_str = read_with_path(format!(
164        "{path}/{test_name}.{prefix}-{test_stage}.serialised.json"
165    ))?;
166    println!("{path}/{test_name}.{prefix}-{test_stage}.serialised.json");
167    let expected_model: SerdeModel = serde_json::from_str(&expected_json_str)?;
168
169    Ok(expected_model.initialise(ctx.clone()).unwrap())
170}
171
172/// Reads only the first `max_lines` from a serialised model JSON file.
173pub fn read_model_json_prefix(
174    path: &str,
175    test_name: &str,
176    prefix: &str,
177    test_stage: &str,
178    max_lines: usize,
179) -> Result<String, std::io::Error> {
180    let filename = format!("{path}/{test_name}.{prefix}-{test_stage}.serialised.json");
181    read_first_n_lines(filename, max_lines)
182}
183
184pub fn minion_solutions_from_json(
185    serialized: &str,
186) -> Result<Vec<HashMap<Name, Literal>>, anyhow::Error> {
187    let json: JsonValue = serde_json::from_str(serialized)?;
188
189    let json_array = json
190        .as_array()
191        .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
192
193    let mut solutions = Vec::new();
194
195    for solution in json_array {
196        let mut sol = HashMap::new();
197        let solution = solution
198            .as_object()
199            .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
200
201        for (var_name, constant) in solution {
202            let constant = match constant {
203                JsonValue::Number(n) => {
204                    let n = n
205                        .as_i64()
206                        .ok_or(Error::Parse("Invalid integer".to_owned()))?;
207                    Literal::Int(n as i32)
208                }
209                JsonValue::Bool(b) => Literal::Bool(*b),
210                _ => return Err(Error::Parse("Invalid constant".to_owned()).into()),
211            };
212
213            sol.insert(User(var_name.into()), constant);
214        }
215
216        solutions.push(sol);
217    }
218
219    Ok(solutions)
220}
221
222/// Writes the minion solutions to a generated JSON file, and returns the JSON structure.
223pub fn save_solutions_json(
224    solutions: &Vec<BTreeMap<Name, Literal>>,
225    path: &str,
226    test_name: &str,
227    solver: SolverFamily,
228) -> Result<JsonValue, std::io::Error> {
229    let json_solutions = solutions_to_json(solutions);
230    let generated_json_str = serde_json::to_string_pretty(&json_solutions)?;
231
232    let solver_name = match solver {
233        SolverFamily::Sat => "sat",
234        #[cfg(feature = "smt")]
235        SolverFamily::Smt(..) => "smt",
236        SolverFamily::Minion => "minion",
237    };
238
239    let filename = format!("{path}/{test_name}.generated-{solver_name}.solutions.json");
240    File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
241
242    Ok(json_solutions)
243}
244
245pub fn read_solutions_json(
246    path: &str,
247    test_name: &str,
248    prefix: &str,
249    solver: SolverFamily,
250) -> Result<JsonValue, anyhow::Error> {
251    let solver_name = match solver {
252        SolverFamily::Sat => "sat",
253        #[cfg(feature = "smt")]
254        SolverFamily::Smt(..) => "smt",
255        SolverFamily::Minion => "minion",
256    };
257
258    let expected_json_str = read_with_path(format!(
259        "{path}/{test_name}.{prefix}-{solver_name}.solutions.json"
260    ))?;
261
262    let expected_solutions: JsonValue =
263        sort_json_object(&serde_json::from_str(&expected_json_str)?, true);
264
265    Ok(expected_solutions)
266}
267
268/// Reads a human-readable rule trace text file.
269pub fn read_human_rule_trace(
270    path: &str,
271    test_name: &str,
272    prefix: &str,
273) -> Result<Vec<String>, std::io::Error> {
274    let filename = format!("{path}/{test_name}-{prefix}-rule-trace-human.txt");
275    let rules_trace: Vec<String> = read_with_path(filename)?
276        .lines()
277        .map(String::from)
278        .collect();
279
280    Ok(rules_trace)
281}
282
283#[doc(hidden)]
284pub fn normalize_solutions_for_comparison(
285    input_solutions: &[BTreeMap<Name, Literal>],
286) -> Vec<BTreeMap<Name, Literal>> {
287    let mut normalized = input_solutions.to_vec();
288
289    for solset in &mut normalized {
290        // remove machine names
291        let keys_to_remove: Vec<Name> = solset
292            .keys()
293            .filter(|k| matches!(k, Name::Machine(_)))
294            .cloned()
295            .collect();
296        for k in keys_to_remove {
297            solset.remove(&k);
298        }
299
300        let mut updates = vec![];
301        for (k, v) in solset.clone() {
302            if let Name::User(_) = k {
303                match v {
304                    Literal::Bool(true) => updates.push((k, Literal::Int(1))),
305                    Literal::Bool(false) => updates.push((k, Literal::Int(0))),
306                    Literal::Int(_) => {}
307                    Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, _)) => {
308                        // make all domains the same (this is just in the tester so the types dont
309                        // actually matter)
310
311                        let mut matrix =
312                            AbstractLiteral::Matrix(elems, Moo::new(GroundDomain::Int(vec![])));
313                        matrix = matrix.transform(&move |x: AbstractLiteral<Literal>| match x {
314                            AbstractLiteral::Matrix(items, _) => {
315                                let items = items
316                                    .into_iter()
317                                    .map(|x| match x {
318                                        Literal::Bool(false) => Literal::Int(0),
319                                        Literal::Bool(true) => Literal::Int(1),
320                                        x => x,
321                                    })
322                                    .collect_vec();
323
324                                AbstractLiteral::Matrix(items, Moo::new(GroundDomain::Int(vec![])))
325                            }
326                            x => x,
327                        });
328                        updates.push((k, Literal::AbstractLiteral(matrix)));
329                    }
330                    Literal::AbstractLiteral(AbstractLiteral::Tuple(elems)) => {
331                        // just the same as matrix but with tuples instead
332                        // only conversion needed is to convert bools to ints
333                        let mut tuple = AbstractLiteral::Tuple(elems);
334                        tuple = tuple.transform(
335                            &(move |x: AbstractLiteral<Literal>| match x {
336                                AbstractLiteral::Tuple(items) => {
337                                    let items = items
338                                        .into_iter()
339                                        .map(|x| match x {
340                                            Literal::Bool(false) => Literal::Int(0),
341                                            Literal::Bool(true) => Literal::Int(1),
342                                            x => x,
343                                        })
344                                        .collect_vec();
345
346                                    AbstractLiteral::Tuple(items)
347                                }
348                                x => x,
349                            }),
350                        );
351                        updates.push((k, Literal::AbstractLiteral(tuple)));
352                    }
353                    Literal::AbstractLiteral(AbstractLiteral::Record(entries)) => {
354                        // just the same as matrix but with tuples instead
355                        // only conversion needed is to convert bools to ints
356                        let mut record = AbstractLiteral::Record(entries);
357                        record = record.transform(&move |x: AbstractLiteral<Literal>| match x {
358                            AbstractLiteral::Record(entries) => {
359                                let entries = entries
360                                    .into_iter()
361                                    .map(|x| {
362                                        let RecordValue { name, value } = x;
363                                        {
364                                            let value = match value {
365                                                Literal::Bool(false) => Literal::Int(0),
366                                                Literal::Bool(true) => Literal::Int(1),
367                                                x => x,
368                                            };
369                                            RecordValue { name, value }
370                                        }
371                                    })
372                                    .collect_vec();
373
374                                AbstractLiteral::Record(entries)
375                            }
376                            x => x,
377                        });
378                        updates.push((k, Literal::AbstractLiteral(record)));
379                    }
380                    Literal::AbstractLiteral(AbstractLiteral::Set(members)) => {
381                        let set = AbstractLiteral::Set(members).transform(&move |x| match x {
382                            AbstractLiteral::Set(members) => {
383                                let members = members
384                                    .into_iter()
385                                    .map(|x| match x {
386                                        Literal::Bool(false) => Literal::Int(0),
387                                        Literal::Bool(true) => Literal::Int(1),
388                                        x => x,
389                                    })
390                                    .collect_vec();
391
392                                AbstractLiteral::Set(members)
393                            }
394                            x => x,
395                        });
396                        updates.push((k, Literal::AbstractLiteral(set)));
397                    }
398                    e => bug!("unexpected literal type: {e:?}"),
399                }
400            }
401        }
402
403        for (k, v) in updates {
404            solset.insert(k, v);
405        }
406    }
407
408    // Remove duplicates
409    normalized = normalized.into_iter().unique().collect();
410    normalized
411}
412
413fn maybe_truncate_serialised_json(serialised: String, test_stage: &str) -> String {
414    if test_stage == "rewrite" {
415        truncate_to_first_lines(&serialised, REWRITE_SERIALISED_JSON_MAX_LINES)
416    } else {
417        serialised
418    }
419}
420
421fn truncate_to_first_lines(content: &str, max_lines: usize) -> String {
422    content.lines().take(max_lines).join("\n")
423}
424
425fn read_first_n_lines<P: AsRef<Path>>(filename: P, n: usize) -> io::Result<String> {
426    let reader = BufReader::new(File::open(&filename)?);
427    let lines = reader
428        .lines()
429        .chunks(n)
430        .into_iter()
431        .next()
432        .unwrap()
433        .collect::<Result<Vec<_>, _>>()?;
434    Ok(lines.join("\n"))
435}