1
use std::collections::{BTreeMap, HashMap, HashSet};
2
use std::fmt::Debug;
3
use std::path::Path;
4
use std::{io, vec};
5

            
6
use conjure_cp::ast::records::RecordValue;
7
use conjure_cp::bug;
8
use itertools::Itertools as _;
9
use std::fs::File;
10
use std::hash::Hash;
11
use std::io::{BufRead, BufReader, Write};
12
use std::sync::{Arc, RwLock};
13
use uniplate::Uniplate;
14

            
15
use conjure_cp::ast::{AbstractLiteral, GroundDomain, Moo, SerdeModel};
16
use conjure_cp::context::Context;
17
use serde_json::{Error as JsonError, Value as JsonValue};
18

            
19
use conjure_cp::error::Error;
20

            
21
use crate::utils::conjure::solutions_to_json;
22
use crate::utils::json::sort_json_object;
23
use crate::utils::misc::to_set;
24
use conjure_cp::Model as ConjureModel;
25
use conjure_cp::ast::Name::User;
26
use conjure_cp::ast::{Literal, Name};
27
use conjure_cp::solver::SolverFamily;
28

            
29
/// Limit how many lines of the rewrite serialisation we persist/compare in integration tests.
30
pub const REWRITE_SERIALISED_JSON_MAX_LINES: usize = 1000;
31

            
32
/// Converts a SerdeModel to JSON with stable IDs.
33
///
34
/// This ensures that the same model structure always produces the same IDs,
35
/// regardless of the order in which objects were created in memory.
36
fn model_tojson_with_stable_ids(model: &SerdeModel) -> Result<JsonValue, JsonError> {
37
    // Collect stable ID mapping using uniplate traversal on the SerdeModel
38
    let id_map_u32 = model.collect_stable_id_mapping();
39

            
40
    // Convert to u64 for JSON processing (serde_json::Number only has as_u64())
41
    let id_map: HashMap<u64, u64> = id_map_u32
42
        .into_iter()
43
        .map(|(k, v)| (k as u64, v as u64))
44
        .collect();
45

            
46
    // Serialize the model to JSON
47
    let mut json = serde_json::to_value(model)?;
48

            
49
    // Replace all IDs in the JSON with their stable counterparts
50
    replace_ids(&mut json, &id_map);
51

            
52
    Ok(json)
53
}
54

            
55
/// Recursively replaces all IDs in the JSON with their stable counterparts.
56
///
57
/// This is applied to all fields that are called "id" or "ptr" - be mindful
58
/// of potential naming clashes in the future!
59
fn replace_ids(value: &mut JsonValue, id_map: &HashMap<u64, u64>) {
60
    match value {
61
        JsonValue::Object(map) => {
62
            // Replace IDs in three places:
63
            // - "id" fields (SymbolTable IDs)
64
            // - "parent" fields (SymbolTable nesting)
65
            // - "ptr" fields (DeclarationPtr IDs)
66
            for (k, v) in map.iter_mut() {
67
                if (k == "id" || k == "ptr" || k == "parent")
68
                    && let JsonValue::Number(n) = v
69
                    && let Some(id) = n.as_u64()
70
                    && let Some(&stable_id) = id_map.get(&id)
71
                {
72
                    *v = JsonValue::Number(stable_id.into());
73
                }
74
            }
75

            
76
            // Recursively process all values
77
            for val in map.values_mut() {
78
                replace_ids(val, id_map);
79
            }
80
        }
81
        JsonValue::Array(arr) => {
82
            for item in arr {
83
                replace_ids(item, id_map);
84
            }
85
        }
86
        _ => {}
87
    }
88
}
89

            
90
pub fn assert_eq_any_order<T: Eq + Hash + Debug + Clone>(a: &Vec<Vec<T>>, b: &Vec<Vec<T>>) {
91
    assert_eq!(a.len(), b.len());
92

            
93
    let mut a_rows: Vec<HashSet<T>> = Vec::new();
94
    for row in a {
95
        let hash_row = to_set(row);
96
        a_rows.push(hash_row);
97
    }
98

            
99
    let mut b_rows: Vec<HashSet<T>> = Vec::new();
100
    for row in b {
101
        let hash_row = to_set(row);
102
        b_rows.push(hash_row);
103
    }
104

            
105
    println!("{a_rows:?},{b_rows:?}");
106
    for row in a_rows {
107
        assert!(b_rows.contains(&row));
108
    }
109
}
110

            
111
pub fn serialize_model(model: &ConjureModel) -> Result<String, JsonError> {
112
    let serde_model: SerdeModel = model.clone().into();
113

            
114
    // Convert to JSON with stable IDs
115
    let json_with_stable_ids = model_tojson_with_stable_ids(&serde_model)?;
116

            
117
    // Sort JSON object keys for consistent output
118
    let sorted_json = sort_json_object(&json_with_stable_ids, false);
119

            
120
    // Serialize to pretty-printed string
121
    serde_json::to_string_pretty(&sorted_json)
122
}
123

            
124
pub fn save_model_json(
125
    model: &ConjureModel,
126
    path: &str,
127
    test_name: &str,
128
    test_stage: &str,
129
) -> Result<(), std::io::Error> {
130
    let generated_json_str = serialize_model(model)?;
131
    let generated_json_str = maybe_truncate_serialised_json(generated_json_str, test_stage);
132
    let filename = format!("{path}/{test_name}.generated-{test_stage}.serialised.json");
133
    File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
134
    Ok(())
135
}
136

            
137
pub fn save_stats_json(
138
    context: Arc<RwLock<Context<'static>>>,
139
    path: &str,
140
    test_name: &str,
141
) -> Result<(), std::io::Error> {
142
    #[allow(clippy::unwrap_used)]
143
    let stats = context.read().unwrap().stats.clone();
144
    let generated_json = sort_json_object(&serde_json::to_value(stats)?, false);
145

            
146
    // serialise to string
147
    let generated_json_str = serde_json::to_string_pretty(&generated_json)?;
148

            
149
    File::create(format!("{path}/{test_name}-stats.json"))?
150
        .write_all(generated_json_str.as_bytes())?;
151

            
152
    Ok(())
153
}
154

            
155
/// Reads a file into a `String`, providing a clearer error message that includes the file path.
156
fn read_with_path(path: String) -> Result<String, std::io::Error> {
157
    std::fs::read_to_string(&path)
158
        .map_err(|e| io::Error::new(e.kind(), format!("{} (path: {})", e, path)))
159
}
160

            
161
pub fn read_model_json(
162
    ctx: &Arc<RwLock<Context<'static>>>,
163
    path: &str,
164
    test_name: &str,
165
    prefix: &str,
166
    test_stage: &str,
167
) -> Result<ConjureModel, std::io::Error> {
168
    let expected_json_str = read_with_path(format!(
169
        "{path}/{test_name}.{prefix}-{test_stage}.serialised.json"
170
    ))?;
171
    println!("{path}/{test_name}.{prefix}-{test_stage}.serialised.json");
172
    let expected_model: SerdeModel = serde_json::from_str(&expected_json_str)?;
173

            
174
    Ok(expected_model.initialise(ctx.clone()).unwrap())
175
}
176

            
177
/// Reads only the first `max_lines` from a serialised model JSON file.
178
pub fn read_model_json_prefix(
179
    path: &str,
180
    test_name: &str,
181
    prefix: &str,
182
    test_stage: &str,
183
    max_lines: usize,
184
) -> Result<String, std::io::Error> {
185
    let filename = format!("{path}/{test_name}.{prefix}-{test_stage}.serialised.json");
186
    read_first_n_lines(filename, max_lines)
187
}
188

            
189
pub fn minion_solutions_from_json(
190
    serialized: &str,
191
) -> Result<Vec<HashMap<Name, Literal>>, anyhow::Error> {
192
    let json: JsonValue = serde_json::from_str(serialized)?;
193

            
194
    let json_array = json
195
        .as_array()
196
        .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
197

            
198
    let mut solutions = Vec::new();
199

            
200
    for solution in json_array {
201
        let mut sol = HashMap::new();
202
        let solution = solution
203
            .as_object()
204
            .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
205

            
206
        for (var_name, constant) in solution {
207
            let constant = match constant {
208
                JsonValue::Number(n) => {
209
                    let n = n
210
                        .as_i64()
211
                        .ok_or(Error::Parse("Invalid integer".to_owned()))?;
212
                    Literal::Int(n as i32)
213
                }
214
                JsonValue::Bool(b) => Literal::Bool(*b),
215
                _ => return Err(Error::Parse("Invalid constant".to_owned()).into()),
216
            };
217

            
218
            sol.insert(User(var_name.into()), constant);
219
        }
220

            
221
        solutions.push(sol);
222
    }
223

            
224
    Ok(solutions)
225
}
226

            
227
/// Writes the minion solutions to a generated JSON file, and returns the JSON structure.
228
pub fn save_solutions_json(
229
    solutions: &Vec<BTreeMap<Name, Literal>>,
230
    path: &str,
231
    test_name: &str,
232
    solver: SolverFamily,
233
) -> Result<JsonValue, std::io::Error> {
234
    let json_solutions = solutions_to_json(solutions);
235
    let generated_json_str = serde_json::to_string_pretty(&json_solutions)?;
236

            
237
    let solver_name = match solver {
238
        SolverFamily::Sat => "sat",
239
        SolverFamily::Smt(..) => "smt",
240
        SolverFamily::Minion => "minion",
241
    };
242

            
243
    let filename = format!("{path}/{test_name}.generated-{solver_name}.solutions.json");
244
    File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
245

            
246
    Ok(json_solutions)
247
}
248

            
249
pub fn read_solutions_json(
250
    path: &str,
251
    test_name: &str,
252
    prefix: &str,
253
    solver: SolverFamily,
254
) -> Result<JsonValue, anyhow::Error> {
255
    let solver_name = match solver {
256
        SolverFamily::Sat => "sat",
257
        SolverFamily::Smt(..) => "smt",
258
        SolverFamily::Minion => "minion",
259
    };
260

            
261
    let expected_json_str = read_with_path(format!(
262
        "{path}/{test_name}.{prefix}-{solver_name}.solutions.json"
263
    ))?;
264

            
265
    let expected_solutions: JsonValue =
266
        sort_json_object(&serde_json::from_str(&expected_json_str)?, true);
267

            
268
    Ok(expected_solutions)
269
}
270

            
271
/// Reads a human-readable rule trace text file.
272
pub fn read_human_rule_trace(
273
    path: &str,
274
    test_name: &str,
275
    prefix: &str,
276
) -> Result<Vec<String>, std::io::Error> {
277
    let filename = format!("{path}/{test_name}-{prefix}-rule-trace-human.txt");
278
    let rules_trace: Vec<String> = read_with_path(filename)?
279
        .lines()
280
        .map(String::from)
281
        .collect();
282

            
283
    Ok(rules_trace)
284
}
285

            
286
#[doc(hidden)]
287
pub fn normalize_solutions_for_comparison(
288
    input_solutions: &[BTreeMap<Name, Literal>],
289
) -> Vec<BTreeMap<Name, Literal>> {
290
    let mut normalized = input_solutions.to_vec();
291

            
292
    for solset in &mut normalized {
293
        // remove machine names
294
        let keys_to_remove: Vec<Name> = solset
295
            .keys()
296
            .filter(|k| matches!(k, Name::Machine(_)))
297
            .cloned()
298
            .collect();
299
        for k in keys_to_remove {
300
            solset.remove(&k);
301
        }
302

            
303
        let mut updates = vec![];
304
        for (k, v) in solset.clone() {
305
            if let Name::User(_) = k {
306
                match v {
307
                    Literal::Bool(true) => updates.push((k, Literal::Int(1))),
308
                    Literal::Bool(false) => updates.push((k, Literal::Int(0))),
309
                    Literal::Int(_) => {}
310
                    Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, _)) => {
311
                        // make all domains the same (this is just in the tester so the types dont
312
                        // actually matter)
313

            
314
                        let mut matrix =
315
                            AbstractLiteral::Matrix(elems, Moo::new(GroundDomain::Int(vec![])));
316
                        matrix = matrix.transform(&move |x: AbstractLiteral<Literal>| match x {
317
                            AbstractLiteral::Matrix(items, _) => {
318
                                let items = items
319
                                    .into_iter()
320
                                    .map(|x| match x {
321
                                        Literal::Bool(false) => Literal::Int(0),
322
                                        Literal::Bool(true) => Literal::Int(1),
323
                                        x => x,
324
                                    })
325
                                    .collect_vec();
326

            
327
                                AbstractLiteral::Matrix(items, Moo::new(GroundDomain::Int(vec![])))
328
                            }
329
                            x => x,
330
                        });
331
                        updates.push((k, Literal::AbstractLiteral(matrix)));
332
                    }
333
                    Literal::AbstractLiteral(AbstractLiteral::Tuple(elems)) => {
334
                        // just the same as matrix but with tuples instead
335
                        // only conversion needed is to convert bools to ints
336
                        let mut tuple = AbstractLiteral::Tuple(elems);
337
                        tuple = tuple.transform(
338
                            &(move |x: AbstractLiteral<Literal>| match x {
339
                                AbstractLiteral::Tuple(items) => {
340
                                    let items = items
341
                                        .into_iter()
342
                                        .map(|x| match x {
343
                                            Literal::Bool(false) => Literal::Int(0),
344
                                            Literal::Bool(true) => Literal::Int(1),
345
                                            x => x,
346
                                        })
347
                                        .collect_vec();
348

            
349
                                    AbstractLiteral::Tuple(items)
350
                                }
351
                                x => x,
352
                            }),
353
                        );
354
                        updates.push((k, Literal::AbstractLiteral(tuple)));
355
                    }
356
                    Literal::AbstractLiteral(AbstractLiteral::Record(entries)) => {
357
                        // just the same as matrix but with tuples instead
358
                        // only conversion needed is to convert bools to ints
359
                        let mut record = AbstractLiteral::Record(entries);
360
                        record = record.transform(&move |x: AbstractLiteral<Literal>| match x {
361
                            AbstractLiteral::Record(entries) => {
362
                                let entries = entries
363
                                    .into_iter()
364
                                    .map(|x| {
365
                                        let RecordValue { name, value } = x;
366
                                        {
367
                                            let value = match value {
368
                                                Literal::Bool(false) => Literal::Int(0),
369
                                                Literal::Bool(true) => Literal::Int(1),
370
                                                x => x,
371
                                            };
372
                                            RecordValue { name, value }
373
                                        }
374
                                    })
375
                                    .collect_vec();
376

            
377
                                AbstractLiteral::Record(entries)
378
                            }
379
                            x => x,
380
                        });
381
                        updates.push((k, Literal::AbstractLiteral(record)));
382
                    }
383
                    e => bug!("unexpected literal type: {e:?}"),
384
                }
385
            }
386
        }
387

            
388
        for (k, v) in updates {
389
            solset.insert(k, v);
390
        }
391
    }
392

            
393
    // Remove duplicates
394
    normalized = normalized.into_iter().unique().collect();
395
    normalized
396
}
397

            
398
fn maybe_truncate_serialised_json(serialised: String, test_stage: &str) -> String {
399
    if test_stage == "rewrite" {
400
        truncate_to_first_lines(&serialised, REWRITE_SERIALISED_JSON_MAX_LINES)
401
    } else {
402
        serialised
403
    }
404
}
405

            
406
fn truncate_to_first_lines(content: &str, max_lines: usize) -> String {
407
    content.lines().take(max_lines).join("\n")
408
}
409

            
410
fn read_first_n_lines<P: AsRef<Path>>(filename: P, n: usize) -> io::Result<String> {
411
    let reader = BufReader::new(File::open(&filename)?);
412
    let lines = reader
413
        .lines()
414
        .chunks(n)
415
        .into_iter()
416
        .next()
417
        .unwrap()
418
        .collect::<Result<Vec<_>, _>>()?;
419
    Ok(lines.join("\n"))
420
}