1
use std::collections::{BTreeMap, HashMap, HashSet};
2
use std::fmt::Debug;
3
use std::path::Path;
4
use std::{io, mem, vec};
5

            
6
use conjure_cp::ast::records::RecordValue;
7
use conjure_cp::ast::serde::ObjId;
8
use conjure_cp::bug;
9
use itertools::Itertools as _;
10
use std::fs::File;
11
use std::hash::Hash;
12
use std::io::{BufRead, BufReader, Write};
13
use std::sync::{Arc, RwLock};
14
use uniplate::Uniplate;
15

            
16
use conjure_cp::ast::{AbstractLiteral, GroundDomain, Moo, SerdeModel};
17
use conjure_cp::context::Context;
18
use serde_json::{Error as JsonError, Value as JsonValue};
19

            
20
use conjure_cp::error::Error;
21

            
22
use crate::utils::conjure::solutions_to_json;
23
use crate::utils::json::sort_json_object;
24
use crate::utils::misc::to_set;
25
use conjure_cp::Model as ConjureModel;
26
use conjure_cp::ast::Name::User;
27
use conjure_cp::ast::{Literal, Name};
28
use conjure_cp::settings::SolverFamily;
29

            
30
/// Limit how many lines of the rewrite serialisation we persist/compare in integration tests.
31
pub const REWRITE_SERIALISED_JSON_MAX_LINES: usize = 1000;
32

            
33
/// Converts a SerdeModel to JSON with stable IDs.
34
///
35
/// This ensures that the same model structure always produces the same IDs,
36
/// regardless of the order in which objects were created in memory.
37
126
fn model_to_json_with_stable_ids(model: &SerdeModel) -> Result<JsonValue, JsonError> {
38
    // Collect stable ID mapping using uniplate traversal on the SerdeModel
39
126
    let id_map = model.collect_stable_id_mapping();
40

            
41
    // Serialize the model to JSON
42
126
    let mut json = serde_json::to_value(model)?;
43

            
44
    // Replace all IDs in the JSON with their stable counterparts
45
126
    replace_ids(&mut json, &id_map);
46

            
47
126
    Ok(json)
48
126
}
49

            
50
/// Recursively replaces all IDs in the JSON with their stable counterparts.
51
///
52
/// This is applied to all fields that are called "id" or "ptr" - be mindful
53
/// of potential naming clashes in the future!
54
8284
fn replace_ids(value: &mut JsonValue, id_map: &HashMap<ObjId, ObjId>) {
55
8284
    match value {
56
3522
        JsonValue::Object(map) => {
57
            // Replace IDs in three places:
58
            // - "id" fields (SymbolTable IDs)
59
            // - "parent" fields (SymbolTable nesting)
60
            // - "ptr" fields (DeclarationPtr IDs)
61
5674
            for (k, v) in map.iter_mut() {
62
5674
                if (k == "id" || k == "ptr" || k == "parent")
63
496
                    && let Ok(old_id) = serde_json::from_value::<ObjId>(mem::take(v))
64
370
                {
65
370
                    let new_id = id_map.get(&old_id).expect("all ids to be in the id map");
66
370
                    *v = serde_json::to_value(new_id)
67
370
                        .expect("serialization of an ObjId to always succeed");
68
5304
                }
69
            }
70

            
71
            // Recursively process all values
72
5674
            for val in map.values_mut() {
73
5674
                replace_ids(val, id_map);
74
5674
            }
75
        }
76
1544
        JsonValue::Array(arr) => {
77
2484
            for item in arr {
78
2484
                replace_ids(item, id_map);
79
2484
            }
80
        }
81
3218
        _ => {}
82
    }
83
8284
}
84

            
85
pub fn assert_eq_any_order<T: Eq + Hash + Debug + Clone>(a: &Vec<Vec<T>>, b: &Vec<Vec<T>>) {
86
    assert_eq!(a.len(), b.len());
87

            
88
    let mut a_rows: Vec<HashSet<T>> = Vec::new();
89
    for row in a {
90
        let hash_row = to_set(row);
91
        a_rows.push(hash_row);
92
    }
93

            
94
    let mut b_rows: Vec<HashSet<T>> = Vec::new();
95
    for row in b {
96
        let hash_row = to_set(row);
97
        b_rows.push(hash_row);
98
    }
99

            
100
    for row in a_rows {
101
        assert!(b_rows.contains(&row));
102
    }
103
}
104

            
105
126
pub fn serialize_model(model: &ConjureModel) -> Result<String, JsonError> {
106
126
    let serde_model: SerdeModel = model.clone().into();
107

            
108
    // Convert to JSON with stable IDs
109
126
    let json_with_stable_ids = model_to_json_with_stable_ids(&serde_model)?;
110

            
111
    // Sort JSON object keys for consistent output
112
126
    let sorted_json = sort_json_object(&json_with_stable_ids, false);
113

            
114
    // Serialize to pretty-printed string
115
126
    serde_json::to_string_pretty(&sorted_json)
116
126
}
117

            
118
pub fn save_model_json(
119
    model: &ConjureModel,
120
    path: &str,
121
    test_name: &str,
122
    test_stage: &str,
123
    solver: SolverFamily,
124
) -> Result<(), std::io::Error> {
125
    let marker = solver.as_str();
126
    let generated_json_str = serialize_model(model)?;
127
    let generated_json_str = maybe_truncate_serialised_json(generated_json_str, test_stage);
128
    let filename = format!("{path}/{test_name}-{marker}.generated-{test_stage}.serialised.json");
129
    println!("saving: {}", filename);
130
    File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
131
    Ok(())
132
}
133

            
134
5140
pub fn save_stats_json(
135
5140
    context: Arc<RwLock<Context<'static>>>,
136
5140
    path: &str,
137
5140
    test_name: &str,
138
5140
    solver: SolverFamily,
139
5140
) -> Result<(), std::io::Error> {
140
    #[allow(clippy::unwrap_used)]
141
5140
    let solver_name = solver.as_str();
142

            
143
5140
    let stats = context.read().unwrap().clone();
144
5140
    let generated_json = sort_json_object(&serde_json::to_value(stats)?, false);
145

            
146
    // serialise to string
147
5140
    let generated_json_str = serde_json::to_string_pretty(&generated_json)?;
148

            
149
5140
    File::create(format!("{path}/{test_name}-{solver_name}-stats.json"))?
150
5140
        .write_all(generated_json_str.as_bytes())?;
151

            
152
5140
    Ok(())
153
5140
}
154

            
155
/// Reads a file into a `String`, providing a clearer error message that includes the file path.
156
15428
fn read_with_path(path: String) -> Result<String, std::io::Error> {
157
15428
    std::fs::read_to_string(&path)
158
15428
        .map_err(|e| io::Error::new(e.kind(), format!("{} (path: {})", e, path)))
159
15428
}
160

            
161
pub fn read_model_json(
162
    ctx: &Arc<RwLock<Context<'static>>>,
163
    path: &str,
164
    test_name: &str,
165
    prefix: &str,
166
    test_stage: &str,
167
    solver: SolverFamily,
168
) -> Result<ConjureModel, std::io::Error> {
169
    let marker = solver.as_str();
170
    let filepath = format!("{path}/{test_name}-{marker}.{prefix}-{test_stage}.serialised.json");
171
    let expected_json_str = std::fs::read_to_string(filepath)?;
172
    let expected_model: SerdeModel = serde_json::from_str(&expected_json_str)?;
173

            
174
    Ok(expected_model.initialise(ctx.clone()).unwrap())
175
}
176

            
177
/// Reads only the first `max_lines` from a serialised model JSON file.
178
pub fn read_model_json_prefix(
179
    path: &str,
180
    test_name: &str,
181
    prefix: &str,
182
    test_stage: &str,
183
    solver: SolverFamily,
184
    max_lines: usize,
185
) -> Result<String, std::io::Error> {
186
    let marker = solver.as_str();
187
    let filename = format!("{path}/{test_name}-{marker}.{prefix}-{test_stage}.serialised.json");
188
    println!("reading: {}", filename);
189
    read_first_n_lines(filename, max_lines)
190
}
191

            
192
pub fn minion_solutions_from_json(
193
    serialized: &str,
194
) -> Result<Vec<HashMap<Name, Literal>>, anyhow::Error> {
195
    let json: JsonValue = serde_json::from_str(serialized)?;
196

            
197
    let json_array = json
198
        .as_array()
199
        .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
200

            
201
    let mut solutions = Vec::new();
202

            
203
    for solution in json_array {
204
        let mut sol = HashMap::new();
205
        let solution = solution
206
            .as_object()
207
            .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
208

            
209
        for (var_name, constant) in solution {
210
            let constant = match constant {
211
                JsonValue::Number(n) => {
212
                    let n = n
213
                        .as_i64()
214
                        .ok_or(Error::Parse("Invalid integer".to_owned()))?;
215
                    Literal::Int(n as i32)
216
                }
217
                JsonValue::Bool(b) => Literal::Bool(*b),
218
                _ => return Err(Error::Parse("Invalid constant".to_owned()).into()),
219
            };
220

            
221
            sol.insert(User(var_name.into()), constant);
222
        }
223

            
224
        solutions.push(sol);
225
    }
226

            
227
    Ok(solutions)
228
}
229

            
230
/// Writes the minion solutions to a generated JSON file, and returns the JSON structure.
231
5148
pub fn save_solutions_json(
232
5148
    solutions: &Vec<BTreeMap<Name, Literal>>,
233
5148
    path: &str,
234
5148
    test_name: &str,
235
5148
    solver: SolverFamily,
236
5148
) -> Result<JsonValue, std::io::Error> {
237
5148
    let json_solutions = solutions_to_json(solutions);
238
5148
    let generated_json_str = serde_json::to_string_pretty(&json_solutions)?;
239

            
240
5148
    let solver_name = solver.as_str();
241
5148
    let filename = format!("{path}/{test_name}-{solver_name}.generated-solutions.json");
242
5148
    File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
243

            
244
5148
    Ok(json_solutions)
245
5148
}
246

            
247
5148
pub fn read_solutions_json(
248
5148
    path: &str,
249
5148
    test_name: &str,
250
5148
    prefix: &str,
251
5148
    solver: SolverFamily,
252
5148
) -> Result<JsonValue, anyhow::Error> {
253
5148
    let solver_name = solver.as_str();
254
5148
    let filename = format!("{path}/{test_name}-{solver_name}.{prefix}-solutions.json");
255
5148
    let expected_json_str = read_with_path(filename)?;
256

            
257
5140
    let expected_solutions: JsonValue =
258
5140
        sort_json_object(&serde_json::from_str(&expected_json_str)?, true);
259

            
260
5140
    Ok(expected_solutions)
261
5148
}
262

            
263
/// Reads a human-readable rule trace text file.
264
10280
pub fn read_human_rule_trace(
265
10280
    path: &str,
266
10280
    test_name: &str,
267
10280
    prefix: &str,
268
10280
    solver: &SolverFamily,
269
10280
) -> Result<Vec<String>, std::io::Error> {
270
10280
    let solver_name = solver.as_str();
271
10280
    let filename = format!("{path}/{test_name}-{solver_name}-{prefix}-rule-trace.txt");
272
10280
    let rules_trace: Vec<String> = read_with_path(filename)?
273
10280
        .lines()
274
10280
        .map(String::from)
275
10280
        .collect();
276

            
277
10280
    Ok(rules_trace)
278
10280
}
279

            
280
#[doc(hidden)]
281
pub fn normalize_solutions_for_comparison(
282
    input_solutions: &[BTreeMap<Name, Literal>],
283
) -> Vec<BTreeMap<Name, Literal>> {
284
    let mut normalized = input_solutions.to_vec();
285

            
286
    for solset in &mut normalized {
287
        // remove machine names
288
        let keys_to_remove: Vec<Name> = solset
289
            .keys()
290
            .filter(|k| matches!(k, Name::Machine(_)))
291
            .cloned()
292
            .collect();
293
        for k in keys_to_remove {
294
            solset.remove(&k);
295
        }
296

            
297
        let mut updates = vec![];
298
        for (k, v) in solset.clone() {
299
            if let Name::User(_) = k {
300
                match v {
301
                    Literal::Bool(true) => updates.push((k, Literal::Int(1))),
302
                    Literal::Bool(false) => updates.push((k, Literal::Int(0))),
303
                    Literal::Int(_) => {}
304
                    Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, _)) => {
305
                        // make all domains the same (this is just in the tester so the types dont
306
                        // actually matter)
307

            
308
                        let mut matrix =
309
                            AbstractLiteral::Matrix(elems, Moo::new(GroundDomain::Int(vec![])));
310
                        matrix = matrix.transform(&move |x: AbstractLiteral<Literal>| match x {
311
                            AbstractLiteral::Matrix(items, _) => {
312
                                let items = items
313
                                    .into_iter()
314
                                    .map(|x| match x {
315
                                        Literal::Bool(false) => Literal::Int(0),
316
                                        Literal::Bool(true) => Literal::Int(1),
317
                                        x => x,
318
                                    })
319
                                    .collect_vec();
320

            
321
                                AbstractLiteral::Matrix(items, Moo::new(GroundDomain::Int(vec![])))
322
                            }
323
                            x => x,
324
                        });
325
                        updates.push((k, Literal::AbstractLiteral(matrix)));
326
                    }
327
                    Literal::AbstractLiteral(AbstractLiteral::Tuple(elems)) => {
328
                        // just the same as matrix but with tuples instead
329
                        // only conversion needed is to convert bools to ints
330
                        let mut tuple = AbstractLiteral::Tuple(elems);
331
                        tuple = tuple.transform(
332
                            &(move |x: AbstractLiteral<Literal>| match x {
333
                                AbstractLiteral::Tuple(items) => {
334
                                    let items = items
335
                                        .into_iter()
336
                                        .map(|x| match x {
337
                                            Literal::Bool(false) => Literal::Int(0),
338
                                            Literal::Bool(true) => Literal::Int(1),
339
                                            x => x,
340
                                        })
341
                                        .collect_vec();
342

            
343
                                    AbstractLiteral::Tuple(items)
344
                                }
345
                                x => x,
346
                            }),
347
                        );
348
                        updates.push((k, Literal::AbstractLiteral(tuple)));
349
                    }
350
                    Literal::AbstractLiteral(AbstractLiteral::Record(entries)) => {
351
                        // just the same as matrix but with tuples instead
352
                        // only conversion needed is to convert bools to ints
353
                        let mut record = AbstractLiteral::Record(entries);
354
                        record = record.transform(&move |x: AbstractLiteral<Literal>| match x {
355
                            AbstractLiteral::Record(entries) => {
356
                                let entries = entries
357
                                    .into_iter()
358
                                    .map(|x| {
359
                                        let RecordValue { name, value } = x;
360
                                        {
361
                                            let value = match value {
362
                                                Literal::Bool(false) => Literal::Int(0),
363
                                                Literal::Bool(true) => Literal::Int(1),
364
                                                x => x,
365
                                            };
366
                                            RecordValue { name, value }
367
                                        }
368
                                    })
369
                                    .collect_vec();
370

            
371
                                AbstractLiteral::Record(entries)
372
                            }
373
                            x => x,
374
                        });
375
                        updates.push((k, Literal::AbstractLiteral(record)));
376
                    }
377
                    Literal::AbstractLiteral(AbstractLiteral::Set(members)) => {
378
                        let set = AbstractLiteral::Set(members).transform(&move |x| match x {
379
                            AbstractLiteral::Set(members) => {
380
                                let members = members
381
                                    .into_iter()
382
                                    .map(|x| match x {
383
                                        Literal::Bool(false) => Literal::Int(0),
384
                                        Literal::Bool(true) => Literal::Int(1),
385
                                        x => x,
386
                                    })
387
                                    .collect_vec();
388

            
389
                                AbstractLiteral::Set(members)
390
                            }
391
                            x => x,
392
                        });
393
                        updates.push((k, Literal::AbstractLiteral(set)));
394
                    }
395
                    e => bug!("unexpected literal type: {e:?}"),
396
                }
397
            }
398
        }
399

            
400
        for (k, v) in updates {
401
            solset.insert(k, v);
402
        }
403
    }
404

            
405
    // Remove duplicates
406
    normalized = normalized.into_iter().unique().collect();
407
    normalized
408
}
409

            
410
fn maybe_truncate_serialised_json(serialised: String, test_stage: &str) -> String {
411
    if test_stage == "rewrite" {
412
        truncate_to_first_lines(&serialised, REWRITE_SERIALISED_JSON_MAX_LINES)
413
    } else {
414
        serialised
415
    }
416
}
417

            
418
fn truncate_to_first_lines(content: &str, max_lines: usize) -> String {
419
    content.lines().take(max_lines).join("\n")
420
}
421

            
422
fn read_first_n_lines<P: AsRef<Path>>(filename: P, n: usize) -> io::Result<String> {
423
    let reader = BufReader::new(File::open(&filename)?);
424
    let lines = reader
425
        .lines()
426
        .chunks(n)
427
        .into_iter()
428
        .next()
429
        .unwrap()
430
        .collect::<Result<Vec<_>, _>>()?;
431
    Ok(lines.join("\n"))
432
}