Skip to content
Open
1 change: 1 addition & 0 deletions docs/source/user-guide/latest/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ These settings can be used to determine which parts of the plan are accelerated
| `spark.comet.expression.MakeDecimal.enabled` | Enable Comet acceleration for `MakeDecimal` | true |
| `spark.comet.expression.MapEntries.enabled` | Enable Comet acceleration for `MapEntries` | true |
| `spark.comet.expression.MapFromArrays.enabled` | Enable Comet acceleration for `MapFromArrays` | true |
| `spark.comet.expression.MapFromEntries.enabled` | Enable Comet acceleration for `MapFromEntries` | true |
| `spark.comet.expression.MapKeys.enabled` | Enable Comet acceleration for `MapKeys` | true |
| `spark.comet.expression.MapValues.enabled` | Enable Comet acceleration for `MapValues` | true |
| `spark.comet.expression.Md5.enabled` | Enable Comet acceleration for `Md5` | true |
Expand Down
22 changes: 21 additions & 1 deletion fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,19 @@ case class Function(name: String, signatures: Seq[FunctionSignature])

object Meta {

val primitiveSparkTypes: Seq[SparkType] = Seq(
SparkBooleanType,
SparkBinaryType,
SparkStringType,
SparkByteType,
SparkShortType,
SparkIntType,
SparkLongType,
SparkFloatType,
SparkDoubleType,
SparkDateType,
SparkTimestampType)

val dataTypes: Seq[(DataType, Double)] = Seq(
(DataTypes.BooleanType, 0.1),
(DataTypes.ByteType, 0.2),
Expand Down Expand Up @@ -212,7 +225,14 @@ object Meta {
createFunctionWithInputTypes("map_values", Seq(SparkMapType(SparkAnyType, SparkAnyType))),
createFunctionWithInputTypes(
"map_from_arrays",
Seq(SparkArrayType(SparkAnyType), SparkArrayType(SparkAnyType))))
Seq(SparkArrayType(SparkAnyType), SparkArrayType(SparkAnyType))),
createFunctionWithInputTypes(
"map_from_entries",
Seq(
SparkArrayType(
SparkStructType(Seq(
SparkTypeOneOf(primitiveSparkTypes.filterNot(_ == SparkBinaryType)),
SparkTypeOneOf(primitiveSparkTypes.filterNot(_ == SparkBinaryType))))))))

// Predicate expressions (corresponds to predicateExpressions in QueryPlanSerde)
val predicateScalarFunc: Seq[Function] = Seq(
Expand Down
2 changes: 2 additions & 0 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ use datafusion_spark::function::datetime::date_add::SparkDateAdd;
use datafusion_spark::function::datetime::date_sub::SparkDateSub;
use datafusion_spark::function::hash::sha1::SparkSha1;
use datafusion_spark::function::hash::sha2::SparkSha2;
use datafusion_spark::function::map::map_from_entries::MapFromEntries;
use datafusion_spark::function::math::expm1::SparkExpm1;
use datafusion_spark::function::math::hex::SparkHex;
use datafusion_spark::function::string::char::CharFunc;
Expand Down Expand Up @@ -342,6 +343,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) {
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkConcat::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseNot::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkHex::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(MapFromEntries::default()));
}

/// Prepares arrow arrays for output.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ object QueryPlanSerde extends Logging with CometExprShim {
classOf[MapKeys] -> CometMapKeys,
classOf[MapEntries] -> CometMapEntries,
classOf[MapValues] -> CometMapValues,
classOf[MapFromArrays] -> CometMapFromArrays)
classOf[MapFromArrays] -> CometMapFromArrays,
classOf[MapFromEntries] -> CometMapFromEntries)

private val structExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(
classOf[CreateNamedStruct] -> CometCreateNamedStruct,
Expand Down
26 changes: 25 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/maps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
package org.apache.comet.serde

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{ArrayType, MapType}
import org.apache.spark.sql.types._

import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType}

Expand Down Expand Up @@ -89,3 +89,27 @@ object CometMapFromArrays extends CometExpressionSerde[MapFromArrays] {
optExprWithInfo(mapFromArraysExpr, expr, expr.children: _*)
}
}

object CometMapFromEntries extends CometScalarFunction[MapFromEntries]("map_from_entries") {
val keyUnsupportedReason = "Using BinaryType as Map keys is not allowed in map_from_entries"
val valueUnsupportedReason = "Using BinaryType as Map values is not allowed in map_from_entries"

private def containsBinary(dataType: DataType): Boolean = {
dataType match {
case BinaryType => true
case StructType(fields) => fields.exists(field => containsBinary(field.dataType))
case ArrayType(elementType, _) => containsBinary(elementType)
case _ => false
}
}

override def getSupportLevel(expr: MapFromEntries): SupportLevel = {
if (containsBinary(expr.dataType.keyType)) {
return Incompatible(Some(keyUnsupportedReason))
}
if (containsBinary(expr.dataType.valueType)) {
return Incompatible(Some(valueUnsupportedReason))
}
Compatible(None)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.sql.CometTestBase
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.BinaryType

import org.apache.comet.serde.CometMapFromEntries
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator, SchemaGenOptions}

class CometMapExpressionSuite extends CometTestBase {
Expand Down Expand Up @@ -157,4 +159,61 @@ class CometMapExpressionSuite extends CometTestBase {
}
}

test("map_from_entries") {
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
val filename = path.toString
val random = new Random(42)
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
val schemaGenOptions =
SchemaGenOptions(
generateArray = false,
generateStruct = false,
primitiveTypes = SchemaGenOptions.defaultPrimitiveTypes.filterNot(_ == BinaryType))
val dataGenOptions = DataGenOptions(allowNull = false, generateNegativeZero = false)
ParquetGenerator.makeParquetFile(
random,
spark,
filename,
100,
schemaGenOptions,
dataGenOptions)
}
withSQLConf(
CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false",
CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true",
CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.key -> "true") {
val df = spark.read.parquet(filename)
df.createOrReplaceTempView("t1")
for (field <- df.schema.fieldNames) {
checkSparkAnswerAndOperator(
spark.sql(
s"SELECT map_from_entries(array(struct($field as a, $field as b))) FROM t1"))
}
}
}
}

test("map_from_entries - fallback for binary type") {
def fallbackReason(reason: String) = {
if (CometConf.COMET_NATIVE_SCAN_IMPL.key == CometConf.SCAN_NATIVE_COMET || sys.env
.getOrElse("COMET_PARQUET_SCAN_IMPL", "") == CometConf.SCAN_NATIVE_COMET) {
"Unsupported schema"
} else {
reason
}
}
val table = "t2"
withTable(table) {
sql(
s"create table $table using parquet as select cast(array() as array<binary>) as c1 from range(10)")
checkSparkAnswerAndFallbackReason(
sql(s"select map_from_entries(array(struct(c1, 0))) from $table"),
fallbackReason(CometMapFromEntries.keyUnsupportedReason))
checkSparkAnswerAndFallbackReason(
sql(s"select map_from_entries(array(struct(0, c1))) from $table"),
fallbackReason(CometMapFromEntries.valueUnsupportedReason))
}
}

}
Loading