Skip to content

Commit 82bf6f4

Browse files
committed
refactor: enhance py_obj_to_scalar_value to utilize PyArrow for complex type conversion
1 parent 4c40b85 commit 82bf6f4

File tree

1 file changed

+28
-129
lines changed

1 file changed

+28
-129
lines changed

src/utils.rs

Lines changed: 28 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::common::data_type::PyScalarValue;
1819
use crate::errors::{PyDataFusionError, PyDataFusionResult};
1920
use crate::TokioRuntime;
2021
use datafusion::common::ScalarValue;
@@ -88,144 +89,42 @@ pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyRe
8889

8990
Ok(())
9091
}
91-
/// Convert a python object to a ScalarValue
92-
pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> ScalarValue {
93-
// Try extracting primitive types first
94-
if let Some(scalar) = try_extract_primitive(py, &obj) {
95-
return scalar;
96-
}
97-
98-
// Try extracting datetime types
99-
if let Some(scalar) = try_extract_datetime(py, &obj) {
100-
return scalar;
101-
}
102-
103-
// Try extracting date type
104-
if let Some(scalar) = try_extract_date(py, &obj) {
105-
return scalar;
106-
}
10792

108-
// If we reach here, the type is unsupported
109-
panic!("Unsupported value type")
110-
}
111-
112-
/// Try to extract primitive types (bool, numbers, string)
113-
fn try_extract_primitive(py: Python, obj: &PyObject) -> Option<ScalarValue> {
93+
/// Convert a Python object to ScalarValue using PyArrow
94+
///
95+
/// Args:
96+
/// py: Python interpreter
97+
/// obj: Python object to convert
98+
///
99+
/// Returns:
100+
/// ScalarValue representation of the Python object
101+
///
102+
/// This function handles basic Python types directly and uses PyArrow
103+
/// for complex types like datetime.
104+
pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> PyResult<ScalarValue> {
105+
// Try basic types first for efficiency
114106
if let Ok(value) = obj.extract::<bool>(py) {
115-
Some(ScalarValue::Boolean(Some(value)))
107+
return Ok(ScalarValue::Boolean(Some(value)));
116108
} else if let Ok(value) = obj.extract::<i64>(py) {
117-
Some(ScalarValue::Int64(Some(value)))
109+
return Ok(ScalarValue::Int64(Some(value)));
118110
} else if let Ok(value) = obj.extract::<u64>(py) {
119-
Some(ScalarValue::UInt64(Some(value)))
111+
return Ok(ScalarValue::UInt64(Some(value)));
120112
} else if let Ok(value) = obj.extract::<f64>(py) {
121-
Some(ScalarValue::Float64(Some(value)))
113+
return Ok(ScalarValue::Float64(Some(value)));
122114
} else if let Ok(value) = obj.extract::<String>(py) {
123-
Some(ScalarValue::Utf8(Some(value)))
124-
} else {
125-
None
126-
}
127-
}
128-
129-
/// Try to extract datetime object to TimestampNanosecond
130-
fn try_extract_datetime(py: Python, obj: &PyObject) -> Option<ScalarValue> {
131-
let datetime_module = py.import("datetime").ok()?;
132-
let datetime_class = datetime_module.getattr("datetime").ok()?;
133-
134-
if obj.is_instance(py, datetime_class).ok()? {
135-
// Extract timestamp as nanoseconds
136-
let timestamp = obj.call_method0(py, "timestamp").ok()?;
137-
let seconds_f64 = timestamp.extract::<f64>(py).ok()?;
138-
139-
// Convert seconds to nanoseconds
140-
let nanos = (seconds_f64 * 1_000_000_000.0) as i64;
141-
return Some(ScalarValue::TimestampNanosecond(Some(nanos), None));
115+
return Ok(ScalarValue::Utf8(Some(value)));
142116
}
143117

144-
None
145-
}
146-
147-
/// Try to extract date object to Date64
148-
fn try_extract_date(py: Python, obj: &PyObject) -> Option<ScalarValue> {
149-
let datetime_module = py.import("datetime").ok()?;
150-
let date_class = datetime_module.getattr("date").ok()?;
118+
// For datetime and other complex types, convert via PyArrow
119+
let pa = py.import("pyarrow")?;
151120

152-
let any = PyAny::from(obj);
153-
let py_any: Bound<_, PyAny> = Bound::new(py, any).ok()?;
154-
if any.is_instance_of(py, date_class).ok()? {
155-
// Check if it's actually a datetime (which also is an instance of date)
156-
let datetime_class = datetime_module.getattr("datetime").ok()?;
157-
if any.is_instance(py, datetime_class).ok()? {
158-
return None; // Let the datetime handler take care of it
159-
}
160-
161-
// Extract date components
162-
let year = obj.getattr(py, "year").ok()?.extract::<i32>(py).ok()?;
163-
let month = obj.getattr(py, "month").ok()?.extract::<u8>(py).ok()?;
164-
let day = obj.getattr(py, "day").ok()?.extract::<u8>(py).ok()?;
165-
166-
// Calculate milliseconds since epoch (1970-01-01)
167-
let millis = date_to_millis(year, month as u32, day as u32)?;
168-
return Some(ScalarValue::Date64(Some(millis)));
169-
}
121+
// Convert Python object to PyArrow scalar
122+
// This handles datetime objects by converting to PyArrow timestamp type
123+
let scalar = pa.call_method1("scalar", (obj,))?;
170124

171-
None
172-
}
173-
174-
/// Convert year, month, day to milliseconds since Unix epoch
175-
fn date_to_millis(year: i32, month: u32, day: u32) -> Option<i64> {
176-
// Validate inputs
177-
if month < 1 || month > 12 || day < 1 || day > 31 {
178-
return None;
179-
}
180-
181-
// Days in each month (non-leap year)
182-
let days_in_month = [0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31];
183-
184-
// Check if the day is valid for the given month
185-
let max_day = if month == 2 && is_leap_year(year) {
186-
29
187-
} else {
188-
days_in_month[month as usize]
189-
};
190-
191-
if day > max_day {
192-
return None;
193-
}
194-
195-
// Calculate days since epoch
196-
let mut total_days: i64 = 0;
197-
198-
// Handle years
199-
let year_diff = year - 1970;
200-
if year_diff >= 0 {
201-
// Years from 1970 to year-1
202-
for y in 1970..year {
203-
total_days += if is_leap_year(y) { 366 } else { 365 };
204-
}
205-
} else {
206-
// Years from year to 1969
207-
for y in year..1970 {
208-
total_days -= if is_leap_year(y) { 366 } else { 365 };
209-
}
210-
}
211-
212-
// Add days for the months in the current year
213-
for m in 1..month {
214-
total_days += if m == 2 && is_leap_year(year) {
215-
29
216-
} else {
217-
days_in_month[m as usize]
218-
} as i64;
219-
}
220-
221-
// Add days in the current month
222-
total_days += (day as i64) - 1;
223-
224-
// Convert days to milliseconds (86,400,000 ms per day)
225-
Some(total_days * 86_400_000)
226-
}
125+
// Convert PyArrow scalar to PyScalarValue
126+
let py_scalar = PyScalarValue::extract(scalar.as_ref())?;
227127

228-
/// Check if a year is a leap year
229-
fn is_leap_year(year: i32) -> bool {
230-
(year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)
128+
// Convert PyScalarValue to ScalarValue
129+
Ok(py_scalar.into())
231130
}

0 commit comments

Comments
 (0)