pyo3/conversions/
rust_decimal.rs

1#![cfg(feature = "rust_decimal")]
2//! Conversions to and from [rust_decimal](https://docs.rs/rust_decimal)'s [`Decimal`] type.
3//!
4//! This is useful for converting Python's decimal.Decimal into and from a native Rust type.
5//!
6//! # Setup
7//!
8//! To use this feature, add to your **`Cargo.toml`**:
9//!
10//! ```toml
11//! [dependencies]
12#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"),  "\", features = [\"rust_decimal\"] }")]
13//! rust_decimal = "1.0"
14//! ```
15//!
16//! Note that you must use a compatible version of rust_decimal and PyO3.
17//! The required rust_decimal version may vary based on the version of PyO3.
18//!
19//! # Example
20//!
21//! Rust code to create a function that adds one to a Decimal
22//!
23//! ```rust
24//! use rust_decimal::Decimal;
25//! use pyo3::prelude::*;
26//!
27//! #[pyfunction]
28//! fn add_one(d: Decimal) -> Decimal {
29//!     d + Decimal::ONE
30//! }
31//!
32//! #[pymodule]
33//! fn my_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
34//!     m.add_function(wrap_pyfunction!(add_one, m)?)?;
35//!     Ok(())
36//! }
37//! ```
38//!
39//! Python code that validates the functionality
40//!
41//!
42//! ```python
43//! from my_module import add_one
44//! from decimal import Decimal
45//!
46//! d = Decimal("2")
47//! value = add_one(d)
48//!
49//! assert d + 1 == value
50//! ```
51
52use crate::conversion::IntoPyObject;
53use crate::exceptions::PyValueError;
54use crate::sync::GILOnceCell;
55use crate::types::any::PyAnyMethods;
56use crate::types::string::PyStringMethods;
57use crate::types::PyType;
58use crate::{Bound, FromPyObject, Py, PyAny, PyErr, PyObject, PyResult, Python};
59#[allow(deprecated)]
60use crate::{IntoPy, ToPyObject};
61use rust_decimal::Decimal;
62use std::str::FromStr;
63
64impl FromPyObject<'_> for Decimal {
65    fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
66        // use the string representation to not be lossy
67        if let Ok(val) = obj.extract() {
68            Ok(Decimal::new(val, 0))
69        } else {
70            let py_str = &obj.str()?;
71            let rs_str = &py_str.to_cow()?;
72            Decimal::from_str(rs_str).or_else(|_| {
73                Decimal::from_scientific(rs_str).map_err(|e| PyValueError::new_err(e.to_string()))
74            })
75        }
76    }
77}
78
79static DECIMAL_CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();
80
81fn get_decimal_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
82    DECIMAL_CLS.import(py, "decimal", "Decimal")
83}
84
85#[allow(deprecated)]
86impl ToPyObject for Decimal {
87    #[inline]
88    fn to_object(&self, py: Python<'_>) -> PyObject {
89        self.into_pyobject(py).unwrap().into_any().unbind()
90    }
91}
92
93#[allow(deprecated)]
94impl IntoPy<PyObject> for Decimal {
95    #[inline]
96    fn into_py(self, py: Python<'_>) -> PyObject {
97        self.into_pyobject(py).unwrap().into_any().unbind()
98    }
99}
100
101impl<'py> IntoPyObject<'py> for Decimal {
102    type Target = PyAny;
103    type Output = Bound<'py, Self::Target>;
104    type Error = PyErr;
105
106    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
107        let dec_cls = get_decimal_cls(py)?;
108        // now call the constructor with the Rust Decimal string-ified
109        // to not be lossy
110        dec_cls.call1((self.to_string(),))
111    }
112}
113
114impl<'py> IntoPyObject<'py> for &Decimal {
115    type Target = PyAny;
116    type Output = Bound<'py, Self::Target>;
117    type Error = PyErr;
118
119    #[inline]
120    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
121        (*self).into_pyobject(py)
122    }
123}
124
125#[cfg(test)]
126mod test_rust_decimal {
127    use super::*;
128    use crate::types::dict::PyDictMethods;
129    use crate::types::PyDict;
130    use std::ffi::CString;
131
132    use crate::ffi;
133    #[cfg(not(target_arch = "wasm32"))]
134    use proptest::prelude::*;
135
136    macro_rules! convert_constants {
137        ($name:ident, $rs:expr, $py:literal) => {
138            #[test]
139            fn $name() {
140                Python::with_gil(|py| {
141                    let rs_orig = $rs;
142                    let rs_dec = rs_orig.into_pyobject(py).unwrap();
143                    let locals = PyDict::new(py);
144                    locals.set_item("rs_dec", &rs_dec).unwrap();
145                    // Checks if Rust Decimal -> Python Decimal conversion is correct
146                    py.run(
147                        &CString::new(format!(
148                            "import decimal\npy_dec = decimal.Decimal({})\nassert py_dec == rs_dec",
149                            $py
150                        ))
151                        .unwrap(),
152                        None,
153                        Some(&locals),
154                    )
155                    .unwrap();
156                    // Checks if Python Decimal -> Rust Decimal conversion is correct
157                    let py_dec = locals.get_item("py_dec").unwrap().unwrap();
158                    let py_result: Decimal = py_dec.extract().unwrap();
159                    assert_eq!(rs_orig, py_result);
160                })
161            }
162        };
163    }
164
165    convert_constants!(convert_zero, Decimal::ZERO, "0");
166    convert_constants!(convert_one, Decimal::ONE, "1");
167    convert_constants!(convert_neg_one, Decimal::NEGATIVE_ONE, "-1");
168    convert_constants!(convert_two, Decimal::TWO, "2");
169    convert_constants!(convert_ten, Decimal::TEN, "10");
170    convert_constants!(convert_one_hundred, Decimal::ONE_HUNDRED, "100");
171    convert_constants!(convert_one_thousand, Decimal::ONE_THOUSAND, "1000");
172
173    #[cfg(not(target_arch = "wasm32"))]
174    proptest! {
175        #[test]
176        fn test_roundtrip(
177            lo in any::<u32>(),
178            mid in any::<u32>(),
179            high in any::<u32>(),
180            negative in any::<bool>(),
181            scale in 0..28u32
182        ) {
183            let num = Decimal::from_parts(lo, mid, high, negative, scale);
184            Python::with_gil(|py| {
185                let rs_dec = num.into_pyobject(py).unwrap();
186                let locals = PyDict::new(py);
187                locals.set_item("rs_dec", &rs_dec).unwrap();
188                py.run(
189                    &CString::new(format!(
190                       "import decimal\npy_dec = decimal.Decimal(\"{}\")\nassert py_dec == rs_dec",
191                     num)).unwrap(),
192                None, Some(&locals)).unwrap();
193                let roundtripped: Decimal = rs_dec.extract().unwrap();
194                assert_eq!(num, roundtripped);
195            })
196        }
197
198        #[test]
199        fn test_integers(num in any::<i64>()) {
200            Python::with_gil(|py| {
201                let py_num = num.into_pyobject(py).unwrap();
202                let roundtripped: Decimal = py_num.extract().unwrap();
203                let rs_dec = Decimal::new(num, 0);
204                assert_eq!(rs_dec, roundtripped);
205            })
206        }
207    }
208
209    #[test]
210    fn test_nan() {
211        Python::with_gil(|py| {
212            let locals = PyDict::new(py);
213            py.run(
214                ffi::c_str!("import decimal\npy_dec = decimal.Decimal(\"NaN\")"),
215                None,
216                Some(&locals),
217            )
218            .unwrap();
219            let py_dec = locals.get_item("py_dec").unwrap().unwrap();
220            let roundtripped: Result<Decimal, PyErr> = py_dec.extract();
221            assert!(roundtripped.is_err());
222        })
223    }
224
225    #[test]
226    fn test_scientific_notation() {
227        Python::with_gil(|py| {
228            let locals = PyDict::new(py);
229            py.run(
230                ffi::c_str!("import decimal\npy_dec = decimal.Decimal(\"1e3\")"),
231                None,
232                Some(&locals),
233            )
234            .unwrap();
235            let py_dec = locals.get_item("py_dec").unwrap().unwrap();
236            let roundtripped: Decimal = py_dec.extract().unwrap();
237            let rs_dec = Decimal::from_scientific("1e3").unwrap();
238            assert_eq!(rs_dec, roundtripped);
239        })
240    }
241
242    #[test]
243    fn test_infinity() {
244        Python::with_gil(|py| {
245            let locals = PyDict::new(py);
246            py.run(
247                ffi::c_str!("import decimal\npy_dec = decimal.Decimal(\"Infinity\")"),
248                None,
249                Some(&locals),
250            )
251            .unwrap();
252            let py_dec = locals.get_item("py_dec").unwrap().unwrap();
253            let roundtripped: Result<Decimal, PyErr> = py_dec.extract();
254            assert!(roundtripped.is_err());
255        })
256    }
257}