Skip to content

Commit 72df613

Browse files
committed
Update naming from type to field where appropriate
1 parent ecaf49b commit 72df613

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

python/datafusion/user_defined.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ def __init__(
123123
self,
124124
name: str,
125125
func: Callable[..., _R],
126-
input_types: list[pa.Field],
127-
return_type: _R,
126+
input_fields: list[pa.Field],
127+
return_field: _R,
128128
volatility: Volatility | str,
129129
) -> None:
130130
"""Instantiate a scalar user-defined function (UDF).
@@ -134,10 +134,10 @@ def __init__(
134134
if hasattr(func, "__datafusion_scalar_udf__"):
135135
self._udf = df_internal.ScalarUDF.from_pycapsule(func)
136136
return
137-
if isinstance(input_types, pa.DataType):
138-
input_types = [input_types]
137+
if isinstance(input_fields, pa.DataType):
138+
input_fields = [input_fields]
139139
self._udf = df_internal.ScalarUDF(
140-
name, func, input_types, return_type, str(volatility)
140+
name, func, input_fields, return_field, str(volatility)
141141
)
142142

143143
def __repr__(self) -> str:
@@ -156,8 +156,8 @@ def __call__(self, *args: Expr) -> Expr:
156156
@overload
157157
@staticmethod
158158
def udf(
159-
input_types: Sequence[pa.DataType | pa.Field] | pa.DataType | pa.Field,
160-
return_type: pa.DataType | pa.Field,
159+
input_fields: Sequence[pa.DataType | pa.Field] | pa.DataType | pa.Field,
160+
return_field: pa.DataType | pa.Field,
161161
volatility: Volatility | str,
162162
name: str | None = None,
163163
) -> Callable[..., ScalarUDF]: ...
@@ -166,8 +166,8 @@ def udf(
166166
@staticmethod
167167
def udf(
168168
func: Callable[..., _R],
169-
input_types: Sequence[pa.DataType | pa.Field] | pa.DataType | pa.Field,
170-
return_type: pa.DataType | pa.Field,
169+
input_fields: Sequence[pa.DataType | pa.Field] | pa.DataType | pa.Field,
170+
return_field: pa.DataType | pa.Field,
171171
volatility: Volatility | str,
172172
name: str | None = None,
173173
) -> ScalarUDF: ...
@@ -193,10 +193,10 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417
193193
backed ScalarUDF within a PyCapsule, you can pass this parameter
194194
and ignore the rest. They will be determined directly from the
195195
underlying function. See the online documentation for more information.
196-
input_types (list[pa.DataType]): The data types of the arguments
197-
to ``func``. This list must be of the same length as the number of
198-
arguments.
199-
return_type (_R): The data type of the return value from the function.
196+
input_fields (list[pa.Field | pa.DataType]): The data types or Fields
197+
of the arguments to ``func``. This list must be of the same length
198+
as the number of arguments.
199+
return_field (_R): The field of the return value from the function.
200200
volatility (Volatility | str): See `Volatility` for allowed values.
201201
name (Optional[str]): A descriptive name for the function.
202202
@@ -220,8 +220,8 @@ def double_udf(x):
220220

221221
def _function(
222222
func: Callable[..., _R],
223-
input_types: Sequence[pa.DataType | pa.Field] | pa.DataType | pa.Field,
224-
return_type: pa.DataType | pa.Field,
223+
input_fields: Sequence[pa.DataType | pa.Field] | pa.DataType | pa.Field,
224+
return_field: pa.DataType | pa.Field,
225225
volatility: Volatility | str,
226226
name: str | None = None,
227227
) -> ScalarUDF:
@@ -233,25 +233,25 @@ def _function(
233233
name = func.__qualname__.lower()
234234
else:
235235
name = func.__class__.__name__.lower()
236-
input_types = data_types_or_fields_to_field_list(input_types)
237-
return_type = data_type_or_field_to_field(return_type, "value")
236+
input_fields = data_types_or_fields_to_field_list(input_fields)
237+
return_field = data_type_or_field_to_field(return_field, "value")
238238
return ScalarUDF(
239239
name=name,
240240
func=func,
241-
input_types=input_types,
242-
return_type=return_type,
241+
input_fields=input_fields,
242+
return_field=return_field,
243243
volatility=volatility,
244244
)
245245

246246
def _decorator(
247-
input_types: list[pa.DataType],
248-
return_type: _R,
247+
input_fields: list[pa.DataType],
248+
return_field: _R,
249249
volatility: Volatility | str,
250250
name: str | None = None,
251251
) -> Callable:
252252
def decorator(func: Callable) -> Callable:
253253
udf_caller = ScalarUDF.udf(
254-
func, input_types, return_type, volatility, name
254+
func, input_fields, return_field, volatility, name
255255
)
256256

257257
@functools.wraps(func)
@@ -282,8 +282,8 @@ def from_pycapsule(func: ScalarUDFExportable) -> ScalarUDF:
282282
return ScalarUDF(
283283
name=name,
284284
func=func,
285-
input_types=None,
286-
return_type=None,
285+
input_fields=None,
286+
return_field=None,
287287
volatility=None,
288288
)
289289

0 commit comments

Comments
 (0)