Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions vortex-cuda/src/canonical.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use async_trait::async_trait;
use vortex_array::Canonical;
use vortex_array::arrays::BoolArray;
use vortex_array::arrays::BoolArrayParts;
use vortex_array::arrays::DecimalArray;
use vortex_array::arrays::DecimalArrayParts;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::PrimitiveArrayParts;
use vortex_array::buffer::BufferHandle;
use vortex_error::VortexResult;

/// Move all canonical data from to_host from device.
#[async_trait]
pub trait CanonicalCudaExt {
async fn to_host(self) -> VortexResult<Self>
where
Self: Sized;
}

#[async_trait]
impl CanonicalCudaExt for Canonical {
async fn to_host(self) -> VortexResult<Self> {
match self {
n @ Canonical::Null(_) => Ok(n),
Canonical::Bool(bool) => {
// NOTE: update to copy to host when adding buffer handle.
// Also update other method to copy validity to host.
let BoolArrayParts { bits, validity, .. } = bool.into_parts();
Ok(Canonical::Bool(BoolArray::from_bit_buffer(bits, validity)))
}
Canonical::Primitive(prim) => {
let PrimitiveArrayParts {
ptype,
buffer,
validity,
..
} = prim.into_parts();
Ok(Canonical::Primitive(PrimitiveArray::from_byte_buffer(
buffer.try_into_host()?.await?,
ptype,
validity,
)))
}
Canonical::Decimal(decimal) => {
let DecimalArrayParts {
decimal_dtype,
values,
values_type,
validity,
..
} = decimal.into_parts();
Ok(Canonical::Decimal(unsafe {
DecimalArray::new_unchecked_handle(
BufferHandle::new_host(values.try_into_host()?.await?),
values_type,
decimal_dtype,
validity,
)
}))
}
_ => todo!(),
}
}
}
2 changes: 2 additions & 0 deletions vortex-cuda/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

//! CUDA support for Vortex arrays.

mod canonical;
mod device_buffer;
pub mod executor;
mod kernel;
mod session;
mod stream;

pub use canonical::CanonicalCudaExt;
pub use device_buffer::CudaBufferExt;
pub use device_buffer::CudaDeviceBuffer;
pub use executor::CudaExecutionCtx;
Expand Down
Loading