rtic_macros/syntax/parse/
util.rs

1use syn::{
2    bracketed,
3    parse::{self, ParseStream},
4    punctuated::Punctuated,
5    spanned::Spanned,
6    Abi, AttrStyle, Attribute, Expr, ExprPath, FnArg, ForeignItemFn, Ident, ItemFn, Pat, PatType,
7    Path, PathArguments, ReturnType, Token, Type, Visibility,
8};
9
10use crate::syntax::{
11    ast::{Access, Local, LocalResources, SharedResources, TaskLocal},
12    Map,
13};
14
15pub fn abi_is_rust(abi: &Abi) -> bool {
16    match &abi.name {
17        None => true,
18        Some(s) => s.value() == "Rust",
19    }
20}
21
22pub fn attr_eq(attr: &Attribute, name: &str) -> bool {
23    attr.style == AttrStyle::Outer && attr.path().segments.len() == 1 && {
24        let segment = attr.path().segments.first().unwrap();
25        segment.arguments == PathArguments::None && *segment.ident.to_string() == *name
26    }
27}
28
29/// checks that a function signature
30///
31/// - has no bounds (like where clauses)
32/// - is not `async`
33/// - is not `const`
34/// - is not `unsafe`
35/// - is not generic (has no type parameters)
36/// - is not variadic
37/// - uses the Rust ABI (and not e.g. "C")
38pub fn check_fn_signature(item: &ItemFn, allow_async: bool) -> bool {
39    item.vis == Visibility::Inherited
40        && item.sig.constness.is_none()
41        && (item.sig.asyncness.is_none() || allow_async)
42        && item.sig.abi.is_none()
43        && item.sig.unsafety.is_none()
44        && item.sig.generics.params.is_empty()
45        && item.sig.generics.where_clause.is_none()
46        && item.sig.variadic.is_none()
47}
48
49#[allow(dead_code)]
50pub fn check_foreign_fn_signature(item: &ForeignItemFn, allow_async: bool) -> bool {
51    item.vis == Visibility::Inherited
52        && item.sig.constness.is_none()
53        && (item.sig.asyncness.is_none() || allow_async)
54        && item.sig.abi.is_none()
55        && item.sig.unsafety.is_none()
56        && item.sig.generics.params.is_empty()
57        && item.sig.generics.where_clause.is_none()
58        && item.sig.variadic.is_none()
59}
60
61pub struct FilterAttrs {
62    pub cfgs: Vec<Attribute>,
63    pub docs: Vec<Attribute>,
64    pub attrs: Vec<Attribute>,
65}
66
67pub fn filter_attributes(input_attrs: Vec<Attribute>) -> FilterAttrs {
68    let mut cfgs = vec![];
69    let mut docs = vec![];
70    let mut attrs = vec![];
71
72    for attr in input_attrs {
73        if attr_eq(&attr, "cfg") {
74            cfgs.push(attr);
75        } else if attr_eq(&attr, "doc") {
76            docs.push(attr);
77        } else {
78            attrs.push(attr);
79        }
80    }
81
82    FilterAttrs { cfgs, docs, attrs }
83}
84
85pub fn extract_lock_free(attrs: &mut Vec<Attribute>) -> parse::Result<bool> {
86    if let Some(pos) = attrs.iter().position(|attr| attr_eq(attr, "lock_free")) {
87        attrs.remove(pos);
88        Ok(true)
89    } else {
90        Ok(false)
91    }
92}
93
94pub fn parse_shared_resources(content: ParseStream<'_>) -> parse::Result<SharedResources> {
95    let inner;
96    bracketed!(inner in content);
97
98    let mut resources = Map::new();
99    for e in inner.call(Punctuated::<Expr, Token![,]>::parse_terminated)? {
100        let err = Err(parse::Error::new(
101            e.span(),
102            "identifier appears more than once in list",
103        ));
104        let (access, path) = match e {
105            Expr::Path(e) => (Access::Exclusive, e.path),
106
107            Expr::Reference(ref r) if r.mutability.is_none() => match &*r.expr {
108                Expr::Path(e) => (Access::Shared, e.path.clone()),
109
110                _ => return err,
111            },
112
113            _ => return err,
114        };
115
116        let ident = extract_resource_name_ident(path)?;
117
118        if resources.contains_key(&ident) {
119            return Err(parse::Error::new(
120                ident.span(),
121                "resource appears more than once in list",
122            ));
123        }
124
125        resources.insert(ident, access);
126    }
127
128    Ok(resources)
129}
130
131fn extract_resource_name_ident(path: Path) -> parse::Result<Ident> {
132    if path.leading_colon.is_some()
133        || path.segments.len() != 1
134        || path.segments[0].arguments != PathArguments::None
135    {
136        Err(parse::Error::new(
137            path.span(),
138            "resource must be an identifier, not a path",
139        ))
140    } else {
141        Ok(path.segments[0].ident.clone())
142    }
143}
144
145pub fn parse_local_resources(content: ParseStream<'_>) -> parse::Result<LocalResources> {
146    let input;
147    bracketed!(input in content);
148
149    let mut resources = Map::new();
150
151    let error_msg_no_local_resources =
152        "malformed, expected 'local = [EXPRPATH: TYPE = EXPR]', or 'local = [EXPRPATH, ...]'";
153
154    loop {
155        if input.is_empty() {
156            break;
157        }
158        // Type ascription is de-RFCd from Rust in
159        // https://github.com/rust-lang/rfcs/pull/3307
160        // Manually pull out the tokens
161
162        // Two acceptable variants:
163        //
164        // Task local and declared (initialized in place)
165        // local = [EXPRPATH: TYPE = EXPR, ...]
166        //          ~~~~~~~~~~~~~~~~~~~~~~
167        // or
168        // Task local but not initialized
169        // local = [EXPRPATH, ...],
170        //          ~~~~~~~~~
171
172        // Common: grab first identifier EXPRPATH
173        // local = [EXPRPATH: TYPE = EXPR, ...]
174        //          ~~~~~~~~
175        let exprpath: ExprPath = input.parse()?;
176
177        let name = extract_resource_name_ident(exprpath.path)?;
178
179        // Extract attributes
180        let ExprPath { attrs, .. } = exprpath;
181        let (cfgs, attrs) = {
182            let FilterAttrs { cfgs, attrs, .. } = filter_attributes(attrs);
183            (cfgs, attrs)
184        };
185
186        let local;
187
188        // Declared requries type ascription
189        if input.peek(Token![:]) {
190            // Handle colon
191            let _: Token![:] = input.parse()?;
192
193            // Extract the type
194            let ty: Box<Type> = input.parse()?;
195
196            if input.peek(Token![=]) {
197                // Handle equal sign
198                let _: Token![=] = input.parse()?;
199            } else {
200                return Err(parse::Error::new(
201                    name.span(),
202                    "malformed, expected 'IDENT: TYPE = EXPR'",
203                ));
204            }
205
206            // Grab the final expression right of equal
207            let expr: Box<Expr> = input.parse()?;
208
209            // We got a trailing colon, remove it
210            if input.peek(Token![,]) {
211                let _: Token![,] = input.parse()?;
212            }
213
214            // Error check
215            match &*ty {
216                Type::Array(_) => {}
217                Type::Path(_) => {}
218                Type::Ptr(_) => {}
219                Type::Tuple(_) => {}
220                _ => {
221                    return Err(parse::Error::new(
222                        ty.span(),
223                        "unsupported type, must be an array, tuple, pointer or type path",
224                    ))
225                }
226            };
227
228            local = TaskLocal::Declared(Local {
229                attrs,
230                cfgs,
231                ty,
232                expr,
233            });
234        } else if input.peek(Token![=]) {
235            // Missing type ascription is not valid
236            return Err(parse::Error::new(name.span(), "malformed, expected a type"));
237        } else if input.peek(Token![,]) {
238            // Attributes not supported on non-initialized local resources!
239
240            if !attrs.is_empty() {
241                return Err(parse::Error::new(
242                    name.span(),
243                    "attributes are not supported here",
244                ));
245            }
246
247            // Remove comma
248            let _: Token![,] = input.parse()?;
249
250            // Expected when multiple local resources
251            local = TaskLocal::External;
252        } else if input.is_empty() {
253            // There was only one single local resource
254            // Task local but not initialized
255            // local = [EXPRPATH],
256            //          ~~~~~~~~
257            local = TaskLocal::External;
258        } else {
259            // Specifying local without any resources is invalid
260            return Err(parse::Error::new(name.span(), error_msg_no_local_resources));
261        };
262
263        if resources.contains_key(&name) {
264            return Err(parse::Error::new(
265                name.span(),
266                "resource appears more than once in list",
267            ));
268        }
269
270        resources.insert(name, local);
271    }
272
273    if resources.is_empty() {
274        return Err(parse::Error::new(
275            input.span(),
276            error_msg_no_local_resources,
277        ));
278    }
279
280    Ok(resources)
281}
282
283type ParseInputResult = Option<(Box<Pat>, Result<Vec<PatType>, FnArg>)>;
284
285pub fn parse_inputs(inputs: Punctuated<FnArg, Token![,]>, name: &str) -> ParseInputResult {
286    let mut inputs = inputs.into_iter();
287
288    match inputs.next() {
289        Some(FnArg::Typed(first)) => {
290            if type_is_path(&first.ty, &[name, "Context"]) {
291                let rest = inputs
292                    .map(|arg| match arg {
293                        FnArg::Typed(arg) => Ok(arg),
294                        _ => Err(arg),
295                    })
296                    .collect::<Result<Vec<_>, _>>();
297
298                Some((first.pat, rest))
299            } else {
300                None
301            }
302        }
303
304        _ => None,
305    }
306}
307
308pub fn type_is_bottom(ty: &ReturnType) -> bool {
309    if let ReturnType::Type(_, ty) = ty {
310        matches!(**ty, Type::Never(_))
311    } else {
312        false
313    }
314}
315
316fn extract_init_resource_name_ident(ty: Type) -> Result<Ident, ()> {
317    match ty {
318        Type::Path(path) => {
319            let path = path.path;
320
321            if path.leading_colon.is_some()
322                || path.segments.len() != 1
323                || path.segments[0].arguments != PathArguments::None
324            {
325                Err(())
326            } else {
327                Ok(path.segments[0].ident.clone())
328            }
329        }
330        _ => Err(()),
331    }
332}
333
334/// Checks Init's return type, return the user provided types for analysis
335pub fn type_is_init_return(ty: &ReturnType) -> Result<(Ident, Ident), ()> {
336    match ty {
337        ReturnType::Default => Err(()),
338
339        ReturnType::Type(_, ty) => match &**ty {
340            Type::Tuple(t) => {
341                // return should be:
342                // fn -> (User's #[shared] struct, User's #[local] struct)
343                //
344                // We check the length and the last one here, analysis checks that the user
345                // provided structs are correct.
346                if t.elems.len() == 2 {
347                    return Ok((
348                        extract_init_resource_name_ident(t.elems[0].clone())?,
349                        extract_init_resource_name_ident(t.elems[1].clone())?,
350                    ));
351                }
352
353                Err(())
354            }
355
356            _ => Err(()),
357        },
358    }
359}
360
361pub fn type_is_path(ty: &Type, segments: &[&str]) -> bool {
362    match ty {
363        Type::Path(tpath) if tpath.qself.is_none() => {
364            tpath.path.segments.len() == segments.len()
365                && tpath
366                    .path
367                    .segments
368                    .iter()
369                    .zip(segments)
370                    .all(|(lhs, rhs)| lhs.ident == **rhs)
371        }
372
373        _ => false,
374    }
375}
376
377pub fn type_is_unit(ty: &ReturnType) -> bool {
378    if let ReturnType::Type(_, ty) = ty {
379        if let Type::Tuple(ref tuple) = **ty {
380            tuple.elems.is_empty()
381        } else {
382            false
383        }
384    } else {
385        true
386    }
387}