1
use super::ParseContext;
2
use super::util::{get_expr_tree, query_toplevel};
3
use crate::diagnostics::source_map::SourceMap;
4
use crate::errors::FatalParseError;
5
use crate::expression::parse_expression;
6
use crate::util::node_is_expression;
7
use conjure_cp_core::ast::{Expression, SymbolTablePtr};
8
use std::collections::BTreeMap;
9
#[allow(unused)]
10
use uniplate::Uniplate;
11

            
12
9
pub fn parse_expr(
13
9
    src: &str,
14
9
    symbols_ptr: SymbolTablePtr,
15
9
) -> Result<Option<Expression>, FatalParseError> {
16
9
    let exprs = parse_exprs(src, symbols_ptr)?;
17
9
    if exprs.len() != 1 {
18
1
        return Ok(None);
19
8
    }
20
8
    Ok(Some(exprs[0].clone()))
21
9
}
22

            
23
10
pub fn parse_exprs(
24
10
    src: &str,
25
10
    symbols_ptr: SymbolTablePtr,
26
10
) -> Result<Vec<Expression>, FatalParseError> {
27
10
    let Some((tree, source_code)) = get_expr_tree(src) else {
28
        return Ok(Vec::new());
29
    };
30

            
31
10
    let root = tree.root_node();
32
10
    let mut source_map = SourceMap::default();
33
10
    let mut decl_spans = BTreeMap::new();
34
10
    let mut errors = Vec::new();
35
10
    let mut ctx = ParseContext::new(
36
10
        &source_code,
37
10
        &root,
38
10
        Some(symbols_ptr),
39
10
        &mut errors,
40
10
        &mut source_map,
41
10
        &mut decl_spans,
42
    );
43
10
    let mut ans = Vec::new();
44
11
    for expr in query_toplevel(&root, &node_is_expression) {
45
11
        let Some(expr) = parse_expression(&mut ctx, expr)? else {
46
1
            continue;
47
        };
48
10
        ans.push(expr);
49
    }
50
10
    Ok(ans)
51
10
}
52

            
53
mod test {
54
    #[allow(unused)]
55
    use super::{parse_expr, parse_exprs};
56
    #[allow(unused)]
57
    use conjure_cp_core::ast::SymbolTablePtr;
58
    #[allow(unused)]
59
    use conjure_cp_core::ast::{
60
        Atom, DeclarationPtr, Domain, Expression, Literal, Metadata, Moo, Name, SymbolTable,
61
    };
62
    #[allow(unused)]
63
    use std::collections::HashMap;
64
    #[allow(unused)]
65
    use std::sync::Arc;
66
    #[allow(unused)]
67
    use tree_sitter::Range;
68

            
69
    #[test]
70
1
    pub fn test_parse_constant() {
71
1
        let symbols = SymbolTablePtr::new();
72

            
73
1
        assert_eq!(
74
1
            parse_expr("42", symbols.clone()).unwrap().unwrap(),
75
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(42)))
76
        );
77
1
        assert_eq!(
78
1
            parse_expr("true", symbols.clone()).unwrap().unwrap(),
79
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)))
80
        );
81
1
        assert_eq!(
82
1
            parse_expr("false", symbols).unwrap().unwrap(),
83
1
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false)))
84
        )
85
1
    }
86

            
87
    #[test]
88
1
    pub fn test_parse_expressions() {
89
1
        let src = "x >= 5, y = a / 2";
90
1
        let symbols = SymbolTablePtr::new();
91
1
        let x = DeclarationPtr::new_find(
92
1
            Name::User("x".into()),
93
1
            Domain::int(vec![conjure_cp_core::ast::Range::Bounded(0, 10)]),
94
        );
95

            
96
1
        let y = DeclarationPtr::new_find(
97
1
            Name::User("y".into()),
98
1
            Domain::int(vec![conjure_cp_core::ast::Range::Bounded(0, 10)]),
99
        );
100

            
101
1
        let a = DeclarationPtr::new_find(
102
1
            Name::User("a".into()),
103
1
            Domain::int(vec![conjure_cp_core::ast::Range::Bounded(0, 10)]),
104
        );
105

            
106
        // Clone the Rc when inserting!
107
1
        symbols
108
1
            .write()
109
1
            .insert(x.clone())
110
1
            .expect("x should not exist in the symbol-table yet, so we should be able to add it");
111

            
112
1
        symbols
113
1
            .write()
114
1
            .insert(y.clone())
115
1
            .expect("y should not exist in the symbol-table yet, so we should be able to add it");
116

            
117
1
        symbols
118
1
            .write()
119
1
            .insert(a.clone())
120
1
            .expect("a should not exist in the symbol-table yet, so we should be able to add it");
121

            
122
1
        let exprs = parse_exprs(src, symbols).unwrap();
123
1
        assert_eq!(exprs.len(), 2);
124

            
125
1
        assert_eq!(
126
1
            exprs[0],
127
1
            Expression::Geq(
128
1
                Metadata::new(),
129
1
                Moo::new(Expression::Atomic(Metadata::new(), Atom::new_ref(x))),
130
1
                Moo::new(Expression::Atomic(Metadata::new(), 5.into()))
131
1
            )
132
        );
133

            
134
1
        assert_eq!(
135
1
            exprs[1],
136
1
            Expression::Eq(
137
1
                Metadata::new(),
138
1
                Moo::new(Expression::Atomic(Metadata::new(), Atom::new_ref(y))),
139
1
                Moo::new(Expression::UnsafeDiv(
140
1
                    Metadata::new(),
141
1
                    Moo::new(Expression::Atomic(Metadata::new(), Atom::new_ref(a))),
142
1
                    Moo::new(Expression::Atomic(Metadata::new(), 2.into()))
143
1
                ))
144
1
            )
145
        );
146
1
    }
147
}