pyo3/conversions/std/
array.rs

1use crate::conversion::IntoPyObject;
2use crate::instance::Bound;
3use crate::types::any::PyAnyMethods;
4use crate::types::PySequence;
5use crate::{err::DowncastError, ffi, FromPyObject, Py, PyAny, PyObject, PyResult, Python};
6use crate::{exceptions, PyErr};
7#[allow(deprecated)]
8use crate::{IntoPy, ToPyObject};
9
10#[allow(deprecated)]
11impl<T, const N: usize> IntoPy<PyObject> for [T; N]
12where
13    T: IntoPy<PyObject>,
14{
15    fn into_py(self, py: Python<'_>) -> PyObject {
16        unsafe {
17            let len = N as ffi::Py_ssize_t;
18
19            let ptr = ffi::PyList_New(len);
20
21            // We create the  `Py` pointer here for two reasons:
22            // - panics if the ptr is null
23            // - its Drop cleans up the list if user code panics.
24            let list: Py<PyAny> = Py::from_owned_ptr(py, ptr);
25
26            for (i, obj) in (0..len).zip(self) {
27                let obj = obj.into_py(py).into_ptr();
28
29                #[cfg(not(Py_LIMITED_API))]
30                ffi::PyList_SET_ITEM(ptr, i, obj);
31                #[cfg(Py_LIMITED_API)]
32                ffi::PyList_SetItem(ptr, i, obj);
33            }
34
35            list
36        }
37    }
38}
39
40impl<'py, T, const N: usize> IntoPyObject<'py> for [T; N]
41where
42    T: IntoPyObject<'py>,
43{
44    type Target = PyAny;
45    type Output = Bound<'py, Self::Target>;
46    type Error = PyErr;
47
48    /// Turns [`[u8; N]`](std::array) into [`PyBytes`], all other `T`s will be turned into a [`PyList`]
49    ///
50    /// [`PyBytes`]: crate::types::PyBytes
51    /// [`PyList`]: crate::types::PyList
52    #[inline]
53    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
54        T::owned_sequence_into_pyobject(self, py, crate::conversion::private::Token)
55    }
56}
57
58impl<'a, 'py, T, const N: usize> IntoPyObject<'py> for &'a [T; N]
59where
60    &'a T: IntoPyObject<'py>,
61{
62    type Target = PyAny;
63    type Output = Bound<'py, Self::Target>;
64    type Error = PyErr;
65
66    #[inline]
67    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
68        self.as_slice().into_pyobject(py)
69    }
70}
71
72#[allow(deprecated)]
73impl<T, const N: usize> ToPyObject for [T; N]
74where
75    T: ToPyObject,
76{
77    fn to_object(&self, py: Python<'_>) -> PyObject {
78        self.as_ref().to_object(py)
79    }
80}
81
82impl<'py, T, const N: usize> FromPyObject<'py> for [T; N]
83where
84    T: FromPyObject<'py>,
85{
86    fn extract_bound(obj: &Bound<'py, PyAny>) -> PyResult<Self> {
87        create_array_from_obj(obj)
88    }
89}
90
91fn create_array_from_obj<'py, T, const N: usize>(obj: &Bound<'py, PyAny>) -> PyResult<[T; N]>
92where
93    T: FromPyObject<'py>,
94{
95    // Types that pass `PySequence_Check` usually implement enough of the sequence protocol
96    // to support this function and if not, we will only fail extraction safely.
97    let seq = unsafe {
98        if ffi::PySequence_Check(obj.as_ptr()) != 0 {
99            obj.downcast_unchecked::<PySequence>()
100        } else {
101            return Err(DowncastError::new(obj, "Sequence").into());
102        }
103    };
104    let seq_len = seq.len()?;
105    if seq_len != N {
106        return Err(invalid_sequence_length(N, seq_len));
107    }
108    array_try_from_fn(|idx| seq.get_item(idx).and_then(|any| any.extract()))
109}
110
111// TODO use std::array::try_from_fn, if that stabilises:
112// (https://github.com/rust-lang/rust/issues/89379)
113fn array_try_from_fn<E, F, T, const N: usize>(mut cb: F) -> Result<[T; N], E>
114where
115    F: FnMut(usize) -> Result<T, E>,
116{
117    // Helper to safely create arrays since the standard library doesn't
118    // provide one yet. Shouldn't be necessary in the future.
119    struct ArrayGuard<T, const N: usize> {
120        dst: *mut T,
121        initialized: usize,
122    }
123
124    impl<T, const N: usize> Drop for ArrayGuard<T, N> {
125        fn drop(&mut self) {
126            debug_assert!(self.initialized <= N);
127            let initialized_part = core::ptr::slice_from_raw_parts_mut(self.dst, self.initialized);
128            unsafe {
129                core::ptr::drop_in_place(initialized_part);
130            }
131        }
132    }
133
134    // [MaybeUninit<T>; N] would be "nicer" but is actually difficult to create - there are nightly
135    // APIs which would make this easier.
136    let mut array: core::mem::MaybeUninit<[T; N]> = core::mem::MaybeUninit::uninit();
137    let mut guard: ArrayGuard<T, N> = ArrayGuard {
138        dst: array.as_mut_ptr() as _,
139        initialized: 0,
140    };
141    unsafe {
142        let mut value_ptr = array.as_mut_ptr() as *mut T;
143        for i in 0..N {
144            core::ptr::write(value_ptr, cb(i)?);
145            value_ptr = value_ptr.offset(1);
146            guard.initialized += 1;
147        }
148        core::mem::forget(guard);
149        Ok(array.assume_init())
150    }
151}
152
153fn invalid_sequence_length(expected: usize, actual: usize) -> PyErr {
154    exceptions::PyValueError::new_err(format!(
155        "expected a sequence of length {} (got {})",
156        expected, actual
157    ))
158}
159
160#[cfg(test)]
161mod tests {
162    use std::{
163        panic,
164        sync::atomic::{AtomicUsize, Ordering},
165    };
166
167    use crate::{
168        conversion::IntoPyObject,
169        ffi,
170        types::{any::PyAnyMethods, PyBytes, PyBytesMethods},
171    };
172    use crate::{types::PyList, PyResult, Python};
173
174    #[test]
175    fn array_try_from_fn() {
176        static DROP_COUNTER: AtomicUsize = AtomicUsize::new(0);
177        struct CountDrop;
178        impl Drop for CountDrop {
179            fn drop(&mut self) {
180                DROP_COUNTER.fetch_add(1, Ordering::SeqCst);
181            }
182        }
183        let _ = catch_unwind_silent(move || {
184            let _: Result<[CountDrop; 4], ()> = super::array_try_from_fn(|idx| {
185                #[allow(clippy::manual_assert)]
186                if idx == 2 {
187                    panic!("peek a boo");
188                }
189                Ok(CountDrop)
190            });
191        });
192        assert_eq!(DROP_COUNTER.load(Ordering::SeqCst), 2);
193    }
194
195    #[test]
196    fn test_extract_bytearray_to_array() {
197        Python::with_gil(|py| {
198            let v: [u8; 33] = py
199                .eval(
200                    ffi::c_str!("bytearray(b'abcabcabcabcabcabcabcabcabcabcabc')"),
201                    None,
202                    None,
203                )
204                .unwrap()
205                .extract()
206                .unwrap();
207            assert!(&v == b"abcabcabcabcabcabcabcabcabcabcabc");
208        })
209    }
210
211    #[test]
212    fn test_extract_small_bytearray_to_array() {
213        Python::with_gil(|py| {
214            let v: [u8; 3] = py
215                .eval(ffi::c_str!("bytearray(b'abc')"), None, None)
216                .unwrap()
217                .extract()
218                .unwrap();
219            assert!(&v == b"abc");
220        });
221    }
222    #[test]
223    fn test_into_pyobject_array_conversion() {
224        Python::with_gil(|py| {
225            let array: [f32; 4] = [0.0, -16.0, 16.0, 42.0];
226            let pyobject = array.into_pyobject(py).unwrap();
227            let pylist = pyobject.downcast::<PyList>().unwrap();
228            assert_eq!(pylist.get_item(0).unwrap().extract::<f32>().unwrap(), 0.0);
229            assert_eq!(pylist.get_item(1).unwrap().extract::<f32>().unwrap(), -16.0);
230            assert_eq!(pylist.get_item(2).unwrap().extract::<f32>().unwrap(), 16.0);
231            assert_eq!(pylist.get_item(3).unwrap().extract::<f32>().unwrap(), 42.0);
232        });
233    }
234
235    #[test]
236    fn test_extract_invalid_sequence_length() {
237        Python::with_gil(|py| {
238            let v: PyResult<[u8; 3]> = py
239                .eval(ffi::c_str!("bytearray(b'abcdefg')"), None, None)
240                .unwrap()
241                .extract();
242            assert_eq!(
243                v.unwrap_err().to_string(),
244                "ValueError: expected a sequence of length 3 (got 7)"
245            );
246        })
247    }
248
249    #[test]
250    fn test_intopyobject_array_conversion() {
251        Python::with_gil(|py| {
252            let array: [f32; 4] = [0.0, -16.0, 16.0, 42.0];
253            let pylist = array
254                .into_pyobject(py)
255                .unwrap()
256                .downcast_into::<PyList>()
257                .unwrap();
258
259            assert_eq!(pylist.get_item(0).unwrap().extract::<f32>().unwrap(), 0.0);
260            assert_eq!(pylist.get_item(1).unwrap().extract::<f32>().unwrap(), -16.0);
261            assert_eq!(pylist.get_item(2).unwrap().extract::<f32>().unwrap(), 16.0);
262            assert_eq!(pylist.get_item(3).unwrap().extract::<f32>().unwrap(), 42.0);
263        });
264    }
265
266    #[test]
267    fn test_array_intopyobject_impl() {
268        Python::with_gil(|py| {
269            let bytes: [u8; 6] = *b"foobar";
270            let obj = bytes.into_pyobject(py).unwrap();
271            assert!(obj.is_instance_of::<PyBytes>());
272            let obj = obj.downcast_into::<PyBytes>().unwrap();
273            assert_eq!(obj.as_bytes(), &bytes);
274
275            let nums: [u16; 4] = [0, 1, 2, 3];
276            let obj = nums.into_pyobject(py).unwrap();
277            assert!(obj.is_instance_of::<PyList>());
278        });
279    }
280
281    #[test]
282    fn test_extract_non_iterable_to_array() {
283        Python::with_gil(|py| {
284            let v = py.eval(ffi::c_str!("42"), None, None).unwrap();
285            v.extract::<i32>().unwrap();
286            v.extract::<[i32; 1]>().unwrap_err();
287        });
288    }
289
290    #[cfg(feature = "macros")]
291    #[test]
292    fn test_pyclass_intopy_array_conversion() {
293        #[crate::pyclass(crate = "crate")]
294        struct Foo;
295
296        Python::with_gil(|py| {
297            let array: [Foo; 8] = [Foo, Foo, Foo, Foo, Foo, Foo, Foo, Foo];
298            let list = array
299                .into_pyobject(py)
300                .unwrap()
301                .downcast_into::<PyList>()
302                .unwrap();
303            let _bound = list.get_item(4).unwrap().downcast::<Foo>().unwrap();
304        });
305    }
306
307    // https://stackoverflow.com/a/59211505
308    fn catch_unwind_silent<F, R>(f: F) -> std::thread::Result<R>
309    where
310        F: FnOnce() -> R + panic::UnwindSafe,
311    {
312        let prev_hook = panic::take_hook();
313        panic::set_hook(Box::new(|_| {}));
314        let result = panic::catch_unwind(f);
315        panic::set_hook(prev_hook);
316        result
317    }
318}