1use crate::bug;
6use crate::representation::{Representation, get_repr_rule};
7use std::any::TypeId;
8
9use std::collections::BTreeSet;
10use std::collections::btree_map::Entry;
11use std::collections::{BTreeMap, VecDeque};
12use std::hash::{Hash, Hasher};
13use std::sync::Arc;
14use std::sync::atomic::{AtomicU32, Ordering};
15
16use super::comprehension::Comprehension;
17use super::serde::{AsId, DefaultWithId, HasId, IdPtr, ObjId, PtrAsInner};
18use super::{
19 DeclarationPtr, DomainPtr, Expression, GroundDomain, Model, Moo, Name, ReturnType, Typeable,
20};
21use itertools::{Itertools as _, izip};
22use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
23use serde::{Deserialize, Serialize};
24use serde_with::serde_as;
25use tracing::trace;
26use uniplate::{Biplate, Tree, Uniplate};
27
28static SYMBOL_TABLE_ID_COUNTER: AtomicU32 = const { AtomicU32::new(0) };
33
34#[derive(Debug, Clone, PartialEq, Eq, Hash)]
35pub struct SymbolTablePtr
36where
37 Self: Send + Sync,
38{
39 inner: Arc<SymbolTablePtrInner>,
40}
41
42impl SymbolTablePtr {
43 pub fn new() -> Self {
45 Self::new_with_data(SymbolTable::new())
46 }
47
48 pub fn with_parent(symbols: SymbolTablePtr) -> Self {
50 Self::new_with_data(SymbolTable::with_parent(symbols))
51 }
52
53 fn new_with_data(data: SymbolTable) -> Self {
54 let object_id = SYMBOL_TABLE_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
55 let id = ObjId {
56 object_id,
57 type_name: SymbolTablePtr::TYPE_NAME.into(),
58 };
59 Self::new_with_id_and_data(id, data)
60 }
61
62 fn new_with_id_and_data(id: ObjId, data: SymbolTable) -> Self {
63 Self {
64 inner: Arc::new(SymbolTablePtrInner {
65 id,
66 value: RwLock::new(data),
67 }),
68 }
69 }
70
71 pub fn read(&self) -> RwLockReadGuard<'_, SymbolTable> {
78 self.inner.value.read()
79 }
80
81 pub fn write(&self) -> RwLockWriteGuard<'_, SymbolTable> {
94 self.inner.value.write()
95 }
96
97 pub fn detach(&self) -> Self {
100 Self::new_with_data(self.read().clone())
101 }
102}
103
104impl Default for SymbolTablePtr {
105 fn default() -> Self {
106 Self::new()
107 }
108}
109
110impl HasId for SymbolTablePtr {
111 const TYPE_NAME: &'static str = "SymbolTable";
112
113 fn id(&self) -> ObjId {
114 self.inner.id.clone()
115 }
116}
117
118impl DefaultWithId for SymbolTablePtr {
119 fn default_with_id(id: ObjId) -> Self {
120 Self::new_with_id_and_data(id, SymbolTable::default())
121 }
122}
123
124impl IdPtr for SymbolTablePtr {
125 type Data = SymbolTable;
126
127 fn get_data(&self) -> Self::Data {
128 self.read().clone()
129 }
130
131 fn with_id_and_data(id: ObjId, data: Self::Data) -> Self {
132 Self::new_with_id_and_data(id, data)
133 }
134}
135
136impl Uniplate for SymbolTablePtr {
142 fn uniplate(&self) -> (Tree<Self>, Box<dyn Fn(Tree<Self>) -> Self>) {
143 let symtab = self.read();
144 let (tree, recons) = Biplate::<SymbolTablePtr>::biplate(&symtab as &SymbolTable);
145
146 let self2 = self.clone();
147 (
148 tree,
149 Box::new(move |x| {
150 let self3 = self2.clone();
151 *(self3.write()) = recons(x);
152 self3
153 }),
154 )
155 }
156}
157
158impl<To> Biplate<To> for SymbolTablePtr
159where
160 SymbolTable: Biplate<To>,
161 To: Uniplate,
162{
163 fn biplate(&self) -> (Tree<To>, Box<dyn Fn(Tree<To>) -> Self>) {
164 if TypeId::of::<To>() == TypeId::of::<Self>() {
165 unsafe {
166 let self_as_to = std::mem::transmute::<&Self, &To>(self).clone();
167 (
168 Tree::One(self_as_to),
169 Box::new(move |x| {
170 let Tree::One(x) = x else { panic!() };
171
172 let x_as_self = std::mem::transmute::<&To, &Self>(&x);
173 x_as_self.clone()
174 }),
175 )
176 }
177 } else {
178 let decl = self.read();
180 let (tree, recons) = Biplate::<To>::biplate(&decl as &SymbolTable);
181
182 let self2 = self.clone();
183 (
184 tree,
185 Box::new(move |x| {
186 let self3 = self2.clone();
187 *(self3.write()) = recons(x);
188 self3
189 }),
190 )
191 }
192 }
193}
194
195#[derive(Debug)]
196struct SymbolTablePtrInner {
197 id: ObjId,
198 value: RwLock<SymbolTable>,
199}
200
201impl Hash for SymbolTablePtrInner {
202 fn hash<H: Hasher>(&self, state: &mut H) {
203 self.id.hash(state);
204 }
205}
206
207impl PartialEq for SymbolTablePtrInner {
208 fn eq(&self, other: &Self) -> bool {
209 self.value.read().eq(&other.value.read())
210 }
211}
212
213impl Eq for SymbolTablePtrInner {}
214
215#[serde_as]
250#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
251pub struct SymbolTable {
252 #[serde_as(as = "Vec<(_,PtrAsInner)>")]
253 table: BTreeMap<Name, DeclarationPtr>,
254
255 #[serde_as(as = "Option<AsId>")]
256 parent: Option<SymbolTablePtr>,
257
258 next_machine_name: i32,
259}
260
261impl SymbolTable {
262 pub fn new() -> SymbolTable {
264 SymbolTable::new_inner(None)
265 }
266
267 pub fn with_parent(parent: SymbolTablePtr) -> SymbolTable {
269 SymbolTable::new_inner(Some(parent))
270 }
271
272 fn new_inner(parent: Option<SymbolTablePtr>) -> SymbolTable {
273 let id = SYMBOL_TABLE_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
274 trace!(
275 "new symbol table: id = {id} parent_id = {}",
276 parent
277 .as_ref()
278 .map(|x| x.id().to_string())
279 .unwrap_or(String::from("none"))
280 );
281 SymbolTable {
282 table: BTreeMap::new(),
283 next_machine_name: 0,
284 parent,
285 }
286 }
287
288 pub fn lookup_local(&self, name: &Name) -> Option<DeclarationPtr> {
292 self.table.get(name).cloned()
293 }
294
295 pub fn lookup(&self, name: &Name) -> Option<DeclarationPtr> {
299 self.lookup_local(name).or_else(|| {
300 self.parent
301 .as_ref()
302 .and_then(|parent| parent.read().lookup(name))
303 })
304 }
305
306 pub fn insert(&mut self, declaration: DeclarationPtr) -> Option<()> {
310 let name = declaration.name().clone();
311 if let Entry::Vacant(e) = self.table.entry(name) {
312 e.insert(declaration);
313 Some(())
314 } else {
315 None
316 }
317 }
318
319 pub fn update_insert(&mut self, declaration: DeclarationPtr) {
321 let name = declaration.name().clone();
322 self.table.insert(name, declaration);
323 }
324
325 pub fn return_type(&self, name: &Name) -> Option<ReturnType> {
327 self.lookup(name).map(|x| x.return_type())
328 }
329
330 pub fn return_type_local(&self, name: &Name) -> Option<ReturnType> {
332 self.lookup_local(name).map(|x| x.return_type())
333 }
334
335 pub fn domain(&self, name: &Name) -> Option<DomainPtr> {
340 if let Name::WithRepresentation(name, _) = name {
341 self.lookup(name)?.domain()
342 } else {
343 self.lookup(name)?.domain()
344 }
345 }
346
347 pub fn resolve_domain(&self, name: &Name) -> Option<Moo<GroundDomain>> {
351 self.domain(name)?.resolve()
352 }
353
354 pub fn into_iter_local(self) -> impl Iterator<Item = (Name, DeclarationPtr)> {
356 self.table.into_iter()
357 }
358
359 pub fn iter_local(&self) -> impl Iterator<Item = (&Name, &DeclarationPtr)> {
361 self.table.iter()
362 }
363
364 pub fn iter_local_mut(&mut self) -> impl Iterator<Item = (&Name, &mut DeclarationPtr)> {
366 self.table.iter_mut()
367 }
368
369 pub fn extend(&mut self, other: SymbolTable) {
372 if other.table.keys().count() > self.table.keys().count() {
373 let new_vars = other.table.keys().collect::<BTreeSet<_>>();
374 let old_vars = self.table.keys().collect::<BTreeSet<_>>();
375
376 for added_var in new_vars.difference(&old_vars) {
377 let next_var = &mut self.next_machine_name;
378 if let Name::Machine(m) = *added_var
379 && *m >= *next_var
380 {
381 *next_var = *m + 1;
382 }
383 }
384 }
385
386 self.table.extend(other.table);
387 }
388
389 pub fn gensym(&mut self, domain: &DomainPtr) -> DeclarationPtr {
392 let num = self.next_machine_name;
393 self.next_machine_name += 1;
394 let decl = DeclarationPtr::new_find(Name::Machine(num), domain.clone());
395 self.insert(decl.clone());
396 decl
397 }
398
399 pub fn parent_mut_unchecked(&mut self) -> &mut Option<SymbolTablePtr> {
403 &mut self.parent
404 }
405
406 pub fn parent(&self) -> &Option<SymbolTablePtr> {
408 &self.parent
409 }
410
411 pub fn get_representation(
417 &self,
418 name: &Name,
419 representation: &[&str],
420 ) -> Option<Vec<Box<dyn Representation>>> {
421 let decl = self.lookup(name)?;
433 let var = &decl.as_find()?;
434
435 var.representations
436 .iter()
437 .find(|x| &x.iter().map(|r| r.repr_name()).collect_vec()[..] == representation)
438 .cloned()
439 }
440
441 pub fn representations_for(&self, name: &Name) -> Option<Vec<Vec<Box<dyn Representation>>>> {
447 let decl = self.lookup(name)?;
448 decl.as_find().map(|x| x.representations.clone())
449 }
450
451 pub fn get_or_add_representation(
468 &mut self,
469 name: &Name,
470 representation: &[&str],
471 ) -> Option<Vec<Box<dyn Representation>>> {
472 let mut decl = self.lookup(name)?;
474
475 if let Some(var) = decl.as_find()
476 && let Some(existing_reprs) = var
477 .representations
478 .iter()
479 .find(|x| &x.iter().map(|r| r.repr_name()).collect_vec()[..] == representation)
480 .cloned()
481 {
482 return Some(existing_reprs); }
484 if representation.len() != 1 {
488 bug!("nested representations not implemented")
489 }
490 let repr_name_str = representation[0];
491 let repr_init_fn = get_repr_rule(repr_name_str)?;
492
493 let reprs = vec![repr_init_fn(name, self)?];
494
495 let mut var = decl.as_find_mut()?;
497
498 for repr_instance in &reprs {
499 repr_instance
500 .declaration_down()
501 .ok()?
502 .into_iter()
503 .for_each(|x| self.update_insert(x));
504 }
505
506 var.representations.push(reprs.clone());
507
508 Some(reprs)
509 }
510}
511
512impl IntoIterator for SymbolTable {
513 type Item = (Name, DeclarationPtr);
514
515 type IntoIter = SymbolTableIter;
516
517 fn into_iter(self) -> Self::IntoIter {
519 SymbolTableIter {
520 inner: self.table.into_iter(),
521 parent: self.parent,
522 }
523 }
524}
525
526pub struct SymbolTableIter {
528 inner: std::collections::btree_map::IntoIter<Name, DeclarationPtr>,
530
531 parent: Option<SymbolTablePtr>,
533}
534
535impl Iterator for SymbolTableIter {
536 type Item = (Name, DeclarationPtr);
537
538 fn next(&mut self) -> Option<Self::Item> {
539 let mut val = self.inner.next();
540
541 while val.is_none() {
545 let parent = self.parent.clone()?;
546
547 let guard = parent.read();
548 self.inner = guard.table.clone().into_iter();
549 self.parent.clone_from(&guard.parent);
550
551 val = self.inner.next();
552 }
553
554 val
555 }
556}
557
558impl Default for SymbolTable {
559 fn default() -> Self {
560 Self::new_inner(None)
561 }
562}
563
564impl Uniplate for SymbolTable {
567 fn uniplate(&self) -> (Tree<Self>, Box<dyn Fn(Tree<Self>) -> Self>) {
568 let self2 = self.clone();
570 (Tree::Zero, Box::new(move |_| self2.clone()))
571 }
572}
573
574impl Biplate<SymbolTablePtr> for SymbolTable {
575 fn biplate(
576 &self,
577 ) -> (
578 Tree<SymbolTablePtr>,
579 Box<dyn Fn(Tree<SymbolTablePtr>) -> Self>,
580 ) {
581 let self2 = self.clone();
582 (Tree::Zero, Box::new(move |_| self2.clone()))
583 }
584}
585
586impl Biplate<Expression> for SymbolTable {
587 fn biplate(&self) -> (Tree<Expression>, Box<dyn Fn(Tree<Expression>) -> Self>) {
588 let (child_trees, ctxs): (VecDeque<_>, Vec<_>) = self
589 .table
590 .values()
591 .map(Biplate::<Expression>::biplate)
592 .unzip();
593
594 let tree = Tree::Many(child_trees);
595
596 let self2 = self.clone();
597 let ctx = Box::new(move |tree| {
598 let Tree::Many(exprs) = tree else {
599 panic!("unexpected children structure");
600 };
601
602 let mut self3 = self2.clone();
603 let self3_iter = self3.table.iter_mut();
604 for (ctx, tree, (_, decl)) in izip!(&ctxs, exprs, self3_iter) {
605 *decl = ctx(tree)
608 }
609
610 self3
611 });
612
613 (tree, ctx)
614 }
615}
616
617impl Biplate<Comprehension> for SymbolTable {
618 fn biplate(
619 &self,
620 ) -> (
621 Tree<Comprehension>,
622 Box<dyn Fn(Tree<Comprehension>) -> Self>,
623 ) {
624 let (expr_tree, expr_ctx) = <SymbolTable as Biplate<Expression>>::biplate(self);
625
626 let (exprs, recons_expr_tree) = expr_tree.list();
627
628 let (comprehension_tree, comprehension_ctx) =
629 <VecDeque<Expression> as Biplate<Comprehension>>::biplate(&exprs);
630
631 let ctx = Box::new(move |x| {
632 let exprs = comprehension_ctx(x);
634
635 let expr_tree = recons_expr_tree(exprs);
637
638 expr_ctx(expr_tree)
640 });
641
642 (comprehension_tree, ctx)
643 }
644}
645
646impl Biplate<Model> for SymbolTable {
647 fn biplate(&self) -> (Tree<Model>, Box<dyn Fn(Tree<Model>) -> Self>) {
649 let (expr_tree, expr_ctx) = <SymbolTable as Biplate<Expression>>::biplate(self);
650
651 let (exprs, recons_expr_tree) = expr_tree.list();
652
653 let (submodel_tree, submodel_ctx) =
654 <VecDeque<Expression> as Biplate<Model>>::biplate(&exprs);
655
656 let ctx = Box::new(move |x| {
657 let exprs = submodel_ctx(x);
659
660 let expr_tree = recons_expr_tree(exprs);
662
663 expr_ctx(expr_tree)
665 });
666 (submodel_tree, ctx)
667 }
668}