Skip to content

Commit 0dfbdfa

Browse files
committed
refactor: improve type handling in python_value_to_scalar_value function
1 parent aa87a8e commit 0dfbdfa

File tree

1 file changed

+31
-13
lines changed

1 file changed

+31
-13
lines changed

src/dataframe.rs

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use arrow::ffi_stream::FFI_ArrowArrayStream;
2626
use datafusion::arrow::datatypes::Schema;
2727
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
2828
use datafusion::arrow::util::pretty;
29-
use datafusion::common::UnnestOptions;
29+
use datafusion::common::{ScalarValue, UnnestOptions};
3030
use datafusion::config::{CsvOptions, TableParquetOptions};
3131
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
3232
use datafusion::datasource::TableProvider;
@@ -714,7 +714,7 @@ impl PyDataFrame {
714714
#[pyo3(signature = (value, columns=None))]
715715
fn fill_null(
716716
&self,
717-
value: PyObject,
717+
value: PyAny,
718718
columns: Option<Vec<PyBackedStr>>,
719719
py: Python,
720720
) -> PyDataFusionResult<Self> {
@@ -890,6 +890,9 @@ fn python_value_to_scalar_value(value: &PyObject, py: Python) -> PyDataFusionRes
890890
return Err(PyDataFusionError::Common(msg.to_string()));
891891
}
892892

893+
// Convert PyObject to PyAny for easier extraction
894+
let py_any: &PyAny = value.as_ref(py);
895+
893896
// Try extracting different types in sequence
894897
if let Some(scalar) = try_extract_numeric(value, py) {
895898
return Ok(scalar);
@@ -951,13 +954,15 @@ fn try_extract_numeric(value: &PyObject, py: Python) -> Option<ScalarValue> {
951954

952955
/// Try to extract datetime from a Python object
953956
fn try_extract_datetime(value: &PyObject, py: Python) -> Option<ScalarValue> {
954-
let datetime_result = py
957+
let datetime_cls = py
955958
.import("datetime")
956959
.and_then(|m| m.getattr("datetime"))
957960
.ok()?;
958961

959-
if value.is_instance(datetime_result).ok()? {
960-
let dt = value.cast_as::<pyo3::types::PyDateTime>(py).ok()?;
962+
let any: PyAny = value.extract(py).ok()?;
963+
964+
if any.is_instance(datetime_cls).ok()? {
965+
let dt = any.cast_as::<pyo3::types::PyDateTime>(py).ok()?;
961966

962967
// Extract datetime components
963968
let year = dt.get_year() as i32;
@@ -978,17 +983,30 @@ fn try_extract_datetime(value: &PyObject, py: Python) -> Option<ScalarValue> {
978983

979984
/// Try to extract date from a Python object
980985
fn try_extract_date(value: &PyObject, py: Python) -> Option<ScalarValue> {
981-
let date_result = py.import("datetime").and_then(|m| m.getattr("date")).ok()?;
986+
// Import datetime module once
987+
let datetime_mod = py.import("datetime").ok()?;
988+
let date_cls = datetime_mod.getattr("date").ok()?;
989+
let datetime_cls = datetime_mod.getattr("datetime").ok()?;
990+
991+
// convert your PyObject into a &PyAny
992+
let any: PyAny = value.extract(py).ok()?;
993+
994+
// Is it a date?
995+
if any.is_instance(date_cls).ok()? {
996+
// But not a datetime (we assume you handled datetimes elsewhere)
997+
if any.is_instance(datetime_cls).ok()? {
998+
return None;
999+
}
9821000

983-
if value.is_instance(date_result).ok()? {
984-
let date = value.cast_as::<pyo3::types::PyDate>(py).ok()?;
1001+
// Downcast into the PyDate type
1002+
let dt: &PyDate = any.downcast().ok()?;
9851003

986-
// Extract date components
987-
let year = date.get_year() as i32;
988-
let month = date.get_month() as u8;
989-
let day = date.get_day() as u8;
1004+
// Pull out year/month/day
1005+
let year = dt.get_year() as i32;
1006+
let month = dt.get_month() as u8;
1007+
let day = dt.get_day() as u8;
9901008

991-
// Convert to days since epoch
1009+
// Convert to your internal Date32
9921010
let days = date_to_days_since_epoch(year, month, day).ok()?;
9931011
return Some(ScalarValue::Date32(Some(days)));
9941012
}

0 commit comments

Comments
 (0)