1use std::{
2 cell::UnsafeCell,
3 sync::{Mutex, Once},
4 thread::ThreadId,
5};
6
7use crate::{
8 exceptions::{PyBaseException, PyTypeError},
9 ffi,
10 ffi_ptr_ext::FfiPtrExt,
11 types::{PyAnyMethods, PyTraceback, PyType},
12 Bound, Py, PyAny, PyErrArguments, PyObject, PyTypeInfo, Python,
13};
14
15pub(crate) struct PyErrState {
16 normalized: Once,
19 normalizing_thread: Mutex<Option<ThreadId>>,
21 inner: UnsafeCell<Option<PyErrStateInner>>,
22}
23
24unsafe impl Send for PyErrState {}
27unsafe impl Sync for PyErrState {}
28#[cfg(feature = "nightly")]
29unsafe impl crate::marker::Ungil for PyErrState {}
30
31impl PyErrState {
32 pub(crate) fn lazy(f: Box<PyErrStateLazyFn>) -> Self {
33 Self::from_inner(PyErrStateInner::Lazy(f))
34 }
35
36 pub(crate) fn lazy_arguments(ptype: Py<PyAny>, args: impl PyErrArguments + 'static) -> Self {
37 Self::from_inner(PyErrStateInner::Lazy(Box::new(move |py| {
38 PyErrStateLazyFnOutput {
39 ptype,
40 pvalue: args.arguments(py),
41 }
42 })))
43 }
44
45 pub(crate) fn normalized(normalized: PyErrStateNormalized) -> Self {
46 let state = Self::from_inner(PyErrStateInner::Normalized(normalized));
47 state.normalized.call_once(|| {});
52 state
53 }
54
55 pub(crate) fn restore(self, py: Python<'_>) {
56 self.inner
57 .into_inner()
58 .expect("PyErr state should never be invalid outside of normalization")
59 .restore(py)
60 }
61
62 fn from_inner(inner: PyErrStateInner) -> Self {
63 Self {
64 normalized: Once::new(),
65 normalizing_thread: Mutex::new(None),
66 inner: UnsafeCell::new(Some(inner)),
67 }
68 }
69
70 #[inline]
71 pub(crate) fn as_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
72 if self.normalized.is_completed() {
73 match unsafe {
74 &*self.inner.get()
76 } {
77 Some(PyErrStateInner::Normalized(n)) => return n,
78 _ => unreachable!(),
79 }
80 }
81
82 self.make_normalized(py)
83 }
84
85 #[cold]
86 fn make_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
87 if let Some(thread) = self.normalizing_thread.lock().unwrap().as_ref() {
94 assert!(
95 !(*thread == std::thread::current().id()),
96 "Re-entrant normalization of PyErrState detected"
97 );
98 }
99
100 py.allow_threads(|| {
102 self.normalized.call_once(|| {
103 self.normalizing_thread
104 .lock()
105 .unwrap()
106 .replace(std::thread::current().id());
107
108 let state = unsafe {
110 (*self.inner.get())
111 .take()
112 .expect("Cannot normalize a PyErr while already normalizing it.")
113 };
114
115 let normalized_state =
116 Python::with_gil(|py| PyErrStateInner::Normalized(state.normalize(py)));
117
118 unsafe {
120 *self.inner.get() = Some(normalized_state);
121 }
122 })
123 });
124
125 match unsafe {
126 &*self.inner.get()
128 } {
129 Some(PyErrStateInner::Normalized(n)) => n,
130 _ => unreachable!(),
131 }
132 }
133}
134
135pub(crate) struct PyErrStateNormalized {
136 #[cfg(not(Py_3_12))]
137 ptype: Py<PyType>,
138 pub pvalue: Py<PyBaseException>,
139 #[cfg(not(Py_3_12))]
140 ptraceback: Option<Py<PyTraceback>>,
141}
142
143impl PyErrStateNormalized {
144 pub(crate) fn new(pvalue: Bound<'_, PyBaseException>) -> Self {
145 Self {
146 #[cfg(not(Py_3_12))]
147 ptype: pvalue.get_type().into(),
148 #[cfg(not(Py_3_12))]
149 ptraceback: unsafe {
150 Py::from_owned_ptr_or_opt(
151 pvalue.py(),
152 ffi::PyException_GetTraceback(pvalue.as_ptr()),
153 )
154 },
155 pvalue: pvalue.into(),
156 }
157 }
158
159 #[cfg(not(Py_3_12))]
160 pub(crate) fn ptype<'py>(&self, py: Python<'py>) -> Bound<'py, PyType> {
161 self.ptype.bind(py).clone()
162 }
163
164 #[cfg(Py_3_12)]
165 pub(crate) fn ptype<'py>(&self, py: Python<'py>) -> Bound<'py, PyType> {
166 self.pvalue.bind(py).get_type()
167 }
168
169 #[cfg(not(Py_3_12))]
170 pub(crate) fn ptraceback<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyTraceback>> {
171 self.ptraceback
172 .as_ref()
173 .map(|traceback| traceback.bind(py).clone())
174 }
175
176 #[cfg(Py_3_12)]
177 pub(crate) fn ptraceback<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyTraceback>> {
178 unsafe {
179 ffi::PyException_GetTraceback(self.pvalue.as_ptr())
180 .assume_owned_or_opt(py)
181 .map(|b| b.downcast_into_unchecked())
182 }
183 }
184
185 pub(crate) fn take(py: Python<'_>) -> Option<PyErrStateNormalized> {
186 #[cfg(Py_3_12)]
187 {
188 unsafe { ffi::PyErr_GetRaisedException().assume_owned_or_opt(py) }.map(|pvalue| {
191 PyErrStateNormalized {
192 pvalue: unsafe { pvalue.downcast_into_unchecked() }.unbind(),
194 }
195 })
196 }
197
198 #[cfg(not(Py_3_12))]
199 {
200 let (ptype, pvalue, ptraceback) = unsafe {
201 let mut ptype: *mut ffi::PyObject = std::ptr::null_mut();
202 let mut pvalue: *mut ffi::PyObject = std::ptr::null_mut();
203 let mut ptraceback: *mut ffi::PyObject = std::ptr::null_mut();
204
205 ffi::PyErr_Fetch(&mut ptype, &mut pvalue, &mut ptraceback);
206
207 if !ptype.is_null() {
209 ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback);
210 }
211
212 (
215 ptype
216 .assume_owned_or_opt(py)
217 .map(|b| b.downcast_into_unchecked()),
218 pvalue
219 .assume_owned_or_opt(py)
220 .map(|b| b.downcast_into_unchecked()),
221 ptraceback
222 .assume_owned_or_opt(py)
223 .map(|b| b.downcast_into_unchecked()),
224 )
225 };
226
227 ptype.map(|ptype| PyErrStateNormalized {
228 ptype: ptype.unbind(),
229 pvalue: pvalue.expect("normalized exception value missing").unbind(),
230 ptraceback: ptraceback.map(Bound::unbind),
231 })
232 }
233 }
234
235 #[cfg(not(Py_3_12))]
236 unsafe fn from_normalized_ffi_tuple(
237 py: Python<'_>,
238 ptype: *mut ffi::PyObject,
239 pvalue: *mut ffi::PyObject,
240 ptraceback: *mut ffi::PyObject,
241 ) -> Self {
242 PyErrStateNormalized {
243 ptype: Py::from_owned_ptr_or_opt(py, ptype).expect("Exception type missing"),
244 pvalue: Py::from_owned_ptr_or_opt(py, pvalue).expect("Exception value missing"),
245 ptraceback: Py::from_owned_ptr_or_opt(py, ptraceback),
246 }
247 }
248
249 pub fn clone_ref(&self, py: Python<'_>) -> Self {
250 Self {
251 #[cfg(not(Py_3_12))]
252 ptype: self.ptype.clone_ref(py),
253 pvalue: self.pvalue.clone_ref(py),
254 #[cfg(not(Py_3_12))]
255 ptraceback: self
256 .ptraceback
257 .as_ref()
258 .map(|ptraceback| ptraceback.clone_ref(py)),
259 }
260 }
261}
262
263pub(crate) struct PyErrStateLazyFnOutput {
264 pub(crate) ptype: PyObject,
265 pub(crate) pvalue: PyObject,
266}
267
268pub(crate) type PyErrStateLazyFn =
269 dyn for<'py> FnOnce(Python<'py>) -> PyErrStateLazyFnOutput + Send + Sync;
270
271enum PyErrStateInner {
272 Lazy(Box<PyErrStateLazyFn>),
273 Normalized(PyErrStateNormalized),
274}
275
276impl PyErrStateInner {
277 fn normalize(self, py: Python<'_>) -> PyErrStateNormalized {
278 match self {
279 #[cfg(not(Py_3_12))]
280 PyErrStateInner::Lazy(lazy) => {
281 let (ptype, pvalue, ptraceback) = lazy_into_normalized_ffi_tuple(py, lazy);
282 unsafe {
283 PyErrStateNormalized::from_normalized_ffi_tuple(py, ptype, pvalue, ptraceback)
284 }
285 }
286 #[cfg(Py_3_12)]
287 PyErrStateInner::Lazy(lazy) => {
288 raise_lazy(py, lazy);
291 PyErrStateNormalized::take(py)
292 .expect("exception missing after writing to the interpreter")
293 }
294 PyErrStateInner::Normalized(normalized) => normalized,
295 }
296 }
297
298 #[cfg(not(Py_3_12))]
299 fn restore(self, py: Python<'_>) {
300 let (ptype, pvalue, ptraceback) = match self {
301 PyErrStateInner::Lazy(lazy) => lazy_into_normalized_ffi_tuple(py, lazy),
302 PyErrStateInner::Normalized(PyErrStateNormalized {
303 ptype,
304 pvalue,
305 ptraceback,
306 }) => (
307 ptype.into_ptr(),
308 pvalue.into_ptr(),
309 ptraceback.map_or(std::ptr::null_mut(), Py::into_ptr),
310 ),
311 };
312 unsafe { ffi::PyErr_Restore(ptype, pvalue, ptraceback) }
313 }
314
315 #[cfg(Py_3_12)]
316 fn restore(self, py: Python<'_>) {
317 match self {
318 PyErrStateInner::Lazy(lazy) => raise_lazy(py, lazy),
319 PyErrStateInner::Normalized(PyErrStateNormalized { pvalue }) => unsafe {
320 ffi::PyErr_SetRaisedException(pvalue.into_ptr())
321 },
322 }
323 }
324}
325
326#[cfg(not(Py_3_12))]
327fn lazy_into_normalized_ffi_tuple(
328 py: Python<'_>,
329 lazy: Box<PyErrStateLazyFn>,
330) -> (*mut ffi::PyObject, *mut ffi::PyObject, *mut ffi::PyObject) {
331 raise_lazy(py, lazy);
334 let mut ptype = std::ptr::null_mut();
335 let mut pvalue = std::ptr::null_mut();
336 let mut ptraceback = std::ptr::null_mut();
337 unsafe {
338 ffi::PyErr_Fetch(&mut ptype, &mut pvalue, &mut ptraceback);
339 ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback);
340 }
341 (ptype, pvalue, ptraceback)
342}
343
344fn raise_lazy(py: Python<'_>, lazy: Box<PyErrStateLazyFn>) {
352 let PyErrStateLazyFnOutput { ptype, pvalue } = lazy(py);
353 unsafe {
354 if ffi::PyExceptionClass_Check(ptype.as_ptr()) == 0 {
355 ffi::PyErr_SetString(
356 PyTypeError::type_object_raw(py).cast(),
357 ffi::c_str!("exceptions must derive from BaseException").as_ptr(),
358 )
359 } else {
360 ffi::PyErr_SetObject(ptype.as_ptr(), pvalue.as_ptr())
361 }
362 }
363}
364
365#[cfg(test)]
366mod tests {
367
368 use crate::{
369 exceptions::PyValueError, sync::GILOnceCell, PyErr, PyErrArguments, PyObject, Python,
370 };
371
372 #[test]
373 #[should_panic(expected = "Re-entrant normalization of PyErrState detected")]
374 fn test_reentrant_normalization() {
375 static ERR: GILOnceCell<PyErr> = GILOnceCell::new();
376
377 struct RecursiveArgs;
378
379 impl PyErrArguments for RecursiveArgs {
380 fn arguments(self, py: Python<'_>) -> PyObject {
381 ERR.get(py)
383 .expect("is set just below")
384 .value(py)
385 .clone()
386 .into()
387 }
388 }
389
390 Python::with_gil(|py| {
391 ERR.set(py, PyValueError::new_err(RecursiveArgs)).unwrap();
392 ERR.get(py).expect("is set just above").value(py);
393 })
394 }
395
396 #[test]
397 #[cfg(not(target_arch = "wasm32"))] fn test_no_deadlock_thread_switch() {
399 static ERR: GILOnceCell<PyErr> = GILOnceCell::new();
400
401 struct GILSwitchArgs;
402
403 impl PyErrArguments for GILSwitchArgs {
404 fn arguments(self, py: Python<'_>) -> PyObject {
405 py.allow_threads(|| {
408 std::thread::sleep(std::time::Duration::from_millis(10));
409 });
410 py.None()
411 }
412 }
413
414 Python::with_gil(|py| ERR.set(py, PyValueError::new_err(GILSwitchArgs)).unwrap());
415
416 let handles = (0..10)
418 .map(|_| {
419 std::thread::spawn(|| {
420 Python::with_gil(|py| {
421 ERR.get(py).expect("is set just above").value(py);
422 });
423 })
424 })
425 .collect::<Vec<_>>();
426
427 for handle in handles {
428 handle.join().unwrap();
429 }
430
431 Python::with_gil(|py| {
434 assert!(ERR
435 .get(py)
436 .expect("is set above")
437 .is_instance_of::<PyValueError>(py))
438 });
439 }
440}