Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use crate::DecimalBytePartsVTable;

impl MaskKernel for DecimalBytePartsVTable {
fn mask(&self, array: &DecimalBytePartsArray, mask_array: &Mask) -> VortexResult<ArrayRef> {
DecimalBytePartsArray::try_new(mask(&array.msp, mask_array)?, *array.decimal_dtype())
.map(|a| a.to_array())
let masked = mask(&array.msp, mask_array)?;
DecimalBytePartsArray::try_new(masked, *array.decimal_dtype()).map(|a| a.to_array())
}
}

Expand Down
37 changes: 29 additions & 8 deletions vortex-array/src/arrays/chunked/compute/mask.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use itertools::Itertools as _;
use vortex_buffer::BitBuffer;
use vortex_buffer::BitBufferMut;
use vortex_dtype::DType;
use vortex_error::VortexResult;
use vortex_mask::AllOr;
Expand All @@ -12,18 +14,20 @@ use vortex_scalar::Scalar;
use super::filter::ChunkFilter;
use super::filter::chunk_filters;
use super::filter::find_chunk_idx;
use crate::Array;
use crate::ArrayRef;
use crate::IntoArray;
use crate::arrays::BoolArray;
use crate::arrays::ChunkedArray;
use crate::arrays::ChunkedVTable;
use crate::arrays::ConstantArray;
use crate::arrays::chunked::compute::filter::FILTER_SLICES_SELECTIVITY_THRESHOLD;
use crate::builtins::ArrayBuiltins;
use crate::compute::MaskKernel;
use crate::compute::MaskKernelAdapter;
use crate::compute::cast;
use crate::compute::mask;
use crate::register_kernel;
use crate::validity::Validity;

impl MaskKernel for ChunkedVTable {
fn mask(&self, array: &ChunkedArray, mask: &Mask) -> VortexResult<ArrayRef> {
Expand Down Expand Up @@ -54,24 +58,33 @@ fn mask_indices(
) -> VortexResult<Vec<ArrayRef>> {
let mut new_chunks = Vec::with_capacity(array.nchunks());
let mut current_chunk_id = 0;
let mut chunk_indices = Vec::new();
let mut chunk_indices = Vec::<usize>::new();

let chunk_offsets = array.chunk_offsets();

for &set_index in indices {
let (chunk_id, index) = find_chunk_idx(set_index, &chunk_offsets);
if chunk_id != current_chunk_id {
let chunk = array.chunk(current_chunk_id);
let masked_chunk = mask(chunk, &Mask::from_indices(chunk.len(), chunk_indices))?;
let chunk = array.chunk(current_chunk_id).clone();
let chunk_len = chunk.len();
// chunk_indices contains indices to null out, but chunk.mask() expects
// mask=true to mean "retain". So we create a mask with bits set at indices
// to null, then invert it to get mask=true at indices to retain.
let mask = BoolArray::new(
!BitBuffer::from_indices(chunk_len, &chunk_indices),
Validity::NonNullable,
)
.into_array();
let masked_chunk = chunk.mask(mask)?;
// Advance the chunk forward, reset the chunk indices buffer.
chunk_indices = Vec::new();
new_chunks.push(masked_chunk);
current_chunk_id += 1;

while current_chunk_id < chunk_id {
// Chunks that are not affected by the mask, must still be casted to the correct dtype.
let chunk = array.chunk(current_chunk_id);
new_chunks.push(cast(chunk, new_dtype)?);
let chunk = array.chunk(current_chunk_id).cast(new_dtype.clone())?;
new_chunks.push(chunk);
current_chunk_id += 1;
}
}
Expand All @@ -80,8 +93,16 @@ fn mask_indices(
}

if !chunk_indices.is_empty() {
let chunk = array.chunk(current_chunk_id);
let masked_chunk = mask(chunk, &Mask::from_indices(chunk.len(), chunk_indices))?;
let chunk = array.chunk(current_chunk_id).clone();
let chunk_len = chunk.len();
// Same inversion as above: invert the mask so mask=true means "retain"
let masked_chunk = chunk.mask(
BoolArray::new(
!BitBufferMut::from_indices(chunk_len, &chunk_indices).freeze(),
Validity::NonNullable,
)
.into_array(),
)?;
new_chunks.push(masked_chunk);
current_chunk_id += 1;
}
Expand Down
5 changes: 3 additions & 2 deletions vortex-array/src/arrays/extension/compute/mask.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ use crate::arrays::ExtensionArray;
use crate::arrays::ExtensionVTable;
use crate::compute::MaskKernel;
use crate::compute::MaskKernelAdapter;
use crate::compute::{self};
use crate::compute::mask;
use crate::register_kernel;

impl MaskKernel for ExtensionVTable {
fn mask(&self, array: &ExtensionArray, mask_array: &Mask) -> VortexResult<ArrayRef> {
let masked_storage = compute::mask(array.storage(), mask_array)?;
// Use compute::mask directly since mask_array has compute::mask semantics (true=null)
let masked_storage = mask(array.storage(), mask_array)?;
if masked_storage.dtype().nullability() == array.ext_dtype().storage_dtype().nullability() {
Ok(ExtensionArray::new(array.ext_dtype().clone(), masked_storage).into_array())
} else {
Expand Down
6 changes: 3 additions & 3 deletions vortex-array/src/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ pub trait ArrayBuiltins: Sized {
/// Mask the array using the given boolean mask.
/// The resulting array's validity is the intersection of the original array's validity
/// and the mask's validity.
fn mask(&self, mask: &ArrayRef) -> VortexResult<ArrayRef>;
fn mask(self, mask: ArrayRef) -> VortexResult<ArrayRef>;

/// Boolean negation.
fn not(&self) -> VortexResult<ArrayRef>;
Expand All @@ -105,8 +105,8 @@ impl ArrayBuiltins for ArrayRef {
.optimize()
}

fn mask(&self, mask: &ArrayRef) -> VortexResult<ArrayRef> {
Mask.try_new_array(self.len(), EmptyOptions, [self.clone(), mask.clone()])?
fn mask(self, mask: ArrayRef) -> VortexResult<ArrayRef> {
Mask.try_new_array(self.len(), EmptyOptions, [self, mask])?
.optimize()
}

Expand Down
11 changes: 6 additions & 5 deletions vortex-array/src/expr/exprs/mask.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use std::fmt::Formatter;
use std::ops::Not;

use vortex_dtype::DType;
use vortex_dtype::Nullability;
Expand All @@ -17,8 +16,8 @@ use vortex_vector::ScalarOps;
use vortex_vector::VectorMutOps;
use vortex_vector::VectorOps;

use crate::Array;
use crate::ArrayRef;
use crate::ToCanonical;
use crate::expr::Arity;
use crate::expr::ChildName;
use crate::expr::EmptyOptions;
Expand Down Expand Up @@ -95,10 +94,12 @@ impl VTable for Mask {
) -> VortexResult<ArrayRef> {
let child = expr.child(0).evaluate(scope)?;

// Invert the validity mask - we want to set values to null where validity is false.
let inverted_mask = child.validity_mask().not();
// The expr::Mask semantics are: mask=true means retain, mask=false means null.
// But compute::mask has: mask=true means null, mask=false means retain.
// So we need to invert the mask before passing to compute::mask.
let mask = expr.child(1).evaluate(scope)?.to_bool().into_bit_buffer();

crate::compute::mask(&child, &inverted_mask)
crate::compute::mask(&child, &vortex_mask::Mask::from_buffer(!mask))
}

fn execute(&self, _options: &Self::Options, args: ExecutionArgs) -> VortexResult<Datum> {
Expand Down
5 changes: 5 additions & 0 deletions vortex-buffer/src/bit/buf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ impl BitBuffer {
}
}

/// Create a bit buffer of `len` with `indices` set as true.
pub fn from_indices(len: usize, indices: &[usize]) -> BitBuffer {
BitBufferMut::from_indices(len, indices).freeze()
}

/// Create a new empty `BitBuffer`.
pub fn empty() -> Self {
Self::new_set(0)
Expand Down
8 changes: 8 additions & 0 deletions vortex-buffer/src/bit/buf_mut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ impl BitBufferMut {
}
}

/// Create a bit buffer of `len` with `indices` set as true.
pub fn from_indices(len: usize, indices: &[usize]) -> BitBufferMut {
let mut buf = BitBufferMut::new_unset(len);
// TODO(ngates): for dense indices, we can do better by collecting into u64s.
indices.iter().for_each(|&idx| buf.set(idx));
buf
}

/// Invokes `f` with indexes `0..len` collecting the boolean results into a new `BitBufferMut`
pub fn collect_bool<F: FnMut(usize) -> bool>(len: usize, mut f: F) -> Self {
let mut buffer = BufferMut::with_capacity(len.div_ceil(64) * 8);
Expand Down
Loading