Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions vortex-array/src/expr/exprs/between.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_err;
use vortex_proto::expr as pb;
use vortex_vector::Datum;

use crate::ArrayRef;
use crate::IntoArray;
use crate::compute::BetweenOptions;
use crate::compute::between as between_compute;
use crate::expr::Arity;
use crate::expr::ChildName;
use crate::expr::ExecutionArgs;
use crate::expr::ExecutionResult;
use crate::expr::ExprId;
use crate::expr::StatsCatalog;
use crate::expr::VTable;
Expand Down Expand Up @@ -150,38 +151,35 @@ impl VTable for Between {
between_compute(&arr, &lower, &upper, options)
}

fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult<Datum> {
let [arr, lower, upper]: [Datum; _] = args
.datums
fn execute(
&self,
options: &Self::Options,
args: ExecutionArgs,
) -> VortexResult<ExecutionResult> {
let [arr, lower, upper]: [ArrayRef; _] = args
.inputs
.try_into()
.map_err(|_| vortex_err!("Expected 3 arguments for Between expression",))?;
let [arr_dt, lower_dt, upper_dt]: [DType; _] = args
.dtypes
.try_into()
.map_err(|_| vortex_err!("Expected 3 dtypes for Between expression",))?;

let lower_bound = Binary
.bind(options.lower_strict.to_operator().into())
.execute(ExecutionArgs {
datums: vec![lower, arr.clone()],
dtypes: vec![lower_dt, arr_dt.clone()],
inputs: vec![lower.clone(), arr.clone()],
row_count: args.row_count,
return_dtype: args.return_dtype.clone(),
ctx: args.ctx,
})?;
let upper_bound = Binary
.bind(options.upper_strict.to_operator().into())
.execute(ExecutionArgs {
datums: vec![arr, upper],
dtypes: vec![arr_dt, upper_dt],
inputs: vec![arr, upper],
row_count: args.row_count,
return_dtype: args.return_dtype.clone(),
ctx: args.ctx,
})?;

Binary.bind(Operator::And).execute(ExecutionArgs {
datums: vec![lower_bound, upper_bound],
dtypes: vec![args.return_dtype.clone(), args.return_dtype.clone()],
inputs: vec![lower_bound.into_array(), upper_bound.into_array()],
row_count: args.row_count,
return_dtype: args.return_dtype,
ctx: args.ctx,
})
}

Expand Down
94 changes: 18 additions & 76 deletions vortex-array/src/expr/exprs/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use crate::compute::sub;
use crate::expr::Arity;
use crate::expr::ChildName;
use crate::expr::ExecutionArgs;
use crate::expr::ExecutionResult;
use crate::expr::ExprId;
use crate::expr::StatsCatalog;
use crate::expr::VTable;
Expand Down Expand Up @@ -131,86 +132,27 @@ impl VTable for Binary {
}
}

fn execute(&self, op: &Operator, args: ExecutionArgs) -> VortexResult<Datum> {
let [lhs, rhs]: [Datum; _] = args
.datums
fn execute(&self, op: &Operator, args: ExecutionArgs) -> VortexResult<ExecutionResult> {
let [lhs, rhs]: [ArrayRef; _] = args
.inputs
.try_into()
.map_err(|_| vortex_err!("Wrong arg count"))?;

// Handle logical operators.
match op {
Operator::And => {
return Ok(LogicalAndKleene::and_kleene(&lhs.into_bool(), &rhs.into_bool()).into());
}
Operator::Or => {
return Ok(LogicalOrKleene::or_kleene(&lhs.into_bool(), &rhs.into_bool()).into());
}
_ => {}
}

// Arrow's vectorized comparison kernels (`cmp::eq`, etc.) don't support nested types
// (Struct, List, FixedSizeList). For those, we use `compare_nested_arrow_arrays` which does
// element-wise comparison via `make_comparator`.
if let Some(cmp_op) = op.maybe_cmp_operator()
&& (lhs.is_nested() || rhs.is_nested())
{
// Treat scalars as 1-element arrow arrays.
let lhs_arr = lhs.into_arrow()?;
let rhs_arr = rhs.into_arrow()?;

let bool_array = compare_nested_arrow_arrays(lhs_arr.get().0, rhs_arr.get().0, cmp_op)?;
let vector = bool_array.into_vector()?;

let both_are_scalar = lhs_arr.get().1 && rhs_arr.get().1;

return Ok(if both_are_scalar {
Datum::Scalar(vortex_vector::Scalar::Bool(vector.scalar_at(0)))
} else {
Datum::Vector(vortex_vector::Vector::Bool(vector))
});
}

let lhs = lhs.into_arrow()?;
let rhs = rhs.into_arrow()?;

let vector = match op {
// Handle comparison operators.
Operator::Eq => cmp::eq(lhs.as_ref(), rhs.as_ref())?.into_vector()?.into(),
Operator::NotEq => cmp::neq(lhs.as_ref(), rhs.as_ref())?.into_vector()?.into(),
Operator::Gt => cmp::gt(lhs.as_ref(), rhs.as_ref())?.into_vector()?.into(),
Operator::Gte => cmp::gt_eq(lhs.as_ref(), rhs.as_ref())?
.into_vector()?
.into(),
Operator::Lt => cmp::lt(lhs.as_ref(), rhs.as_ref())?.into_vector()?.into(),
Operator::Lte => cmp::lt_eq(lhs.as_ref(), rhs.as_ref())?
.into_vector()?
.into(),

// Handle arithmetic operators.
Operator::Add => {
arrow_arith::numeric::add(lhs.as_ref(), rhs.as_ref())?.into_vector()?
}
Operator::Sub => {
arrow_arith::numeric::sub(lhs.as_ref(), rhs.as_ref())?.into_vector()?
}
Operator::Mul => {
arrow_arith::numeric::mul(lhs.as_ref(), rhs.as_ref())?.into_vector()?
}
Operator::Div => {
arrow_arith::numeric::div(lhs.as_ref(), rhs.as_ref())?.into_vector()?
}

// Logical operators were handled above.
Operator::And | Operator::Or => unreachable!("Already dealt with above"),
};

let both_are_scalar = lhs.get().1 && rhs.get().1;

Ok(if both_are_scalar {
Datum::Scalar(vector.scalar_at(0))
} else {
Datum::Vector(vector)
})
Operator::Eq => compare(&lhs, &rhs, compute::Operator::Eq),
Operator::NotEq => compare(&lhs, &rhs, compute::Operator::NotEq),
Operator::Lt => compare(&lhs, &rhs, compute::Operator::Lt),
Operator::Lte => compare(&lhs, &rhs, compute::Operator::Lte),
Operator::Gt => compare(&lhs, &rhs, compute::Operator::Gt),
Operator::Gte => compare(&lhs, &rhs, compute::Operator::Gte),
Operator::And => and_kleene(&lhs, &rhs),
Operator::Or => or_kleene(&lhs, &rhs),
Operator::Add => add(&lhs, &rhs),
Operator::Sub => sub(&lhs, &rhs),
Operator::Mul => mul(&lhs, &rhs),
Operator::Div => div(&lhs, &rhs),
}?
.execute(args.ctx)
}

fn stat_falsification(
Expand Down
12 changes: 8 additions & 4 deletions vortex-array/src/expr/exprs/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_err;
use vortex_proto::expr as pb;
use vortex_vector::Datum;

use crate::ArrayRef;
use crate::compute::cast as compute_cast;
use crate::expr::Arity;
use crate::expr::ChildName;
use crate::expr::ExecutionArgs;
use crate::expr::ExecutionResult;
use crate::expr::ExprId;
use crate::expr::ReduceCtx;
use crate::expr::ReduceNode;
Expand Down Expand Up @@ -93,12 +93,16 @@ impl VTable for Cast {
})
}

fn execute(&self, target_dtype: &DType, mut args: ExecutionArgs) -> VortexResult<Datum> {
fn execute(
&self,
target_dtype: &DType,
mut args: ExecutionArgs,
) -> VortexResult<ExecutionResult> {
let input = args
.datums
.inputs
.pop()
.vortex_expect("missing input for Cast expression");
vortex_compute::cast::Cast::cast(&input, target_dtype)
compute_cast(input.as_ref(), target_dtype)?.execute(args.ctx)
}

fn reduce(
Expand Down
25 changes: 11 additions & 14 deletions vortex-array/src/expr/exprs/dynamic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_scalar::Scalar;
use vortex_scalar::ScalarValue;
use vortex_vector::Datum;
use vortex_vector::Scalar as VectorScalar;
use vortex_vector::bool::BoolScalar;

use crate::Array;
use crate::ArrayRef;
Expand All @@ -29,6 +26,7 @@ use crate::expr::Arity;
use crate::expr::Binary;
use crate::expr::ChildName;
use crate::expr::ExecutionArgs;
use crate::expr::ExecutionResult;
use crate::expr::ExprId;
use crate::expr::Expression;
use crate::expr::StatsCatalog;
Expand Down Expand Up @@ -118,26 +116,25 @@ impl VTable for DynamicComparison {
.into_array())
}

fn execute(&self, data: &Self::Options, args: ExecutionArgs) -> VortexResult<Datum> {
fn execute(&self, data: &Self::Options, args: ExecutionArgs) -> VortexResult<ExecutionResult> {
if let Some(scalar) = data.rhs.scalar() {
let [lhs]: [Datum; _] = args
.datums
let [lhs]: [ArrayRef; _] = args
.inputs
.try_into()
.map_err(|_| vortex_error::vortex_err!("Wrong arg count for DynamicComparison"))?;
let rhs_vector_scalar = scalar.to_vector_scalar();
let rhs = Datum::Scalar(rhs_vector_scalar);
let rhs = ConstantArray::new(scalar.clone(), args.row_count).into_array();

return Binary.bind(data.operator.into()).execute(ExecutionArgs {
datums: vec![lhs, rhs],
dtypes: args.dtypes,
inputs: vec![lhs, rhs],
row_count: args.row_count,
return_dtype: args.return_dtype,
ctx: args.ctx,
});
}

Ok(Datum::Scalar(VectorScalar::Bool(BoolScalar::new(Some(
data.default,
)))))
Ok(ExecutionResult::Scalar(ConstantArray::new(
false,
args.row_count,
)))
}

fn stat_falsification(
Expand Down
42 changes: 19 additions & 23 deletions vortex-array/src/expr/exprs/get_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,17 @@ use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_err;
use vortex_proto::expr as pb;
use vortex_vector::Datum;
use vortex_vector::ScalarOps;
use vortex_vector::VectorOps;

use crate::ArrayRef;
use crate::ToCanonical;
use crate::arrays::StructArray;
use crate::builtins::ExprBuiltins;
use crate::compute::mask;
use crate::expr::Arity;
use crate::expr::ChildName;
use crate::expr::EmptyOptions;
use crate::expr::ExecutionArgs;
use crate::expr::ExecutionResult;
use crate::expr::ExprId;
use crate::expr::Expression;
use crate::expr::Literal;
Expand Down Expand Up @@ -119,26 +118,23 @@ impl VTable for GetItem {
}
}

fn execute(&self, field_name: &FieldName, mut args: ExecutionArgs) -> VortexResult<Datum> {
let struct_dtype = args.dtypes[0]
.as_struct_fields_opt()
.ok_or_else(|| vortex_err!("Expected struct dtype for child of GetItem expression"))?;
let field_idx = struct_dtype
.find(field_name)
.ok_or_else(|| vortex_err!("Field {} not found in struct dtype", field_name))?;

match args.datums.pop().vortex_expect("missing input") {
Datum::Scalar(s) => {
let mut field = s.as_struct().field(field_idx);
field.mask_validity(s.is_valid());
Ok(Datum::Scalar(field))
}
Datum::Vector(v) => {
let mut field = v.as_struct().fields()[field_idx].clone();
field.mask_validity(v.validity());
Ok(Datum::Vector(field))
}
}
fn execute(
&self,
field_name: &FieldName,
mut args: ExecutionArgs,
) -> VortexResult<ExecutionResult> {
let input = args
.inputs
.pop()
.vortex_expect("missing input for GetItem expression")
.execute::<StructArray>(args.ctx)?;
let field = input.field_by_name(field_name).cloned()?;

match input.dtype().nullability() {
Nullability::NonNullable => Ok(field),
Nullability::Nullable => mask(&field, &input.validity_mask().not()),
}?
.execute(args.ctx)
}

fn reduce(
Expand Down
31 changes: 17 additions & 14 deletions vortex-array/src/expr/exprs/is_null.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,11 @@
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use std::fmt::Formatter;
use std::ops::Not;

use vortex_dtype::DType;
use vortex_dtype::Nullability;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_mask::Mask;
use vortex_vector::Datum;
use vortex_vector::ScalarOps;
use vortex_vector::VectorOps;
use vortex_vector::bool::BoolScalar;
use vortex_vector::bool::BoolVector;

use crate::Array;
use crate::ArrayRef;
Expand All @@ -24,6 +17,7 @@ use crate::expr::Arity;
use crate::expr::ChildName;
use crate::expr::EmptyOptions;
use crate::expr::ExecutionArgs;
use crate::expr::ExecutionResult;
use crate::expr::ExprId;
use crate::expr::Expression;
use crate::expr::StatsCatalog;
Expand Down Expand Up @@ -94,13 +88,22 @@ impl VTable for IsNull {
})
}

fn execute(&self, _data: &Self::Options, mut args: ExecutionArgs) -> VortexResult<Datum> {
let child = args.datums.pop().vortex_expect("Missing input child");
Ok(match child {
Datum::Scalar(s) => Datum::Scalar(BoolScalar::new(Some(s.is_null())).into()),
Datum::Vector(v) => Datum::Vector(
BoolVector::new(v.validity().to_bit_buffer().not(), Mask::new_true(v.len())).into(),
),
fn execute(
&self,
_data: &Self::Options,
mut args: ExecutionArgs,
) -> VortexResult<ExecutionResult> {
let child = args.inputs.pop().vortex_expect("Missing input child");
if let Some(scalar) = child.as_constant() {
return Ok(ExecutionResult::constant(scalar.is_null(), args.row_count));
}

Ok(match child.validity()? {
Validity::NonNullable | Validity::AllValid => {
ExecutionResult::constant(false, args.row_count)
}
Validity::AllInvalid => ExecutionResult::constant(true, args.row_count),
Validity::Array(a) => a.not()?.execute(args.ctx)?,
})
}

Expand Down
Loading
Loading