pyo3_macros_backend/
attributes.rs

1use proc_macro2::TokenStream;
2use quote::{quote, ToTokens};
3use syn::parse::Parser;
4use syn::{
5    ext::IdentExt,
6    parse::{Parse, ParseStream},
7    punctuated::Punctuated,
8    spanned::Spanned,
9    token::Comma,
10    Attribute, Expr, ExprPath, Ident, Index, LitBool, LitStr, Member, Path, Result, Token,
11};
12
13pub mod kw {
14    syn::custom_keyword!(annotation);
15    syn::custom_keyword!(attribute);
16    syn::custom_keyword!(cancel_handle);
17    syn::custom_keyword!(constructor);
18    syn::custom_keyword!(dict);
19    syn::custom_keyword!(eq);
20    syn::custom_keyword!(eq_int);
21    syn::custom_keyword!(extends);
22    syn::custom_keyword!(freelist);
23    syn::custom_keyword!(from_py_with);
24    syn::custom_keyword!(frozen);
25    syn::custom_keyword!(get);
26    syn::custom_keyword!(get_all);
27    syn::custom_keyword!(hash);
28    syn::custom_keyword!(into_py_with);
29    syn::custom_keyword!(item);
30    syn::custom_keyword!(from_item_all);
31    syn::custom_keyword!(mapping);
32    syn::custom_keyword!(module);
33    syn::custom_keyword!(name);
34    syn::custom_keyword!(ord);
35    syn::custom_keyword!(pass_module);
36    syn::custom_keyword!(rename_all);
37    syn::custom_keyword!(sequence);
38    syn::custom_keyword!(set);
39    syn::custom_keyword!(set_all);
40    syn::custom_keyword!(signature);
41    syn::custom_keyword!(str);
42    syn::custom_keyword!(subclass);
43    syn::custom_keyword!(submodule);
44    syn::custom_keyword!(text_signature);
45    syn::custom_keyword!(transparent);
46    syn::custom_keyword!(unsendable);
47    syn::custom_keyword!(weakref);
48    syn::custom_keyword!(gil_used);
49}
50
51fn take_int(read: &mut &str, tracker: &mut usize) -> String {
52    let mut int = String::new();
53    for (i, ch) in read.char_indices() {
54        match ch {
55            '0'..='9' => {
56                *tracker += 1;
57                int.push(ch)
58            }
59            _ => {
60                *read = &read[i..];
61                break;
62            }
63        }
64    }
65    int
66}
67
68fn take_ident(read: &mut &str, tracker: &mut usize) -> Ident {
69    let mut ident = String::new();
70    if read.starts_with("r#") {
71        ident.push_str("r#");
72        *tracker += 2;
73        *read = &read[2..];
74    }
75    for (i, ch) in read.char_indices() {
76        match ch {
77            'a'..='z' | 'A'..='Z' | '0'..='9' | '_' => {
78                *tracker += 1;
79                ident.push(ch)
80            }
81            _ => {
82                *read = &read[i..];
83                break;
84            }
85        }
86    }
87    Ident::parse_any.parse_str(&ident).unwrap()
88}
89
90// shorthand parsing logic inspiration taken from https://github.com/dtolnay/thiserror/blob/master/impl/src/fmt.rs
91fn parse_shorthand_format(fmt: LitStr) -> Result<(LitStr, Vec<Member>)> {
92    let span = fmt.span();
93    let token = fmt.token();
94    let value = fmt.value();
95    let mut read = value.as_str();
96    let mut out = String::new();
97    let mut members = Vec::new();
98    let mut tracker = 1;
99    while let Some(brace) = read.find('{') {
100        tracker += brace;
101        out += &read[..brace + 1];
102        read = &read[brace + 1..];
103        if read.starts_with('{') {
104            out.push('{');
105            read = &read[1..];
106            tracker += 2;
107            continue;
108        }
109        let next = match read.chars().next() {
110            Some(next) => next,
111            None => break,
112        };
113        tracker += 1;
114        let member = match next {
115            '0'..='9' => {
116                let start = tracker;
117                let index = take_int(&mut read, &mut tracker).parse::<u32>().unwrap();
118                let end = tracker;
119                let subspan = token.subspan(start..end).unwrap_or(span);
120                let idx = Index {
121                    index,
122                    span: subspan,
123                };
124                Member::Unnamed(idx)
125            }
126            'a'..='z' | 'A'..='Z' | '_' => {
127                let start = tracker;
128                let mut ident = take_ident(&mut read, &mut tracker);
129                let end = tracker;
130                let subspan = token.subspan(start..end).unwrap_or(span);
131                ident.set_span(subspan);
132                Member::Named(ident)
133            }
134            '}' | ':' => {
135                let start = tracker;
136                tracker += 1;
137                let end = tracker;
138                let subspan = token.subspan(start..end).unwrap_or(span);
139                // we found a closing bracket or formatting ':' without finding a member, we assume the user wants the instance formatted here
140                bail_spanned!(subspan.span() => "No member found, you must provide a named or positionally specified member.")
141            }
142            _ => continue,
143        };
144        members.push(member);
145    }
146    out += read;
147    Ok((LitStr::new(&out, span), members))
148}
149
150#[derive(Clone, Debug)]
151pub struct StringFormatter {
152    pub fmt: LitStr,
153    pub args: Vec<Member>,
154}
155
156impl Parse for crate::attributes::StringFormatter {
157    fn parse(input: ParseStream<'_>) -> Result<Self> {
158        let (fmt, args) = parse_shorthand_format(input.parse()?)?;
159        Ok(Self { fmt, args })
160    }
161}
162
163impl ToTokens for crate::attributes::StringFormatter {
164    fn to_tokens(&self, tokens: &mut TokenStream) {
165        self.fmt.to_tokens(tokens);
166        tokens.extend(quote! {self.args})
167    }
168}
169
170#[derive(Clone, Debug)]
171pub struct KeywordAttribute<K, V> {
172    pub kw: K,
173    pub value: V,
174}
175
176#[derive(Clone, Debug)]
177pub struct OptionalKeywordAttribute<K, V> {
178    pub kw: K,
179    pub value: Option<V>,
180}
181
182/// A helper type which parses the inner type via a literal string
183/// e.g. `LitStrValue<Path>` -> parses "some::path" in quotes.
184#[derive(Clone, Debug, PartialEq, Eq)]
185pub struct LitStrValue<T>(pub T);
186
187impl<T: Parse> Parse for LitStrValue<T> {
188    fn parse(input: ParseStream<'_>) -> Result<Self> {
189        let lit_str: LitStr = input.parse()?;
190        lit_str.parse().map(LitStrValue)
191    }
192}
193
194impl<T: ToTokens> ToTokens for LitStrValue<T> {
195    fn to_tokens(&self, tokens: &mut TokenStream) {
196        self.0.to_tokens(tokens)
197    }
198}
199
200/// A helper type which parses a name via a literal string
201#[derive(Clone, Debug, PartialEq, Eq)]
202pub struct NameLitStr(pub Ident);
203
204impl Parse for NameLitStr {
205    fn parse(input: ParseStream<'_>) -> Result<Self> {
206        let string_literal: LitStr = input.parse()?;
207        if let Ok(ident) = string_literal.parse_with(Ident::parse_any) {
208            Ok(NameLitStr(ident))
209        } else {
210            bail_spanned!(string_literal.span() => "expected a single identifier in double quotes")
211        }
212    }
213}
214
215impl ToTokens for NameLitStr {
216    fn to_tokens(&self, tokens: &mut TokenStream) {
217        self.0.to_tokens(tokens)
218    }
219}
220
221/// Available renaming rules
222#[derive(Clone, Copy, Debug, PartialEq, Eq)]
223pub enum RenamingRule {
224    CamelCase,
225    KebabCase,
226    Lowercase,
227    PascalCase,
228    ScreamingKebabCase,
229    ScreamingSnakeCase,
230    SnakeCase,
231    Uppercase,
232}
233
234/// A helper type which parses a renaming rule via a literal string
235#[derive(Clone, Debug, PartialEq, Eq)]
236pub struct RenamingRuleLitStr {
237    pub lit: LitStr,
238    pub rule: RenamingRule,
239}
240
241impl Parse for RenamingRuleLitStr {
242    fn parse(input: ParseStream<'_>) -> Result<Self> {
243        let string_literal: LitStr = input.parse()?;
244        let rule = match string_literal.value().as_ref() {
245            "camelCase" => RenamingRule::CamelCase,
246            "kebab-case" => RenamingRule::KebabCase,
247            "lowercase" => RenamingRule::Lowercase,
248            "PascalCase" => RenamingRule::PascalCase,
249            "SCREAMING-KEBAB-CASE" => RenamingRule::ScreamingKebabCase,
250            "SCREAMING_SNAKE_CASE" => RenamingRule::ScreamingSnakeCase,
251            "snake_case" => RenamingRule::SnakeCase,
252            "UPPERCASE" => RenamingRule::Uppercase,
253            _ => {
254                bail_spanned!(string_literal.span() => "expected a valid renaming rule, possible values are: \"camelCase\", \"kebab-case\", \"lowercase\", \"PascalCase\", \"SCREAMING-KEBAB-CASE\", \"SCREAMING_SNAKE_CASE\", \"snake_case\", \"UPPERCASE\"")
255            }
256        };
257        Ok(Self {
258            lit: string_literal,
259            rule,
260        })
261    }
262}
263
264impl ToTokens for RenamingRuleLitStr {
265    fn to_tokens(&self, tokens: &mut TokenStream) {
266        self.lit.to_tokens(tokens)
267    }
268}
269
270/// Text signatue can be either a literal string or opt-in/out
271#[derive(Clone, Debug, PartialEq, Eq)]
272pub enum TextSignatureAttributeValue {
273    Str(LitStr),
274    // `None` ident to disable automatic text signature generation
275    Disabled(Ident),
276}
277
278impl Parse for TextSignatureAttributeValue {
279    fn parse(input: ParseStream<'_>) -> Result<Self> {
280        if let Ok(lit_str) = input.parse::<LitStr>() {
281            return Ok(TextSignatureAttributeValue::Str(lit_str));
282        }
283
284        let err_span = match input.parse::<Ident>() {
285            Ok(ident) if ident == "None" => {
286                return Ok(TextSignatureAttributeValue::Disabled(ident));
287            }
288            Ok(other_ident) => other_ident.span(),
289            Err(e) => e.span(),
290        };
291
292        Err(err_spanned!(err_span => "expected a string literal or `None`"))
293    }
294}
295
296impl ToTokens for TextSignatureAttributeValue {
297    fn to_tokens(&self, tokens: &mut TokenStream) {
298        match self {
299            TextSignatureAttributeValue::Str(s) => s.to_tokens(tokens),
300            TextSignatureAttributeValue::Disabled(b) => b.to_tokens(tokens),
301        }
302    }
303}
304
305pub type ExtendsAttribute = KeywordAttribute<kw::extends, Path>;
306pub type FreelistAttribute = KeywordAttribute<kw::freelist, Box<Expr>>;
307pub type ModuleAttribute = KeywordAttribute<kw::module, LitStr>;
308pub type NameAttribute = KeywordAttribute<kw::name, NameLitStr>;
309pub type RenameAllAttribute = KeywordAttribute<kw::rename_all, RenamingRuleLitStr>;
310pub type StrFormatterAttribute = OptionalKeywordAttribute<kw::str, StringFormatter>;
311pub type TextSignatureAttribute = KeywordAttribute<kw::text_signature, TextSignatureAttributeValue>;
312pub type SubmoduleAttribute = kw::submodule;
313pub type GILUsedAttribute = KeywordAttribute<kw::gil_used, LitBool>;
314
315impl<K: Parse + std::fmt::Debug, V: Parse> Parse for KeywordAttribute<K, V> {
316    fn parse(input: ParseStream<'_>) -> Result<Self> {
317        let kw: K = input.parse()?;
318        let _: Token![=] = input.parse()?;
319        let value = input.parse()?;
320        Ok(KeywordAttribute { kw, value })
321    }
322}
323
324impl<K: ToTokens, V: ToTokens> ToTokens for KeywordAttribute<K, V> {
325    fn to_tokens(&self, tokens: &mut TokenStream) {
326        self.kw.to_tokens(tokens);
327        Token![=](self.kw.span()).to_tokens(tokens);
328        self.value.to_tokens(tokens);
329    }
330}
331
332impl<K: Parse + std::fmt::Debug, V: Parse> Parse for OptionalKeywordAttribute<K, V> {
333    fn parse(input: ParseStream<'_>) -> Result<Self> {
334        let kw: K = input.parse()?;
335        let value = match input.parse::<Token![=]>() {
336            Ok(_) => Some(input.parse()?),
337            Err(_) => None,
338        };
339        Ok(OptionalKeywordAttribute { kw, value })
340    }
341}
342
343impl<K: ToTokens, V: ToTokens> ToTokens for OptionalKeywordAttribute<K, V> {
344    fn to_tokens(&self, tokens: &mut TokenStream) {
345        self.kw.to_tokens(tokens);
346        if self.value.is_some() {
347            Token![=](self.kw.span()).to_tokens(tokens);
348            self.value.to_tokens(tokens);
349        }
350    }
351}
352
353#[derive(Debug, Clone)]
354pub struct ExprPathWrap {
355    pub from_lit_str: bool,
356    pub expr_path: ExprPath,
357}
358
359impl Parse for ExprPathWrap {
360    fn parse(input: ParseStream<'_>) -> Result<Self> {
361        match input.parse::<ExprPath>() {
362            Ok(expr_path) => Ok(ExprPathWrap {
363                from_lit_str: false,
364                expr_path,
365            }),
366            Err(e) => match input.parse::<LitStrValue<ExprPath>>() {
367                Ok(LitStrValue(expr_path)) => Ok(ExprPathWrap {
368                    from_lit_str: true,
369                    expr_path,
370                }),
371                Err(_) => Err(e),
372            },
373        }
374    }
375}
376
377impl ToTokens for ExprPathWrap {
378    fn to_tokens(&self, tokens: &mut TokenStream) {
379        self.expr_path.to_tokens(tokens)
380    }
381}
382
383pub type FromPyWithAttribute = KeywordAttribute<kw::from_py_with, ExprPathWrap>;
384pub type IntoPyWithAttribute = KeywordAttribute<kw::into_py_with, ExprPath>;
385
386pub type DefaultAttribute = OptionalKeywordAttribute<Token![default], Expr>;
387
388/// For specifying the path to the pyo3 crate.
389pub type CrateAttribute = KeywordAttribute<Token![crate], LitStrValue<Path>>;
390
391pub fn get_pyo3_options<T: Parse>(attr: &syn::Attribute) -> Result<Option<Punctuated<T, Comma>>> {
392    if attr.path().is_ident("pyo3") {
393        attr.parse_args_with(Punctuated::parse_terminated).map(Some)
394    } else {
395        Ok(None)
396    }
397}
398
399/// Takes attributes from an attribute vector.
400///
401/// For each attribute in `attrs`, `extractor` is called. If `extractor` returns `Ok(true)`, then
402/// the attribute will be removed from the vector.
403///
404/// This is similar to `Vec::retain` except the closure is fallible and the condition is reversed.
405/// (In `retain`, returning `true` keeps the element, here it removes it.)
406pub fn take_attributes(
407    attrs: &mut Vec<Attribute>,
408    mut extractor: impl FnMut(&Attribute) -> Result<bool>,
409) -> Result<()> {
410    *attrs = attrs
411        .drain(..)
412        .filter_map(|attr| {
413            extractor(&attr)
414                .map(move |attribute_handled| if attribute_handled { None } else { Some(attr) })
415                .transpose()
416        })
417        .collect::<Result<_>>()?;
418    Ok(())
419}
420
421pub fn take_pyo3_options<T: Parse>(attrs: &mut Vec<syn::Attribute>) -> Result<Vec<T>> {
422    let mut out = Vec::new();
423    let mut all_errors = ErrorCombiner(None);
424    take_attributes(attrs, |attr| match get_pyo3_options(attr) {
425        Ok(result) => {
426            if let Some(options) = result {
427                out.extend(options);
428                Ok(true)
429            } else {
430                Ok(false)
431            }
432        }
433        Err(err) => {
434            all_errors.combine(err);
435            Ok(true)
436        }
437    })?;
438    all_errors.ensure_empty()?;
439    Ok(out)
440}
441
442pub struct ErrorCombiner(pub Option<syn::Error>);
443
444impl ErrorCombiner {
445    pub fn combine(&mut self, error: syn::Error) {
446        if let Some(existing) = &mut self.0 {
447            existing.combine(error);
448        } else {
449            self.0 = Some(error);
450        }
451    }
452
453    pub fn ensure_empty(self) -> Result<()> {
454        if let Some(error) = self.0 {
455            Err(error)
456        } else {
457            Ok(())
458        }
459    }
460}
⚠️ Internal Docs ⚠️ Not Public API 👉 Official Docs Here