1
use conjure_oxide::utils::testing::read_rule_trace;
2
use glob::glob;
3
use serde_json::json;
4
use serde_json::Value;
5
use std::collections::HashMap;
6
use std::env;
7
use std::error::Error;
8
use std::fs;
9
use std::fs::File;
10
use tracing::field::Field;
11
use tracing::field::Visit;
12
use tracing::Subscriber;
13
use tracing::{span, Level};
14
use tracing_subscriber::{filter::EnvFilter, fmt, fmt::FmtContext, layer::SubscriberExt, Registry};
15

            
16
use tracing_appender::non_blocking::WorkerGuard;
17

            
18
use std::io::BufWriter;
19
use std::path::Path;
20
use std::sync::Arc;
21
use std::sync::Mutex;
22
use std::sync::RwLock;
23

            
24
use conjure_core::ast::Atom;
25
use conjure_core::ast::{Expression, Literal, Name};
26
use conjure_core::context::Context;
27
use conjure_oxide::defaults::get_default_rule_sets;
28
use conjure_oxide::rule_engine::resolve_rule_sets;
29
use conjure_oxide::rule_engine::rewrite_model;
30
use conjure_oxide::utils::conjure::minion_solutions_to_json;
31
use conjure_oxide::utils::conjure::{
32
    get_minion_solutions, get_solutions_from_conjure, parse_essence_file,
33
};
34
use conjure_oxide::utils::testing::save_stats_json;
35
use conjure_oxide::utils::testing::{
36
    read_minion_solutions_json, read_model_json, save_minion_solutions_json, save_model_json,
37
};
38
use conjure_oxide::SolverFamily;
39
use serde::Deserialize;
40
use uniplate::Uniplate;
41

            
42
use pretty_assertions::assert_eq;
43
use tracing::Event;
44
use tracing_subscriber::fmt::format::Writer;
45
use tracing_subscriber::fmt::FormatEvent;
46
use tracing_subscriber::registry::LookupSpan;
47

            
48
6
#[derive(Deserialize, Default)]
49
struct TestConfig {
50
    extra_rewriter_asserts: Vec<String>,
51
}
52

            
53
fn main() {
54
    let _guard = create_scoped_subscriber("./logs", "test_log");
55

            
56
    // creating a span and log a message
57
    let test_span = span!(Level::TRACE, "test_span");
58
    let _enter: span::Entered<'_> = test_span.enter();
59

            
60
    for entry in glob("conjure_oxide/tests/integration/*").expect("Failed to read glob pattern") {
61
        match entry {
62
            Ok(path) => println!("File: {:?}", path),
63
            Err(e) => println!("Error: {:?}", e),
64
        }
65
    }
66

            
67
    let file_path = Path::new("conjure_oxide/tests/integration/*"); // using relative path
68

            
69
    let base_name = file_path.file_stem().and_then(|stem| stem.to_str());
70

            
71
    match base_name {
72
        Some(name) => println!("Base name: {}", name),
73
        None => println!("Could not extract the base name"),
74
    }
75
}
76

            
77
// run tests in sequence not parallel when verbose logging, to ensure the logs are ordered
78
// correctly
79
static GUARD: Mutex<()> = Mutex::new(());
80

            
81
// wrapper to conditionally enforce sequential execution
82
50
fn integration_test(path: &str, essence_base: &str, extension: &str) -> Result<(), Box<dyn Error>> {
83
50
    let verbose = env::var("VERBOSE").unwrap_or("false".to_string()) == "true";
84
50

            
85
50
    // Lock here to ensure sequential execution
86
50
    let _guard = GUARD.lock().unwrap();
87
50

            
88
50
    // run tests in sequence not parallel when verbose logging, to ensure the logs are ordered
89
50
    // correctly
90
50

            
91
50
    let (subscriber, _guard) = create_scoped_subscriber(path, essence_base);
92
50

            
93
50
    // set the subscriber as default
94
50
    tracing::subscriber::with_default(subscriber, || {
95
        // create a span for the trace
96
50
        let test_span = span!(target: "rule_engine", Level::TRACE, "test_span");
97
50
        let _enter = test_span.enter();
98
50

            
99
50
        // execute tests based on verbosity
100
50
        if verbose {
101
            #[allow(clippy::unwrap_used)]
102
            let _guard = GUARD.lock().unwrap();
103
            integration_test_inner(path, essence_base, extension)?
104
        } else {
105
50
            integration_test_inner(path, essence_base, extension)?
106
        }
107

            
108
50
        Ok(())
109
50
    })
110
50
}
111

            
112
/// Runs an integration test for a given Conjure model by:
113
/// 1. Parsing the model from an Essence file.
114
/// 2. Rewriting the model according to predefined rule sets.
115
/// 3. Solving the model using the Minion solver and validating the solutions.
116
///
117
/// This function operates in three main stages:
118
/// - **Parsing Stage**: Reads the Essence model file and verifies that it parses correctly.
119
/// - **Rewrite Stage**: Applies a set of rules to the parsed model and validates the result.
120
/// - **Solution Stage**: Uses Minion to solve the model and compares solutions with expected results.
121
///
122
/// # Arguments
123
///
124
/// * `path` - The file path where the Essence model and other resources are located.
125
/// * `essence_base` - The base name of the Essence model file.
126
/// * `extension` - The file extension for the Essence model.
127
///
128
/// # Errors
129
///
130
/// Returns an error if any stage fails due to a mismatch with expected results or file I/O issues.
131
#[allow(clippy::unwrap_used)]
132
50
fn integration_test_inner(
133
50
    path: &str,
134
50
    essence_base: &str,
135
50
    extension: &str,
136
50
) -> Result<(), Box<dyn Error>> {
137
50
    let context: Arc<RwLock<Context<'static>>> = Default::default();
138
50
    let accept = env::var("ACCEPT").unwrap_or("false".to_string()) == "true";
139
50
    let verbose = env::var("VERBOSE").unwrap_or("false".to_string()) == "true";
140
50

            
141
50
    if verbose {
142
        println!(
143
            "Running integration test for {}/{}, ACCEPT={}",
144
            path, essence_base, accept
145
        );
146
50
    }
147

            
148
50
    let config: TestConfig =
149
50
        if let Ok(config_contents) = fs::read_to_string(format!("{}/config.toml", path)) {
150
6
            toml::from_str(&config_contents).unwrap()
151
        } else {
152
44
            Default::default()
153
        };
154

            
155
    // Stage 1: Read the essence file and check that the model is parsed correctly
156
50
    let model = parse_essence_file(path, essence_base, extension, context.clone())?;
157
50
    if verbose {
158
        println!("Parsed model: {:#?}", model)
159
50
    }
160

            
161
50
    context.as_ref().write().unwrap().file_name =
162
50
        Some(format!("{path}/{essence_base}.{extension}"));
163
50

            
164
50
    save_model_json(&model, path, essence_base, "parse", accept)?;
165
50
    let expected_model = read_model_json(path, essence_base, "expected", "parse")?;
166
50
    if verbose {
167
        println!("Expected model: {:#?}", expected_model)
168
50
    }
169

            
170
50
    assert_eq!(model, expected_model);
171

            
172
    // Stage 2: Rewrite the model using the rule engine and check that the result is as expected
173
50
    let rule_sets = resolve_rule_sets(SolverFamily::Minion, &get_default_rule_sets())?;
174
50
    let model = rewrite_model(&model, &rule_sets)?;
175

            
176
50
    if verbose {
177
        println!("Rewritten model: {:#?}", model)
178
50
    }
179

            
180
50
    save_model_json(&model, path, essence_base, "rewrite", accept)?;
181

            
182
56
    for extra_assert in config.extra_rewriter_asserts {
183
6
        match extra_assert.as_str() {
184
6
            "vector_operators_have_partially_evaluated" => {
185
6
                assert_vector_operators_have_partially_evaluated(&model)
186
            }
187
            x => println!("Unrecognised extra assert: {}", x),
188
        };
189
    }
190

            
191
50
    let expected_model = read_model_json(path, essence_base, "expected", "rewrite")?;
192
50
    if verbose {
193
        println!("Expected model: {:#?}", expected_model)
194
50
    }
195

            
196
50
    assert_eq!(model, expected_model);
197

            
198
    // Stage 3: Run the model through the Minion solver and check that the solutions are as expected
199
50
    let solutions = get_minion_solutions(model)?;
200

            
201
50
    let solutions_json = save_minion_solutions_json(&solutions, path, essence_base, accept)?;
202
50
    if verbose {
203
        println!("Minion solutions: {:#?}", solutions_json)
204
50
    }
205

            
206
50
    let expected_rule_trace = read_rule_trace(path, essence_base, "expected")?;
207
50
    let generated_rule_trace = read_rule_trace(path, essence_base, "generated")?;
208

            
209
50
    assert_eq!(expected_rule_trace, generated_rule_trace);
210

            
211
    // test solutions against conjure before writing
212
50
    if accept {
213
        let mut conjure_solutions: Vec<HashMap<Name, Literal>> =
214
            get_solutions_from_conjure(&format!("{}/{}.{}", path, essence_base, extension))?;
215

            
216
        // Change bools to nums in both outputs, as we currently don't convert 0,1 back to
217
        // booleans for Minion.
218

            
219
        // remove machine names from Minion solutions, as the conjure solutions won't have these.
220
        let mut username_solutions = solutions.clone();
221
        for solset in &mut username_solutions {
222
            for (k, v) in solset.clone().into_iter() {
223
                match k {
224
                    conjure_core::ast::Name::MachineName(_) => {
225
                        solset.remove(&k);
226
                    }
227
                    conjure_core::ast::Name::UserName(_) => match v {
228
                        Literal::Bool(true) => {
229
                            solset.insert(k, Literal::Int(1));
230
                        }
231
                        Literal::Bool(false) => {
232
                            solset.insert(k, Literal::Int(0));
233
                        }
234
                        _ => {}
235
                    },
236
                }
237
            }
238
        }
239

            
240
        for solset in &mut conjure_solutions {
241
            for (k, v) in solset.clone().into_iter() {
242
                match v {
243
                    Literal::Bool(true) => {
244
                        solset.insert(k, Literal::Int(1));
245
                    }
246
                    Literal::Bool(false) => {
247
                        solset.insert(k, Literal::Int(0));
248
                    }
249
                    _ => {}
250
                }
251
            }
252
        }
253

            
254
        // I can't make these sets of hashmaps due to hashmaps not implementing hash; so, to
255
        // compare these, I make them both json and compare that.
256
        let mut conjure_solutions_json: serde_json::Value =
257
            minion_solutions_to_json(&conjure_solutions);
258
        let mut username_solutions_json: serde_json::Value =
259
            minion_solutions_to_json(&username_solutions);
260
        conjure_solutions_json.sort_all_objects();
261
        username_solutions_json.sort_all_objects();
262

            
263
        assert_eq!(
264
            username_solutions_json, conjure_solutions_json,
265
            "Solutions do not match conjure!"
266
        );
267
50
    }
268

            
269
50
    let expected_solutions_json = read_minion_solutions_json(path, essence_base, "expected")?;
270
50
    if verbose {
271
        println!("Expected solutions: {:#?}", expected_solutions_json)
272
50
    }
273

            
274
50
    assert_eq!(solutions_json, expected_solutions_json);
275

            
276
50
    save_stats_json(context, path, essence_base)?;
277

            
278
50
    Ok(())
279
50
}
280

            
281
6
fn assert_vector_operators_have_partially_evaluated(model: &conjure_core::Model) {
282
61
    model.constraints.transform(Arc::new(|x| {
283
        use conjure_core::ast::Expression::*;
284
61
        match &x {
285
            Bubble(_, _, _) => (),
286
24
            Atomic(_, _) => (),
287
            Sum(_, vec) => assert_constants_leq_one(&x, vec),
288
            Min(_, vec) => assert_constants_leq_one(&x, vec),
289
            Max(_, vec) => assert_constants_leq_one(&x, vec),
290
            Not(_, _) => (),
291
7
            Or(_, vec) => assert_constants_leq_one(&x, vec),
292
6
            And(_, vec) => assert_constants_leq_one(&x, vec),
293
2
            Eq(_, _, _) => (),
294
            Neq(_, _, _) => (),
295
            Geq(_, _, _) => (),
296
            Leq(_, _, _) => (),
297
            Gt(_, _, _) => (),
298
            Lt(_, _, _) => (),
299
            SafeDiv(_, _, _) => (),
300
            UnsafeDiv(_, _, _) => (),
301
            SumEq(_, vec, _) => assert_constants_leq_one(&x, vec),
302
2
            SumGeq(_, vec, _) => assert_constants_leq_one(&x, vec),
303
3
            SumLeq(_, vec, _) => assert_constants_leq_one(&x, vec),
304
            DivEqUndefZero(_, _, _, _) => (),
305
2
            Ineq(_, _, _, _) => (),
306
            // this is a vector operation, but we don't want to fold values into each-other in this
307
            // one
308
            AllDiff(_, _) => (),
309
15
            WatchedLiteral(_, _, _) => (),
310
            Reify(_, _, _) => (),
311
            AuxDeclaration(_, _, _) => (),
312
            UnsafeMod(_, _, _) => (),
313
            SafeMod(_, _, _) => (),
314
            ModuloEqUndefZero(_, _, _, _) => (),
315
        };
316
61
        x.clone()
317
61
    }));
318
6
}
319

            
320
18
fn assert_constants_leq_one(parent_expr: &Expression, exprs: &[Expression]) {
321
18
    let count = exprs.iter().fold(0, |i, x| match x {
322
        Expression::Atomic(_, Atom::Literal(_)) => i + 1,
323
40
        _ => i,
324
40
    });
325
18

            
326
18
    assert!(count <= 1, "assert_vector_operators_have_partially_evaluated: expression {} is not partially evaluated",parent_expr)
327
18
}
328

            
329
// using a custom formatter to omit the span name in the log
330
// and removing the identifier and application fields for assertions
331
struct JsonFormatter;
332

            
333
impl<S, N> FormatEvent<S, N> for JsonFormatter
334
where
335
    S: Subscriber + for<'span> LookupSpan<'span>,
336
    N: for<'a> tracing_subscriber::fmt::FormatFields<'a> + 'static,
337
{
338
312
    fn format_event(
339
312
        &self,
340
312
        _ctx: &FmtContext<'_, S, N>,
341
312
        mut writer: Writer<'_>,
342
312
        event: &Event<'_>,
343
312
    ) -> std::fmt::Result {
344
312
        // initialising the log object with level and target
345
312
        let mut log = json!({
346
312
            //"level": event.metadata().level().to_string(),
347
312
            "target": event.metadata().target(),
348
312
        });
349

            
350
        // creating a visitor to capture fields
351
        struct JsonVisitor {
352
            log: Value,
353
        }
354

            
355
        impl Visit for JsonVisitor {
356
            fn record_str(&mut self, field: &Field, value: &str) {
357
                self.log
358
                    .as_object_mut()
359
                    .map(|obj| obj.insert(field.name().to_string(), json!(value)));
360
            }
361

            
362
312
            fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) {
363
312
                self.log
364
312
                    .as_object_mut()
365
312
                    .map(|obj| obj.insert(field.name().to_string(), json!(format!("{:?}", value))));
366
312
            }
367
        }
368

            
369
        // using the visitor to record fields
370
312
        let mut visitor = JsonVisitor { log: log.clone() };
371
312
        event.record(&mut visitor);
372
312

            
373
312
        // merging the visitor's log into the main log object
374
312
        log.as_object_mut().map(|obj| {
375
312
            if let Some(visitor_obj) = visitor.log.as_object() {
376
936
                for (key, value) in visitor_obj {
377
624
                    obj.insert(key.clone(), value.clone());
378
624
                }
379
            }
380
312
        });
381
312

            
382
312
        // Write the JSON log
383
312
        write!(writer, "{}\n", log)
384
312
    }
385
}
386

            
387
50
pub fn create_scoped_subscriber(
388
50
    path: &str,
389
50
    test_name: &str,
390
50
) -> (impl tracing::Subscriber + Send + Sync, WorkerGuard) {
391
50
    let file = File::create(format!("{path}/{test_name}-generated-rule-trace.json"))
392
50
        .expect("Unable to create log file");
393
50
    let writer = BufWriter::new(file);
394
50
    let (non_blocking, guard) = tracing_appender::non_blocking(writer);
395
50

            
396
50
    // subscriber setup with the JSON formatter
397
50
    let subscriber = Registry::default()
398
50
        .with(EnvFilter::new("rule_engine=trace"))
399
50
        .with(
400
50
            fmt::layer()
401
50
                .with_writer(non_blocking)
402
50
                .json()
403
50
                .event_format(JsonFormatter),
404
50
        );
405
50

            
406
50
    // wrapping the subscriber in an Arc to share across multiple threads
407
50
    let subscriber = Arc::new(subscriber) as Arc<dyn tracing::Subscriber + Send + Sync>;
408
50

            
409
50
    // setting this subscriber as the default
410
50
    let _default = tracing::subscriber::set_default(subscriber.clone());
411
50

            
412
50
    (subscriber, guard)
413
50
}
414

            
415
#[test]
416
1
fn assert_conjure_present() {
417
1
    conjure_oxide::find_conjure::conjure_executable().unwrap();
418
1
}
419

            
420
include!(concat!(env!("OUT_DIR"), "/gen_tests.rs"));