diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index c47e4a696a1d8..eebe626bad64b 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -55,6 +55,7 @@ pub mod repeat; pub mod replace; pub mod resize; pub mod reverse; +pub mod transform; pub mod set_ops; pub mod sort; pub mod string; @@ -104,6 +105,8 @@ pub mod expr_fn { pub use super::replace::array_replace_n; pub use super::resize::array_resize; pub use super::reverse::array_reverse; + pub use super::transform::array_transform; + pub use super::transform::array_reduce; pub use super::set_ops::array_distinct; pub use super::set_ops::array_intersect; pub use super::set_ops::array_union; @@ -143,6 +146,8 @@ pub fn all_default_nested_functions() -> Vec> { repeat::array_repeat_udf(), resize::array_resize_udf(), reverse::array_reverse_udf(), + transform::array_transform_udf(), + transform::array_reduce_udf(), set_ops::array_distinct_udf(), set_ops::array_intersect_udf(), set_ops::array_union_udf(), diff --git a/datafusion/functions-nested/src/transform.rs b/datafusion/functions-nested/src/transform.rs new file mode 100644 index 0000000000000..b2cd2b0478f45 --- /dev/null +++ b/datafusion/functions-nested/src/transform.rs @@ -0,0 +1,362 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_transform and array_reduce functions. + +use std::any::Any; +use std::sync::{Arc, OnceLock}; + +use arrow::array::{Int64Array, Int64Builder, ListArray}; +use arrow_array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait}; +use arrow_schema::{DataType, Field}; +use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; +use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; + +use crate::utils::make_scalar_function; + +make_udf_expr_and_func!( + ArrayTransform, + array_transform, + array func, + "applies a scalar function to each element of the array.", + array_transform_udf +); +make_udf_expr_and_func!( + ArrayReduce, + array_reduce, + array func, + "reduces array elements using the specified aggregate function.", + array_reduce_udf +); + +#[derive(Debug)] +pub struct ArrayTransform { + signature: Signature, + aliases: Vec, +} + +impl ArrayTransform { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + aliases: vec![String::from("list_transform")], + } + } +} + +impl Default for ArrayTransform { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArrayTransform { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_transform" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_transform_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_transform_doc()) + } +} + +static DOC_TRANSFORM: OnceLock = OnceLock::new(); + +fn get_array_transform_doc() -> &'static Documentation { + DOC_TRANSFORM.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Applies a scalar function to each element of the array.", + ) + .with_syntax_example("array_transform(array, func)") + .with_sql_example( + r#"```sql +> select array_transform([1, -2, 3], 'abs'); ++---------------------------------------------+ +| array_transform(List([1,-2,3]),Utf8("abs")) | ++---------------------------------------------+ +| [1, 2, 3] | ++---------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "func", + "Name of a scalar function (e.g. 'abs').", + ) + .build() + .unwrap() + }) +} + +#[derive(Debug)] +pub struct ArrayReduce { + signature: Signature, + aliases: Vec, +} + +impl ArrayReduce { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + aliases: vec![String::from("list_reduce")], + } + } +} + +impl Default for ArrayReduce { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArrayReduce { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_reduce" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_reduce_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_reduce_doc()) + } +} + +static DOC_REDUCE: OnceLock = OnceLock::new(); + +fn get_array_reduce_doc() -> &'static Documentation { + DOC_REDUCE.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Aggregates the array elements using the specified aggregate function.", + ) + .with_syntax_example("array_reduce(array, func)") + .with_sql_example( + r#"```sql +> select array_reduce([1,2,3], 'sum'); ++------------------------------------------+ +| array_reduce(List([1,2,3]),Utf8("sum")) | ++------------------------------------------+ +| 6 | ++------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "func", + "Name of an aggregate function (e.g. 'sum').", + ) + .build() + .unwrap() + }) +} + +fn array_transform_inner(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_transform expects two arguments"); + } + let func_name = ScalarValue::try_from_array(&args[1], 0)?; + let func_name = match func_name { + ScalarValue::Utf8(Some(s)) => s, + ScalarValue::LargeUtf8(Some(s)) => s, + _ => return exec_err!("function name must be a string"), + }; + match &args[0].data_type() { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + general_array_transform::(array, &func_name) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + general_array_transform::(array, &func_name) + } + dt => exec_err!("array_transform does not support type '{dt:?}'"), + } +} + +fn general_array_transform( + array: &GenericListArray, + func: &str, +) -> Result { + let values = array.values(); + let transformed_values = apply_scalar_function(values.clone(), func)?; + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new("item", transformed_values.data_type().clone(), true)), + array.offsets().clone(), + transformed_values, + array.nulls().cloned(), + )?)) +} + +fn apply_scalar_function(values: ArrayRef, func: &str) -> Result { + use datafusion_functions::math; + let udf = match func.to_ascii_lowercase().as_str() { + "abs" => math::abs(), + _ => return exec_err!("unsupported function '{func}'"), + }; + let res = udf.invoke(&[ColumnarValue::Array(values.clone())])?; + res.into_array(values.len()) +} + +fn array_reduce_inner(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_reduce expects two arguments"); + } + let func_name = ScalarValue::try_from_array(&args[1], 0)?; + let func_name = match func_name { + ScalarValue::Utf8(Some(s)) => s, + ScalarValue::LargeUtf8(Some(s)) => s, + _ => return exec_err!("function name must be a string"), + }; + match &args[0].data_type() { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + general_array_reduce::(array, &func_name) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + general_array_reduce::(array, &func_name) + } + dt => exec_err!("array_reduce does not support type '{dt:?}'"), + } +} + +fn general_array_reduce( + array: &GenericListArray, + func: &str, +) -> Result { + match func.to_ascii_lowercase().as_str() { + "sum" => array_reduce_sum(array), + _ => exec_err!("unsupported aggregate '{func}'"), + } +} + +fn array_reduce_sum(array: &GenericListArray) -> Result { + let values = as_int64_array(array.values())?; + let mut builder = Int64Builder::with_capacity(array.len()); + for (i, window) in array.offsets().windows(2).enumerate() { + if array.is_null(i) { + builder.append_null(); + continue; + } + let start = window[0].to_usize().unwrap(); + let end = window[1].to_usize().unwrap(); + let mut acc = 0i64; + let mut has_value = false; + for idx in start..end { + if values.is_valid(idx) { + acc += values.value(idx); + has_value = true; + } + } + if has_value { + builder.append_value(acc); + } else { + builder.append_null(); + } + } + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::Int64Type; + + #[test] + fn test_array_transform_abs() -> Result<()> { + let array = Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(-1), Some(2), Some(-3)]), + ])); + let result = array_transform_udf().invoke(&[ + ColumnarValue::Array(array as ArrayRef), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("abs".to_string()))), + ])?; + let result = result.into_array(1)?; + let expected = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + ]); + assert_eq!(result.as_ref(), &expected); + Ok(()) + } + + #[test] + fn test_array_reduce_sum() -> Result<()> { + let array = Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + ])); + let result = array_reduce_udf().invoke(&[ + ColumnarValue::Array(array as ArrayRef), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("sum".to_string()))), + ])?; + let result = result.into_array(1)?; + let expected = Int64Array::from(vec![Some(6)]); + assert_eq!(result.as_ref(), &expected); + Ok(()) + } +} + diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index a379dfc9ec29a..dfd7b2851063e 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -2497,6 +2497,8 @@ _Alias of [current_date](#current_date)._ - [array_reverse](#array_reverse) - [array_slice](#array_slice) - [array_sort](#array_sort) +- [array_transform](#array_transform) +- [array_reduce](#array_reduce) - [array_to_string](#array_to_string) - [array_union](#array_union) - [cardinality](#cardinality) @@ -3464,6 +3466,54 @@ array_sort(array, desc, nulls_first) - list_sort +### `array_transform` + +Applies a scalar function to each element of the array. + +``` +array_transform(array, func) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **func**: Name of a scalar function (e.g. `abs`). + +#### Example + +```sql +> select array_transform([1, -2, 3], 'abs'); ++---------------------------------------------+ +| array_transform(List([1,-2,3]),Utf8("abs")) | ++---------------------------------------------+ +| [1, 2, 3] | ++---------------------------------------------+ +``` + +### `array_reduce` + +Aggregates the array elements using the specified aggregate function. + +``` +array_reduce(array, func) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **func**: Name of an aggregate function (e.g. `sum`). + +#### Example + +```sql +> select array_reduce([1, 2, 3], 'sum'); ++------------------------------------------+ +| array_reduce(List([1,2,3]),Utf8("sum")) | ++------------------------------------------+ +| 6 | ++------------------------------------------+ +``` + ### `array_to_string` Converts each element to its text representation.