1
use std::{cell::Cell, fmt::Display, str::FromStr};
2

            
3
use schemars::JsonSchema;
4
use serde::{Deserialize, Serialize};
5
use strum_macros::{Display as StrumDisplay, EnumIter};
6

            
7
use crate::bug;
8

            
9
use crate::solver::adaptors::smt::{IntTheory, MatrixTheory, TheoryConfig};
10

            
11
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
12
pub enum Parser {
13
    #[default]
14
    TreeSitter,
15
    ViaConjure,
16
}
17

            
18
impl Display for Parser {
19
59424
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20
59424
        match self {
21
5784
            Parser::TreeSitter => write!(f, "tree-sitter"),
22
53640
            Parser::ViaConjure => write!(f, "via-conjure"),
23
        }
24
59424
    }
25
}
26

            
27
impl FromStr for Parser {
28
    type Err = String;
29

            
30
24226
    fn from_str(s: &str) -> Result<Self, Self::Err> {
31
24226
        match s.trim().to_ascii_lowercase().as_str() {
32
24226
            "tree-sitter" => Ok(Parser::TreeSitter),
33
20254
            "via-conjure" => Ok(Parser::ViaConjure),
34
            other => Err(format!(
35
                "unknown parser: {other}; expected one of: tree-sitter, via-conjure"
36
            )),
37
        }
38
24226
    }
39
}
40

            
41
thread_local! {
42
    /// Thread-local setting for which parser is currently active.
43
    ///
44
    /// Must be explicitly set before use.
45
    static CURRENT_PARSER: Cell<Option<Parser>> = const { Cell::new(None) };
46
}
47

            
48
26734
pub fn set_current_parser(parser: Parser) {
49
26734
    CURRENT_PARSER.with(|current| current.set(Some(parser)));
50
26734
}
51

            
52
pub fn current_parser() -> Parser {
53
    CURRENT_PARSER.with(|current| {
54
        current.get().unwrap_or_else(|| {
55
            // loud failure on purpose, so we don't end up using the default
56
            bug!("current parser not set for this thread; call set_current_parser first")
57
        })
58
    })
59
}
60

            
61
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
62
pub enum Rewriter {
63
    Naive,
64
    Morph,
65
}
66

            
67
thread_local! {
68
    /// Thread-local setting for which rewriter is currently active.
69
    ///
70
    /// Must be explicitly set before use.
71
    static CURRENT_REWRITER: Cell<Option<Rewriter>> = const { Cell::new(None) };
72
}
73

            
74
impl Display for Rewriter {
75
53398
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76
53398
        match self {
77
53398
            Rewriter::Naive => write!(f, "naive"),
78
            Rewriter::Morph => write!(f, "morph"),
79
        }
80
53398
    }
81
}
82

            
83
impl FromStr for Rewriter {
84
    type Err = String;
85

            
86
16828
    fn from_str(s: &str) -> Result<Self, Self::Err> {
87
16828
        match s.trim().to_ascii_lowercase().as_str() {
88
16828
            "naive" => Ok(Rewriter::Naive),
89
            "morph" => Ok(Rewriter::Morph),
90
            other => Err(format!(
91
                "unknown rewriter: {other}; expected one of: naive, morph"
92
            )),
93
        }
94
16828
    }
95
}
96

            
97
60062
pub fn set_current_rewriter(rewriter: Rewriter) {
98
60062
    CURRENT_REWRITER.with(|current| current.set(Some(rewriter)));
99
60062
}
100

            
101
4605
pub fn current_rewriter() -> Rewriter {
102
4605
    CURRENT_REWRITER.with(|current| {
103
4605
        current.get().unwrap_or_else(|| {
104
            // loud failure on purpose, so we don't end up using the default
105
            bug!("current rewriter not set for this thread; call set_current_rewriter first")
106
        })
107
4605
    })
108
4605
}
109

            
110
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
111
pub enum QuantifiedExpander {
112
    Native,
113
    ViaSolver,
114
    ViaSolverAc,
115
}
116

            
117
impl Display for QuantifiedExpander {
118
53422
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119
53422
        match self {
120
3658
            QuantifiedExpander::Native => write!(f, "native"),
121
2088
            QuantifiedExpander::ViaSolver => write!(f, "via-solver"),
122
47676
            QuantifiedExpander::ViaSolverAc => write!(f, "via-solver-ac"),
123
        }
124
53422
    }
125
}
126

            
127
impl FromStr for QuantifiedExpander {
128
    type Err = String;
129

            
130
18452
    fn from_str(s: &str) -> Result<Self, Self::Err> {
131
18452
        match s.trim().to_ascii_lowercase().as_str() {
132
18452
            "native" => Ok(QuantifiedExpander::Native),
133
16712
            "via-solver" => Ok(QuantifiedExpander::ViaSolver),
134
15668
            "via-solver-ac" => Ok(QuantifiedExpander::ViaSolverAc),
135
            _ => Err(format!(
136
                "unknown comprehension expander: {s}; expected one of: \
137
                 native, via-solver, via-solver-ac"
138
            )),
139
        }
140
18452
    }
141
}
142

            
143
thread_local! {
144
    /// Thread-local setting for which comprehension expansion strategy is currently active.
145
    ///
146
    /// Must be explicitly set before use.
147
    static COMPREHENSION_EXPANDER: Cell<Option<QuantifiedExpander>> = const { Cell::new(None) };
148
}
149

            
150
26822
pub fn set_comprehension_expander(expander: QuantifiedExpander) {
151
26822
    COMPREHENSION_EXPANDER.with(|current| current.set(Some(expander)));
152
26822
}
153

            
154
1734300
pub fn comprehension_expander() -> QuantifiedExpander {
155
1734300
    COMPREHENSION_EXPANDER.with(|current| {
156
1734300
        current.get().unwrap_or_else(|| {
157
            // loud failure on purpose, so we don't end up using the default
158
            bug!(
159
                "comprehension expander not set for this thread; call set_comprehension_expander first"
160
            )
161
        })
162
1734300
    })
163
1734300
}
164

            
165
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default, Serialize, Deserialize, JsonSchema)]
166
pub enum SatEncoding {
167
    #[default]
168
    Log,
169
    Direct,
170
    Order,
171
}
172

            
173
impl SatEncoding {
174
46378
    pub const fn as_str(self) -> &'static str {
175
46378
        match self {
176
16334
            SatEncoding::Log => "log",
177
19082
            SatEncoding::Direct => "direct",
178
10962
            SatEncoding::Order => "order",
179
        }
180
46378
    }
181

            
182
6654
    pub const fn as_rule_set(self) -> &'static str {
183
6654
        match self {
184
2362
            SatEncoding::Log => "SAT_Log",
185
2726
            SatEncoding::Direct => "SAT_Direct",
186
1566
            SatEncoding::Order => "SAT_Order",
187
        }
188
6654
    }
189
}
190

            
191
impl Display for SatEncoding {
192
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193
        match self {
194
            SatEncoding::Log => write!(f, "log"),
195
            SatEncoding::Direct => write!(f, "direct"),
196
            SatEncoding::Order => write!(f, "order"),
197
        }
198
    }
199
}
200

            
201
impl FromStr for SatEncoding {
202
    type Err = String;
203

            
204
    fn from_str(s: &str) -> Result<Self, Self::Err> {
205
        match s.trim().to_ascii_lowercase().as_str() {
206
            "log" => Ok(SatEncoding::Log),
207
            "direct" => Ok(SatEncoding::Direct),
208
            "order" => Ok(SatEncoding::Order),
209
            other => Err(format!(
210
                "unknown sat-encoding: {other}; expected one of: log, direct, order"
211
            )),
212
        }
213
    }
214
}
215

            
216
#[derive(
217
    Debug,
218
    EnumIter,
219
    StrumDisplay,
220
    PartialEq,
221
    Eq,
222
    Hash,
223
    Clone,
224
    Copy,
225
    Serialize,
226
    Deserialize,
227
    JsonSchema,
228
)]
229
pub enum SolverFamily {
230
    Minion,
231
    Sat(SatEncoding),
232
    Smt(TheoryConfig),
233
}
234

            
235
thread_local! {
236
    /// Thread-local setting for which solver family is currently active.
237
    ///
238
    /// Must be explicitly set before use.
239
    static CURRENT_SOLVER_FAMILY: Cell<Option<SolverFamily>> = const { Cell::new(None) };
240
}
241

            
242
pub const DEFAULT_MINION_DISCRETE_THRESHOLD: usize = 10;
243

            
244
thread_local! {
245
    /// Thread-local setting controlling when Minion int domains are emitted as `DISCRETE`.
246
    ///
247
    /// If an int domain size is <= this threshold, the Minion adaptor uses `DISCRETE`; otherwise
248
    /// it uses `BOUND`, unless another constraint requires `DISCRETE`.
249
    static MINION_DISCRETE_THRESHOLD: Cell<usize> =
250
        const { Cell::new(DEFAULT_MINION_DISCRETE_THRESHOLD) };
251
}
252

            
253
26734
pub fn set_current_solver_family(solver_family: SolverFamily) {
254
26734
    CURRENT_SOLVER_FAMILY.with(|current| current.set(Some(solver_family)));
255
26734
}
256

            
257
pub fn current_solver_family() -> SolverFamily {
258
    CURRENT_SOLVER_FAMILY.with(|current| {
259
        current.get().unwrap_or_else(|| {
260
            // loud failure on purpose, so we don't end up using the default
261
            bug!(
262
                "current solver family not set for this thread; call set_current_solver_family first"
263
            )
264
        })
265
    })
266
}
267

            
268
26734
pub fn set_minion_discrete_threshold(threshold: usize) {
269
26734
    MINION_DISCRETE_THRESHOLD.with(|current| current.set(threshold));
270
26734
}
271

            
272
229256
pub fn minion_discrete_threshold() -> usize {
273
229256
    MINION_DISCRETE_THRESHOLD.with(|current| current.get())
274
229256
}
275

            
276
impl FromStr for SolverFamily {
277
    type Err = String;
278

            
279
23846
    fn from_str(s: &str) -> Result<Self, Self::Err> {
280
23846
        let s = s.trim().to_ascii_lowercase();
281

            
282
23846
        match s.as_str() {
283
23846
            "minion" => Ok(SolverFamily::Minion),
284
9728
            "sat" | "sat-log" => Ok(SolverFamily::Sat(SatEncoding::Log)),
285
7366
            "sat-direct" => Ok(SolverFamily::Sat(SatEncoding::Direct)),
286
4640
            "sat-order" => Ok(SolverFamily::Sat(SatEncoding::Order)),
287
3074
            "smt" => Ok(SolverFamily::Smt(TheoryConfig::default())),
288
3074
            other => {
289
                // allow forms like `smt-bv-atomic` or `smt-lia-arrays`
290
3074
                if other.starts_with("smt-") {
291
3074
                    let parts = other.split('-').skip(1);
292
3074
                    let mut ints = IntTheory::default();
293
3074
                    let mut matrices = MatrixTheory::default();
294
3074
                    let mut unwrap_alldiff = false;
295

            
296
6380
                    for token in parts {
297
6380
                        match token {
298
6380
                            "" => {}
299
6380
                            "lia" => ints = IntTheory::Lia,
300
3538
                            "bv" => ints = IntTheory::Bv,
301
3306
                            "arrays" => matrices = MatrixTheory::Arrays,
302
464
                            "atomic" => matrices = MatrixTheory::Atomic,
303
232
                            "nodiscrete" => unwrap_alldiff = true,
304
                            other_token => {
305
                                return Err(format!(
306
                                    "unknown SMT theory option '{other_token}', must be one of bv|lia|arrays|atomic|nodiscrete"
307
                                ));
308
                            }
309
                        }
310
                    }
311

            
312
3074
                    return Ok(SolverFamily::Smt(TheoryConfig {
313
3074
                        ints,
314
3074
                        matrices,
315
3074
                        unwrap_alldiff,
316
3074
                    }));
317
                }
318
                Err(format!(
319
                    "unknown solver family '{other}', expected one of: minion, sat-log, sat-direct, sat-order, smt[(bv|lia)-(arrays|atomic)][-nodiscrete]"
320
                ))
321
            }
322
        }
323
23846
    }
324
}
325

            
326
impl SolverFamily {
327
186448
    pub fn as_str(&self) -> String {
328
186448
        match self {
329
118552
            SolverFamily::Minion => "minion".to_owned(),
330
46378
            SolverFamily::Sat(encoding) => format!("sat-{}", encoding.as_str()),
331
21518
            SolverFamily::Smt(theory_config) => format!("smt-{}", theory_config.as_str()),
332
        }
333
186448
    }
334
}
335

            
336
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
337
pub struct SolverArgs {
338
    pub timeout_ms: Option<u64>,
339
}