pyo3/coroutine/
cancel.rs

1use crate::{Py, PyAny, PyObject};
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll, Waker};
6
7#[derive(Debug, Default)]
8struct Inner {
9    exception: Option<PyObject>,
10    waker: Option<Waker>,
11}
12
13/// Helper used to wait and retrieve exception thrown in [`Coroutine`](super::Coroutine).
14///
15/// Only the last exception thrown can be retrieved.
16#[derive(Debug, Default)]
17pub struct CancelHandle(Arc<Mutex<Inner>>);
18
19impl CancelHandle {
20    /// Create a new `CoroutineCancel`.
21    pub fn new() -> Self {
22        Default::default()
23    }
24
25    /// Returns whether the associated coroutine has been cancelled.
26    pub fn is_cancelled(&self) -> bool {
27        self.0.lock().unwrap().exception.is_some()
28    }
29
30    /// Poll to retrieve the exception thrown in the associated coroutine.
31    pub fn poll_cancelled(&mut self, cx: &mut Context<'_>) -> Poll<PyObject> {
32        let mut inner = self.0.lock().unwrap();
33        if let Some(exc) = inner.exception.take() {
34            return Poll::Ready(exc);
35        }
36        if let Some(ref waker) = inner.waker {
37            if cx.waker().will_wake(waker) {
38                return Poll::Pending;
39            }
40        }
41        inner.waker = Some(cx.waker().clone());
42        Poll::Pending
43    }
44
45    /// Retrieve the exception thrown in the associated coroutine.
46    pub async fn cancelled(&mut self) -> PyObject {
47        Cancelled(self).await
48    }
49
50    #[doc(hidden)]
51    pub fn throw_callback(&self) -> ThrowCallback {
52        ThrowCallback(self.0.clone())
53    }
54}
55
56// Because `poll_fn` is not available in MSRV
57struct Cancelled<'a>(&'a mut CancelHandle);
58
59impl Future for Cancelled<'_> {
60    type Output = PyObject;
61    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
62        self.0.poll_cancelled(cx)
63    }
64}
65
66#[doc(hidden)]
67pub struct ThrowCallback(Arc<Mutex<Inner>>);
68
69impl ThrowCallback {
70    pub(super) fn throw(&self, exc: Py<PyAny>) {
71        let mut inner = self.0.lock().unwrap();
72        inner.exception = Some(exc);
73        if let Some(waker) = inner.waker.take() {
74            waker.wake();
75        }
76    }
77}
⚠️ Internal Docs ⚠️ Not Public API 👉 Official Docs Here