1
//! conjure_oxide solve sub-command
2
#![allow(clippy::unwrap_used)]
3
#[cfg(feature = "smt")]
4
use std::time::Duration;
5
use std::{
6
    fs::File,
7
    io::Write as _,
8
    path::PathBuf,
9
    process::exit,
10
    sync::{Arc, RwLock},
11
};
12

            
13
use anyhow::{anyhow, ensure};
14
use clap::ValueHint;
15
use conjure_cp::defaults::DEFAULT_RULE_SETS;
16
use conjure_cp::{
17
    Model,
18
    context::Context,
19
    rule_engine::{resolve_rule_sets, rewrite_morph, rewrite_naive},
20
    settings::{
21
        Rewriter, set_comprehension_expander, set_current_parser, set_current_rewriter,
22
        set_current_solver_family, set_minion_discrete_threshold,
23
    },
24
    solver::Solver,
25
};
26
use conjure_cp::{
27
    parse::conjure_json::model_from_json, rule_engine::get_rules, settings::SolverFamily,
28
};
29
use conjure_cp::{parse::tree_sitter::parse_essence_file_native, solver::adaptors::*};
30
use conjure_cp_cli::find_conjure::conjure_executable;
31
use conjure_cp_cli::utils::conjure::{get_solutions, solutions_to_json};
32
use serde_json::to_string_pretty;
33

            
34
use crate::cli::{GlobalArgs, LOGGING_HELP_HEADING};
35

            
36
#[derive(Clone, Debug, clap::Args)]
37
pub struct Args {
38
    /// The input Essence file
39
    #[arg(value_name = "INPUT_ESSENCE", value_hint = ValueHint::FilePath)]
40
    pub input_file: PathBuf,
41

            
42
    /// Save execution info as JSON to the given filepath.
43
    #[arg(long ,value_hint=ValueHint::FilePath,help_heading=LOGGING_HELP_HEADING)]
44
    pub info_json_path: Option<PathBuf>,
45

            
46
    /// Do not run the solver.
47
    ///
48
    /// The rewritten model is printed to stdout in an Essence-style syntax
49
    /// (but is not necessarily valid Essence).
50
    #[arg(long, default_value_t = false)]
51
    pub no_run_solver: bool,
52

            
53
    /// Number of solutions to return. 0 returns all solutions
54
    #[arg(long, default_value_t = 0, short = 'n')]
55
    pub number_of_solutions: i32,
56

            
57
    /// Save solutions to the given JSON file
58
    #[arg(long, short = 'o', value_hint = ValueHint::FilePath,help_heading=LOGGING_HELP_HEADING)]
59
    pub output: Option<PathBuf>,
60
}
61

            
62
42
pub fn run_solve_command(global_args: GlobalArgs, solve_args: Args) -> anyhow::Result<()> {
63
42
    let input_file = solve_args.input_file.clone();
64

            
65
    // each step is in its own method so that similar commands
66
    // (e.g. testsolve) can reuse some of these steps.
67

            
68
42
    let context = init_context(&global_args, input_file)?;
69
42
    let model = parse(&global_args, Arc::clone(&context))?;
70
30
    let rewritten_model = rewrite(model, &global_args, Arc::clone(&context))?;
71
30
    let solver = init_solver(&global_args);
72

            
73
30
    if solve_args.no_run_solver {
74
        println!("{}", &rewritten_model);
75

            
76
        if let Some(path) = global_args.save_solver_input_file {
77
            let solver = solver.load_model(rewritten_model)?;
78
            eprintln!("Writing solver input file to {}", path.display());
79
            let mut file: Box<dyn std::io::Write> = Box::new(File::create(path)?);
80
            solver.write_solver_input_file(&mut file)?;
81
        }
82
    } else {
83
30
        run_solver(solver, &global_args, &solve_args, rewritten_model)?
84
    }
85

            
86
    // still do postamble even if we didn't run the solver
87
30
    if let Some(ref path) = solve_args.info_json_path {
88
        let context_obj = context.read().unwrap().clone();
89
        let generated_json = &serde_json::to_value(context_obj)?;
90
        let pretty_json = serde_json::to_string_pretty(&generated_json)?;
91
        File::create(path)?.write_all(pretty_json.as_bytes())?;
92
30
    }
93
30
    Ok(())
94
42
}
95

            
96
/// Returns a new Context and Solver for solving.
97
48
pub(crate) fn init_context(
98
48
    global_args: &GlobalArgs,
99
48
    input_file: PathBuf,
100
48
) -> anyhow::Result<Arc<RwLock<Context<'static>>>> {
101
48
    set_current_parser(global_args.parser);
102
48
    set_current_rewriter(global_args.rewriter);
103
48
    set_comprehension_expander(global_args.comprehension_expander);
104
48
    set_current_solver_family(global_args.solver);
105
48
    set_minion_discrete_threshold(global_args.minion_discrete_threshold);
106

            
107
48
    let target_family = global_args.solver;
108
48
    let mut extra_rule_sets: Vec<&str> = DEFAULT_RULE_SETS.to_vec();
109
48
    for rs in &global_args.extra_rule_sets {
110
        extra_rule_sets.push(rs.as_str());
111
    }
112

            
113
48
    if let SolverFamily::Sat(sat_encoding) = target_family {
114
        extra_rule_sets.push(sat_encoding.as_rule_set());
115
48
    }
116

            
117
48
    let rule_sets = match resolve_rule_sets(target_family, &extra_rule_sets) {
118
48
        Ok(rs) => rs,
119
        Err(e) => {
120
            tracing::error!("Error resolving rule sets: {}", e);
121
            exit(1);
122
        }
123
    };
124

            
125
48
    let pretty_rule_sets = rule_sets
126
48
        .iter()
127
48
        .map(|rule_set| rule_set.name)
128
48
        .collect::<Vec<_>>()
129
48
        .join(", ");
130

            
131
48
    tracing::info!("Enabled rule sets: [{}]", pretty_rule_sets);
132
48
    tracing::info!(
133
        target: "file",
134
        "Rule sets: {}",
135
        pretty_rule_sets
136
    );
137

            
138
48
    let rules = get_rules(&rule_sets)?.into_iter().collect::<Vec<_>>();
139
48
    tracing::info!(
140
        target: "file",
141
        "Rules: {}",
142
1710
        rules.iter().map(|rd| format!("{rd}")).collect::<Vec<_>>().join("\n")
143
    );
144
48
    let context = Context::new_ptr(
145
48
        target_family,
146
144
        extra_rule_sets.iter().map(|rs| rs.to_string()).collect(),
147
48
        rules,
148
48
        rule_sets.clone(),
149
    );
150

            
151
48
    context.write().unwrap().file_name = Some(input_file.to_str().expect("").into());
152

            
153
48
    Ok(context)
154
48
}
155

            
156
30
pub(crate) fn init_solver(global_args: &GlobalArgs) -> Solver {
157
30
    let family = global_args.solver;
158
    #[cfg(feature = "smt")]
159
30
    let timeout_ms = global_args
160
30
        .solver_timeout
161
30
        .map(|dur| Duration::from(dur).as_millis())
162
30
        .map(|timeout_ms| u64::try_from(timeout_ms).expect("Timeout too large"));
163

            
164
30
    match family {
165
30
        SolverFamily::Minion => Solver::new(Minion::default()),
166
        SolverFamily::Sat(_) => Solver::new(Sat::default()),
167
        #[cfg(feature = "smt")]
168
        SolverFamily::Smt(theory_cfg) => Solver::new(Smt::new(timeout_ms, theory_cfg)),
169
    }
170
30
}
171

            
172
48
pub(crate) fn parse(
173
48
    global_args: &GlobalArgs,
174
48
    context: Arc<RwLock<Context<'static>>>,
175
48
) -> anyhow::Result<Model> {
176
48
    let input_file: String = context
177
48
        .read()
178
48
        .unwrap()
179
48
        .file_name
180
48
        .clone()
181
48
        .expect("context should contain the input file");
182

            
183
48
    tracing::info!(target: "file", "Input file: {}", input_file);
184
48
    match global_args.parser {
185
        conjure_cp::settings::Parser::TreeSitter => {
186
42
            parse_essence_file_native(input_file.as_str(), context.clone()).map_err(|e| e.into())
187
        }
188
        conjure_cp::settings::Parser::ViaConjure => {
189
6
            conjure_executable()
190
6
                .map_err(|e| anyhow!("Could not find correct conjure executable: {e}"))?;
191

            
192
6
            let mut cmd = std::process::Command::new("conjure");
193
6
            let output = cmd
194
6
                .arg("pretty")
195
6
                .arg("--output-format=astjson")
196
6
                .arg(input_file)
197
6
                .output()?;
198

            
199
6
            let conjure_stderr = String::from_utf8(output.stderr)?;
200

            
201
6
            ensure!(conjure_stderr.is_empty(), conjure_stderr);
202

            
203
6
            let astjson = String::from_utf8(output.stdout)?;
204

            
205
6
            if cfg!(feature = "extra-rule-checks") {
206
6
                tracing::info!("extra-rule-checks: enabled");
207
            } else {
208
                tracing::info!("extra-rule-checks: disabled");
209
            }
210

            
211
6
            model_from_json(&astjson, context.clone()).map_err(|e| anyhow!(e))
212
        }
213
    }
214
48
}
215

            
216
30
pub(crate) fn rewrite(
217
30
    model: Model,
218
30
    global_args: &GlobalArgs,
219
30
    context: Arc<RwLock<Context<'static>>>,
220
30
) -> anyhow::Result<Model> {
221
30
    tracing::info!("Initial model: \n{}\n", model);
222

            
223
30
    set_current_rewriter(global_args.rewriter);
224

            
225
30
    let comprehension_expander = global_args.comprehension_expander;
226
30
    set_comprehension_expander(comprehension_expander);
227
30
    tracing::info!("Comprehension expander: {}", comprehension_expander);
228

            
229
30
    let rule_sets = context.read().unwrap().rule_sets.clone();
230

            
231
30
    let new_model = match global_args.rewriter {
232
        Rewriter::Morph => {
233
            tracing::info!("Rewriting the model using the morph rewriter");
234
            rewrite_morph(
235
                model,
236
                &rule_sets,
237
                global_args.check_equally_applicable_rules,
238
            )
239
        }
240
        Rewriter::Naive => {
241
30
            tracing::info!("Rewriting the model using the default / naive rewriter");
242
30
            rewrite_naive(
243
30
                &model,
244
30
                &rule_sets,
245
30
                global_args.check_equally_applicable_rules,
246
            )?
247
        }
248
    };
249

            
250
30
    tracing::info!("Rewritten model: \n{}\n", new_model);
251
30
    Ok(new_model)
252
30
}
253

            
254
30
fn run_solver(
255
30
    solver: Solver,
256
30
    global_args: &GlobalArgs,
257
30
    cmd_args: &Args,
258
30
    model: Model,
259
30
) -> anyhow::Result<()> {
260
30
    let out_file: Option<File> = match &cmd_args.output {
261
30
        None => None,
262
        Some(pth) => Some(
263
            File::options()
264
                .create(true)
265
                .truncate(true)
266
                .write(true)
267
                .open(pth)?,
268
        ),
269
    };
270

            
271
30
    let solutions = get_solutions(
272
30
        solver,
273
30
        model,
274
30
        cmd_args.number_of_solutions,
275
30
        &global_args.save_solver_input_file,
276
    )?;
277
30
    tracing::info!(target: "file", "Solutions: {}", solutions_to_json(&solutions));
278

            
279
30
    let solutions_json = solutions_to_json(&solutions);
280
30
    let solutions_str = to_string_pretty(&solutions_json)?;
281
30
    match out_file {
282
30
        None => {
283
30
            println!("Solutions:");
284
30
            println!("{solutions_str}");
285
30
        }
286
        Some(mut outf) => {
287
            outf.write_all(solutions_str.as_bytes())?;
288
            println!(
289
                "Solutions saved to {:?}",
290
                &cmd_args.output.clone().unwrap().canonicalize()?
291
            )
292
        }
293
    }
294
30
    Ok(())
295
30
}