diff --git a/docs/source/user-guide/latest/configs.md b/docs/source/user-guide/latest/configs.md index 1a273ad033..8bfb4fb2ea 100644 --- a/docs/source/user-guide/latest/configs.md +++ b/docs/source/user-guide/latest/configs.md @@ -278,6 +278,7 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.MakeDecimal.enabled` | Enable Comet acceleration for `MakeDecimal` | true | | `spark.comet.expression.MapEntries.enabled` | Enable Comet acceleration for `MapEntries` | true | | `spark.comet.expression.MapFromArrays.enabled` | Enable Comet acceleration for `MapFromArrays` | true | +| `spark.comet.expression.MapFromEntries.enabled` | Enable Comet acceleration for `MapFromEntries` | true | | `spark.comet.expression.MapKeys.enabled` | Enable Comet acceleration for `MapKeys` | true | | `spark.comet.expression.MapValues.enabled` | Enable Comet acceleration for `MapValues` | true | | `spark.comet.expression.Md5.enabled` | Enable Comet acceleration for `Md5` | true | diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala index f8d591be28..32cba46e94 100644 --- a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala +++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala @@ -50,6 +50,19 @@ case class Function(name: String, signatures: Seq[FunctionSignature]) object Meta { + val primitiveSparkTypes: Seq[SparkType] = Seq( + SparkBooleanType, + SparkBinaryType, + SparkStringType, + SparkByteType, + SparkShortType, + SparkIntType, + SparkLongType, + SparkFloatType, + SparkDoubleType, + SparkDateType, + SparkTimestampType) + val dataTypes: Seq[(DataType, Double)] = Seq( (DataTypes.BooleanType, 0.1), (DataTypes.ByteType, 0.2), @@ -212,7 +225,14 @@ object Meta { createFunctionWithInputTypes("map_values", Seq(SparkMapType(SparkAnyType, SparkAnyType))), createFunctionWithInputTypes( "map_from_arrays", - Seq(SparkArrayType(SparkAnyType), SparkArrayType(SparkAnyType)))) + Seq(SparkArrayType(SparkAnyType), SparkArrayType(SparkAnyType))), + createFunctionWithInputTypes( + "map_from_entries", + Seq( + SparkArrayType( + SparkStructType(Seq( + SparkTypeOneOf(primitiveSparkTypes.filterNot(_ == SparkBinaryType)), + SparkTypeOneOf(primitiveSparkTypes.filterNot(_ == SparkBinaryType)))))))) // Predicate expressions (corresponds to predicateExpressions in QueryPlanSerde) val predicateScalarFunc: Seq[Function] = Seq( diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 04f7c809ae..f067264444 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -46,6 +46,7 @@ use datafusion_spark::function::datetime::date_add::SparkDateAdd; use datafusion_spark::function::datetime::date_sub::SparkDateSub; use datafusion_spark::function::hash::sha1::SparkSha1; use datafusion_spark::function::hash::sha2::SparkSha2; +use datafusion_spark::function::map::map_from_entries::MapFromEntries; use datafusion_spark::function::math::expm1::SparkExpm1; use datafusion_spark::function::math::hex::SparkHex; use datafusion_spark::function::string::char::CharFunc; @@ -342,6 +343,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) { session_ctx.register_udf(ScalarUDF::new_from_impl(SparkConcat::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseNot::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkHex::default())); + session_ctx.register_udf(ScalarUDF::new_from_impl(MapFromEntries::default())); } /// Prepares arrow arrays for output. diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index e50b1d80e6..86f610e08d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -126,7 +126,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[MapKeys] -> CometMapKeys, classOf[MapEntries] -> CometMapEntries, classOf[MapValues] -> CometMapValues, - classOf[MapFromArrays] -> CometMapFromArrays) + classOf[MapFromArrays] -> CometMapFromArrays, + classOf[MapFromEntries] -> CometMapFromEntries) private val structExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[CreateNamedStruct] -> CometCreateNamedStruct, diff --git a/spark/src/main/scala/org/apache/comet/serde/maps.scala b/spark/src/main/scala/org/apache/comet/serde/maps.scala index 2e217f6af0..78b2180756 100644 --- a/spark/src/main/scala/org/apache/comet/serde/maps.scala +++ b/spark/src/main/scala/org/apache/comet/serde/maps.scala @@ -20,7 +20,7 @@ package org.apache.comet.serde import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{ArrayType, MapType} +import org.apache.spark.sql.types._ import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType} @@ -89,3 +89,27 @@ object CometMapFromArrays extends CometExpressionSerde[MapFromArrays] { optExprWithInfo(mapFromArraysExpr, expr, expr.children: _*) } } + +object CometMapFromEntries extends CometScalarFunction[MapFromEntries]("map_from_entries") { + val keyUnsupportedReason = "Using BinaryType as Map keys is not allowed in map_from_entries" + val valueUnsupportedReason = "Using BinaryType as Map values is not allowed in map_from_entries" + + private def containsBinary(dataType: DataType): Boolean = { + dataType match { + case BinaryType => true + case StructType(fields) => fields.exists(field => containsBinary(field.dataType)) + case ArrayType(elementType, _) => containsBinary(elementType) + case _ => false + } + } + + override def getSupportLevel(expr: MapFromEntries): SupportLevel = { + if (containsBinary(expr.dataType.keyType)) { + return Incompatible(Some(keyUnsupportedReason)) + } + if (containsBinary(expr.dataType.valueType)) { + return Incompatible(Some(valueUnsupportedReason)) + } + Compatible(None) + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala index 9276a20348..5a1c31a488 100644 --- a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala @@ -25,7 +25,9 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.BinaryType +import org.apache.comet.serde.CometMapFromEntries import org.apache.comet.testing.{DataGenOptions, ParquetGenerator, SchemaGenOptions} class CometMapExpressionSuite extends CometTestBase { @@ -157,4 +159,61 @@ class CometMapExpressionSuite extends CometTestBase { } } + test("map_from_entries") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val schemaGenOptions = + SchemaGenOptions( + generateArray = false, + generateStruct = false, + primitiveTypes = SchemaGenOptions.defaultPrimitiveTypes.filterNot(_ == BinaryType)) + val dataGenOptions = DataGenOptions(allowNull = false, generateNegativeZero = false) + ParquetGenerator.makeParquetFile( + random, + spark, + filename, + 100, + schemaGenOptions, + dataGenOptions) + } + withSQLConf( + CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false", + CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true", + CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.key -> "true") { + val df = spark.read.parquet(filename) + df.createOrReplaceTempView("t1") + for (field <- df.schema.fieldNames) { + checkSparkAnswerAndOperator( + spark.sql( + s"SELECT map_from_entries(array(struct($field as a, $field as b))) FROM t1")) + } + } + } + } + + test("map_from_entries - fallback for binary type") { + def fallbackReason(reason: String) = { + if (CometConf.COMET_NATIVE_SCAN_IMPL.key == CometConf.SCAN_NATIVE_COMET || sys.env + .getOrElse("COMET_PARQUET_SCAN_IMPL", "") == CometConf.SCAN_NATIVE_COMET) { + "Unsupported schema" + } else { + reason + } + } + val table = "t2" + withTable(table) { + sql( + s"create table $table using parquet as select cast(array() as array) as c1 from range(10)") + checkSparkAnswerAndFallbackReason( + sql(s"select map_from_entries(array(struct(c1, 0))) from $table"), + fallbackReason(CometMapFromEntries.keyUnsupportedReason)) + checkSparkAnswerAndFallbackReason( + sql(s"select map_from_entries(array(struct(0, c1))) from $table"), + fallbackReason(CometMapFromEntries.valueUnsupportedReason)) + } + } + }