1#![allow(missing_docs)]
2use std::cell::UnsafeCell;
5use std::marker::PhantomData;
6use std::mem::ManuallyDrop;
7use std::sync::atomic::{AtomicUsize, Ordering};
8
9use crate::impl_::pyclass::{
10 PyClassBaseType, PyClassDict, PyClassImpl, PyClassThreadChecker, PyClassWeakRef,
11};
12use crate::internal::get_slot::TP_FREE;
13use crate::type_object::{PyLayout, PySizedLayout};
14use crate::types::{PyType, PyTypeMethods};
15use crate::{ffi, PyClass, PyTypeInfo, Python};
16
17use super::{PyBorrowError, PyBorrowMutError};
18
19pub trait PyClassMutability {
20 type Storage: PyClassBorrowChecker;
23 type Checker: PyClassBorrowChecker;
27 type ImmutableChild: PyClassMutability;
28 type MutableChild: PyClassMutability;
29}
30
31pub struct ImmutableClass(());
32pub struct MutableClass(());
33pub struct ExtendsMutableAncestor<M: PyClassMutability>(PhantomData<M>);
34
35impl PyClassMutability for ImmutableClass {
36 type Storage = EmptySlot;
37 type Checker = EmptySlot;
38 type ImmutableChild = ImmutableClass;
39 type MutableChild = MutableClass;
40}
41
42impl PyClassMutability for MutableClass {
43 type Storage = BorrowChecker;
44 type Checker = BorrowChecker;
45 type ImmutableChild = ExtendsMutableAncestor<ImmutableClass>;
46 type MutableChild = ExtendsMutableAncestor<MutableClass>;
47}
48
49impl<M: PyClassMutability> PyClassMutability for ExtendsMutableAncestor<M> {
50 type Storage = EmptySlot;
51 type Checker = BorrowChecker;
52 type ImmutableChild = ExtendsMutableAncestor<ImmutableClass>;
53 type MutableChild = ExtendsMutableAncestor<MutableClass>;
54}
55
56#[derive(Debug)]
57struct BorrowFlag(AtomicUsize);
58
59impl BorrowFlag {
60 pub(crate) const UNUSED: usize = 0;
61 const HAS_MUTABLE_BORROW: usize = usize::MAX;
62 fn increment(&self) -> Result<(), PyBorrowError> {
63 let mut value = self.0.load(Ordering::Relaxed);
65 loop {
66 if value == BorrowFlag::HAS_MUTABLE_BORROW {
67 return Err(PyBorrowError { _private: () });
68 }
69 match self.0.compare_exchange(
70 value,
73 value + 1,
74 Ordering::AcqRel,
77 Ordering::Relaxed,
79 ) {
80 Ok(..) => {
81 break Ok(());
82 }
83 Err(changed_value) => {
84 value = changed_value;
86 }
87 }
88 }
89 }
90 fn decrement(&self) {
91 self.0.fetch_sub(1, Ordering::Release);
93 }
94}
95
96pub struct EmptySlot(());
97pub struct BorrowChecker(BorrowFlag);
98
99pub trait PyClassBorrowChecker {
100 fn new() -> Self;
102
103 fn try_borrow(&self) -> Result<(), PyBorrowError>;
105
106 fn release_borrow(&self);
108 fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError>;
110 fn release_borrow_mut(&self);
112}
113
114impl PyClassBorrowChecker for EmptySlot {
115 #[inline]
116 fn new() -> Self {
117 EmptySlot(())
118 }
119
120 #[inline]
121 fn try_borrow(&self) -> Result<(), PyBorrowError> {
122 Ok(())
123 }
124
125 #[inline]
126 fn release_borrow(&self) {}
127
128 #[inline]
129 fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError> {
130 unreachable!()
131 }
132
133 #[inline]
134 fn release_borrow_mut(&self) {
135 unreachable!()
136 }
137}
138
139impl PyClassBorrowChecker for BorrowChecker {
140 #[inline]
141 fn new() -> Self {
142 Self(BorrowFlag(AtomicUsize::new(BorrowFlag::UNUSED)))
143 }
144
145 fn try_borrow(&self) -> Result<(), PyBorrowError> {
146 self.0.increment()
147 }
148
149 fn release_borrow(&self) {
150 self.0.decrement();
151 }
152
153 fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError> {
154 let flag = &self.0;
155 match flag.0.compare_exchange(
156 BorrowFlag::UNUSED,
159 BorrowFlag::HAS_MUTABLE_BORROW,
160 Ordering::AcqRel,
163 Ordering::Relaxed,
166 ) {
167 Ok(..) => Ok(()),
168 Err(..) => Err(PyBorrowMutError { _private: () }),
169 }
170 }
171
172 fn release_borrow_mut(&self) {
173 self.0 .0.store(BorrowFlag::UNUSED, Ordering::Release)
174 }
175}
176
177pub trait GetBorrowChecker<T: PyClassImpl> {
178 fn borrow_checker(
179 class_object: &PyClassObject<T>,
180 ) -> &<T::PyClassMutability as PyClassMutability>::Checker;
181}
182
183impl<T: PyClassImpl<PyClassMutability = Self>> GetBorrowChecker<T> for MutableClass {
184 fn borrow_checker(class_object: &PyClassObject<T>) -> &BorrowChecker {
185 &class_object.contents.borrow_checker
186 }
187}
188
189impl<T: PyClassImpl<PyClassMutability = Self>> GetBorrowChecker<T> for ImmutableClass {
190 fn borrow_checker(class_object: &PyClassObject<T>) -> &EmptySlot {
191 &class_object.contents.borrow_checker
192 }
193}
194
195impl<T: PyClassImpl<PyClassMutability = Self>, M: PyClassMutability> GetBorrowChecker<T>
196 for ExtendsMutableAncestor<M>
197where
198 T::BaseType: PyClassImpl + PyClassBaseType<LayoutAsBase = PyClassObject<T::BaseType>>,
199 <T::BaseType as PyClassImpl>::PyClassMutability: PyClassMutability<Checker = BorrowChecker>,
200{
201 fn borrow_checker(class_object: &PyClassObject<T>) -> &BorrowChecker {
202 <<T::BaseType as PyClassImpl>::PyClassMutability as GetBorrowChecker<T::BaseType>>::borrow_checker(&class_object.ob_base)
203 }
204}
205
206#[doc(hidden)]
208#[repr(C)]
209pub struct PyClassObjectBase<T> {
210 ob_base: T,
211}
212
213unsafe impl<T, U> PyLayout<T> for PyClassObjectBase<U> where U: PySizedLayout<T> {}
214
215#[doc(hidden)]
216pub trait PyClassObjectLayout<T>: PyLayout<T> {
217 fn ensure_threadsafe(&self);
218 fn check_threadsafe(&self) -> Result<(), PyBorrowError>;
219 unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject);
224}
225
226impl<T, U> PyClassObjectLayout<T> for PyClassObjectBase<U>
227where
228 U: PySizedLayout<T>,
229 T: PyTypeInfo,
230{
231 fn ensure_threadsafe(&self) {}
232 fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
233 Ok(())
234 }
235 unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
236 let type_obj = T::type_object(py);
239 let type_ptr = type_obj.as_type_ptr();
240 let actual_type = PyType::from_borrowed_type_ptr(py, ffi::Py_TYPE(slf));
241
242 if type_ptr == std::ptr::addr_of_mut!(ffi::PyBaseObject_Type) {
244 let tp_free = actual_type
245 .get_slot(TP_FREE)
246 .expect("PyBaseObject_Type should have tp_free");
247 return tp_free(slf.cast());
248 }
249
250 #[cfg(not(Py_LIMITED_API))]
252 {
253 if let Some(dealloc) = (*type_ptr).tp_dealloc {
255 #[cfg(not(any(Py_3_11, PyPy)))]
259 if ffi::PyType_FastSubclass(type_ptr, ffi::Py_TPFLAGS_BASE_EXC_SUBCLASS) == 1 {
260 ffi::PyObject_GC_Track(slf.cast());
261 }
262 dealloc(slf);
263 } else {
264 (*actual_type.as_type_ptr())
265 .tp_free
266 .expect("type missing tp_free")(slf.cast());
267 }
268 }
269
270 #[cfg(Py_LIMITED_API)]
271 unreachable!("subclassing native types is not possible with the `abi3` feature");
272 }
273}
274
275#[repr(C)]
277pub struct PyClassObject<T: PyClassImpl> {
278 pub(crate) ob_base: <T::BaseType as PyClassBaseType>::LayoutAsBase,
279 pub(crate) contents: PyClassObjectContents<T>,
280}
281
282#[repr(C)]
283pub(crate) struct PyClassObjectContents<T: PyClassImpl> {
284 pub(crate) value: ManuallyDrop<UnsafeCell<T>>,
285 pub(crate) borrow_checker: <T::PyClassMutability as PyClassMutability>::Storage,
286 pub(crate) thread_checker: T::ThreadChecker,
287 pub(crate) dict: T::Dict,
288 pub(crate) weakref: T::WeakRef,
289}
290
291impl<T: PyClassImpl> PyClassObject<T> {
292 pub(crate) fn get_ptr(&self) -> *mut T {
293 self.contents.value.get()
294 }
295
296 pub(crate) fn dict_offset() -> ffi::Py_ssize_t {
298 use memoffset::offset_of;
299
300 let offset =
301 offset_of!(PyClassObject<T>, contents) + offset_of!(PyClassObjectContents<T>, dict);
302
303 #[allow(clippy::useless_conversion)]
305 offset.try_into().expect("offset should fit in Py_ssize_t")
306 }
307
308 pub(crate) fn weaklist_offset() -> ffi::Py_ssize_t {
310 use memoffset::offset_of;
311
312 let offset =
313 offset_of!(PyClassObject<T>, contents) + offset_of!(PyClassObjectContents<T>, weakref);
314
315 #[allow(clippy::useless_conversion)]
317 offset.try_into().expect("offset should fit in Py_ssize_t")
318 }
319}
320
321impl<T: PyClassImpl> PyClassObject<T> {
322 pub(crate) fn borrow_checker(&self) -> &<T::PyClassMutability as PyClassMutability>::Checker {
323 T::PyClassMutability::borrow_checker(self)
324 }
325}
326
327unsafe impl<T: PyClassImpl> PyLayout<T> for PyClassObject<T> {}
328impl<T: PyClass> PySizedLayout<T> for PyClassObject<T> {}
329
330impl<T: PyClassImpl> PyClassObjectLayout<T> for PyClassObject<T>
331where
332 <T::BaseType as PyClassBaseType>::LayoutAsBase: PyClassObjectLayout<T::BaseType>,
333{
334 fn ensure_threadsafe(&self) {
335 self.contents.thread_checker.ensure();
336 self.ob_base.ensure_threadsafe();
337 }
338 fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
339 if !self.contents.thread_checker.check() {
340 return Err(PyBorrowError { _private: () });
341 }
342 self.ob_base.check_threadsafe()
343 }
344 unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
345 let class_object = &mut *(slf.cast::<PyClassObject<T>>());
347 if class_object.contents.thread_checker.can_drop(py) {
348 ManuallyDrop::drop(&mut class_object.contents.value);
349 }
350 class_object.contents.dict.clear_dict(py);
351 class_object.contents.weakref.clear_weakrefs(slf, py);
352 <T::BaseType as PyClassBaseType>::LayoutAsBase::tp_dealloc(py, slf)
353 }
354}
355
356#[cfg(test)]
357#[cfg(feature = "macros")]
358mod tests {
359 use super::*;
360
361 use crate::prelude::*;
362 use crate::pyclass::boolean_struct::{False, True};
363
364 #[pyclass(crate = "crate", subclass)]
365 struct MutableBase;
366
367 #[pyclass(crate = "crate", extends = MutableBase, subclass)]
368 struct MutableChildOfMutableBase;
369
370 #[pyclass(crate = "crate", extends = MutableBase, frozen, subclass)]
371 struct ImmutableChildOfMutableBase;
372
373 #[pyclass(crate = "crate", extends = MutableChildOfMutableBase)]
374 struct MutableChildOfMutableChildOfMutableBase;
375
376 #[pyclass(crate = "crate", extends = ImmutableChildOfMutableBase)]
377 struct MutableChildOfImmutableChildOfMutableBase;
378
379 #[pyclass(crate = "crate", extends = MutableChildOfMutableBase, frozen)]
380 struct ImmutableChildOfMutableChildOfMutableBase;
381
382 #[pyclass(crate = "crate", extends = ImmutableChildOfMutableBase, frozen)]
383 struct ImmutableChildOfImmutableChildOfMutableBase;
384
385 #[pyclass(crate = "crate", frozen, subclass)]
386 struct ImmutableBase;
387
388 #[pyclass(crate = "crate", extends = ImmutableBase, subclass)]
389 struct MutableChildOfImmutableBase;
390
391 #[pyclass(crate = "crate", extends = ImmutableBase, frozen, subclass)]
392 struct ImmutableChildOfImmutableBase;
393
394 #[pyclass(crate = "crate", extends = MutableChildOfImmutableBase)]
395 struct MutableChildOfMutableChildOfImmutableBase;
396
397 #[pyclass(crate = "crate", extends = ImmutableChildOfImmutableBase)]
398 struct MutableChildOfImmutableChildOfImmutableBase;
399
400 #[pyclass(crate = "crate", extends = MutableChildOfImmutableBase, frozen)]
401 struct ImmutableChildOfMutableChildOfImmutableBase;
402
403 #[pyclass(crate = "crate", extends = ImmutableChildOfImmutableBase, frozen)]
404 struct ImmutableChildOfImmutableChildOfImmutableBase;
405
406 fn assert_mutable<T: PyClass<Frozen = False, PyClassMutability = MutableClass>>() {}
407 fn assert_immutable<T: PyClass<Frozen = True, PyClassMutability = ImmutableClass>>() {}
408 fn assert_mutable_with_mutable_ancestor<
409 T: PyClass<Frozen = False, PyClassMutability = ExtendsMutableAncestor<MutableClass>>,
410 >() {
411 }
412 fn assert_immutable_with_mutable_ancestor<
413 T: PyClass<Frozen = True, PyClassMutability = ExtendsMutableAncestor<ImmutableClass>>,
414 >() {
415 }
416
417 #[test]
418 fn test_inherited_mutability() {
419 assert_mutable::<MutableBase>();
421
422 assert_mutable_with_mutable_ancestor::<MutableChildOfMutableBase>();
424 assert_immutable_with_mutable_ancestor::<ImmutableChildOfMutableBase>();
425
426 assert_mutable_with_mutable_ancestor::<MutableChildOfMutableChildOfMutableBase>();
428 assert_mutable_with_mutable_ancestor::<MutableChildOfImmutableChildOfMutableBase>();
429 assert_immutable_with_mutable_ancestor::<ImmutableChildOfMutableChildOfMutableBase>();
430 assert_immutable_with_mutable_ancestor::<ImmutableChildOfImmutableChildOfMutableBase>();
431
432 assert_immutable::<ImmutableBase>();
434 assert_immutable::<ImmutableChildOfImmutableBase>();
435 assert_immutable::<ImmutableChildOfImmutableChildOfImmutableBase>();
436
437 assert_mutable::<MutableChildOfImmutableBase>();
439 assert_mutable::<MutableChildOfImmutableChildOfImmutableBase>();
440
441 assert_mutable_with_mutable_ancestor::<MutableChildOfMutableChildOfImmutableBase>();
443 assert_immutable_with_mutable_ancestor::<ImmutableChildOfMutableChildOfImmutableBase>();
444 }
445
446 #[test]
447 fn test_mutable_borrow_prevents_further_borrows() {
448 Python::with_gil(|py| {
449 let mmm = Py::new(
450 py,
451 PyClassInitializer::from(MutableBase)
452 .add_subclass(MutableChildOfMutableBase)
453 .add_subclass(MutableChildOfMutableChildOfMutableBase),
454 )
455 .unwrap();
456
457 let mmm_bound: &Bound<'_, MutableChildOfMutableChildOfMutableBase> = mmm.bind(py);
458
459 let mmm_refmut = mmm_bound.borrow_mut();
460
461 assert!(mmm_bound
463 .extract::<PyRef<'_, MutableChildOfMutableChildOfMutableBase>>()
464 .is_err());
465 assert!(mmm_bound
466 .extract::<PyRef<'_, MutableChildOfMutableBase>>()
467 .is_err());
468 assert!(mmm_bound.extract::<PyRef<'_, MutableBase>>().is_err());
469 assert!(mmm_bound
470 .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
471 .is_err());
472 assert!(mmm_bound
473 .extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
474 .is_err());
475 assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_err());
476
477 drop(mmm_refmut);
479
480 assert!(mmm_bound
481 .extract::<PyRef<'_, MutableChildOfMutableChildOfMutableBase>>()
482 .is_ok());
483 assert!(mmm_bound
484 .extract::<PyRef<'_, MutableChildOfMutableBase>>()
485 .is_ok());
486 assert!(mmm_bound.extract::<PyRef<'_, MutableBase>>().is_ok());
487 assert!(mmm_bound
488 .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
489 .is_ok());
490 assert!(mmm_bound
491 .extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
492 .is_ok());
493 assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_ok());
494 })
495 }
496
497 #[test]
498 fn test_immutable_borrows_prevent_mutable_borrows() {
499 Python::with_gil(|py| {
500 let mmm = Py::new(
501 py,
502 PyClassInitializer::from(MutableBase)
503 .add_subclass(MutableChildOfMutableBase)
504 .add_subclass(MutableChildOfMutableChildOfMutableBase),
505 )
506 .unwrap();
507
508 let mmm_bound: &Bound<'_, MutableChildOfMutableChildOfMutableBase> = mmm.bind(py);
509
510 let mmm_refmut = mmm_bound.borrow();
511
512 assert!(mmm_bound
514 .extract::<PyRef<'_, MutableChildOfMutableChildOfMutableBase>>()
515 .is_ok());
516 assert!(mmm_bound
517 .extract::<PyRef<'_, MutableChildOfMutableBase>>()
518 .is_ok());
519 assert!(mmm_bound.extract::<PyRef<'_, MutableBase>>().is_ok());
520
521 assert!(mmm_bound
523 .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
524 .is_err());
525 assert!(mmm_bound
526 .extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
527 .is_err());
528 assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_err());
529
530 drop(mmm_refmut);
532
533 assert!(mmm_bound
534 .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
535 .is_ok());
536 assert!(mmm_bound
537 .extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
538 .is_ok());
539 assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_ok());
540 })
541 }
542
543 #[test]
544 #[cfg(not(target_arch = "wasm32"))]
545 fn test_thread_safety() {
546 #[crate::pyclass(crate = "crate")]
547 struct MyClass {
548 x: u64,
549 }
550
551 Python::with_gil(|py| {
552 let inst = Py::new(py, MyClass { x: 0 }).unwrap();
553
554 let total_modifications = py.allow_threads(|| {
555 std::thread::scope(|s| {
556 let threads = (0..10)
559 .map(|_| {
560 s.spawn(|| {
561 Python::with_gil(|py| {
562 let mut local_modifications = 0;
564 for _ in 0..100 {
565 if let Ok(mut i) = inst.try_borrow_mut(py) {
566 i.x += 1;
567 local_modifications += 1;
568 }
569 }
570 local_modifications
571 })
572 })
573 })
574 .collect::<Vec<_>>();
575
576 threads.into_iter().map(|t| t.join().unwrap()).sum::<u64>()
578 })
579 });
580
581 assert_eq!(total_modifications, inst.borrow(py).x);
584 });
585 }
586
587 #[test]
588 #[cfg(not(target_arch = "wasm32"))]
589 fn test_thread_safety_2() {
590 struct SyncUnsafeCell<T>(UnsafeCell<T>);
591 unsafe impl<T> Sync for SyncUnsafeCell<T> {}
592
593 impl<T> SyncUnsafeCell<T> {
594 fn get(&self) -> *mut T {
595 self.0.get()
596 }
597 }
598
599 let data = SyncUnsafeCell(UnsafeCell::new(0));
600 let data2 = SyncUnsafeCell(UnsafeCell::new(0));
601 let borrow_checker = BorrowChecker(BorrowFlag(AtomicUsize::new(BorrowFlag::UNUSED)));
602
603 std::thread::scope(|s| {
604 s.spawn(|| {
605 for _ in 0..1_000_000 {
606 if borrow_checker.try_borrow_mut().is_ok() {
607 unsafe { *data.get() += 1 };
609 unsafe { *data2.get() += 1 };
610 borrow_checker.release_borrow_mut();
611 }
612 }
613 });
614
615 s.spawn(|| {
616 for _ in 0..1_000_000 {
617 if borrow_checker.try_borrow().is_ok() {
618 assert_eq!(unsafe { *data.get() }, unsafe { *data2.get() });
621 borrow_checker.release_borrow();
622 }
623 }
624 });
625 });
626 }
627}