From 85ae6cfa356ef3e58fbe1bfd1a52ee582af4bdaf Mon Sep 17 00:00:00 2001 From: Rafael Fernandez Date: Thu, 12 Feb 2026 12:03:34 +0100 Subject: [PATCH 1/9] add support aes_decrypt --- docs/spark_expressions_support.md | 2 +- .../apache/comet/serde/QueryPlanSerde.scala | 1 + .../comet/CometStringExpressionSuite.scala | 40 +++++++++++++++++++ 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/docs/spark_expressions_support.md b/docs/spark_expressions_support.md index 2c18cbd08d..3961a91852 100644 --- a/docs/spark_expressions_support.md +++ b/docs/spark_expressions_support.md @@ -353,7 +353,7 @@ ### misc_funcs -- [ ] aes_decrypt +- [x] aes_decrypt - [ ] aes_encrypt - [ ] assert_true - [x] current_catalog diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 960aff8702..b00e7ce2cf 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -146,6 +146,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Sha1] -> CometSha1) private val stringExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( + classOf[AesDecrypt] -> CometScalarFunction("aes_decrypt"), classOf[Ascii] -> CometScalarFunction("ascii"), classOf[BitLength] -> CometScalarFunction("bit_length"), classOf[Chr] -> CometScalarFunction("char"), diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index 2a2932c643..197cf20194 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -378,4 +378,44 @@ class CometStringExpressionSuite extends CometTestBase { } } + test("aes_decrypt") { + withTable("aes_tbl") { + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + sql(""" + |CREATE TABLE aes_tbl( + | encrypted_default BINARY, + | encrypted_with_aad BINARY, + | `key` BINARY, + | mode STRING, + | padding STRING, + | aad BINARY + |) USING parquet + |""".stripMargin) + + sql(""" + |INSERT INTO aes_tbl + |SELECT + | aes_encrypt( + | encode('Spark SQL', 'UTF-8'), + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')), + | aes_encrypt( + | encode('Spark SQL', 'UTF-8'), + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + | 'GCM', + | 'DEFAULT', + | unhex('00112233445566778899AABB')), + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + | 'GCM', + | 'DEFAULT', + | unhex('00112233445566778899AABB') + |""".stripMargin) + } + + checkSparkAnswerAndOperator( + "SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl") + checkSparkAnswerAndOperator( + "SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding, aad) AS STRING) FROM aes_tbl") + } + } + } From 85489b0329d1b44b0c43f73aed075d75e064345e Mon Sep 17 00:00:00 2001 From: Rafael Fernandez Date: Thu, 12 Feb 2026 12:33:38 +0100 Subject: [PATCH 2/9] add cargo clean before build with miri --- .github/workflows/miri.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/miri.yml b/.github/workflows/miri.yml index ea36e1359a..7c193d6d12 100644 --- a/.github/workflows/miri.yml +++ b/.github/workflows/miri.yml @@ -60,4 +60,5 @@ jobs: - name: Test with Miri run: | cd native + cargo clean --target-dir target/miri MIRIFLAGS="-Zmiri-disable-isolation" cargo miri test --lib --bins --tests --examples From eb318d13e4e9e4b32d0d88b57605a96541442b75 Mon Sep 17 00:00:00 2001 From: Rafael Fernandez Date: Fri, 13 Feb 2026 14:09:07 +0100 Subject: [PATCH 3/9] tests --- .github/workflows/miri.yml | 1 - .../expressions/misc/aes_decrypt.sql | 54 ++++++++++ .../comet/CometMiscExpressionSuite.scala | 102 ++++++++++++++++++ .../comet/CometStringExpressionSuite.scala | 40 ------- 4 files changed, 156 insertions(+), 41 deletions(-) create mode 100644 spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql create mode 100644 spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala diff --git a/.github/workflows/miri.yml b/.github/workflows/miri.yml index 1cb3d22367..c9ee6abdd9 100644 --- a/.github/workflows/miri.yml +++ b/.github/workflows/miri.yml @@ -62,5 +62,4 @@ jobs: - name: Test with Miri run: | cd native - cargo clean --target-dir target/miri MIRIFLAGS="-Zmiri-disable-isolation" cargo miri test --lib --bins --tests --examples diff --git a/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql b/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql new file mode 100644 index 0000000000..f89f24d1ee --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql @@ -0,0 +1,54 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- MinSparkVersion: 3.5 + +statement +CREATE TABLE aes_tbl( + encrypted_default BINARY, + encrypted_with_aad BINARY, + `key` BINARY, + mode STRING, + padding STRING, + iv BINARY, + aad STRING +) USING parquet + +statement +INSERT INTO aes_tbl +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')), + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'GCM', + 'DEFAULT', + unhex('00112233445566778899AABB'), + 'Comet AAD'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'GCM', + 'DEFAULT', + unhex('00112233445566778899AABB'), + 'Comet AAD' + +query expect_fallback(Static invoke expression: aesDecrypt is not supported) +SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl + +query expect_fallback(Static invoke expression: aesDecrypt is not supported) +SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding, aad) AS STRING) FROM aes_tbl diff --git a/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala new file mode 100644 index 0000000000..d87a42d328 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.spark.sql.CometTestBase + +import org.apache.comet.CometSparkSessionExtensions.isSpark35Plus + +class CometMiscExpressionSuite extends CometTestBase { + + test("aes_decrypt") { + withTable("aes_tbl") { + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + sql(""" + |CREATE TABLE aes_tbl( + | encrypted_default BINARY, + | encrypted_with_aad BINARY, + | `key` BINARY, + | mode STRING, + | padding STRING, + | iv BINARY, + | aad STRING + |) USING parquet + |""".stripMargin) + + if (isSpark35Plus) { + sql(""" + |INSERT INTO aes_tbl + |SELECT + | aes_encrypt( + | encode('Spark SQL', 'UTF-8'), + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')), + | aes_encrypt( + | encode('Spark SQL', 'UTF-8'), + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + | 'GCM', + | 'DEFAULT', + | unhex('00112233445566778899AABB'), + | 'Comet AAD'), + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + | 'GCM', + | 'DEFAULT', + | unhex('00112233445566778899AABB'), + | 'Comet AAD' + |""".stripMargin) + } else { + sql(""" + |INSERT INTO aes_tbl + |SELECT + | aes_encrypt( + | encode('Spark SQL', 'UTF-8'), + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')), + | aes_encrypt( + | encode('Spark SQL', 'UTF-8'), + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + | 'GCM', + | 'DEFAULT'), + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + | 'GCM', + | 'DEFAULT', + | cast(null as binary), + | cast(null as string) + |""".stripMargin) + } + } + + if (isSpark35Plus) { + checkSparkAnswerAndFallbackReason( + "SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl", + "Static invoke expression: aesDecrypt is not supported") + checkSparkAnswerAndFallbackReason( + "SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding, aad) AS STRING) FROM aes_tbl", + "Static invoke expression: aesDecrypt is not supported") + } else { + checkSparkAnswerAndFallbackReason( + "SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl", + "Static invoke expression: aesDecrypt is not supported") + checkSparkAnswerAndFallbackReason( + "SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding) AS STRING) FROM aes_tbl", + "Static invoke expression: aesDecrypt is not supported") + } + } + } + +} diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index d071c595ff..121d7f7d5a 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -478,44 +478,4 @@ class CometStringExpressionSuite extends CometTestBase { } } - test("aes_decrypt") { - withTable("aes_tbl") { - withSQLConf(CometConf.COMET_ENABLED.key -> "false") { - sql(""" - |CREATE TABLE aes_tbl( - | encrypted_default BINARY, - | encrypted_with_aad BINARY, - | `key` BINARY, - | mode STRING, - | padding STRING, - | aad BINARY - |) USING parquet - |""".stripMargin) - - sql(""" - |INSERT INTO aes_tbl - |SELECT - | aes_encrypt( - | encode('Spark SQL', 'UTF-8'), - | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')), - | aes_encrypt( - | encode('Spark SQL', 'UTF-8'), - | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), - | 'GCM', - | 'DEFAULT', - | unhex('00112233445566778899AABB')), - | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), - | 'GCM', - | 'DEFAULT', - | unhex('00112233445566778899AABB') - |""".stripMargin) - } - - checkSparkAnswerAndOperator( - "SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl") - checkSparkAnswerAndOperator( - "SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding, aad) AS STRING) FROM aes_tbl") - } - } - } From 6e1391a1763addb42aee99937f52ad81c8c6260d Mon Sep 17 00:00:00 2001 From: Rafael Fernandez Date: Fri, 13 Feb 2026 16:21:34 +0100 Subject: [PATCH 4/9] add support --- .../apache/comet/serde/QueryPlanSerde.scala | 2 +- .../org/apache/comet/serde/statics.scala | 40 +++++++- .../expressions/misc/aes_decrypt.sql | 4 +- .../comet/CometMiscExpressionSuite.scala | 93 +++++++------------ 4 files changed, 72 insertions(+), 67 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 12e7426459..574b9efb63 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -146,7 +146,6 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Sha1] -> CometSha1) private val stringExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( - classOf[AesDecrypt] -> CometScalarFunction("aes_decrypt"), classOf[Ascii] -> CometScalarFunction("ascii"), classOf[BitLength] -> CometScalarFunction("bit_length"), classOf[Chr] -> CometScalarFunction("char"), @@ -220,6 +219,7 @@ object QueryPlanSerde extends Logging with CometExprShim { private val miscExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( // TODO PromotePrecision + classOf[AesDecrypt] -> CometAesDecrypt, classOf[Alias] -> CometAlias, classOf[AttributeReference] -> CometAttributeReference, classOf[BloomFilterMightContain] -> CometBloomFilterMightContain, diff --git a/spark/src/main/scala/org/apache/comet/serde/statics.scala b/spark/src/main/scala/org/apache/comet/serde/statics.scala index 0737644ab9..e984750565 100644 --- a/spark/src/main/scala/org/apache/comet/serde/statics.scala +++ b/spark/src/main/scala/org/apache/comet/serde/statics.scala @@ -19,11 +19,46 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{AesDecrypt, Attribute, Expression, ExpressionImplUtils} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType} + +private object CometAesDecryptHelper { + def convertToAesDecryptExpr[T <: Expression]( + expr: T, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + val childExpr = expr.children.map(exprToProtoInternal(_, inputs, binding)) + val optExpr = scalarFunctionExprToProtoWithReturnType( + "aes_decrypt", + expr.dataType, + failOnError = false, + childExpr: _*) + optExprWithInfo(optExpr, expr, expr.children: _*) + } +} + +object CometAesDecrypt extends CometExpressionSerde[AesDecrypt] { + override def convert( + expr: AesDecrypt, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + CometAesDecryptHelper.convertToAesDecryptExpr(expr, inputs, binding) + } +} + +object CometAesDecryptStaticInvoke extends CometExpressionSerde[StaticInvoke] { + override def convert( + expr: StaticInvoke, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + CometAesDecryptHelper.convertToAesDecryptExpr(expr, inputs, binding) + } +} object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { @@ -34,7 +69,8 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { : Map[(String, Class[_]), CometExpressionSerde[StaticInvoke]] = Map( ("readSidePadding", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction( - "read_side_padding")) + "read_side_padding"), + ("aesDecrypt", classOf[ExpressionImplUtils]) -> CometAesDecryptStaticInvoke) override def convert( expr: StaticInvoke, diff --git a/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql b/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql index f89f24d1ee..cca41c83d7 100644 --- a/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql +++ b/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql @@ -47,8 +47,8 @@ SELECT unhex('00112233445566778899AABB'), 'Comet AAD' -query expect_fallback(Static invoke expression: aesDecrypt is not supported) +query SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl -query expect_fallback(Static invoke expression: aesDecrypt is not supported) +query SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding, aad) AS STRING) FROM aes_tbl diff --git a/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala index d87a42d328..b6f7e921b4 100644 --- a/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala @@ -26,75 +26,44 @@ import org.apache.comet.CometSparkSessionExtensions.isSpark35Plus class CometMiscExpressionSuite extends CometTestBase { test("aes_decrypt") { - withTable("aes_tbl") { + withTempView("aes_tbl") { withSQLConf(CometConf.COMET_ENABLED.key -> "false") { - sql(""" - |CREATE TABLE aes_tbl( - | encrypted_default BINARY, - | encrypted_with_aad BINARY, - | `key` BINARY, - | mode STRING, - | padding STRING, - | iv BINARY, - | aad STRING - |) USING parquet - |""".stripMargin) - - if (isSpark35Plus) { - sql(""" - |INSERT INTO aes_tbl - |SELECT - | aes_encrypt( - | encode('Spark SQL', 'UTF-8'), - | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')), - | aes_encrypt( - | encode('Spark SQL', 'UTF-8'), - | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), - | 'GCM', - | 'DEFAULT', - | unhex('00112233445566778899AABB'), - | 'Comet AAD'), - | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), - | 'GCM', - | 'DEFAULT', - | unhex('00112233445566778899AABB'), - | 'Comet AAD' - |""".stripMargin) + val aesDf = if (isSpark35Plus) { + spark + .range(1) + .selectExpr( + "aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')) as encrypted_default", + "aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), 'GCM', 'DEFAULT', unhex('00112233445566778899AABB'), 'Comet AAD') as encrypted_with_aad", + "encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8') as `key`", + "'GCM' as mode", + "'DEFAULT' as padding", + "unhex('00112233445566778899AABB') as iv", + "'Comet AAD' as aad") } else { - sql(""" - |INSERT INTO aes_tbl - |SELECT - | aes_encrypt( - | encode('Spark SQL', 'UTF-8'), - | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')), - | aes_encrypt( - | encode('Spark SQL', 'UTF-8'), - | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), - | 'GCM', - | 'DEFAULT'), - | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), - | 'GCM', - | 'DEFAULT', - | cast(null as binary), - | cast(null as string) - |""".stripMargin) + spark + .range(1) + .selectExpr( + "aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')) as encrypted_default", + "aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), 'GCM', 'DEFAULT') as encrypted_with_aad", + "encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8') as `key`", + "'GCM' as mode", + "'DEFAULT' as padding", + "cast(null as binary) as iv", + "cast(null as string) as aad") } + aesDf.createOrReplaceTempView("aes_tbl") } if (isSpark35Plus) { - checkSparkAnswerAndFallbackReason( - "SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl", - "Static invoke expression: aesDecrypt is not supported") - checkSparkAnswerAndFallbackReason( - "SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding, aad) AS STRING) FROM aes_tbl", - "Static invoke expression: aesDecrypt is not supported") + checkSparkAnswerAndOperator( + "SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl") + checkSparkAnswerAndOperator( + "SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding, aad) AS STRING) FROM aes_tbl") } else { - checkSparkAnswerAndFallbackReason( - "SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl", - "Static invoke expression: aesDecrypt is not supported") - checkSparkAnswerAndFallbackReason( - "SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding) AS STRING) FROM aes_tbl", - "Static invoke expression: aesDecrypt is not supported") + checkSparkAnswerAndOperator( + "SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl") + checkSparkAnswerAndOperator( + "SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding) AS STRING) FROM aes_tbl") } } } From 44073a5e551cf3835af7debea80a50901fe29579 Mon Sep 17 00:00:00 2001 From: Rafael Fernandez Date: Fri, 13 Feb 2026 16:23:18 +0100 Subject: [PATCH 5/9] add native support --- native/Cargo.lock | 134 ++++++++ native/spark-expr/Cargo.toml | 4 + native/spark-expr/src/comet_scalar_funcs.rs | 13 +- native/spark-expr/src/lib.rs | 2 +- .../spark-expr/src/math_funcs/aes_decrypt.rs | 323 ++++++++++++++++++ native/spark-expr/src/math_funcs/mod.rs | 2 + native/spark-expr/tests/spark_expr_reg.rs | 6 + 7 files changed, 479 insertions(+), 5 deletions(-) create mode 100644 native/spark-expr/src/math_funcs/aes_decrypt.rs diff --git a/native/Cargo.lock b/native/Cargo.lock index c1224c2a06..3c789281d6 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -17,6 +17,41 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "aead" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" +dependencies = [ + "crypto-common", + "generic-array", +] + +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures 0.2.17", +] + +[[package]] +name = "aes-gcm" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1" +dependencies = [ + "aead", + "aes", + "cipher", + "ctr", + "ghash", + "subtle", +] + [[package]] name = "ahash" version = "0.7.8" @@ -1072,6 +1107,15 @@ dependencies = [ "generic-array", ] +[[package]] +name = "block-padding" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8894febbff9f758034a5b8e12d87918f56dfc64a8e1fe757d65e29041538d93" +dependencies = [ + "generic-array", +] + [[package]] name = "blocking" version = "1.6.2" @@ -1244,6 +1288,15 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" +[[package]] +name = "cbc" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b52a9543ae338f279b96b0b9fed9c8093744685043739079ce85cd58f289a6" +dependencies = [ + "cipher", +] + [[package]] name = "cc" version = "1.2.55" @@ -1345,6 +1398,16 @@ dependencies = [ "half", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clang-sys" version = "1.8.1" @@ -1594,6 +1657,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" dependencies = [ "generic-array", + "rand_core 0.6.4", "typenum", ] @@ -1618,6 +1682,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "ctr" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" +dependencies = [ + "cipher", +] + [[package]] name = "darling" version = "0.20.11" @@ -1936,12 +2009,16 @@ dependencies = [ name = "datafusion-comet-spark-expr" version = "0.14.0" dependencies = [ + "aes", + "aes-gcm", "arrow", "base64", + "cbc", "chrono", "chrono-tz", "criterion", "datafusion", + "ecb", "futures", "hex", "num", @@ -2642,6 +2719,15 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" +[[package]] +name = "ecb" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a8bfa975b1aec2145850fcaa1c6fe269a16578c44705a532ae3edc92b8881c7" +dependencies = [ + "cipher", +] + [[package]] name = "either" version = "1.15.0" @@ -2978,6 +3064,16 @@ dependencies = [ "wasip3", ] +[[package]] +name = "ghash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1" +dependencies = [ + "opaque-debug", + "polyval", +] + [[package]] name = "gimli" version = "0.32.3" @@ -3492,6 +3588,16 @@ dependencies = [ "str_stack", ] +[[package]] +name = "inout" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" +dependencies = [ + "block-padding", + "generic-array", +] + [[package]] name = "integer-encoding" version = "3.0.4" @@ -4159,6 +4265,12 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" +[[package]] +name = "opaque-debug" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" + [[package]] name = "opendal" version = "0.55.0" @@ -4482,6 +4594,18 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "polyval" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" +dependencies = [ + "cfg-if", + "cpufeatures 0.2.17", + "opaque-debug", + "universal-hash", +] + [[package]] name = "portable-atomic" version = "1.13.0" @@ -6130,6 +6254,16 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "universal-hash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" +dependencies = [ + "crypto-common", + "subtle", +] + [[package]] name = "unsafe-any-ors" version = "1.0.0" diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index fd0a211b29..6bad894cbc 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -40,6 +40,10 @@ twox-hash = "2.1.2" rand = { workspace = true } hex = "0.4.3" base64 = "0.22.1" +aes = "0.8.4" +aes-gcm = "0.10.3" +cbc = "0.1.2" +ecb = "0.1.2" [dev-dependencies] arrow = {workspace = true} diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 323a483171..ad49425c47 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -20,10 +20,11 @@ use crate::math_funcs::abs::abs; use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub}; use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ - spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, - spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, - spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkContains, SparkDateDiff, - SparkDateTrunc, SparkMakeDate, SparkSizeFunc, SparkStringSpace, + spark_aes_decrypt, spark_array_repeat, spark_ceil, spark_decimal_div, + spark_decimal_integral_div, spark_floor, spark_isnan, spark_lpad, spark_make_decimal, + spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, EvalMode, + SparkBitwiseCount, SparkContains, SparkDateDiff, SparkDateTrunc, SparkMakeDate, SparkSizeFunc, + SparkStringSpace, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -181,6 +182,10 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(abs); make_comet_scalar_udf!("abs", func, without data_type) } + "aes_decrypt" => { + let func = Arc::new(spark_aes_decrypt); + make_comet_scalar_udf!("aes_decrypt", func, without data_type) + } "split" => { let func = Arc::new(crate::string_funcs::spark_split); make_comet_scalar_udf!("split", func, without data_type) diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 40eb180ab8..aef76ea7dd 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -79,7 +79,7 @@ pub use error::{SparkError, SparkResult}; pub use hash_funcs::*; pub use json_funcs::{FromJson, ToJson}; pub use math_funcs::{ - create_modulo_expr, create_negate_expr, spark_ceil, spark_decimal_div, + create_modulo_expr, create_negate_expr, spark_aes_decrypt, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_make_decimal, spark_round, spark_unhex, spark_unscaled_value, CheckOverflow, NegativeExpr, NormalizeNaNAndZero, }; diff --git a/native/spark-expr/src/math_funcs/aes_decrypt.rs b/native/spark-expr/src/math_funcs/aes_decrypt.rs new file mode 100644 index 0000000000..7b6a9fc3a5 --- /dev/null +++ b/native/spark-expr/src/math_funcs/aes_decrypt.rs @@ -0,0 +1,323 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use aes::cipher::consts::{U12, U16}; +use aes::{Aes128, Aes192, Aes256}; +use aes_gcm::aead::{Aead, Payload}; +use aes_gcm::{Aes128Gcm, Aes256Gcm, AesGcm, KeyInit, Nonce}; +use arrow::array::{ + Array, ArrayRef, BinaryArray, BinaryBuilder, LargeBinaryArray, LargeStringArray, StringArray, +}; +use arrow::datatypes::DataType; +use cbc::cipher::{block_padding::Pkcs7, BlockDecryptMut, KeyIvInit}; +use datafusion::common::{exec_err, DataFusionError}; +use datafusion::logical_expr::ColumnarValue; + +const GCM_IV_LEN: usize = 12; +const CBC_IV_LEN: usize = 16; + +#[derive(Clone, Copy)] +enum AesMode { + Ecb, + Cbc, + Gcm, +} + +impl AesMode { + fn from_mode_padding(mode: &str, padding: &str) -> Result { + let is_none = padding.eq_ignore_ascii_case("NONE"); + let is_pkcs = padding.eq_ignore_ascii_case("PKCS"); + let is_default = padding.eq_ignore_ascii_case("DEFAULT"); + + if mode.eq_ignore_ascii_case("ECB") && (is_pkcs || is_default) { + Ok(Self::Ecb) + } else if mode.eq_ignore_ascii_case("CBC") && (is_pkcs || is_default) { + Ok(Self::Cbc) + } else if mode.eq_ignore_ascii_case("GCM") && (is_none || is_default) { + Ok(Self::Gcm) + } else { + exec_err!("Unsupported AES mode/padding combination: {mode}/{padding}") + } + } +} + +enum BinaryArg<'a> { + Binary(&'a BinaryArray), + LargeBinary(&'a LargeBinaryArray), +} + +impl<'a> BinaryArg<'a> { + fn from(arg_name: &str, arr: &'a ArrayRef) -> Result { + match arr.data_type() { + DataType::Binary => Ok(Self::Binary( + arr.as_any().downcast_ref::().ok_or_else(|| { + DataFusionError::Internal(format!( + "Failed to downcast {arg_name} to BinaryArray" + )) + })?, + )), + DataType::LargeBinary => Ok(Self::LargeBinary( + arr.as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Failed to downcast {arg_name} to LargeBinaryArray" + )) + })?, + )), + other => exec_err!("{arg_name} must be Binary/LargeBinary, got {other:?}"), + } + } + + fn value(&self, i: usize) -> Option<&'a [u8]> { + match self { + Self::Binary(arr) => (!arr.is_null(i)).then(|| arr.value(i)), + Self::LargeBinary(arr) => (!arr.is_null(i)).then(|| arr.value(i)), + } + } +} + +enum StringArg<'a> { + Utf8(&'a StringArray), + LargeUtf8(&'a LargeStringArray), +} + +impl<'a> StringArg<'a> { + fn from(arg_name: &str, arr: &'a ArrayRef) -> Result { + match arr.data_type() { + DataType::Utf8 => Ok(Self::Utf8( + arr.as_any().downcast_ref::().ok_or_else(|| { + DataFusionError::Internal(format!( + "Failed to downcast {arg_name} to StringArray" + )) + })?, + )), + DataType::LargeUtf8 => Ok(Self::LargeUtf8( + arr.as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Failed to downcast {arg_name} to LargeStringArray" + )) + })?, + )), + other => exec_err!("{arg_name} must be Utf8/LargeUtf8, got {other:?}"), + } + } + + fn value(&self, i: usize) -> Option<&'a str> { + match self { + Self::Utf8(arr) => (!arr.is_null(i)).then(|| arr.value(i)), + Self::LargeUtf8(arr) => (!arr.is_null(i)).then(|| arr.value(i)), + } + } +} + +type Aes128CbcDec = cbc::Decryptor; +type Aes192CbcDec = cbc::Decryptor; +type Aes256CbcDec = cbc::Decryptor; +type Aes128EcbDec = ecb::Decryptor; +type Aes192EcbDec = ecb::Decryptor; +type Aes256EcbDec = ecb::Decryptor; +type Aes192Gcm = AesGcm; + +fn decrypt_pkcs_cbc(input: &[u8], key: &[u8]) -> Result, DataFusionError> { + if input.len() < CBC_IV_LEN { + return exec_err!("AES decryption input is too short for CBC"); + } + let (iv, ciphertext) = input.split_at(CBC_IV_LEN); + let mut buf = ciphertext.to_vec(); + + let out = match key.len() { + 16 => Aes128CbcDec::new_from_slices(key, iv) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + 24 => Aes192CbcDec::new_from_slices(key, iv) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + 32 => Aes256CbcDec::new_from_slices(key, iv) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + _ => return exec_err!("Invalid AES key length: {}", key.len()), + }; + + Ok(out.to_vec()) +} + +fn decrypt_pkcs_ecb(input: &[u8], key: &[u8]) -> Result, DataFusionError> { + let mut buf = input.to_vec(); + + let out = match key.len() { + 16 => Aes128EcbDec::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + 24 => Aes192EcbDec::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + 32 => Aes256EcbDec::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + _ => return exec_err!("Invalid AES key length: {}", key.len()), + }; + + Ok(out.to_vec()) +} + +fn decrypt_gcm(input: &[u8], key: &[u8], aad: &[u8]) -> Result, DataFusionError> { + if input.len() < GCM_IV_LEN { + return exec_err!("AES decryption input is too short for GCM"); + } + let (iv, ciphertext) = input.split_at(GCM_IV_LEN); + let nonce = Nonce::from_slice(iv); + let payload = Payload { + msg: ciphertext, + aad, + }; + + match key.len() { + 16 => Aes128Gcm::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt(nonce, payload) + .map_err(|_| { + DataFusionError::Execution("AES crypto error: decrypt failed".to_string()) + }), + 24 => Aes192Gcm::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt(nonce, payload) + .map_err(|_| { + DataFusionError::Execution("AES crypto error: decrypt failed".to_string()) + }), + 32 => Aes256Gcm::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt(nonce, payload) + .map_err(|_| { + DataFusionError::Execution("AES crypto error: decrypt failed".to_string()) + }), + _ => exec_err!("Invalid AES key length: {}", key.len()), + } +} + +fn decrypt_one( + input: &[u8], + key: &[u8], + mode: &str, + padding: &str, + aad: &[u8], +) -> Result, DataFusionError> { + match AesMode::from_mode_padding(mode, padding)? { + AesMode::Ecb => decrypt_pkcs_ecb(input, key), + AesMode::Cbc => decrypt_pkcs_cbc(input, key), + AesMode::Gcm => decrypt_gcm(input, key, aad), + } +} + +pub fn spark_aes_decrypt(args: &[ColumnarValue]) -> Result { + if !(2..=5).contains(&args.len()) { + return exec_err!("aes_decrypt expects 2 to 5 arguments, got {}", args.len()); + } + + let are_scalars = args + .iter() + .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); + let arrays = ColumnarValue::values_to_arrays(args)?; + let num_rows = arrays[0].len(); + + let input = BinaryArg::from("input", &arrays[0])?; + let key = BinaryArg::from("key", &arrays[1])?; + + let mode = if args.len() >= 3 { + Some(StringArg::from("mode", &arrays[2])?) + } else { + None + }; + let padding = if args.len() >= 4 { + Some(StringArg::from("padding", &arrays[3])?) + } else { + None + }; + let aad = if args.len() >= 5 { + Some(BinaryArg::from("aad", &arrays[4])?) + } else { + None + }; + + let mut builder = BinaryBuilder::new(); + + for row in 0..num_rows { + let Some(input_value) = input.value(row) else { + builder.append_null(); + continue; + }; + let Some(key_value) = key.value(row) else { + builder.append_null(); + continue; + }; + + let mode_value = match mode.as_ref() { + Some(mode) => { + let Some(mode) = mode.value(row) else { + builder.append_null(); + continue; + }; + mode + } + None => "GCM", + }; + + let padding_value = match padding.as_ref() { + Some(padding) => { + let Some(padding) = padding.value(row) else { + builder.append_null(); + continue; + }; + padding + } + None => "DEFAULT", + }; + + let aad_value = match aad.as_ref() { + Some(aad) => { + let Some(aad) = aad.value(row) else { + builder.append_null(); + continue; + }; + aad + } + None => &[], + }; + + let plaintext = decrypt_one(input_value, key_value, mode_value, padding_value, aad_value)?; + builder.append_value(plaintext); + } + + let array = Arc::new(builder.finish()); + if are_scalars { + Ok(ColumnarValue::Scalar( + datafusion::common::ScalarValue::try_from_array(array.as_ref(), 0)?, + )) + } else { + Ok(ColumnarValue::Array(array)) + } +} diff --git a/native/spark-expr/src/math_funcs/mod.rs b/native/spark-expr/src/math_funcs/mod.rs index 35c1dc6504..f5fde060d3 100644 --- a/native/spark-expr/src/math_funcs/mod.rs +++ b/native/spark-expr/src/math_funcs/mod.rs @@ -16,6 +16,7 @@ // under the License. pub(crate) mod abs; +pub(crate) mod aes_decrypt; mod ceil; pub(crate) mod checked_arithmetic; mod div; @@ -27,6 +28,7 @@ mod round; pub(crate) mod unhex; mod utils; +pub use aes_decrypt::spark_aes_decrypt; pub use ceil::spark_ceil; pub use div::spark_decimal_div; pub use div::spark_decimal_integral_div; diff --git a/native/spark-expr/tests/spark_expr_reg.rs b/native/spark-expr/tests/spark_expr_reg.rs index 633b226068..f381b77881 100644 --- a/native/spark-expr/tests/spark_expr_reg.rs +++ b/native/spark-expr/tests/spark_expr_reg.rs @@ -35,6 +35,12 @@ mod tests { &session_state, None, )?); + let _ = session_state.register_udf(create_comet_physical_fun( + "aes_decrypt", + DataType::Binary, + &session_state, + None, + )?); let ctx = SessionContext::new_with_state(session_state); // 2. Execute SQL with literal values From 0836838b4ecf8c1f7818540351b2afd963d64bdf Mon Sep 17 00:00:00 2001 From: Rafael Fernandez Date: Tue, 17 Feb 2026 17:25:14 +0100 Subject: [PATCH 6/9] format --- native/spark-expr/src/comet_scalar_funcs.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 688f46b2e1..3fd61f6b05 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -20,9 +20,9 @@ use crate::math_funcs::abs::abs; use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub}; use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ - spark_aes_decrypt, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, - spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, - spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkContains, SparkDateDiff, + spark_aes_decrypt, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, + spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, + spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkContains, SparkDateDiff, SparkDateTrunc, SparkMakeDate, SparkSizeFunc, SparkStringSpace, }; use arrow::datatypes::DataType; From 9ca9839d2df5b89e7a0a299efb5eda50d0732719 Mon Sep 17 00:00:00 2001 From: Rafael Fernandez Date: Tue, 17 Feb 2026 23:12:10 +0100 Subject: [PATCH 7/9] add downcast macros --- native/spark-expr/src/downcast.rs | 36 +++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 native/spark-expr/src/downcast.rs diff --git a/native/spark-expr/src/downcast.rs b/native/spark-expr/src/downcast.rs new file mode 100644 index 0000000000..ade2ef961b --- /dev/null +++ b/native/spark-expr/src/downcast.rs @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +macro_rules! opt_downcast_arg { + ($ARG:expr, $ARRAY_TYPE:ident) => {{ + $ARG.as_any().downcast_ref::<$ARRAY_TYPE>() + }}; +} + +macro_rules! downcast_named_arg { + ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{ + $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { + datafusion::common::internal_datafusion_err!( + "could not cast {} to {}", + $NAME, + std::any::type_name::<$ARRAY_TYPE>() + ) + })? + }}; +} + +pub(crate) use {downcast_named_arg, opt_downcast_arg}; From 18bd8e3f9c10da7ec20d1a4911123a35d85b3024 Mon Sep 17 00:00:00 2001 From: Rafael Fernandez Date: Tue, 17 Feb 2026 23:12:21 +0100 Subject: [PATCH 8/9] edit implementation and move to misc reference --- native/spark-expr/src/lib.rs | 6 +- native/spark-expr/src/math_funcs/mod.rs | 2 - .../{math_funcs => misc_funcs}/aes_decrypt.rs | 152 ++++++++---------- native/spark-expr/src/misc_funcs/mod.rs | 20 +++ .../scala/org/apache/comet/serde/misc.scala | 59 +++++++ .../org/apache/comet/serde/statics.scala | 37 +---- 6 files changed, 155 insertions(+), 121 deletions(-) rename native/spark-expr/src/{math_funcs => misc_funcs}/aes_decrypt.rs (75%) create mode 100644 native/spark-expr/src/misc_funcs/mod.rs create mode 100644 spark/src/main/scala/org/apache/comet/serde/misc.scala diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index aef76ea7dd..ed59682323 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -19,7 +19,9 @@ // The lint makes easier for code reader/reviewer separate references clones from more heavyweight ones #![deny(clippy::clone_on_ref_ptr)] +mod downcast; mod error; +pub(crate) use downcast::{downcast_named_arg, opt_downcast_arg}; pub mod kernels; pub use kernels::temporal::date_trunc_dyn; @@ -58,6 +60,7 @@ pub use bloom_filter::{BloomFilterAgg, BloomFilterMightContain}; mod conditional_funcs; mod conversion_funcs; mod math_funcs; +mod misc_funcs; mod nondetermenistic_funcs; pub use array_funcs::*; @@ -79,10 +82,11 @@ pub use error::{SparkError, SparkResult}; pub use hash_funcs::*; pub use json_funcs::{FromJson, ToJson}; pub use math_funcs::{ - create_modulo_expr, create_negate_expr, spark_aes_decrypt, spark_ceil, spark_decimal_div, + create_modulo_expr, create_negate_expr, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_make_decimal, spark_round, spark_unhex, spark_unscaled_value, CheckOverflow, NegativeExpr, NormalizeNaNAndZero, }; +pub use misc_funcs::spark_aes_decrypt; pub use string_funcs::*; /// Spark supports three evaluation modes when evaluating expressions, which affect diff --git a/native/spark-expr/src/math_funcs/mod.rs b/native/spark-expr/src/math_funcs/mod.rs index f5fde060d3..35c1dc6504 100644 --- a/native/spark-expr/src/math_funcs/mod.rs +++ b/native/spark-expr/src/math_funcs/mod.rs @@ -16,7 +16,6 @@ // under the License. pub(crate) mod abs; -pub(crate) mod aes_decrypt; mod ceil; pub(crate) mod checked_arithmetic; mod div; @@ -28,7 +27,6 @@ mod round; pub(crate) mod unhex; mod utils; -pub use aes_decrypt::spark_aes_decrypt; pub use ceil::spark_ceil; pub use div::spark_decimal_div; pub use div::spark_decimal_integral_div; diff --git a/native/spark-expr/src/math_funcs/aes_decrypt.rs b/native/spark-expr/src/misc_funcs/aes_decrypt.rs similarity index 75% rename from native/spark-expr/src/math_funcs/aes_decrypt.rs rename to native/spark-expr/src/misc_funcs/aes_decrypt.rs index 7b6a9fc3a5..605fe50994 100644 --- a/native/spark-expr/src/math_funcs/aes_decrypt.rs +++ b/native/spark-expr/src/misc_funcs/aes_decrypt.rs @@ -15,18 +15,14 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - use aes::cipher::consts::{U12, U16}; use aes::{Aes128, Aes192, Aes256}; use aes_gcm::aead::{Aead, Payload}; use aes_gcm::{Aes128Gcm, Aes256Gcm, AesGcm, KeyInit, Nonce}; -use arrow::array::{ - Array, ArrayRef, BinaryArray, BinaryBuilder, LargeBinaryArray, LargeStringArray, StringArray, -}; +use arrow::array::{Array, ArrayRef, BinaryArray, LargeBinaryArray, LargeStringArray, StringArray}; use arrow::datatypes::DataType; use cbc::cipher::{block_padding::Pkcs7, BlockDecryptMut, KeyIvInit}; -use datafusion::common::{exec_err, DataFusionError}; +use datafusion::common::{exec_err, DataFusionError, ScalarValue}; use datafusion::logical_expr::ColumnarValue; const GCM_IV_LEN: usize = 12; @@ -66,21 +62,19 @@ impl<'a> BinaryArg<'a> { fn from(arg_name: &str, arr: &'a ArrayRef) -> Result { match arr.data_type() { DataType::Binary => Ok(Self::Binary( - arr.as_any().downcast_ref::().ok_or_else(|| { - DataFusionError::Internal(format!( - "Failed to downcast {arg_name} to BinaryArray" - )) + crate::opt_downcast_arg!(arr, BinaryArray).ok_or_else(|| { + datafusion::common::internal_datafusion_err!( + "could not cast {} to {}", + arg_name, + std::any::type_name::() + ) })?, )), - DataType::LargeBinary => Ok(Self::LargeBinary( - arr.as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal(format!( - "Failed to downcast {arg_name} to LargeBinaryArray" - )) - })?, - )), + DataType::LargeBinary => Ok(Self::LargeBinary(crate::downcast_named_arg!( + arr, + arg_name, + LargeBinaryArray + ))), other => exec_err!("{arg_name} must be Binary/LargeBinary, got {other:?}"), } } @@ -102,21 +96,19 @@ impl<'a> StringArg<'a> { fn from(arg_name: &str, arr: &'a ArrayRef) -> Result { match arr.data_type() { DataType::Utf8 => Ok(Self::Utf8( - arr.as_any().downcast_ref::().ok_or_else(|| { - DataFusionError::Internal(format!( - "Failed to downcast {arg_name} to StringArray" - )) + crate::opt_downcast_arg!(arr, StringArray).ok_or_else(|| { + datafusion::common::internal_datafusion_err!( + "could not cast {} to {}", + arg_name, + std::any::type_name::() + ) })?, )), - DataType::LargeUtf8 => Ok(Self::LargeUtf8( - arr.as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal(format!( - "Failed to downcast {arg_name} to LargeStringArray" - )) - })?, - )), + DataType::LargeUtf8 => Ok(Self::LargeUtf8(crate::downcast_named_arg!( + arr, + arg_name, + LargeStringArray + ))), other => exec_err!("{arg_name} must be Utf8/LargeUtf8, got {other:?}"), } } @@ -263,56 +255,52 @@ pub fn spark_aes_decrypt(args: &[ColumnarValue]) -> Result { - let Some(mode) = mode.value(row) else { - builder.append_null(); - continue; - }; - mode - } - None => "GCM", - }; - - let padding_value = match padding.as_ref() { - Some(padding) => { - let Some(padding) = padding.value(row) else { - builder.append_null(); - continue; - }; - padding - } - None => "DEFAULT", - }; - - let aad_value = match aad.as_ref() { - Some(aad) => { - let Some(aad) = aad.value(row) else { - builder.append_null(); - continue; - }; - aad - } - None => &[], - }; - - let plaintext = decrypt_one(input_value, key_value, mode_value, padding_value, aad_value)?; - builder.append_value(plaintext); - } - - let array = Arc::new(builder.finish()); + let values: Result, DataFusionError> = (0..num_rows) + .map(|row| { + let Some(input_value) = input.value(row) else { + return Ok(ScalarValue::Binary(None)); + }; + let Some(key_value) = key.value(row) else { + return Ok(ScalarValue::Binary(None)); + }; + + let mode_value = match mode.as_ref() { + Some(mode) => { + let Some(mode) = mode.value(row) else { + return Ok(ScalarValue::Binary(None)); + }; + mode + } + None => "GCM", + }; + + let padding_value = match padding.as_ref() { + Some(padding) => { + let Some(padding) = padding.value(row) else { + return Ok(ScalarValue::Binary(None)); + }; + padding + } + None => "DEFAULT", + }; + + let aad_value = match aad.as_ref() { + Some(aad) => { + let Some(aad) = aad.value(row) else { + return Ok(ScalarValue::Binary(None)); + }; + aad + } + None => &[], + }; + + let plaintext = + decrypt_one(input_value, key_value, mode_value, padding_value, aad_value)?; + Ok(ScalarValue::Binary(Some(plaintext))) + }) + .collect(); + + let array: ArrayRef = ScalarValue::iter_to_array(values?)?; if are_scalars { Ok(ColumnarValue::Scalar( datafusion::common::ScalarValue::try_from_array(array.as_ref(), 0)?, diff --git a/native/spark-expr/src/misc_funcs/mod.rs b/native/spark-expr/src/misc_funcs/mod.rs new file mode 100644 index 0000000000..c55b82811d --- /dev/null +++ b/native/spark-expr/src/misc_funcs/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub(crate) mod aes_decrypt; + +pub use aes_decrypt::spark_aes_decrypt; diff --git a/spark/src/main/scala/org/apache/comet/serde/misc.scala b/spark/src/main/scala/org/apache/comet/serde/misc.scala new file mode 100644 index 0000000000..9bb8416579 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/misc.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.{AesDecrypt, Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke + +import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType} + +private object CometAesDecryptHelper { + def convertToAesDecryptExpr[T <: Expression]( + expr: T, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + val childExpr = expr.children.map(exprToProtoInternal(_, inputs, binding)) + val optExpr = scalarFunctionExprToProtoWithReturnType( + "aes_decrypt", + expr.dataType, + failOnError = false, + childExpr: _*) + optExprWithInfo(optExpr, expr, expr.children: _*) + } +} + +object CometAesDecrypt extends CometExpressionSerde[AesDecrypt] { + override def convert( + expr: AesDecrypt, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + CometAesDecryptHelper.convertToAesDecryptExpr(expr, inputs, binding) + } +} + +object CometAesDecryptStaticInvoke extends CometExpressionSerde[StaticInvoke] { + override def convert( + expr: StaticInvoke, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + CometAesDecryptHelper.convertToAesDecryptExpr(expr, inputs, binding) + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/statics.scala b/spark/src/main/scala/org/apache/comet/serde/statics.scala index e984750565..fa77369405 100644 --- a/spark/src/main/scala/org/apache/comet/serde/statics.scala +++ b/spark/src/main/scala/org/apache/comet/serde/statics.scala @@ -19,46 +19,11 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{AesDecrypt, Attribute, Expression, ExpressionImplUtils} +import org.apache.spark.sql.catalyst.expressions.{Attribute, ExpressionImplUtils} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils import org.apache.comet.CometSparkSessionExtensions.withInfo -import org.apache.comet.serde.ExprOuterClass.Expr -import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType} - -private object CometAesDecryptHelper { - def convertToAesDecryptExpr[T <: Expression]( - expr: T, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - val childExpr = expr.children.map(exprToProtoInternal(_, inputs, binding)) - val optExpr = scalarFunctionExprToProtoWithReturnType( - "aes_decrypt", - expr.dataType, - failOnError = false, - childExpr: _*) - optExprWithInfo(optExpr, expr, expr.children: _*) - } -} - -object CometAesDecrypt extends CometExpressionSerde[AesDecrypt] { - override def convert( - expr: AesDecrypt, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - CometAesDecryptHelper.convertToAesDecryptExpr(expr, inputs, binding) - } -} - -object CometAesDecryptStaticInvoke extends CometExpressionSerde[StaticInvoke] { - override def convert( - expr: StaticInvoke, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - CometAesDecryptHelper.convertToAesDecryptExpr(expr, inputs, binding) - } -} object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { From 730d6c7c74747b8ac255b1cf1518e8e20e30619a Mon Sep 17 00:00:00 2001 From: Rafael Fernandez Date: Thu, 19 Feb 2026 21:50:29 +0100 Subject: [PATCH 9/9] suggestion - tests --- .../expressions/misc/aes_decrypt.sql | 110 ++++++++++++++++++ .../comet/CometMiscExpressionSuite.scala | 88 ++++++++++++++ 2 files changed, 198 insertions(+) diff --git a/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql b/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql index cca41c83d7..5c93dd6a7a 100644 --- a/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql +++ b/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql @@ -52,3 +52,113 @@ SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl query SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding, aad) AS STRING) FROM aes_tbl + +statement +CREATE TABLE aes_modes_tbl( + encrypted BINARY, + `key` BINARY, + mode STRING, + padding STRING, + label STRING +) USING parquet + +statement +INSERT INTO aes_modes_tbl +SELECT + aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'GCM', 'DEFAULT'), + encode('abcdefghijklmnop', 'UTF-8'), + 'GCM', + 'DEFAULT', + 'gcm_128' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'GCM', + 'DEFAULT'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'GCM', + 'DEFAULT', + 'gcm_192' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'GCM', + 'DEFAULT'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'GCM', + 'DEFAULT', + 'gcm_256' +UNION ALL +SELECT + aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'CBC', 'PKCS'), + encode('abcdefghijklmnop', 'UTF-8'), + 'CBC', + 'PKCS', + 'cbc_128' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'CBC', + 'PKCS'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'CBC', + 'PKCS', + 'cbc_192' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'CBC', + 'PKCS'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'CBC', + 'PKCS', + 'cbc_256' +UNION ALL +SELECT + aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'ECB', 'PKCS'), + encode('abcdefghijklmnop', 'UTF-8'), + 'ECB', + 'PKCS', + 'ecb_128' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'ECB', + 'PKCS'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'ECB', + 'PKCS', + 'ecb_192' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'ECB', + 'PKCS'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'ECB', + 'PKCS', + 'ecb_256' +UNION ALL +SELECT + cast(null AS binary), + encode('abcdefghijklmnop', 'UTF-8'), + 'GCM', + 'DEFAULT', + 'null_input' + +query +SELECT label, CAST(aes_decrypt(encrypted, `key`, mode, padding) AS STRING) +FROM aes_modes_tbl +ORDER BY label diff --git a/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala index b6f7e921b4..2e6cc90cb8 100644 --- a/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala @@ -68,4 +68,92 @@ class CometMiscExpressionSuite extends CometTestBase { } } + test("aes_decrypt mode and key-size combinations") { + withTempView("aes_modes_tbl") { + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + spark + .sql(""" + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'GCM', 'DEFAULT') AS encrypted, + | encode('abcdefghijklmnop', 'UTF-8') AS `key`, + | 'GCM' AS mode, + | 'DEFAULT' AS padding, + | 'gcm_128' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678', 'UTF-8'), 'GCM', 'DEFAULT') AS encrypted, + | encode('abcdefghijklmnop12345678', 'UTF-8') AS `key`, + | 'GCM' AS mode, + | 'DEFAULT' AS padding, + | 'gcm_192' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), 'GCM', 'DEFAULT') AS encrypted, + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8') AS `key`, + | 'GCM' AS mode, + | 'DEFAULT' AS padding, + | 'gcm_256' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'CBC', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop', 'UTF-8') AS `key`, + | 'CBC' AS mode, + | 'PKCS' AS padding, + | 'cbc_128' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678', 'UTF-8'), 'CBC', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop12345678', 'UTF-8') AS `key`, + | 'CBC' AS mode, + | 'PKCS' AS padding, + | 'cbc_192' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), 'CBC', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8') AS `key`, + | 'CBC' AS mode, + | 'PKCS' AS padding, + | 'cbc_256' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'ECB', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop', 'UTF-8') AS `key`, + | 'ECB' AS mode, + | 'PKCS' AS padding, + | 'ecb_128' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678', 'UTF-8'), 'ECB', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop12345678', 'UTF-8') AS `key`, + | 'ECB' AS mode, + | 'PKCS' AS padding, + | 'ecb_192' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), 'ECB', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8') AS `key`, + | 'ECB' AS mode, + | 'PKCS' AS padding, + | 'ecb_256' AS label + |UNION ALL + |SELECT + | cast(null AS binary) AS encrypted, + | encode('abcdefghijklmnop', 'UTF-8') AS `key`, + | 'GCM' AS mode, + | 'DEFAULT' AS padding, + | 'null_input' AS label + |""".stripMargin) + .createOrReplaceTempView("aes_modes_tbl") + } + + checkSparkAnswerAndOperator(""" + |SELECT + | label, + | CAST(aes_decrypt(encrypted, `key`, mode, padding) AS STRING) AS decrypted + |FROM aes_modes_tbl + |ORDER BY label + |""".stripMargin) + } + } + }