1use std::{cell::Cell, fmt::Display, str::FromStr};
2
3use schemars::JsonSchema;
4use serde::{Deserialize, Serialize};
5use strum_macros::{Display as StrumDisplay, EnumIter};
6
7use crate::bug;
8
9use crate::solver::adaptors::smt::{IntTheory, MatrixTheory, TheoryConfig};
10
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
12pub enum Parser {
13 #[default]
14 TreeSitter,
15 ViaConjure,
16}
17
18impl Display for Parser {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 match self {
21 Parser::TreeSitter => write!(f, "tree-sitter"),
22 Parser::ViaConjure => write!(f, "via-conjure"),
23 }
24 }
25}
26
27impl FromStr for Parser {
28 type Err = String;
29
30 fn from_str(s: &str) -> Result<Self, Self::Err> {
31 match s.trim().to_ascii_lowercase().as_str() {
32 "tree-sitter" => Ok(Parser::TreeSitter),
33 "via-conjure" => Ok(Parser::ViaConjure),
34 other => Err(format!(
35 "unknown parser: {other}; expected one of: tree-sitter, via-conjure"
36 )),
37 }
38 }
39}
40
41thread_local! {
42 static CURRENT_PARSER: Cell<Option<Parser>> = const { Cell::new(None) };
46}
47
48pub fn set_current_parser(parser: Parser) {
49 CURRENT_PARSER.with(|current| current.set(Some(parser)));
50}
51
52pub fn current_parser() -> Parser {
53 CURRENT_PARSER.with(|current| {
54 current.get().unwrap_or_else(|| {
55 bug!("current parser not set for this thread; call set_current_parser first")
57 })
58 })
59}
60
61#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
62pub enum Rewriter {
63 Naive,
64 Morph,
65}
66
67thread_local! {
68 static CURRENT_REWRITER: Cell<Option<Rewriter>> = const { Cell::new(None) };
72}
73
74impl Display for Rewriter {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 match self {
77 Rewriter::Naive => write!(f, "naive"),
78 Rewriter::Morph => write!(f, "morph"),
79 }
80 }
81}
82
83impl FromStr for Rewriter {
84 type Err = String;
85
86 fn from_str(s: &str) -> Result<Self, Self::Err> {
87 match s.trim().to_ascii_lowercase().as_str() {
88 "naive" => Ok(Rewriter::Naive),
89 "morph" => Ok(Rewriter::Morph),
90 other => Err(format!(
91 "unknown rewriter: {other}; expected one of: naive, morph"
92 )),
93 }
94 }
95}
96
97pub fn set_current_rewriter(rewriter: Rewriter) {
98 CURRENT_REWRITER.with(|current| current.set(Some(rewriter)));
99}
100
101pub fn current_rewriter() -> Rewriter {
102 CURRENT_REWRITER.with(|current| {
103 current.get().unwrap_or_else(|| {
104 bug!("current rewriter not set for this thread; call set_current_rewriter first")
106 })
107 })
108}
109
110#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
111pub enum QuantifiedExpander {
112 Native,
113 ViaSolver,
114 ViaSolverAc,
115}
116
117impl Display for QuantifiedExpander {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 match self {
120 QuantifiedExpander::Native => write!(f, "native"),
121 QuantifiedExpander::ViaSolver => write!(f, "via-solver"),
122 QuantifiedExpander::ViaSolverAc => write!(f, "via-solver-ac"),
123 }
124 }
125}
126
127impl FromStr for QuantifiedExpander {
128 type Err = String;
129
130 fn from_str(s: &str) -> Result<Self, Self::Err> {
131 match s.trim().to_ascii_lowercase().as_str() {
132 "native" => Ok(QuantifiedExpander::Native),
133 "via-solver" => Ok(QuantifiedExpander::ViaSolver),
134 "via-solver-ac" => Ok(QuantifiedExpander::ViaSolverAc),
135 _ => Err(format!(
136 "unknown comprehension expander: {s}; expected one of: \
137 native, via-solver, via-solver-ac"
138 )),
139 }
140 }
141}
142
143thread_local! {
144 static COMPREHENSION_EXPANDER: Cell<Option<QuantifiedExpander>> = const { Cell::new(None) };
148}
149
150pub fn set_comprehension_expander(expander: QuantifiedExpander) {
151 COMPREHENSION_EXPANDER.with(|current| current.set(Some(expander)));
152}
153
154pub fn comprehension_expander() -> QuantifiedExpander {
155 COMPREHENSION_EXPANDER.with(|current| {
156 current.get().unwrap_or_else(|| {
157 bug!(
159 "comprehension expander not set for this thread; call set_comprehension_expander first"
160 )
161 })
162 })
163}
164
165#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default, Serialize, Deserialize, JsonSchema)]
166pub enum SatEncoding {
167 #[default]
168 Log,
169 Direct,
170 Order,
171}
172
173impl SatEncoding {
174 pub const fn as_str(self) -> &'static str {
175 match self {
176 SatEncoding::Log => "log",
177 SatEncoding::Direct => "direct",
178 SatEncoding::Order => "order",
179 }
180 }
181
182 pub const fn as_rule_set(self) -> &'static str {
183 match self {
184 SatEncoding::Log => "SAT_Log",
185 SatEncoding::Direct => "SAT_Direct",
186 SatEncoding::Order => "SAT_Order",
187 }
188 }
189}
190
191impl Display for SatEncoding {
192 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193 match self {
194 SatEncoding::Log => write!(f, "log"),
195 SatEncoding::Direct => write!(f, "direct"),
196 SatEncoding::Order => write!(f, "order"),
197 }
198 }
199}
200
201impl FromStr for SatEncoding {
202 type Err = String;
203
204 fn from_str(s: &str) -> Result<Self, Self::Err> {
205 match s.trim().to_ascii_lowercase().as_str() {
206 "log" => Ok(SatEncoding::Log),
207 "direct" => Ok(SatEncoding::Direct),
208 "order" => Ok(SatEncoding::Order),
209 other => Err(format!(
210 "unknown sat-encoding: {other}; expected one of: log, direct, order"
211 )),
212 }
213 }
214}
215
216#[derive(
217 Debug,
218 EnumIter,
219 StrumDisplay,
220 PartialEq,
221 Eq,
222 Hash,
223 Clone,
224 Copy,
225 Serialize,
226 Deserialize,
227 JsonSchema,
228)]
229pub enum SolverFamily {
230 Minion,
231 Sat(SatEncoding),
232 Smt(TheoryConfig),
233}
234
235thread_local! {
236 static CURRENT_SOLVER_FAMILY: Cell<Option<SolverFamily>> = const { Cell::new(None) };
240}
241
242pub const DEFAULT_MINION_DISCRETE_THRESHOLD: usize = 10;
243
244thread_local! {
245 static MINION_DISCRETE_THRESHOLD: Cell<usize> =
250 const { Cell::new(DEFAULT_MINION_DISCRETE_THRESHOLD) };
251}
252
253pub fn set_current_solver_family(solver_family: SolverFamily) {
254 CURRENT_SOLVER_FAMILY.with(|current| current.set(Some(solver_family)));
255}
256
257pub fn current_solver_family() -> SolverFamily {
258 CURRENT_SOLVER_FAMILY.with(|current| {
259 current.get().unwrap_or_else(|| {
260 bug!(
262 "current solver family not set for this thread; call set_current_solver_family first"
263 )
264 })
265 })
266}
267
268pub fn set_minion_discrete_threshold(threshold: usize) {
269 MINION_DISCRETE_THRESHOLD.with(|current| current.set(threshold));
270}
271
272pub fn minion_discrete_threshold() -> usize {
273 MINION_DISCRETE_THRESHOLD.with(|current| current.get())
274}
275
276impl FromStr for SolverFamily {
277 type Err = String;
278
279 fn from_str(s: &str) -> Result<Self, Self::Err> {
280 let s = s.trim().to_ascii_lowercase();
281
282 match s.as_str() {
283 "minion" => Ok(SolverFamily::Minion),
284 "sat" | "sat-log" => Ok(SolverFamily::Sat(SatEncoding::Log)),
285 "sat-direct" => Ok(SolverFamily::Sat(SatEncoding::Direct)),
286 "sat-order" => Ok(SolverFamily::Sat(SatEncoding::Order)),
287 "smt" => Ok(SolverFamily::Smt(TheoryConfig::default())),
288 other => {
289 if other.starts_with("smt-") {
291 let parts = other.split('-').skip(1);
292 let mut ints = IntTheory::default();
293 let mut matrices = MatrixTheory::default();
294 let mut unwrap_alldiff = false;
295
296 for token in parts {
297 match token {
298 "" => {}
299 "lia" => ints = IntTheory::Lia,
300 "bv" => ints = IntTheory::Bv,
301 "arrays" => matrices = MatrixTheory::Arrays,
302 "atomic" => matrices = MatrixTheory::Atomic,
303 "nodiscrete" => unwrap_alldiff = true,
304 other_token => {
305 return Err(format!(
306 "unknown SMT theory option '{other_token}', must be one of bv|lia|arrays|atomic|nodiscrete"
307 ));
308 }
309 }
310 }
311
312 return Ok(SolverFamily::Smt(TheoryConfig {
313 ints,
314 matrices,
315 unwrap_alldiff,
316 }));
317 }
318 Err(format!(
319 "unknown solver family '{other}', expected one of: minion, sat-log, sat-direct, sat-order, smt[(bv|lia)-(arrays|atomic)][-nodiscrete]"
320 ))
321 }
322 }
323 }
324}
325
326impl SolverFamily {
327 pub fn as_str(&self) -> String {
328 match self {
329 SolverFamily::Minion => "minion".to_owned(),
330 SolverFamily::Sat(encoding) => format!("sat-{}", encoding.as_str()),
331 SolverFamily::Smt(theory_config) => format!("smt-{}", theory_config.as_str()),
332 }
333 }
334}
335
336#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
337pub struct SolverArgs {
338 pub timeout_ms: Option<u64>,
339}