1
use std::{cell::Cell, fmt::Display, str::FromStr};
2

            
3
use schemars::JsonSchema;
4
use serde::{Deserialize, Serialize};
5
use strum_macros::{Display as StrumDisplay, EnumIter};
6

            
7
use crate::bug;
8

            
9
#[cfg(feature = "smt")]
10
use crate::solver::adaptors::smt::{IntTheory, MatrixTheory, TheoryConfig};
11

            
12
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
13
pub enum Parser {
14
    #[default]
15
    TreeSitter,
16
    ViaConjure,
17
}
18

            
19
impl Display for Parser {
20
53168
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21
53168
        match self {
22
3540
            Parser::TreeSitter => write!(f, "tree-sitter"),
23
49628
            Parser::ViaConjure => write!(f, "via-conjure"),
24
        }
25
53168
    }
26
}
27

            
28
impl FromStr for Parser {
29
    type Err = String;
30

            
31
19916
    fn from_str(s: &str) -> Result<Self, Self::Err> {
32
19916
        match s.trim().to_ascii_lowercase().as_str() {
33
19916
            "tree-sitter" => Ok(Parser::TreeSitter),
34
17974
            "via-conjure" => Ok(Parser::ViaConjure),
35
            other => Err(format!(
36
                "unknown parser: {other}; expected one of: tree-sitter, via-conjure"
37
            )),
38
        }
39
19916
    }
40
}
41

            
42
thread_local! {
43
    /// Thread-local setting for which parser is currently active.
44
    ///
45
    /// Must be explicitly set before use.
46
    static CURRENT_PARSER: Cell<Option<Parser>> = const { Cell::new(None) };
47
}
48

            
49
25868
pub fn set_current_parser(parser: Parser) {
50
25868
    CURRENT_PARSER.with(|current| current.set(Some(parser)));
51
25868
}
52

            
53
pub fn current_parser() -> Parser {
54
    CURRENT_PARSER.with(|current| {
55
        current.get().unwrap_or_else(|| {
56
            // loud failure on purpose, so we don't end up using the default
57
            bug!("current parser not set for this thread; call set_current_parser first")
58
        })
59
    })
60
}
61

            
62
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
63
pub enum Rewriter {
64
    Naive,
65
    Morph,
66
}
67

            
68
thread_local! {
69
    /// Thread-local setting for which rewriter is currently active.
70
    ///
71
    /// Must be explicitly set before use.
72
    static CURRENT_REWRITER: Cell<Option<Rewriter>> = const { Cell::new(None) };
73
}
74

            
75
impl Display for Rewriter {
76
51688
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77
51688
        match self {
78
51688
            Rewriter::Naive => write!(f, "naive"),
79
            Rewriter::Morph => write!(f, "morph"),
80
        }
81
51688
    }
82
}
83

            
84
impl FromStr for Rewriter {
85
    type Err = String;
86

            
87
17176
    fn from_str(s: &str) -> Result<Self, Self::Err> {
88
17176
        match s.trim().to_ascii_lowercase().as_str() {
89
17176
            "naive" => Ok(Rewriter::Naive),
90
            "morph" => Ok(Rewriter::Morph),
91
            other => Err(format!(
92
                "unknown rewriter: {other}; expected one of: naive, morph"
93
            )),
94
        }
95
17176
    }
96
}
97

            
98
57848
pub fn set_current_rewriter(rewriter: Rewriter) {
99
57848
    CURRENT_REWRITER.with(|current| current.set(Some(rewriter)));
100
57848
}
101

            
102
2040
pub fn current_rewriter() -> Rewriter {
103
2040
    CURRENT_REWRITER.with(|current| {
104
2040
        current.get().unwrap_or_else(|| {
105
            // loud failure on purpose, so we don't end up using the default
106
            bug!("current rewriter not set for this thread; call set_current_rewriter first")
107
        })
108
2040
    })
109
2040
}
110

            
111
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
112
pub enum QuantifiedExpander {
113
    Native,
114
    ViaSolver,
115
    ViaSolverAc,
116
}
117

            
118
impl Display for QuantifiedExpander {
119
51724
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120
51724
        match self {
121
3124
            QuantifiedExpander::Native => write!(f, "native"),
122
2160
            QuantifiedExpander::ViaSolver => write!(f, "via-solver"),
123
46440
            QuantifiedExpander::ViaSolverAc => write!(f, "via-solver-ac"),
124
        }
125
51724
    }
126
}
127

            
128
impl FromStr for QuantifiedExpander {
129
    type Err = String;
130

            
131
18856
    fn from_str(s: &str) -> Result<Self, Self::Err> {
132
18856
        match s.trim().to_ascii_lowercase().as_str() {
133
18856
            "native" => Ok(QuantifiedExpander::Native),
134
17240
            "via-solver" => Ok(QuantifiedExpander::ViaSolver),
135
16160
            "via-solver-ac" => Ok(QuantifiedExpander::ViaSolverAc),
136
            _ => Err(format!(
137
                "unknown comprehension expander: {s}; expected one of: \
138
                 native, via-solver, via-solver-ac"
139
            )),
140
        }
141
18856
    }
142
}
143

            
144
thread_local! {
145
    /// Thread-local setting for which comprehension expansion strategy is currently active.
146
    ///
147
    /// Must be explicitly set before use.
148
    static COMPREHENSION_EXPANDER: Cell<Option<QuantifiedExpander>> = const { Cell::new(None) };
149
}
150

            
151
25958
pub fn set_comprehension_expander(expander: QuantifiedExpander) {
152
25958
    COMPREHENSION_EXPANDER.with(|current| current.set(Some(expander)));
153
25958
}
154

            
155
876780
pub fn comprehension_expander() -> QuantifiedExpander {
156
876780
    COMPREHENSION_EXPANDER.with(|current| {
157
876780
        current.get().unwrap_or_else(|| {
158
            // loud failure on purpose, so we don't end up using the default
159
            bug!(
160
                "comprehension expander not set for this thread; call set_comprehension_expander first"
161
            )
162
        })
163
876780
    })
164
876780
}
165

            
166
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default, Serialize, Deserialize, JsonSchema)]
167
pub enum SatEncoding {
168
    #[default]
169
    Log,
170
    Direct,
171
    Order,
172
}
173

            
174
impl SatEncoding {
175
41320
    pub const fn as_str(self) -> &'static str {
176
41320
        match self {
177
16260
            SatEncoding::Log => "log",
178
14420
            SatEncoding::Direct => "direct",
179
10640
            SatEncoding::Order => "order",
180
        }
181
41320
    }
182

            
183
5920
    pub const fn as_rule_set(self) -> &'static str {
184
5920
        match self {
185
2340
            SatEncoding::Log => "SAT_Log",
186
2060
            SatEncoding::Direct => "SAT_Direct",
187
1520
            SatEncoding::Order => "SAT_Order",
188
        }
189
5920
    }
190
}
191

            
192
impl Display for SatEncoding {
193
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194
        match self {
195
            SatEncoding::Log => write!(f, "log"),
196
            SatEncoding::Direct => write!(f, "direct"),
197
            SatEncoding::Order => write!(f, "order"),
198
        }
199
    }
200
}
201

            
202
impl FromStr for SatEncoding {
203
    type Err = String;
204

            
205
    fn from_str(s: &str) -> Result<Self, Self::Err> {
206
        match s.trim().to_ascii_lowercase().as_str() {
207
            "log" => Ok(SatEncoding::Log),
208
            "direct" => Ok(SatEncoding::Direct),
209
            "order" => Ok(SatEncoding::Order),
210
            other => Err(format!(
211
                "unknown sat-encoding: {other}; expected one of: log, direct, order"
212
            )),
213
        }
214
    }
215
}
216

            
217
#[derive(
218
    Debug,
219
    EnumIter,
220
    StrumDisplay,
221
    PartialEq,
222
    Eq,
223
    Hash,
224
    Clone,
225
    Copy,
226
    Serialize,
227
    Deserialize,
228
    JsonSchema,
229
)]
230
pub enum SolverFamily {
231
    Minion,
232
    Sat(SatEncoding),
233
    #[cfg(feature = "smt")]
234
    Smt(TheoryConfig),
235
}
236

            
237
thread_local! {
238
    /// Thread-local setting for which solver family is currently active.
239
    ///
240
    /// Must be explicitly set before use.
241
    static CURRENT_SOLVER_FAMILY: Cell<Option<SolverFamily>> = const { Cell::new(None) };
242
}
243

            
244
pub const DEFAULT_MINION_DISCRETE_THRESHOLD: usize = 10;
245

            
246
thread_local! {
247
    /// Thread-local setting controlling when Minion int domains are emitted as `DISCRETE`.
248
    ///
249
    /// If an int domain size is <= this threshold, the Minion adaptor uses `DISCRETE`; otherwise
250
    /// it uses `BOUND`, unless another constraint requires `DISCRETE`.
251
    static MINION_DISCRETE_THRESHOLD: Cell<usize> =
252
        const { Cell::new(DEFAULT_MINION_DISCRETE_THRESHOLD) };
253
}
254

            
255
25868
pub fn set_current_solver_family(solver_family: SolverFamily) {
256
25868
    CURRENT_SOLVER_FAMILY.with(|current| current.set(Some(solver_family)));
257
25868
}
258

            
259
pub fn current_solver_family() -> SolverFamily {
260
    CURRENT_SOLVER_FAMILY.with(|current| {
261
        current.get().unwrap_or_else(|| {
262
            // loud failure on purpose, so we don't end up using the default
263
            bug!(
264
                "current solver family not set for this thread; call set_current_solver_family first"
265
            )
266
        })
267
    })
268
}
269

            
270
25868
pub fn set_minion_discrete_threshold(threshold: usize) {
271
25868
    MINION_DISCRETE_THRESHOLD.with(|current| current.set(threshold));
272
25868
}
273

            
274
236080
pub fn minion_discrete_threshold() -> usize {
275
236080
    MINION_DISCRETE_THRESHOLD.with(|current| current.get())
276
236080
}
277

            
278
impl FromStr for SolverFamily {
279
    type Err = String;
280

            
281
23076
    fn from_str(s: &str) -> Result<Self, Self::Err> {
282
23076
        let s = s.trim().to_ascii_lowercase();
283

            
284
23076
        match s.as_str() {
285
23076
            "minion" => Ok(SolverFamily::Minion),
286
8660
            "sat" | "sat-log" => Ok(SolverFamily::Sat(SatEncoding::Log)),
287
6320
            "sat-direct" => Ok(SolverFamily::Sat(SatEncoding::Direct)),
288
4220
            "sat-order" => Ok(SolverFamily::Sat(SatEncoding::Order)),
289
            #[cfg(feature = "smt")]
290
2700
            "smt" => Ok(SolverFamily::Smt(TheoryConfig::default())),
291
2700
            other => {
292
                // allow forms like `smt-bv-atomic` or `smt-lia-arrays`
293
                #[cfg(feature = "smt")]
294
2700
                if other.starts_with("smt-") {
295
2700
                    let parts = other.split('-').skip(1);
296
2700
                    let mut ints = IntTheory::default();
297
2700
                    let mut matrices = MatrixTheory::default();
298
2700
                    let mut unwrap_alldiff = false;
299

            
300
5400
                    for token in parts {
301
5400
                        match token {
302
5400
                            "" => {}
303
5400
                            "lia" => ints = IntTheory::Lia,
304
2700
                            "bv" => ints = IntTheory::Bv,
305
2700
                            "arrays" => matrices = MatrixTheory::Arrays,
306
                            "atomic" => matrices = MatrixTheory::Atomic,
307
                            "nodiscrete" => unwrap_alldiff = true,
308
                            other_token => {
309
                                return Err(format!(
310
                                    "unknown SMT theory option '{other_token}', must be one of bv|lia|arrays|atomic|nodiscrete"
311
                                ));
312
                            }
313
                        }
314
                    }
315

            
316
2700
                    return Ok(SolverFamily::Smt(TheoryConfig {
317
2700
                        ints,
318
2700
                        matrices,
319
2700
                        unwrap_alldiff,
320
2700
                    }));
321
                }
322
                Err(format!(
323
                    "unknown solver family '{other}', expected one of: minion, sat-log, sat-direct, sat-order, smt[(bv|lia)-(arrays|atomic)][-nodiscrete]"
324
                ))
325
            }
326
        }
327
23076
    }
328
}
329

            
330
impl SolverFamily {
331
180220
    pub fn as_str(&self) -> String {
332
180220
        match self {
333
120000
            SolverFamily::Minion => "minion".to_owned(),
334
41320
            SolverFamily::Sat(encoding) => format!("sat-{}", encoding.as_str()),
335
            #[cfg(feature = "smt")]
336
18900
            SolverFamily::Smt(theory_config) => format!("smt-{}", theory_config.as_str()),
337
        }
338
180220
    }
339
}
340

            
341
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
342
pub struct SolverArgs {
343
    pub timeout_ms: Option<u64>,
344
}