1
use std::collections::{BTreeMap, HashMap, HashSet};
2
use std::fmt::Debug;
3
use std::path::Path;
4
use std::{io, mem, vec};
5

            
6
use conjure_cp::ast::records::RecordValue;
7
use conjure_cp::ast::serde::ObjId;
8
use conjure_cp::bug;
9
use itertools::Itertools as _;
10
use std::fs::File;
11
use std::hash::Hash;
12
use std::io::{BufRead, BufReader, Write};
13
use std::sync::{Arc, RwLock};
14
use uniplate::Uniplate;
15

            
16
use conjure_cp::ast::{AbstractLiteral, Expression, GroundDomain, Moo, SerdeModel};
17
use conjure_cp::context::Context;
18
use serde_json::{Error as JsonError, Value as JsonValue};
19

            
20
use conjure_cp::error::Error;
21

            
22
use crate::utils::conjure::solutions_to_json;
23
use crate::utils::json::sort_json_object;
24
use crate::utils::misc::to_set;
25
use conjure_cp::Model as ConjureModel;
26
use conjure_cp::ast::Name::User;
27
use conjure_cp::ast::{Literal, Name};
28
use conjure_cp::settings::SolverFamily;
29

            
30
/// Limit how many lines of the rewrite serialisation we persist/compare in integration tests.
31
pub const REWRITE_SERIALISED_JSON_MAX_LINES: usize = 1000;
32

            
33
/// Converts a SerdeModel to JSON with stable IDs.
34
///
35
/// This ensures that the same model structure always produces the same IDs,
36
/// regardless of the order in which objects were created in memory.
37
122
fn model_to_json_with_stable_ids(model: &SerdeModel) -> Result<JsonValue, JsonError> {
38
    // Collect stable ID mapping using uniplate traversal on the SerdeModel
39
122
    let id_map = model.collect_stable_id_mapping();
40

            
41
    // Serialize the model to JSON
42
122
    let mut json = serde_json::to_value(model)?;
43

            
44
    // Replace all IDs in the JSON with their stable counterparts
45
122
    replace_ids(&mut json, &id_map);
46

            
47
122
    Ok(json)
48
122
}
49

            
50
/// Recursively replaces all IDs in the JSON with their stable counterparts.
51
///
52
/// This is applied to all fields that are called "id" or "ptr" - be mindful
53
/// of potential naming clashes in the future!
54
8052
fn replace_ids(value: &mut JsonValue, id_map: &HashMap<ObjId, ObjId>) {
55
8052
    match value {
56
3422
        JsonValue::Object(map) => {
57
            // Replace IDs in three places:
58
            // - "id" fields (SymbolTable IDs)
59
            // - "parent" fields (SymbolTable nesting)
60
            // - "ptr" fields (DeclarationPtr IDs)
61
5510
            for (k, v) in map.iter_mut() {
62
5510
                if (k == "id" || k == "ptr" || k == "parent")
63
480
                    && let Ok(old_id) = serde_json::from_value::<ObjId>(mem::take(v))
64
358
                {
65
358
                    let new_id = id_map.get(&old_id).expect("all ids to be in the id map");
66
358
                    *v = serde_json::to_value(new_id)
67
358
                        .expect("serialization of an ObjId to always succeed");
68
5152
                }
69
            }
70

            
71
            // Recursively process all values
72
5510
            for val in map.values_mut() {
73
5510
                replace_ids(val, id_map);
74
5510
            }
75
        }
76
1504
        JsonValue::Array(arr) => {
77
2420
            for item in arr {
78
2420
                replace_ids(item, id_map);
79
2420
            }
80
        }
81
3126
        _ => {}
82
    }
83
8052
}
84

            
85
pub fn assert_eq_any_order<T: Eq + Hash + Debug + Clone>(a: &Vec<Vec<T>>, b: &Vec<Vec<T>>) {
86
    assert_eq!(a.len(), b.len());
87

            
88
    let mut a_rows: Vec<HashSet<T>> = Vec::new();
89
    for row in a {
90
        let hash_row = to_set(row);
91
        a_rows.push(hash_row);
92
    }
93

            
94
    let mut b_rows: Vec<HashSet<T>> = Vec::new();
95
    for row in b {
96
        let hash_row = to_set(row);
97
        b_rows.push(hash_row);
98
    }
99

            
100
    for row in a_rows {
101
        assert!(b_rows.contains(&row));
102
    }
103
}
104

            
105
122
pub fn serialize_model(model: &ConjureModel) -> Result<String, JsonError> {
106
122
    let serde_model: SerdeModel = model.clone().into();
107

            
108
    // Convert to JSON with stable IDs
109
122
    let json_with_stable_ids = model_to_json_with_stable_ids(&serde_model)?;
110

            
111
    // Sort JSON object keys for consistent output
112
122
    let sorted_json = sort_json_object(&json_with_stable_ids, false);
113

            
114
    // Serialize to pretty-printed string
115
122
    serde_json::to_string_pretty(&sorted_json)
116
122
}
117

            
118
12
pub fn serialize_domains(model: &ConjureModel) -> Result<String, JsonError> {
119
12
    let mut output = String::new();
120
12
    for constraint in model.constraints() {
121
12
        serialize_domains_expr(constraint, 0, &mut output);
122
12
    }
123
12
    Ok(output)
124
12
}
125

            
126
100
fn serialize_domains_expr(expr: &Expression, depth: usize, output: &mut String) {
127
100
    let domain = expr
128
100
        .domain_of()
129
100
        .map(|domain| domain.to_string())
130
100
        .unwrap_or_else(|| "<unknown>".to_owned());
131
100
    output.push_str(&" ".repeat(depth));
132
100
    output.push_str(&format!("{expr} :: {domain}\n"));
133

            
134
100
    for child in expr.children() {
135
88
        serialize_domains_expr(&child, depth + 1, output);
136
88
    }
137
100
}
138

            
139
pub fn save_model_json(
140
    model: &ConjureModel,
141
    path: &str,
142
    test_name: &str,
143
    test_stage: &str,
144
    solver: SolverFamily,
145
) -> Result<(), std::io::Error> {
146
    let marker = solver.as_str();
147
    let generated_json_str = serialize_model(model)?;
148
    let generated_json_str = maybe_truncate_serialised_json(generated_json_str, test_stage);
149
    let filename = format!("{path}/{test_name}-{marker}.generated-{test_stage}.serialised.json");
150
    println!("saving: {}", filename);
151
    File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
152
    Ok(())
153
}
154

            
155
1720
pub fn save_stats_json(
156
1720
    context: Arc<RwLock<Context<'static>>>,
157
1720
    path: &str,
158
1720
    test_name: &str,
159
1720
    solver: SolverFamily,
160
1720
) -> Result<(), std::io::Error> {
161
    #[allow(clippy::unwrap_used)]
162
1720
    let solver_name = solver.as_str();
163

            
164
1720
    let stats = context.read().unwrap().clone();
165
1720
    let generated_json = sort_json_object(&serde_json::to_value(stats)?, false);
166

            
167
    // serialise to string
168
1720
    let generated_json_str = serde_json::to_string_pretty(&generated_json)?;
169

            
170
1720
    File::create(format!("{path}/{test_name}-{solver_name}-stats.json"))?
171
1720
        .write_all(generated_json_str.as_bytes())?;
172

            
173
1720
    Ok(())
174
1720
}
175

            
176
/// Reads a file into a `String`, providing a clearer error message that includes the file path.
177
5160
fn read_with_path(path: String) -> Result<String, std::io::Error> {
178
5160
    std::fs::read_to_string(&path)
179
5160
        .map_err(|e| io::Error::new(e.kind(), format!("{} (path: {})", e, path)))
180
5160
}
181

            
182
pub fn read_model_json(
183
    ctx: &Arc<RwLock<Context<'static>>>,
184
    path: &str,
185
    test_name: &str,
186
    prefix: &str,
187
    test_stage: &str,
188
    solver: SolverFamily,
189
) -> Result<ConjureModel, std::io::Error> {
190
    let marker = solver.as_str();
191
    let filepath = format!("{path}/{test_name}-{marker}.{prefix}-{test_stage}.serialised.json");
192
    let expected_json_str = std::fs::read_to_string(filepath)?;
193
    let expected_model: SerdeModel = serde_json::from_str(&expected_json_str)?;
194

            
195
    Ok(expected_model.initialise(ctx.clone()).unwrap())
196
}
197

            
198
/// Reads only the first `max_lines` from a serialised model JSON file.
199
pub fn read_model_json_prefix(
200
    path: &str,
201
    test_name: &str,
202
    prefix: &str,
203
    test_stage: &str,
204
    solver: SolverFamily,
205
    max_lines: usize,
206
) -> Result<String, std::io::Error> {
207
    let marker = solver.as_str();
208
    let filename = format!("{path}/{test_name}-{marker}.{prefix}-{test_stage}.serialised.json");
209
    println!("reading: {}", filename);
210
    read_first_n_lines(filename, max_lines)
211
}
212

            
213
pub fn minion_solutions_from_json(
214
    serialized: &str,
215
) -> Result<Vec<HashMap<Name, Literal>>, anyhow::Error> {
216
    let json: JsonValue = serde_json::from_str(serialized)?;
217

            
218
    let json_array = json
219
        .as_array()
220
        .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
221

            
222
    let mut solutions = Vec::new();
223

            
224
    for solution in json_array {
225
        let mut sol = HashMap::new();
226
        let solution = solution
227
            .as_object()
228
            .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
229

            
230
        for (var_name, constant) in solution {
231
            let constant = match constant {
232
                JsonValue::Number(n) => {
233
                    let n = n
234
                        .as_i64()
235
                        .ok_or(Error::Parse("Invalid integer".to_owned()))?;
236
                    Literal::Int(n as i32)
237
                }
238
                JsonValue::Bool(b) => Literal::Bool(*b),
239
                _ => return Err(Error::Parse("Invalid constant".to_owned()).into()),
240
            };
241

            
242
            sol.insert(User(var_name.into()), constant);
243
        }
244

            
245
        solutions.push(sol);
246
    }
247

            
248
    Ok(solutions)
249
}
250

            
251
/// Writes the minion solutions to a generated JSON file, and returns the JSON structure.
252
1720
pub fn save_solutions_json(
253
1720
    solutions: &Vec<BTreeMap<Name, Literal>>,
254
1720
    path: &str,
255
1720
    test_name: &str,
256
1720
    solver: SolverFamily,
257
1720
) -> Result<JsonValue, std::io::Error> {
258
1720
    let json_solutions = solutions_to_json(solutions);
259
1720
    let generated_json_str = serde_json::to_string_pretty(&json_solutions)?;
260

            
261
1720
    let solver_name = solver.as_str();
262
1720
    let filename = format!("{path}/{test_name}-{solver_name}.generated-solutions.json");
263
1720
    File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
264

            
265
1720
    Ok(json_solutions)
266
1720
}
267

            
268
1720
pub fn read_solutions_json(
269
1720
    path: &str,
270
1720
    test_name: &str,
271
1720
    prefix: &str,
272
1720
    solver: SolverFamily,
273
1720
) -> Result<JsonValue, anyhow::Error> {
274
1720
    let solver_name = solver.as_str();
275
1720
    let filename = format!("{path}/{test_name}-{solver_name}.{prefix}-solutions.json");
276
1720
    let expected_json_str = read_with_path(filename)?;
277

            
278
1720
    let expected_solutions: JsonValue =
279
1720
        sort_json_object(&serde_json::from_str(&expected_json_str)?, true);
280

            
281
1720
    Ok(expected_solutions)
282
1720
}
283

            
284
/// Reads a human-readable rule trace text file.
285
3440
pub fn read_human_rule_trace(
286
3440
    path: &str,
287
3440
    test_name: &str,
288
3440
    prefix: &str,
289
3440
    solver: &SolverFamily,
290
3440
) -> Result<Vec<String>, std::io::Error> {
291
3440
    let solver_name = solver.as_str();
292
3440
    let filename = format!("{path}/{test_name}-{solver_name}-{prefix}-rule-trace.txt");
293
3440
    let rules_trace: Vec<String> = read_with_path(filename)?
294
3440
        .lines()
295
3440
        .map(String::from)
296
3440
        .collect();
297

            
298
3440
    Ok(rules_trace)
299
3440
}
300

            
301
#[doc(hidden)]
302
pub fn normalize_solutions_for_comparison(
303
    input_solutions: &[BTreeMap<Name, Literal>],
304
) -> Vec<BTreeMap<Name, Literal>> {
305
    let mut normalized = input_solutions.to_vec();
306

            
307
    for solset in &mut normalized {
308
        // remove machine names
309
        let keys_to_remove: Vec<Name> = solset
310
            .keys()
311
            .filter(|k| matches!(k, Name::Machine(_)))
312
            .cloned()
313
            .collect();
314
        for k in keys_to_remove {
315
            solset.remove(&k);
316
        }
317

            
318
        let mut updates = vec![];
319
        for (k, v) in solset.clone() {
320
            if let Name::User(_) = k {
321
                match v {
322
                    Literal::Bool(true) => updates.push((k, Literal::Int(1))),
323
                    Literal::Bool(false) => updates.push((k, Literal::Int(0))),
324
                    Literal::Int(_) => {}
325
                    Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, _)) => {
326
                        // make all domains the same (this is just in the tester so the types dont
327
                        // actually matter)
328

            
329
                        let mut matrix =
330
                            AbstractLiteral::Matrix(elems, Moo::new(GroundDomain::Int(vec![])));
331
                        matrix = matrix.transform(&move |x: AbstractLiteral<Literal>| match x {
332
                            AbstractLiteral::Matrix(items, _) => {
333
                                let items = items
334
                                    .into_iter()
335
                                    .map(|x| match x {
336
                                        Literal::Bool(false) => Literal::Int(0),
337
                                        Literal::Bool(true) => Literal::Int(1),
338
                                        x => x,
339
                                    })
340
                                    .collect_vec();
341

            
342
                                AbstractLiteral::Matrix(items, Moo::new(GroundDomain::Int(vec![])))
343
                            }
344
                            x => x,
345
                        });
346
                        updates.push((k, Literal::AbstractLiteral(matrix)));
347
                    }
348
                    Literal::AbstractLiteral(AbstractLiteral::Tuple(elems)) => {
349
                        // just the same as matrix but with tuples instead
350
                        // only conversion needed is to convert bools to ints
351
                        let mut tuple = AbstractLiteral::Tuple(elems);
352
                        tuple = tuple.transform(
353
                            &(move |x: AbstractLiteral<Literal>| match x {
354
                                AbstractLiteral::Tuple(items) => {
355
                                    let items = items
356
                                        .into_iter()
357
                                        .map(|x| match x {
358
                                            Literal::Bool(false) => Literal::Int(0),
359
                                            Literal::Bool(true) => Literal::Int(1),
360
                                            x => x,
361
                                        })
362
                                        .collect_vec();
363

            
364
                                    AbstractLiteral::Tuple(items)
365
                                }
366
                                x => x,
367
                            }),
368
                        );
369
                        updates.push((k, Literal::AbstractLiteral(tuple)));
370
                    }
371
                    Literal::AbstractLiteral(AbstractLiteral::Record(entries)) => {
372
                        // just the same as matrix but with tuples instead
373
                        // only conversion needed is to convert bools to ints
374
                        let mut record = AbstractLiteral::Record(entries);
375
                        record = record.transform(&move |x: AbstractLiteral<Literal>| match x {
376
                            AbstractLiteral::Record(entries) => {
377
                                let entries = entries
378
                                    .into_iter()
379
                                    .map(|x| {
380
                                        let RecordValue { name, value } = x;
381
                                        {
382
                                            let value = match value {
383
                                                Literal::Bool(false) => Literal::Int(0),
384
                                                Literal::Bool(true) => Literal::Int(1),
385
                                                x => x,
386
                                            };
387
                                            RecordValue { name, value }
388
                                        }
389
                                    })
390
                                    .collect_vec();
391

            
392
                                AbstractLiteral::Record(entries)
393
                            }
394
                            x => x,
395
                        });
396
                        updates.push((k, Literal::AbstractLiteral(record)));
397
                    }
398
                    Literal::AbstractLiteral(AbstractLiteral::Set(members)) => {
399
                        let set = AbstractLiteral::Set(members).transform(&move |x| match x {
400
                            AbstractLiteral::Set(members) => {
401
                                let members = members
402
                                    .into_iter()
403
                                    .map(|x| match x {
404
                                        Literal::Bool(false) => Literal::Int(0),
405
                                        Literal::Bool(true) => Literal::Int(1),
406
                                        x => x,
407
                                    })
408
                                    .collect_vec();
409

            
410
                                AbstractLiteral::Set(members)
411
                            }
412
                            x => x,
413
                        });
414
                        updates.push((k, Literal::AbstractLiteral(set)));
415
                    }
416
                    e => bug!("unexpected literal type: {e:?}"),
417
                }
418
            }
419
        }
420

            
421
        for (k, v) in updates {
422
            solset.insert(k, v);
423
        }
424
    }
425

            
426
    // Remove duplicates
427
    normalized = normalized.into_iter().unique().collect();
428
    normalized
429
}
430

            
431
fn maybe_truncate_serialised_json(serialised: String, test_stage: &str) -> String {
432
    if test_stage == "rewrite" {
433
        truncate_to_first_lines(&serialised, REWRITE_SERIALISED_JSON_MAX_LINES)
434
    } else {
435
        serialised
436
    }
437
}
438

            
439
fn truncate_to_first_lines(content: &str, max_lines: usize) -> String {
440
    content.lines().take(max_lines).join("\n")
441
}
442

            
443
fn read_first_n_lines<P: AsRef<Path>>(filename: P, n: usize) -> io::Result<String> {
444
    let reader = BufReader::new(File::open(&filename)?);
445
    let lines = reader
446
        .lines()
447
        .chunks(n)
448
        .into_iter()
449
        .next()
450
        .unwrap()
451
        .collect::<Result<Vec<_>, _>>()?;
452
    Ok(lines.join("\n"))
453
}