diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs index 97703937f39f2..101291ac5f66e 100644 --- a/datafusion/spark/src/function/math/abs.rs +++ b/datafusion/spark/src/function/math/abs.rs @@ -16,10 +16,11 @@ // under the License. use arrow::array::*; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{DataFusionError, Result, ScalarValue, internal_err}; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_functions::{ downcast_named_arg, make_abs_function, make_wrapping_abs_function, @@ -69,8 +70,18 @@ impl ScalarUDFImpl for SparkAbs { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!( + "SparkAbs: return_type() is not used; return_field_from_args() is implemented" + ) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let input_field = &args.arg_fields[0]; + let out_dt = input_field.data_type().clone(); + let out_nullable = input_field.is_nullable(); + + Ok(Arc::new(Field::new(self.name(), out_dt, out_nullable))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -375,4 +386,63 @@ mod tests { as_decimal256_array ); } + + #[test] + fn test_abs_nullability() { + use arrow::datatypes::{DataType, Field}; + use datafusion_expr::ReturnFieldArgs; + use std::sync::Arc; + + let abs = SparkAbs::new(); + + // --- non-nullable Int32 input --- + let non_nullable_i32 = Arc::new(Field::new("c", DataType::Int32, false)); + let out_non_null = abs + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&non_nullable_i32)], + scalar_arguments: &[None], + }) + .unwrap(); + + // result should be non-nullable and the same DataType as input + assert!(!out_non_null.is_nullable()); + assert_eq!(out_non_null.data_type(), &DataType::Int32); + + // --- nullable Int32 input --- + let nullable_i32 = Arc::new(Field::new("c", DataType::Int32, true)); + let out_nullable = abs + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&nullable_i32)], + scalar_arguments: &[None], + }) + .unwrap(); + + // result should be nullable and the same DataType as input + assert!(out_nullable.is_nullable()); + assert_eq!(out_nullable.data_type(), &DataType::Int32); + + // --- non-nullable Float64 input --- + let non_nullable_f64 = Arc::new(Field::new("c", DataType::Float64, false)); + let out_f64 = abs + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&non_nullable_f64)], + scalar_arguments: &[None], + }) + .unwrap(); + + assert!(!out_f64.is_nullable()); + assert_eq!(out_f64.data_type(), &DataType::Float64); + + // --- nullable Float64 input --- + let nullable_f64 = Arc::new(Field::new("c", DataType::Float64, true)); + let out_f64_null = abs + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&nullable_f64)], + scalar_arguments: &[None], + }) + .unwrap(); + + assert!(out_f64_null.is_nullable()); + assert_eq!(out_f64_null.data_type(), &DataType::Float64); + } }