Skip to content

Commit 24549bd

Browse files
committed
Implement metadata-aware PySimpleScalarUDF
Enhance scalar UDF definitions to retain Arrow Field information, including extension metadata, in DataFusion. Normalize Python UDF signatures to accept pyarrow.Field objects, ensuring metadata survives the Rust bindings roundtrip. Add a regression test for UUID-backed UDFs to verify that the second UDF correctly receives a pyarrow.ExtensionArray, preventing past metadata loss.
1 parent 3f0338b commit 24549bd

File tree

3 files changed

+185
-27
lines changed

3 files changed

+185
-27
lines changed

python/datafusion/user_defined.py

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,13 @@
2222
import functools
2323
from abc import ABCMeta, abstractmethod
2424
from enum import Enum
25-
from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, TypeVar, overload
25+
from typing import Any, Callable, Optional, Protocol, Sequence, overload
2626

2727
import pyarrow as pa
2828

2929
import datafusion._internal as df_internal
3030
from datafusion.expr import Expr
3131

32-
if TYPE_CHECKING:
33-
_R = TypeVar("_R", bound=pa.DataType)
34-
35-
3632
class Volatility(Enum):
3733
"""Defines how stable or volatile a function is.
3834
@@ -77,6 +73,40 @@ def __str__(self) -> str:
7773
return self.name.lower()
7874

7975

76+
def _normalize_field(value: pa.DataType | pa.Field, *, default_name: str) -> pa.Field:
77+
if isinstance(value, pa.Field):
78+
return value
79+
if isinstance(value, pa.DataType):
80+
return pa.field(default_name, value)
81+
msg = "Expected a pyarrow.DataType or pyarrow.Field"
82+
raise TypeError(msg)
83+
84+
85+
def _normalize_input_fields(
86+
values: pa.DataType | pa.Field | Sequence[pa.DataType | pa.Field],
87+
) -> list[pa.Field]:
88+
if isinstance(values, (pa.DataType, pa.Field)):
89+
sequence: Sequence[pa.DataType | pa.Field] = [values]
90+
elif isinstance(values, Sequence) and not isinstance(values, (str, bytes)):
91+
sequence = values
92+
else:
93+
msg = "input_types must be a DataType, Field, or a sequence of them"
94+
raise TypeError(msg)
95+
96+
return [
97+
_normalize_field(value, default_name=f"arg_{idx}") for idx, value in enumerate(sequence)
98+
]
99+
100+
101+
def _normalize_return_field(
102+
value: pa.DataType | pa.Field,
103+
*,
104+
name: str,
105+
) -> pa.Field:
106+
default_name = f"{name}_result" if name else "result"
107+
return _normalize_field(value, default_name=default_name)
108+
109+
80110
class ScalarUDFExportable(Protocol):
81111
"""Type hint for object that has __datafusion_scalar_udf__ PyCapsule."""
82112

@@ -93,9 +123,9 @@ class ScalarUDF:
93123
def __init__(
94124
self,
95125
name: str,
96-
func: Callable[..., _R],
97-
input_types: pa.DataType | list[pa.DataType],
98-
return_type: _R,
126+
func: Callable[..., Any],
127+
input_types: pa.DataType | pa.Field | Sequence[pa.DataType | pa.Field],
128+
return_type: pa.DataType | pa.Field,
99129
volatility: Volatility | str,
100130
) -> None:
101131
"""Instantiate a scalar user-defined function (UDF).
@@ -105,10 +135,10 @@ def __init__(
105135
if hasattr(func, "__datafusion_scalar_udf__"):
106136
self._udf = df_internal.ScalarUDF.from_pycapsule(func)
107137
return
108-
if isinstance(input_types, pa.DataType):
109-
input_types = [input_types]
138+
normalized_inputs = _normalize_input_fields(input_types)
139+
normalized_return = _normalize_return_field(return_type, name=name)
110140
self._udf = df_internal.ScalarUDF(
111-
name, func, input_types, return_type, str(volatility)
141+
name, func, normalized_inputs, normalized_return, str(volatility)
112142
)
113143

114144
def __repr__(self) -> str:
@@ -127,18 +157,18 @@ def __call__(self, *args: Expr) -> Expr:
127157
@overload
128158
@staticmethod
129159
def udf(
130-
input_types: list[pa.DataType],
131-
return_type: _R,
160+
input_types: list[pa.DataType | pa.Field],
161+
return_type: pa.DataType | pa.Field,
132162
volatility: Volatility | str,
133163
name: Optional[str] = None,
134164
) -> Callable[..., ScalarUDF]: ...
135165

136166
@overload
137167
@staticmethod
138168
def udf(
139-
func: Callable[..., _R],
140-
input_types: list[pa.DataType],
141-
return_type: _R,
169+
func: Callable[..., Any],
170+
input_types: list[pa.DataType | pa.Field],
171+
return_type: pa.DataType | pa.Field,
142172
volatility: Volatility | str,
143173
name: Optional[str] = None,
144174
) -> ScalarUDF: ...
@@ -164,10 +194,11 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417
164194
backed ScalarUDF within a PyCapsule, you can pass this parameter
165195
and ignore the rest. They will be determined directly from the
166196
underlying function. See the online documentation for more information.
167-
input_types (list[pa.DataType]): The data types of the arguments
168-
to ``func``. This list must be of the same length as the number of
169-
arguments.
170-
return_type (_R): The data type of the return value from the function.
197+
input_types (list[pa.DataType | pa.Field]): The argument types for ``func``.
198+
This list must be of the same length as the number of arguments. Pass
199+
:class:`pyarrow.Field` instances to preserve extension metadata.
200+
return_type (pa.DataType | pa.Field): The return type of the function. Use a
201+
:class:`pyarrow.Field` to preserve metadata on extension arrays.
171202
volatility (Volatility | str): See `Volatility` for allowed values.
172203
name (Optional[str]): A descriptive name for the function.
173204

python/tests/test_udf.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,58 @@ def udf_with_param(values: pa.Array) -> pa.Array:
124124
result = df2.collect()[0].column(0)
125125

126126
assert result == pa.array([False, True, True])
127+
128+
129+
def test_uuid_extension_chain(ctx) -> None:
130+
uuid_type = pa.uuid()
131+
uuid_field = pa.field("uuid_col", uuid_type)
132+
133+
first = udf(
134+
lambda values: values,
135+
[uuid_field],
136+
uuid_field,
137+
volatility="immutable",
138+
name="uuid_identity",
139+
)
140+
141+
def ensure_extension(values: pa.Array) -> pa.Array:
142+
assert isinstance(values, pa.ExtensionArray)
143+
return values
144+
145+
second = udf(
146+
ensure_extension,
147+
[uuid_field],
148+
uuid_field,
149+
volatility="immutable",
150+
name="uuid_assert",
151+
)
152+
153+
batch = pa.RecordBatch.from_arrays(
154+
[
155+
pa.array(
156+
[
157+
"00000000-0000-0000-0000-000000000000",
158+
"00000000-0000-0000-0000-000000000001",
159+
],
160+
type=uuid_type,
161+
)
162+
],
163+
names=["uuid_col"],
164+
)
165+
166+
df = ctx.create_dataframe([[batch]])
167+
result = (
168+
df.select(second(first(column("uuid_col"))))
169+
.collect()[0]
170+
.column(0)
171+
)
172+
173+
expected = pa.array(
174+
[
175+
"00000000-0000-0000-0000-000000000000",
176+
"00000000-0000-0000-0000-000000000001",
177+
],
178+
type=uuid_type,
179+
)
180+
181+
assert result.equals(expected)

src/udf.rs

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,16 @@ use pyo3::types::PyCapsule;
2222
use pyo3::{prelude::*, types::PyTuple};
2323

2424
use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef};
25-
use datafusion::arrow::datatypes::DataType;
25+
use datafusion::arrow::datatypes::{DataType, Field};
2626
use datafusion::arrow::pyarrow::FromPyArrow;
2727
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
2828
use datafusion::error::DataFusionError;
2929
use datafusion::logical_expr::function::ScalarFunctionImplementation;
30-
use datafusion::logical_expr::ScalarUDF;
31-
use datafusion::logical_expr::{create_udf, ColumnarValue};
30+
use datafusion::logical_expr::ptr_eq::PtrEq;
31+
use datafusion::logical_expr::{
32+
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
33+
Volatility,
34+
};
3235

3336
use crate::errors::to_datafusion_err;
3437
use crate::errors::{py_datafusion_err, PyDataFusionResult};
@@ -80,6 +83,73 @@ fn to_scalar_function_impl(func: PyObject) -> ScalarFunctionImplementation {
8083
})
8184
}
8285

86+
#[derive(Debug, PartialEq, Eq, Hash)]
87+
struct PySimpleScalarUDF {
88+
name: String,
89+
signature: Signature,
90+
return_field: Arc<Field>,
91+
fun: PtrEq<ScalarFunctionImplementation>,
92+
}
93+
94+
impl PySimpleScalarUDF {
95+
fn new(
96+
name: impl Into<String>,
97+
input_fields: Vec<Field>,
98+
return_field: Field,
99+
volatility: Volatility,
100+
fun: ScalarFunctionImplementation,
101+
) -> Self {
102+
let signature_types = input_fields
103+
.into_iter()
104+
.map(|field| field.data_type().clone())
105+
.collect();
106+
let signature = Signature::exact(signature_types, volatility);
107+
Self {
108+
name: name.into(),
109+
signature,
110+
return_field: Arc::new(return_field),
111+
fun: fun.into(),
112+
}
113+
}
114+
}
115+
116+
impl ScalarUDFImpl for PySimpleScalarUDF {
117+
fn as_any(&self) -> &dyn std::any::Any {
118+
self
119+
}
120+
121+
fn name(&self) -> &str {
122+
&self.name
123+
}
124+
125+
fn signature(&self) -> &Signature {
126+
&self.signature
127+
}
128+
129+
fn return_type(&self, _arg_types: &[DataType]) -> datafusion::error::Result<DataType> {
130+
Ok(self.return_field.data_type().clone())
131+
}
132+
133+
fn return_field_from_args(
134+
&self,
135+
_args: ReturnFieldArgs,
136+
) -> datafusion::error::Result<Arc<Field>> {
137+
Ok(Arc::new(
138+
self.return_field
139+
.as_ref()
140+
.clone()
141+
.with_name(self.name.clone()),
142+
))
143+
}
144+
145+
fn invoke_with_args(
146+
&self,
147+
args: ScalarFunctionArgs,
148+
) -> datafusion::error::Result<ColumnarValue> {
149+
(self.fun)(&args.args)
150+
}
151+
}
152+
83153
/// Represents a PyScalarUDF
84154
#[pyclass(frozen, name = "ScalarUDF", module = "datafusion", subclass)]
85155
#[derive(Debug, Clone)]
@@ -94,17 +164,19 @@ impl PyScalarUDF {
94164
fn new(
95165
name: &str,
96166
func: PyObject,
97-
input_types: PyArrowType<Vec<DataType>>,
98-
return_type: PyArrowType<DataType>,
167+
input_types: PyArrowType<Vec<Field>>,
168+
return_type: PyArrowType<Field>,
99169
volatility: &str,
100170
) -> PyResult<Self> {
101-
let function = create_udf(
171+
let volatility = parse_volatility(volatility)?;
172+
let scalar_impl = PySimpleScalarUDF::new(
102173
name,
103174
input_types.0,
104175
return_type.0,
105-
parse_volatility(volatility)?,
176+
volatility,
106177
to_scalar_function_impl(func),
107178
);
179+
let function = ScalarUDF::new_from_impl(scalar_impl);
108180
Ok(Self { function })
109181
}
110182

0 commit comments

Comments
 (0)