diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mask.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mask.rs index 1ad0d959dff..04a91c5de17 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mask.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mask.rs @@ -14,8 +14,8 @@ use crate::DecimalBytePartsVTable; impl MaskKernel for DecimalBytePartsVTable { fn mask(&self, array: &DecimalBytePartsArray, mask_array: &Mask) -> VortexResult { - DecimalBytePartsArray::try_new(mask(&array.msp, mask_array)?, *array.decimal_dtype()) - .map(|a| a.to_array()) + let masked = mask(&array.msp, mask_array)?; + DecimalBytePartsArray::try_new(masked, *array.decimal_dtype()).map(|a| a.to_array()) } } diff --git a/vortex-array/src/arrays/chunked/compute/mask.rs b/vortex-array/src/arrays/chunked/compute/mask.rs index ce75689ac98..034c3f84297 100644 --- a/vortex-array/src/arrays/chunked/compute/mask.rs +++ b/vortex-array/src/arrays/chunked/compute/mask.rs @@ -2,6 +2,8 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use itertools::Itertools as _; +use vortex_buffer::BitBuffer; +use vortex_buffer::BitBufferMut; use vortex_dtype::DType; use vortex_error::VortexResult; use vortex_mask::AllOr; @@ -12,18 +14,20 @@ use vortex_scalar::Scalar; use super::filter::ChunkFilter; use super::filter::chunk_filters; use super::filter::find_chunk_idx; -use crate::Array; use crate::ArrayRef; use crate::IntoArray; +use crate::arrays::BoolArray; use crate::arrays::ChunkedArray; use crate::arrays::ChunkedVTable; use crate::arrays::ConstantArray; use crate::arrays::chunked::compute::filter::FILTER_SLICES_SELECTIVITY_THRESHOLD; +use crate::builtins::ArrayBuiltins; use crate::compute::MaskKernel; use crate::compute::MaskKernelAdapter; use crate::compute::cast; use crate::compute::mask; use crate::register_kernel; +use crate::validity::Validity; impl MaskKernel for ChunkedVTable { fn mask(&self, array: &ChunkedArray, mask: &Mask) -> VortexResult { @@ -54,15 +58,24 @@ fn mask_indices( ) -> VortexResult> { let mut new_chunks = Vec::with_capacity(array.nchunks()); let mut current_chunk_id = 0; - let mut chunk_indices = Vec::new(); + let mut chunk_indices = Vec::::new(); let chunk_offsets = array.chunk_offsets(); for &set_index in indices { let (chunk_id, index) = find_chunk_idx(set_index, &chunk_offsets); if chunk_id != current_chunk_id { - let chunk = array.chunk(current_chunk_id); - let masked_chunk = mask(chunk, &Mask::from_indices(chunk.len(), chunk_indices))?; + let chunk = array.chunk(current_chunk_id).clone(); + let chunk_len = chunk.len(); + // chunk_indices contains indices to null out, but chunk.mask() expects + // mask=true to mean "retain". So we create a mask with bits set at indices + // to null, then invert it to get mask=true at indices to retain. + let mask = BoolArray::new( + !BitBuffer::from_indices(chunk_len, &chunk_indices), + Validity::NonNullable, + ) + .into_array(); + let masked_chunk = chunk.mask(mask)?; // Advance the chunk forward, reset the chunk indices buffer. chunk_indices = Vec::new(); new_chunks.push(masked_chunk); @@ -70,8 +83,8 @@ fn mask_indices( while current_chunk_id < chunk_id { // Chunks that are not affected by the mask, must still be casted to the correct dtype. - let chunk = array.chunk(current_chunk_id); - new_chunks.push(cast(chunk, new_dtype)?); + let chunk = array.chunk(current_chunk_id).cast(new_dtype.clone())?; + new_chunks.push(chunk); current_chunk_id += 1; } } @@ -80,8 +93,16 @@ fn mask_indices( } if !chunk_indices.is_empty() { - let chunk = array.chunk(current_chunk_id); - let masked_chunk = mask(chunk, &Mask::from_indices(chunk.len(), chunk_indices))?; + let chunk = array.chunk(current_chunk_id).clone(); + let chunk_len = chunk.len(); + // Same inversion as above: invert the mask so mask=true means "retain" + let masked_chunk = chunk.mask( + BoolArray::new( + !BitBufferMut::from_indices(chunk_len, &chunk_indices).freeze(), + Validity::NonNullable, + ) + .into_array(), + )?; new_chunks.push(masked_chunk); current_chunk_id += 1; } diff --git a/vortex-array/src/arrays/extension/compute/mask.rs b/vortex-array/src/arrays/extension/compute/mask.rs index 2260c2d6dbb..9c2b2607ec0 100644 --- a/vortex-array/src/arrays/extension/compute/mask.rs +++ b/vortex-array/src/arrays/extension/compute/mask.rs @@ -13,12 +13,13 @@ use crate::arrays::ExtensionArray; use crate::arrays::ExtensionVTable; use crate::compute::MaskKernel; use crate::compute::MaskKernelAdapter; -use crate::compute::{self}; +use crate::compute::mask; use crate::register_kernel; impl MaskKernel for ExtensionVTable { fn mask(&self, array: &ExtensionArray, mask_array: &Mask) -> VortexResult { - let masked_storage = compute::mask(array.storage(), mask_array)?; + // Use compute::mask directly since mask_array has compute::mask semantics (true=null) + let masked_storage = mask(array.storage(), mask_array)?; if masked_storage.dtype().nullability() == array.ext_dtype().storage_dtype().nullability() { Ok(ExtensionArray::new(array.ext_dtype().clone(), masked_storage).into_array()) } else { diff --git a/vortex-array/src/builtins.rs b/vortex-array/src/builtins.rs index c6f9cd0fa47..e0744dbf151 100644 --- a/vortex-array/src/builtins.rs +++ b/vortex-array/src/builtins.rs @@ -81,7 +81,7 @@ pub trait ArrayBuiltins: Sized { /// Mask the array using the given boolean mask. /// The resulting array's validity is the intersection of the original array's validity /// and the mask's validity. - fn mask(&self, mask: &ArrayRef) -> VortexResult; + fn mask(self, mask: ArrayRef) -> VortexResult; /// Boolean negation. fn not(&self) -> VortexResult; @@ -105,8 +105,8 @@ impl ArrayBuiltins for ArrayRef { .optimize() } - fn mask(&self, mask: &ArrayRef) -> VortexResult { - Mask.try_new_array(self.len(), EmptyOptions, [self.clone(), mask.clone()])? + fn mask(self, mask: ArrayRef) -> VortexResult { + Mask.try_new_array(self.len(), EmptyOptions, [self, mask])? .optimize() } diff --git a/vortex-array/src/expr/exprs/mask.rs b/vortex-array/src/expr/exprs/mask.rs index 150064a3b36..f24e2347ae9 100644 --- a/vortex-array/src/expr/exprs/mask.rs +++ b/vortex-array/src/expr/exprs/mask.rs @@ -2,7 +2,6 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use std::fmt::Formatter; -use std::ops::Not; use vortex_dtype::DType; use vortex_dtype::Nullability; @@ -17,8 +16,8 @@ use vortex_vector::ScalarOps; use vortex_vector::VectorMutOps; use vortex_vector::VectorOps; -use crate::Array; use crate::ArrayRef; +use crate::ToCanonical; use crate::expr::Arity; use crate::expr::ChildName; use crate::expr::EmptyOptions; @@ -95,10 +94,12 @@ impl VTable for Mask { ) -> VortexResult { let child = expr.child(0).evaluate(scope)?; - // Invert the validity mask - we want to set values to null where validity is false. - let inverted_mask = child.validity_mask().not(); + // The expr::Mask semantics are: mask=true means retain, mask=false means null. + // But compute::mask has: mask=true means null, mask=false means retain. + // So we need to invert the mask before passing to compute::mask. + let mask = expr.child(1).evaluate(scope)?.to_bool().into_bit_buffer(); - crate::compute::mask(&child, &inverted_mask) + crate::compute::mask(&child, &vortex_mask::Mask::from_buffer(!mask)) } fn execute(&self, _options: &Self::Options, args: ExecutionArgs) -> VortexResult { diff --git a/vortex-buffer/src/bit/buf.rs b/vortex-buffer/src/bit/buf.rs index c1472b6d1d5..558c5ccf34a 100644 --- a/vortex-buffer/src/bit/buf.rs +++ b/vortex-buffer/src/bit/buf.rs @@ -118,6 +118,11 @@ impl BitBuffer { } } + /// Create a bit buffer of `len` with `indices` set as true. + pub fn from_indices(len: usize, indices: &[usize]) -> BitBuffer { + BitBufferMut::from_indices(len, indices).freeze() + } + /// Create a new empty `BitBuffer`. pub fn empty() -> Self { Self::new_set(0) diff --git a/vortex-buffer/src/bit/buf_mut.rs b/vortex-buffer/src/bit/buf_mut.rs index c94d3027399..b9e1f63a78d 100644 --- a/vortex-buffer/src/bit/buf_mut.rs +++ b/vortex-buffer/src/bit/buf_mut.rs @@ -124,6 +124,14 @@ impl BitBufferMut { } } + /// Create a bit buffer of `len` with `indices` set as true. + pub fn from_indices(len: usize, indices: &[usize]) -> BitBufferMut { + let mut buf = BitBufferMut::new_unset(len); + // TODO(ngates): for dense indices, we can do better by collecting into u64s. + indices.iter().for_each(|&idx| buf.set(idx)); + buf + } + /// Invokes `f` with indexes `0..len` collecting the boolean results into a new `BitBufferMut` pub fn collect_bool bool>(len: usize, mut f: F) -> Self { let mut buffer = BufferMut::with_capacity(len.div_ceil(64) * 8);