pyo3/
coroutine.rs

1//! Python coroutine implementation, used notably when wrapping `async fn`
2//! with `#[pyfunction]`/`#[pymethods]`.
3use std::{
4    future::Future,
5    panic,
6    pin::Pin,
7    sync::Arc,
8    task::{Context, Poll, Waker},
9};
10
11use pyo3_macros::{pyclass, pymethods};
12
13use crate::{
14    coroutine::{cancel::ThrowCallback, waker::AsyncioWaker},
15    exceptions::{PyAttributeError, PyRuntimeError, PyStopIteration},
16    panic::PanicException,
17    types::{string::PyStringMethods, PyIterator, PyString},
18    Bound, IntoPyObject, IntoPyObjectExt, Py, PyAny, PyErr, PyObject, PyResult, Python,
19};
20
21pub(crate) mod cancel;
22mod waker;
23
24pub use cancel::CancelHandle;
25
26const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine";
27
28/// Python coroutine wrapping a [`Future`].
29#[pyclass(crate = "crate")]
30pub struct Coroutine {
31    name: Option<Py<PyString>>,
32    qualname_prefix: Option<&'static str>,
33    throw_callback: Option<ThrowCallback>,
34    future: Option<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send>>>,
35    waker: Option<Arc<AsyncioWaker>>,
36}
37
38// Safety: `Coroutine` is allowed to be `Sync` even though the future is not,
39// because the future is polled with `&mut self` receiver
40unsafe impl Sync for Coroutine {}
41
42impl Coroutine {
43    ///  Wrap a future into a Python coroutine.
44    ///
45    /// Coroutine `send` polls the wrapped future, ignoring the value passed
46    /// (should always be `None` anyway).
47    ///
48    /// `Coroutine `throw` drop the wrapped future and reraise the exception passed
49    pub(crate) fn new<'py, F, T, E>(
50        name: Option<Bound<'py, PyString>>,
51        qualname_prefix: Option<&'static str>,
52        throw_callback: Option<ThrowCallback>,
53        future: F,
54    ) -> Self
55    where
56        F: Future<Output = Result<T, E>> + Send + 'static,
57        T: IntoPyObject<'py>,
58        E: Into<PyErr>,
59    {
60        let wrap = async move {
61            let obj = future.await.map_err(Into::into)?;
62            // SAFETY: GIL is acquired when future is polled (see `Coroutine::poll`)
63            obj.into_py_any(unsafe { Python::assume_gil_acquired() })
64        };
65        Self {
66            name: name.map(Bound::unbind),
67            qualname_prefix,
68            throw_callback,
69            future: Some(Box::pin(wrap)),
70            waker: None,
71        }
72    }
73
74    fn poll(&mut self, py: Python<'_>, throw: Option<PyObject>) -> PyResult<PyObject> {
75        // raise if the coroutine has already been run to completion
76        let future_rs = match self.future {
77            Some(ref mut fut) => fut,
78            None => return Err(PyRuntimeError::new_err(COROUTINE_REUSED_ERROR)),
79        };
80        // reraise thrown exception it
81        match (throw, &self.throw_callback) {
82            (Some(exc), Some(cb)) => cb.throw(exc),
83            (Some(exc), None) => {
84                self.close();
85                return Err(PyErr::from_value(exc.into_bound(py)));
86            }
87            (None, _) => {}
88        }
89        // create a new waker, or try to reset it in place
90        if let Some(waker) = self.waker.as_mut().and_then(Arc::get_mut) {
91            waker.reset();
92        } else {
93            self.waker = Some(Arc::new(AsyncioWaker::new()));
94        }
95        let waker = Waker::from(self.waker.clone().unwrap());
96        // poll the Rust future and forward its results if ready
97        // polling is UnwindSafe because the future is dropped in case of panic
98        let poll = || future_rs.as_mut().poll(&mut Context::from_waker(&waker));
99        match panic::catch_unwind(panic::AssertUnwindSafe(poll)) {
100            Ok(Poll::Ready(res)) => {
101                self.close();
102                return Err(PyStopIteration::new_err((res?,)));
103            }
104            Err(err) => {
105                self.close();
106                return Err(PanicException::from_panic_payload(err));
107            }
108            _ => {}
109        }
110        // otherwise, initialize the waker `asyncio.Future`
111        if let Some(future) = self.waker.as_ref().unwrap().initialize_future(py)? {
112            // `asyncio.Future` must be awaited; fortunately, it implements `__iter__ = __await__`
113            // and will yield itself if its result has not been set in polling above
114            if let Some(future) = PyIterator::from_object(future).unwrap().next() {
115                // future has not been leaked into Python for now, and Rust code can only call
116                // `set_result(None)` in `Wake` implementation, so it's safe to unwrap
117                return Ok(future.unwrap().into());
118            }
119        }
120        // if waker has been waken during future polling, this is roughly equivalent to
121        // `await asyncio.sleep(0)`, so just yield `None`.
122        Ok(py.None())
123    }
124}
125
126#[pymethods(crate = "crate")]
127impl Coroutine {
128    #[getter]
129    fn __name__(&self, py: Python<'_>) -> PyResult<Py<PyString>> {
130        match &self.name {
131            Some(name) => Ok(name.clone_ref(py)),
132            None => Err(PyAttributeError::new_err("__name__")),
133        }
134    }
135
136    #[getter]
137    fn __qualname__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyString>> {
138        match (&self.name, &self.qualname_prefix) {
139            (Some(name), Some(prefix)) => Ok(PyString::new(
140                py,
141                &format!("{}.{}", prefix, name.bind(py).to_cow()?),
142            )),
143            (Some(name), None) => Ok(name.bind(py).clone()),
144            (None, _) => Err(PyAttributeError::new_err("__qualname__")),
145        }
146    }
147
148    fn send(&mut self, py: Python<'_>, _value: &Bound<'_, PyAny>) -> PyResult<PyObject> {
149        self.poll(py, None)
150    }
151
152    fn throw(&mut self, py: Python<'_>, exc: PyObject) -> PyResult<PyObject> {
153        self.poll(py, Some(exc))
154    }
155
156    fn close(&mut self) {
157        // the Rust future is dropped, and the field set to `None`
158        // to indicate the coroutine has been run to completion
159        drop(self.future.take());
160    }
161
162    fn __await__(self_: Py<Self>) -> Py<Self> {
163        self_
164    }
165
166    fn __next__(&mut self, py: Python<'_>) -> PyResult<PyObject> {
167        self.poll(py, None)
168    }
169}
⚠️ Internal Docs ⚠️ Not Public API 👉 Official Docs Here