@@ -123,8 +123,8 @@ def __init__(
123123 self ,
124124 name : str ,
125125 func : Callable [..., _R ],
126- input_types : list [pa .Field ],
127- return_type : _R ,
126+ input_fields : list [pa .Field ],
127+ return_field : _R ,
128128 volatility : Volatility | str ,
129129 ) -> None :
130130 """Instantiate a scalar user-defined function (UDF).
@@ -134,10 +134,10 @@ def __init__(
134134 if hasattr (func , "__datafusion_scalar_udf__" ):
135135 self ._udf = df_internal .ScalarUDF .from_pycapsule (func )
136136 return
137- if isinstance (input_types , pa .DataType ):
138- input_types = [input_types ]
137+ if isinstance (input_fields , pa .DataType ):
138+ input_fields = [input_fields ]
139139 self ._udf = df_internal .ScalarUDF (
140- name , func , input_types , return_type , str (volatility )
140+ name , func , input_fields , return_field , str (volatility )
141141 )
142142
143143 def __repr__ (self ) -> str :
@@ -156,8 +156,8 @@ def __call__(self, *args: Expr) -> Expr:
156156 @overload
157157 @staticmethod
158158 def udf (
159- input_types : Sequence [pa .DataType | pa .Field ] | pa .DataType | pa .Field ,
160- return_type : pa .DataType | pa .Field ,
159+ input_fields : Sequence [pa .DataType | pa .Field ] | pa .DataType | pa .Field ,
160+ return_field : pa .DataType | pa .Field ,
161161 volatility : Volatility | str ,
162162 name : str | None = None ,
163163 ) -> Callable [..., ScalarUDF ]: ...
@@ -166,8 +166,8 @@ def udf(
166166 @staticmethod
167167 def udf (
168168 func : Callable [..., _R ],
169- input_types : Sequence [pa .DataType | pa .Field ] | pa .DataType | pa .Field ,
170- return_type : pa .DataType | pa .Field ,
169+ input_fields : Sequence [pa .DataType | pa .Field ] | pa .DataType | pa .Field ,
170+ return_field : pa .DataType | pa .Field ,
171171 volatility : Volatility | str ,
172172 name : str | None = None ,
173173 ) -> ScalarUDF : ...
@@ -193,10 +193,10 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417
193193 backed ScalarUDF within a PyCapsule, you can pass this parameter
194194 and ignore the rest. They will be determined directly from the
195195 underlying function. See the online documentation for more information.
196- input_types (list[pa.DataType]): The data types of the arguments
197- to ``func``. This list must be of the same length as the number of
198- arguments.
199- return_type (_R): The data type of the return value from the function.
196+ input_fields (list[pa.Field | pa. DataType]): The data types or Fields
197+ of the arguments to ``func``. This list must be of the same length
198+ as the number of arguments.
199+ return_field (_R): The field of the return value from the function.
200200 volatility (Volatility | str): See `Volatility` for allowed values.
201201 name (Optional[str]): A descriptive name for the function.
202202
@@ -220,8 +220,8 @@ def double_udf(x):
220220
221221 def _function (
222222 func : Callable [..., _R ],
223- input_types : Sequence [pa .DataType | pa .Field ] | pa .DataType | pa .Field ,
224- return_type : pa .DataType | pa .Field ,
223+ input_fields : Sequence [pa .DataType | pa .Field ] | pa .DataType | pa .Field ,
224+ return_field : pa .DataType | pa .Field ,
225225 volatility : Volatility | str ,
226226 name : str | None = None ,
227227 ) -> ScalarUDF :
@@ -233,25 +233,25 @@ def _function(
233233 name = func .__qualname__ .lower ()
234234 else :
235235 name = func .__class__ .__name__ .lower ()
236- input_types = data_types_or_fields_to_field_list (input_types )
237- return_type = data_type_or_field_to_field (return_type , "value" )
236+ input_fields = data_types_or_fields_to_field_list (input_fields )
237+ return_field = data_type_or_field_to_field (return_field , "value" )
238238 return ScalarUDF (
239239 name = name ,
240240 func = func ,
241- input_types = input_types ,
242- return_type = return_type ,
241+ input_fields = input_fields ,
242+ return_field = return_field ,
243243 volatility = volatility ,
244244 )
245245
246246 def _decorator (
247- input_types : list [pa .DataType ],
248- return_type : _R ,
247+ input_fields : list [pa .DataType ],
248+ return_field : _R ,
249249 volatility : Volatility | str ,
250250 name : str | None = None ,
251251 ) -> Callable :
252252 def decorator (func : Callable ) -> Callable :
253253 udf_caller = ScalarUDF .udf (
254- func , input_types , return_type , volatility , name
254+ func , input_fields , return_field , volatility , name
255255 )
256256
257257 @functools .wraps (func )
@@ -282,8 +282,8 @@ def from_pycapsule(func: ScalarUDFExportable) -> ScalarUDF:
282282 return ScalarUDF (
283283 name = name ,
284284 func = func ,
285- input_types = None ,
286- return_type = None ,
285+ input_fields = None ,
286+ return_field = None ,
287287 volatility = None ,
288288 )
289289
0 commit comments