thiserror_impl/
valid.rs

1use crate::ast::{Enum, Field, Input, Struct, Variant};
2use crate::attr::Attrs;
3use quote::ToTokens;
4use std::collections::BTreeSet as Set;
5use syn::{Error, GenericArgument, Member, PathArguments, Result, Type};
6
7impl Input<'_> {
8    pub(crate) fn validate(&self) -> Result<()> {
9        match self {
10            Input::Struct(input) => input.validate(),
11            Input::Enum(input) => input.validate(),
12        }
13    }
14}
15
16impl Struct<'_> {
17    fn validate(&self) -> Result<()> {
18        check_non_field_attrs(&self.attrs)?;
19        if let Some(transparent) = self.attrs.transparent {
20            if self.fields.len() != 1 {
21                return Err(Error::new_spanned(
22                    transparent.original,
23                    "#[error(transparent)] requires exactly one field",
24                ));
25            }
26            if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) {
27                return Err(Error::new_spanned(
28                    source,
29                    "transparent error struct can't contain #[source]",
30                ));
31            }
32        }
33        check_field_attrs(&self.fields)?;
34        for field in &self.fields {
35            field.validate()?;
36        }
37        Ok(())
38    }
39}
40
41impl Enum<'_> {
42    fn validate(&self) -> Result<()> {
43        check_non_field_attrs(&self.attrs)?;
44        let has_display = self.has_display();
45        for variant in &self.variants {
46            variant.validate()?;
47            if has_display && variant.attrs.display.is_none() && variant.attrs.transparent.is_none()
48            {
49                return Err(Error::new_spanned(
50                    variant.original,
51                    "missing #[error(\"...\")] display attribute",
52                ));
53            }
54        }
55        let mut from_types = Set::new();
56        for variant in &self.variants {
57            if let Some(from_field) = variant.from_field() {
58                let repr = from_field.ty.to_token_stream().to_string();
59                if !from_types.insert(repr) {
60                    return Err(Error::new_spanned(
61                        from_field.original,
62                        "cannot derive From because another variant has the same source type",
63                    ));
64                }
65            }
66        }
67        Ok(())
68    }
69}
70
71impl Variant<'_> {
72    fn validate(&self) -> Result<()> {
73        check_non_field_attrs(&self.attrs)?;
74        if self.attrs.transparent.is_some() {
75            if self.fields.len() != 1 {
76                return Err(Error::new_spanned(
77                    self.original,
78                    "#[error(transparent)] requires exactly one field",
79                ));
80            }
81            if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) {
82                return Err(Error::new_spanned(
83                    source,
84                    "transparent variant can't contain #[source]",
85                ));
86            }
87        }
88        check_field_attrs(&self.fields)?;
89        for field in &self.fields {
90            field.validate()?;
91        }
92        Ok(())
93    }
94}
95
96impl Field<'_> {
97    fn validate(&self) -> Result<()> {
98        if let Some(display) = &self.attrs.display {
99            return Err(Error::new_spanned(
100                display.original,
101                "not expected here; the #[error(...)] attribute belongs on top of a struct or an enum variant",
102            ));
103        }
104        Ok(())
105    }
106}
107
108fn check_non_field_attrs(attrs: &Attrs) -> Result<()> {
109    if let Some(from) = &attrs.from {
110        return Err(Error::new_spanned(
111            from,
112            "not expected here; the #[from] attribute belongs on a specific field",
113        ));
114    }
115    if let Some(source) = &attrs.source {
116        return Err(Error::new_spanned(
117            source,
118            "not expected here; the #[source] attribute belongs on a specific field",
119        ));
120    }
121    if let Some(backtrace) = &attrs.backtrace {
122        return Err(Error::new_spanned(
123            backtrace,
124            "not expected here; the #[backtrace] attribute belongs on a specific field",
125        ));
126    }
127    if let Some(display) = &attrs.display {
128        if attrs.transparent.is_some() {
129            return Err(Error::new_spanned(
130                display.original,
131                "cannot have both #[error(transparent)] and a display attribute",
132            ));
133        }
134    }
135    Ok(())
136}
137
138fn check_field_attrs(fields: &[Field]) -> Result<()> {
139    let mut from_field = None;
140    let mut source_field = None;
141    let mut backtrace_field = None;
142    let mut has_backtrace = false;
143    for field in fields {
144        if let Some(from) = field.attrs.from {
145            if from_field.is_some() {
146                return Err(Error::new_spanned(from, "duplicate #[from] attribute"));
147            }
148            from_field = Some(field);
149        }
150        if let Some(source) = field.attrs.source {
151            if source_field.is_some() {
152                return Err(Error::new_spanned(source, "duplicate #[source] attribute"));
153            }
154            source_field = Some(field);
155        }
156        if let Some(backtrace) = field.attrs.backtrace {
157            if backtrace_field.is_some() {
158                return Err(Error::new_spanned(
159                    backtrace,
160                    "duplicate #[backtrace] attribute",
161                ));
162            }
163            backtrace_field = Some(field);
164            has_backtrace = true;
165        }
166        if let Some(transparent) = field.attrs.transparent {
167            return Err(Error::new_spanned(
168                transparent.original,
169                "#[error(transparent)] needs to go outside the enum or struct, not on an individual field",
170            ));
171        }
172        has_backtrace |= field.is_backtrace();
173    }
174    if let (Some(from_field), Some(source_field)) = (from_field, source_field) {
175        if !same_member(from_field, source_field) {
176            return Err(Error::new_spanned(
177                from_field.attrs.from,
178                "#[from] is only supported on the source field, not any other field",
179            ));
180        }
181    }
182    if let Some(from_field) = from_field {
183        let max_expected_fields = match backtrace_field {
184            Some(backtrace_field) => 1 + !same_member(from_field, backtrace_field) as usize,
185            None => 1 + has_backtrace as usize,
186        };
187        if fields.len() > max_expected_fields {
188            return Err(Error::new_spanned(
189                from_field.attrs.from,
190                "deriving From requires no fields other than source and backtrace",
191            ));
192        }
193    }
194    if let Some(source_field) = source_field.or(from_field) {
195        if contains_non_static_lifetime(source_field.ty) {
196            return Err(Error::new_spanned(
197                &source_field.original.ty,
198                "non-static lifetimes are not allowed in the source of an error, because std::error::Error requires the source is dyn Error + 'static",
199            ));
200        }
201    }
202    Ok(())
203}
204
205fn same_member(one: &Field, two: &Field) -> bool {
206    match (&one.member, &two.member) {
207        (Member::Named(one), Member::Named(two)) => one == two,
208        (Member::Unnamed(one), Member::Unnamed(two)) => one.index == two.index,
209        _ => unreachable!(),
210    }
211}
212
213fn contains_non_static_lifetime(ty: &Type) -> bool {
214    match ty {
215        Type::Path(ty) => {
216            let bracketed = match &ty.path.segments.last().unwrap().arguments {
217                PathArguments::AngleBracketed(bracketed) => bracketed,
218                _ => return false,
219            };
220            for arg in &bracketed.args {
221                match arg {
222                    GenericArgument::Type(ty) if contains_non_static_lifetime(ty) => return true,
223                    GenericArgument::Lifetime(lifetime) if lifetime.ident != "static" => {
224                        return true
225                    }
226                    _ => {}
227                }
228            }
229            false
230        }
231        Type::Reference(ty) => ty
232            .lifetime
233            .as_ref()
234            .map_or(false, |lifetime| lifetime.ident != "static"),
235        _ => false, // maybe implement later if there are common other cases
236    }
237}