Skip to main content

conjure_oxide/
solve.rs

1//! conjure_oxide solve sub-command
2#![allow(clippy::unwrap_used)]
3use std::time::Duration;
4use std::{
5    fs::File,
6    io::Write as _,
7    path::PathBuf,
8    process::exit,
9    sync::{Arc, RwLock},
10};
11
12use anyhow::anyhow;
13use clap::ValueHint;
14use conjure_cp::instantiate::instantiate_model;
15use conjure_cp::{
16    Model,
17    context::Context,
18    defaults::DEFAULT_RULE_SETS,
19    rule_engine::{resolve_rule_sets, rewrite_morph, rewrite_naive},
20    settings::{
21        Rewriter, set_comprehension_expander, set_current_parser, set_current_rewriter,
22        set_current_solver_family, set_default_rule_trace_enabled, set_minion_discrete_threshold,
23        set_rule_trace_aggregates_enabled, set_rule_trace_enabled, set_rule_trace_verbose_enabled,
24    },
25    solver::Solver,
26};
27use conjure_cp::{
28    parse::conjure_json::model_from_json, rule_engine::get_rules, settings::SolverFamily,
29};
30use conjure_cp::{parse::tree_sitter::parse_essence_file_native, solver::adaptors::*};
31use conjure_cp_cli::find_conjure::conjure_executable;
32use conjure_cp_cli::utils::conjure::{get_solutions, solutions_to_json};
33use serde_json::to_string_pretty;
34
35use crate::cli::{GlobalArgs, LOGGING_HELP_HEADING};
36
37#[derive(Clone, Copy, Debug, PartialEq, Eq)]
38pub enum NumberOfSolutions {
39    All,
40    Limit(i32),
41}
42
43impl NumberOfSolutions {
44    fn as_solver_limit(self) -> i32 {
45        match self {
46            NumberOfSolutions::All => 0,
47            NumberOfSolutions::Limit(limit) => limit,
48        }
49    }
50}
51
52fn parse_number_of_solutions(input: &str) -> Result<NumberOfSolutions, String> {
53    if input.eq_ignore_ascii_case("all") {
54        return Ok(NumberOfSolutions::All);
55    }
56
57    let limit = input
58        .parse::<i32>()
59        .map_err(|_| "expected a positive integer or 'all'".to_string())?;
60
61    if limit <= 0 {
62        return Err("expected a positive integer or 'all'".to_string());
63    }
64
65    Ok(NumberOfSolutions::Limit(limit))
66}
67
68#[derive(Clone, Debug, clap::Args)]
69pub struct Args {
70    /// The input Essence problem file
71    #[arg(value_name = "INPUT_ESSENCE", value_hint = ValueHint::FilePath)]
72    pub essence_file: PathBuf,
73
74    /// The input Essence parameter file
75    #[arg(value_name = "PARAM_ESSENCE", value_hint = ValueHint::FilePath)]
76    pub param_file: Option<PathBuf>,
77
78    /// Save execution info as JSON to the given filepath.
79    #[arg(long ,value_hint=ValueHint::FilePath,help_heading=LOGGING_HELP_HEADING)]
80    pub info_json_path: Option<PathBuf>,
81
82    /// Do not run the solver.
83    ///
84    /// The rewritten model is printed to stdout in an Essence-style syntax
85    /// (but is not necessarily valid Essence).
86    #[arg(long, default_value_t = false)]
87    pub no_run_solver: bool,
88
89    /// Number of solutions to return. Use a positive integer, or `all`.
90    #[arg(
91        long,
92        short = 'n',
93        default_value = "1",
94        value_name = "N|all",
95        value_parser = parse_number_of_solutions
96    )]
97    pub number_of_solutions: NumberOfSolutions,
98
99    /// Save solutions to the given JSON file
100    #[arg(long, short = 'o', value_hint = ValueHint::FilePath,help_heading=LOGGING_HELP_HEADING)]
101    pub output: Option<PathBuf>,
102}
103
104pub fn run_solve_command(global_args: GlobalArgs, solve_args: Args) -> anyhow::Result<()> {
105    let essence_file = solve_args.essence_file.clone();
106    let param_file = solve_args.param_file.clone();
107
108    // each step is in its own method so that similar commands
109    // (e.g. testsolve) can reuse some of these steps.
110
111    let context = init_context(&global_args, essence_file, param_file)?;
112
113    let ctx_lock = context.read().unwrap();
114    let essence_file_name = ctx_lock
115        .essence_file_name
116        .as_ref()
117        .expect("context should contain the problem input file");
118    let param_file_name = ctx_lock.param_file_name.as_ref();
119
120    // parse models
121    let problem_model = parse(&global_args, Arc::clone(&context), essence_file_name)?;
122
123    // unify models
124    let unified_model = match param_file_name {
125        Some(param_file_name) => {
126            let param_model = parse(&global_args, Arc::clone(&context), param_file_name)?;
127            instantiate_model(problem_model, param_model)?
128        }
129        None => problem_model,
130    };
131    drop(ctx_lock);
132
133    let rewritten_model = rewrite(unified_model, &global_args, Arc::clone(&context))?;
134
135    let solver = init_solver(&global_args);
136
137    if solve_args.no_run_solver {
138        println!("{}", &rewritten_model);
139
140        if let Some(path) = global_args.save_solver_input_file {
141            let solver = solver.load_model(rewritten_model)?;
142            eprintln!("Writing solver input file to {}", path.display());
143            let mut file: Box<dyn std::io::Write> = Box::new(File::create(path)?);
144            solver.write_solver_input_file(&mut file)?;
145        }
146    } else {
147        run_solver(solver, &global_args, &solve_args, rewritten_model)?
148    }
149
150    // still do postamble even if we didn't run the solver
151    if let Some(ref path) = solve_args.info_json_path {
152        let context_obj = context.read().unwrap().clone();
153        let generated_json = &serde_json::to_value(context_obj)?;
154        let pretty_json = serde_json::to_string_pretty(&generated_json)?;
155        File::create(path)?.write_all(pretty_json.as_bytes())?;
156    }
157    Ok(())
158}
159
160/// Returns a new Context and Solver for solving.
161pub(crate) fn init_context(
162    global_args: &GlobalArgs,
163    essence_file: PathBuf,
164    param_file: Option<PathBuf>,
165) -> anyhow::Result<Arc<RwLock<Context<'static>>>> {
166    let default_rule_trace_enabled = global_args.rule_trace.is_some();
167    let verbose_rule_trace_enabled = global_args.rule_trace_verbose.is_some();
168    let rule_trace_aggregates_enabled = global_args.rule_trace_aggregates.is_some();
169    let rule_trace_enabled =
170        default_rule_trace_enabled || verbose_rule_trace_enabled || rule_trace_aggregates_enabled;
171
172    set_current_parser(global_args.parser);
173    set_current_rewriter(global_args.rewriter);
174    set_comprehension_expander(global_args.comprehension_expander);
175    set_current_solver_family(global_args.solver);
176    set_minion_discrete_threshold(global_args.minion_discrete_threshold);
177    set_rule_trace_enabled(rule_trace_enabled);
178    set_default_rule_trace_enabled(default_rule_trace_enabled);
179    set_rule_trace_verbose_enabled(verbose_rule_trace_enabled);
180    set_rule_trace_aggregates_enabled(rule_trace_aggregates_enabled);
181
182    let target_family = global_args.solver;
183    let mut extra_rule_sets: Vec<&str> = DEFAULT_RULE_SETS.to_vec();
184    for rs in &global_args.extra_rule_sets {
185        extra_rule_sets.push(rs.as_str());
186    }
187
188    if let SolverFamily::Sat(sat_encoding) = target_family {
189        extra_rule_sets.push(sat_encoding.as_rule_set());
190    }
191
192    let rule_sets = match resolve_rule_sets(target_family, &extra_rule_sets) {
193        Ok(rs) => rs,
194        Err(e) => {
195            tracing::error!("Error resolving rule sets: {}", e);
196            exit(1);
197        }
198    };
199
200    let pretty_rule_sets = rule_sets
201        .iter()
202        .map(|rule_set| rule_set.name)
203        .collect::<Vec<_>>()
204        .join(", ");
205
206    tracing::info!("Enabled rule sets: [{}]", pretty_rule_sets);
207    tracing::info!(
208        target: "file",
209        "Rule sets: {}",
210        pretty_rule_sets
211    );
212
213    let rules = get_rules(&rule_sets)?.into_iter().collect::<Vec<_>>();
214    tracing::info!(
215        target: "file",
216        "Rules: {}",
217        rules.iter().map(|rd| format!("{rd}")).collect::<Vec<_>>().join("\n")
218    );
219    let context = Context::new_ptr(
220        target_family,
221        extra_rule_sets.iter().map(|rs| rs.to_string()).collect(),
222        rules,
223        rule_sets.clone(),
224    );
225
226    context.write().unwrap().essence_file_name = Some(essence_file.to_str().expect("").into());
227    if let Some(param_file) = param_file {
228        context.write().unwrap().param_file_name = Some(param_file.to_str().expect("").into());
229    }
230
231    Ok(context)
232}
233
234pub(crate) fn init_solver(global_args: &GlobalArgs) -> Solver {
235    let family = global_args.solver;
236    let timeout_ms = global_args
237        .solver_timeout
238        .map(|dur| Duration::from(dur).as_millis())
239        .map(|timeout_ms| u64::try_from(timeout_ms).expect("Timeout too large"));
240
241    match family {
242        SolverFamily::Minion => Solver::new(Minion::default()),
243        SolverFamily::Sat(_) => Solver::new(Sat::default()),
244        SolverFamily::Smt(theory_cfg) => Solver::new(Smt::new(timeout_ms, theory_cfg)),
245    }
246}
247
248pub(crate) fn parse(
249    global_args: &GlobalArgs,
250    context: Arc<RwLock<Context<'static>>>,
251    file_path: &str,
252) -> anyhow::Result<Model> {
253    tracing::info!(target: "file", "Input file: {}", file_path);
254
255    match global_args.parser {
256        conjure_cp::settings::Parser::TreeSitter => {
257            parse_essence_file_native(file_path, context.clone()).map_err(|e| e.into())
258        }
259        conjure_cp::settings::Parser::ViaConjure => parse_with_conjure(file_path, context.clone()),
260    }
261}
262
263pub(crate) fn parse_with_conjure(
264    input_file: &str,
265    context: Arc<RwLock<Context<'static>>>,
266) -> anyhow::Result<Model> {
267    conjure_executable().map_err(|e| anyhow!("Could not find correct conjure executable: {e}"))?;
268
269    let mut cmd = std::process::Command::new("conjure");
270    let output = cmd
271        .arg("pretty")
272        .arg("--output-format=astjson")
273        .arg(input_file)
274        .output()?;
275
276    if !output.status.success() {
277        println!("Parsing error: {}", String::from_utf8(output.stderr)?);
278    }
279
280    let astjson = String::from_utf8(output.stdout)?;
281
282    if cfg!(feature = "extra-rule-checks") {
283        tracing::info!("extra-rule-checks: enabled");
284    } else {
285        tracing::info!("extra-rule-checks: disabled");
286    }
287
288    model_from_json(&astjson, context.clone()).map_err(|e| anyhow!(e))
289}
290
291pub(crate) fn rewrite(
292    model: Model,
293    global_args: &GlobalArgs,
294    context: Arc<RwLock<Context<'static>>>,
295) -> anyhow::Result<Model> {
296    tracing::info!("Initial model: \n{}\n", model);
297
298    let rewriter = global_args.rewriter;
299    set_current_rewriter(rewriter);
300
301    let comprehension_expander = global_args.comprehension_expander;
302    set_comprehension_expander(comprehension_expander);
303    tracing::info!("Comprehension expander: {}", comprehension_expander);
304
305    let rule_sets = context.read().unwrap().rule_sets.clone();
306
307    let new_model = match rewriter {
308        Rewriter::Morph(config) => {
309            tracing::info!("Rewriting the model using the morph rewriter ({})", config);
310            rewrite_morph(
311                model,
312                &rule_sets,
313                global_args.check_equally_applicable_rules,
314                config,
315            )
316        }
317        Rewriter::Naive => {
318            tracing::info!("Rewriting the model using the default / naive rewriter");
319            rewrite_naive(
320                &model,
321                &rule_sets,
322                global_args.check_equally_applicable_rules,
323            )?
324        }
325    };
326
327    tracing::info!("Rewritten model: \n{}\n", new_model);
328    Ok(new_model)
329}
330
331fn run_solver(
332    solver: Solver,
333    global_args: &GlobalArgs,
334    cmd_args: &Args,
335    model: Model,
336) -> anyhow::Result<()> {
337    let out_file: Option<File> = match &cmd_args.output {
338        None => None,
339        Some(pth) => Some(
340            File::options()
341                .create(true)
342                .truncate(true)
343                .write(true)
344                .open(pth)?,
345        ),
346    };
347
348    let solutions = get_solutions(
349        solver,
350        model,
351        cmd_args.number_of_solutions.as_solver_limit(),
352        &global_args.save_solver_input_file,
353        global_args.rule_trace_cdp,
354    )?;
355    tracing::info!(target: "file", "Solutions: {}", solutions_to_json(&solutions));
356
357    let solutions_json = solutions_to_json(&solutions);
358    let solutions_str = to_string_pretty(&solutions_json)?;
359    match out_file {
360        None => {
361            println!("Solutions:");
362            println!("{solutions_str}");
363        }
364        Some(mut outf) => {
365            outf.write_all(solutions_str.as_bytes())?;
366            println!(
367                "Solutions saved to {:?}",
368                &cmd_args.output.clone().unwrap().canonicalize()?
369            )
370        }
371    }
372    Ok(())
373}