1use crate::sync::GILOnceCell;
2use crate::types::any::PyAnyMethods;
3use crate::types::PyCFunction;
4use crate::{intern, wrap_pyfunction, Bound, Py, PyAny, PyObject, PyResult, Python};
5use pyo3_macros::pyfunction;
6use std::sync::Arc;
7use std::task::Wake;
8
9pub struct AsyncioWaker(GILOnceCell<Option<LoopAndFuture>>);
18
19impl AsyncioWaker {
20 pub(super) fn new() -> Self {
21 Self(GILOnceCell::new())
22 }
23
24 pub(super) fn reset(&mut self) {
25 self.0.take();
26 }
27
28 pub(super) fn initialize_future<'py>(
29 &self,
30 py: Python<'py>,
31 ) -> PyResult<Option<&Bound<'py, PyAny>>> {
32 let init = || LoopAndFuture::new(py).map(Some);
33 let loop_and_future = self.0.get_or_try_init(py, init)?.as_ref();
34 Ok(loop_and_future.map(|LoopAndFuture { future, .. }| future.bind(py)))
35 }
36}
37
38impl Wake for AsyncioWaker {
39 fn wake(self: Arc<Self>) {
40 self.wake_by_ref()
41 }
42
43 fn wake_by_ref(self: &Arc<Self>) {
44 Python::with_gil(|gil| {
45 if let Some(loop_and_future) = self.0.get_or_init(gil, || None) {
46 loop_and_future
47 .set_result(gil)
48 .expect("unexpected error in coroutine waker");
49 }
50 });
51 }
52}
53
54struct LoopAndFuture {
55 event_loop: PyObject,
56 future: PyObject,
57}
58
59impl LoopAndFuture {
60 fn new(py: Python<'_>) -> PyResult<Self> {
61 static GET_RUNNING_LOOP: GILOnceCell<PyObject> = GILOnceCell::new();
62 let import = || -> PyResult<_> {
63 let module = py.import("asyncio")?;
64 Ok(module.getattr("get_running_loop")?.into())
65 };
66 let event_loop = GET_RUNNING_LOOP.get_or_try_init(py, import)?.call0(py)?;
67 let future = event_loop.call_method0(py, "create_future")?;
68 Ok(Self { event_loop, future })
69 }
70
71 fn set_result(&self, py: Python<'_>) -> PyResult<()> {
72 static RELEASE_WAITER: GILOnceCell<Py<PyCFunction>> = GILOnceCell::new();
73 let release_waiter = RELEASE_WAITER.get_or_try_init(py, || {
74 wrap_pyfunction!(release_waiter, py).map(Bound::unbind)
75 })?;
76 let call_soon_threadsafe = self.event_loop.call_method1(
79 py,
80 intern!(py, "call_soon_threadsafe"),
81 (release_waiter, self.future.bind(py)),
82 );
83 if let Err(err) = call_soon_threadsafe {
84 let is_closed = self.event_loop.call_method0(py, "is_closed")?;
87 if !is_closed.extract(py)? {
88 return Err(err);
89 }
90 }
91 Ok(())
92 }
93}
94
95#[pyfunction(crate = "crate")]
100fn release_waiter(future: &Bound<'_, PyAny>) -> PyResult<()> {
101 let done = future.call_method0(intern!(future.py(), "done"))?;
102 if !done.extract::<bool>()? {
103 future.call_method1(intern!(future.py(), "set_result"), (future.py().None(),))?;
104 }
105 Ok(())
106}