1
use std::collections::{BTreeMap, HashMap, HashSet};
2
use std::fmt::Debug;
3
use std::path::Path;
4
use std::{io, mem, vec};
5

            
6
use conjure_cp::ast::records::RecordValue;
7
use conjure_cp::ast::serde::ObjId;
8
use conjure_cp::bug;
9
use itertools::Itertools as _;
10
use std::fs::File;
11
use std::hash::Hash;
12
use std::io::{BufRead, BufReader, Write};
13
use std::sync::{Arc, RwLock};
14
use uniplate::Uniplate;
15

            
16
use conjure_cp::ast::{AbstractLiteral, ExprInfo, GroundDomain, Moo, SerdeModel};
17
use conjure_cp::context::Context;
18
use serde_json::{Error as JsonError, Value as JsonValue};
19

            
20
use conjure_cp::error::Error;
21

            
22
use crate::utils::conjure::solutions_to_json;
23
use crate::utils::json::sort_json_object;
24
use crate::utils::misc::to_set;
25
use conjure_cp::Model as ConjureModel;
26
use conjure_cp::ast::Name::User;
27
use conjure_cp::ast::{Literal, Name};
28
use conjure_cp::settings::SolverFamily;
29

            
30
/// Limit how many lines of the rewrite serialisation we persist/compare in integration tests.
31
pub const REWRITE_SERIALISED_JSON_MAX_LINES: usize = 1000;
32

            
33
/// Converts a SerdeModel to JSON with stable IDs.
34
///
35
/// This ensures that the same model structure always produces the same IDs,
36
/// regardless of the order in which objects were created in memory.
37
122
fn model_to_json_with_stable_ids(model: &SerdeModel) -> Result<JsonValue, JsonError> {
38
    // Collect stable ID mapping using uniplate traversal on the SerdeModel
39
122
    let id_map = model.collect_stable_id_mapping();
40

            
41
    // Serialize the model to JSON
42
122
    let mut json = serde_json::to_value(model)?;
43

            
44
    // Replace all IDs in the JSON with their stable counterparts
45
122
    replace_ids(&mut json, &id_map);
46

            
47
122
    Ok(json)
48
122
}
49

            
50
/// Recursively replaces all IDs in the JSON with their stable counterparts.
51
///
52
/// This is applied to all fields that are called "id" or "ptr" - be mindful
53
/// of potential naming clashes in the future!
54
8052
fn replace_ids(value: &mut JsonValue, id_map: &HashMap<ObjId, ObjId>) {
55
8052
    match value {
56
3422
        JsonValue::Object(map) => {
57
            // Replace IDs in three places:
58
            // - "id" fields (SymbolTable IDs)
59
            // - "parent" fields (SymbolTable nesting)
60
            // - "ptr" fields (DeclarationPtr IDs)
61
5510
            for (k, v) in map.iter_mut() {
62
5510
                if (k == "id" || k == "ptr" || k == "parent")
63
480
                    && let Ok(old_id) = serde_json::from_value::<ObjId>(mem::take(v))
64
358
                {
65
358
                    let new_id = id_map.get(&old_id).expect("all ids to be in the id map");
66
358
                    *v = serde_json::to_value(new_id)
67
358
                        .expect("serialization of an ObjId to always succeed");
68
5152
                }
69
            }
70

            
71
            // Recursively process all values
72
5510
            for val in map.values_mut() {
73
5510
                replace_ids(val, id_map);
74
5510
            }
75
        }
76
1504
        JsonValue::Array(arr) => {
77
2420
            for item in arr {
78
2420
                replace_ids(item, id_map);
79
2420
            }
80
        }
81
3126
        _ => {}
82
    }
83
8052
}
84

            
85
pub fn assert_eq_any_order<T: Eq + Hash + Debug + Clone>(a: &Vec<Vec<T>>, b: &Vec<Vec<T>>) {
86
    assert_eq!(a.len(), b.len());
87

            
88
    let mut a_rows: Vec<HashSet<T>> = Vec::new();
89
    for row in a {
90
        let hash_row = to_set(row);
91
        a_rows.push(hash_row);
92
    }
93

            
94
    let mut b_rows: Vec<HashSet<T>> = Vec::new();
95
    for row in b {
96
        let hash_row = to_set(row);
97
        b_rows.push(hash_row);
98
    }
99

            
100
    for row in a_rows {
101
        assert!(b_rows.contains(&row));
102
    }
103
}
104

            
105
122
pub fn serialize_model(model: &ConjureModel) -> Result<String, JsonError> {
106
122
    let serde_model: SerdeModel = model.clone().into();
107

            
108
    // Convert to JSON with stable IDs
109
122
    let json_with_stable_ids = model_to_json_with_stable_ids(&serde_model)?;
110

            
111
    // Sort JSON object keys for consistent output
112
122
    let sorted_json = sort_json_object(&json_with_stable_ids, false);
113

            
114
    // Serialize to pretty-printed string
115
122
    serde_json::to_string_pretty(&sorted_json)
116
122
}
117

            
118
pub fn serialize_domains(model: &ConjureModel) -> Result<String, JsonError> {
119
    let exprs: Vec<ExprInfo> = model.constraints().iter().map(ExprInfo::create).collect();
120
    serde_json::to_string_pretty(&exprs)
121
}
122

            
123
pub fn save_model_json(
124
    model: &ConjureModel,
125
    path: &str,
126
    test_name: &str,
127
    test_stage: &str,
128
    solver: SolverFamily,
129
) -> Result<(), std::io::Error> {
130
    let marker = solver.as_str();
131
    let generated_json_str = serialize_model(model)?;
132
    let generated_json_str = maybe_truncate_serialised_json(generated_json_str, test_stage);
133
    let filename = format!("{path}/{test_name}-{marker}.generated-{test_stage}.serialised.json");
134
    println!("saving: {}", filename);
135
    File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
136
    Ok(())
137
}
138

            
139
1720
pub fn save_stats_json(
140
1720
    context: Arc<RwLock<Context<'static>>>,
141
1720
    path: &str,
142
1720
    test_name: &str,
143
1720
    solver: SolverFamily,
144
1720
) -> Result<(), std::io::Error> {
145
    #[allow(clippy::unwrap_used)]
146
1720
    let solver_name = solver.as_str();
147

            
148
1720
    let stats = context.read().unwrap().clone();
149
1720
    let generated_json = sort_json_object(&serde_json::to_value(stats)?, false);
150

            
151
    // serialise to string
152
1720
    let generated_json_str = serde_json::to_string_pretty(&generated_json)?;
153

            
154
1720
    File::create(format!("{path}/{test_name}-{solver_name}-stats.json"))?
155
1720
        .write_all(generated_json_str.as_bytes())?;
156

            
157
1720
    Ok(())
158
1720
}
159

            
160
/// Reads a file into a `String`, providing a clearer error message that includes the file path.
161
5160
fn read_with_path(path: String) -> Result<String, std::io::Error> {
162
5160
    std::fs::read_to_string(&path)
163
5160
        .map_err(|e| io::Error::new(e.kind(), format!("{} (path: {})", e, path)))
164
5160
}
165

            
166
pub fn read_model_json(
167
    ctx: &Arc<RwLock<Context<'static>>>,
168
    path: &str,
169
    test_name: &str,
170
    prefix: &str,
171
    test_stage: &str,
172
    solver: SolverFamily,
173
) -> Result<ConjureModel, std::io::Error> {
174
    let marker = solver.as_str();
175
    let filepath = format!("{path}/{test_name}-{marker}.{prefix}-{test_stage}.serialised.json");
176
    let expected_json_str = std::fs::read_to_string(filepath)?;
177
    let expected_model: SerdeModel = serde_json::from_str(&expected_json_str)?;
178

            
179
    Ok(expected_model.initialise(ctx.clone()).unwrap())
180
}
181

            
182
/// Reads only the first `max_lines` from a serialised model JSON file.
183
pub fn read_model_json_prefix(
184
    path: &str,
185
    test_name: &str,
186
    prefix: &str,
187
    test_stage: &str,
188
    solver: SolverFamily,
189
    max_lines: usize,
190
) -> Result<String, std::io::Error> {
191
    let marker = solver.as_str();
192
    let filename = format!("{path}/{test_name}-{marker}.{prefix}-{test_stage}.serialised.json");
193
    println!("reading: {}", filename);
194
    read_first_n_lines(filename, max_lines)
195
}
196

            
197
pub fn minion_solutions_from_json(
198
    serialized: &str,
199
) -> Result<Vec<HashMap<Name, Literal>>, anyhow::Error> {
200
    let json: JsonValue = serde_json::from_str(serialized)?;
201

            
202
    let json_array = json
203
        .as_array()
204
        .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
205

            
206
    let mut solutions = Vec::new();
207

            
208
    for solution in json_array {
209
        let mut sol = HashMap::new();
210
        let solution = solution
211
            .as_object()
212
            .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
213

            
214
        for (var_name, constant) in solution {
215
            let constant = match constant {
216
                JsonValue::Number(n) => {
217
                    let n = n
218
                        .as_i64()
219
                        .ok_or(Error::Parse("Invalid integer".to_owned()))?;
220
                    Literal::Int(n as i32)
221
                }
222
                JsonValue::Bool(b) => Literal::Bool(*b),
223
                _ => return Err(Error::Parse("Invalid constant".to_owned()).into()),
224
            };
225

            
226
            sol.insert(User(var_name.into()), constant);
227
        }
228

            
229
        solutions.push(sol);
230
    }
231

            
232
    Ok(solutions)
233
}
234

            
235
/// Writes the minion solutions to a generated JSON file, and returns the JSON structure.
236
1720
pub fn save_solutions_json(
237
1720
    solutions: &Vec<BTreeMap<Name, Literal>>,
238
1720
    path: &str,
239
1720
    test_name: &str,
240
1720
    solver: SolverFamily,
241
1720
) -> Result<JsonValue, std::io::Error> {
242
1720
    let json_solutions = solutions_to_json(solutions);
243
1720
    let generated_json_str = serde_json::to_string_pretty(&json_solutions)?;
244

            
245
1720
    let solver_name = solver.as_str();
246
1720
    let filename = format!("{path}/{test_name}-{solver_name}.generated-solutions.json");
247
1720
    File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
248

            
249
1720
    Ok(json_solutions)
250
1720
}
251

            
252
1720
pub fn read_solutions_json(
253
1720
    path: &str,
254
1720
    test_name: &str,
255
1720
    prefix: &str,
256
1720
    solver: SolverFamily,
257
1720
) -> Result<JsonValue, anyhow::Error> {
258
1720
    let solver_name = solver.as_str();
259
1720
    let filename = format!("{path}/{test_name}-{solver_name}.{prefix}-solutions.json");
260
1720
    let expected_json_str = read_with_path(filename)?;
261

            
262
1720
    let expected_solutions: JsonValue =
263
1720
        sort_json_object(&serde_json::from_str(&expected_json_str)?, true);
264

            
265
1720
    Ok(expected_solutions)
266
1720
}
267

            
268
/// Reads a human-readable rule trace text file.
269
3440
pub fn read_human_rule_trace(
270
3440
    path: &str,
271
3440
    test_name: &str,
272
3440
    prefix: &str,
273
3440
    solver: &SolverFamily,
274
3440
) -> Result<Vec<String>, std::io::Error> {
275
3440
    let solver_name = solver.as_str();
276
3440
    let filename = format!("{path}/{test_name}-{solver_name}-{prefix}-rule-trace.txt");
277
3440
    let rules_trace: Vec<String> = read_with_path(filename)?
278
3440
        .lines()
279
3440
        .map(String::from)
280
3440
        .collect();
281

            
282
3440
    Ok(rules_trace)
283
3440
}
284

            
285
#[doc(hidden)]
286
pub fn normalize_solutions_for_comparison(
287
    input_solutions: &[BTreeMap<Name, Literal>],
288
) -> Vec<BTreeMap<Name, Literal>> {
289
    let mut normalized = input_solutions.to_vec();
290

            
291
    for solset in &mut normalized {
292
        // remove machine names
293
        let keys_to_remove: Vec<Name> = solset
294
            .keys()
295
            .filter(|k| matches!(k, Name::Machine(_)))
296
            .cloned()
297
            .collect();
298
        for k in keys_to_remove {
299
            solset.remove(&k);
300
        }
301

            
302
        let mut updates = vec![];
303
        for (k, v) in solset.clone() {
304
            if let Name::User(_) = k {
305
                match v {
306
                    Literal::Bool(true) => updates.push((k, Literal::Int(1))),
307
                    Literal::Bool(false) => updates.push((k, Literal::Int(0))),
308
                    Literal::Int(_) => {}
309
                    Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, _)) => {
310
                        // make all domains the same (this is just in the tester so the types dont
311
                        // actually matter)
312

            
313
                        let mut matrix =
314
                            AbstractLiteral::Matrix(elems, Moo::new(GroundDomain::Int(vec![])));
315
                        matrix = matrix.transform(&move |x: AbstractLiteral<Literal>| match x {
316
                            AbstractLiteral::Matrix(items, _) => {
317
                                let items = items
318
                                    .into_iter()
319
                                    .map(|x| match x {
320
                                        Literal::Bool(false) => Literal::Int(0),
321
                                        Literal::Bool(true) => Literal::Int(1),
322
                                        x => x,
323
                                    })
324
                                    .collect_vec();
325

            
326
                                AbstractLiteral::Matrix(items, Moo::new(GroundDomain::Int(vec![])))
327
                            }
328
                            x => x,
329
                        });
330
                        updates.push((k, Literal::AbstractLiteral(matrix)));
331
                    }
332
                    Literal::AbstractLiteral(AbstractLiteral::Tuple(elems)) => {
333
                        // just the same as matrix but with tuples instead
334
                        // only conversion needed is to convert bools to ints
335
                        let mut tuple = AbstractLiteral::Tuple(elems);
336
                        tuple = tuple.transform(
337
                            &(move |x: AbstractLiteral<Literal>| match x {
338
                                AbstractLiteral::Tuple(items) => {
339
                                    let items = items
340
                                        .into_iter()
341
                                        .map(|x| match x {
342
                                            Literal::Bool(false) => Literal::Int(0),
343
                                            Literal::Bool(true) => Literal::Int(1),
344
                                            x => x,
345
                                        })
346
                                        .collect_vec();
347

            
348
                                    AbstractLiteral::Tuple(items)
349
                                }
350
                                x => x,
351
                            }),
352
                        );
353
                        updates.push((k, Literal::AbstractLiteral(tuple)));
354
                    }
355
                    Literal::AbstractLiteral(AbstractLiteral::Record(entries)) => {
356
                        // just the same as matrix but with tuples instead
357
                        // only conversion needed is to convert bools to ints
358
                        let mut record = AbstractLiteral::Record(entries);
359
                        record = record.transform(&move |x: AbstractLiteral<Literal>| match x {
360
                            AbstractLiteral::Record(entries) => {
361
                                let entries = entries
362
                                    .into_iter()
363
                                    .map(|x| {
364
                                        let RecordValue { name, value } = x;
365
                                        {
366
                                            let value = match value {
367
                                                Literal::Bool(false) => Literal::Int(0),
368
                                                Literal::Bool(true) => Literal::Int(1),
369
                                                x => x,
370
                                            };
371
                                            RecordValue { name, value }
372
                                        }
373
                                    })
374
                                    .collect_vec();
375

            
376
                                AbstractLiteral::Record(entries)
377
                            }
378
                            x => x,
379
                        });
380
                        updates.push((k, Literal::AbstractLiteral(record)));
381
                    }
382
                    Literal::AbstractLiteral(AbstractLiteral::Set(members)) => {
383
                        let set = AbstractLiteral::Set(members).transform(&move |x| match x {
384
                            AbstractLiteral::Set(members) => {
385
                                let members = members
386
                                    .into_iter()
387
                                    .map(|x| match x {
388
                                        Literal::Bool(false) => Literal::Int(0),
389
                                        Literal::Bool(true) => Literal::Int(1),
390
                                        x => x,
391
                                    })
392
                                    .collect_vec();
393

            
394
                                AbstractLiteral::Set(members)
395
                            }
396
                            x => x,
397
                        });
398
                        updates.push((k, Literal::AbstractLiteral(set)));
399
                    }
400
                    e => bug!("unexpected literal type: {e:?}"),
401
                }
402
            }
403
        }
404

            
405
        for (k, v) in updates {
406
            solset.insert(k, v);
407
        }
408
    }
409

            
410
    // Remove duplicates
411
    normalized = normalized.into_iter().unique().collect();
412
    normalized
413
}
414

            
415
fn maybe_truncate_serialised_json(serialised: String, test_stage: &str) -> String {
416
    if test_stage == "rewrite" {
417
        truncate_to_first_lines(&serialised, REWRITE_SERIALISED_JSON_MAX_LINES)
418
    } else {
419
        serialised
420
    }
421
}
422

            
423
fn truncate_to_first_lines(content: &str, max_lines: usize) -> String {
424
    content.lines().take(max_lines).join("\n")
425
}
426

            
427
fn read_first_n_lines<P: AsRef<Path>>(filename: P, n: usize) -> io::Result<String> {
428
    let reader = BufReader::new(File::open(&filename)?);
429
    let lines = reader
430
        .lines()
431
        .chunks(n)
432
        .into_iter()
433
        .next()
434
        .unwrap()
435
        .collect::<Result<Vec<_>, _>>()?;
436
    Ok(lines.join("\n"))
437
}