1
use super::util::{get_tree, query_toplevel};
2
use crate::errors::FatalParseError;
3
use crate::expression::parse_expression;
4
use crate::util::node_is_expression;
5
use conjure_cp_core::ast::{Expression, SymbolTablePtr};
6
#[allow(unused)]
7
use uniplate::Uniplate;
8

            
9
9
pub fn parse_expr(src: &str, symbols_ptr: SymbolTablePtr) -> Result<Expression, FatalParseError> {
10
9
    let exprs = parse_exprs(src, symbols_ptr)?;
11
17
    if exprs.len() != 1 {
12
9
        return Err(FatalParseError::syntax_error(
13
9
            "Expected a single expression".to_string(),
14
1
            None,
15
1
        ));
16
9
    }
17
9
    Ok(exprs[0].clone())
18
17
}
19
8

            
20
19
pub fn parse_exprs(
21
10
    src: &str,
22
20
    symbols_ptr: SymbolTablePtr,
23
20
) -> Result<Vec<Expression>, FatalParseError> {
24
20
    let (tree, source_code) = get_tree(src).ok_or(FatalParseError::TreeSitterError(
25
20
        "Failed to parse Essence source code".to_string(),
26
20
    ))?;
27
10

            
28
20
    let root = tree.root_node();
29
10
    let mut ans = Vec::new();
30
21
    for expr in query_toplevel(&root, &node_is_expression) {
31
21
        ans.push(parse_expression(
32
21
            expr,
33
21
            &source_code,
34
21
            &root,
35
21
            Some(symbols_ptr.clone()),
36
21
            &mut Vec::new(),
37
11
        )?);
38
10
    }
39
9
    Ok(ans)
40
20
}
41
11

            
42
11
mod test {
43
1
    #[allow(unused)]
44
    use super::{parse_expr, parse_exprs};
45
10
    #[allow(unused)]
46
    use conjure_cp_core::ast::SymbolTablePtr;
47
10
    #[allow(unused)]
48
10
    use conjure_cp_core::ast::{
49
        Atom, DeclarationPtr, Domain, Expression, Literal, Metadata, Moo, Name, SymbolTable,
50
    };
51
    #[allow(unused)]
52
    use std::collections::HashMap;
53
    #[allow(unused)]
54
    use std::sync::Arc;
55
    #[allow(unused)]
56
    use tree_sitter::Range;
57

            
58
    #[test]
59
1
    pub fn test_parse_constant() {
60
1
        let symbols = SymbolTablePtr::new();
61

            
62
1
        assert_eq!(
63
1
            parse_expr("42", symbols.clone()).unwrap(),
64
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(42)))
65
        );
66
1
        assert_eq!(
67
2
            parse_expr("true", symbols.clone()).unwrap(),
68
2
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)))
69
        );
70
2
        assert_eq!(
71
2
            parse_expr("false", symbols).unwrap(),
72
2
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false)))
73
        )
74
2
    }
75
1

            
76
1
    #[test]
77
1
    pub fn test_parse_expressions() {
78
2
        let src = "x >= 5, y = a / 2";
79
2
        let symbols = SymbolTablePtr::new();
80
2
        let x = DeclarationPtr::new_find(
81
1
            Name::User("x".into()),
82
2
            Domain::int(vec![conjure_cp_core::ast::Range::Bounded(0, 10)]),
83
        );
84

            
85
2
        let y = DeclarationPtr::new_find(
86
2
            Name::User("y".into()),
87
2
            Domain::int(vec![conjure_cp_core::ast::Range::Bounded(0, 10)]),
88
1
        );
89
1

            
90
2
        let a = DeclarationPtr::new_find(
91
1
            Name::User("a".into()),
92
1
            Domain::int(vec![conjure_cp_core::ast::Range::Bounded(0, 10)]),
93
1
        );
94
1

            
95
        // Clone the Rc when inserting!
96
1
        symbols
97
1
            .write()
98
2
            .insert(x.clone())
99
2
            .expect("x should not exist in the symbol-table yet, so we should be able to add it");
100
1

            
101
1
        symbols
102
1
            .write()
103
1
            .insert(y.clone())
104
2
            .expect("y should not exist in the symbol-table yet, so we should be able to add it");
105
1

            
106
2
        symbols
107
2
            .write()
108
1
            .insert(a.clone())
109
2
            .expect("a should not exist in the symbol-table yet, so we should be able to add it");
110
1

            
111
2
        let exprs = parse_exprs(src, symbols).unwrap();
112
2
        assert_eq!(exprs.len(), 2);
113

            
114
2
        assert_eq!(
115
2
            exprs[0],
116
2
            Expression::Geq(
117
2
                Metadata::new(),
118
1
                Moo::new(Expression::Atomic(Metadata::new(), Atom::new_ref(x))),
119
2
                Moo::new(Expression::Atomic(Metadata::new(), 5.into()))
120
2
            )
121
        );
122
1

            
123
2
        assert_eq!(
124
2
            exprs[1],
125
2
            Expression::Eq(
126
2
                Metadata::new(),
127
2
                Moo::new(Expression::Atomic(Metadata::new(), Atom::new_ref(y))),
128
2
                Moo::new(Expression::UnsafeDiv(
129
1
                    Metadata::new(),
130
1
                    Moo::new(Expression::Atomic(Metadata::new(), Atom::new_ref(a))),
131
2
                    Moo::new(Expression::Atomic(Metadata::new(), 2.into()))
132
2
                ))
133
2
            )
134
1
        );
135
2
    }
136
1
}