pyo3/conversions/
num_complex.rs

1#![cfg(feature = "num-complex")]
2
3//!  Conversions to and from [num-complex](https://docs.rs/num-complex)’
4//! [`Complex`]`<`[`f32`]`>` and [`Complex`]`<`[`f64`]`>`.
5//!
6//! num-complex’ [`Complex`] supports more operations than PyO3's [`PyComplex`]
7//! and can be used with the rest of the Rust ecosystem.
8//!
9//! # Setup
10//!
11//! To use this feature, add this to your **`Cargo.toml`**:
12//!
13//! ```toml
14//! [dependencies]
15//! # change * to the latest versions
16//! num-complex = "*"
17#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"),  "\", features = [\"num-complex\"] }")]
18//! ```
19//!
20//! Note that you must use compatible versions of num-complex and PyO3.
21//! The required num-complex version may vary based on the version of PyO3.
22//!
23//! # Examples
24//!
25//! Using [num-complex](https://docs.rs/num-complex) and [nalgebra](https://docs.rs/nalgebra)
26//! to create a pyfunction that calculates the eigenvalues of a 2x2 matrix.
27//! ```ignore
28//! # // not tested because nalgebra isn't supported on msrv
29//! # // please file an issue if it breaks!
30//! use nalgebra::base::{dimension::Const, Matrix};
31//! use num_complex::Complex;
32//! use pyo3::prelude::*;
33//!
34//! type T = Complex<f64>;
35//!
36//! #[pyfunction]
37//! fn get_eigenvalues(m11: T, m12: T, m21: T, m22: T) -> Vec<T> {
38//!     let mat = Matrix::<T, Const<2>, Const<2>, _>::new(m11, m12, m21, m22);
39//!
40//!     match mat.eigenvalues() {
41//!         Some(e) => e.data.as_slice().to_vec(),
42//!         None => vec![],
43//!     }
44//! }
45//!
46//! #[pymodule]
47//! fn my_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
48//!     m.add_function(wrap_pyfunction!(get_eigenvalues, m)?)?;
49//!     Ok(())
50//! }
51//! # // test
52//! # use assert_approx_eq::assert_approx_eq;
53//! # use nalgebra::ComplexField;
54//! # use pyo3::types::PyComplex;
55//! #
56//! # fn main() -> PyResult<()> {
57//! #     Python::with_gil(|py| -> PyResult<()> {
58//! #         let module = PyModule::new(py, "my_module")?;
59//! #
60//! #         module.add_function(&wrap_pyfunction!(get_eigenvalues, module)?)?;
61//! #
62//! #         let m11 = PyComplex::from_doubles(py, 0_f64, -1_f64);
63//! #         let m12 = PyComplex::from_doubles(py, 1_f64, 0_f64);
64//! #         let m21 = PyComplex::from_doubles(py, 2_f64, -1_f64);
65//! #         let m22 = PyComplex::from_doubles(py, -1_f64, 0_f64);
66//! #
67//! #         let result = module
68//! #             .getattr("get_eigenvalues")?
69//! #             .call1((m11, m12, m21, m22))?;
70//! #         println!("eigenvalues: {:?}", result);
71//! #
72//! #         let result = result.extract::<Vec<T>>()?;
73//! #         let e0 = result[0];
74//! #         let e1 = result[1];
75//! #
76//! #         assert_approx_eq!(e0, Complex::new(1_f64, -1_f64));
77//! #         assert_approx_eq!(e1, Complex::new(-2_f64, 0_f64));
78//! #
79//! #         Ok(())
80//! #     })
81//! # }
82//! ```
83//!
84//! Python code:
85//! ```python
86//! from my_module import get_eigenvalues
87//!
88//! m11 = complex(0,-1)
89//! m12 = complex(1,0)
90//! m21 = complex(2,-1)
91//! m22 = complex(-1,0)
92//!
93//! result = get_eigenvalues(m11,m12,m21,m22)
94//! assert result == [complex(1,-1), complex(-2,0)]
95//! ```
96#[allow(deprecated)]
97use crate::ToPyObject;
98use crate::{
99    ffi,
100    ffi_ptr_ext::FfiPtrExt,
101    types::{any::PyAnyMethods, PyComplex},
102    Bound, FromPyObject, PyAny, PyErr, PyObject, PyResult, Python,
103};
104use num_complex::Complex;
105use std::os::raw::c_double;
106
107impl PyComplex {
108    /// Creates a new Python `PyComplex` object from `num_complex`'s [`Complex`].
109    pub fn from_complex_bound<F: Into<c_double>>(
110        py: Python<'_>,
111        complex: Complex<F>,
112    ) -> Bound<'_, PyComplex> {
113        unsafe {
114            ffi::PyComplex_FromDoubles(complex.re.into(), complex.im.into())
115                .assume_owned(py)
116                .downcast_into_unchecked()
117        }
118    }
119}
120
121macro_rules! complex_conversion {
122    ($float: ty) => {
123        #[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
124        #[allow(deprecated)]
125        impl ToPyObject for Complex<$float> {
126            #[inline]
127            fn to_object(&self, py: Python<'_>) -> PyObject {
128                crate::IntoPy::<PyObject>::into_py(self.to_owned(), py)
129            }
130        }
131
132        #[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
133        #[allow(deprecated)]
134        impl crate::IntoPy<PyObject> for Complex<$float> {
135            fn into_py(self, py: Python<'_>) -> PyObject {
136                unsafe {
137                    let raw_obj =
138                        ffi::PyComplex_FromDoubles(self.re as c_double, self.im as c_double);
139                    PyObject::from_owned_ptr(py, raw_obj)
140                }
141            }
142        }
143
144        #[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
145        impl<'py> crate::conversion::IntoPyObject<'py> for Complex<$float> {
146            type Target = PyComplex;
147            type Output = Bound<'py, Self::Target>;
148            type Error = std::convert::Infallible;
149
150            fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
151                unsafe {
152                    Ok(
153                        ffi::PyComplex_FromDoubles(self.re as c_double, self.im as c_double)
154                            .assume_owned(py)
155                            .downcast_into_unchecked(),
156                    )
157                }
158            }
159        }
160
161        #[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
162        impl<'py> crate::conversion::IntoPyObject<'py> for &Complex<$float> {
163            type Target = PyComplex;
164            type Output = Bound<'py, Self::Target>;
165            type Error = std::convert::Infallible;
166
167            #[inline]
168            fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
169                (*self).into_pyobject(py)
170            }
171        }
172
173        #[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
174        impl FromPyObject<'_> for Complex<$float> {
175            fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Complex<$float>> {
176                #[cfg(not(any(Py_LIMITED_API, PyPy)))]
177                unsafe {
178                    let val = ffi::PyComplex_AsCComplex(obj.as_ptr());
179                    if val.real == -1.0 {
180                        if let Some(err) = PyErr::take(obj.py()) {
181                            return Err(err);
182                        }
183                    }
184                    Ok(Complex::new(val.real as $float, val.imag as $float))
185                }
186
187                #[cfg(any(Py_LIMITED_API, PyPy))]
188                unsafe {
189                    let complex;
190                    let obj = if obj.is_instance_of::<PyComplex>() {
191                        obj
192                    } else if let Some(method) =
193                        obj.lookup_special(crate::intern!(obj.py(), "__complex__"))?
194                    {
195                        complex = method.call0()?;
196                        &complex
197                    } else {
198                        // `obj` might still implement `__float__` or `__index__`, which will be
199                        // handled by `PyComplex_{Real,Imag}AsDouble`, including propagating any
200                        // errors if those methods don't exist / raise exceptions.
201                        obj
202                    };
203                    let ptr = obj.as_ptr();
204                    let real = ffi::PyComplex_RealAsDouble(ptr);
205                    if real == -1.0 {
206                        if let Some(err) = PyErr::take(obj.py()) {
207                            return Err(err);
208                        }
209                    }
210                    let imag = ffi::PyComplex_ImagAsDouble(ptr);
211                    Ok(Complex::new(real as $float, imag as $float))
212                }
213            }
214        }
215    };
216}
217complex_conversion!(f32);
218complex_conversion!(f64);
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use crate::tests::common::generate_unique_module_name;
224    use crate::types::{complex::PyComplexMethods, PyModule};
225    use crate::IntoPyObject;
226    use pyo3_ffi::c_str;
227
228    #[test]
229    fn from_complex() {
230        Python::with_gil(|py| {
231            let complex = Complex::new(3.0, 1.2);
232            let py_c = PyComplex::from_complex_bound(py, complex);
233            assert_eq!(py_c.real(), 3.0);
234            assert_eq!(py_c.imag(), 1.2);
235        });
236    }
237    #[test]
238    fn to_from_complex() {
239        Python::with_gil(|py| {
240            let val = Complex::new(3.0f64, 1.2);
241            let obj = val.into_pyobject(py).unwrap();
242            assert_eq!(obj.extract::<Complex<f64>>().unwrap(), val);
243        });
244    }
245    #[test]
246    fn from_complex_err() {
247        Python::with_gil(|py| {
248            let obj = vec![1i32].into_pyobject(py).unwrap();
249            assert!(obj.extract::<Complex<f64>>().is_err());
250        });
251    }
252    #[test]
253    fn from_python_magic() {
254        Python::with_gil(|py| {
255            let module = PyModule::from_code(
256                py,
257                c_str!(
258                    r#"
259class A:
260    def __complex__(self): return 3.0+1.2j
261class B:
262    def __float__(self): return 3.0
263class C:
264    def __index__(self): return 3
265                "#
266                ),
267                c_str!("test.py"),
268                &generate_unique_module_name("test"),
269            )
270            .unwrap();
271            let from_complex = module.getattr("A").unwrap().call0().unwrap();
272            assert_eq!(
273                from_complex.extract::<Complex<f64>>().unwrap(),
274                Complex::new(3.0, 1.2)
275            );
276            let from_float = module.getattr("B").unwrap().call0().unwrap();
277            assert_eq!(
278                from_float.extract::<Complex<f64>>().unwrap(),
279                Complex::new(3.0, 0.0)
280            );
281            // Before Python 3.8, `__index__` wasn't tried by `float`/`complex`.
282            #[cfg(Py_3_8)]
283            {
284                let from_index = module.getattr("C").unwrap().call0().unwrap();
285                assert_eq!(
286                    from_index.extract::<Complex<f64>>().unwrap(),
287                    Complex::new(3.0, 0.0)
288                );
289            }
290        })
291    }
292    #[test]
293    fn from_python_inherited_magic() {
294        Python::with_gil(|py| {
295            let module = PyModule::from_code(
296                py,
297                c_str!(
298                    r#"
299class First: pass
300class ComplexMixin:
301    def __complex__(self): return 3.0+1.2j
302class FloatMixin:
303    def __float__(self): return 3.0
304class IndexMixin:
305    def __index__(self): return 3
306class A(First, ComplexMixin): pass
307class B(First, FloatMixin): pass
308class C(First, IndexMixin): pass
309                "#
310                ),
311                c_str!("test.py"),
312                &generate_unique_module_name("test"),
313            )
314            .unwrap();
315            let from_complex = module.getattr("A").unwrap().call0().unwrap();
316            assert_eq!(
317                from_complex.extract::<Complex<f64>>().unwrap(),
318                Complex::new(3.0, 1.2)
319            );
320            let from_float = module.getattr("B").unwrap().call0().unwrap();
321            assert_eq!(
322                from_float.extract::<Complex<f64>>().unwrap(),
323                Complex::new(3.0, 0.0)
324            );
325            #[cfg(Py_3_8)]
326            {
327                let from_index = module.getattr("C").unwrap().call0().unwrap();
328                assert_eq!(
329                    from_index.extract::<Complex<f64>>().unwrap(),
330                    Complex::new(3.0, 0.0)
331                );
332            }
333        })
334    }
335    #[test]
336    fn from_python_noncallable_descriptor_magic() {
337        // Functions and lambdas implement the descriptor protocol in a way that makes
338        // `type(inst).attr(inst)` equivalent to `inst.attr()` for methods, but this isn't the only
339        // way the descriptor protocol might be implemented.
340        Python::with_gil(|py| {
341            let module = PyModule::from_code(
342                py,
343                c_str!(
344                    r#"
345class A:
346    @property
347    def __complex__(self):
348        return lambda: 3.0+1.2j
349                "#
350                ),
351                c_str!("test.py"),
352                &generate_unique_module_name("test"),
353            )
354            .unwrap();
355            let obj = module.getattr("A").unwrap().call0().unwrap();
356            assert_eq!(
357                obj.extract::<Complex<f64>>().unwrap(),
358                Complex::new(3.0, 1.2)
359            );
360        })
361    }
362    #[test]
363    fn from_python_nondescriptor_magic() {
364        // Magic methods don't need to implement the descriptor protocol, if they're callable.
365        Python::with_gil(|py| {
366            let module = PyModule::from_code(
367                py,
368                c_str!(
369                    r#"
370class MyComplex:
371    def __call__(self): return 3.0+1.2j
372class A:
373    __complex__ = MyComplex()
374                "#
375                ),
376                c_str!("test.py"),
377                &generate_unique_module_name("test"),
378            )
379            .unwrap();
380            let obj = module.getattr("A").unwrap().call0().unwrap();
381            assert_eq!(
382                obj.extract::<Complex<f64>>().unwrap(),
383                Complex::new(3.0, 1.2)
384            );
385        })
386    }
387}