pyo3/conversions/
rust_decimal.rs1#![cfg(feature = "rust_decimal")]
2#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"), "\", features = [\"rust_decimal\"] }")]
13use crate::conversion::IntoPyObject;
53use crate::exceptions::PyValueError;
54use crate::sync::GILOnceCell;
55use crate::types::any::PyAnyMethods;
56use crate::types::string::PyStringMethods;
57use crate::types::PyType;
58use crate::{Bound, FromPyObject, Py, PyAny, PyErr, PyObject, PyResult, Python};
59#[allow(deprecated)]
60use crate::{IntoPy, ToPyObject};
61use rust_decimal::Decimal;
62use std::str::FromStr;
63
64impl FromPyObject<'_> for Decimal {
65 fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
66 if let Ok(val) = obj.extract() {
68 Ok(Decimal::new(val, 0))
69 } else {
70 let py_str = &obj.str()?;
71 let rs_str = &py_str.to_cow()?;
72 Decimal::from_str(rs_str).or_else(|_| {
73 Decimal::from_scientific(rs_str).map_err(|e| PyValueError::new_err(e.to_string()))
74 })
75 }
76 }
77}
78
79static DECIMAL_CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();
80
81fn get_decimal_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
82 DECIMAL_CLS.import(py, "decimal", "Decimal")
83}
84
85#[allow(deprecated)]
86impl ToPyObject for Decimal {
87 #[inline]
88 fn to_object(&self, py: Python<'_>) -> PyObject {
89 self.into_pyobject(py).unwrap().into_any().unbind()
90 }
91}
92
93#[allow(deprecated)]
94impl IntoPy<PyObject> for Decimal {
95 #[inline]
96 fn into_py(self, py: Python<'_>) -> PyObject {
97 self.into_pyobject(py).unwrap().into_any().unbind()
98 }
99}
100
101impl<'py> IntoPyObject<'py> for Decimal {
102 type Target = PyAny;
103 type Output = Bound<'py, Self::Target>;
104 type Error = PyErr;
105
106 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
107 let dec_cls = get_decimal_cls(py)?;
108 dec_cls.call1((self.to_string(),))
111 }
112}
113
114impl<'py> IntoPyObject<'py> for &Decimal {
115 type Target = PyAny;
116 type Output = Bound<'py, Self::Target>;
117 type Error = PyErr;
118
119 #[inline]
120 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
121 (*self).into_pyobject(py)
122 }
123}
124
125#[cfg(test)]
126mod test_rust_decimal {
127 use super::*;
128 use crate::types::dict::PyDictMethods;
129 use crate::types::PyDict;
130 use std::ffi::CString;
131
132 use crate::ffi;
133 #[cfg(not(target_arch = "wasm32"))]
134 use proptest::prelude::*;
135
136 macro_rules! convert_constants {
137 ($name:ident, $rs:expr, $py:literal) => {
138 #[test]
139 fn $name() {
140 Python::with_gil(|py| {
141 let rs_orig = $rs;
142 let rs_dec = rs_orig.into_pyobject(py).unwrap();
143 let locals = PyDict::new(py);
144 locals.set_item("rs_dec", &rs_dec).unwrap();
145 py.run(
147 &CString::new(format!(
148 "import decimal\npy_dec = decimal.Decimal({})\nassert py_dec == rs_dec",
149 $py
150 ))
151 .unwrap(),
152 None,
153 Some(&locals),
154 )
155 .unwrap();
156 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
158 let py_result: Decimal = py_dec.extract().unwrap();
159 assert_eq!(rs_orig, py_result);
160 })
161 }
162 };
163 }
164
165 convert_constants!(convert_zero, Decimal::ZERO, "0");
166 convert_constants!(convert_one, Decimal::ONE, "1");
167 convert_constants!(convert_neg_one, Decimal::NEGATIVE_ONE, "-1");
168 convert_constants!(convert_two, Decimal::TWO, "2");
169 convert_constants!(convert_ten, Decimal::TEN, "10");
170 convert_constants!(convert_one_hundred, Decimal::ONE_HUNDRED, "100");
171 convert_constants!(convert_one_thousand, Decimal::ONE_THOUSAND, "1000");
172
173 #[cfg(not(target_arch = "wasm32"))]
174 proptest! {
175 #[test]
176 fn test_roundtrip(
177 lo in any::<u32>(),
178 mid in any::<u32>(),
179 high in any::<u32>(),
180 negative in any::<bool>(),
181 scale in 0..28u32
182 ) {
183 let num = Decimal::from_parts(lo, mid, high, negative, scale);
184 Python::with_gil(|py| {
185 let rs_dec = num.into_pyobject(py).unwrap();
186 let locals = PyDict::new(py);
187 locals.set_item("rs_dec", &rs_dec).unwrap();
188 py.run(
189 &CString::new(format!(
190 "import decimal\npy_dec = decimal.Decimal(\"{}\")\nassert py_dec == rs_dec",
191 num)).unwrap(),
192 None, Some(&locals)).unwrap();
193 let roundtripped: Decimal = rs_dec.extract().unwrap();
194 assert_eq!(num, roundtripped);
195 })
196 }
197
198 #[test]
199 fn test_integers(num in any::<i64>()) {
200 Python::with_gil(|py| {
201 let py_num = num.into_pyobject(py).unwrap();
202 let roundtripped: Decimal = py_num.extract().unwrap();
203 let rs_dec = Decimal::new(num, 0);
204 assert_eq!(rs_dec, roundtripped);
205 })
206 }
207 }
208
209 #[test]
210 fn test_nan() {
211 Python::with_gil(|py| {
212 let locals = PyDict::new(py);
213 py.run(
214 ffi::c_str!("import decimal\npy_dec = decimal.Decimal(\"NaN\")"),
215 None,
216 Some(&locals),
217 )
218 .unwrap();
219 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
220 let roundtripped: Result<Decimal, PyErr> = py_dec.extract();
221 assert!(roundtripped.is_err());
222 })
223 }
224
225 #[test]
226 fn test_scientific_notation() {
227 Python::with_gil(|py| {
228 let locals = PyDict::new(py);
229 py.run(
230 ffi::c_str!("import decimal\npy_dec = decimal.Decimal(\"1e3\")"),
231 None,
232 Some(&locals),
233 )
234 .unwrap();
235 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
236 let roundtripped: Decimal = py_dec.extract().unwrap();
237 let rs_dec = Decimal::from_scientific("1e3").unwrap();
238 assert_eq!(rs_dec, roundtripped);
239 })
240 }
241
242 #[test]
243 fn test_infinity() {
244 Python::with_gil(|py| {
245 let locals = PyDict::new(py);
246 py.run(
247 ffi::c_str!("import decimal\npy_dec = decimal.Decimal(\"Infinity\")"),
248 None,
249 Some(&locals),
250 )
251 .unwrap();
252 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
253 let roundtripped: Result<Decimal, PyErr> = py_dec.extract();
254 assert!(roundtripped.is_err());
255 })
256 }
257}