1use std::collections::{BTreeMap, HashMap, HashSet};
2use std::fmt::Debug;
3use std::path::Path;
4use std::{io, mem, vec};
5
6use conjure_cp::ast::records::RecordValue;
7use conjure_cp::ast::serde::ObjId;
8use conjure_cp::bug;
9use itertools::Itertools as _;
10use std::fs::File;
11use std::hash::Hash;
12use std::io::{BufRead, BufReader, Write};
13use std::sync::{Arc, RwLock};
14use uniplate::Uniplate;
15
16use conjure_cp::ast::{AbstractLiteral, GroundDomain, Moo, SerdeModel};
17use conjure_cp::context::Context;
18use serde_json::{Error as JsonError, Value as JsonValue};
19
20use conjure_cp::error::Error;
21
22use crate::utils::conjure::solutions_to_json;
23use crate::utils::json::sort_json_object;
24use crate::utils::misc::to_set;
25use conjure_cp::Model as ConjureModel;
26use conjure_cp::ast::Name::User;
27use conjure_cp::ast::{Literal, Name};
28use conjure_cp::solver::SolverFamily;
29
30pub const REWRITE_SERIALISED_JSON_MAX_LINES: usize = 1000;
32
33fn model_to_json_with_stable_ids(model: &SerdeModel) -> Result<JsonValue, JsonError> {
38 let id_map = model.collect_stable_id_mapping();
40
41 let mut json = serde_json::to_value(model)?;
43
44 replace_ids(&mut json, &id_map);
46
47 Ok(json)
48}
49
50fn replace_ids(value: &mut JsonValue, id_map: &HashMap<ObjId, ObjId>) {
55 match value {
56 JsonValue::Object(map) => {
57 for (k, v) in map.iter_mut() {
62 if (k == "id" || k == "ptr" || k == "parent")
63 && let Ok(old_id) = serde_json::from_value::<ObjId>(mem::take(v))
64 {
65 let new_id = id_map.get(&old_id).expect("all ids to be in the id map");
66 *v = serde_json::to_value(new_id)
67 .expect("serialization of an ObjId to always succeed");
68 }
69 }
70
71 for val in map.values_mut() {
73 replace_ids(val, id_map);
74 }
75 }
76 JsonValue::Array(arr) => {
77 for item in arr {
78 replace_ids(item, id_map);
79 }
80 }
81 _ => {}
82 }
83}
84
85pub fn assert_eq_any_order<T: Eq + Hash + Debug + Clone>(a: &Vec<Vec<T>>, b: &Vec<Vec<T>>) {
86 assert_eq!(a.len(), b.len());
87
88 let mut a_rows: Vec<HashSet<T>> = Vec::new();
89 for row in a {
90 let hash_row = to_set(row);
91 a_rows.push(hash_row);
92 }
93
94 let mut b_rows: Vec<HashSet<T>> = Vec::new();
95 for row in b {
96 let hash_row = to_set(row);
97 b_rows.push(hash_row);
98 }
99
100 println!("{a_rows:?},{b_rows:?}");
101 for row in a_rows {
102 assert!(b_rows.contains(&row));
103 }
104}
105
106pub fn serialize_model(model: &ConjureModel) -> Result<String, JsonError> {
107 let serde_model: SerdeModel = model.clone().into();
108
109 let json_with_stable_ids = model_to_json_with_stable_ids(&serde_model)?;
111
112 let sorted_json = sort_json_object(&json_with_stable_ids, false);
114
115 serde_json::to_string_pretty(&sorted_json)
117}
118
119pub fn save_model_json(
120 model: &ConjureModel,
121 path: &str,
122 test_name: &str,
123 test_stage: &str,
124) -> Result<(), std::io::Error> {
125 let generated_json_str = serialize_model(model)?;
126 let generated_json_str = maybe_truncate_serialised_json(generated_json_str, test_stage);
127 let filename = format!("{path}/{test_name}.generated-{test_stage}.serialised.json");
128 File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
129 Ok(())
130}
131
132pub fn save_stats_json(
133 context: Arc<RwLock<Context<'static>>>,
134 path: &str,
135 test_name: &str,
136) -> Result<(), std::io::Error> {
137 #[allow(clippy::unwrap_used)]
138 let stats = context.read().unwrap().stats.clone();
139 let generated_json = sort_json_object(&serde_json::to_value(stats)?, false);
140
141 let generated_json_str = serde_json::to_string_pretty(&generated_json)?;
143
144 File::create(format!("{path}/{test_name}-stats.json"))?
145 .write_all(generated_json_str.as_bytes())?;
146
147 Ok(())
148}
149
150fn read_with_path(path: String) -> Result<String, std::io::Error> {
152 std::fs::read_to_string(&path)
153 .map_err(|e| io::Error::new(e.kind(), format!("{} (path: {})", e, path)))
154}
155
156pub fn read_model_json(
157 ctx: &Arc<RwLock<Context<'static>>>,
158 path: &str,
159 test_name: &str,
160 prefix: &str,
161 test_stage: &str,
162) -> Result<ConjureModel, std::io::Error> {
163 let expected_json_str = read_with_path(format!(
164 "{path}/{test_name}.{prefix}-{test_stage}.serialised.json"
165 ))?;
166 println!("{path}/{test_name}.{prefix}-{test_stage}.serialised.json");
167 let expected_model: SerdeModel = serde_json::from_str(&expected_json_str)?;
168
169 Ok(expected_model.initialise(ctx.clone()).unwrap())
170}
171
172pub fn read_model_json_prefix(
174 path: &str,
175 test_name: &str,
176 prefix: &str,
177 test_stage: &str,
178 max_lines: usize,
179) -> Result<String, std::io::Error> {
180 let filename = format!("{path}/{test_name}.{prefix}-{test_stage}.serialised.json");
181 read_first_n_lines(filename, max_lines)
182}
183
184pub fn minion_solutions_from_json(
185 serialized: &str,
186) -> Result<Vec<HashMap<Name, Literal>>, anyhow::Error> {
187 let json: JsonValue = serde_json::from_str(serialized)?;
188
189 let json_array = json
190 .as_array()
191 .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
192
193 let mut solutions = Vec::new();
194
195 for solution in json_array {
196 let mut sol = HashMap::new();
197 let solution = solution
198 .as_object()
199 .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
200
201 for (var_name, constant) in solution {
202 let constant = match constant {
203 JsonValue::Number(n) => {
204 let n = n
205 .as_i64()
206 .ok_or(Error::Parse("Invalid integer".to_owned()))?;
207 Literal::Int(n as i32)
208 }
209 JsonValue::Bool(b) => Literal::Bool(*b),
210 _ => return Err(Error::Parse("Invalid constant".to_owned()).into()),
211 };
212
213 sol.insert(User(var_name.into()), constant);
214 }
215
216 solutions.push(sol);
217 }
218
219 Ok(solutions)
220}
221
222pub fn save_solutions_json(
224 solutions: &Vec<BTreeMap<Name, Literal>>,
225 path: &str,
226 test_name: &str,
227 solver: SolverFamily,
228) -> Result<JsonValue, std::io::Error> {
229 let json_solutions = solutions_to_json(solutions);
230 let generated_json_str = serde_json::to_string_pretty(&json_solutions)?;
231
232 let solver_name = match solver {
233 SolverFamily::Sat => "sat",
234 #[cfg(feature = "smt")]
235 SolverFamily::Smt(..) => "smt",
236 SolverFamily::Minion => "minion",
237 };
238
239 let filename = format!("{path}/{test_name}.generated-{solver_name}.solutions.json");
240 File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
241
242 Ok(json_solutions)
243}
244
245pub fn read_solutions_json(
246 path: &str,
247 test_name: &str,
248 prefix: &str,
249 solver: SolverFamily,
250) -> Result<JsonValue, anyhow::Error> {
251 let solver_name = match solver {
252 SolverFamily::Sat => "sat",
253 #[cfg(feature = "smt")]
254 SolverFamily::Smt(..) => "smt",
255 SolverFamily::Minion => "minion",
256 };
257
258 let expected_json_str = read_with_path(format!(
259 "{path}/{test_name}.{prefix}-{solver_name}.solutions.json"
260 ))?;
261
262 let expected_solutions: JsonValue =
263 sort_json_object(&serde_json::from_str(&expected_json_str)?, true);
264
265 Ok(expected_solutions)
266}
267
268pub fn read_human_rule_trace(
270 path: &str,
271 test_name: &str,
272 prefix: &str,
273) -> Result<Vec<String>, std::io::Error> {
274 let filename = format!("{path}/{test_name}-{prefix}-rule-trace-human.txt");
275 let rules_trace: Vec<String> = read_with_path(filename)?
276 .lines()
277 .map(String::from)
278 .collect();
279
280 Ok(rules_trace)
281}
282
283#[doc(hidden)]
284pub fn normalize_solutions_for_comparison(
285 input_solutions: &[BTreeMap<Name, Literal>],
286) -> Vec<BTreeMap<Name, Literal>> {
287 let mut normalized = input_solutions.to_vec();
288
289 for solset in &mut normalized {
290 let keys_to_remove: Vec<Name> = solset
292 .keys()
293 .filter(|k| matches!(k, Name::Machine(_)))
294 .cloned()
295 .collect();
296 for k in keys_to_remove {
297 solset.remove(&k);
298 }
299
300 let mut updates = vec![];
301 for (k, v) in solset.clone() {
302 if let Name::User(_) = k {
303 match v {
304 Literal::Bool(true) => updates.push((k, Literal::Int(1))),
305 Literal::Bool(false) => updates.push((k, Literal::Int(0))),
306 Literal::Int(_) => {}
307 Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, _)) => {
308 let mut matrix =
312 AbstractLiteral::Matrix(elems, Moo::new(GroundDomain::Int(vec![])));
313 matrix = matrix.transform(&move |x: AbstractLiteral<Literal>| match x {
314 AbstractLiteral::Matrix(items, _) => {
315 let items = items
316 .into_iter()
317 .map(|x| match x {
318 Literal::Bool(false) => Literal::Int(0),
319 Literal::Bool(true) => Literal::Int(1),
320 x => x,
321 })
322 .collect_vec();
323
324 AbstractLiteral::Matrix(items, Moo::new(GroundDomain::Int(vec![])))
325 }
326 x => x,
327 });
328 updates.push((k, Literal::AbstractLiteral(matrix)));
329 }
330 Literal::AbstractLiteral(AbstractLiteral::Tuple(elems)) => {
331 let mut tuple = AbstractLiteral::Tuple(elems);
334 tuple = tuple.transform(
335 &(move |x: AbstractLiteral<Literal>| match x {
336 AbstractLiteral::Tuple(items) => {
337 let items = items
338 .into_iter()
339 .map(|x| match x {
340 Literal::Bool(false) => Literal::Int(0),
341 Literal::Bool(true) => Literal::Int(1),
342 x => x,
343 })
344 .collect_vec();
345
346 AbstractLiteral::Tuple(items)
347 }
348 x => x,
349 }),
350 );
351 updates.push((k, Literal::AbstractLiteral(tuple)));
352 }
353 Literal::AbstractLiteral(AbstractLiteral::Record(entries)) => {
354 let mut record = AbstractLiteral::Record(entries);
357 record = record.transform(&move |x: AbstractLiteral<Literal>| match x {
358 AbstractLiteral::Record(entries) => {
359 let entries = entries
360 .into_iter()
361 .map(|x| {
362 let RecordValue { name, value } = x;
363 {
364 let value = match value {
365 Literal::Bool(false) => Literal::Int(0),
366 Literal::Bool(true) => Literal::Int(1),
367 x => x,
368 };
369 RecordValue { name, value }
370 }
371 })
372 .collect_vec();
373
374 AbstractLiteral::Record(entries)
375 }
376 x => x,
377 });
378 updates.push((k, Literal::AbstractLiteral(record)));
379 }
380 Literal::AbstractLiteral(AbstractLiteral::Set(members)) => {
381 let set = AbstractLiteral::Set(members).transform(&move |x| match x {
382 AbstractLiteral::Set(members) => {
383 let members = members
384 .into_iter()
385 .map(|x| match x {
386 Literal::Bool(false) => Literal::Int(0),
387 Literal::Bool(true) => Literal::Int(1),
388 x => x,
389 })
390 .collect_vec();
391
392 AbstractLiteral::Set(members)
393 }
394 x => x,
395 });
396 updates.push((k, Literal::AbstractLiteral(set)));
397 }
398 e => bug!("unexpected literal type: {e:?}"),
399 }
400 }
401 }
402
403 for (k, v) in updates {
404 solset.insert(k, v);
405 }
406 }
407
408 normalized = normalized.into_iter().unique().collect();
410 normalized
411}
412
413fn maybe_truncate_serialised_json(serialised: String, test_stage: &str) -> String {
414 if test_stage == "rewrite" {
415 truncate_to_first_lines(&serialised, REWRITE_SERIALISED_JSON_MAX_LINES)
416 } else {
417 serialised
418 }
419}
420
421fn truncate_to_first_lines(content: &str, max_lines: usize) -> String {
422 content.lines().take(max_lines).join("\n")
423}
424
425fn read_first_n_lines<P: AsRef<Path>>(filename: P, n: usize) -> io::Result<String> {
426 let reader = BufReader::new(File::open(&filename)?);
427 let lines = reader
428 .lines()
429 .chunks(n)
430 .into_iter()
431 .next()
432 .unwrap()
433 .collect::<Result<Vec<_>, _>>()?;
434 Ok(lines.join("\n"))
435}