diff --git a/datafusion/functions/src/math/iszero.rs b/datafusion/functions/src/math/iszero.rs index 6349551ca0a43..ba4afc5622eb3 100644 --- a/datafusion/functions/src/math/iszero.rs +++ b/datafusion/functions/src/math/iszero.rs @@ -18,12 +18,13 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, AsArray, BooleanArray}; -use arrow::datatypes::DataType::{Boolean, Float32, Float64}; -use arrow::datatypes::{DataType, Float32Type, Float64Type}; +use arrow::array::{ArrayRef, ArrowNativeTypeOp, AsArray, BooleanArray}; +use arrow::datatypes::DataType::{Boolean, Float16, Float32, Float64}; +use arrow::datatypes::{DataType, Float16Type, Float32Type, Float64Type}; -use datafusion_common::{Result, exec_err}; -use datafusion_expr::TypeSignature::Exact; +use datafusion_common::types::NativeType; +use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_expr::{Coercion, TypeSignatureClass}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -59,12 +60,14 @@ impl Default for IsZeroFunc { impl IsZeroFunc { pub fn new() -> Self { - use DataType::*; + // Accept any numeric type and coerce to float + let float = Coercion::new_implicit( + TypeSignatureClass::Float, + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ); Self { - signature: Signature::one_of( - vec![Exact(vec![Float32]), Exact(vec![Float64])], - Volatility::Immutable, - ), + signature: Signature::coercible(vec![float], Volatility::Immutable), } } } @@ -87,6 +90,10 @@ impl ScalarUDFImpl for IsZeroFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + // Handle NULL input + if args.args[0].data_type().is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); + } make_scalar_function(iszero, vec![])(&args.args) } @@ -108,6 +115,11 @@ fn iszero(args: &[ArrayRef]) -> Result { |x| x == 0.0, )) as ArrayRef), + Float16 => Ok(Arc::new(BooleanArray::from_unary( + args[0].as_primitive::(), + |x| x.is_zero(), + )) as ArrayRef), + other => exec_err!("Unsupported data type {other:?} for function iszero"), } } diff --git a/datafusion/functions/src/math/nans.rs b/datafusion/functions/src/math/nans.rs index be21cfde0aa6c..03f246c28be19 100644 --- a/datafusion/functions/src/math/nans.rs +++ b/datafusion/functions/src/math/nans.rs @@ -17,9 +17,10 @@ //! Math function: `isnan()`. -use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use datafusion_common::{Result, exec_err}; -use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, TypeSignature}; +use arrow::datatypes::{DataType, Float16Type, Float32Type, Float64Type}; +use datafusion_common::types::NativeType; +use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_expr::{Coercion, ColumnarValue, ScalarFunctionArgs, TypeSignatureClass}; use arrow::array::{ArrayRef, AsArray, BooleanArray}; use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility}; @@ -54,15 +55,14 @@ impl Default for IsNanFunc { impl IsNanFunc { pub fn new() -> Self { - use DataType::*; + // Accept any numeric type and coerce to float + let float = Coercion::new_implicit( + TypeSignatureClass::Float, + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ); Self { - signature: Signature::one_of( - vec![ - TypeSignature::Exact(vec![Float32]), - TypeSignature::Exact(vec![Float64]), - ], - Volatility::Immutable, - ), + signature: Signature::coercible(vec![float], Volatility::Immutable), } } } @@ -84,6 +84,11 @@ impl ScalarUDFImpl for IsNanFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + // Handle NULL input + if args.args[0].data_type().is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); + } + let args = ColumnarValue::values_to_arrays(&args.args)?; let arr: ArrayRef = match args[0].data_type() { @@ -96,6 +101,11 @@ impl ScalarUDFImpl for IsNanFunc { args[0].as_primitive::(), f32::is_nan, )) as ArrayRef, + + DataType::Float16 => Arc::new(BooleanArray::from_unary( + args[0].as_primitive::(), + |x| x.is_nan(), + )) as ArrayRef, other => { return exec_err!( "Unsupported data type {other:?} for function {}", diff --git a/datafusion/functions/src/math/nanvl.rs b/datafusion/functions/src/math/nanvl.rs index 345b1a5b71aef..6daf476e250d3 100644 --- a/datafusion/functions/src/math/nanvl.rs +++ b/datafusion/functions/src/math/nanvl.rs @@ -20,9 +20,9 @@ use std::sync::Arc; use crate::utils::make_scalar_function; -use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array}; -use arrow::datatypes::DataType::{Float32, Float64}; -use arrow::datatypes::{DataType, Float32Type, Float64Type}; +use arrow::array::{ArrayRef, AsArray, Float16Array, Float32Array, Float64Array}; +use arrow::datatypes::DataType::{Float16, Float32, Float64}; +use arrow::datatypes::{DataType, Float16Type, Float32Type, Float64Type}; use datafusion_common::{DataFusionError, Result, exec_err}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ @@ -66,10 +66,13 @@ impl Default for NanvlFunc { impl NanvlFunc { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], + vec![ + Exact(vec![Float16, Float16]), + Exact(vec![Float32, Float32]), + Exact(vec![Float64, Float64]), + ], Volatility::Immutable, ), } @@ -91,6 +94,7 @@ impl ScalarUDFImpl for NanvlFunc { fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[0] { + Float16 => Ok(Float16), Float32 => Ok(Float32), _ => Ok(Float64), } @@ -130,6 +134,19 @@ fn nanvl(args: &[ArrayRef]) -> Result { .map(|res| Arc::new(res) as _) .map_err(DataFusionError::from) } + Float16 => { + let compute_nanvl = + |x: ::Native, + y: ::Native| { + if x.is_nan() { y } else { x } + }; + + let x = args[0].as_primitive() as &Float16Array; + let y = args[1].as_primitive() as &Float16Array; + arrow::compute::binary::<_, _, _, Float16Type>(x, y, compute_nanvl) + .map(|res| Arc::new(res) as _) + .map_err(DataFusionError::from) + } other => exec_err!("Unsupported data type {other:?} for function nanvl"), } }