1
use super::ParseContext;
2
use super::util::{get_expr_tree, query_toplevel};
3
use crate::diagnostics::source_map::SourceMap;
4
use crate::errors::FatalParseError;
5
use crate::expression::parse_expression;
6
use crate::util::TypecheckingContext;
7
use crate::util::node_is_expression;
8
use conjure_cp_core::ast::{Expression, SymbolTablePtr};
9
use std::collections::BTreeMap;
10
#[allow(unused)]
11
use uniplate::Uniplate;
12

            
13
9
pub fn parse_expr(
14
9
    src: &str,
15
9
    symbols_ptr: SymbolTablePtr,
16
9
) -> Result<Option<Expression>, FatalParseError> {
17
9
    let exprs = parse_exprs(src, symbols_ptr)?;
18
9
    if exprs.len() != 1 {
19
1
        return Ok(None);
20
8
    }
21
8
    Ok(Some(exprs[0].clone()))
22
9
}
23

            
24
10
pub fn parse_exprs(
25
10
    src: &str,
26
10
    symbols_ptr: SymbolTablePtr,
27
10
) -> Result<Vec<Expression>, FatalParseError> {
28
10
    let Some((tree, source_code)) = get_expr_tree(src) else {
29
        return Ok(Vec::new());
30
    };
31

            
32
10
    let root = tree.root_node();
33
10
    let mut source_map = SourceMap::default();
34
10
    let mut decl_spans = BTreeMap::new();
35
10
    let mut errors = Vec::new();
36
10
    let mut ctx = ParseContext::new(
37
10
        &source_code,
38
10
        &root,
39
10
        Some(symbols_ptr),
40
10
        &mut errors,
41
10
        &mut source_map,
42
10
        &mut decl_spans,
43
    );
44
10
    let mut ans = Vec::new();
45
11
    for expr in query_toplevel(&root, &node_is_expression) {
46
11
        ctx.typechecking_context = TypecheckingContext::Unknown;
47
11
        ctx.inner_typechecking_context = TypecheckingContext::Unknown;
48
11
        let Some(expr) = parse_expression(&mut ctx, expr)? else {
49
1
            continue;
50
        };
51
10
        ans.push(expr);
52
    }
53
10
    Ok(ans)
54
10
}
55

            
56
mod test {
57
    #[allow(unused)]
58
    use super::{parse_expr, parse_exprs};
59
    #[allow(unused)]
60
    use conjure_cp_core::ast::SymbolTablePtr;
61
    #[allow(unused)]
62
    use conjure_cp_core::ast::{
63
        Atom, DeclarationPtr, Domain, Expression, Literal, Metadata, Moo, Name, SymbolTable,
64
    };
65
    #[allow(unused)]
66
    use std::collections::HashMap;
67
    #[allow(unused)]
68
    use std::sync::Arc;
69
    #[allow(unused)]
70
    use tree_sitter::Range;
71

            
72
    #[test]
73
1
    pub fn test_parse_constant() {
74
1
        let symbols = SymbolTablePtr::new();
75

            
76
1
        assert_eq!(
77
1
            parse_expr("42", symbols.clone()).unwrap().unwrap(),
78
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(42)))
79
        );
80
1
        assert_eq!(
81
1
            parse_expr("true", symbols.clone()).unwrap().unwrap(),
82
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)))
83
        );
84
1
        assert_eq!(
85
1
            parse_expr("false", symbols).unwrap().unwrap(),
86
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false)))
87
        )
88
1
    }
89

            
90
    #[test]
91
1
    pub fn test_parse_expressions() {
92
1
        let src = "x >= 5, y = a / 2";
93
1
        let symbols = SymbolTablePtr::new();
94
1
        let x = DeclarationPtr::new_find(
95
1
            Name::User("x".into()),
96
1
            Domain::int(vec![conjure_cp_core::ast::Range::Bounded(0, 10)]),
97
        );
98

            
99
1
        let y = DeclarationPtr::new_find(
100
1
            Name::User("y".into()),
101
1
            Domain::int(vec![conjure_cp_core::ast::Range::Bounded(0, 10)]),
102
        );
103

            
104
1
        let a = DeclarationPtr::new_find(
105
1
            Name::User("a".into()),
106
1
            Domain::int(vec![conjure_cp_core::ast::Range::Bounded(0, 10)]),
107
        );
108

            
109
        // Clone the Rc when inserting!
110
1
        symbols
111
1
            .write()
112
1
            .insert(x.clone())
113
1
            .expect("x should not exist in the symbol-table yet, so we should be able to add it");
114

            
115
1
        symbols
116
1
            .write()
117
1
            .insert(y.clone())
118
1
            .expect("y should not exist in the symbol-table yet, so we should be able to add it");
119

            
120
1
        symbols
121
1
            .write()
122
1
            .insert(a.clone())
123
1
            .expect("a should not exist in the symbol-table yet, so we should be able to add it");
124

            
125
1
        let exprs = parse_exprs(src, symbols).unwrap();
126
1
        assert_eq!(exprs.len(), 2);
127

            
128
1
        assert_eq!(
129
1
            exprs[0],
130
1
            Expression::Geq(
131
1
                Metadata::new(),
132
1
                Moo::new(Expression::Atomic(Metadata::new(), Atom::new_ref(x))),
133
1
                Moo::new(Expression::Atomic(Metadata::new(), 5.into()))
134
1
            )
135
        );
136

            
137
1
        assert_eq!(
138
1
            exprs[1],
139
1
            Expression::Eq(
140
1
                Metadata::new(),
141
1
                Moo::new(Expression::Atomic(Metadata::new(), Atom::new_ref(y))),
142
1
                Moo::new(Expression::UnsafeDiv(
143
1
                    Metadata::new(),
144
1
                    Moo::new(Expression::Atomic(Metadata::new(), Atom::new_ref(a))),
145
1
                    Moo::new(Expression::Atomic(Metadata::new(), 2.into()))
146
1
                ))
147
1
            )
148
        );
149
1
    }
150
}