1
use std::collections::{HashMap, HashSet};
2
use std::fmt::Debug;
3

            
4
use std::fs::File;
5
use std::fs::{read_to_string, OpenOptions};
6
use std::hash::Hash;
7
use std::io::Write;
8
use std::sync::{Arc, RwLock};
9

            
10
use conjure_core::context::Context;
11
use serde_json::{json, Error as JsonError, Value as JsonValue};
12

            
13
use conjure_core::error::Error;
14

            
15
use crate::ast::Name::UserName;
16
use crate::ast::{Literal, Name};
17
use crate::utils::conjure::minion_solutions_to_json;
18
use crate::utils::json::sort_json_object;
19
use crate::utils::misc::to_set;
20
use crate::Model as ConjureModel;
21

            
22
pub fn assert_eq_any_order<T: Eq + Hash + Debug + Clone>(a: &Vec<Vec<T>>, b: &Vec<Vec<T>>) {
23
    assert_eq!(a.len(), b.len());
24

            
25
    let mut a_rows: Vec<HashSet<T>> = Vec::new();
26
    for row in a {
27
        let hash_row = to_set(row);
28
        a_rows.push(hash_row);
29
    }
30

            
31
    let mut b_rows: Vec<HashSet<T>> = Vec::new();
32
    for row in b {
33
        let hash_row = to_set(row);
34
        b_rows.push(hash_row);
35
    }
36

            
37
    println!("{:?},{:?}", a_rows, b_rows);
38
    for row in a_rows {
39
        assert!(b_rows.contains(&row));
40
    }
41
}
42

            
43
318
pub fn serialise_model(model: &ConjureModel) -> Result<String, JsonError> {
44
    // A consistent sorting of the keys of json objects
45
    // only required for the generated version
46
    // since the expected version will already be sorted
47
318
    let generated_json = sort_json_object(&serde_json::to_value(model.clone())?, false);
48

            
49
    // serialise to string
50
318
    let generated_json_str = serde_json::to_string_pretty(&generated_json)?;
51

            
52
318
    Ok(generated_json_str)
53
318
}
54

            
55
318
pub fn save_model_json(
56
318
    model: &ConjureModel,
57
318
    path: &str,
58
318
    test_name: &str,
59
318
    test_stage: &str,
60
318
    accept: bool,
61
318
) -> Result<(), std::io::Error> {
62
318
    let generated_json_str = serialise_model(model)?;
63

            
64
318
    File::create(format!(
65
318
        "{path}/{test_name}.generated-{test_stage}.serialised.json"
66
318
    ))?
67
318
    .write_all(generated_json_str.as_bytes())?;
68

            
69
318
    if accept {
70
        std::fs::copy(
71
            format!("{path}/{test_name}.generated-{test_stage}.serialised.json"),
72
            format!("{path}/{test_name}.expected-{test_stage}.serialised.json"),
73
        )?;
74
318
    }
75

            
76
318
    Ok(())
77
318
}
78

            
79
pub fn save_stats_json(
80
    context: Arc<RwLock<Context<'static>>>,
81
    path: &str,
82
    test_name: &str,
83
) -> Result<(), std::io::Error> {
84
    #[allow(clippy::unwrap_used)]
85
    let stats = context.read().unwrap().clone();
86
    let generated_json = sort_json_object(&serde_json::to_value(stats)?, false);
87

            
88
    // serialise to string
89
    let generated_json_str = serde_json::to_string_pretty(&generated_json)?;
90

            
91
    File::create(format!("{path}/{test_name}-stats.json"))?
92
        .write_all(generated_json_str.as_bytes())?;
93

            
94
    Ok(())
95
}
96

            
97
318
pub fn read_model_json(
98
318
    path: &str,
99
318
    test_name: &str,
100
318
    prefix: &str,
101
318
    test_stage: &str,
102
318
) -> Result<ConjureModel, std::io::Error> {
103
318
    let expected_json_str = std::fs::read_to_string(format!(
104
318
        "{path}/{test_name}.{prefix}-{test_stage}.serialised.json"
105
318
    ))?;
106

            
107
318
    let expected_model: ConjureModel = serde_json::from_str(&expected_json_str)?;
108

            
109
18
    Ok(expected_model)
110
318
}
111

            
112
pub fn minion_solutions_from_json(
113
    serialized: &str,
114
) -> Result<Vec<HashMap<Name, Literal>>, anyhow::Error> {
115
    let json: JsonValue = serde_json::from_str(serialized)?;
116

            
117
    let json_array = json
118
        .as_array()
119
        .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
120

            
121
    let mut solutions = Vec::new();
122

            
123
    for solution in json_array {
124
        let mut sol = HashMap::new();
125
        let solution = solution
126
            .as_object()
127
            .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
128

            
129
        for (var_name, constant) in solution {
130
            let constant = match constant {
131
                JsonValue::Number(n) => {
132
                    let n = n
133
                        .as_i64()
134
                        .ok_or(Error::Parse("Invalid integer".to_owned()))?;
135
                    Literal::Int(n as i32)
136
                }
137
                JsonValue::Bool(b) => Literal::Bool(*b),
138
                _ => return Err(Error::Parse("Invalid constant".to_owned()).into()),
139
            };
140

            
141
            sol.insert(UserName(var_name.into()), constant);
142
        }
143

            
144
        solutions.push(sol);
145
    }
146

            
147
    Ok(solutions)
148
}
149

            
150
pub fn save_minion_solutions_json(
151
    solutions: &Vec<HashMap<Name, Literal>>,
152
    path: &str,
153
    test_name: &str,
154
    accept: bool,
155
) -> Result<JsonValue, std::io::Error> {
156
    let json_solutions = minion_solutions_to_json(solutions);
157

            
158
    let generated_json_str = serde_json::to_string_pretty(&json_solutions)?;
159

            
160
    File::create(format!(
161
        "{path}/{test_name}.generated-minion.solutions.json"
162
    ))?
163
    .write_all(generated_json_str.as_bytes())?;
164

            
165
    if accept {
166
        std::fs::copy(
167
            format!("{path}/{test_name}.generated-minion.solutions.json"),
168
            format!("{path}/{test_name}.expected-minion.solutions.json"),
169
        )?;
170
    }
171

            
172
    Ok(json_solutions)
173
}
174

            
175
pub fn read_minion_solutions_json(
176
    path: &str,
177
    test_name: &str,
178
    prefix: &str,
179
) -> Result<JsonValue, anyhow::Error> {
180
    let expected_json_str =
181
        std::fs::read_to_string(format!("{path}/{test_name}.{prefix}-minion.solutions.json"))?;
182

            
183
    let expected_solutions: JsonValue =
184
        sort_json_object(&serde_json::from_str(&expected_json_str)?, true);
185

            
186
    Ok(expected_solutions)
187
}
188

            
189
pub fn read_rule_trace(
190
    path: &str,
191
    test_name: &str,
192
    prefix: &str,
193
    accept: bool,
194
) -> Result<Vec<String>, std::io::Error> {
195
    let filename = format!("{path}/{test_name}-{prefix}-rule-trace.json");
196
    let mut rules_trace: Vec<String> = read_to_string(&filename)
197
        .unwrap()
198
        .lines()
199
        .map(String::from)
200
        .collect();
201

            
202
    //only count the number of rule in generated file (assumming the expected version already has that line and it is correct)
203
    if prefix == "generated" {
204
        let rule_count = rules_trace.len();
205

            
206
        let count_message = json!({
207
            "message": " Number of rules applied",
208
            "count": rule_count
209
        });
210

            
211
        // Append the count message to the vector
212
        let count_message_string = serde_json::to_string(&count_message)?;
213
        rules_trace.push(count_message_string.clone());
214

            
215
        // Write the updated rules trace back to the file
216
        let mut file = OpenOptions::new()
217
            .write(true)
218
            .truncate(true) // Overwrite the file with updated content
219
            .open(&filename)?;
220

            
221
        writeln!(file, "{}", rules_trace.join("\n"))?;
222
    }
223

            
224
    if accept {
225
        std::fs::copy(
226
            format!("{path}/{test_name}-generated-rule-trace.json"),
227
            format!("{path}/{test_name}-expected-rule-trace.json"),
228
        )?;
229
    }
230

            
231
    Ok(rules_trace)
232
}