1
//! conjure_oxide solve sub-command
2
#![allow(clippy::unwrap_used)]
3
use std::time::Duration;
4
use std::{
5
    fs::File,
6
    io::Write as _,
7
    path::PathBuf,
8
    process::exit,
9
    sync::{Arc, RwLock},
10
};
11

            
12
use anyhow::{anyhow, ensure};
13
use clap::ValueHint;
14
use conjure_cp::defaults::DEFAULT_RULE_SETS;
15
use conjure_cp::{
16
    Model,
17
    context::Context,
18
    rule_engine::{resolve_rule_sets, rewrite_morph, rewrite_naive},
19
    settings::{
20
        Rewriter, set_comprehension_expander, set_current_parser, set_current_rewriter,
21
        set_current_solver_family, set_minion_discrete_threshold,
22
    },
23
    solver::Solver,
24
};
25
use conjure_cp::{
26
    parse::conjure_json::model_from_json, rule_engine::get_rules, settings::SolverFamily,
27
};
28
use conjure_cp::{parse::tree_sitter::parse_essence_file_native, solver::adaptors::*};
29
use conjure_cp_cli::find_conjure::conjure_executable;
30
use conjure_cp_cli::utils::conjure::{get_solutions, solutions_to_json};
31
use serde_json::to_string_pretty;
32

            
33
use crate::cli::{GlobalArgs, LOGGING_HELP_HEADING};
34

            
35
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
36
pub enum NumberOfSolutions {
37
    All,
38
    Limit(i32),
39
}
40

            
41
impl NumberOfSolutions {
42
20
    fn as_solver_limit(self) -> i32 {
43
20
        match self {
44
20
            NumberOfSolutions::All => 0,
45
            NumberOfSolutions::Limit(limit) => limit,
46
        }
47
20
    }
48
}
49

            
50
36
fn parse_number_of_solutions(input: &str) -> Result<NumberOfSolutions, String> {
51
36
    if input.eq_ignore_ascii_case("all") {
52
20
        return Ok(NumberOfSolutions::All);
53
16
    }
54

            
55
16
    let limit = input
56
16
        .parse::<i32>()
57
16
        .map_err(|_| "expected a positive integer or 'all'".to_string())?;
58

            
59
16
    if limit <= 0 {
60
        return Err("expected a positive integer or 'all'".to_string());
61
16
    }
62

            
63
16
    Ok(NumberOfSolutions::Limit(limit))
64
36
}
65

            
66
#[derive(Clone, Debug, clap::Args)]
67
pub struct Args {
68
    /// The input Essence file
69
    #[arg(value_name = "INPUT_ESSENCE", value_hint = ValueHint::FilePath)]
70
    pub input_file: PathBuf,
71

            
72
    /// Save execution info as JSON to the given filepath.
73
    #[arg(long ,value_hint=ValueHint::FilePath,help_heading=LOGGING_HELP_HEADING)]
74
    pub info_json_path: Option<PathBuf>,
75

            
76
    /// Do not run the solver.
77
    ///
78
    /// The rewritten model is printed to stdout in an Essence-style syntax
79
    /// (but is not necessarily valid Essence).
80
    #[arg(long, default_value_t = false)]
81
    pub no_run_solver: bool,
82

            
83
    /// Number of solutions to return. Use a positive integer, or `all`.
84
    #[arg(
85
        long,
86
        short = 'n',
87
        default_value = "1",
88
        value_name = "N|all",
89
        value_parser = parse_number_of_solutions
90
    )]
91
    pub number_of_solutions: NumberOfSolutions,
92

            
93
    /// Save solutions to the given JSON file
94
    #[arg(long, short = 'o', value_hint = ValueHint::FilePath,help_heading=LOGGING_HELP_HEADING)]
95
    pub output: Option<PathBuf>,
96
}
97

            
98
36
pub fn run_solve_command(global_args: GlobalArgs, solve_args: Args) -> anyhow::Result<()> {
99
36
    let input_file = solve_args.input_file.clone();
100

            
101
    // each step is in its own method so that similar commands
102
    // (e.g. testsolve) can reuse some of these steps.
103

            
104
36
    let context = init_context(&global_args, input_file)?;
105
36
    let model = parse(&global_args, Arc::clone(&context))?;
106
28
    let rewritten_model = rewrite(model, &global_args, Arc::clone(&context))?;
107
28
    let solver = init_solver(&global_args);
108

            
109
28
    if solve_args.no_run_solver {
110
8
        println!("{}", &rewritten_model);
111

            
112
8
        if let Some(path) = global_args.save_solver_input_file {
113
            let solver = solver.load_model(rewritten_model)?;
114
            eprintln!("Writing solver input file to {}", path.display());
115
            let mut file: Box<dyn std::io::Write> = Box::new(File::create(path)?);
116
            solver.write_solver_input_file(&mut file)?;
117
8
        }
118
    } else {
119
20
        run_solver(solver, &global_args, &solve_args, rewritten_model)?
120
    }
121

            
122
    // still do postamble even if we didn't run the solver
123
28
    if let Some(ref path) = solve_args.info_json_path {
124
        let context_obj = context.read().unwrap().clone();
125
        let generated_json = &serde_json::to_value(context_obj)?;
126
        let pretty_json = serde_json::to_string_pretty(&generated_json)?;
127
        File::create(path)?.write_all(pretty_json.as_bytes())?;
128
28
    }
129
28
    Ok(())
130
36
}
131

            
132
/// Returns a new Context and Solver for solving.
133
64
pub(crate) fn init_context(
134
64
    global_args: &GlobalArgs,
135
64
    input_file: PathBuf,
136
64
) -> anyhow::Result<Arc<RwLock<Context<'static>>>> {
137
64
    set_current_parser(global_args.parser);
138
64
    set_current_rewriter(global_args.rewriter);
139
64
    set_comprehension_expander(global_args.comprehension_expander);
140
64
    set_current_solver_family(global_args.solver);
141
64
    set_minion_discrete_threshold(global_args.minion_discrete_threshold);
142

            
143
64
    let target_family = global_args.solver;
144
64
    let mut extra_rule_sets: Vec<&str> = DEFAULT_RULE_SETS.to_vec();
145
64
    for rs in &global_args.extra_rule_sets {
146
        extra_rule_sets.push(rs.as_str());
147
    }
148

            
149
64
    if let SolverFamily::Sat(sat_encoding) = target_family {
150
        extra_rule_sets.push(sat_encoding.as_rule_set());
151
64
    }
152

            
153
64
    let rule_sets = match resolve_rule_sets(target_family, &extra_rule_sets) {
154
64
        Ok(rs) => rs,
155
        Err(e) => {
156
            tracing::error!("Error resolving rule sets: {}", e);
157
            exit(1);
158
        }
159
    };
160

            
161
64
    let pretty_rule_sets = rule_sets
162
64
        .iter()
163
64
        .map(|rule_set| rule_set.name)
164
64
        .collect::<Vec<_>>()
165
64
        .join(", ");
166

            
167
64
    tracing::info!("Enabled rule sets: [{}]", pretty_rule_sets);
168
64
    tracing::info!(
169
        target: "file",
170
        "Rule sets: {}",
171
        pretty_rule_sets
172
    );
173

            
174
64
    let rules = get_rules(&rule_sets)?.into_iter().collect::<Vec<_>>();
175
64
    tracing::info!(
176
        target: "file",
177
        "Rules: {}",
178
1900
        rules.iter().map(|rd| format!("{rd}")).collect::<Vec<_>>().join("\n")
179
    );
180
64
    let context = Context::new_ptr(
181
64
        target_family,
182
192
        extra_rule_sets.iter().map(|rs| rs.to_string()).collect(),
183
64
        rules,
184
64
        rule_sets.clone(),
185
    );
186

            
187
64
    context.write().unwrap().file_name = Some(input_file.to_str().expect("").into());
188

            
189
64
    Ok(context)
190
64
}
191

            
192
28
pub(crate) fn init_solver(global_args: &GlobalArgs) -> Solver {
193
28
    let family = global_args.solver;
194
28
    let timeout_ms = global_args
195
28
        .solver_timeout
196
28
        .map(|dur| Duration::from(dur).as_millis())
197
28
        .map(|timeout_ms| u64::try_from(timeout_ms).expect("Timeout too large"));
198

            
199
28
    match family {
200
28
        SolverFamily::Minion => Solver::new(Minion::default()),
201
        SolverFamily::Sat(_) => Solver::new(Sat::default()),
202
        SolverFamily::Smt(theory_cfg) => Solver::new(Smt::new(timeout_ms, theory_cfg)),
203
    }
204
28
}
205

            
206
64
pub(crate) fn parse(
207
64
    global_args: &GlobalArgs,
208
64
    context: Arc<RwLock<Context<'static>>>,
209
64
) -> anyhow::Result<Model> {
210
64
    let input_file: String = context
211
64
        .read()
212
64
        .unwrap()
213
64
        .file_name
214
64
        .clone()
215
64
        .expect("context should contain the input file");
216

            
217
64
    tracing::info!(target: "file", "Input file: {}", input_file);
218
64
    match global_args.parser {
219
        conjure_cp::settings::Parser::TreeSitter => {
220
28
            parse_essence_file_native(input_file.as_str(), context.clone()).map_err(|e| e.into())
221
        }
222
        conjure_cp::settings::Parser::ViaConjure => {
223
36
            conjure_executable()
224
36
                .map_err(|e| anyhow!("Could not find correct conjure executable: {e}"))?;
225

            
226
36
            let mut cmd = std::process::Command::new("conjure");
227
36
            let output = cmd
228
36
                .arg("pretty")
229
36
                .arg("--output-format=astjson")
230
36
                .arg(input_file)
231
36
                .output()?;
232

            
233
36
            let conjure_stderr = String::from_utf8(output.stderr)?;
234

            
235
36
            ensure!(conjure_stderr.is_empty(), conjure_stderr);
236

            
237
36
            let astjson = String::from_utf8(output.stdout)?;
238

            
239
36
            if cfg!(feature = "extra-rule-checks") {
240
36
                tracing::info!("extra-rule-checks: enabled");
241
            } else {
242
                tracing::info!("extra-rule-checks: disabled");
243
            }
244

            
245
36
            model_from_json(&astjson, context.clone()).map_err(|e| anyhow!(e))
246
        }
247
    }
248
64
}
249

            
250
28
pub(crate) fn rewrite(
251
28
    model: Model,
252
28
    global_args: &GlobalArgs,
253
28
    context: Arc<RwLock<Context<'static>>>,
254
28
) -> anyhow::Result<Model> {
255
28
    tracing::info!("Initial model: \n{}\n", model);
256

            
257
28
    set_current_rewriter(global_args.rewriter);
258

            
259
28
    let comprehension_expander = global_args.comprehension_expander;
260
28
    set_comprehension_expander(comprehension_expander);
261
28
    tracing::info!("Comprehension expander: {}", comprehension_expander);
262

            
263
28
    let rule_sets = context.read().unwrap().rule_sets.clone();
264

            
265
28
    let new_model = match global_args.rewriter {
266
        Rewriter::Morph => {
267
            tracing::info!("Rewriting the model using the morph rewriter");
268
            rewrite_morph(
269
                model,
270
                &rule_sets,
271
                global_args.check_equally_applicable_rules,
272
            )
273
        }
274
        Rewriter::Naive => {
275
28
            tracing::info!("Rewriting the model using the default / naive rewriter");
276
28
            rewrite_naive(
277
28
                &model,
278
28
                &rule_sets,
279
28
                global_args.check_equally_applicable_rules,
280
            )?
281
        }
282
    };
283

            
284
28
    tracing::info!("Rewritten model: \n{}\n", new_model);
285
28
    Ok(new_model)
286
28
}
287

            
288
20
fn run_solver(
289
20
    solver: Solver,
290
20
    global_args: &GlobalArgs,
291
20
    cmd_args: &Args,
292
20
    model: Model,
293
20
) -> anyhow::Result<()> {
294
20
    let out_file: Option<File> = match &cmd_args.output {
295
20
        None => None,
296
        Some(pth) => Some(
297
            File::options()
298
                .create(true)
299
                .truncate(true)
300
                .write(true)
301
                .open(pth)?,
302
        ),
303
    };
304

            
305
20
    let solutions = get_solutions(
306
20
        solver,
307
20
        model,
308
20
        cmd_args.number_of_solutions.as_solver_limit(),
309
20
        &global_args.save_solver_input_file,
310
    )?;
311
20
    tracing::info!(target: "file", "Solutions: {}", solutions_to_json(&solutions));
312

            
313
20
    let solutions_json = solutions_to_json(&solutions);
314
20
    let solutions_str = to_string_pretty(&solutions_json)?;
315
20
    match out_file {
316
20
        None => {
317
20
            println!("Solutions:");
318
20
            println!("{solutions_str}");
319
20
        }
320
        Some(mut outf) => {
321
            outf.write_all(solutions_str.as_bytes())?;
322
            println!(
323
                "Solutions saved to {:?}",
324
                &cmd_args.output.clone().unwrap().canonicalize()?
325
            )
326
        }
327
    }
328
20
    Ok(())
329
20
}