1
//! conjure_oxide solve sub-command
2
#![allow(clippy::unwrap_used)]
3
use std::time::Duration;
4
use std::{
5
    fs::File,
6
    io::Write as _,
7
    path::PathBuf,
8
    process::exit,
9
    sync::{Arc, RwLock},
10
};
11

            
12
use anyhow::anyhow;
13
use clap::ValueHint;
14
use conjure_cp::{
15
    Model,
16
    ast::{DeclarationPtr, declaration::Declaration, eval_constant},
17
    context::Context,
18
    rule_engine::{resolve_rule_sets, rewrite_morph, rewrite_naive},
19
    settings::{
20
        Rewriter, set_comprehension_expander, set_current_parser, set_current_rewriter,
21
        set_current_solver_family, set_minion_discrete_threshold,
22
    },
23
    solver::Solver,
24
};
25
use conjure_cp::{ast::DeclarationKind, defaults::DEFAULT_RULE_SETS};
26
use conjure_cp::{
27
    parse::conjure_json::model_from_json, rule_engine::get_rules, settings::SolverFamily,
28
};
29
use conjure_cp::{parse::tree_sitter::parse_essence_file_native, solver::adaptors::*};
30
use conjure_cp_cli::find_conjure::conjure_executable;
31
use conjure_cp_cli::utils::conjure::{get_solutions, solutions_to_json};
32
use serde_json::to_string_pretty;
33

            
34
use crate::cli::{GlobalArgs, LOGGING_HELP_HEADING};
35

            
36
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
37
pub enum NumberOfSolutions {
38
    All,
39
    Limit(i32),
40
}
41

            
42
impl NumberOfSolutions {
43
12
    fn as_solver_limit(self) -> i32 {
44
12
        match self {
45
10
            NumberOfSolutions::All => 0,
46
2
            NumberOfSolutions::Limit(limit) => limit,
47
        }
48
12
    }
49
}
50

            
51
26
fn parse_number_of_solutions(input: &str) -> Result<NumberOfSolutions, String> {
52
26
    if input.eq_ignore_ascii_case("all") {
53
10
        return Ok(NumberOfSolutions::All);
54
16
    }
55

            
56
16
    let limit = input
57
16
        .parse::<i32>()
58
16
        .map_err(|_| "expected a positive integer or 'all'".to_string())?;
59

            
60
16
    if limit <= 0 {
61
        return Err("expected a positive integer or 'all'".to_string());
62
16
    }
63

            
64
16
    Ok(NumberOfSolutions::Limit(limit))
65
26
}
66

            
67
#[derive(Clone, Debug, clap::Args)]
68
pub struct Args {
69
    /// The input Essence problem file
70
    #[arg(value_name = "INPUT_ESSENCE", value_hint = ValueHint::FilePath)]
71
    pub essence_file: PathBuf,
72

            
73
    /// The input Essence parameter file
74
    #[arg(value_name = "PARAM_ESSENCE", value_hint = ValueHint::FilePath)]
75
    pub param_file: Option<PathBuf>,
76

            
77
    /// Save execution info as JSON to the given filepath.
78
    #[arg(long ,value_hint=ValueHint::FilePath,help_heading=LOGGING_HELP_HEADING)]
79
    pub info_json_path: Option<PathBuf>,
80

            
81
    /// Do not run the solver.
82
    ///
83
    /// The rewritten model is printed to stdout in an Essence-style syntax
84
    /// (but is not necessarily valid Essence).
85
    #[arg(long, default_value_t = false)]
86
    pub no_run_solver: bool,
87

            
88
    /// Number of solutions to return. Use a positive integer, or `all`.
89
    #[arg(
90
        long,
91
        short = 'n',
92
        default_value = "1",
93
        value_name = "N|all",
94
        value_parser = parse_number_of_solutions
95
    )]
96
    pub number_of_solutions: NumberOfSolutions,
97

            
98
    /// Save solutions to the given JSON file
99
    #[arg(long, short = 'o', value_hint = ValueHint::FilePath,help_heading=LOGGING_HELP_HEADING)]
100
    pub output: Option<PathBuf>,
101
}
102

            
103
26
pub fn run_solve_command(global_args: GlobalArgs, solve_args: Args) -> anyhow::Result<()> {
104
26
    let essence_file = solve_args.essence_file.clone();
105
26
    let param_file = solve_args.param_file.clone();
106

            
107
    // each step is in its own method so that similar commands
108
    // (e.g. testsolve) can reuse some of these steps.
109

            
110
26
    let context = init_context(&global_args, essence_file, param_file)?;
111

            
112
26
    let ctx_lock = context.read().unwrap();
113
26
    let essence_file_name = ctx_lock
114
26
        .essence_file_name
115
26
        .as_ref()
116
26
        .expect("context should contain the problem input file");
117
26
    let param_file_name = ctx_lock.param_file_name.as_ref();
118

            
119
    // parse models
120
26
    let problem_model = parse(&global_args, Arc::clone(&context), essence_file_name)?;
121

            
122
    // unify models
123
22
    let unified_model = match param_file_name {
124
8
        Some(param_file_name) => {
125
8
            let param_model = parse(&global_args, Arc::clone(&context), param_file_name)?;
126
8
            instantiate_model(problem_model, param_model)?
127
        }
128
14
        None => problem_model,
129
    };
130
16
    drop(ctx_lock);
131

            
132
16
    let rewritten_model = rewrite(unified_model, &global_args, Arc::clone(&context))?;
133

            
134
16
    let solver = init_solver(&global_args);
135

            
136
16
    if solve_args.no_run_solver {
137
4
        println!("{}", &rewritten_model);
138

            
139
4
        if let Some(path) = global_args.save_solver_input_file {
140
            let solver = solver.load_model(rewritten_model)?;
141
            eprintln!("Writing solver input file to {}", path.display());
142
            let mut file: Box<dyn std::io::Write> = Box::new(File::create(path)?);
143
            solver.write_solver_input_file(&mut file)?;
144
4
        }
145
    } else {
146
12
        run_solver(solver, &global_args, &solve_args, rewritten_model)?
147
    }
148

            
149
    // still do postamble even if we didn't run the solver
150
16
    if let Some(ref path) = solve_args.info_json_path {
151
        let context_obj = context.read().unwrap().clone();
152
        let generated_json = &serde_json::to_value(context_obj)?;
153
        let pretty_json = serde_json::to_string_pretty(&generated_json)?;
154
        File::create(path)?.write_all(pretty_json.as_bytes())?;
155
16
    }
156
16
    Ok(())
157
26
}
158

            
159
12
pub(crate) fn instantiate_model(problem_model: Model, param_model: Model) -> anyhow::Result<Model> {
160
12
    let mut symbol_table = problem_model.symbols_ptr_unchecked().write();
161
12
    let param_table = param_model.symbols_ptr_unchecked().write();
162

            
163
24
    for (name, decl) in symbol_table.iter_local_mut() {
164
24
        let Some(domain) = decl.as_given() else {
165
12
            continue;
166
        };
167

            
168
        // Find corresponding letting in param file
169
12
        let param_decl = param_table.lookup(name);
170
12
        let expr = param_decl
171
12
                .as_ref()
172
12
                .and_then(DeclarationPtr::as_value_letting)
173
12
                .ok_or_else(|| anyhow!(
174
                    "Given declaration `{name}` does not have corresponding letting in parameter file"
175
4
                ))?;
176

            
177
        // Evaluate the letting expresison to a literal
178
8
        let expr_value = eval_constant(&expr)
179
8
            .ok_or_else(|| anyhow!("Letting expression `{expr}` cannot be evaluated"))?;
180

            
181
        // Resolve the given's domain
182
8
        let ground_domain = domain
183
8
            .resolve()
184
8
            .ok_or_else(|| anyhow!("Domain of given statement `{name}` cannot be resolved"))?;
185

            
186
        // Ensure the letting value is contained within the given expression's domain
187
8
        if !ground_domain.contains(&expr_value).unwrap() {
188
4
            return Err(anyhow!(
189
4
                "Domain of given statement `{name}` does not contain letting value"
190
4
            ));
191
4
        }
192

            
193
        // Replace the given statement in the model with the statement
194
4
        let new_decl = Declaration::new(
195
4
            name.clone(),
196
4
            DeclarationKind::ValueLetting(expr.clone(), Some(domain.clone())),
197
        );
198
4
        drop(domain);
199
4
        decl.replace(new_decl);
200

            
201
4
        tracing::info!("Replaced {name} given with letting.");
202
    }
203

            
204
4
    drop(symbol_table);
205
4
    Ok(problem_model)
206
12
}
207

            
208
/// Returns a new Context and Solver for solving.
209
44
pub(crate) fn init_context(
210
44
    global_args: &GlobalArgs,
211
44
    essence_file: PathBuf,
212
44
    param_file: Option<PathBuf>,
213
44
) -> anyhow::Result<Arc<RwLock<Context<'static>>>> {
214
44
    set_current_parser(global_args.parser);
215
44
    set_current_rewriter(global_args.rewriter);
216
44
    set_comprehension_expander(global_args.comprehension_expander);
217
44
    set_current_solver_family(global_args.solver);
218
44
    set_minion_discrete_threshold(global_args.minion_discrete_threshold);
219

            
220
44
    let target_family = global_args.solver;
221
44
    let mut extra_rule_sets: Vec<&str> = DEFAULT_RULE_SETS.to_vec();
222
44
    for rs in &global_args.extra_rule_sets {
223
        extra_rule_sets.push(rs.as_str());
224
    }
225

            
226
44
    if let SolverFamily::Sat(sat_encoding) = target_family {
227
        extra_rule_sets.push(sat_encoding.as_rule_set());
228
44
    }
229

            
230
44
    let rule_sets = match resolve_rule_sets(target_family, &extra_rule_sets) {
231
44
        Ok(rs) => rs,
232
        Err(e) => {
233
            tracing::error!("Error resolving rule sets: {}", e);
234
            exit(1);
235
        }
236
    };
237

            
238
44
    let pretty_rule_sets = rule_sets
239
44
        .iter()
240
44
        .map(|rule_set| rule_set.name)
241
44
        .collect::<Vec<_>>()
242
44
        .join(", ");
243

            
244
44
    tracing::info!("Enabled rule sets: [{}]", pretty_rule_sets);
245
44
    tracing::info!(
246
        target: "file",
247
        "Rule sets: {}",
248
        pretty_rule_sets
249
    );
250

            
251
44
    let rules = get_rules(&rule_sets)?.into_iter().collect::<Vec<_>>();
252
44
    tracing::info!(
253
        target: "file",
254
        "Rules: {}",
255
950
        rules.iter().map(|rd| format!("{rd}")).collect::<Vec<_>>().join("\n")
256
    );
257
44
    let context = Context::new_ptr(
258
44
        target_family,
259
132
        extra_rule_sets.iter().map(|rs| rs.to_string()).collect(),
260
44
        rules,
261
44
        rule_sets.clone(),
262
    );
263

            
264
44
    context.write().unwrap().essence_file_name = Some(essence_file.to_str().expect("").into());
265
44
    if let Some(param_file) = param_file {
266
12
        context.write().unwrap().param_file_name = Some(param_file.to_str().expect("").into());
267
32
    }
268

            
269
44
    Ok(context)
270
44
}
271

            
272
18
pub(crate) fn init_solver(global_args: &GlobalArgs) -> Solver {
273
18
    let family = global_args.solver;
274
18
    let timeout_ms = global_args
275
18
        .solver_timeout
276
18
        .map(|dur| Duration::from(dur).as_millis())
277
18
        .map(|timeout_ms| u64::try_from(timeout_ms).expect("Timeout too large"));
278

            
279
18
    match family {
280
18
        SolverFamily::Minion => Solver::new(Minion::default()),
281
        SolverFamily::Sat(_) => Solver::new(Sat::default()),
282
        SolverFamily::Smt(theory_cfg) => Solver::new(Smt::new(timeout_ms, theory_cfg)),
283
    }
284
18
}
285

            
286
56
pub(crate) fn parse(
287
56
    global_args: &GlobalArgs,
288
56
    context: Arc<RwLock<Context<'static>>>,
289
56
    file_path: &str,
290
56
) -> anyhow::Result<Model> {
291
56
    tracing::info!(target: "file", "Input file: {}", file_path);
292

            
293
56
    match global_args.parser {
294
        conjure_cp::settings::Parser::TreeSitter => {
295
14
            parse_essence_file_native(file_path, context.clone()).map_err(|e| e.into())
296
        }
297
42
        conjure_cp::settings::Parser::ViaConjure => parse_with_conjure(file_path, context.clone()),
298
    }
299
56
}
300

            
301
42
pub(crate) fn parse_with_conjure(
302
42
    input_file: &str,
303
42
    context: Arc<RwLock<Context<'static>>>,
304
42
) -> anyhow::Result<Model> {
305
42
    conjure_executable().map_err(|e| anyhow!("Could not find correct conjure executable: {e}"))?;
306

            
307
42
    let mut cmd = std::process::Command::new("conjure");
308
42
    let output = cmd
309
42
        .arg("pretty")
310
42
        .arg("--output-format=astjson")
311
42
        .arg(input_file)
312
42
        .output()?;
313

            
314
42
    if !output.status.success() {
315
        println!("Parsing error: {}", String::from_utf8(output.stderr)?);
316
42
    }
317

            
318
42
    let astjson = String::from_utf8(output.stdout)?;
319

            
320
42
    if cfg!(feature = "extra-rule-checks") {
321
42
        tracing::info!("extra-rule-checks: enabled");
322
    } else {
323
        tracing::info!("extra-rule-checks: disabled");
324
    }
325

            
326
42
    model_from_json(&astjson, context.clone()).map_err(|e| anyhow!(e))
327
42
}
328

            
329
18
pub(crate) fn rewrite(
330
18
    model: Model,
331
18
    global_args: &GlobalArgs,
332
18
    context: Arc<RwLock<Context<'static>>>,
333
18
) -> anyhow::Result<Model> {
334
18
    tracing::info!("Initial model: \n{}\n", model);
335

            
336
18
    set_current_rewriter(global_args.rewriter);
337

            
338
18
    let comprehension_expander = global_args.comprehension_expander;
339
18
    set_comprehension_expander(comprehension_expander);
340
18
    tracing::info!("Comprehension expander: {}", comprehension_expander);
341

            
342
18
    let rule_sets = context.read().unwrap().rule_sets.clone();
343

            
344
18
    let new_model = match global_args.rewriter {
345
        Rewriter::Morph => {
346
            tracing::info!("Rewriting the model using the morph rewriter");
347
            rewrite_morph(
348
                model,
349
                &rule_sets,
350
                global_args.check_equally_applicable_rules,
351
            )
352
        }
353
        Rewriter::Naive => {
354
18
            tracing::info!("Rewriting the model using the default / naive rewriter");
355
18
            rewrite_naive(
356
18
                &model,
357
18
                &rule_sets,
358
18
                global_args.check_equally_applicable_rules,
359
            )?
360
        }
361
    };
362

            
363
18
    tracing::info!("Rewritten model: \n{}\n", new_model);
364
18
    Ok(new_model)
365
18
}
366

            
367
12
fn run_solver(
368
12
    solver: Solver,
369
12
    global_args: &GlobalArgs,
370
12
    cmd_args: &Args,
371
12
    model: Model,
372
12
) -> anyhow::Result<()> {
373
12
    let out_file: Option<File> = match &cmd_args.output {
374
12
        None => None,
375
        Some(pth) => Some(
376
            File::options()
377
                .create(true)
378
                .truncate(true)
379
                .write(true)
380
                .open(pth)?,
381
        ),
382
    };
383

            
384
12
    let solutions = get_solutions(
385
12
        solver,
386
12
        model,
387
12
        cmd_args.number_of_solutions.as_solver_limit(),
388
12
        &global_args.save_solver_input_file,
389
    )?;
390
12
    tracing::info!(target: "file", "Solutions: {}", solutions_to_json(&solutions));
391

            
392
12
    let solutions_json = solutions_to_json(&solutions);
393
12
    let solutions_str = to_string_pretty(&solutions_json)?;
394
12
    match out_file {
395
12
        None => {
396
12
            println!("Solutions:");
397
12
            println!("{solutions_str}");
398
12
        }
399
        Some(mut outf) => {
400
            outf.write_all(solutions_str.as_bytes())?;
401
            println!(
402
                "Solutions saved to {:?}",
403
                &cmd_args.output.clone().unwrap().canonicalize()?
404
            )
405
        }
406
    }
407
12
    Ok(())
408
12
}