Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.{ColumnarArray, ColumnarBatch}
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarArray, ColumnarBatch}

import org.apache.comet.CometArrowAllocator
import org.apache.comet.vector.NativeUtil
import org.apache.comet.vector.{CometVector, NativeUtil}

object CometArrowConverters extends Logging {
// This is similar how Spark converts internal row to Arrow format except that it is transforming
Expand Down Expand Up @@ -185,6 +185,29 @@ object CometArrowConverters extends Logging {
}
}

/**
* Attempts zero-copy conversion of a ColumnarBatch whose columns are all ArrowColumnVector
* instances. Returns Some(iterator) if successful, None if the batch is not Arrow-backed.
*/
def tryZeroCopyConvert(batch: ColumnarBatch): Option[Iterator[ColumnarBatch]] = {
val numCols = batch.numCols()
if (numCols == 0) return None

// Check that every column is an ArrowColumnVector
var i = 0
while (i < numCols) {
if (!batch.column(i).isInstanceOf[ArrowColumnVector]) return None
i += 1
}

// All columns are Arrow-backed; wrap their ValueVectors as CometVectors (zero-copy)
val cometVectors = (0 until numCols).map { idx =>
val valueVector = batch.column(idx).asInstanceOf[ArrowColumnVector].getValueVector
CometVector.getVector(valueVector, true, null)
}
Some(Iterator(new ColumnarBatch(cometVectors.toArray, batch.numRows())))
}

def columnarBatchToArrowBatchIter(
colBatch: ColumnarBatch,
schema: StructType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,16 @@ case class CometSparkToColumnarExec(child: SparkPlan)
.mapPartitionsInternal { sparkBatches =>
val arrowBatches =
sparkBatches.flatMap { sparkBatch =>
val context = TaskContext.get()
CometArrowConverters.columnarBatchToArrowBatchIter(
sparkBatch,
schema,
maxRecordsPerBatch,
timeZoneId,
context)
CometArrowConverters.tryZeroCopyConvert(sparkBatch).getOrElse {
// Fallback: element-by-element copy via ArrowWriter
val context = TaskContext.get()
CometArrowConverters.columnarBatchToArrowBatchIter(
sparkBatch,
schema,
maxRecordsPerBatch,
timeZoneId,
context)
}
}
createTimingIter(arrowBatches, numInputRows, numOutputBatches, conversionTime)
}
Expand Down
72 changes: 72 additions & 0 deletions spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2159,6 +2159,78 @@ class CometExecSuite extends CometTestBase {
}
}

test("SparkToColumnar zero-copy for ArrowColumnVector input") {
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.{IntVector, VarCharVector}
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
import org.apache.spark.sql.comet.execution.arrow.CometArrowConverters
import org.apache.comet.vector.CometVector

val allocator = new RootAllocator(Long.MaxValue)
try {
// Create Arrow vectors with test data
val intVector = new IntVector("intCol", allocator)
intVector.allocateNew(3)
intVector.set(0, 10)
intVector.set(1, 20)
intVector.setNull(2)
intVector.setValueCount(3)

val varcharVector = new VarCharVector("strCol", allocator)
varcharVector.allocateNew()
varcharVector.setSafe(0, "hello".getBytes)
varcharVector.setSafe(1, "world".getBytes)
varcharVector.setNull(2)
varcharVector.setValueCount(3)

// Wrap in Spark's ArrowColumnVector
val arrowCol0 = new ArrowColumnVector(intVector)
val arrowCol1 = new ArrowColumnVector(varcharVector)
val inputBatch = new ColumnarBatch(Array(arrowCol0, arrowCol1), 3)

// Zero-copy conversion should succeed
val result = CometArrowConverters.tryZeroCopyConvert(inputBatch)
assert(result.isDefined, "Should detect ArrowColumnVector and return Some")

val outputBatch = result.get.next()
assert(outputBatch.numRows() == 3)
assert(outputBatch.numCols() == 2)

// Verify columns are CometVectors wrapping the same underlying ValueVectors (zero-copy)
val outCol0 = outputBatch.column(0).asInstanceOf[CometVector]
val outCol1 = outputBatch.column(1).asInstanceOf[CometVector]
assert(outCol0.getValueVector eq intVector, "Should be the same ValueVector instance")
assert(outCol1.getValueVector eq varcharVector, "Should be the same ValueVector instance")

// Verify data is accessible through the CometVector wrappers
assert(outCol0.getInt(0) == 10)
assert(outCol0.getInt(1) == 20)
assert(outCol0.isNullAt(2))
assert(outCol1.getUTF8String(0).toString == "hello")
assert(outCol1.getUTF8String(1).toString == "world")
assert(outCol1.isNullAt(2))

inputBatch.close()
} finally {
allocator.close()
}
}

test("SparkToColumnar tryZeroCopyConvert returns None for non-Arrow batches") {
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.sql.comet.execution.arrow.CometArrowConverters
import org.apache.spark.sql.types.IntegerType

val sparkCol = new OnHeapColumnVector(10, IntegerType)
val batch = new ColumnarBatch(Array(sparkCol), 10)

val result = CometArrowConverters.tryZeroCopyConvert(batch)
assert(result.isEmpty, "Should return None for non-ArrowColumnVector batches")

batch.close()
}

test("LocalTableScanExec spark fallback") {
withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "false") {
val df = Seq.range(0, 10).toDF("id")
Expand Down
Loading