pyo3/err/
err_state.rs

1use std::{
2    cell::UnsafeCell,
3    sync::{Mutex, Once},
4    thread::ThreadId,
5};
6
7use crate::{
8    exceptions::{PyBaseException, PyTypeError},
9    ffi,
10    ffi_ptr_ext::FfiPtrExt,
11    types::{PyAnyMethods, PyTraceback, PyType},
12    Bound, Py, PyAny, PyErrArguments, PyObject, PyTypeInfo, Python,
13};
14
15pub(crate) struct PyErrState {
16    // Safety: can only hand out references when in the "normalized" state. Will never change
17    // after normalization.
18    normalized: Once,
19    // Guard against re-entrancy when normalizing the exception state.
20    normalizing_thread: Mutex<Option<ThreadId>>,
21    inner: UnsafeCell<Option<PyErrStateInner>>,
22}
23
24// Safety: The inner value is protected by locking to ensure that only the normalized state is
25// handed out as a reference.
26unsafe impl Send for PyErrState {}
27unsafe impl Sync for PyErrState {}
28#[cfg(feature = "nightly")]
29unsafe impl crate::marker::Ungil for PyErrState {}
30
31impl PyErrState {
32    pub(crate) fn lazy(f: Box<PyErrStateLazyFn>) -> Self {
33        Self::from_inner(PyErrStateInner::Lazy(f))
34    }
35
36    pub(crate) fn lazy_arguments(ptype: Py<PyAny>, args: impl PyErrArguments + 'static) -> Self {
37        Self::from_inner(PyErrStateInner::Lazy(Box::new(move |py| {
38            PyErrStateLazyFnOutput {
39                ptype,
40                pvalue: args.arguments(py),
41            }
42        })))
43    }
44
45    pub(crate) fn normalized(normalized: PyErrStateNormalized) -> Self {
46        let state = Self::from_inner(PyErrStateInner::Normalized(normalized));
47        // This state is already normalized, by completing the Once immediately we avoid
48        // reaching the `py.allow_threads` in `make_normalized` which is less efficient
49        // and introduces a GIL switch which could deadlock.
50        // See https://github.com/PyO3/pyo3/issues/4764
51        state.normalized.call_once(|| {});
52        state
53    }
54
55    pub(crate) fn restore(self, py: Python<'_>) {
56        self.inner
57            .into_inner()
58            .expect("PyErr state should never be invalid outside of normalization")
59            .restore(py)
60    }
61
62    fn from_inner(inner: PyErrStateInner) -> Self {
63        Self {
64            normalized: Once::new(),
65            normalizing_thread: Mutex::new(None),
66            inner: UnsafeCell::new(Some(inner)),
67        }
68    }
69
70    #[inline]
71    pub(crate) fn as_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
72        if self.normalized.is_completed() {
73            match unsafe {
74                // Safety: self.inner will never be written again once normalized.
75                &*self.inner.get()
76            } {
77                Some(PyErrStateInner::Normalized(n)) => return n,
78                _ => unreachable!(),
79            }
80        }
81
82        self.make_normalized(py)
83    }
84
85    #[cold]
86    fn make_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
87        // This process is safe because:
88        // - Access is guaranteed not to be concurrent thanks to `Python` GIL token
89        // - Write happens only once, and then never will change again.
90
91        // Guard against re-entrant normalization, because `Once` does not provide
92        // re-entrancy guarantees.
93        if let Some(thread) = self.normalizing_thread.lock().unwrap().as_ref() {
94            assert!(
95                !(*thread == std::thread::current().id()),
96                "Re-entrant normalization of PyErrState detected"
97            );
98        }
99
100        // avoid deadlock of `.call_once` with the GIL
101        py.allow_threads(|| {
102            self.normalized.call_once(|| {
103                self.normalizing_thread
104                    .lock()
105                    .unwrap()
106                    .replace(std::thread::current().id());
107
108                // Safety: no other thread can access the inner value while we are normalizing it.
109                let state = unsafe {
110                    (*self.inner.get())
111                        .take()
112                        .expect("Cannot normalize a PyErr while already normalizing it.")
113                };
114
115                let normalized_state =
116                    Python::with_gil(|py| PyErrStateInner::Normalized(state.normalize(py)));
117
118                // Safety: no other thread can access the inner value while we are normalizing it.
119                unsafe {
120                    *self.inner.get() = Some(normalized_state);
121                }
122            })
123        });
124
125        match unsafe {
126            // Safety: self.inner will never be written again once normalized.
127            &*self.inner.get()
128        } {
129            Some(PyErrStateInner::Normalized(n)) => n,
130            _ => unreachable!(),
131        }
132    }
133}
134
135pub(crate) struct PyErrStateNormalized {
136    #[cfg(not(Py_3_12))]
137    ptype: Py<PyType>,
138    pub pvalue: Py<PyBaseException>,
139    #[cfg(not(Py_3_12))]
140    ptraceback: Option<Py<PyTraceback>>,
141}
142
143impl PyErrStateNormalized {
144    pub(crate) fn new(pvalue: Bound<'_, PyBaseException>) -> Self {
145        Self {
146            #[cfg(not(Py_3_12))]
147            ptype: pvalue.get_type().into(),
148            #[cfg(not(Py_3_12))]
149            ptraceback: unsafe {
150                Py::from_owned_ptr_or_opt(
151                    pvalue.py(),
152                    ffi::PyException_GetTraceback(pvalue.as_ptr()),
153                )
154            },
155            pvalue: pvalue.into(),
156        }
157    }
158
159    #[cfg(not(Py_3_12))]
160    pub(crate) fn ptype<'py>(&self, py: Python<'py>) -> Bound<'py, PyType> {
161        self.ptype.bind(py).clone()
162    }
163
164    #[cfg(Py_3_12)]
165    pub(crate) fn ptype<'py>(&self, py: Python<'py>) -> Bound<'py, PyType> {
166        self.pvalue.bind(py).get_type()
167    }
168
169    #[cfg(not(Py_3_12))]
170    pub(crate) fn ptraceback<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyTraceback>> {
171        self.ptraceback
172            .as_ref()
173            .map(|traceback| traceback.bind(py).clone())
174    }
175
176    #[cfg(Py_3_12)]
177    pub(crate) fn ptraceback<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyTraceback>> {
178        unsafe {
179            ffi::PyException_GetTraceback(self.pvalue.as_ptr())
180                .assume_owned_or_opt(py)
181                .map(|b| b.downcast_into_unchecked())
182        }
183    }
184
185    pub(crate) fn take(py: Python<'_>) -> Option<PyErrStateNormalized> {
186        #[cfg(Py_3_12)]
187        {
188            // Safety: PyErr_GetRaisedException can be called when attached to Python and
189            // returns either NULL or an owned reference.
190            unsafe { ffi::PyErr_GetRaisedException().assume_owned_or_opt(py) }.map(|pvalue| {
191                PyErrStateNormalized {
192                    // Safety: PyErr_GetRaisedException returns a valid exception type.
193                    pvalue: unsafe { pvalue.downcast_into_unchecked() }.unbind(),
194                }
195            })
196        }
197
198        #[cfg(not(Py_3_12))]
199        {
200            let (ptype, pvalue, ptraceback) = unsafe {
201                let mut ptype: *mut ffi::PyObject = std::ptr::null_mut();
202                let mut pvalue: *mut ffi::PyObject = std::ptr::null_mut();
203                let mut ptraceback: *mut ffi::PyObject = std::ptr::null_mut();
204
205                ffi::PyErr_Fetch(&mut ptype, &mut pvalue, &mut ptraceback);
206
207                // Ensure that the exception coming from the interpreter is normalized.
208                if !ptype.is_null() {
209                    ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback);
210                }
211
212                // Safety: PyErr_NormalizeException will have produced up to three owned
213                // references of the correct types.
214                (
215                    ptype
216                        .assume_owned_or_opt(py)
217                        .map(|b| b.downcast_into_unchecked()),
218                    pvalue
219                        .assume_owned_or_opt(py)
220                        .map(|b| b.downcast_into_unchecked()),
221                    ptraceback
222                        .assume_owned_or_opt(py)
223                        .map(|b| b.downcast_into_unchecked()),
224                )
225            };
226
227            ptype.map(|ptype| PyErrStateNormalized {
228                ptype: ptype.unbind(),
229                pvalue: pvalue.expect("normalized exception value missing").unbind(),
230                ptraceback: ptraceback.map(Bound::unbind),
231            })
232        }
233    }
234
235    #[cfg(not(Py_3_12))]
236    unsafe fn from_normalized_ffi_tuple(
237        py: Python<'_>,
238        ptype: *mut ffi::PyObject,
239        pvalue: *mut ffi::PyObject,
240        ptraceback: *mut ffi::PyObject,
241    ) -> Self {
242        PyErrStateNormalized {
243            ptype: Py::from_owned_ptr_or_opt(py, ptype).expect("Exception type missing"),
244            pvalue: Py::from_owned_ptr_or_opt(py, pvalue).expect("Exception value missing"),
245            ptraceback: Py::from_owned_ptr_or_opt(py, ptraceback),
246        }
247    }
248
249    pub fn clone_ref(&self, py: Python<'_>) -> Self {
250        Self {
251            #[cfg(not(Py_3_12))]
252            ptype: self.ptype.clone_ref(py),
253            pvalue: self.pvalue.clone_ref(py),
254            #[cfg(not(Py_3_12))]
255            ptraceback: self
256                .ptraceback
257                .as_ref()
258                .map(|ptraceback| ptraceback.clone_ref(py)),
259        }
260    }
261}
262
263pub(crate) struct PyErrStateLazyFnOutput {
264    pub(crate) ptype: PyObject,
265    pub(crate) pvalue: PyObject,
266}
267
268pub(crate) type PyErrStateLazyFn =
269    dyn for<'py> FnOnce(Python<'py>) -> PyErrStateLazyFnOutput + Send + Sync;
270
271enum PyErrStateInner {
272    Lazy(Box<PyErrStateLazyFn>),
273    Normalized(PyErrStateNormalized),
274}
275
276impl PyErrStateInner {
277    fn normalize(self, py: Python<'_>) -> PyErrStateNormalized {
278        match self {
279            #[cfg(not(Py_3_12))]
280            PyErrStateInner::Lazy(lazy) => {
281                let (ptype, pvalue, ptraceback) = lazy_into_normalized_ffi_tuple(py, lazy);
282                unsafe {
283                    PyErrStateNormalized::from_normalized_ffi_tuple(py, ptype, pvalue, ptraceback)
284                }
285            }
286            #[cfg(Py_3_12)]
287            PyErrStateInner::Lazy(lazy) => {
288                // To keep the implementation simple, just write the exception into the interpreter,
289                // which will cause it to be normalized
290                raise_lazy(py, lazy);
291                PyErrStateNormalized::take(py)
292                    .expect("exception missing after writing to the interpreter")
293            }
294            PyErrStateInner::Normalized(normalized) => normalized,
295        }
296    }
297
298    #[cfg(not(Py_3_12))]
299    fn restore(self, py: Python<'_>) {
300        let (ptype, pvalue, ptraceback) = match self {
301            PyErrStateInner::Lazy(lazy) => lazy_into_normalized_ffi_tuple(py, lazy),
302            PyErrStateInner::Normalized(PyErrStateNormalized {
303                ptype,
304                pvalue,
305                ptraceback,
306            }) => (
307                ptype.into_ptr(),
308                pvalue.into_ptr(),
309                ptraceback.map_or(std::ptr::null_mut(), Py::into_ptr),
310            ),
311        };
312        unsafe { ffi::PyErr_Restore(ptype, pvalue, ptraceback) }
313    }
314
315    #[cfg(Py_3_12)]
316    fn restore(self, py: Python<'_>) {
317        match self {
318            PyErrStateInner::Lazy(lazy) => raise_lazy(py, lazy),
319            PyErrStateInner::Normalized(PyErrStateNormalized { pvalue }) => unsafe {
320                ffi::PyErr_SetRaisedException(pvalue.into_ptr())
321            },
322        }
323    }
324}
325
326#[cfg(not(Py_3_12))]
327fn lazy_into_normalized_ffi_tuple(
328    py: Python<'_>,
329    lazy: Box<PyErrStateLazyFn>,
330) -> (*mut ffi::PyObject, *mut ffi::PyObject, *mut ffi::PyObject) {
331    // To be consistent with 3.12 logic, go via raise_lazy, but also then normalize
332    // the resulting exception
333    raise_lazy(py, lazy);
334    let mut ptype = std::ptr::null_mut();
335    let mut pvalue = std::ptr::null_mut();
336    let mut ptraceback = std::ptr::null_mut();
337    unsafe {
338        ffi::PyErr_Fetch(&mut ptype, &mut pvalue, &mut ptraceback);
339        ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback);
340    }
341    (ptype, pvalue, ptraceback)
342}
343
344/// Raises a "lazy" exception state into the Python interpreter.
345///
346/// In principle this could be split in two; first a function to create an exception
347/// in a normalized state, and then a call to `PyErr_SetRaisedException` to raise it.
348///
349/// This would require either moving some logic from C to Rust, or requesting a new
350/// API in CPython.
351fn raise_lazy(py: Python<'_>, lazy: Box<PyErrStateLazyFn>) {
352    let PyErrStateLazyFnOutput { ptype, pvalue } = lazy(py);
353    unsafe {
354        if ffi::PyExceptionClass_Check(ptype.as_ptr()) == 0 {
355            ffi::PyErr_SetString(
356                PyTypeError::type_object_raw(py).cast(),
357                ffi::c_str!("exceptions must derive from BaseException").as_ptr(),
358            )
359        } else {
360            ffi::PyErr_SetObject(ptype.as_ptr(), pvalue.as_ptr())
361        }
362    }
363}
364
365#[cfg(test)]
366mod tests {
367
368    use crate::{
369        exceptions::PyValueError, sync::GILOnceCell, PyErr, PyErrArguments, PyObject, Python,
370    };
371
372    #[test]
373    #[should_panic(expected = "Re-entrant normalization of PyErrState detected")]
374    fn test_reentrant_normalization() {
375        static ERR: GILOnceCell<PyErr> = GILOnceCell::new();
376
377        struct RecursiveArgs;
378
379        impl PyErrArguments for RecursiveArgs {
380            fn arguments(self, py: Python<'_>) -> PyObject {
381                // .value(py) triggers normalization
382                ERR.get(py)
383                    .expect("is set just below")
384                    .value(py)
385                    .clone()
386                    .into()
387            }
388        }
389
390        Python::with_gil(|py| {
391            ERR.set(py, PyValueError::new_err(RecursiveArgs)).unwrap();
392            ERR.get(py).expect("is set just above").value(py);
393        })
394    }
395
396    #[test]
397    #[cfg(not(target_arch = "wasm32"))] // We are building wasm Python with pthreads disabled
398    fn test_no_deadlock_thread_switch() {
399        static ERR: GILOnceCell<PyErr> = GILOnceCell::new();
400
401        struct GILSwitchArgs;
402
403        impl PyErrArguments for GILSwitchArgs {
404            fn arguments(self, py: Python<'_>) -> PyObject {
405                // releasing the GIL potentially allows for other threads to deadlock
406                // with the normalization going on here
407                py.allow_threads(|| {
408                    std::thread::sleep(std::time::Duration::from_millis(10));
409                });
410                py.None()
411            }
412        }
413
414        Python::with_gil(|py| ERR.set(py, PyValueError::new_err(GILSwitchArgs)).unwrap());
415
416        // Let many threads attempt to read the normalized value at the same time
417        let handles = (0..10)
418            .map(|_| {
419                std::thread::spawn(|| {
420                    Python::with_gil(|py| {
421                        ERR.get(py).expect("is set just above").value(py);
422                    });
423                })
424            })
425            .collect::<Vec<_>>();
426
427        for handle in handles {
428            handle.join().unwrap();
429        }
430
431        // We should never have deadlocked, and should be able to run
432        // this assertion
433        Python::with_gil(|py| {
434            assert!(ERR
435                .get(py)
436                .expect("is set above")
437                .is_instance_of::<PyValueError>(py))
438        });
439    }
440}