pyo3/coroutine/
waker.rs

1use crate::sync::GILOnceCell;
2use crate::types::any::PyAnyMethods;
3use crate::types::PyCFunction;
4use crate::{intern, wrap_pyfunction, Bound, Py, PyAny, PyObject, PyResult, Python};
5use pyo3_macros::pyfunction;
6use std::sync::Arc;
7use std::task::Wake;
8
9/// Lazy `asyncio.Future` wrapper, implementing [`Wake`] by calling `Future.set_result`.
10///
11/// asyncio future is let uninitialized until [`initialize_future`][1] is called.
12/// If [`wake`][2] is called before future initialization (during Rust future polling),
13/// [`initialize_future`][1] will return `None` (it is roughly equivalent to `asyncio.sleep(0)`)
14///
15/// [1]: AsyncioWaker::initialize_future
16/// [2]: AsyncioWaker::wake
17pub struct AsyncioWaker(GILOnceCell<Option<LoopAndFuture>>);
18
19impl AsyncioWaker {
20    pub(super) fn new() -> Self {
21        Self(GILOnceCell::new())
22    }
23
24    pub(super) fn reset(&mut self) {
25        self.0.take();
26    }
27
28    pub(super) fn initialize_future<'py>(
29        &self,
30        py: Python<'py>,
31    ) -> PyResult<Option<&Bound<'py, PyAny>>> {
32        let init = || LoopAndFuture::new(py).map(Some);
33        let loop_and_future = self.0.get_or_try_init(py, init)?.as_ref();
34        Ok(loop_and_future.map(|LoopAndFuture { future, .. }| future.bind(py)))
35    }
36}
37
38impl Wake for AsyncioWaker {
39    fn wake(self: Arc<Self>) {
40        self.wake_by_ref()
41    }
42
43    fn wake_by_ref(self: &Arc<Self>) {
44        Python::with_gil(|gil| {
45            if let Some(loop_and_future) = self.0.get_or_init(gil, || None) {
46                loop_and_future
47                    .set_result(gil)
48                    .expect("unexpected error in coroutine waker");
49            }
50        });
51    }
52}
53
54struct LoopAndFuture {
55    event_loop: PyObject,
56    future: PyObject,
57}
58
59impl LoopAndFuture {
60    fn new(py: Python<'_>) -> PyResult<Self> {
61        static GET_RUNNING_LOOP: GILOnceCell<PyObject> = GILOnceCell::new();
62        let import = || -> PyResult<_> {
63            let module = py.import("asyncio")?;
64            Ok(module.getattr("get_running_loop")?.into())
65        };
66        let event_loop = GET_RUNNING_LOOP.get_or_try_init(py, import)?.call0(py)?;
67        let future = event_loop.call_method0(py, "create_future")?;
68        Ok(Self { event_loop, future })
69    }
70
71    fn set_result(&self, py: Python<'_>) -> PyResult<()> {
72        static RELEASE_WAITER: GILOnceCell<Py<PyCFunction>> = GILOnceCell::new();
73        let release_waiter = RELEASE_WAITER.get_or_try_init(py, || {
74            wrap_pyfunction!(release_waiter, py).map(Bound::unbind)
75        })?;
76        // `Future.set_result` must be called in event loop thread,
77        // so it requires `call_soon_threadsafe`
78        let call_soon_threadsafe = self.event_loop.call_method1(
79            py,
80            intern!(py, "call_soon_threadsafe"),
81            (release_waiter, self.future.bind(py)),
82        );
83        if let Err(err) = call_soon_threadsafe {
84            // `call_soon_threadsafe` will raise if the event loop is closed;
85            // instead of catching an unspecific `RuntimeError`, check directly if it's closed.
86            let is_closed = self.event_loop.call_method0(py, "is_closed")?;
87            if !is_closed.extract(py)? {
88                return Err(err);
89            }
90        }
91        Ok(())
92    }
93}
94
95/// Call `future.set_result` if the future is not done.
96///
97/// Future can be cancelled by the event loop before being waken.
98/// See <https://github.com/python/cpython/blob/main/Lib/asyncio/tasks.py#L452C5-L452C5>
99#[pyfunction(crate = "crate")]
100fn release_waiter(future: &Bound<'_, PyAny>) -> PyResult<()> {
101    let done = future.call_method0(intern!(future.py(), "done"))?;
102    if !done.extract::<bool>()? {
103        future.call_method1(intern!(future.py(), "set_result"), (future.py().None(),))?;
104    }
105    Ok(())
106}
⚠️ Internal Docs ⚠️ Not Public API 👉 Official Docs Here