Skip to content

Commit d546f7a

Browse files
committed
refactor: modify py_obj_to_scalar_value to return ScalarValue directly and streamline error handling
1 parent b5d87b0 commit d546f7a

File tree

2 files changed

+21
-21
lines changed

2 files changed

+21
-21
lines changed

src/dataframe.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,7 @@ impl PyDataFrame {
720720
columns: Option<Vec<PyBackedStr>>,
721721
py: Python,
722722
) -> PyDataFusionResult<Self> {
723-
let scalar_value = py_obj_to_scalar_value(py, value)?;
723+
let scalar_value = py_obj_to_scalar_value(py, value);
724724

725725
let cols = match columns {
726726
Some(col_names) => col_names.iter().map(|c| c.to_string()).collect(),

src/utils.rs

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -101,30 +101,30 @@ pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyRe
101101
///
102102
/// This function handles basic Python types directly and uses PyArrow
103103
/// for complex types like datetime.
104-
pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> PyResult<ScalarValue> {
105-
// Try basic types first for efficiency
104+
pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> ScalarValue {
106105
if let Ok(value) = obj.extract::<bool>(py) {
107-
return Ok(ScalarValue::Boolean(Some(value)));
106+
ScalarValue::Boolean(Some(value))
108107
} else if let Ok(value) = obj.extract::<i64>(py) {
109-
return Ok(ScalarValue::Int64(Some(value)));
108+
ScalarValue::Int64(Some(value))
110109
} else if let Ok(value) = obj.extract::<u64>(py) {
111-
return Ok(ScalarValue::UInt64(Some(value)));
110+
ScalarValue::UInt64(Some(value))
112111
} else if let Ok(value) = obj.extract::<f64>(py) {
113-
return Ok(ScalarValue::Float64(Some(value)));
112+
ScalarValue::Float64(Some(value))
114113
} else if let Ok(value) = obj.extract::<String>(py) {
115-
return Ok(ScalarValue::Utf8(Some(value)));
114+
ScalarValue::Utf8(Some(value))
115+
} else {
116+
// For datetime and other complex types, convert via PyArrow
117+
let pa = py.import("pyarrow");
118+
let pa = pa.expect("Failed to import PyArrow");
119+
// Convert Python object to PyArrow scalar
120+
// This handles datetime objects by converting to PyArrow timestamp type
121+
let scalar = pa.call_method1("scalar", (obj,));
122+
let scalar = scalar.expect("Failed to convert Python object to PyArrow scalar");
123+
// Convert PyArrow scalar to PyScalarValue
124+
let py_scalar = PyScalarValue::extract_bound(scalar.as_ref());
125+
// Unwrap the result - this will panic if extraction failed
126+
let py_scalar = py_scalar.expect("Failed to extract PyScalarValue from PyArrow scalar");
127+
// Convert PyScalarValue to ScalarValue
128+
py_scalar.into()
116129
}
117-
118-
// For datetime and other complex types, convert via PyArrow
119-
let pa = py.import("pyarrow")?;
120-
121-
// Convert Python object to PyArrow scalar
122-
// This handles datetime objects by converting to PyArrow timestamp type
123-
let scalar = pa.call_method1("scalar", (obj,))?;
124-
125-
// Convert PyArrow scalar to PyScalarValue
126-
let py_scalar = PyScalarValue::extract_bound(scalar.as_ref())?;
127-
128-
// Convert PyScalarValue to ScalarValue
129-
Ok(py_scalar.into())
130130
}

0 commit comments

Comments
 (0)