1
use std::collections::{BTreeMap, HashMap, HashSet};
2
use std::fmt::Debug;
3
use std::path::Path;
4
use std::{io, mem, vec};
5

            
6
use conjure_cp::ast::records::RecordValue;
7
use conjure_cp::ast::serde::ObjId;
8
use conjure_cp::bug;
9
use itertools::Itertools as _;
10
use std::fs::File;
11
use std::hash::Hash;
12
use std::io::{BufRead, BufReader, Write};
13
use std::sync::{Arc, RwLock};
14
use uniplate::Uniplate;
15

            
16
use conjure_cp::ast::{AbstractLiteral, GroundDomain, Moo, SerdeModel};
17
use conjure_cp::context::Context;
18
use serde_json::{Error as JsonError, Value as JsonValue};
19

            
20
use conjure_cp::error::Error;
21

            
22
use crate::utils::conjure::solutions_to_json;
23
use crate::utils::json::sort_json_object;
24
use crate::utils::misc::to_set;
25
use conjure_cp::Model as ConjureModel;
26
use conjure_cp::ast::Name::User;
27
use conjure_cp::ast::{Literal, Name};
28
use conjure_cp::settings::SolverFamily;
29

            
30
/// Limit how many lines of the rewrite serialisation we persist/compare in integration tests.
31
pub const REWRITE_SERIALISED_JSON_MAX_LINES: usize = 1000;
32

            
33
/// Converts a SerdeModel to JSON with stable IDs.
34
///
35
/// This ensures that the same model structure always produces the same IDs,
36
/// regardless of the order in which objects were created in memory.
37
122
fn model_to_json_with_stable_ids(model: &SerdeModel) -> Result<JsonValue, JsonError> {
38
    // Collect stable ID mapping using uniplate traversal on the SerdeModel
39
122
    let id_map = model.collect_stable_id_mapping();
40

            
41
    // Serialize the model to JSON
42
122
    let mut json = serde_json::to_value(model)?;
43

            
44
    // Replace all IDs in the JSON with their stable counterparts
45
122
    replace_ids(&mut json, &id_map);
46

            
47
122
    Ok(json)
48
122
}
49

            
50
/// Recursively replaces all IDs in the JSON with their stable counterparts.
51
///
52
/// This is applied to all fields that are called "id" or "ptr" - be mindful
53
/// of potential naming clashes in the future!
54
8052
fn replace_ids(value: &mut JsonValue, id_map: &HashMap<ObjId, ObjId>) {
55
8052
    match value {
56
3422
        JsonValue::Object(map) => {
57
            // Replace IDs in three places:
58
            // - "id" fields (SymbolTable IDs)
59
            // - "parent" fields (SymbolTable nesting)
60
            // - "ptr" fields (DeclarationPtr IDs)
61
5510
            for (k, v) in map.iter_mut() {
62
5510
                if (k == "id" || k == "ptr" || k == "parent")
63
480
                    && let Ok(old_id) = serde_json::from_value::<ObjId>(mem::take(v))
64
358
                {
65
358
                    let new_id = id_map.get(&old_id).expect("all ids to be in the id map");
66
358
                    *v = serde_json::to_value(new_id)
67
358
                        .expect("serialization of an ObjId to always succeed");
68
5152
                }
69
            }
70

            
71
            // Recursively process all values
72
5510
            for val in map.values_mut() {
73
5510
                replace_ids(val, id_map);
74
5510
            }
75
        }
76
1504
        JsonValue::Array(arr) => {
77
2420
            for item in arr {
78
2420
                replace_ids(item, id_map);
79
2420
            }
80
        }
81
3126
        _ => {}
82
    }
83
8052
}
84

            
85
pub fn assert_eq_any_order<T: Eq + Hash + Debug + Clone>(a: &Vec<Vec<T>>, b: &Vec<Vec<T>>) {
86
    assert_eq!(a.len(), b.len());
87

            
88
    let mut a_rows: Vec<HashSet<T>> = Vec::new();
89
    for row in a {
90
        let hash_row = to_set(row);
91
        a_rows.push(hash_row);
92
    }
93

            
94
    let mut b_rows: Vec<HashSet<T>> = Vec::new();
95
    for row in b {
96
        let hash_row = to_set(row);
97
        b_rows.push(hash_row);
98
    }
99

            
100
    for row in a_rows {
101
        assert!(b_rows.contains(&row));
102
    }
103
}
104

            
105
122
pub fn serialize_model(model: &ConjureModel) -> Result<String, JsonError> {
106
122
    let serde_model: SerdeModel = model.clone().into();
107

            
108
    // Convert to JSON with stable IDs
109
122
    let json_with_stable_ids = model_to_json_with_stable_ids(&serde_model)?;
110

            
111
    // Sort JSON object keys for consistent output
112
122
    let sorted_json = sort_json_object(&json_with_stable_ids, false);
113

            
114
    // Serialize to pretty-printed string
115
122
    serde_json::to_string_pretty(&sorted_json)
116
122
}
117

            
118
120
pub fn save_model_json(
119
120
    model: &ConjureModel,
120
120
    path: &str,
121
120
    test_name: &str,
122
120
    test_stage: &str,
123
120
    solver: Option<SolverFamily>,
124
120
) -> Result<(), std::io::Error> {
125
120
    let marker = solver.map_or("agnostic", |s| s.as_str());
126
120
    let generated_json_str = serialize_model(model)?;
127
120
    let generated_json_str = maybe_truncate_serialised_json(generated_json_str, test_stage);
128
120
    let filename = format!("{path}/{test_name}-{marker}.generated-{test_stage}.serialised.json");
129
120
    println!("saving: {}", filename);
130
120
    File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
131
120
    Ok(())
132
120
}
133

            
134
1116
pub fn save_stats_json(
135
1116
    context: Arc<RwLock<Context<'static>>>,
136
1116
    path: &str,
137
1116
    test_name: &str,
138
1116
    solver: SolverFamily,
139
1116
) -> Result<(), std::io::Error> {
140
    #[allow(clippy::unwrap_used)]
141
1116
    let solver_name = solver.as_str();
142

            
143
1116
    let stats = context.read().unwrap().clone();
144
1116
    let generated_json = sort_json_object(&serde_json::to_value(stats)?, false);
145

            
146
    // serialise to string
147
1116
    let generated_json_str = serde_json::to_string_pretty(&generated_json)?;
148

            
149
1116
    File::create(format!("{path}/{test_name}-{solver_name}-stats.json"))?
150
1116
        .write_all(generated_json_str.as_bytes())?;
151

            
152
1116
    Ok(())
153
1116
}
154

            
155
/// Reads a file into a `String`, providing a clearer error message that includes the file path.
156
3348
fn read_with_path(path: String) -> Result<String, std::io::Error> {
157
3348
    std::fs::read_to_string(&path)
158
3348
        .map_err(|e| io::Error::new(e.kind(), format!("{} (path: {})", e, path)))
159
3348
}
160

            
161
240
pub fn read_model_json(
162
240
    ctx: &Arc<RwLock<Context<'static>>>,
163
240
    path: &str,
164
240
    test_name: &str,
165
240
    prefix: &str,
166
240
    test_stage: &str,
167
240
    solver: Option<SolverFamily>,
168
240
) -> Result<ConjureModel, std::io::Error> {
169
240
    let marker = solver.map_or("agnostic", |s| s.as_str());
170
240
    let new_filepath = format!("{path}/{test_name}-{marker}.{prefix}-{test_stage}.serialised.json");
171
240
    let old_filepath = format!("{path}/{marker}-{test_name}.{prefix}-{test_stage}.serialised.json");
172
240
    let filepath = if Path::new(&new_filepath).exists() {
173
120
        new_filepath
174
    } else {
175
120
        old_filepath
176
    };
177
240
    let expected_json_str = std::fs::read_to_string(filepath)?;
178
240
    let expected_model: SerdeModel = serde_json::from_str(&expected_json_str)?;
179

            
180
240
    Ok(expected_model.initialise(ctx.clone()).unwrap())
181
240
}
182

            
183
/// Reads only the first `max_lines` from a serialised model JSON file.
184
pub fn read_model_json_prefix(
185
    path: &str,
186
    test_name: &str,
187
    prefix: &str,
188
    test_stage: &str,
189
    solver: Option<SolverFamily>,
190
    max_lines: usize,
191
) -> Result<String, std::io::Error> {
192
    let marker = solver.map_or("agnostic", |s| s.as_str());
193
    let new_filename = format!("{path}/{test_name}-{marker}.{prefix}-{test_stage}.serialised.json");
194
    let old_filename = format!("{path}/{marker}-{test_name}.{prefix}-{test_stage}.serialised.json");
195
    let filename = if Path::new(&new_filename).exists() {
196
        new_filename
197
    } else {
198
        old_filename
199
    };
200
    println!("reading: {}", filename);
201
    read_first_n_lines(filename, max_lines)
202
}
203

            
204
pub fn minion_solutions_from_json(
205
    serialized: &str,
206
) -> Result<Vec<HashMap<Name, Literal>>, anyhow::Error> {
207
    let json: JsonValue = serde_json::from_str(serialized)?;
208

            
209
    let json_array = json
210
        .as_array()
211
        .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
212

            
213
    let mut solutions = Vec::new();
214

            
215
    for solution in json_array {
216
        let mut sol = HashMap::new();
217
        let solution = solution
218
            .as_object()
219
            .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
220

            
221
        for (var_name, constant) in solution {
222
            let constant = match constant {
223
                JsonValue::Number(n) => {
224
                    let n = n
225
                        .as_i64()
226
                        .ok_or(Error::Parse("Invalid integer".to_owned()))?;
227
                    Literal::Int(n as i32)
228
                }
229
                JsonValue::Bool(b) => Literal::Bool(*b),
230
                _ => return Err(Error::Parse("Invalid constant".to_owned()).into()),
231
            };
232

            
233
            sol.insert(User(var_name.into()), constant);
234
        }
235

            
236
        solutions.push(sol);
237
    }
238

            
239
    Ok(solutions)
240
}
241

            
242
/// Writes the minion solutions to a generated JSON file, and returns the JSON structure.
243
1116
pub fn save_solutions_json(
244
1116
    solutions: &Vec<BTreeMap<Name, Literal>>,
245
1116
    path: &str,
246
1116
    test_name: &str,
247
1116
    solver: SolverFamily,
248
1116
) -> Result<JsonValue, std::io::Error> {
249
1116
    let json_solutions = solutions_to_json(solutions);
250
1116
    let generated_json_str = serde_json::to_string_pretty(&json_solutions)?;
251

            
252
1116
    let solver_name = solver.as_str();
253
1116
    let filename = format!("{path}/{test_name}-{solver_name}.generated-solutions.json");
254
1116
    File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
255

            
256
1116
    Ok(json_solutions)
257
1116
}
258

            
259
1116
pub fn read_solutions_json(
260
1116
    path: &str,
261
1116
    test_name: &str,
262
1116
    prefix: &str,
263
1116
    solver: SolverFamily,
264
1116
) -> Result<JsonValue, anyhow::Error> {
265
1116
    let solver_name = match solver {
266
        SolverFamily::Sat(_) => "sat",
267
        #[cfg(feature = "smt")]
268
180
        SolverFamily::Smt(..) => "smt",
269
936
        SolverFamily::Minion => "minion",
270
    };
271
1116
    let new_filename = format!("{path}/{test_name}-{solver_name}.{prefix}-solutions.json");
272
1116
    let old_filename = format!("{path}/{solver_name}-{test_name}.{prefix}-solutions.json");
273
1116
    let expected_json_str = if Path::new(&new_filename).exists() {
274
1116
        read_with_path(new_filename)?
275
    } else {
276
        read_with_path(old_filename)?
277
    };
278

            
279
1116
    let expected_solutions: JsonValue =
280
1116
        sort_json_object(&serde_json::from_str(&expected_json_str)?, true);
281

            
282
1116
    Ok(expected_solutions)
283
1116
}
284

            
285
/// Reads a human-readable rule trace text file.
286
2232
pub fn read_human_rule_trace(
287
2232
    path: &str,
288
2232
    test_name: &str,
289
2232
    prefix: &str,
290
2232
    solver: &SolverFamily,
291
2232
) -> Result<Vec<String>, std::io::Error> {
292
2232
    let solver_name = solver.as_str();
293
2232
    let new_filename = format!("{path}/{test_name}-{solver_name}-{prefix}-rule-trace.txt");
294
2232
    let old_filename = format!("{path}/{solver_name}-{test_name}-{prefix}-rule-trace.txt");
295
2232
    let filename = if Path::new(&new_filename).exists() {
296
2232
        new_filename
297
    } else {
298
        old_filename
299
    };
300
2232
    let rules_trace: Vec<String> = read_with_path(filename)?
301
2232
        .lines()
302
2232
        .map(String::from)
303
2232
        .collect();
304

            
305
2232
    Ok(rules_trace)
306
2232
}
307

            
308
#[doc(hidden)]
309
pub fn normalize_solutions_for_comparison(
310
    input_solutions: &[BTreeMap<Name, Literal>],
311
) -> Vec<BTreeMap<Name, Literal>> {
312
    let mut normalized = input_solutions.to_vec();
313

            
314
    for solset in &mut normalized {
315
        // remove machine names
316
        let keys_to_remove: Vec<Name> = solset
317
            .keys()
318
            .filter(|k| matches!(k, Name::Machine(_)))
319
            .cloned()
320
            .collect();
321
        for k in keys_to_remove {
322
            solset.remove(&k);
323
        }
324

            
325
        let mut updates = vec![];
326
        for (k, v) in solset.clone() {
327
            if let Name::User(_) = k {
328
                match v {
329
                    Literal::Bool(true) => updates.push((k, Literal::Int(1))),
330
                    Literal::Bool(false) => updates.push((k, Literal::Int(0))),
331
                    Literal::Int(_) => {}
332
                    Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, _)) => {
333
                        // make all domains the same (this is just in the tester so the types dont
334
                        // actually matter)
335

            
336
                        let mut matrix =
337
                            AbstractLiteral::Matrix(elems, Moo::new(GroundDomain::Int(vec![])));
338
                        matrix = matrix.transform(&move |x: AbstractLiteral<Literal>| match x {
339
                            AbstractLiteral::Matrix(items, _) => {
340
                                let items = items
341
                                    .into_iter()
342
                                    .map(|x| match x {
343
                                        Literal::Bool(false) => Literal::Int(0),
344
                                        Literal::Bool(true) => Literal::Int(1),
345
                                        x => x,
346
                                    })
347
                                    .collect_vec();
348

            
349
                                AbstractLiteral::Matrix(items, Moo::new(GroundDomain::Int(vec![])))
350
                            }
351
                            x => x,
352
                        });
353
                        updates.push((k, Literal::AbstractLiteral(matrix)));
354
                    }
355
                    Literal::AbstractLiteral(AbstractLiteral::Tuple(elems)) => {
356
                        // just the same as matrix but with tuples instead
357
                        // only conversion needed is to convert bools to ints
358
                        let mut tuple = AbstractLiteral::Tuple(elems);
359
                        tuple = tuple.transform(
360
                            &(move |x: AbstractLiteral<Literal>| match x {
361
                                AbstractLiteral::Tuple(items) => {
362
                                    let items = items
363
                                        .into_iter()
364
                                        .map(|x| match x {
365
                                            Literal::Bool(false) => Literal::Int(0),
366
                                            Literal::Bool(true) => Literal::Int(1),
367
                                            x => x,
368
                                        })
369
                                        .collect_vec();
370

            
371
                                    AbstractLiteral::Tuple(items)
372
                                }
373
                                x => x,
374
                            }),
375
                        );
376
                        updates.push((k, Literal::AbstractLiteral(tuple)));
377
                    }
378
                    Literal::AbstractLiteral(AbstractLiteral::Record(entries)) => {
379
                        // just the same as matrix but with tuples instead
380
                        // only conversion needed is to convert bools to ints
381
                        let mut record = AbstractLiteral::Record(entries);
382
                        record = record.transform(&move |x: AbstractLiteral<Literal>| match x {
383
                            AbstractLiteral::Record(entries) => {
384
                                let entries = entries
385
                                    .into_iter()
386
                                    .map(|x| {
387
                                        let RecordValue { name, value } = x;
388
                                        {
389
                                            let value = match value {
390
                                                Literal::Bool(false) => Literal::Int(0),
391
                                                Literal::Bool(true) => Literal::Int(1),
392
                                                x => x,
393
                                            };
394
                                            RecordValue { name, value }
395
                                        }
396
                                    })
397
                                    .collect_vec();
398

            
399
                                AbstractLiteral::Record(entries)
400
                            }
401
                            x => x,
402
                        });
403
                        updates.push((k, Literal::AbstractLiteral(record)));
404
                    }
405
                    Literal::AbstractLiteral(AbstractLiteral::Set(members)) => {
406
                        let set = AbstractLiteral::Set(members).transform(&move |x| match x {
407
                            AbstractLiteral::Set(members) => {
408
                                let members = members
409
                                    .into_iter()
410
                                    .map(|x| match x {
411
                                        Literal::Bool(false) => Literal::Int(0),
412
                                        Literal::Bool(true) => Literal::Int(1),
413
                                        x => x,
414
                                    })
415
                                    .collect_vec();
416

            
417
                                AbstractLiteral::Set(members)
418
                            }
419
                            x => x,
420
                        });
421
                        updates.push((k, Literal::AbstractLiteral(set)));
422
                    }
423
                    e => bug!("unexpected literal type: {e:?}"),
424
                }
425
            }
426
        }
427

            
428
        for (k, v) in updates {
429
            solset.insert(k, v);
430
        }
431
    }
432

            
433
    // Remove duplicates
434
    normalized = normalized.into_iter().unique().collect();
435
    normalized
436
}
437

            
438
120
fn maybe_truncate_serialised_json(serialised: String, test_stage: &str) -> String {
439
120
    if test_stage == "rewrite" {
440
        truncate_to_first_lines(&serialised, REWRITE_SERIALISED_JSON_MAX_LINES)
441
    } else {
442
120
        serialised
443
    }
444
120
}
445

            
446
fn truncate_to_first_lines(content: &str, max_lines: usize) -> String {
447
    content.lines().take(max_lines).join("\n")
448
}
449

            
450
fn read_first_n_lines<P: AsRef<Path>>(filename: P, n: usize) -> io::Result<String> {
451
    let reader = BufReader::new(File::open(&filename)?);
452
    let lines = reader
453
        .lines()
454
        .chunks(n)
455
        .into_iter()
456
        .next()
457
        .unwrap()
458
        .collect::<Result<Vec<_>, _>>()?;
459
    Ok(lines.join("\n"))
460
}