1
use super::util::{get_tree, query_toplevel};
2
use crate::errors::EssenceParseError;
3
use crate::expression::parse_expression;
4
use crate::util::node_is_expression;
5
use conjure_cp_core::ast::{Expression, SymbolTablePtr};
6
#[allow(unused)]
7
use uniplate::Uniplate;
8

            
9
9
pub fn parse_expr(src: &str, symbols_ptr: SymbolTablePtr) -> Result<Expression, EssenceParseError> {
10
9
    let exprs = parse_exprs(src, symbols_ptr)?;
11
8
    if exprs.len() != 1 {
12
        return Err(EssenceParseError::syntax_error(
13
            "Expected a single expression".to_string(),
14
            None,
15
        ));
16
8
    }
17
8
    Ok(exprs[0].clone())
18
9
}
19

            
20
10
pub fn parse_exprs(
21
10
    src: &str,
22
10
    symbols_ptr: SymbolTablePtr,
23
10
) -> Result<Vec<Expression>, EssenceParseError> {
24
10
    let (tree, source_code) = get_tree(src).ok_or(EssenceParseError::TreeSitterError(
25
10
        "Failed to parse Essence source code".to_string(),
26
10
    ))?;
27

            
28
10
    let root = tree.root_node();
29
10
    let mut ans = Vec::new();
30
11
    for expr in query_toplevel(&root, &node_is_expression) {
31
11
        ans.push(parse_expression(
32
11
            expr,
33
11
            &source_code,
34
11
            &root,
35
11
            Some(symbols_ptr.clone()),
36
1
        )?);
37
    }
38
9
    Ok(ans)
39
10
}
40

            
41
mod test {
42
    #[allow(unused)]
43
    use super::{parse_expr, parse_exprs};
44
    #[allow(unused)]
45
    use conjure_cp_core::ast::SymbolTablePtr;
46
    #[allow(unused)]
47
    use conjure_cp_core::ast::{
48
        Atom, DeclarationPtr, Domain, Expression, Literal, Metadata, Moo, Name, SymbolTable,
49
    };
50
    #[allow(unused)]
51
    use std::collections::HashMap;
52
    #[allow(unused)]
53
    use std::sync::Arc;
54
    #[allow(unused)]
55
    use tree_sitter::Range;
56

            
57
    #[test]
58
1
    pub fn test_parse_constant() {
59
1
        let symbols = SymbolTablePtr::new();
60

            
61
1
        assert_eq!(
62
1
            parse_expr("42", symbols.clone()).unwrap(),
63
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(42)))
64
        );
65
1
        assert_eq!(
66
1
            parse_expr("true", symbols.clone()).unwrap(),
67
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)))
68
        );
69
1
        assert_eq!(
70
1
            parse_expr("false", symbols).unwrap(),
71
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false)))
72
        )
73
1
    }
74

            
75
    #[test]
76
1
    pub fn test_parse_expressions() {
77
1
        let src = "x >= 5, y = a / 2";
78
1
        let symbols = SymbolTablePtr::new();
79
1
        let x = DeclarationPtr::new_find(
80
1
            Name::User("x".into()),
81
1
            Domain::int(vec![conjure_cp_core::ast::Range::Bounded(0, 10)]),
82
        );
83

            
84
1
        let y = DeclarationPtr::new_find(
85
1
            Name::User("y".into()),
86
1
            Domain::int(vec![conjure_cp_core::ast::Range::Bounded(0, 10)]),
87
        );
88

            
89
1
        let a = DeclarationPtr::new_find(
90
1
            Name::User("a".into()),
91
1
            Domain::int(vec![conjure_cp_core::ast::Range::Bounded(0, 10)]),
92
        );
93

            
94
        // Clone the Rc when inserting!
95
1
        symbols
96
1
            .write()
97
1
            .insert(x.clone())
98
1
            .expect("x should not exist in the symbol-table yet, so we should be able to add it");
99

            
100
1
        symbols
101
1
            .write()
102
1
            .insert(y.clone())
103
1
            .expect("y should not exist in the symbol-table yet, so we should be able to add it");
104

            
105
1
        symbols
106
1
            .write()
107
1
            .insert(a.clone())
108
1
            .expect("a should not exist in the symbol-table yet, so we should be able to add it");
109

            
110
1
        let exprs = parse_exprs(src, symbols).unwrap();
111
1
        assert_eq!(exprs.len(), 2);
112

            
113
1
        assert_eq!(
114
1
            exprs[0],
115
1
            Expression::Geq(
116
1
                Metadata::new(),
117
1
                Moo::new(Expression::Atomic(Metadata::new(), Atom::new_ref(x))),
118
1
                Moo::new(Expression::Atomic(Metadata::new(), 5.into()))
119
1
            )
120
        );
121

            
122
1
        assert_eq!(
123
1
            exprs[1],
124
1
            Expression::Eq(
125
1
                Metadata::new(),
126
1
                Moo::new(Expression::Atomic(Metadata::new(), Atom::new_ref(y))),
127
1
                Moo::new(Expression::UnsafeDiv(
128
1
                    Metadata::new(),
129
1
                    Moo::new(Expression::Atomic(Metadata::new(), Atom::new_ref(a))),
130
1
                    Moo::new(Expression::Atomic(Metadata::new(), 2.into()))
131
1
                ))
132
1
            )
133
        );
134
1
    }
135
}