1
//! conjure_oxide solve sub-command
2
#![allow(clippy::unwrap_used)]
3
#[cfg(feature = "smt")]
4
use std::time::Duration;
5
use std::{
6
    fs::File,
7
    io::Write as _,
8
    path::PathBuf,
9
    process::exit,
10
    sync::{Arc, RwLock},
11
};
12

            
13
use anyhow::{anyhow, ensure};
14
use clap::ValueHint;
15
use conjure_cp::defaults::DEFAULT_RULE_SETS;
16
use conjure_cp::{
17
    Model,
18
    ast::comprehension::{
19
        USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS, set_quantified_expander_for_comprehensions,
20
    },
21
    context::Context,
22
    rule_engine::{resolve_rule_sets, rewrite_morph, rewrite_naive},
23
    settings::Rewriter,
24
    solver::Solver,
25
};
26
use conjure_cp::{
27
    parse::conjure_json::model_from_json, rule_engine::get_rules, settings::SolverFamily,
28
};
29
use conjure_cp::{parse::tree_sitter::parse_essence_file_native, solver::adaptors::*};
30
use conjure_cp_cli::find_conjure::conjure_executable;
31
use conjure_cp_cli::utils::conjure::{get_solutions, solutions_to_json};
32
use serde_json::to_string_pretty;
33

            
34
use crate::cli::{GlobalArgs, LOGGING_HELP_HEADING};
35

            
36
#[derive(Clone, Debug, clap::Args)]
37
pub struct Args {
38
    /// The input Essence file
39
    #[arg(value_name = "INPUT_ESSENCE", value_hint = ValueHint::FilePath)]
40
    pub input_file: PathBuf,
41

            
42
    /// Save execution info as JSON to the given filepath.
43
    #[arg(long ,value_hint=ValueHint::FilePath,help_heading=LOGGING_HELP_HEADING)]
44
    pub info_json_path: Option<PathBuf>,
45

            
46
    /// Do not run the solver.
47
    ///
48
    /// The rewritten model is printed to stdout in an Essence-style syntax
49
    /// (but is not necessarily valid Essence).
50
    #[arg(long, default_value_t = false)]
51
    pub no_run_solver: bool,
52

            
53
    /// Number of solutions to return. 0 returns all solutions
54
    #[arg(long, default_value_t = 0, short = 'n')]
55
    pub number_of_solutions: i32,
56

            
57
    /// Save solutions to the given JSON file
58
    #[arg(long, short = 'o', value_hint = ValueHint::FilePath,help_heading=LOGGING_HELP_HEADING)]
59
    pub output: Option<PathBuf>,
60
}
61

            
62
14
pub fn run_solve_command(global_args: GlobalArgs, solve_args: Args) -> anyhow::Result<()> {
63
14
    let input_file = solve_args.input_file.clone();
64

            
65
    // each step is in its own method so that similar commands
66
    // (e.g. testsolve) can reuse some of these steps.
67

            
68
14
    let context = init_context(&global_args, input_file)?;
69
14
    let model = parse(&global_args, Arc::clone(&context))?;
70
10
    let rewritten_model = rewrite(model, &global_args, Arc::clone(&context))?;
71
10
    let solver = init_solver(&global_args);
72

            
73
10
    if solve_args.no_run_solver {
74
        println!("{}", &rewritten_model);
75

            
76
        if let Some(path) = global_args.save_solver_input_file {
77
            let solver = solver.load_model(rewritten_model)?;
78
            eprintln!("Writing solver input file to {}", path.display());
79
            let mut file: Box<dyn std::io::Write> = Box::new(File::create(path)?);
80
            solver.write_solver_input_file(&mut file)?;
81
        }
82
    } else {
83
10
        run_solver(solver, &global_args, &solve_args, rewritten_model)?
84
    }
85

            
86
    // still do postamble even if we didn't run the solver
87
10
    if let Some(ref path) = solve_args.info_json_path {
88
        let context_obj = context.read().unwrap().clone();
89
        let generated_json = &serde_json::to_value(context_obj)?;
90
        let pretty_json = serde_json::to_string_pretty(&generated_json)?;
91
        File::create(path)?.write_all(pretty_json.as_bytes())?;
92
10
    }
93
10
    Ok(())
94
14
}
95

            
96
/// Returns a new Context and Solver for solving.
97
16
pub(crate) fn init_context(
98
16
    global_args: &GlobalArgs,
99
16
    input_file: PathBuf,
100
16
) -> anyhow::Result<Arc<RwLock<Context<'static>>>> {
101
16
    let target_family = global_args.solver;
102
16
    let mut extra_rule_sets: Vec<&str> = DEFAULT_RULE_SETS.to_vec();
103
16
    for rs in &global_args.extra_rule_sets {
104
        extra_rule_sets.push(rs.as_str());
105
    }
106

            
107
16
    if let SolverFamily::Sat(sat_encoding) = target_family {
108
        extra_rule_sets.push(sat_encoding.as_rule_set());
109
16
    }
110

            
111
16
    let rule_sets = match resolve_rule_sets(target_family, &extra_rule_sets) {
112
16
        Ok(rs) => rs,
113
        Err(e) => {
114
            tracing::error!("Error resolving rule sets: {}", e);
115
            exit(1);
116
        }
117
    };
118

            
119
16
    let pretty_rule_sets = rule_sets
120
16
        .iter()
121
16
        .map(|rule_set| rule_set.name)
122
16
        .collect::<Vec<_>>()
123
16
        .join(", ");
124

            
125
16
    tracing::info!("Enabled rule sets: [{}]", pretty_rule_sets);
126
16
    tracing::info!(
127
        target: "file",
128
        "Rule sets: {}",
129
        pretty_rule_sets
130
    );
131

            
132
16
    let rules = get_rules(&rule_sets)?.into_iter().collect::<Vec<_>>();
133
16
    tracing::info!(
134
        target: "file",
135
        "Rules: {}",
136
552
        rules.iter().map(|rd| format!("{rd}")).collect::<Vec<_>>().join("\n")
137
    );
138
16
    let context = Context::new_ptr(
139
16
        target_family,
140
48
        extra_rule_sets.iter().map(|rs| rs.to_string()).collect(),
141
16
        rules,
142
16
        rule_sets.clone(),
143
    );
144

            
145
16
    context.write().unwrap().file_name = Some(input_file.to_str().expect("").into());
146

            
147
16
    Ok(context)
148
16
}
149

            
150
10
pub(crate) fn init_solver(global_args: &GlobalArgs) -> Solver {
151
10
    let family = global_args.solver;
152
    #[cfg(feature = "smt")]
153
10
    let timeout_ms = global_args
154
10
        .solver_timeout
155
10
        .map(|dur| Duration::from(dur).as_millis())
156
10
        .map(|timeout_ms| u64::try_from(timeout_ms).expect("Timeout too large"));
157

            
158
10
    match family {
159
10
        SolverFamily::Minion => Solver::new(Minion::default()),
160
        SolverFamily::Sat(_) => Solver::new(Sat::default()),
161
        #[cfg(feature = "smt")]
162
        SolverFamily::Smt(theory_cfg) => Solver::new(Smt::new(timeout_ms, theory_cfg)),
163
    }
164
10
}
165

            
166
16
pub(crate) fn parse(
167
16
    global_args: &GlobalArgs,
168
16
    context: Arc<RwLock<Context<'static>>>,
169
16
) -> anyhow::Result<Model> {
170
16
    let input_file: String = context
171
16
        .read()
172
16
        .unwrap()
173
16
        .file_name
174
16
        .clone()
175
16
        .expect("context should contain the input file");
176

            
177
16
    tracing::info!(target: "file", "Input file: {}", input_file);
178
16
    match global_args.parser {
179
        conjure_cp::settings::Parser::TreeSitter => {
180
16
            parse_essence_file_native(input_file.as_str(), context.clone()).map_err(|e| e.into())
181
        }
182
        conjure_cp::settings::Parser::ViaConjure => {
183
            conjure_executable()
184
                .map_err(|e| anyhow!("Could not find correct conjure executable: {e}"))?;
185

            
186
            let mut cmd = std::process::Command::new("conjure");
187
            let output = cmd
188
                .arg("pretty")
189
                .arg("--output-format=astjson")
190
                .arg(input_file)
191
                .output()?;
192

            
193
            let conjure_stderr = String::from_utf8(output.stderr)?;
194

            
195
            ensure!(conjure_stderr.is_empty(), conjure_stderr);
196

            
197
            let astjson = String::from_utf8(output.stdout)?;
198

            
199
            if cfg!(feature = "extra-rule-checks") {
200
                tracing::info!("extra-rule-checks: enabled");
201
            } else {
202
                tracing::info!("extra-rule-checks: disabled");
203
            }
204

            
205
            model_from_json(&astjson, context.clone()).map_err(|e| anyhow!(e))
206
        }
207
    }
208
16
}
209

            
210
10
pub(crate) fn rewrite(
211
10
    model: Model,
212
10
    global_args: &GlobalArgs,
213
10
    context: Arc<RwLock<Context<'static>>>,
214
10
) -> anyhow::Result<Model> {
215
10
    tracing::info!("Initial model: \n{}\n", model);
216

            
217
10
    let quantified_expander = global_args.quantified_expander;
218
10
    set_quantified_expander_for_comprehensions(quantified_expander);
219
10
    tracing::info!("Quantified expander: {}", quantified_expander);
220

            
221
10
    let rule_sets = context.read().unwrap().rule_sets.clone();
222

            
223
10
    let new_model = match global_args.rewriter {
224
        Rewriter::Morph => {
225
            USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS
226
                .store(true, std::sync::atomic::Ordering::Relaxed);
227
            tracing::info!("Rewriting the model using the morph rewriter");
228
            rewrite_morph(
229
                model,
230
                &rule_sets,
231
                global_args.check_equally_applicable_rules,
232
            )
233
        }
234
        Rewriter::Naive => {
235
10
            USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS
236
10
                .store(false, std::sync::atomic::Ordering::Relaxed);
237
10
            tracing::info!("Rewriting the model using the default / naive rewriter");
238
10
            if global_args.exit_after_unrolling {
239
                tracing::info!("Exiting after unrolling");
240
10
            }
241
10
            rewrite_naive(
242
10
                &model,
243
10
                &rule_sets,
244
10
                global_args.check_equally_applicable_rules,
245
10
                global_args.exit_after_unrolling,
246
            )?
247
        }
248
    };
249

            
250
10
    tracing::info!("Rewritten model: \n{}\n", new_model);
251
10
    Ok(new_model)
252
10
}
253

            
254
10
fn run_solver(
255
10
    solver: Solver,
256
10
    global_args: &GlobalArgs,
257
10
    cmd_args: &Args,
258
10
    model: Model,
259
10
) -> anyhow::Result<()> {
260
10
    let out_file: Option<File> = match &cmd_args.output {
261
10
        None => None,
262
        Some(pth) => Some(
263
            File::options()
264
                .create(true)
265
                .truncate(true)
266
                .write(true)
267
                .open(pth)?,
268
        ),
269
    };
270

            
271
10
    let solutions = get_solutions(
272
10
        solver,
273
10
        model,
274
10
        cmd_args.number_of_solutions,
275
10
        &global_args.save_solver_input_file,
276
    )?;
277
10
    tracing::info!(target: "file", "Solutions: {}", solutions_to_json(&solutions));
278

            
279
10
    let solutions_json = solutions_to_json(&solutions);
280
10
    let solutions_str = to_string_pretty(&solutions_json)?;
281
10
    match out_file {
282
10
        None => {
283
10
            println!("Solutions:");
284
10
            println!("{solutions_str}");
285
10
        }
286
        Some(mut outf) => {
287
            outf.write_all(solutions_str.as_bytes())?;
288
            println!(
289
                "Solutions saved to {:?}",
290
                &cmd_args.output.clone().unwrap().canonicalize()?
291
            )
292
        }
293
    }
294
10
    Ok(())
295
10
}