Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 68 additions & 57 deletions datafusion/functions-nested/src/repeat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,23 @@

use crate::utils::make_scalar_function;
use arrow::array::{
Array, ArrayRef, BooleanBufferBuilder, GenericListArray, OffsetSizeTrait, UInt64Array,
Array, ArrayRef, BooleanBufferBuilder, GenericListArray, Int64Array, OffsetSizeTrait,
UInt64Array,
};
use arrow::buffer::{NullBuffer, OffsetBuffer};
use arrow::compute;
use arrow::compute::cast;
use arrow::datatypes::DataType;
use arrow::datatypes::{
DataType::{LargeList, List},
Field,
};
use datafusion_common::cast::{as_large_list_array, as_list_array, as_uint64_array};
use datafusion_common::{Result, exec_err, utils::take_function_args};
use datafusion_common::Result;
use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array};
use datafusion_common::types::{NativeType, logical_int64};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
use datafusion_macros::user_doc;
use std::any::Any;
use std::sync::Arc;
Expand Down Expand Up @@ -88,7 +90,17 @@ impl Default for ArrayRepeat {
impl ArrayRepeat {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
signature: Signature::coercible(
vec![
Coercion::new_exact(TypeSignatureClass::Any),
Coercion::new_implicit(
TypeSignatureClass::Native(logical_int64()),
vec![TypeSignatureClass::Integer],
NativeType::Int64,
),
],
Volatility::Immutable,
),
aliases: vec![String::from("list_repeat")],
}
}
Expand Down Expand Up @@ -132,39 +144,14 @@ impl ScalarUDFImpl for ArrayRepeat {
&self.aliases
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
let [first_type, second_type] = take_function_args(self.name(), arg_types)?;

// Coerce the second argument to Int64/UInt64 if it's a numeric type
let second = match second_type {
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
DataType::Int64
}
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
DataType::UInt64
}
_ => return exec_err!("count must be an integer type"),
};

Ok(vec![first_type.clone(), second])
}

fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}

fn array_repeat_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let element = &args[0];
let count_array = &args[1];

let count_array = match count_array.data_type() {
DataType::Int64 => &cast(count_array, &DataType::UInt64)?,
DataType::UInt64 => count_array,
_ => return exec_err!("count must be an integer type"),
};

let count_array = as_uint64_array(count_array)?;
let count_array = as_int64_array(&args[1])?;

match element.data_type() {
List(_) => {
Expand Down Expand Up @@ -193,21 +180,27 @@ fn array_repeat_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
/// ```
fn general_repeat<O: OffsetSizeTrait>(
array: &ArrayRef,
count_array: &UInt64Array,
count_array: &Int64Array,
) -> Result<ArrayRef> {
// Build offsets and take_indices
let total_repeated_values: usize =
count_array.values().iter().map(|&c| c as usize).sum();
let total_repeated_values: usize = count_array
.values()
.iter()
.map(|&c| if c > 0 { c as usize } else { 0 })
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: technically the spec allows null slots to have non-0 values, so this could overestimate

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible fix:

let total_repeated_values: usize = (0..count_array.len())
    .map(|i| get_count_with_validity(count_array, i).0)
    .sum();

.sum();

let mut take_indices = Vec::with_capacity(total_repeated_values);
let mut nulls = BooleanBufferBuilder::new(count_array.len());
let mut offsets = Vec::with_capacity(count_array.len() + 1);
offsets.push(O::zero());
let mut running_offset = 0usize;

for (idx, &count) in count_array.values().iter().enumerate() {
let count = count as usize;
for idx in 0..count_array.len() {
let (count, is_valid) = get_count_with_validity(count_array, idx);

running_offset += count;
offsets.push(O::from_usize(running_offset).unwrap());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extreme case:

  1. if the input values is ListArray, then its offset type will be i32
  2. and if the count value is bigger than i32::MAX
  3. then i32::from_usize(i32::MAX + 1) will return None and it will panic

take_indices.extend(std::iter::repeat_n(idx as u64, count))
nulls.append(is_valid);
take_indices.extend(std::iter::repeat_n(idx as u64, count));
}

// Build the flattened values
Expand All @@ -222,7 +215,7 @@ fn general_repeat<O: OffsetSizeTrait>(
Arc::new(Field::new_list_field(array.data_type().to_owned(), true)),
OffsetBuffer::new(offsets.into()),
repeated_values,
None,
Some(NullBuffer::new(nulls.finish())),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can copy the nulls buffer from the count array instead of using a builder

Suggested change
Some(NullBuffer::new(nulls.finish())),
count_array.nulls().cloned(),

)?))
}

Expand All @@ -238,42 +231,46 @@ fn general_repeat<O: OffsetSizeTrait>(
/// ```
fn general_list_repeat<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
count_array: &UInt64Array,
count_array: &Int64Array,
) -> Result<ArrayRef> {
let counts = count_array.values();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also does not check for NULLs in the count_array and may lead to overestimates. You need to use get_count_with_validity() at line 245 too

let list_offsets = list_array.value_offsets();

// calculate capacities for pre-allocation
let outer_total = counts.iter().map(|&c| c as usize).sum();
let inner_total = counts
.iter()
.enumerate()
.filter(|&(i, _)| !list_array.is_null(i))
.map(|(i, &c)| {
let len = list_offsets[i + 1].to_usize().unwrap()
- list_offsets[i].to_usize().unwrap();
len * (c as usize)
})
.sum();
let mut outer_total = 0usize;
let mut inner_total = 0usize;
for (i, &c) in counts.iter().enumerate() {
if c > 0 {
outer_total += c as usize;
if !list_array.is_null(i) {
let len = list_offsets[i + 1].to_usize().unwrap()
- list_offsets[i].to_usize().unwrap();
inner_total += len * (c as usize);
}
}
}

// Build inner structures
let mut inner_offsets = Vec::with_capacity(outer_total + 1);
let mut take_indices = Vec::with_capacity(inner_total);
let mut inner_nulls = BooleanBufferBuilder::new(outer_total);
let mut outer_nulls = BooleanBufferBuilder::new(count_array.len());
let mut inner_running = 0usize;
inner_offsets.push(O::zero());

for (row_idx, &count) in counts.iter().enumerate() {
let is_valid = !list_array.is_null(row_idx);
for row_idx in 0..count_array.len() {
let (count, count_is_valid) = get_count_with_validity(count_array, row_idx);
outer_nulls.append(count_is_valid);
let list_is_valid = !list_array.is_null(row_idx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let list_is_valid = !list_array.is_null(row_idx);
let list_is_valid = list_array.is_valid(row_idx);

let start = list_offsets[row_idx].to_usize().unwrap();
let end = list_offsets[row_idx + 1].to_usize().unwrap();
let row_len = end - start;

for _ in 0..count {
inner_running += row_len;
inner_offsets.push(O::from_usize(inner_running).unwrap());
inner_nulls.append(is_valid);
if is_valid {
inner_nulls.append(list_is_valid);
if list_is_valid {
take_indices.extend(start as u64..end as u64);
}
}
Expand All @@ -298,8 +295,22 @@ fn general_list_repeat<O: OffsetSizeTrait>(
list_array.data_type().to_owned(),
true,
)),
OffsetBuffer::<O>::from_lengths(counts.iter().map(|&c| c as usize)),
OffsetBuffer::<O>::from_lengths(
count_array
.iter()
.map(|c| c.map(|v| if v > 0 { v as usize } else { 0 }).unwrap_or(0)),
),
Arc::new(inner_list),
None,
Some(NullBuffer::new(outer_nulls.finish())),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same note here about reusing input null buffer

)?))
}

/// Helper function to get count and validity from count_array at given index
fn get_count_with_validity(count_array: &Int64Array, idx: usize) -> (usize, bool) {
if count_array.is_null(idx) {
(0, false)
} else {
let c = count_array.value(idx);
if c > 0 { (c as usize, true) } else { (0, true) }
}
}
79 changes: 76 additions & 3 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -3256,24 +3256,97 @@ drop table array_repeat_table;
statement ok
drop table large_array_repeat_table;


# array_repeat: arrays with NULL counts
statement ok
create table array_repeat_null_count_table
as values
(1, 2),
(2, null),
(3, 1);
(3, 1),
(4, -1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(4, -1);
(4, -1),
(null, null);


query I?
select column1, array_repeat(column1, column2) from array_repeat_null_count_table;
----
1 [1, 1]
2 []
2 NULL
3 [3]
4 []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
4 []
4 []
NULL NULL


statement ok
drop table array_repeat_null_count_table

# array_repeat: nested arrays with NULL counts
statement ok
create table array_repeat_nested_null_count_table
as values
([[1, 2], [3, 4]], 2),
([[5, 6], [7, 8]], null),
([[null, null], [9, 10]], 1),
(null, 3),
([[11, 12]], -1);

query ??
select column1, array_repeat(column1, column2) from array_repeat_nested_null_count_table;
----
[[1, 2], [3, 4]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]
[[5, 6], [7, 8]] NULL
[[NULL, NULL], [9, 10]] [[[NULL, NULL], [9, 10]]]
NULL [NULL, NULL, NULL]
[[11, 12]] []

statement ok
drop table array_repeat_nested_null_count_table

# array_repeat edge cases: empty arrays
query ???
select array_repeat([], 3), array_repeat([], 0), array_repeat([], null);
----
[[], [], []] [] NULL

query ??
select array_repeat(null::int, 0), array_repeat(null::int, null);
----
[] NULL

# array_repeat LargeList with NULL count
statement ok
create table array_repeat_large_list_null_table
as values
(arrow_cast([1, 2, 3], 'LargeList(Int64)'), 2),
(arrow_cast([4, 5], 'LargeList(Int64)'), null),
(arrow_cast(null, 'LargeList(Int64)'), 3);

query ??
select column1, array_repeat(column1, column2) from array_repeat_large_list_null_table;
----
[1, 2, 3] [[1, 2, 3], [1, 2, 3]]
[4, 5] NULL
NULL [NULL, NULL, NULL]

statement ok
drop table array_repeat_large_list_null_table

# array_repeat edge cases: LargeList nested with NULL count
statement ok
create table array_repeat_large_nested_null_table
as values
(arrow_cast([[1, 2], [3, 4]], 'LargeList(List(Int64))'), 2),
(arrow_cast([[5, 6], [7, 8]], 'LargeList(List(Int64))'), null),
(arrow_cast([[null, null]], 'LargeList(List(Int64))'), 1),
(null, 3);

query ??
select column1, array_repeat(column1, column2) from array_repeat_large_nested_null_table;
----
[[1, 2], [3, 4]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]
[[5, 6], [7, 8]] NULL
[[NULL, NULL]] [[[NULL, NULL]]]
NULL [NULL, NULL, NULL]

statement ok
drop table array_repeat_large_nested_null_table

## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`)

# test with empty array
Expand Down