1
use std::collections::{HashMap, HashSet};
2
use std::fmt::Debug;
3

            
4
use std::fs::File;
5
use std::fs::{read_to_string, OpenOptions};
6
use std::hash::Hash;
7
use std::io::Write;
8
use std::sync::{Arc, RwLock};
9

            
10
use conjure_core::context::Context;
11
use serde_json::{json, Error as JsonError, Value as JsonValue};
12

            
13
use conjure_core::error::Error;
14

            
15
use crate::ast::Name::UserName;
16
use crate::ast::{Literal, Name};
17
use crate::utils::conjure::minion_solutions_to_json;
18
use crate::utils::json::sort_json_object;
19
use crate::utils::misc::to_set;
20
use crate::Model as ConjureModel;
21

            
22
pub fn assert_eq_any_order<T: Eq + Hash + Debug + Clone>(a: &Vec<Vec<T>>, b: &Vec<Vec<T>>) {
23
    assert_eq!(a.len(), b.len());
24

            
25
    let mut a_rows: Vec<HashSet<T>> = Vec::new();
26
    for row in a {
27
        let hash_row = to_set(row);
28
        a_rows.push(hash_row);
29
    }
30

            
31
    let mut b_rows: Vec<HashSet<T>> = Vec::new();
32
    for row in b {
33
        let hash_row = to_set(row);
34
        b_rows.push(hash_row);
35
    }
36

            
37
    println!("{:?},{:?}", a_rows, b_rows);
38
    for row in a_rows {
39
        assert!(b_rows.contains(&row));
40
    }
41
}
42

            
43
600
pub fn serialise_model(model: &ConjureModel) -> Result<String, JsonError> {
44
    // A consistent sorting of the keys of json objects
45
    // only required for the generated version
46
    // since the expected version will already be sorted
47
600
    let generated_json = sort_json_object(&serde_json::to_value(model.clone())?, false);
48

            
49
    // serialise to string
50
600
    let generated_json_str = serde_json::to_string_pretty(&generated_json)?;
51

            
52
600
    Ok(generated_json_str)
53
600
}
54

            
55
600
pub fn save_model_json(
56
600
    model: &ConjureModel,
57
600
    path: &str,
58
600
    test_name: &str,
59
600
    test_stage: &str,
60
600
    accept: bool,
61
600
) -> Result<(), std::io::Error> {
62
600
    let generated_json_str = serialise_model(model)?;
63

            
64
600
    File::create(format!(
65
600
        "{path}/{test_name}.generated-{test_stage}.serialised.json"
66
600
    ))?
67
600
    .write_all(generated_json_str.as_bytes())?;
68

            
69
600
    if accept {
70
        std::fs::copy(
71
            format!("{path}/{test_name}.generated-{test_stage}.serialised.json"),
72
            format!("{path}/{test_name}.expected-{test_stage}.serialised.json"),
73
        )?;
74
600
    }
75

            
76
600
    Ok(())
77
600
}
78

            
79
348
pub fn save_stats_json(
80
348
    context: Arc<RwLock<Context<'static>>>,
81
348
    path: &str,
82
348
    test_name: &str,
83
348
) -> Result<(), std::io::Error> {
84
348
    #[allow(clippy::unwrap_used)]
85
348
    let stats = context.read().unwrap().clone();
86
348
    let generated_json = sort_json_object(&serde_json::to_value(stats)?, false);
87

            
88
    // serialise to string
89
348
    let generated_json_str = serde_json::to_string_pretty(&generated_json)?;
90

            
91
348
    File::create(format!("{path}/{test_name}-stats.json"))?
92
348
        .write_all(generated_json_str.as_bytes())?;
93

            
94
348
    Ok(())
95
348
}
96

            
97
600
pub fn read_model_json(
98
600
    path: &str,
99
600
    test_name: &str,
100
600
    prefix: &str,
101
600
    test_stage: &str,
102
600
) -> Result<ConjureModel, std::io::Error> {
103
600
    let expected_json_str = std::fs::read_to_string(format!(
104
600
        "{path}/{test_name}.{prefix}-{test_stage}.serialised.json"
105
600
    ))?;
106

            
107
600
    let expected_model: ConjureModel = serde_json::from_str(&expected_json_str)?;
108

            
109
600
    Ok(expected_model)
110
600
}
111

            
112
pub fn minion_solutions_from_json(
113
    serialized: &str,
114
) -> Result<Vec<HashMap<Name, Literal>>, anyhow::Error> {
115
    let json: JsonValue = serde_json::from_str(serialized)?;
116

            
117
    let json_array = json
118
        .as_array()
119
        .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
120

            
121
    let mut solutions = Vec::new();
122

            
123
    for solution in json_array {
124
        let mut sol = HashMap::new();
125
        let solution = solution
126
            .as_object()
127
            .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
128

            
129
        for (var_name, constant) in solution {
130
            let constant = match constant {
131
                JsonValue::Number(n) => {
132
                    let n = n
133
                        .as_i64()
134
                        .ok_or(Error::Parse("Invalid integer".to_owned()))?;
135
                    Literal::Int(n as i32)
136
                }
137
                JsonValue::Bool(b) => Literal::Bool(*b),
138
                _ => return Err(Error::Parse("Invalid constant".to_owned()).into()),
139
            };
140

            
141
            sol.insert(UserName(var_name.into()), constant);
142
        }
143

            
144
        solutions.push(sol);
145
    }
146

            
147
    Ok(solutions)
148
}
149

            
150
300
pub fn save_minion_solutions_json(
151
300
    solutions: &Vec<HashMap<Name, Literal>>,
152
300
    path: &str,
153
300
    test_name: &str,
154
300
    accept: bool,
155
300
) -> Result<JsonValue, std::io::Error> {
156
300
    let json_solutions = minion_solutions_to_json(solutions);
157

            
158
300
    let generated_json_str = serde_json::to_string_pretty(&json_solutions)?;
159

            
160
300
    File::create(format!(
161
300
        "{path}/{test_name}.generated-minion.solutions.json"
162
300
    ))?
163
300
    .write_all(generated_json_str.as_bytes())?;
164

            
165
300
    if accept {
166
        std::fs::copy(
167
            format!("{path}/{test_name}.generated-minion.solutions.json"),
168
            format!("{path}/{test_name}.expected-minion.solutions.json"),
169
        )?;
170
300
    }
171

            
172
300
    Ok(json_solutions)
173
300
}
174

            
175
300
pub fn read_minion_solutions_json(
176
300
    path: &str,
177
300
    test_name: &str,
178
300
    prefix: &str,
179
300
) -> Result<JsonValue, anyhow::Error> {
180
300
    let expected_json_str =
181
300
        std::fs::read_to_string(format!("{path}/{test_name}.{prefix}-minion.solutions.json"))?;
182

            
183
300
    let expected_solutions: JsonValue =
184
300
        sort_json_object(&serde_json::from_str(&expected_json_str)?, true);
185

            
186
300
    Ok(expected_solutions)
187
300
}
188

            
189
600
pub fn read_rule_trace(
190
600
    path: &str,
191
600
    test_name: &str,
192
600
    prefix: &str,
193
600
) -> Result<Vec<String>, std::io::Error> {
194
600
    let filename = format!("{path}/{test_name}-{prefix}-rule-trace.json");
195
600
    let mut rules_trace: Vec<String> = read_to_string(&filename)
196
600
        .unwrap()
197
600
        .lines()
198
600
        .map(String::from)
199
600
        .collect();
200
600

            
201
600
    //only count the number of rule in generated file (assumming the expected version already has that line and it is correct)
202
600
    if prefix == "generated" {
203
300
        let rule_count = rules_trace.len();
204
300

            
205
300
        let count_message = json!({
206
300
            "message": " Number of rules applied",
207
300
            "count": rule_count
208
300
        });
209

            
210
        // Append the count message to the vector
211
300
        let count_message_string = serde_json::to_string(&count_message)?;
212
300
        rules_trace.push(count_message_string.clone());
213

            
214
        // Write the updated rules trace back to the file
215
300
        let mut file = OpenOptions::new()
216
300
            .write(true)
217
300
            .truncate(true) // Overwrite the file with updated content
218
300
            .open(&filename)?;
219

            
220
300
        writeln!(file, "{}", rules_trace.join("\n"))?;
221
300
    }
222

            
223
600
    Ok(rules_trace)
224
600
}