Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 13 additions & 27 deletions vortex-array/src/arrays/scalar_fn/vtable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use std::hash::Hasher;
use std::marker::PhantomData;
use std::ops::Deref;
use std::ops::Range;
use std::sync::Arc;

use itertools::Itertools;
use vortex_dtype::DType;
Expand All @@ -35,11 +34,12 @@ use crate::executor::ExecutionCtx;
use crate::expr;
use crate::expr::Arity;
use crate::expr::ChildName;
use crate::expr::ExecutionArgs;
use crate::expr::ExecutionResult;
use crate::expr::ExprId;
use crate::expr::Expression;
use crate::expr::ScalarFn;
use crate::expr::VTableExt;
use crate::expr::lit;
use crate::matchers::Matcher;
use crate::serde::ArrayChildren;
use crate::vtable;
Expand Down Expand Up @@ -127,29 +127,16 @@ impl VTable for ScalarFnVTable {
}

fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<Canonical> {
let inputs: Arc<[_]> = array
.children
.iter()
.map(|child| {
if let Some(scalar) = child.as_constant() {
return Ok(lit(scalar));
}
Expression::try_new(
ScalarFn::new(
ArrayExpr,
FakeEq(child.clone().execute::<Canonical>(ctx)?.into_array()),
),
[],
)
})
.collect::<VortexResult<_>>()?;
let args = ExecutionArgs {
inputs: array.children.clone(),
row_count: array.len,
ctx,
};

array
.scalar_fn
.evaluate(
&Expression::try_new(array.scalar_fn.clone(), inputs)?,
&array.to_array(),
)?
.execute(args)?
.into_array()
.execute::<Canonical>(ctx)
}

Expand Down Expand Up @@ -340,13 +327,12 @@ impl expr::VTable for ArrayExpr {
Ok(options.0.dtype().clone())
}

fn evaluate(
fn execute(
&self,
options: &Self::Options,
_expr: &Expression,
_scope: &ArrayRef,
) -> VortexResult<ArrayRef> {
Ok(options.0.clone())
args: ExecutionArgs,
) -> VortexResult<ExecutionResult> {
crate::Executable::execute(options.0.clone(), args.ctx)
}

fn validity(
Expand Down
47 changes: 27 additions & 20 deletions vortex-array/src/arrays/scalar_fn/vtable/operations.rs
Original file line number Diff line number Diff line change
@@ -1,42 +1,49 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use std::sync::Arc;

use vortex_error::VortexExpect;
use vortex_scalar::Scalar;

use crate::Array;
use crate::IntoArray;
use crate::LEGACY_SESSION;
use crate::VortexSessionExecute;
use crate::arrays::ConstantArray;
use crate::arrays::scalar_fn::array::ScalarFnArray;
use crate::arrays::scalar_fn::vtable::ScalarFnVTable;
use crate::expr::Expression;
use crate::expr::lit;
use crate::expr::ExecutionArgs;
use crate::expr::ExecutionResult;
use crate::vtable::OperationsVTable;

impl OperationsVTable<ScalarFnVTable> for ScalarFnVTable {
fn scalar_at(array: &ScalarFnArray, index: usize) -> Scalar {
// TODO(ngates): we should evaluate the scalar function over the scalar inputs.
let inputs: Arc<[_]> = array
let inputs: Vec<_> = array
.children
.iter()
.map(|child| lit(child.scalar_at(index)))
.map(|child| ConstantArray::new(child.scalar_at(index), 1).into_array())
.collect::<_>();

let mut ctx = LEGACY_SESSION.create_execution_ctx();
let args = ExecutionArgs {
inputs,
row_count: 1,
ctx: &mut ctx,
};

let result = array
.scalar_fn
.evaluate(
&Expression::try_new(array.scalar_fn.clone(), inputs)
.vortex_expect("create expr must not fail"),
&array.to_array(),
)
.vortex_expect("execute cannot fail");
.execute(args)
.vortex_expect("todo vortex result return");

result.as_constant().unwrap_or_else(|| {
tracing::info!(
"Scalar function {} returned non-constant array from execution over all scalar inputs",
array.scalar_fn,
);
result.scalar_at(0)
})
match result {
ExecutionResult::Array(arr) => {
tracing::info!(
"Scalar function {} returned non-constant array from execution over all scalar inputs",
array.scalar_fn,
);
arr.as_ref().scalar_at(0)
}
ExecutionResult::Scalar(scalar) => scalar.scalar().clone(),
}
}
}
45 changes: 39 additions & 6 deletions vortex-array/src/arrays/scalar_fn/vtable/validity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,57 @@ use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_mask::Mask;

use crate::ArrayRef;
use crate::IntoArray;
use crate::LEGACY_SESSION;
use crate::VortexSessionExecute;
use crate::arrays::NullArray;
use crate::arrays::scalar_fn::array::ScalarFnArray;
use crate::arrays::scalar_fn::vtable::ArrayExpr;
use crate::arrays::scalar_fn::vtable::FakeEq;
use crate::arrays::scalar_fn::vtable::ScalarFnVTable;
use crate::executor::CanonicalOutput;
use crate::expr::ExecutionArgs;
use crate::expr::Expression;
use crate::expr::Literal;
use crate::expr::Root;
use crate::expr::ScalarFn;
use crate::expr::lit;
use crate::validity::Validity;
use crate::vtable::ValidityVTable;

/// Execute an expression tree recursively.
///
/// This assumes all leaf expressions are either ArrayExpr (wrapping actual arrays) or Literals.
fn execute_expr(expr: &Expression, row_count: usize) -> VortexResult<ArrayRef> {
let mut ctx = LEGACY_SESSION.create_execution_ctx();

// Handle Root expression - this should not happen in validity expressions
if expr.is::<Root>() {
vortex_error::vortex_bail!("Root expression cannot be executed in validity context");
}

// Handle Literal expression - create a constant array
if expr.is::<Literal>() {
let scalar = expr.as_::<Literal>();
return Ok(crate::arrays::ConstantArray::new(scalar.clone(), row_count).into_array());
}

// Recursively execute child expressions to get input arrays
let inputs: Vec<ArrayRef> = expr
.children()
.iter()
.map(|child| execute_expr(child, row_count))
.collect::<VortexResult<_>>()?;

let args = ExecutionArgs {
inputs,
row_count,
ctx: &mut ctx,
};

Ok(expr.scalar_fn().execute(args)?.into_array())
}

impl ValidityVTable<ScalarFnVTable> for ScalarFnVTable {
fn validity(array: &ScalarFnArray) -> VortexResult<Validity> {
let inputs: Arc<[_]> = array
Expand All @@ -38,11 +74,8 @@ impl ValidityVTable<ScalarFnVTable> for ScalarFnVTable {
let expr = Expression::try_new(array.scalar_fn.clone(), inputs)?;
let validity_expr = array.scalar_fn().validity(&expr)?;

// We can evaluate the validity expression against an empty scope because we know all
// leaves are ArrayExpr.
Ok(Validity::Array(
validity_expr.evaluate(&NullArray::new(array.len()).into_array())?,
))
// Execute the validity expression. All leaves are ArrayExpr nodes.
Ok(Validity::Array(execute_expr(&validity_expr, array.len())?))
}

fn validity_mask(array: &ScalarFnArray) -> Mask {
Expand Down
14 changes: 14 additions & 0 deletions vortex-array/src/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,20 @@ impl Canonical {
}
}

pub fn dtype(&self) -> &DType {
match self {
Canonical::Null(c) => c.dtype(),
Canonical::Bool(c) => c.dtype(),
Canonical::Primitive(c) => c.dtype(),
Canonical::Decimal(c) => c.dtype(),
Canonical::VarBinView(c) => c.dtype(),
Canonical::List(c) => c.dtype(),
Canonical::FixedSizeList(c) => c.dtype(),
Canonical::Struct(c) => c.dtype(),
Canonical::Extension(c) => c.dtype(),
}
}

pub fn is_empty(&self) -> bool {
match self {
Canonical::Null(c) => c.is_empty(),
Expand Down
2 changes: 2 additions & 0 deletions vortex-array/src/compute/list_contains.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ pub(crate) fn warm_up_vtable() -> usize {
///
/// assert_eq!(matches.to_bool().bit_buffer(), &bitbuffer![false, true, false]);
/// ```
// TODO(joe): ensure that list_contains_scalar from (548303761b4270b583ef34f6ca6e3c2b134a242a)
// is implemented here.
pub fn list_contains(array: &dyn Array, value: &dyn Array) -> VortexResult<ArrayRef> {
LIST_CONTAINS_FN
.invoke(&InvocationArgs {
Expand Down
32 changes: 31 additions & 1 deletion vortex-array/src/expr/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,41 @@ impl Expression {
}

/// Evaluates the expression in the given scope, returning an array.
///
/// This is a convenience method that recursively evaluates child expressions
/// and calls the scalar function's execute method.
pub fn evaluate(&self, scope: &ArrayRef) -> VortexResult<ArrayRef> {
use crate::IntoArray;
use crate::LEGACY_SESSION;
use crate::VortexSessionExecute;
use crate::arrays::ConstantArray;
use crate::expr::ExecutionArgs;
use crate::expr::Literal;

if self.is::<Root>() {
return Ok(scope.clone());
}
self.scalar_fn.evaluate(self, scope)

if self.is::<Literal>() {
let scalar = self.as_::<Literal>();
return Ok(ConstantArray::new(scalar.clone(), scope.len()).into_array());
}

// Recursively evaluate child expressions to get input arrays
let inputs: Vec<ArrayRef> = self
.children
.iter()
.map(|child| child.evaluate(scope))
.try_collect()?;

let mut ctx = LEGACY_SESSION.create_execution_ctx();
let args = ExecutionArgs {
inputs,
row_count: scope.len(),
ctx: &mut ctx,
};

Ok(self.scalar_fn.execute(args)?.into_array())
}

/// Returns a new expression representing the validity mask output of this expression.
Expand Down
52 changes: 12 additions & 40 deletions vortex-array/src/expr/exprs/between.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@ use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_err;
use vortex_proto::expr as pb;
use vortex_vector::Datum;

use crate::ArrayRef;
use crate::Canonical;
use crate::arrays::ConstantVTable;
use crate::compute::BetweenOptions;
use crate::compute::between as between_compute;
use crate::expr::Arity;
use crate::expr::ChildName;
use crate::expr::ExecutionArgs;
use crate::expr::ExecutionResult;
use crate::expr::ExprId;
use crate::expr::StatsCatalog;
use crate::expr::VTable;
Expand Down Expand Up @@ -138,50 +140,20 @@ impl VTable for Between {
))
}

fn evaluate(
fn execute(
&self,
options: &Self::Options,
expr: &Expression,
scope: &ArrayRef,
) -> VortexResult<ArrayRef> {
let arr = expr.child(0).evaluate(scope)?;
let lower = expr.child(1).evaluate(scope)?;
let upper = expr.child(2).evaluate(scope)?;
between_compute(&arr, &lower, &upper, options)
}

fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult<Datum> {
let [arr, lower, upper]: [Datum; _] = args
.datums
args: ExecutionArgs,
) -> VortexResult<ExecutionResult> {
let [arr, lower, upper]: [ArrayRef; _] = args
.inputs
.try_into()
.map_err(|_| vortex_err!("Expected 3 arguments for Between expression",))?;
let [arr_dt, lower_dt, upper_dt]: [DType; _] = args
.dtypes
.try_into()
.map_err(|_| vortex_err!("Expected 3 dtypes for Between expression",))?;

let lower_bound = Binary
.bind(options.lower_strict.to_operator().into())
.execute(ExecutionArgs {
datums: vec![lower, arr.clone()],
dtypes: vec![lower_dt, arr_dt.clone()],
row_count: args.row_count,
return_dtype: args.return_dtype.clone(),
})?;
let upper_bound = Binary
.bind(options.upper_strict.to_operator().into())
.execute(ExecutionArgs {
datums: vec![arr, upper],
dtypes: vec![arr_dt, upper_dt],
row_count: args.row_count,
return_dtype: args.return_dtype.clone(),
})?;

Binary.bind(Operator::And).execute(ExecutionArgs {
datums: vec![lower_bound, upper_bound],
dtypes: vec![args.return_dtype.clone(), args.return_dtype.clone()],
row_count: args.row_count,
return_dtype: args.return_dtype,
let result = between_compute(arr.as_ref(), lower.as_ref(), upper.as_ref(), options)?;
Ok(match result.try_into::<ConstantVTable>() {
Ok(constant) => ExecutionResult::Scalar(constant),
Err(arr) => ExecutionResult::Array(arr.execute::<Canonical>(args.ctx)?),
})
}

Expand Down
Loading
Loading