1
use std::collections::{BTreeMap, HashMap, HashSet};
2
use std::fmt::Debug;
3
use std::path::Path;
4
use std::{io, mem, vec};
5

            
6
use conjure_cp::ast::records::RecordValue;
7
use conjure_cp::ast::serde::ObjId;
8
use conjure_cp::bug;
9
use itertools::Itertools as _;
10
use std::fs::File;
11
use std::hash::Hash;
12
use std::io::{BufRead, BufReader, Write};
13
use std::sync::{Arc, RwLock};
14
use uniplate::Uniplate;
15

            
16
use conjure_cp::ast::{AbstractLiteral, GroundDomain, Moo, SerdeModel};
17
use conjure_cp::context::Context;
18
use serde_json::{Error as JsonError, Value as JsonValue};
19

            
20
use conjure_cp::error::Error;
21

            
22
use crate::utils::conjure::solutions_to_json;
23
use crate::utils::json::sort_json_object;
24
use crate::utils::misc::to_set;
25
use conjure_cp::Model as ConjureModel;
26
use conjure_cp::ast::Name::User;
27
use conjure_cp::ast::{Literal, Name};
28
use conjure_cp::settings::SolverFamily;
29

            
30
/// Limit how many lines of the rewrite serialisation we persist/compare in integration tests.
31
pub const REWRITE_SERIALISED_JSON_MAX_LINES: usize = 1000;
32

            
33
/// Converts a SerdeModel to JSON with stable IDs.
34
///
35
/// This ensures that the same model structure always produces the same IDs,
36
/// regardless of the order in which objects were created in memory.
37
123
fn model_to_json_with_stable_ids(model: &SerdeModel) -> Result<JsonValue, JsonError> {
38
    // Collect stable ID mapping using uniplate traversal on the SerdeModel
39
123
    let id_map = model.collect_stable_id_mapping();
40

            
41
    // Serialize the model to JSON
42
123
    let mut json = serde_json::to_value(model)?;
43

            
44
    // Replace all IDs in the JSON with their stable counterparts
45
123
    replace_ids(&mut json, &id_map);
46

            
47
123
    Ok(json)
48
123
}
49

            
50
/// Recursively replaces all IDs in the JSON with their stable counterparts.
51
///
52
/// This is applied to all fields that are called "id" or "ptr" - be mindful
53
/// of potential naming clashes in the future!
54
8110
fn replace_ids(value: &mut JsonValue, id_map: &HashMap<ObjId, ObjId>) {
55
8110
    match value {
56
3447
        JsonValue::Object(map) => {
57
            // Replace IDs in three places:
58
            // - "id" fields (SymbolTable IDs)
59
            // - "parent" fields (SymbolTable nesting)
60
            // - "ptr" fields (DeclarationPtr IDs)
61
5551
            for (k, v) in map.iter_mut() {
62
5551
                if (k == "id" || k == "ptr" || k == "parent")
63
484
                    && let Ok(old_id) = serde_json::from_value::<ObjId>(mem::take(v))
64
361
                {
65
361
                    let new_id = id_map.get(&old_id).expect("all ids to be in the id map");
66
361
                    *v = serde_json::to_value(new_id)
67
361
                        .expect("serialization of an ObjId to always succeed");
68
5190
                }
69
            }
70

            
71
            // Recursively process all values
72
5551
            for val in map.values_mut() {
73
5551
                replace_ids(val, id_map);
74
5551
            }
75
        }
76
1514
        JsonValue::Array(arr) => {
77
2436
            for item in arr {
78
2436
                replace_ids(item, id_map);
79
2436
            }
80
        }
81
3149
        _ => {}
82
    }
83
8110
}
84

            
85
pub fn assert_eq_any_order<T: Eq + Hash + Debug + Clone>(a: &Vec<Vec<T>>, b: &Vec<Vec<T>>) {
86
    assert_eq!(a.len(), b.len());
87

            
88
    let mut a_rows: Vec<HashSet<T>> = Vec::new();
89
    for row in a {
90
        let hash_row = to_set(row);
91
        a_rows.push(hash_row);
92
    }
93

            
94
    let mut b_rows: Vec<HashSet<T>> = Vec::new();
95
    for row in b {
96
        let hash_row = to_set(row);
97
        b_rows.push(hash_row);
98
    }
99

            
100
    for row in a_rows {
101
        assert!(b_rows.contains(&row));
102
    }
103
}
104

            
105
123
pub fn serialize_model(model: &ConjureModel) -> Result<String, JsonError> {
106
123
    let serde_model: SerdeModel = model.clone().into();
107

            
108
    // Convert to JSON with stable IDs
109
123
    let json_with_stable_ids = model_to_json_with_stable_ids(&serde_model)?;
110

            
111
    // Sort JSON object keys for consistent output
112
123
    let sorted_json = sort_json_object(&json_with_stable_ids, false);
113

            
114
    // Serialize to pretty-printed string
115
123
    serde_json::to_string_pretty(&sorted_json)
116
123
}
117

            
118
pub fn save_model_json(
119
    model: &ConjureModel,
120
    path: &str,
121
    test_name: &str,
122
    test_stage: &str,
123
    solver: Option<SolverFamily>,
124
) -> Result<(), std::io::Error> {
125
    let marker = solver.map_or("agnostic", |s| s.as_str());
126
    let generated_json_str = serialize_model(model)?;
127
    let generated_json_str = maybe_truncate_serialised_json(generated_json_str, test_stage);
128
    let filename = format!("{path}/{test_name}-{marker}.generated-{test_stage}.serialised.json");
129
    println!("saving: {}", filename);
130
    File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
131
    Ok(())
132
}
133

            
134
2524
pub fn save_stats_json(
135
2524
    context: Arc<RwLock<Context<'static>>>,
136
2524
    path: &str,
137
2524
    test_name: &str,
138
2524
    solver: SolverFamily,
139
2524
) -> Result<(), std::io::Error> {
140
    #[allow(clippy::unwrap_used)]
141
2524
    let solver_name = solver.as_str();
142

            
143
2524
    let stats = context.read().unwrap().clone();
144
2524
    let generated_json = sort_json_object(&serde_json::to_value(stats)?, false);
145

            
146
    // serialise to string
147
2524
    let generated_json_str = serde_json::to_string_pretty(&generated_json)?;
148

            
149
2524
    File::create(format!("{path}/{test_name}-{solver_name}-stats.json"))?
150
2524
        .write_all(generated_json_str.as_bytes())?;
151

            
152
2524
    Ok(())
153
2524
}
154

            
155
/// Reads a file into a `String`, providing a clearer error message that includes the file path.
156
7584
fn read_with_path(path: String) -> Result<String, std::io::Error> {
157
7584
    std::fs::read_to_string(&path)
158
7584
        .map_err(|e| io::Error::new(e.kind(), format!("{} (path: {})", e, path)))
159
7584
}
160

            
161
pub fn read_model_json(
162
    ctx: &Arc<RwLock<Context<'static>>>,
163
    path: &str,
164
    test_name: &str,
165
    prefix: &str,
166
    test_stage: &str,
167
    solver: Option<SolverFamily>,
168
) -> Result<ConjureModel, std::io::Error> {
169
    let marker = solver.map_or("agnostic", |s| s.as_str());
170
    let new_filepath = format!("{path}/{test_name}-{marker}.{prefix}-{test_stage}.serialised.json");
171
    let old_filepath = format!("{path}/{marker}-{test_name}.{prefix}-{test_stage}.serialised.json");
172
    let filepath = if Path::new(&new_filepath).exists() {
173
        new_filepath
174
    } else {
175
        old_filepath
176
    };
177
    let expected_json_str = std::fs::read_to_string(filepath)?;
178
    let expected_model: SerdeModel = serde_json::from_str(&expected_json_str)?;
179

            
180
    Ok(expected_model.initialise(ctx.clone()).unwrap())
181
}
182

            
183
/// Reads only the first `max_lines` from a serialised model JSON file.
184
pub fn read_model_json_prefix(
185
    path: &str,
186
    test_name: &str,
187
    prefix: &str,
188
    test_stage: &str,
189
    solver: Option<SolverFamily>,
190
    max_lines: usize,
191
) -> Result<String, std::io::Error> {
192
    let marker = solver.map_or("agnostic", |s| s.as_str());
193
    let new_filename = format!("{path}/{test_name}-{marker}.{prefix}-{test_stage}.serialised.json");
194
    let old_filename = format!("{path}/{marker}-{test_name}.{prefix}-{test_stage}.serialised.json");
195
    let filename = if Path::new(&new_filename).exists() {
196
        new_filename
197
    } else {
198
        old_filename
199
    };
200
    println!("reading: {}", filename);
201
    read_first_n_lines(filename, max_lines)
202
}
203

            
204
pub fn minion_solutions_from_json(
205
    serialized: &str,
206
) -> Result<Vec<HashMap<Name, Literal>>, anyhow::Error> {
207
    let json: JsonValue = serde_json::from_str(serialized)?;
208

            
209
    let json_array = json
210
        .as_array()
211
        .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
212

            
213
    let mut solutions = Vec::new();
214

            
215
    for solution in json_array {
216
        let mut sol = HashMap::new();
217
        let solution = solution
218
            .as_object()
219
            .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
220

            
221
        for (var_name, constant) in solution {
222
            let constant = match constant {
223
                JsonValue::Number(n) => {
224
                    let n = n
225
                        .as_i64()
226
                        .ok_or(Error::Parse("Invalid integer".to_owned()))?;
227
                    Literal::Int(n as i32)
228
                }
229
                JsonValue::Bool(b) => Literal::Bool(*b),
230
                _ => return Err(Error::Parse("Invalid constant".to_owned()).into()),
231
            };
232

            
233
            sol.insert(User(var_name.into()), constant);
234
        }
235

            
236
        solutions.push(sol);
237
    }
238

            
239
    Ok(solutions)
240
}
241

            
242
/// Writes the minion solutions to a generated JSON file, and returns the JSON structure.
243
2528
pub fn save_solutions_json(
244
2528
    solutions: &Vec<BTreeMap<Name, Literal>>,
245
2528
    path: &str,
246
2528
    test_name: &str,
247
2528
    solver: SolverFamily,
248
2528
) -> Result<JsonValue, std::io::Error> {
249
2528
    let json_solutions = solutions_to_json(solutions);
250
2528
    let generated_json_str = serde_json::to_string_pretty(&json_solutions)?;
251

            
252
2528
    let solver_name = solver.as_str();
253
2528
    let filename = format!("{path}/{test_name}-{solver_name}.generated-solutions.json");
254
2528
    File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
255

            
256
2528
    Ok(json_solutions)
257
2528
}
258

            
259
2528
pub fn read_solutions_json(
260
2528
    path: &str,
261
2528
    test_name: &str,
262
2528
    prefix: &str,
263
2528
    solver: SolverFamily,
264
2528
) -> Result<JsonValue, anyhow::Error> {
265
2528
    let solver_name = match solver {
266
8
        SolverFamily::Sat(_) => "sat",
267
        #[cfg(feature = "smt")]
268
360
        SolverFamily::Smt(..) => "smt",
269
2160
        SolverFamily::Minion => "minion",
270
    };
271
2528
    let new_filename = format!("{path}/{test_name}-{solver_name}.{prefix}-solutions.json");
272
2528
    let old_filename = format!("{path}/{solver_name}-{test_name}.{prefix}-solutions.json");
273
2528
    let expected_json_str = if Path::new(&new_filename).exists() {
274
2528
        read_with_path(new_filename)?
275
    } else {
276
        read_with_path(old_filename)?
277
    };
278

            
279
2528
    let expected_solutions: JsonValue =
280
2528
        sort_json_object(&serde_json::from_str(&expected_json_str)?, true);
281

            
282
2528
    Ok(expected_solutions)
283
2528
}
284

            
285
/// Reads a human-readable rule trace text file.
286
5056
pub fn read_human_rule_trace(
287
5056
    path: &str,
288
5056
    test_name: &str,
289
5056
    prefix: &str,
290
5056
    solver: &SolverFamily,
291
5056
) -> Result<Vec<String>, std::io::Error> {
292
5056
    let solver_name = solver.as_str();
293
5056
    let new_filename = format!("{path}/{test_name}-{solver_name}-{prefix}-rule-trace.txt");
294
5056
    let old_filename = format!("{path}/{solver_name}-{test_name}-{prefix}-rule-trace.txt");
295
5056
    let filename = if Path::new(&new_filename).exists() {
296
5056
        new_filename
297
    } else {
298
        old_filename
299
    };
300
5056
    let rules_trace: Vec<String> = read_with_path(filename)?
301
5056
        .lines()
302
5056
        .map(String::from)
303
5056
        .collect();
304

            
305
5056
    Ok(rules_trace)
306
5056
}
307

            
308
#[doc(hidden)]
309
pub fn normalize_solutions_for_comparison(
310
    input_solutions: &[BTreeMap<Name, Literal>],
311
) -> Vec<BTreeMap<Name, Literal>> {
312
    let mut normalized = input_solutions.to_vec();
313

            
314
    for solset in &mut normalized {
315
        // remove machine names
316
        let keys_to_remove: Vec<Name> = solset
317
            .keys()
318
            .filter(|k| matches!(k, Name::Machine(_)))
319
            .cloned()
320
            .collect();
321
        for k in keys_to_remove {
322
            solset.remove(&k);
323
        }
324

            
325
        let mut updates = vec![];
326
        for (k, v) in solset.clone() {
327
            if let Name::User(_) = k {
328
                match v {
329
                    Literal::Bool(true) => updates.push((k, Literal::Int(1))),
330
                    Literal::Bool(false) => updates.push((k, Literal::Int(0))),
331
                    Literal::Int(_) => {}
332
                    Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, _)) => {
333
                        // make all domains the same (this is just in the tester so the types dont
334
                        // actually matter)
335

            
336
                        let mut matrix =
337
                            AbstractLiteral::Matrix(elems, Moo::new(GroundDomain::Int(vec![])));
338
                        matrix = matrix.transform(&move |x: AbstractLiteral<Literal>| match x {
339
                            AbstractLiteral::Matrix(items, _) => {
340
                                let items = items
341
                                    .into_iter()
342
                                    .map(|x| match x {
343
                                        Literal::Bool(false) => Literal::Int(0),
344
                                        Literal::Bool(true) => Literal::Int(1),
345
                                        x => x,
346
                                    })
347
                                    .collect_vec();
348

            
349
                                AbstractLiteral::Matrix(items, Moo::new(GroundDomain::Int(vec![])))
350
                            }
351
                            x => x,
352
                        });
353
                        updates.push((k, Literal::AbstractLiteral(matrix)));
354
                    }
355
                    Literal::AbstractLiteral(AbstractLiteral::Tuple(elems)) => {
356
                        // just the same as matrix but with tuples instead
357
                        // only conversion needed is to convert bools to ints
358
                        let mut tuple = AbstractLiteral::Tuple(elems);
359
                        tuple = tuple.transform(
360
                            &(move |x: AbstractLiteral<Literal>| match x {
361
                                AbstractLiteral::Tuple(items) => {
362
                                    let items = items
363
                                        .into_iter()
364
                                        .map(|x| match x {
365
                                            Literal::Bool(false) => Literal::Int(0),
366
                                            Literal::Bool(true) => Literal::Int(1),
367
                                            x => x,
368
                                        })
369
                                        .collect_vec();
370

            
371
                                    AbstractLiteral::Tuple(items)
372
                                }
373
                                x => x,
374
                            }),
375
                        );
376
                        updates.push((k, Literal::AbstractLiteral(tuple)));
377
                    }
378
                    Literal::AbstractLiteral(AbstractLiteral::Record(entries)) => {
379
                        // just the same as matrix but with tuples instead
380
                        // only conversion needed is to convert bools to ints
381
                        let mut record = AbstractLiteral::Record(entries);
382
                        record = record.transform(&move |x: AbstractLiteral<Literal>| match x {
383
                            AbstractLiteral::Record(entries) => {
384
                                let entries = entries
385
                                    .into_iter()
386
                                    .map(|x| {
387
                                        let RecordValue { name, value } = x;
388
                                        {
389
                                            let value = match value {
390
                                                Literal::Bool(false) => Literal::Int(0),
391
                                                Literal::Bool(true) => Literal::Int(1),
392
                                                x => x,
393
                                            };
394
                                            RecordValue { name, value }
395
                                        }
396
                                    })
397
                                    .collect_vec();
398

            
399
                                AbstractLiteral::Record(entries)
400
                            }
401
                            x => x,
402
                        });
403
                        updates.push((k, Literal::AbstractLiteral(record)));
404
                    }
405
                    Literal::AbstractLiteral(AbstractLiteral::Set(members)) => {
406
                        let set = AbstractLiteral::Set(members).transform(&move |x| match x {
407
                            AbstractLiteral::Set(members) => {
408
                                let members = members
409
                                    .into_iter()
410
                                    .map(|x| match x {
411
                                        Literal::Bool(false) => Literal::Int(0),
412
                                        Literal::Bool(true) => Literal::Int(1),
413
                                        x => x,
414
                                    })
415
                                    .collect_vec();
416

            
417
                                AbstractLiteral::Set(members)
418
                            }
419
                            x => x,
420
                        });
421
                        updates.push((k, Literal::AbstractLiteral(set)));
422
                    }
423
                    e => bug!("unexpected literal type: {e:?}"),
424
                }
425
            }
426
        }
427

            
428
        for (k, v) in updates {
429
            solset.insert(k, v);
430
        }
431
    }
432

            
433
    // Remove duplicates
434
    normalized = normalized.into_iter().unique().collect();
435
    normalized
436
}
437

            
438
fn maybe_truncate_serialised_json(serialised: String, test_stage: &str) -> String {
439
    if test_stage == "rewrite" {
440
        truncate_to_first_lines(&serialised, REWRITE_SERIALISED_JSON_MAX_LINES)
441
    } else {
442
        serialised
443
    }
444
}
445

            
446
fn truncate_to_first_lines(content: &str, max_lines: usize) -> String {
447
    content.lines().take(max_lines).join("\n")
448
}
449

            
450
fn read_first_n_lines<P: AsRef<Path>>(filename: P, n: usize) -> io::Result<String> {
451
    let reader = BufReader::new(File::open(&filename)?);
452
    let lines = reader
453
        .lines()
454
        .chunks(n)
455
        .into_iter()
456
        .next()
457
        .unwrap()
458
        .collect::<Result<Vec<_>, _>>()?;
459
    Ok(lines.join("\n"))
460
}