diff --git a/datafusion/functions-nested/src/repeat.rs b/datafusion/functions-nested/src/repeat.rs index 28ec827cc5a0..5e78a4d0f601 100644 --- a/datafusion/functions-nested/src/repeat.rs +++ b/datafusion/functions-nested/src/repeat.rs @@ -19,21 +19,23 @@ use crate::utils::make_scalar_function; use arrow::array::{ - Array, ArrayRef, BooleanBufferBuilder, GenericListArray, OffsetSizeTrait, UInt64Array, + Array, ArrayRef, BooleanBufferBuilder, GenericListArray, Int64Array, OffsetSizeTrait, + UInt64Array, }; use arrow::buffer::{NullBuffer, OffsetBuffer}; use arrow::compute; -use arrow::compute::cast; use arrow::datatypes::DataType; use arrow::datatypes::{ DataType::{LargeList, List}, Field, }; -use datafusion_common::cast::{as_large_list_array, as_list_array, as_uint64_array}; -use datafusion_common::{Result, exec_err, utils::take_function_args}; +use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; +use datafusion_common::types::{NativeType, logical_int64}; +use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -88,7 +90,17 @@ impl Default for ArrayRepeat { impl ArrayRepeat { pub fn new() -> Self { Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Any), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Integer], + NativeType::Int64, + ), + ], + Volatility::Immutable, + ), aliases: vec![String::from("list_repeat")], } } @@ -132,23 +144,6 @@ impl ScalarUDFImpl for ArrayRepeat { &self.aliases } - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let [first_type, second_type] = take_function_args(self.name(), arg_types)?; - - // Coerce the second argument to Int64/UInt64 if it's a numeric type - let second = match second_type { - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { - DataType::Int64 - } - DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { - DataType::UInt64 - } - _ => return exec_err!("count must be an integer type"), - }; - - Ok(vec![first_type.clone(), second]) - } - fn documentation(&self) -> Option<&Documentation> { self.doc() } @@ -156,15 +151,7 @@ impl ScalarUDFImpl for ArrayRepeat { fn array_repeat_inner(args: &[ArrayRef]) -> Result { let element = &args[0]; - let count_array = &args[1]; - - let count_array = match count_array.data_type() { - DataType::Int64 => &cast(count_array, &DataType::UInt64)?, - DataType::UInt64 => count_array, - _ => return exec_err!("count must be an integer type"), - }; - - let count_array = as_uint64_array(count_array)?; + let count_array = as_int64_array(&args[1])?; match element.data_type() { List(_) => { @@ -193,21 +180,31 @@ fn array_repeat_inner(args: &[ArrayRef]) -> Result { /// ``` fn general_repeat( array: &ArrayRef, - count_array: &UInt64Array, + count_array: &Int64Array, ) -> Result { - // Build offsets and take_indices - let total_repeated_values: usize = - count_array.values().iter().map(|&c| c as usize).sum(); + let total_repeated_values: usize = (0..count_array.len()) + .map(|i| get_count_with_validity(count_array, i)) + .sum(); + let mut take_indices = Vec::with_capacity(total_repeated_values); let mut offsets = Vec::with_capacity(count_array.len() + 1); offsets.push(O::zero()); let mut running_offset = 0usize; - for (idx, &count) in count_array.values().iter().enumerate() { - let count = count as usize; - running_offset += count; - offsets.push(O::from_usize(running_offset).unwrap()); - take_indices.extend(std::iter::repeat_n(idx as u64, count)) + for idx in 0..count_array.len() { + let count = get_count_with_validity(count_array, idx); + running_offset = running_offset.checked_add(count).ok_or_else(|| { + DataFusionError::Execution( + "array_repeat: running_offset overflowed usize".to_string(), + ) + })?; + let offset = O::from_usize(running_offset).ok_or_else(|| { + DataFusionError::Execution(format!( + "array_repeat: offset {running_offset} exceeds the maximum value for offset type" + )) + })?; + offsets.push(offset); + take_indices.extend(std::iter::repeat_n(idx as u64, count)); } // Build the flattened values @@ -222,7 +219,7 @@ fn general_repeat( Arc::new(Field::new_list_field(array.data_type().to_owned(), true)), OffsetBuffer::new(offsets.into()), repeated_values, - None, + count_array.nulls().cloned(), )?)) } @@ -238,23 +235,24 @@ fn general_repeat( /// ``` fn general_list_repeat( list_array: &GenericListArray, - count_array: &UInt64Array, + count_array: &Int64Array, ) -> Result { - let counts = count_array.values(); let list_offsets = list_array.value_offsets(); // calculate capacities for pre-allocation - let outer_total = counts.iter().map(|&c| c as usize).sum(); - let inner_total = counts - .iter() - .enumerate() - .filter(|&(i, _)| !list_array.is_null(i)) - .map(|(i, &c)| { - let len = list_offsets[i + 1].to_usize().unwrap() - - list_offsets[i].to_usize().unwrap(); - len * (c as usize) - }) - .sum(); + let mut outer_total = 0usize; + let mut inner_total = 0usize; + for i in 0..count_array.len() { + let count = get_count_with_validity(count_array, i); + if count > 0 { + outer_total += count; + if list_array.is_valid(i) { + let len = list_offsets[i + 1].to_usize().unwrap() + - list_offsets[i].to_usize().unwrap(); + inner_total += len * count; + } + } + } // Build inner structures let mut inner_offsets = Vec::with_capacity(outer_total + 1); @@ -263,17 +261,27 @@ fn general_list_repeat( let mut inner_running = 0usize; inner_offsets.push(O::zero()); - for (row_idx, &count) in counts.iter().enumerate() { - let is_valid = !list_array.is_null(row_idx); + for row_idx in 0..count_array.len() { + let count = get_count_with_validity(count_array, row_idx); + let list_is_valid = list_array.is_valid(row_idx); let start = list_offsets[row_idx].to_usize().unwrap(); let end = list_offsets[row_idx + 1].to_usize().unwrap(); let row_len = end - start; for _ in 0..count { - inner_running += row_len; - inner_offsets.push(O::from_usize(inner_running).unwrap()); - inner_nulls.append(is_valid); - if is_valid { + inner_running = inner_running.checked_add(row_len).ok_or_else(|| { + DataFusionError::Execution( + "array_repeat: inner offset overflowed usize".to_string(), + ) + })?; + let offset = O::from_usize(inner_running).ok_or_else(|| { + DataFusionError::Execution(format!( + "array_repeat: offset {inner_running} exceeds the maximum value for offset type" + )) + })?; + inner_offsets.push(offset); + inner_nulls.append(list_is_valid); + if list_is_valid { take_indices.extend(start as u64..end as u64); } } @@ -298,8 +306,24 @@ fn general_list_repeat( list_array.data_type().to_owned(), true, )), - OffsetBuffer::::from_lengths(counts.iter().map(|&c| c as usize)), + OffsetBuffer::::from_lengths( + count_array + .iter() + .map(|c| c.map(|v| if v > 0 { v as usize } else { 0 }).unwrap_or(0)), + ), Arc::new(inner_list), - None, + count_array.nulls().cloned(), )?)) } + +/// Helper function to get count from count_array at given index +/// Return 0 for null values or non-positive count. +#[inline] +fn get_count_with_validity(count_array: &Int64Array, idx: usize) -> usize { + if count_array.is_null(idx) { + 0 + } else { + let c = count_array.value(idx); + if c > 0 { c as usize } else { 0 } + } +} diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index c27433e7efab..2b98ae14d298 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -3256,24 +3256,99 @@ drop table array_repeat_table; statement ok drop table large_array_repeat_table; - +# array_repeat: arrays with NULL counts statement ok create table array_repeat_null_count_table as values (1, 2), (2, null), -(3, 1); +(3, 1), +(4, -1), +(null, null); query I? select column1, array_repeat(column1, column2) from array_repeat_null_count_table; ---- 1 [1, 1] -2 [] +2 NULL 3 [3] +4 [] +NULL NULL statement ok drop table array_repeat_null_count_table +# array_repeat: nested arrays with NULL counts +statement ok +create table array_repeat_nested_null_count_table +as values +([[1, 2], [3, 4]], 2), +([[5, 6], [7, 8]], null), +([[null, null], [9, 10]], 1), +(null, 3), +([[11, 12]], -1); + +query ?? +select column1, array_repeat(column1, column2) from array_repeat_nested_null_count_table; +---- +[[1, 2], [3, 4]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] +[[5, 6], [7, 8]] NULL +[[NULL, NULL], [9, 10]] [[[NULL, NULL], [9, 10]]] +NULL [NULL, NULL, NULL] +[[11, 12]] [] + +statement ok +drop table array_repeat_nested_null_count_table + +# array_repeat edge cases: empty arrays +query ??? +select array_repeat([], 3), array_repeat([], 0), array_repeat([], null); +---- +[[], [], []] [] NULL + +query ?? +select array_repeat(null::int, 0), array_repeat(null::int, null); +---- +[] NULL + +# array_repeat LargeList with NULL count +statement ok +create table array_repeat_large_list_null_table +as values +(arrow_cast([1, 2, 3], 'LargeList(Int64)'), 2), +(arrow_cast([4, 5], 'LargeList(Int64)'), null), +(arrow_cast(null, 'LargeList(Int64)'), 3); + +query ?? +select column1, array_repeat(column1, column2) from array_repeat_large_list_null_table; +---- +[1, 2, 3] [[1, 2, 3], [1, 2, 3]] +[4, 5] NULL +NULL [NULL, NULL, NULL] + +statement ok +drop table array_repeat_large_list_null_table + +# array_repeat edge cases: LargeList nested with NULL count +statement ok +create table array_repeat_large_nested_null_table +as values +(arrow_cast([[1, 2], [3, 4]], 'LargeList(List(Int64))'), 2), +(arrow_cast([[5, 6], [7, 8]], 'LargeList(List(Int64))'), null), +(arrow_cast([[null, null]], 'LargeList(List(Int64))'), 1), +(null, 3); + +query ?? +select column1, array_repeat(column1, column2) from array_repeat_large_nested_null_table; +---- +[[1, 2], [3, 4]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] +[[5, 6], [7, 8]] NULL +[[NULL, NULL]] [[[NULL, NULL]]] +NULL [NULL, NULL, NULL] + +statement ok +drop table array_repeat_large_nested_null_table + ## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`) # test with empty array