-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Fix array_repeat handling of null count values
#20102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -19,21 +19,23 @@ | |||||
|
|
||||||
| use crate::utils::make_scalar_function; | ||||||
| use arrow::array::{ | ||||||
| Array, ArrayRef, BooleanBufferBuilder, GenericListArray, OffsetSizeTrait, UInt64Array, | ||||||
| Array, ArrayRef, BooleanBufferBuilder, GenericListArray, Int64Array, OffsetSizeTrait, | ||||||
| UInt64Array, | ||||||
| }; | ||||||
| use arrow::buffer::{NullBuffer, OffsetBuffer}; | ||||||
| use arrow::compute; | ||||||
| use arrow::compute::cast; | ||||||
| use arrow::datatypes::DataType; | ||||||
| use arrow::datatypes::{ | ||||||
| DataType::{LargeList, List}, | ||||||
| Field, | ||||||
| }; | ||||||
| use datafusion_common::cast::{as_large_list_array, as_list_array, as_uint64_array}; | ||||||
| use datafusion_common::{Result, exec_err, utils::take_function_args}; | ||||||
| use datafusion_common::Result; | ||||||
| use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; | ||||||
| use datafusion_common::types::{NativeType, logical_int64}; | ||||||
| use datafusion_expr::{ | ||||||
| ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, | ||||||
| }; | ||||||
| use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; | ||||||
| use datafusion_macros::user_doc; | ||||||
| use std::any::Any; | ||||||
| use std::sync::Arc; | ||||||
|
|
@@ -88,7 +90,17 @@ impl Default for ArrayRepeat { | |||||
| impl ArrayRepeat { | ||||||
| pub fn new() -> Self { | ||||||
| Self { | ||||||
| signature: Signature::user_defined(Volatility::Immutable), | ||||||
| signature: Signature::coercible( | ||||||
| vec![ | ||||||
| Coercion::new_exact(TypeSignatureClass::Any), | ||||||
| Coercion::new_implicit( | ||||||
| TypeSignatureClass::Native(logical_int64()), | ||||||
| vec![TypeSignatureClass::Integer], | ||||||
| NativeType::Int64, | ||||||
| ), | ||||||
| ], | ||||||
| Volatility::Immutable, | ||||||
| ), | ||||||
| aliases: vec![String::from("list_repeat")], | ||||||
| } | ||||||
| } | ||||||
|
|
@@ -132,39 +144,14 @@ impl ScalarUDFImpl for ArrayRepeat { | |||||
| &self.aliases | ||||||
| } | ||||||
|
|
||||||
| fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> { | ||||||
| let [first_type, second_type] = take_function_args(self.name(), arg_types)?; | ||||||
|
|
||||||
| // Coerce the second argument to Int64/UInt64 if it's a numeric type | ||||||
| let second = match second_type { | ||||||
| DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { | ||||||
| DataType::Int64 | ||||||
| } | ||||||
| DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { | ||||||
| DataType::UInt64 | ||||||
| } | ||||||
| _ => return exec_err!("count must be an integer type"), | ||||||
| }; | ||||||
|
|
||||||
| Ok(vec![first_type.clone(), second]) | ||||||
| } | ||||||
|
|
||||||
| fn documentation(&self) -> Option<&Documentation> { | ||||||
| self.doc() | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| fn array_repeat_inner(args: &[ArrayRef]) -> Result<ArrayRef> { | ||||||
| let element = &args[0]; | ||||||
| let count_array = &args[1]; | ||||||
|
|
||||||
| let count_array = match count_array.data_type() { | ||||||
| DataType::Int64 => &cast(count_array, &DataType::UInt64)?, | ||||||
| DataType::UInt64 => count_array, | ||||||
| _ => return exec_err!("count must be an integer type"), | ||||||
| }; | ||||||
|
|
||||||
| let count_array = as_uint64_array(count_array)?; | ||||||
| let count_array = as_int64_array(&args[1])?; | ||||||
|
|
||||||
| match element.data_type() { | ||||||
| List(_) => { | ||||||
|
|
@@ -193,21 +180,27 @@ fn array_repeat_inner(args: &[ArrayRef]) -> Result<ArrayRef> { | |||||
| /// ``` | ||||||
| fn general_repeat<O: OffsetSizeTrait>( | ||||||
| array: &ArrayRef, | ||||||
| count_array: &UInt64Array, | ||||||
| count_array: &Int64Array, | ||||||
| ) -> Result<ArrayRef> { | ||||||
| // Build offsets and take_indices | ||||||
| let total_repeated_values: usize = | ||||||
| count_array.values().iter().map(|&c| c as usize).sum(); | ||||||
| let total_repeated_values: usize = count_array | ||||||
| .values() | ||||||
| .iter() | ||||||
| .map(|&c| if c > 0 { c as usize } else { 0 }) | ||||||
| .sum(); | ||||||
|
|
||||||
| let mut take_indices = Vec::with_capacity(total_repeated_values); | ||||||
| let mut nulls = BooleanBufferBuilder::new(count_array.len()); | ||||||
| let mut offsets = Vec::with_capacity(count_array.len() + 1); | ||||||
| offsets.push(O::zero()); | ||||||
| let mut running_offset = 0usize; | ||||||
|
|
||||||
| for (idx, &count) in count_array.values().iter().enumerate() { | ||||||
| let count = count as usize; | ||||||
| for idx in 0..count_array.len() { | ||||||
| let (count, is_valid) = get_count_with_validity(count_array, idx); | ||||||
|
|
||||||
| running_offset += count; | ||||||
| offsets.push(O::from_usize(running_offset).unwrap()); | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Extreme case:
|
||||||
| take_indices.extend(std::iter::repeat_n(idx as u64, count)) | ||||||
| nulls.append(is_valid); | ||||||
| take_indices.extend(std::iter::repeat_n(idx as u64, count)); | ||||||
| } | ||||||
|
|
||||||
| // Build the flattened values | ||||||
|
|
@@ -222,7 +215,7 @@ fn general_repeat<O: OffsetSizeTrait>( | |||||
| Arc::new(Field::new_list_field(array.data_type().to_owned(), true)), | ||||||
| OffsetBuffer::new(offsets.into()), | ||||||
| repeated_values, | ||||||
| None, | ||||||
| Some(NullBuffer::new(nulls.finish())), | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can copy the nulls buffer from the count array instead of using a builder
Suggested change
|
||||||
| )?)) | ||||||
| } | ||||||
|
|
||||||
|
|
@@ -238,42 +231,46 @@ fn general_repeat<O: OffsetSizeTrait>( | |||||
| /// ``` | ||||||
| fn general_list_repeat<O: OffsetSizeTrait>( | ||||||
| list_array: &GenericListArray<O>, | ||||||
| count_array: &UInt64Array, | ||||||
| count_array: &Int64Array, | ||||||
| ) -> Result<ArrayRef> { | ||||||
| let counts = count_array.values(); | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also does not check for NULLs in the count_array and may lead to overestimates. You need to use |
||||||
| let list_offsets = list_array.value_offsets(); | ||||||
|
|
||||||
| // calculate capacities for pre-allocation | ||||||
| let outer_total = counts.iter().map(|&c| c as usize).sum(); | ||||||
| let inner_total = counts | ||||||
| .iter() | ||||||
| .enumerate() | ||||||
| .filter(|&(i, _)| !list_array.is_null(i)) | ||||||
| .map(|(i, &c)| { | ||||||
| let len = list_offsets[i + 1].to_usize().unwrap() | ||||||
| - list_offsets[i].to_usize().unwrap(); | ||||||
| len * (c as usize) | ||||||
| }) | ||||||
| .sum(); | ||||||
| let mut outer_total = 0usize; | ||||||
| let mut inner_total = 0usize; | ||||||
| for (i, &c) in counts.iter().enumerate() { | ||||||
| if c > 0 { | ||||||
| outer_total += c as usize; | ||||||
| if !list_array.is_null(i) { | ||||||
| let len = list_offsets[i + 1].to_usize().unwrap() | ||||||
| - list_offsets[i].to_usize().unwrap(); | ||||||
| inner_total += len * (c as usize); | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| // Build inner structures | ||||||
| let mut inner_offsets = Vec::with_capacity(outer_total + 1); | ||||||
| let mut take_indices = Vec::with_capacity(inner_total); | ||||||
| let mut inner_nulls = BooleanBufferBuilder::new(outer_total); | ||||||
| let mut outer_nulls = BooleanBufferBuilder::new(count_array.len()); | ||||||
| let mut inner_running = 0usize; | ||||||
| inner_offsets.push(O::zero()); | ||||||
|
|
||||||
| for (row_idx, &count) in counts.iter().enumerate() { | ||||||
| let is_valid = !list_array.is_null(row_idx); | ||||||
| for row_idx in 0..count_array.len() { | ||||||
| let (count, count_is_valid) = get_count_with_validity(count_array, row_idx); | ||||||
| outer_nulls.append(count_is_valid); | ||||||
| let list_is_valid = !list_array.is_null(row_idx); | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| let start = list_offsets[row_idx].to_usize().unwrap(); | ||||||
| let end = list_offsets[row_idx + 1].to_usize().unwrap(); | ||||||
| let row_len = end - start; | ||||||
|
|
||||||
| for _ in 0..count { | ||||||
| inner_running += row_len; | ||||||
| inner_offsets.push(O::from_usize(inner_running).unwrap()); | ||||||
| inner_nulls.append(is_valid); | ||||||
| if is_valid { | ||||||
| inner_nulls.append(list_is_valid); | ||||||
| if list_is_valid { | ||||||
| take_indices.extend(start as u64..end as u64); | ||||||
| } | ||||||
| } | ||||||
|
|
@@ -298,8 +295,22 @@ fn general_list_repeat<O: OffsetSizeTrait>( | |||||
| list_array.data_type().to_owned(), | ||||||
| true, | ||||||
| )), | ||||||
| OffsetBuffer::<O>::from_lengths(counts.iter().map(|&c| c as usize)), | ||||||
| OffsetBuffer::<O>::from_lengths( | ||||||
| count_array | ||||||
| .iter() | ||||||
| .map(|c| c.map(|v| if v > 0 { v as usize } else { 0 }).unwrap_or(0)), | ||||||
| ), | ||||||
| Arc::new(inner_list), | ||||||
| None, | ||||||
| Some(NullBuffer::new(outer_nulls.finish())), | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same note here about reusing input null buffer |
||||||
| )?)) | ||||||
| } | ||||||
|
|
||||||
| /// Helper function to get count and validity from count_array at given index | ||||||
| fn get_count_with_validity(count_array: &Int64Array, idx: usize) -> (usize, bool) { | ||||||
| if count_array.is_null(idx) { | ||||||
| (0, false) | ||||||
| } else { | ||||||
| let c = count_array.value(idx); | ||||||
| if c > 0 { (c as usize, true) } else { (0, true) } | ||||||
| } | ||||||
| } | ||||||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -3256,24 +3256,97 @@ drop table array_repeat_table; | |||||||
| statement ok | ||||||||
| drop table large_array_repeat_table; | ||||||||
|
|
||||||||
|
|
||||||||
| # array_repeat: arrays with NULL counts | ||||||||
| statement ok | ||||||||
| create table array_repeat_null_count_table | ||||||||
| as values | ||||||||
| (1, 2), | ||||||||
| (2, null), | ||||||||
| (3, 1); | ||||||||
| (3, 1), | ||||||||
| (4, -1); | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
|
||||||||
| query I? | ||||||||
| select column1, array_repeat(column1, column2) from array_repeat_null_count_table; | ||||||||
| ---- | ||||||||
| 1 [1, 1] | ||||||||
| 2 [] | ||||||||
| 2 NULL | ||||||||
| 3 [3] | ||||||||
| 4 [] | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
|
||||||||
| statement ok | ||||||||
| drop table array_repeat_null_count_table | ||||||||
|
|
||||||||
| # array_repeat: nested arrays with NULL counts | ||||||||
| statement ok | ||||||||
| create table array_repeat_nested_null_count_table | ||||||||
| as values | ||||||||
| ([[1, 2], [3, 4]], 2), | ||||||||
| ([[5, 6], [7, 8]], null), | ||||||||
| ([[null, null], [9, 10]], 1), | ||||||||
| (null, 3), | ||||||||
| ([[11, 12]], -1); | ||||||||
|
|
||||||||
| query ?? | ||||||||
| select column1, array_repeat(column1, column2) from array_repeat_nested_null_count_table; | ||||||||
| ---- | ||||||||
| [[1, 2], [3, 4]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] | ||||||||
| [[5, 6], [7, 8]] NULL | ||||||||
| [[NULL, NULL], [9, 10]] [[[NULL, NULL], [9, 10]]] | ||||||||
| NULL [NULL, NULL, NULL] | ||||||||
| [[11, 12]] [] | ||||||||
|
|
||||||||
| statement ok | ||||||||
| drop table array_repeat_nested_null_count_table | ||||||||
|
|
||||||||
| # array_repeat edge cases: empty arrays | ||||||||
| query ??? | ||||||||
| select array_repeat([], 3), array_repeat([], 0), array_repeat([], null); | ||||||||
| ---- | ||||||||
| [[], [], []] [] NULL | ||||||||
|
|
||||||||
| query ?? | ||||||||
| select array_repeat(null::int, 0), array_repeat(null::int, null); | ||||||||
| ---- | ||||||||
| [] NULL | ||||||||
|
|
||||||||
| # array_repeat LargeList with NULL count | ||||||||
| statement ok | ||||||||
| create table array_repeat_large_list_null_table | ||||||||
| as values | ||||||||
| (arrow_cast([1, 2, 3], 'LargeList(Int64)'), 2), | ||||||||
| (arrow_cast([4, 5], 'LargeList(Int64)'), null), | ||||||||
| (arrow_cast(null, 'LargeList(Int64)'), 3); | ||||||||
|
|
||||||||
| query ?? | ||||||||
| select column1, array_repeat(column1, column2) from array_repeat_large_list_null_table; | ||||||||
| ---- | ||||||||
| [1, 2, 3] [[1, 2, 3], [1, 2, 3]] | ||||||||
| [4, 5] NULL | ||||||||
| NULL [NULL, NULL, NULL] | ||||||||
|
|
||||||||
| statement ok | ||||||||
| drop table array_repeat_large_list_null_table | ||||||||
|
|
||||||||
| # array_repeat edge cases: LargeList nested with NULL count | ||||||||
| statement ok | ||||||||
| create table array_repeat_large_nested_null_table | ||||||||
| as values | ||||||||
| (arrow_cast([[1, 2], [3, 4]], 'LargeList(List(Int64))'), 2), | ||||||||
| (arrow_cast([[5, 6], [7, 8]], 'LargeList(List(Int64))'), null), | ||||||||
| (arrow_cast([[null, null]], 'LargeList(List(Int64))'), 1), | ||||||||
| (null, 3); | ||||||||
|
|
||||||||
| query ?? | ||||||||
| select column1, array_repeat(column1, column2) from array_repeat_large_nested_null_table; | ||||||||
| ---- | ||||||||
| [[1, 2], [3, 4]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] | ||||||||
| [[5, 6], [7, 8]] NULL | ||||||||
| [[NULL, NULL]] [[[NULL, NULL]]] | ||||||||
| NULL [NULL, NULL, NULL] | ||||||||
|
|
||||||||
| statement ok | ||||||||
| drop table array_repeat_large_nested_null_table | ||||||||
|
|
||||||||
| ## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`) | ||||||||
|
|
||||||||
| # test with empty array | ||||||||
|
|
||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: technically the spec allows null slots to have non-0 values, so this could overestimate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possible fix: