1
use crate::ast::serde::{AsId, HasId};
2
use crate::{ast::DeclarationPtr, bug};
3
use derivative::Derivative;
4
use parking_lot::MappedRwLockReadGuard;
5
use serde::{Deserialize, Serialize};
6
use serde_with::serde_as;
7
use std::fmt::{Display, Formatter};
8
use uniplate::Uniplate;
9

            
10
use super::{
11
    Atom, DeclarationKind, DomainPtr, Expression, GroundDomain, Literal, Metadata, Moo, Name,
12
    categories::{Category, CategoryOf},
13
    domains::HasDomain,
14
};
15

            
16
/// A reference to a declaration (variable, parameter, etc.)
17
///
18
/// This is a thin wrapper around [`DeclarationPtr`] with two main purposes:
19
/// 1. Encapsulate the serde pragmas (e.g., serializing as IDs rather than full objects)
20
/// 2. Enable type-directed traversals of references via uniplate
21
#[serde_as]
22
#[derive(
23
    Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Uniplate, Derivative,
24
)]
25
#[derivative(Hash)]
26
#[uniplate()]
27
#[biplate(to=DeclarationPtr)]
28
#[biplate(to=Name)]
29
pub struct Reference {
30
    #[serde_as(as = "AsId")]
31
    pub ptr: DeclarationPtr,
32
}
33

            
34
impl Reference {
35
4120311
    pub fn new(ptr: DeclarationPtr) -> Self {
36
4120311
        Reference { ptr }
37
4120311
    }
38

            
39
44115956
    pub fn ptr(&self) -> &DeclarationPtr {
40
44115956
        &self.ptr
41
44115956
    }
42

            
43
1040
    pub fn into_ptr(self) -> DeclarationPtr {
44
1040
        self.ptr
45
1040
    }
46

            
47
282162623
    pub fn name(&self) -> MappedRwLockReadGuard<'_, Name> {
48
282162623
        self.ptr.name()
49
282162623
    }
50

            
51
15872
    pub fn id(&self) -> crate::ast::serde::ObjId {
52
15872
        self.ptr.id()
53
15872
    }
54

            
55
9005276
    pub fn domain(&self) -> Option<DomainPtr> {
56
9005276
        self.ptr.domain()
57
9005276
    }
58

            
59
8925576
    pub fn resolved_domain(&self) -> Option<Moo<GroundDomain>> {
60
8925576
        self.domain()?.resolve()
61
8925576
    }
62

            
63
    /// Returns the expression behind a value-letting reference, if this is one.
64
41165772
    pub fn resolve_expression(&self) -> Option<Expression> {
65
41165772
        if let Some(expr) = self.ptr().as_value_letting() {
66
2317574
            return Some(expr.clone());
67
38848198
        }
68

            
69
38848198
        let generator = {
70
38848198
            let kind = self.ptr.kind();
71
38848198
            if let DeclarationKind::Quantified(inner) = &*kind {
72
439706
                inner.generator().cloned()
73
            } else {
74
38408492
                None
75
            }
76
        };
77

            
78
38848198
        if let Some(generator) = generator
79
            && let Some(expr) = generator.as_value_letting()
80
        {
81
            return Some(expr.clone());
82
38848198
        }
83

            
84
38848198
        None
85
41165772
    }
86

            
87
    /// Evaluates this reference to a literal if it resolves to a constant.
88
22571238
    pub fn resolve_constant(&self) -> Option<Literal> {
89
22571238
        self.resolve_expression()
90
22571238
            .and_then(|expr| super::eval::eval_constant(&expr))
91
22571238
    }
92

            
93
    /// Resolves this reference to an atomic expression, if possible.
94
261834
    pub fn resolve_atomic(&self) -> Option<Atom> {
95
261834
        self.resolve_expression().and_then(|expr| match expr {
96
            Expression::Atomic(_, atom) => Some(atom),
97
            _ => None,
98
        })
99
261834
    }
100
}
101

            
102
impl From<Reference> for Expression {
103
446
    fn from(value: Reference) -> Self {
104
446
        Expression::Atomic(Metadata::new(), value.into())
105
446
    }
106
}
107

            
108
impl From<DeclarationPtr> for Reference {
109
    fn from(ptr: DeclarationPtr) -> Self {
110
        Reference::new(ptr)
111
    }
112
}
113

            
114
impl CategoryOf for Reference {
115
2831266
    fn category_of(&self) -> Category {
116
2831266
        self.ptr.category_of()
117
2831266
    }
118
}
119

            
120
impl HasDomain for Reference {
121
4991082
    fn domain_of(&self) -> DomainPtr {
122
4991082
        self.ptr
123
4991082
            .domain()
124
4991082
            .or_else(|| self.resolve_constant().map(|literal| literal.domain_of()))
125
4991082
            .unwrap_or_else(|| {
126
                bug!(
127
                    "reference ({name}) should have a domain",
128
                    name = self.ptr.name()
129
                )
130
            })
131
4991082
    }
132
}
133

            
134
impl Display for Reference {
135
73826196
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
136
73826196
        self.ptr.name().fmt(f)
137
73826196
    }
138
}