1
use proc_macro::TokenStream;
2

            
3
use proc_macro2::Span;
4
use quote::quote;
5
use syn::token::Comma;
6
use syn::{
7
    ExprClosure, Ident, ItemFn, LitInt, LitStr, Result, bracketed, parenthesized, parse::Parse,
8
    parse::ParseStream, parse_macro_input,
9
};
10

            
11
struct RegisterRuleArgs {
12
    rule_set: LitStr,
13
    priority: LitInt,
14
    /// Expression variant names this rule applies to (e.g. `Add`, `Sub`).
15
    /// Empty means applicable to all variants (universal rule).
16
    applicable_variants: Vec<Ident>,
17
}
18

            
19
impl Parse for RegisterRuleArgs {
20
703
    fn parse(input: ParseStream) -> Result<Self> {
21
703
        if input.is_empty() {
22
2
            return Ok(RegisterRuleArgs {
23
2
                rule_set: LitStr::new("", Span::call_site()),
24
2
                priority: LitInt::new("0", Span::call_site()),
25
2
                applicable_variants: Vec::new(),
26
2
            });
27
701
        }
28

            
29
701
        let rule_set: LitStr = input.parse()?;
30
701
        let _: Comma = input.parse()?;
31
701
        let priority: LitInt = input.parse()?;
32

            
33
        // Parse optional variant names in brackets: "Minion", 4200, [Add, Sub]
34
701
        let mut applicable_variants = Vec::new();
35
701
        if input.peek(Comma) {
36
650
            let _: Comma = input.parse()?;
37
            let content;
38
650
            bracketed!(content in input);
39
860
            while !content.is_empty() {
40
860
                let variant: Ident = content.parse()?;
41
860
                applicable_variants.push(variant);
42
860
                if content.is_empty() {
43
650
                    break;
44
210
                }
45
210
                let _: Comma = content.parse()?;
46
            }
47
51
        }
48

            
49
701
        Ok(RegisterRuleArgs {
50
701
            rule_set,
51
701
            priority,
52
701
            applicable_variants,
53
701
        })
54
703
    }
55
}
56

            
57
/// Register a rule with the given rule sets and priorities.
58
#[proc_macro_attribute]
59
703
pub fn register_rule(arg_tokens: TokenStream, item: TokenStream) -> TokenStream {
60
703
    let func = parse_macro_input!(item as ItemFn);
61
703
    let rule_ident = &func.sig.ident;
62
703
    let static_name = format!("CONJURE_GEN_RULE_{rule_ident}").to_uppercase();
63
703
    let static_ident = Ident::new(&static_name, rule_ident.span());
64

            
65
703
    let args = parse_macro_input!(arg_tokens as RegisterRuleArgs);
66

            
67
703
    let rule_sets_token = if args.rule_set.value().is_empty() {
68
2
        quote! { &[] }
69
    } else {
70
701
        let rule_set_name = &args.rule_set;
71
701
        let priority = &args.priority;
72
701
        quote! { &[(#rule_set_name, #priority as u16)] }
73
    };
74

            
75
703
    let applicable_to = if args.applicable_variants.is_empty() {
76
53
        quote! { None }
77
    } else {
78
650
        let variants = &args.applicable_variants;
79
650
        quote! {
80
            Some(&[#(::conjure_cp::discriminant_from_name!(#variants)),*])
81
        }
82
    };
83

            
84
703
    let expanded = quote! {
85
        #func
86

            
87
        use ::conjure_cp::rule_engine::_dependencies::*; // ToDo idk if we need to explicitly do that?
88

            
89
        #[::conjure_cp::rule_engine::_dependencies::distributed_slice(::conjure_cp::rule_engine::RULES_DISTRIBUTED_SLICE)]
90
        pub static #static_ident: ::conjure_cp::rule_engine::Rule<'static> = ::conjure_cp::rule_engine::Rule {
91
            name: stringify!(#rule_ident),
92
            application: #rule_ident,
93
            rule_sets: #rule_sets_token,
94
            applicable_to: #applicable_to,
95
        };
96
    };
97

            
98
703
    TokenStream::from(expanded)
99
703
}
100

            
101
67
fn parse_parenthesized<T: Parse>(input: ParseStream) -> Result<Vec<T>> {
102
    let content;
103
67
    parenthesized!(content in input);
104

            
105
67
    let mut paths = Vec::new();
106
69
    while !content.is_empty() {
107
55
        let path = content.parse()?;
108
55
        paths.push(path);
109
55
        if content.is_empty() {
110
53
            break;
111
2
        }
112
2
        content.parse::<Comma>()?;
113
    }
114

            
115
67
    Ok(paths)
116
67
}
117

            
118
struct RuleSetArgs {
119
    name: LitStr,
120
    dependencies: Vec<LitStr>,
121
    applies_fn: Option<ExprClosure>,
122
}
123

            
124
impl Parse for RuleSetArgs {
125
68
    fn parse(input: ParseStream) -> Result<Self> {
126
68
        let name = input.parse()?;
127

            
128
68
        if input.is_empty() {
129
1
            return Ok(Self {
130
1
                name,
131
1
                dependencies: Vec::new(),
132
1
                applies_fn: None,
133
1
            });
134
67
        }
135

            
136
67
        input.parse::<Comma>()?;
137
67
        let dependencies = parse_parenthesized::<LitStr>(input)?;
138

            
139
67
        if input.is_empty() {
140
34
            return Ok(Self {
141
34
                name,
142
34
                dependencies,
143
34
                applies_fn: None,
144
34
            });
145
33
        }
146

            
147
33
        input.parse::<Comma>()?;
148
33
        let applies_fn = input.parse::<ExprClosure>()?;
149

            
150
33
        Ok(Self {
151
33
            name,
152
33
            dependencies,
153
33
            applies_fn: Some(applies_fn),
154
33
        })
155
68
    }
156
}
157

            
158
/**
159
* Register a rule set with the given name, dependencies, and metadata.
160
*
161
* # Example
162
* ```rust
163
 * use conjure_cp_rule_macros::register_rule_set;
164
 * register_rule_set!("MyRuleSet", ("DependencyRuleSet", "AnotherRuleSet"));
165
* ```
166
 */
167
#[proc_macro]
168
68
pub fn register_rule_set(args: TokenStream) -> TokenStream {
169
    let RuleSetArgs {
170
68
        name,
171
68
        dependencies,
172
68
        applies_fn,
173
68
    } = parse_macro_input!(args as RuleSetArgs);
174

            
175
68
    let static_name = format!("CONJURE_GEN_RULE_SET_{}", name.value()).to_uppercase();
176
68
    let static_ident = Ident::new(&static_name, Span::call_site());
177

            
178
68
    let dependencies = quote! {
179
        #(#dependencies),*
180
    };
181

            
182
68
    let applies_to_family = match applies_fn {
183
        // Does not apply by default, e.g. only used as a dependency
184
35
        None => quote! { |_: &::conjure_cp::settings::SolverFamily| false },
185
33
        Some(func) => quote! { #func },
186
    };
187

            
188
68
    let expanded = quote! {
189
        use ::conjure_cp::rule_engine::_dependencies::*; // ToDo idk if we need to explicitly do that?
190
        #[::conjure_cp::rule_engine::_dependencies::distributed_slice(::conjure_cp::rule_engine::RULE_SETS_DISTRIBUTED_SLICE)]
191
        pub static #static_ident: ::conjure_cp::rule_engine::RuleSet<'static> =
192
            ::conjure_cp::rule_engine::RuleSet::new(#name, &[#dependencies], #applies_to_family);
193
    };
194

            
195
68
    TokenStream::from(expanded)
196
68
}