1
//! conjure_oxide solve sub-command
2
#![allow(clippy::unwrap_used)]
3
#[cfg(feature = "smt")]
4
use std::time::Duration;
5
use std::{
6
    fs::File,
7
    io::Write as _,
8
    path::PathBuf,
9
    process::exit,
10
    sync::{Arc, RwLock},
11
};
12

            
13
use anyhow::{anyhow, ensure};
14
use clap::ValueHint;
15
use conjure_cp::defaults::DEFAULT_RULE_SETS;
16
use conjure_cp::{
17
    Model,
18
    context::Context,
19
    rule_engine::{resolve_rule_sets, rewrite_morph, rewrite_naive},
20
    settings::{
21
        Rewriter, set_comprehension_expander, set_current_parser, set_current_rewriter,
22
        set_current_solver_family, set_minion_discrete_threshold,
23
    },
24
    solver::Solver,
25
};
26
use conjure_cp::{
27
    parse::conjure_json::model_from_json, rule_engine::get_rules, settings::SolverFamily,
28
};
29
use conjure_cp::{parse::tree_sitter::parse_essence_file_native, solver::adaptors::*};
30
use conjure_cp_cli::find_conjure::conjure_executable;
31
use conjure_cp_cli::utils::conjure::{get_solutions, solutions_to_json};
32
use serde_json::to_string_pretty;
33

            
34
use crate::cli::{GlobalArgs, LOGGING_HELP_HEADING};
35

            
36
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
37
pub enum NumberOfSolutions {
38
    All,
39
    Limit(i32),
40
}
41

            
42
impl NumberOfSolutions {
43
10
    fn as_solver_limit(self) -> i32 {
44
10
        match self {
45
10
            NumberOfSolutions::All => 0,
46
            NumberOfSolutions::Limit(limit) => limit,
47
        }
48
10
    }
49
}
50

            
51
18
fn parse_number_of_solutions(input: &str) -> Result<NumberOfSolutions, String> {
52
18
    if input.eq_ignore_ascii_case("all") {
53
10
        return Ok(NumberOfSolutions::All);
54
8
    }
55

            
56
8
    let limit = input
57
8
        .parse::<i32>()
58
8
        .map_err(|_| "expected a positive integer or 'all'".to_string())?;
59

            
60
8
    if limit <= 0 {
61
        return Err("expected a positive integer or 'all'".to_string());
62
8
    }
63

            
64
8
    Ok(NumberOfSolutions::Limit(limit))
65
18
}
66

            
67
#[derive(Clone, Debug, clap::Args)]
68
pub struct Args {
69
    /// The input Essence file
70
    #[arg(value_name = "INPUT_ESSENCE", value_hint = ValueHint::FilePath)]
71
    pub input_file: PathBuf,
72

            
73
    /// Save execution info as JSON to the given filepath.
74
    #[arg(long ,value_hint=ValueHint::FilePath,help_heading=LOGGING_HELP_HEADING)]
75
    pub info_json_path: Option<PathBuf>,
76

            
77
    /// Do not run the solver.
78
    ///
79
    /// The rewritten model is printed to stdout in an Essence-style syntax
80
    /// (but is not necessarily valid Essence).
81
    #[arg(long, default_value_t = false)]
82
    pub no_run_solver: bool,
83

            
84
    /// Number of solutions to return. Use a positive integer, or `all`.
85
    #[arg(
86
        long,
87
        short = 'n',
88
        default_value = "1",
89
        value_name = "N|all",
90
        value_parser = parse_number_of_solutions
91
    )]
92
    pub number_of_solutions: NumberOfSolutions,
93

            
94
    /// Save solutions to the given JSON file
95
    #[arg(long, short = 'o', value_hint = ValueHint::FilePath,help_heading=LOGGING_HELP_HEADING)]
96
    pub output: Option<PathBuf>,
97
}
98

            
99
18
pub fn run_solve_command(global_args: GlobalArgs, solve_args: Args) -> anyhow::Result<()> {
100
18
    let input_file = solve_args.input_file.clone();
101

            
102
    // each step is in its own method so that similar commands
103
    // (e.g. testsolve) can reuse some of these steps.
104

            
105
18
    let context = init_context(&global_args, input_file)?;
106
18
    let model = parse(&global_args, Arc::clone(&context))?;
107
14
    let rewritten_model = rewrite(model, &global_args, Arc::clone(&context))?;
108
14
    let solver = init_solver(&global_args);
109

            
110
14
    if solve_args.no_run_solver {
111
4
        println!("{}", &rewritten_model);
112

            
113
4
        if let Some(path) = global_args.save_solver_input_file {
114
            let solver = solver.load_model(rewritten_model)?;
115
            eprintln!("Writing solver input file to {}", path.display());
116
            let mut file: Box<dyn std::io::Write> = Box::new(File::create(path)?);
117
            solver.write_solver_input_file(&mut file)?;
118
4
        }
119
    } else {
120
10
        run_solver(solver, &global_args, &solve_args, rewritten_model)?
121
    }
122

            
123
    // still do postamble even if we didn't run the solver
124
14
    if let Some(ref path) = solve_args.info_json_path {
125
        let context_obj = context.read().unwrap().clone();
126
        let generated_json = &serde_json::to_value(context_obj)?;
127
        let pretty_json = serde_json::to_string_pretty(&generated_json)?;
128
        File::create(path)?.write_all(pretty_json.as_bytes())?;
129
14
    }
130
14
    Ok(())
131
18
}
132

            
133
/// Returns a new Context and Solver for solving.
134
32
pub(crate) fn init_context(
135
32
    global_args: &GlobalArgs,
136
32
    input_file: PathBuf,
137
32
) -> anyhow::Result<Arc<RwLock<Context<'static>>>> {
138
32
    set_current_parser(global_args.parser);
139
32
    set_current_rewriter(global_args.rewriter);
140
32
    set_comprehension_expander(global_args.comprehension_expander);
141
32
    set_current_solver_family(global_args.solver);
142
32
    set_minion_discrete_threshold(global_args.minion_discrete_threshold);
143

            
144
32
    let target_family = global_args.solver;
145
32
    let mut extra_rule_sets: Vec<&str> = DEFAULT_RULE_SETS.to_vec();
146
32
    for rs in &global_args.extra_rule_sets {
147
        extra_rule_sets.push(rs.as_str());
148
    }
149

            
150
32
    if let SolverFamily::Sat(sat_encoding) = target_family {
151
        extra_rule_sets.push(sat_encoding.as_rule_set());
152
32
    }
153

            
154
32
    let rule_sets = match resolve_rule_sets(target_family, &extra_rule_sets) {
155
32
        Ok(rs) => rs,
156
        Err(e) => {
157
            tracing::error!("Error resolving rule sets: {}", e);
158
            exit(1);
159
        }
160
    };
161

            
162
32
    let pretty_rule_sets = rule_sets
163
32
        .iter()
164
32
        .map(|rule_set| rule_set.name)
165
32
        .collect::<Vec<_>>()
166
32
        .join(", ");
167

            
168
32
    tracing::info!("Enabled rule sets: [{}]", pretty_rule_sets);
169
32
    tracing::info!(
170
        target: "file",
171
        "Rule sets: {}",
172
        pretty_rule_sets
173
    );
174

            
175
32
    let rules = get_rules(&rule_sets)?.into_iter().collect::<Vec<_>>();
176
32
    tracing::info!(
177
        target: "file",
178
        "Rules: {}",
179
950
        rules.iter().map(|rd| format!("{rd}")).collect::<Vec<_>>().join("\n")
180
    );
181
32
    let context = Context::new_ptr(
182
32
        target_family,
183
96
        extra_rule_sets.iter().map(|rs| rs.to_string()).collect(),
184
32
        rules,
185
32
        rule_sets.clone(),
186
    );
187

            
188
32
    context.write().unwrap().file_name = Some(input_file.to_str().expect("").into());
189

            
190
32
    Ok(context)
191
32
}
192

            
193
14
pub(crate) fn init_solver(global_args: &GlobalArgs) -> Solver {
194
14
    let family = global_args.solver;
195
    #[cfg(feature = "smt")]
196
14
    let timeout_ms = global_args
197
14
        .solver_timeout
198
14
        .map(|dur| Duration::from(dur).as_millis())
199
14
        .map(|timeout_ms| u64::try_from(timeout_ms).expect("Timeout too large"));
200

            
201
14
    match family {
202
14
        SolverFamily::Minion => Solver::new(Minion::default()),
203
        SolverFamily::Sat(_) => Solver::new(Sat::default()),
204
        #[cfg(feature = "smt")]
205
        SolverFamily::Smt(theory_cfg) => Solver::new(Smt::new(timeout_ms, theory_cfg)),
206
    }
207
14
}
208

            
209
32
pub(crate) fn parse(
210
32
    global_args: &GlobalArgs,
211
32
    context: Arc<RwLock<Context<'static>>>,
212
32
) -> anyhow::Result<Model> {
213
32
    let input_file: String = context
214
32
        .read()
215
32
        .unwrap()
216
32
        .file_name
217
32
        .clone()
218
32
        .expect("context should contain the input file");
219

            
220
32
    tracing::info!(target: "file", "Input file: {}", input_file);
221
32
    match global_args.parser {
222
        conjure_cp::settings::Parser::TreeSitter => {
223
14
            parse_essence_file_native(input_file.as_str(), context.clone()).map_err(|e| e.into())
224
        }
225
        conjure_cp::settings::Parser::ViaConjure => {
226
18
            conjure_executable()
227
18
                .map_err(|e| anyhow!("Could not find correct conjure executable: {e}"))?;
228

            
229
18
            let mut cmd = std::process::Command::new("conjure");
230
18
            let output = cmd
231
18
                .arg("pretty")
232
18
                .arg("--output-format=astjson")
233
18
                .arg(input_file)
234
18
                .output()?;
235

            
236
18
            let conjure_stderr = String::from_utf8(output.stderr)?;
237

            
238
18
            ensure!(conjure_stderr.is_empty(), conjure_stderr);
239

            
240
18
            let astjson = String::from_utf8(output.stdout)?;
241

            
242
18
            if cfg!(feature = "extra-rule-checks") {
243
18
                tracing::info!("extra-rule-checks: enabled");
244
            } else {
245
                tracing::info!("extra-rule-checks: disabled");
246
            }
247

            
248
18
            model_from_json(&astjson, context.clone()).map_err(|e| anyhow!(e))
249
        }
250
    }
251
32
}
252

            
253
14
pub(crate) fn rewrite(
254
14
    model: Model,
255
14
    global_args: &GlobalArgs,
256
14
    context: Arc<RwLock<Context<'static>>>,
257
14
) -> anyhow::Result<Model> {
258
14
    tracing::info!("Initial model: \n{}\n", model);
259

            
260
14
    set_current_rewriter(global_args.rewriter);
261

            
262
14
    let comprehension_expander = global_args.comprehension_expander;
263
14
    set_comprehension_expander(comprehension_expander);
264
14
    tracing::info!("Comprehension expander: {}", comprehension_expander);
265

            
266
14
    let rule_sets = context.read().unwrap().rule_sets.clone();
267

            
268
14
    let new_model = match global_args.rewriter {
269
        Rewriter::Morph => {
270
            tracing::info!("Rewriting the model using the morph rewriter");
271
            rewrite_morph(
272
                model,
273
                &rule_sets,
274
                global_args.check_equally_applicable_rules,
275
            )
276
        }
277
        Rewriter::Naive => {
278
14
            tracing::info!("Rewriting the model using the default / naive rewriter");
279
14
            rewrite_naive(
280
14
                &model,
281
14
                &rule_sets,
282
14
                global_args.check_equally_applicable_rules,
283
            )?
284
        }
285
    };
286

            
287
14
    tracing::info!("Rewritten model: \n{}\n", new_model);
288
14
    Ok(new_model)
289
14
}
290

            
291
10
fn run_solver(
292
10
    solver: Solver,
293
10
    global_args: &GlobalArgs,
294
10
    cmd_args: &Args,
295
10
    model: Model,
296
10
) -> anyhow::Result<()> {
297
10
    let out_file: Option<File> = match &cmd_args.output {
298
10
        None => None,
299
        Some(pth) => Some(
300
            File::options()
301
                .create(true)
302
                .truncate(true)
303
                .write(true)
304
                .open(pth)?,
305
        ),
306
    };
307

            
308
10
    let solutions = get_solutions(
309
10
        solver,
310
10
        model,
311
10
        cmd_args.number_of_solutions.as_solver_limit(),
312
10
        &global_args.save_solver_input_file,
313
    )?;
314
10
    tracing::info!(target: "file", "Solutions: {}", solutions_to_json(&solutions));
315

            
316
10
    let solutions_json = solutions_to_json(&solutions);
317
10
    let solutions_str = to_string_pretty(&solutions_json)?;
318
10
    match out_file {
319
10
        None => {
320
10
            println!("Solutions:");
321
10
            println!("{solutions_str}");
322
10
        }
323
        Some(mut outf) => {
324
            outf.write_all(solutions_str.as_bytes())?;
325
            println!(
326
                "Solutions saved to {:?}",
327
                &cmd_args.output.clone().unwrap().canonicalize()?
328
            )
329
        }
330
    }
331
10
    Ok(())
332
10
}