diff --git a/docs/spark_expressions_support.md b/docs/spark_expressions_support.md index 5474894108..752ee9d28e 100644 --- a/docs/spark_expressions_support.md +++ b/docs/spark_expressions_support.md @@ -353,7 +353,7 @@ ### misc_funcs -- [ ] aes_decrypt +- [x] aes_decrypt - [ ] aes_encrypt - [ ] assert_true - [x] current_catalog diff --git a/native/Cargo.lock b/native/Cargo.lock index 0977bb96dc..433c4db17d 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -17,6 +17,41 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "aead" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" +dependencies = [ + "crypto-common", + "generic-array", +] + +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures 0.2.17", +] + +[[package]] +name = "aes-gcm" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1" +dependencies = [ + "aead", + "aes", + "cipher", + "ctr", + "ghash", + "subtle", +] + [[package]] name = "ahash" version = "0.7.8" @@ -1073,6 +1108,15 @@ dependencies = [ "generic-array", ] +[[package]] +name = "block-padding" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8894febbff9f758034a5b8e12d87918f56dfc64a8e1fe757d65e29041538d93" +dependencies = [ + "generic-array", +] + [[package]] name = "blocking" version = "1.6.2" @@ -1245,6 +1289,15 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" +[[package]] +name = "cbc" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b52a9543ae338f279b96b0b9fed9c8093744685043739079ce85cd58f289a6" +dependencies = [ + "cipher", +] + [[package]] name = "cc" version = "1.2.56" @@ -1346,6 +1399,16 @@ dependencies = [ "half", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clang-sys" version = "1.8.1" @@ -1595,6 +1658,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" dependencies = [ "generic-array", + "rand_core 0.6.4", "typenum", ] @@ -1619,6 +1683,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "ctr" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" +dependencies = [ + "cipher", +] + [[package]] name = "darling" version = "0.20.11" @@ -1937,12 +2010,16 @@ dependencies = [ name = "datafusion-comet-spark-expr" version = "0.14.0" dependencies = [ + "aes", + "aes-gcm", "arrow", "base64", + "cbc", "chrono", "chrono-tz", "criterion", "datafusion", + "ecb", "futures", "hex", "num", @@ -2643,6 +2720,15 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" +[[package]] +name = "ecb" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a8bfa975b1aec2145850fcaa1c6fe269a16578c44705a532ae3edc92b8881c7" +dependencies = [ + "cipher", +] + [[package]] name = "either" version = "1.15.0" @@ -2979,6 +3065,16 @@ dependencies = [ "wasip3", ] +[[package]] +name = "ghash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1" +dependencies = [ + "opaque-debug", + "polyval", +] + [[package]] name = "gimli" version = "0.32.3" @@ -3493,6 +3589,16 @@ dependencies = [ "str_stack", ] +[[package]] +name = "inout" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" +dependencies = [ + "block-padding", + "generic-array", +] + [[package]] name = "integer-encoding" version = "3.0.4" @@ -4160,6 +4266,12 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" +[[package]] +name = "opaque-debug" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" + [[package]] name = "opendal" version = "0.55.0" @@ -4483,6 +4595,18 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "polyval" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" +dependencies = [ + "cfg-if", + "cpufeatures 0.2.17", + "opaque-debug", + "universal-hash", +] + [[package]] name = "portable-atomic" version = "1.13.0" @@ -6131,6 +6255,16 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "universal-hash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" +dependencies = [ + "crypto-common", + "subtle", +] + [[package]] name = "unsafe-any-ors" version = "1.0.0" diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 63e1c04762..8ac664377f 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -40,6 +40,10 @@ twox-hash = "2.1.2" rand = { workspace = true } hex = "0.4.3" base64 = "0.22.1" +aes = "0.8.4" +aes-gcm = "0.10.3" +cbc = "0.1.2" +ecb = "0.1.2" [dev-dependencies] arrow = {workspace = true} diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 4bfdef7096..3fd61f6b05 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -20,9 +20,9 @@ use crate::math_funcs::abs::abs; use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub}; use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ - spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, - spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, - spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkContains, SparkDateDiff, + spark_aes_decrypt, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, + spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, + spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkContains, SparkDateDiff, SparkDateTrunc, SparkMakeDate, SparkSizeFunc, SparkStringSpace, }; use arrow::datatypes::DataType; @@ -177,6 +177,10 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(abs); make_comet_scalar_udf!("abs", func, without data_type) } + "aes_decrypt" => { + let func = Arc::new(spark_aes_decrypt); + make_comet_scalar_udf!("aes_decrypt", func, without data_type) + } "split" => { let func = Arc::new(crate::string_funcs::spark_split); make_comet_scalar_udf!("split", func, without data_type) diff --git a/native/spark-expr/src/downcast.rs b/native/spark-expr/src/downcast.rs new file mode 100644 index 0000000000..ade2ef961b --- /dev/null +++ b/native/spark-expr/src/downcast.rs @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +macro_rules! opt_downcast_arg { + ($ARG:expr, $ARRAY_TYPE:ident) => {{ + $ARG.as_any().downcast_ref::<$ARRAY_TYPE>() + }}; +} + +macro_rules! downcast_named_arg { + ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{ + $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { + datafusion::common::internal_datafusion_err!( + "could not cast {} to {}", + $NAME, + std::any::type_name::<$ARRAY_TYPE>() + ) + })? + }}; +} + +pub(crate) use {downcast_named_arg, opt_downcast_arg}; diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 40eb180ab8..ed59682323 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -19,7 +19,9 @@ // The lint makes easier for code reader/reviewer separate references clones from more heavyweight ones #![deny(clippy::clone_on_ref_ptr)] +mod downcast; mod error; +pub(crate) use downcast::{downcast_named_arg, opt_downcast_arg}; pub mod kernels; pub use kernels::temporal::date_trunc_dyn; @@ -58,6 +60,7 @@ pub use bloom_filter::{BloomFilterAgg, BloomFilterMightContain}; mod conditional_funcs; mod conversion_funcs; mod math_funcs; +mod misc_funcs; mod nondetermenistic_funcs; pub use array_funcs::*; @@ -83,6 +86,7 @@ pub use math_funcs::{ spark_decimal_integral_div, spark_floor, spark_make_decimal, spark_round, spark_unhex, spark_unscaled_value, CheckOverflow, NegativeExpr, NormalizeNaNAndZero, }; +pub use misc_funcs::spark_aes_decrypt; pub use string_funcs::*; /// Spark supports three evaluation modes when evaluating expressions, which affect diff --git a/native/spark-expr/src/misc_funcs/aes_decrypt.rs b/native/spark-expr/src/misc_funcs/aes_decrypt.rs new file mode 100644 index 0000000000..605fe50994 --- /dev/null +++ b/native/spark-expr/src/misc_funcs/aes_decrypt.rs @@ -0,0 +1,311 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use aes::cipher::consts::{U12, U16}; +use aes::{Aes128, Aes192, Aes256}; +use aes_gcm::aead::{Aead, Payload}; +use aes_gcm::{Aes128Gcm, Aes256Gcm, AesGcm, KeyInit, Nonce}; +use arrow::array::{Array, ArrayRef, BinaryArray, LargeBinaryArray, LargeStringArray, StringArray}; +use arrow::datatypes::DataType; +use cbc::cipher::{block_padding::Pkcs7, BlockDecryptMut, KeyIvInit}; +use datafusion::common::{exec_err, DataFusionError, ScalarValue}; +use datafusion::logical_expr::ColumnarValue; + +const GCM_IV_LEN: usize = 12; +const CBC_IV_LEN: usize = 16; + +#[derive(Clone, Copy)] +enum AesMode { + Ecb, + Cbc, + Gcm, +} + +impl AesMode { + fn from_mode_padding(mode: &str, padding: &str) -> Result { + let is_none = padding.eq_ignore_ascii_case("NONE"); + let is_pkcs = padding.eq_ignore_ascii_case("PKCS"); + let is_default = padding.eq_ignore_ascii_case("DEFAULT"); + + if mode.eq_ignore_ascii_case("ECB") && (is_pkcs || is_default) { + Ok(Self::Ecb) + } else if mode.eq_ignore_ascii_case("CBC") && (is_pkcs || is_default) { + Ok(Self::Cbc) + } else if mode.eq_ignore_ascii_case("GCM") && (is_none || is_default) { + Ok(Self::Gcm) + } else { + exec_err!("Unsupported AES mode/padding combination: {mode}/{padding}") + } + } +} + +enum BinaryArg<'a> { + Binary(&'a BinaryArray), + LargeBinary(&'a LargeBinaryArray), +} + +impl<'a> BinaryArg<'a> { + fn from(arg_name: &str, arr: &'a ArrayRef) -> Result { + match arr.data_type() { + DataType::Binary => Ok(Self::Binary( + crate::opt_downcast_arg!(arr, BinaryArray).ok_or_else(|| { + datafusion::common::internal_datafusion_err!( + "could not cast {} to {}", + arg_name, + std::any::type_name::() + ) + })?, + )), + DataType::LargeBinary => Ok(Self::LargeBinary(crate::downcast_named_arg!( + arr, + arg_name, + LargeBinaryArray + ))), + other => exec_err!("{arg_name} must be Binary/LargeBinary, got {other:?}"), + } + } + + fn value(&self, i: usize) -> Option<&'a [u8]> { + match self { + Self::Binary(arr) => (!arr.is_null(i)).then(|| arr.value(i)), + Self::LargeBinary(arr) => (!arr.is_null(i)).then(|| arr.value(i)), + } + } +} + +enum StringArg<'a> { + Utf8(&'a StringArray), + LargeUtf8(&'a LargeStringArray), +} + +impl<'a> StringArg<'a> { + fn from(arg_name: &str, arr: &'a ArrayRef) -> Result { + match arr.data_type() { + DataType::Utf8 => Ok(Self::Utf8( + crate::opt_downcast_arg!(arr, StringArray).ok_or_else(|| { + datafusion::common::internal_datafusion_err!( + "could not cast {} to {}", + arg_name, + std::any::type_name::() + ) + })?, + )), + DataType::LargeUtf8 => Ok(Self::LargeUtf8(crate::downcast_named_arg!( + arr, + arg_name, + LargeStringArray + ))), + other => exec_err!("{arg_name} must be Utf8/LargeUtf8, got {other:?}"), + } + } + + fn value(&self, i: usize) -> Option<&'a str> { + match self { + Self::Utf8(arr) => (!arr.is_null(i)).then(|| arr.value(i)), + Self::LargeUtf8(arr) => (!arr.is_null(i)).then(|| arr.value(i)), + } + } +} + +type Aes128CbcDec = cbc::Decryptor; +type Aes192CbcDec = cbc::Decryptor; +type Aes256CbcDec = cbc::Decryptor; +type Aes128EcbDec = ecb::Decryptor; +type Aes192EcbDec = ecb::Decryptor; +type Aes256EcbDec = ecb::Decryptor; +type Aes192Gcm = AesGcm; + +fn decrypt_pkcs_cbc(input: &[u8], key: &[u8]) -> Result, DataFusionError> { + if input.len() < CBC_IV_LEN { + return exec_err!("AES decryption input is too short for CBC"); + } + let (iv, ciphertext) = input.split_at(CBC_IV_LEN); + let mut buf = ciphertext.to_vec(); + + let out = match key.len() { + 16 => Aes128CbcDec::new_from_slices(key, iv) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + 24 => Aes192CbcDec::new_from_slices(key, iv) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + 32 => Aes256CbcDec::new_from_slices(key, iv) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + _ => return exec_err!("Invalid AES key length: {}", key.len()), + }; + + Ok(out.to_vec()) +} + +fn decrypt_pkcs_ecb(input: &[u8], key: &[u8]) -> Result, DataFusionError> { + let mut buf = input.to_vec(); + + let out = match key.len() { + 16 => Aes128EcbDec::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + 24 => Aes192EcbDec::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + 32 => Aes256EcbDec::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + _ => return exec_err!("Invalid AES key length: {}", key.len()), + }; + + Ok(out.to_vec()) +} + +fn decrypt_gcm(input: &[u8], key: &[u8], aad: &[u8]) -> Result, DataFusionError> { + if input.len() < GCM_IV_LEN { + return exec_err!("AES decryption input is too short for GCM"); + } + let (iv, ciphertext) = input.split_at(GCM_IV_LEN); + let nonce = Nonce::from_slice(iv); + let payload = Payload { + msg: ciphertext, + aad, + }; + + match key.len() { + 16 => Aes128Gcm::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt(nonce, payload) + .map_err(|_| { + DataFusionError::Execution("AES crypto error: decrypt failed".to_string()) + }), + 24 => Aes192Gcm::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt(nonce, payload) + .map_err(|_| { + DataFusionError::Execution("AES crypto error: decrypt failed".to_string()) + }), + 32 => Aes256Gcm::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt(nonce, payload) + .map_err(|_| { + DataFusionError::Execution("AES crypto error: decrypt failed".to_string()) + }), + _ => exec_err!("Invalid AES key length: {}", key.len()), + } +} + +fn decrypt_one( + input: &[u8], + key: &[u8], + mode: &str, + padding: &str, + aad: &[u8], +) -> Result, DataFusionError> { + match AesMode::from_mode_padding(mode, padding)? { + AesMode::Ecb => decrypt_pkcs_ecb(input, key), + AesMode::Cbc => decrypt_pkcs_cbc(input, key), + AesMode::Gcm => decrypt_gcm(input, key, aad), + } +} + +pub fn spark_aes_decrypt(args: &[ColumnarValue]) -> Result { + if !(2..=5).contains(&args.len()) { + return exec_err!("aes_decrypt expects 2 to 5 arguments, got {}", args.len()); + } + + let are_scalars = args + .iter() + .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); + let arrays = ColumnarValue::values_to_arrays(args)?; + let num_rows = arrays[0].len(); + + let input = BinaryArg::from("input", &arrays[0])?; + let key = BinaryArg::from("key", &arrays[1])?; + + let mode = if args.len() >= 3 { + Some(StringArg::from("mode", &arrays[2])?) + } else { + None + }; + let padding = if args.len() >= 4 { + Some(StringArg::from("padding", &arrays[3])?) + } else { + None + }; + let aad = if args.len() >= 5 { + Some(BinaryArg::from("aad", &arrays[4])?) + } else { + None + }; + + let values: Result, DataFusionError> = (0..num_rows) + .map(|row| { + let Some(input_value) = input.value(row) else { + return Ok(ScalarValue::Binary(None)); + }; + let Some(key_value) = key.value(row) else { + return Ok(ScalarValue::Binary(None)); + }; + + let mode_value = match mode.as_ref() { + Some(mode) => { + let Some(mode) = mode.value(row) else { + return Ok(ScalarValue::Binary(None)); + }; + mode + } + None => "GCM", + }; + + let padding_value = match padding.as_ref() { + Some(padding) => { + let Some(padding) = padding.value(row) else { + return Ok(ScalarValue::Binary(None)); + }; + padding + } + None => "DEFAULT", + }; + + let aad_value = match aad.as_ref() { + Some(aad) => { + let Some(aad) = aad.value(row) else { + return Ok(ScalarValue::Binary(None)); + }; + aad + } + None => &[], + }; + + let plaintext = + decrypt_one(input_value, key_value, mode_value, padding_value, aad_value)?; + Ok(ScalarValue::Binary(Some(plaintext))) + }) + .collect(); + + let array: ArrayRef = ScalarValue::iter_to_array(values?)?; + if are_scalars { + Ok(ColumnarValue::Scalar( + datafusion::common::ScalarValue::try_from_array(array.as_ref(), 0)?, + )) + } else { + Ok(ColumnarValue::Array(array)) + } +} diff --git a/native/spark-expr/src/misc_funcs/mod.rs b/native/spark-expr/src/misc_funcs/mod.rs new file mode 100644 index 0000000000..c55b82811d --- /dev/null +++ b/native/spark-expr/src/misc_funcs/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub(crate) mod aes_decrypt; + +pub use aes_decrypt::spark_aes_decrypt; diff --git a/native/spark-expr/tests/spark_expr_reg.rs b/native/spark-expr/tests/spark_expr_reg.rs index 633b226068..f381b77881 100644 --- a/native/spark-expr/tests/spark_expr_reg.rs +++ b/native/spark-expr/tests/spark_expr_reg.rs @@ -35,6 +35,12 @@ mod tests { &session_state, None, )?); + let _ = session_state.register_udf(create_comet_physical_fun( + "aes_decrypt", + DataType::Binary, + &session_state, + None, + )?); let ctx = SessionContext::new_with_state(session_state); // 2. Execute SQL with literal values diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 9d13ccd9ed..e244249545 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -220,6 +220,7 @@ object QueryPlanSerde extends Logging with CometExprShim { private val miscExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( // TODO PromotePrecision + classOf[AesDecrypt] -> CometAesDecrypt, classOf[Alias] -> CometAlias, classOf[AttributeReference] -> CometAttributeReference, classOf[BloomFilterMightContain] -> CometBloomFilterMightContain, diff --git a/spark/src/main/scala/org/apache/comet/serde/misc.scala b/spark/src/main/scala/org/apache/comet/serde/misc.scala new file mode 100644 index 0000000000..9bb8416579 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/misc.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.{AesDecrypt, Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke + +import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType} + +private object CometAesDecryptHelper { + def convertToAesDecryptExpr[T <: Expression]( + expr: T, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + val childExpr = expr.children.map(exprToProtoInternal(_, inputs, binding)) + val optExpr = scalarFunctionExprToProtoWithReturnType( + "aes_decrypt", + expr.dataType, + failOnError = false, + childExpr: _*) + optExprWithInfo(optExpr, expr, expr.children: _*) + } +} + +object CometAesDecrypt extends CometExpressionSerde[AesDecrypt] { + override def convert( + expr: AesDecrypt, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + CometAesDecryptHelper.convertToAesDecryptExpr(expr, inputs, binding) + } +} + +object CometAesDecryptStaticInvoke extends CometExpressionSerde[StaticInvoke] { + override def convert( + expr: StaticInvoke, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + CometAesDecryptHelper.convertToAesDecryptExpr(expr, inputs, binding) + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/statics.scala b/spark/src/main/scala/org/apache/comet/serde/statics.scala index 0737644ab9..fa77369405 100644 --- a/spark/src/main/scala/org/apache/comet/serde/statics.scala +++ b/spark/src/main/scala/org/apache/comet/serde/statics.scala @@ -19,7 +19,7 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, ExpressionImplUtils} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils @@ -34,7 +34,8 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { : Map[(String, Class[_]), CometExpressionSerde[StaticInvoke]] = Map( ("readSidePadding", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction( - "read_side_padding")) + "read_side_padding"), + ("aesDecrypt", classOf[ExpressionImplUtils]) -> CometAesDecryptStaticInvoke) override def convert( expr: StaticInvoke, diff --git a/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql b/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql new file mode 100644 index 0000000000..5c93dd6a7a --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql @@ -0,0 +1,164 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- MinSparkVersion: 3.5 + +statement +CREATE TABLE aes_tbl( + encrypted_default BINARY, + encrypted_with_aad BINARY, + `key` BINARY, + mode STRING, + padding STRING, + iv BINARY, + aad STRING +) USING parquet + +statement +INSERT INTO aes_tbl +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')), + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'GCM', + 'DEFAULT', + unhex('00112233445566778899AABB'), + 'Comet AAD'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'GCM', + 'DEFAULT', + unhex('00112233445566778899AABB'), + 'Comet AAD' + +query +SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl + +query +SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding, aad) AS STRING) FROM aes_tbl + +statement +CREATE TABLE aes_modes_tbl( + encrypted BINARY, + `key` BINARY, + mode STRING, + padding STRING, + label STRING +) USING parquet + +statement +INSERT INTO aes_modes_tbl +SELECT + aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'GCM', 'DEFAULT'), + encode('abcdefghijklmnop', 'UTF-8'), + 'GCM', + 'DEFAULT', + 'gcm_128' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'GCM', + 'DEFAULT'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'GCM', + 'DEFAULT', + 'gcm_192' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'GCM', + 'DEFAULT'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'GCM', + 'DEFAULT', + 'gcm_256' +UNION ALL +SELECT + aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'CBC', 'PKCS'), + encode('abcdefghijklmnop', 'UTF-8'), + 'CBC', + 'PKCS', + 'cbc_128' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'CBC', + 'PKCS'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'CBC', + 'PKCS', + 'cbc_192' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'CBC', + 'PKCS'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'CBC', + 'PKCS', + 'cbc_256' +UNION ALL +SELECT + aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'ECB', 'PKCS'), + encode('abcdefghijklmnop', 'UTF-8'), + 'ECB', + 'PKCS', + 'ecb_128' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'ECB', + 'PKCS'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'ECB', + 'PKCS', + 'ecb_192' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'ECB', + 'PKCS'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'ECB', + 'PKCS', + 'ecb_256' +UNION ALL +SELECT + cast(null AS binary), + encode('abcdefghijklmnop', 'UTF-8'), + 'GCM', + 'DEFAULT', + 'null_input' + +query +SELECT label, CAST(aes_decrypt(encrypted, `key`, mode, padding) AS STRING) +FROM aes_modes_tbl +ORDER BY label diff --git a/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala new file mode 100644 index 0000000000..2e6cc90cb8 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.spark.sql.CometTestBase + +import org.apache.comet.CometSparkSessionExtensions.isSpark35Plus + +class CometMiscExpressionSuite extends CometTestBase { + + test("aes_decrypt") { + withTempView("aes_tbl") { + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val aesDf = if (isSpark35Plus) { + spark + .range(1) + .selectExpr( + "aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')) as encrypted_default", + "aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), 'GCM', 'DEFAULT', unhex('00112233445566778899AABB'), 'Comet AAD') as encrypted_with_aad", + "encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8') as `key`", + "'GCM' as mode", + "'DEFAULT' as padding", + "unhex('00112233445566778899AABB') as iv", + "'Comet AAD' as aad") + } else { + spark + .range(1) + .selectExpr( + "aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')) as encrypted_default", + "aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), 'GCM', 'DEFAULT') as encrypted_with_aad", + "encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8') as `key`", + "'GCM' as mode", + "'DEFAULT' as padding", + "cast(null as binary) as iv", + "cast(null as string) as aad") + } + aesDf.createOrReplaceTempView("aes_tbl") + } + + if (isSpark35Plus) { + checkSparkAnswerAndOperator( + "SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl") + checkSparkAnswerAndOperator( + "SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding, aad) AS STRING) FROM aes_tbl") + } else { + checkSparkAnswerAndOperator( + "SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl") + checkSparkAnswerAndOperator( + "SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding) AS STRING) FROM aes_tbl") + } + } + } + + test("aes_decrypt mode and key-size combinations") { + withTempView("aes_modes_tbl") { + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + spark + .sql(""" + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'GCM', 'DEFAULT') AS encrypted, + | encode('abcdefghijklmnop', 'UTF-8') AS `key`, + | 'GCM' AS mode, + | 'DEFAULT' AS padding, + | 'gcm_128' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678', 'UTF-8'), 'GCM', 'DEFAULT') AS encrypted, + | encode('abcdefghijklmnop12345678', 'UTF-8') AS `key`, + | 'GCM' AS mode, + | 'DEFAULT' AS padding, + | 'gcm_192' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), 'GCM', 'DEFAULT') AS encrypted, + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8') AS `key`, + | 'GCM' AS mode, + | 'DEFAULT' AS padding, + | 'gcm_256' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'CBC', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop', 'UTF-8') AS `key`, + | 'CBC' AS mode, + | 'PKCS' AS padding, + | 'cbc_128' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678', 'UTF-8'), 'CBC', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop12345678', 'UTF-8') AS `key`, + | 'CBC' AS mode, + | 'PKCS' AS padding, + | 'cbc_192' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), 'CBC', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8') AS `key`, + | 'CBC' AS mode, + | 'PKCS' AS padding, + | 'cbc_256' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'ECB', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop', 'UTF-8') AS `key`, + | 'ECB' AS mode, + | 'PKCS' AS padding, + | 'ecb_128' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678', 'UTF-8'), 'ECB', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop12345678', 'UTF-8') AS `key`, + | 'ECB' AS mode, + | 'PKCS' AS padding, + | 'ecb_192' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), 'ECB', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8') AS `key`, + | 'ECB' AS mode, + | 'PKCS' AS padding, + | 'ecb_256' AS label + |UNION ALL + |SELECT + | cast(null AS binary) AS encrypted, + | encode('abcdefghijklmnop', 'UTF-8') AS `key`, + | 'GCM' AS mode, + | 'DEFAULT' AS padding, + | 'null_input' AS label + |""".stripMargin) + .createOrReplaceTempView("aes_modes_tbl") + } + + checkSparkAnswerAndOperator(""" + |SELECT + | label, + | CAST(aes_decrypt(encrypted, `key`, mode, padding) AS STRING) AS decrypted + |FROM aes_modes_tbl + |ORDER BY label + |""".stripMargin) + } + } + +}