1
//! conjure_oxide solve sub-command
2
#![allow(clippy::unwrap_used)]
3
use std::time::Duration;
4
use std::{
5
    fs::File,
6
    io::Write as _,
7
    path::PathBuf,
8
    process::exit,
9
    sync::{Arc, RwLock},
10
};
11

            
12
use anyhow::{anyhow, ensure};
13
use clap::ValueHint;
14
use conjure_cp::defaults::DEFAULT_RULE_SETS;
15
use conjure_cp::{
16
    Model,
17
    context::Context,
18
    rule_engine::{resolve_rule_sets, rewrite_morph, rewrite_naive},
19
    settings::{
20
        Rewriter, set_comprehension_expander, set_current_parser, set_current_rewriter,
21
        set_current_solver_family, set_minion_discrete_threshold,
22
    },
23
    solver::Solver,
24
};
25
use conjure_cp::{
26
    parse::conjure_json::model_from_json, rule_engine::get_rules, settings::SolverFamily,
27
};
28
use conjure_cp::{parse::tree_sitter::parse_essence_file_native, solver::adaptors::*};
29
use conjure_cp_cli::find_conjure::conjure_executable;
30
use conjure_cp_cli::utils::conjure::{get_solutions, solutions_to_json};
31
use serde_json::to_string_pretty;
32

            
33
use crate::cli::{GlobalArgs, LOGGING_HELP_HEADING};
34

            
35
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
36
pub enum NumberOfSolutions {
37
    All,
38
    Limit(i32),
39
}
40

            
41
impl NumberOfSolutions {
42
10
    fn as_solver_limit(self) -> i32 {
43
16
        match self {
44
16
            NumberOfSolutions::All => 0,
45
5
            NumberOfSolutions::Limit(limit) => limit,
46
1
        }
47
10
    }
48
6
}
49

            
50
18
fn parse_number_of_solutions(input: &str) -> Result<NumberOfSolutions, String> {
51
31
    if input.eq_ignore_ascii_case("all") {
52
23
        return Ok(NumberOfSolutions::All);
53
13
    }
54
8

            
55
8
    let limit = input
56
16
        .parse::<i32>()
57
16
        .map_err(|_| "expected a positive integer or 'all'".to_string())?;
58
8

            
59
8
    if limit <= 0 {
60
8
        return Err("expected a positive integer or 'all'".to_string());
61
8
    }
62
8

            
63
8
    Ok(NumberOfSolutions::Limit(limit))
64
26
}
65
13

            
66
#[derive(Clone, Debug, clap::Args)]
67
pub struct Args {
68
    /// The input Essence file
69
    #[arg(value_name = "INPUT_ESSENCE", value_hint = ValueHint::FilePath)]
70
    pub input_file: PathBuf,
71

            
72
    /// Save execution info as JSON to the given filepath.
73
    #[arg(long ,value_hint=ValueHint::FilePath,help_heading=LOGGING_HELP_HEADING)]
74
    pub info_json_path: Option<PathBuf>,
75

            
76
    /// Do not run the solver.
77
    ///
78
    /// The rewritten model is printed to stdout in an Essence-style syntax
79
    /// (but is not necessarily valid Essence).
80
    #[arg(long, default_value_t = false)]
81
    pub no_run_solver: bool,
82

            
83
    /// Number of solutions to return. Use a positive integer, or `all`.
84
    #[arg(
85
        long,
86
        short = 'n',
87
        default_value = "1",
88
        value_name = "N|all",
89
        value_parser = parse_number_of_solutions
90
    )]
91
    pub number_of_solutions: NumberOfSolutions,
92

            
93
    /// Save solutions to the given JSON file
94
    #[arg(long, short = 'o', value_hint = ValueHint::FilePath,help_heading=LOGGING_HELP_HEADING)]
95
    pub output: Option<PathBuf>,
96
}
97

            
98
18
pub fn run_solve_command(global_args: GlobalArgs, solve_args: Args) -> anyhow::Result<()> {
99
18
    let input_file = solve_args.input_file.clone();
100

            
101
    // each step is in its own method so that similar commands
102
    // (e.g. testsolve) can reuse some of these steps.
103
13

            
104
31
    let context = init_context(&global_args, input_file)?;
105
31
    let model = parse(&global_args, Arc::clone(&context))?;
106
14
    let rewritten_model = rewrite(model, &global_args, Arc::clone(&context))?;
107
14
    let solver = init_solver(&global_args);
108

            
109
14
    if solve_args.no_run_solver {
110
17
        println!("{}", &rewritten_model);
111

            
112
17
        if let Some(path) = global_args.save_solver_input_file {
113
13
            let solver = solver.load_model(rewritten_model)?;
114
13
            eprintln!("Writing solver input file to {}", path.display());
115
13
            let mut file: Box<dyn std::io::Write> = Box::new(File::create(path)?);
116
13
            solver.write_solver_input_file(&mut file)?;
117
17
        }
118
    } else {
119
10
        run_solver(solver, &global_args, &solve_args, rewritten_model)?
120
13
    }
121

            
122
    // still do postamble even if we didn't run the solver
123
25
    if let Some(ref path) = solve_args.info_json_path {
124
4
        let context_obj = context.read().unwrap().clone();
125
4
        let generated_json = &serde_json::to_value(context_obj)?;
126
4
        let pretty_json = serde_json::to_string_pretty(&generated_json)?;
127
        File::create(path)?.write_all(pretty_json.as_bytes())?;
128
21
    }
129
14
    Ok(())
130
26
}
131

            
132
/// Returns a new Context and Solver for solving.
133
32
pub(crate) fn init_context(
134
40
    global_args: &GlobalArgs,
135
32
    input_file: PathBuf,
136
40
) -> anyhow::Result<Arc<RwLock<Context<'static>>>> {
137
34
    set_current_parser(global_args.parser);
138
32
    set_current_rewriter(global_args.rewriter);
139
34
    set_comprehension_expander(global_args.comprehension_expander);
140
32
    set_current_solver_family(global_args.solver);
141
32
    set_minion_discrete_threshold(global_args.minion_discrete_threshold);
142

            
143
32
    let target_family = global_args.solver;
144
34
    let mut extra_rule_sets: Vec<&str> = DEFAULT_RULE_SETS.to_vec();
145
32
    for rs in &global_args.extra_rule_sets {
146
6
        extra_rule_sets.push(rs.as_str());
147
    }
148

            
149
32
    if let SolverFamily::Sat(sat_encoding) = target_family {
150
8
        extra_rule_sets.push(sat_encoding.as_rule_set());
151
32
    }
152

            
153
32
    let rule_sets = match resolve_rule_sets(target_family, &extra_rule_sets) {
154
32
        Ok(rs) => rs,
155
8
        Err(e) => {
156
8
            tracing::error!("Error resolving rule sets: {}", e);
157
13
            exit(1);
158
        }
159
6
    };
160
6

            
161
38
    let pretty_rule_sets = rule_sets
162
32
        .iter()
163
44
        .map(|rule_set| rule_set.name)
164
44
        .collect::<Vec<_>>()
165
38
        .join(", ");
166

            
167
32
    tracing::info!("Enabled rule sets: [{}]", pretty_rule_sets);
168
32
    tracing::info!(
169
6
        target: "file",
170
6
        "Rule sets: {}",
171
6
        pretty_rule_sets
172
6
    );
173
6

            
174
32
    let rules = get_rules(&rule_sets)?.into_iter().collect::<Vec<_>>();
175
34
    tracing::info!(
176
        target: "file",
177
        "Rules: {}",
178
954
        rules.iter().map(|rd| format!("{rd}")).collect::<Vec<_>>().join("\n")
179
4
    );
180
32
    let context = Context::new_ptr(
181
32
        target_family,
182
100
        extra_rule_sets.iter().map(|rs| rs.to_string()).collect(),
183
36
        rules,
184
36
        rule_sets.clone(),
185
    );
186

            
187
36
    context.write().unwrap().file_name = Some(input_file.to_str().expect("").into());
188
2

            
189
34
    Ok(context)
190
34
}
191
2

            
192
14
pub(crate) fn init_solver(global_args: &GlobalArgs) -> Solver {
193
14
    let family = global_args.solver;
194
16
    let timeout_ms = global_args
195
16
        .solver_timeout
196
16
        .map(|dur| Duration::from(dur).as_millis())
197
14
        .map(|timeout_ms| u64::try_from(timeout_ms).expect("Timeout too large"));
198
2

            
199
16
    match family {
200
14
        SolverFamily::Minion => Solver::new(Minion::default()),
201
2
        SolverFamily::Sat(_) => Solver::new(Sat::default()),
202
        SolverFamily::Smt(theory_cfg) => Solver::new(Smt::new(timeout_ms, theory_cfg)),
203
    }
204
16
}
205
2

            
206
38
pub(crate) fn parse(
207
32
    global_args: &GlobalArgs,
208
32
    context: Arc<RwLock<Context<'static>>>,
209
54
) -> anyhow::Result<Model> {
210
54
    let input_file: String = context
211
54
        .read()
212
54
        .unwrap()
213
54
        .file_name
214
54
        .clone()
215
54
        .expect("context should contain the input file");
216
22

            
217
54
    tracing::info!(target: "file", "Input file: {}", input_file);
218
54
    match global_args.parser {
219
        conjure_cp::settings::Parser::TreeSitter => {
220
36
            parse_essence_file_native(input_file.as_str(), context.clone()).map_err(|e| e.into())
221
22
        }
222
22
        conjure_cp::settings::Parser::ViaConjure => {
223
18
            conjure_executable()
224
18
                .map_err(|e| anyhow!("Could not find correct conjure executable: {e}"))?;
225

            
226
40
            let mut cmd = std::process::Command::new("conjure");
227
18
            let output = cmd
228
40
                .arg("pretty")
229
18
                .arg("--output-format=astjson")
230
40
                .arg(input_file)
231
40
                .output()?;
232

            
233
18
            let conjure_stderr = String::from_utf8(output.stderr)?;
234

            
235
18
            ensure!(conjure_stderr.is_empty(), conjure_stderr);
236

            
237
18
            let astjson = String::from_utf8(output.stdout)?;
238
22

            
239
40
            if cfg!(feature = "extra-rule-checks") {
240
40
                tracing::info!("extra-rule-checks: enabled");
241
22
            } else {
242
22
                tracing::info!("extra-rule-checks: disabled");
243
            }
244
22

            
245
40
            model_from_json(&astjson, context.clone()).map_err(|e| anyhow!(e))
246
        }
247
    }
248
32
}
249

            
250
14
pub(crate) fn rewrite(
251
36
    model: Model,
252
36
    global_args: &GlobalArgs,
253
14
    context: Arc<RwLock<Context<'static>>>,
254
14
) -> anyhow::Result<Model> {
255
489
    tracing::info!("Initial model: \n{}\n", model);
256

            
257
36
    set_current_rewriter(global_args.rewriter);
258
22

            
259
80
    let comprehension_expander = global_args.comprehension_expander;
260
36
    set_comprehension_expander(comprehension_expander);
261
36
    tracing::info!("Comprehension expander: {}", comprehension_expander);
262

            
263
14
    let rule_sets = context.read().unwrap().rule_sets.clone();
264
22

            
265
36
    let new_model = match global_args.rewriter {
266
6
        Rewriter::Morph => {
267
16
            tracing::info!("Rewriting the model using the morph rewriter");
268
            rewrite_morph(
269
22
                model,
270
22
                &rule_sets,
271
                global_args.check_equally_applicable_rules,
272
9
            )
273
9
        }
274
9
        Rewriter::Naive => {
275
23
            tracing::info!("Rewriting the model using the default / naive rewriter");
276
23
            rewrite_naive(
277
23
                &model,
278
14
                &rule_sets,
279
23
                global_args.check_equally_applicable_rules,
280
9
            )?
281
        }
282
    };
283

            
284
23
    tracing::info!("Rewritten model: \n{}\n", new_model);
285
14
    Ok(new_model)
286
42
}
287
28

            
288
38
fn run_solver(
289
38
    solver: Solver,
290
38
    global_args: &GlobalArgs,
291
38
    cmd_args: &Args,
292
10
    model: Model,
293
38
) -> anyhow::Result<()> {
294
10
    let out_file: Option<File> = match &cmd_args.output {
295
17
        None => None,
296
        Some(pth) => Some(
297
21
            File::options()
298
                .create(true)
299
28
                .truncate(true)
300
                .write(true)
301
21
                .open(pth)?,
302
21
        ),
303
21
    };
304
21

            
305
31
    let solutions = get_solutions(
306
10
        solver,
307
31
        model,
308
31
        cmd_args.number_of_solutions.as_solver_limit(),
309
31
        &global_args.save_solver_input_file,
310
21
    )?;
311
31
    tracing::info!(target: "file", "Solutions: {}", solutions_to_json(&solutions));
312
21

            
313
10
    let solutions_json = solutions_to_json(&solutions);
314
31
    let solutions_str = to_string_pretty(&solutions_json)?;
315
10
    match out_file {
316
31
        None => {
317
10
            println!("Solutions:");
318
31
            println!("{solutions_str}");
319
10
        }
320
21
        Some(mut outf) => {
321
21
            outf.write_all(solutions_str.as_bytes())?;
322
            println!(
323
                "Solutions saved to {:?}",
324
                &cmd_args.output.clone().unwrap().canonicalize()?
325
            )
326
21
        }
327
21
    }
328
10
    Ok(())
329
19
}