|
15 | 15 | // specific language governing permissions and limitations |
16 | 16 | // under the License. |
17 | 17 |
|
| 18 | +use crate::common::data_type::PyScalarValue; |
18 | 19 | use crate::errors::{PyDataFusionError, PyDataFusionResult}; |
19 | 20 | use crate::TokioRuntime; |
20 | 21 | use datafusion::common::ScalarValue; |
@@ -88,144 +89,42 @@ pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyRe |
88 | 89 |
|
89 | 90 | Ok(()) |
90 | 91 | } |
91 | | -/// Convert a python object to a ScalarValue |
92 | | -pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> ScalarValue { |
93 | | - // Try extracting primitive types first |
94 | | - if let Some(scalar) = try_extract_primitive(py, &obj) { |
95 | | - return scalar; |
96 | | - } |
97 | | - |
98 | | - // Try extracting datetime types |
99 | | - if let Some(scalar) = try_extract_datetime(py, &obj) { |
100 | | - return scalar; |
101 | | - } |
102 | | - |
103 | | - // Try extracting date type |
104 | | - if let Some(scalar) = try_extract_date(py, &obj) { |
105 | | - return scalar; |
106 | | - } |
107 | 92 |
|
108 | | - // If we reach here, the type is unsupported |
109 | | - panic!("Unsupported value type") |
110 | | -} |
111 | | - |
112 | | -/// Try to extract primitive types (bool, numbers, string) |
113 | | -fn try_extract_primitive(py: Python, obj: &PyObject) -> Option<ScalarValue> { |
| 93 | +/// Convert a Python object to ScalarValue using PyArrow |
| 94 | +/// |
| 95 | +/// Args: |
| 96 | +/// py: Python interpreter |
| 97 | +/// obj: Python object to convert |
| 98 | +/// |
| 99 | +/// Returns: |
| 100 | +/// ScalarValue representation of the Python object |
| 101 | +/// |
| 102 | +/// This function handles basic Python types directly and uses PyArrow |
| 103 | +/// for complex types like datetime. |
| 104 | +pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> PyResult<ScalarValue> { |
| 105 | + // Try basic types first for efficiency |
114 | 106 | if let Ok(value) = obj.extract::<bool>(py) { |
115 | | - Some(ScalarValue::Boolean(Some(value))) |
| 107 | + return Ok(ScalarValue::Boolean(Some(value))); |
116 | 108 | } else if let Ok(value) = obj.extract::<i64>(py) { |
117 | | - Some(ScalarValue::Int64(Some(value))) |
| 109 | + return Ok(ScalarValue::Int64(Some(value))); |
118 | 110 | } else if let Ok(value) = obj.extract::<u64>(py) { |
119 | | - Some(ScalarValue::UInt64(Some(value))) |
| 111 | + return Ok(ScalarValue::UInt64(Some(value))); |
120 | 112 | } else if let Ok(value) = obj.extract::<f64>(py) { |
121 | | - Some(ScalarValue::Float64(Some(value))) |
| 113 | + return Ok(ScalarValue::Float64(Some(value))); |
122 | 114 | } else if let Ok(value) = obj.extract::<String>(py) { |
123 | | - Some(ScalarValue::Utf8(Some(value))) |
124 | | - } else { |
125 | | - None |
126 | | - } |
127 | | -} |
128 | | - |
129 | | -/// Try to extract datetime object to TimestampNanosecond |
130 | | -fn try_extract_datetime(py: Python, obj: &PyObject) -> Option<ScalarValue> { |
131 | | - let datetime_module = py.import("datetime").ok()?; |
132 | | - let datetime_class = datetime_module.getattr("datetime").ok()?; |
133 | | - |
134 | | - if obj.is_instance(py, datetime_class).ok()? { |
135 | | - // Extract timestamp as nanoseconds |
136 | | - let timestamp = obj.call_method0(py, "timestamp").ok()?; |
137 | | - let seconds_f64 = timestamp.extract::<f64>(py).ok()?; |
138 | | - |
139 | | - // Convert seconds to nanoseconds |
140 | | - let nanos = (seconds_f64 * 1_000_000_000.0) as i64; |
141 | | - return Some(ScalarValue::TimestampNanosecond(Some(nanos), None)); |
| 115 | + return Ok(ScalarValue::Utf8(Some(value))); |
142 | 116 | } |
143 | 117 |
|
144 | | - None |
145 | | -} |
146 | | - |
147 | | -/// Try to extract date object to Date64 |
148 | | -fn try_extract_date(py: Python, obj: &PyObject) -> Option<ScalarValue> { |
149 | | - let datetime_module = py.import("datetime").ok()?; |
150 | | - let date_class = datetime_module.getattr("date").ok()?; |
| 118 | + // For datetime and other complex types, convert via PyArrow |
| 119 | + let pa = py.import("pyarrow")?; |
151 | 120 |
|
152 | | - let any = PyAny::from(obj); |
153 | | - let py_any: Bound<_, PyAny> = Bound::new(py, any).ok()?; |
154 | | - if any.is_instance_of(py, date_class).ok()? { |
155 | | - // Check if it's actually a datetime (which also is an instance of date) |
156 | | - let datetime_class = datetime_module.getattr("datetime").ok()?; |
157 | | - if any.is_instance(py, datetime_class).ok()? { |
158 | | - return None; // Let the datetime handler take care of it |
159 | | - } |
160 | | - |
161 | | - // Extract date components |
162 | | - let year = obj.getattr(py, "year").ok()?.extract::<i32>(py).ok()?; |
163 | | - let month = obj.getattr(py, "month").ok()?.extract::<u8>(py).ok()?; |
164 | | - let day = obj.getattr(py, "day").ok()?.extract::<u8>(py).ok()?; |
165 | | - |
166 | | - // Calculate milliseconds since epoch (1970-01-01) |
167 | | - let millis = date_to_millis(year, month as u32, day as u32)?; |
168 | | - return Some(ScalarValue::Date64(Some(millis))); |
169 | | - } |
| 121 | + // Convert Python object to PyArrow scalar |
| 122 | + // This handles datetime objects by converting to PyArrow timestamp type |
| 123 | + let scalar = pa.call_method1("scalar", (obj,))?; |
170 | 124 |
|
171 | | - None |
172 | | -} |
173 | | - |
174 | | -/// Convert year, month, day to milliseconds since Unix epoch |
175 | | -fn date_to_millis(year: i32, month: u32, day: u32) -> Option<i64> { |
176 | | - // Validate inputs |
177 | | - if month < 1 || month > 12 || day < 1 || day > 31 { |
178 | | - return None; |
179 | | - } |
180 | | - |
181 | | - // Days in each month (non-leap year) |
182 | | - let days_in_month = [0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]; |
183 | | - |
184 | | - // Check if the day is valid for the given month |
185 | | - let max_day = if month == 2 && is_leap_year(year) { |
186 | | - 29 |
187 | | - } else { |
188 | | - days_in_month[month as usize] |
189 | | - }; |
190 | | - |
191 | | - if day > max_day { |
192 | | - return None; |
193 | | - } |
194 | | - |
195 | | - // Calculate days since epoch |
196 | | - let mut total_days: i64 = 0; |
197 | | - |
198 | | - // Handle years |
199 | | - let year_diff = year - 1970; |
200 | | - if year_diff >= 0 { |
201 | | - // Years from 1970 to year-1 |
202 | | - for y in 1970..year { |
203 | | - total_days += if is_leap_year(y) { 366 } else { 365 }; |
204 | | - } |
205 | | - } else { |
206 | | - // Years from year to 1969 |
207 | | - for y in year..1970 { |
208 | | - total_days -= if is_leap_year(y) { 366 } else { 365 }; |
209 | | - } |
210 | | - } |
211 | | - |
212 | | - // Add days for the months in the current year |
213 | | - for m in 1..month { |
214 | | - total_days += if m == 2 && is_leap_year(year) { |
215 | | - 29 |
216 | | - } else { |
217 | | - days_in_month[m as usize] |
218 | | - } as i64; |
219 | | - } |
220 | | - |
221 | | - // Add days in the current month |
222 | | - total_days += (day as i64) - 1; |
223 | | - |
224 | | - // Convert days to milliseconds (86,400,000 ms per day) |
225 | | - Some(total_days * 86_400_000) |
226 | | -} |
| 125 | + // Convert PyArrow scalar to PyScalarValue |
| 126 | + let py_scalar = PyScalarValue::extract(scalar.as_ref())?; |
227 | 127 |
|
228 | | -/// Check if a year is a leap year |
229 | | -fn is_leap_year(year: i32) -> bool { |
230 | | - (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0) |
| 128 | + // Convert PyScalarValue to ScalarValue |
| 129 | + Ok(py_scalar.into()) |
231 | 130 | } |
0 commit comments