From 68f74aecde2044f3e98a28561fe7bc099340a120 Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 29 Nov 2025 21:17:49 -0800 Subject: [PATCH 1/9] prototype --- native/spark-expr/src/comet_scalar_funcs.rs | 4 +- native/spark-expr/src/string_funcs/mod.rs | 2 + .../src/string_funcs/regexp_extract.rs | 401 ++++++++++++++++++ .../apache/comet/serde/QueryPlanSerde.scala | 2 + .../org/apache/comet/serde/strings.scala | 100 ++++- .../comet/CometStringExpressionSuite.scala | 122 ++++++ 6 files changed, 629 insertions(+), 2 deletions(-) create mode 100644 native/spark-expr/src/string_funcs/regexp_extract.rs diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 021bb1c78f..45b4ca8aad 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -23,7 +23,7 @@ use crate::{ spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkBitwiseNot, - SparkDateTrunc, SparkStringSpace, + SparkDateTrunc, SparkRegExpExtract, SparkRegExpExtractAll, SparkStringSpace, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -199,6 +199,8 @@ fn all_scalar_functions() -> Vec> { Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())), Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())), Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())), + Arc::new(ScalarUDF::new_from_impl(SparkRegExpExtract::default())), + Arc::new(ScalarUDF::new_from_impl(SparkRegExpExtractAll::default())), ] } diff --git a/native/spark-expr/src/string_funcs/mod.rs b/native/spark-expr/src/string_funcs/mod.rs index aac8204e29..2026ec5fec 100644 --- a/native/spark-expr/src/string_funcs/mod.rs +++ b/native/spark-expr/src/string_funcs/mod.rs @@ -15,8 +15,10 @@ // specific language governing permissions and limitations // under the License. +mod regexp_extract; mod string_space; mod substring; +pub use regexp_extract::{SparkRegExpExtract, SparkRegExpExtractAll}; pub use string_space::SparkStringSpace; pub use substring::SubstringExpr; diff --git a/native/spark-expr/src/string_funcs/regexp_extract.rs b/native/spark-expr/src/string_funcs/regexp_extract.rs new file mode 100644 index 0000000000..2a9ce6b82c --- /dev/null +++ b/native/spark-expr/src/string_funcs/regexp_extract.rs @@ -0,0 +1,401 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, GenericStringArray}; +use arrow::datatypes::DataType; +use datafusion::common::{exec_err, internal_datafusion_err, Result, ScalarValue}; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use regex::Regex; +use std::sync::Arc; +use std::any::Any; + +/// Spark-compatible regexp_extract function +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkRegExpExtract { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkRegExpExtract { + fn default() -> Self { + Self::new() + } +} + +impl SparkRegExpExtract { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkRegExpExtract { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "regexp_extract" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + // regexp_extract always returns String + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + // regexp_extract(subject, pattern, idx) + if args.args.len() != 3 { + return exec_err!( + "regexp_extract expects 3 arguments, got {}", + args.args.len() + ); + } + + let subject = &args.args[0]; + let pattern = &args.args[1]; + let idx = &args.args[2]; + + // Pattern must be a literal string + let pattern_str = match pattern { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s.clone(), + _ => { + return exec_err!("regexp_extract pattern must be a string literal"); + } + }; + + // idx must be a literal int + let idx_val = match idx { + ColumnarValue::Scalar(ScalarValue::Int32(Some(i))) => *i as usize, + _ => { + return exec_err!("regexp_extract idx must be an integer literal"); + } + }; + + // Compile regex once + let regex = Regex::new(&pattern_str).map_err(|e| { + internal_datafusion_err!("Invalid regex pattern '{}': {}", pattern_str, e) + })?; + + match subject { + ColumnarValue::Array(array) => { + let result = regexp_extract_array(array, ®ex, idx_val)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(ScalarValue::Utf8(s)) => { + let result = match s { + Some(text) => Some(extract_group(text, ®ex, idx_val)), + None => None, // NULL input → NULL output + }; + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + _ => exec_err!("regexp_extract expects string input"), + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Spark-compatible regexp_extract_all function +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkRegExpExtractAll { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkRegExpExtractAll { + fn default() -> Self { + Self::new() + } +} + +impl SparkRegExpExtractAll { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkRegExpExtractAll { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "regexp_extract_all" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + // regexp_extract_all returns Array + Ok(DataType::List(Arc::new( + arrow::datatypes::Field::new("item", DataType::Utf8, true), + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + // regexp_extract_all(subject, pattern) or regexp_extract_all(subject, pattern, idx) + if args.args.len() < 2 || args.args.len() > 3 { + return exec_err!( + "regexp_extract_all expects 2 or 3 arguments, got {}", + args.args.len() + ); + } + + let subject = &args.args[0]; + let pattern = &args.args[1]; + let idx_val = if args.args.len() == 3 { + match &args.args[2] { + ColumnarValue::Scalar(ScalarValue::Int32(Some(i))) => *i as usize, + _ => { + return exec_err!("regexp_extract_all idx must be an integer literal"); + } + } + } else { + 0 // default to group 0 (entire match) + }; + + // Pattern must be a literal string + let pattern_str = match pattern { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s.clone(), + _ => { + return exec_err!("regexp_extract_all pattern must be a string literal"); + } + }; + + // Compile regex once + let regex = Regex::new(&pattern_str).map_err(|e| { + internal_datafusion_err!("Invalid regex pattern '{}': {}", pattern_str, e) + })?; + + match subject { + ColumnarValue::Array(array) => { + let result = regexp_extract_all_array(array, ®ex, idx_val)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(ScalarValue::Utf8(s)) => { + match s { + Some(text) => { + let matches = extract_all_groups(text, ®ex, idx_val); + // Build a list array with a single element + let mut list_builder = + arrow::array::ListBuilder::new(arrow::array::StringBuilder::new()); + for m in matches { + list_builder.values().append_value(m); + } + list_builder.append(true); + let list_array = list_builder.finish(); + + Ok(ColumnarValue::Scalar(ScalarValue::List( + Arc::new(list_array), + ))) + } + None => { + // Return NULL list using try_into (same as planner.rs:424) + let null_list: ScalarValue = DataType::List(Arc::new( + arrow::datatypes::Field::new("item", DataType::Utf8, true) + )).try_into()?; + Ok(ColumnarValue::Scalar(null_list)) + } + } + } + _ => exec_err!("regexp_extract_all expects string input"), + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +// Helper functions + +fn extract_group(text: &str, regex: &Regex, idx: usize) -> String { + regex + .captures(text) + .and_then(|caps| caps.get(idx)) + .map(|m| m.as_str().to_string()) + // Spark behavior: return empty string "" if no match or group not found + .unwrap_or_else(|| String::new()) +} + +fn regexp_extract_array(array: &ArrayRef, regex: &Regex, idx: usize) -> Result { + let string_array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + internal_datafusion_err!("regexp_extract expects string array input") + })?; + + let result: GenericStringArray = string_array + .iter() + .map(|s| s.map(|text| extract_group(text, regex, idx))) // NULL → None, non-NULL → Some("") + .collect(); + + Ok(Arc::new(result)) +} + +fn extract_all_groups(text: &str, regex: &Regex, idx: usize) -> Vec { + regex + .captures_iter(text) + .filter_map(|caps| caps.get(idx).map(|m| m.as_str().to_string())) + .collect() +} + +fn regexp_extract_all_array(array: &ArrayRef, regex: &Regex, idx: usize) -> Result { + let string_array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + internal_datafusion_err!("regexp_extract_all expects string array input") + })?; + + let mut list_builder = + arrow::array::ListBuilder::new(arrow::array::StringBuilder::new()); + + for s in string_array.iter() { + match s { + Some(text) => { + let matches = extract_all_groups(text, regex, idx); + for m in matches { + list_builder.values().append_value(m); + } + list_builder.append(true); + } + None => { + list_builder.append(false); + } + } + } + + Ok(Arc::new(list_builder.finish())) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::StringArray; + + #[test] + fn test_regexp_extract_basic() { + let regex = Regex::new(r"(\d+)-(\w+)").unwrap(); + + // Spark behavior: return "" on no match, not None + assert_eq!(extract_group("123-abc", ®ex, 0), "123-abc"); + assert_eq!(extract_group("123-abc", ®ex, 1), "123"); + assert_eq!(extract_group("123-abc", ®ex, 2), "abc"); + assert_eq!(extract_group("123-abc", ®ex, 3), ""); // no such group → "" + assert_eq!(extract_group("no match", ®ex, 0), ""); // no match → "" + } + + #[test] + fn test_regexp_extract_all_basic() { + let regex = Regex::new(r"(\d+)").unwrap(); + + // Multiple matches + let matches = extract_all_groups("a1b2c3", ®ex, 0); + assert_eq!(matches, vec!["1", "2", "3"]); + + // Same with group index 1 + let matches = extract_all_groups("a1b2c3", ®ex, 1); + assert_eq!(matches, vec!["1", "2", "3"]); + + // No match + let matches = extract_all_groups("no digits", ®ex, 0); + assert!(matches.is_empty()); + assert_eq!(matches, Vec::::new()); + } + + #[test] + fn test_regexp_extract_all_array() -> Result<()> { + use datafusion::common::cast::as_list_array; + + let regex = Regex::new(r"(\d+)").unwrap(); + let array = Arc::new(StringArray::from(vec![ + Some("a1b2"), + Some("no digits"), + None, + Some("c3d4e5"), + ])) as ArrayRef; + + let result = regexp_extract_all_array(&array, ®ex, 0)?; + let list_array = as_list_array(&result)?; + + // Row 0: "a1b2" → ["1", "2"] + let row0 = list_array.value(0); + let row0_str = row0.as_any().downcast_ref::>().unwrap(); + assert_eq!(row0_str.len(), 2); + assert_eq!(row0_str.value(0), "1"); + assert_eq!(row0_str.value(1), "2"); + + // Row 1: "no digits" → [] (empty array, not NULL) + let row1 = list_array.value(1); + let row1_str = row1.as_any().downcast_ref::>().unwrap(); + assert_eq!(row1_str.len(), 0); // Empty array + assert!(!list_array.is_null(1)); // Not NULL, just empty + + // Row 2: NULL input → NULL output + assert!(list_array.is_null(2)); + + // Row 3: "c3d4e5" → ["3", "4", "5"] + let row3 = list_array.value(3); + let row3_str = row3.as_any().downcast_ref::>().unwrap(); + assert_eq!(row3_str.len(), 3); + assert_eq!(row3_str.value(0), "3"); + assert_eq!(row3_str.value(1), "4"); + assert_eq!(row3_str.value(2), "5"); + + Ok(()) + } + + #[test] + fn test_regexp_extract_array() -> Result<()> { + let regex = Regex::new(r"(\d+)-(\w+)").unwrap(); + let array = Arc::new(StringArray::from(vec![ + Some("123-abc"), + Some("456-def"), + None, + Some("no-match"), + ])) as ArrayRef; + + let result = regexp_extract_array(&array, ®ex, 1)?; + let result_array = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result_array.value(0), "123"); + assert_eq!(result_array.value(1), "456"); + assert!(result_array.is_null(2)); // NULL input → NULL output + assert_eq!(result_array.value(3), ""); // no match → "" (empty string) + + Ok(()) + } +} + diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 54df2f1688..d18e84ffab 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -153,6 +153,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Like] -> CometLike, classOf[Lower] -> CometLower, classOf[OctetLength] -> CometScalarFunction("octet_length"), + classOf[RegExpExtract] -> CometRegExpExtract, + classOf[RegExpExtractAll] -> CometRegExpExtractAll, classOf[RegExpReplace] -> CometRegExpReplace, classOf[Reverse] -> CometReverse, classOf[RLike] -> CometRLike, diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 15f4b238f2..0756615bd2 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, Expression, InitCap, Length, Like, Literal, Lower, RegExpReplace, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, Expression, InitCap, Length, Like, Literal, Lower, RegExpReplace, RegExpExtract, RegExpExtractAll, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper} import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType} import org.apache.comet.CometConf @@ -286,3 +286,101 @@ trait CommonStringExprs { } } } + +object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { + override def getSupportLevel(expr: RegExpExtract): SupportLevel = { + // Check if the pattern is compatible with Spark + expr.regexp match { + case Literal(pattern, DataTypes.StringType) => + if (!RegExp.isSupportedPattern(pattern.toString) && + !CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) { + withInfo( + expr, + s"Regexp pattern $pattern is not compatible with Spark. " + + s"Set ${CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key}=true " + + "to allow it anyway.") + return Incompatible() + } + case _ => + // Pattern must be a literal + return Unsupported(Some("Only literal regexp patterns are supported")) + } + + // Check if idx is a literal + expr.idx match { + case Literal(_, DataTypes.IntegerType) => Compatible() + case _ => + Unsupported(Some("Only literal group index is supported")) + } + } + + override def convert( + expr: RegExpExtract, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding) + val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding) + val idxExpr = exprToProtoInternal(expr.idx, inputs, binding) + + val optExpr = scalarFunctionExprToProto( + "regexp_extract", + subjectExpr, + patternExpr, + idxExpr) + optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.idx) + } +} + +object CometRegExpExtractAll extends CometExpressionSerde[RegExpExtractAll] { + override def getSupportLevel(expr: RegExpExtractAll): SupportLevel = { + // Check if the pattern is compatible with Spark + expr.regexp match { + case Literal(pattern, DataTypes.StringType) => + if (!RegExp.isSupportedPattern(pattern.toString) && + !CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) { + withInfo( + expr, + s"Regexp pattern $pattern is not compatible with Spark. " + + s"Set ${CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key}=true " + + "to allow it anyway.") + return Incompatible() + } + case _ => + // Pattern must be a literal + return Unsupported(Some("Only literal regexp patterns are supported")) + } + + // Check if idx is a literal if exists + if (expr.idx.isDefined) { + expr.idx.get match { + case Literal(_, DataTypes.IntegerType) => Compatible() + case _ => return Unsupported(Some("Only literal group index is supported")) + } + } + } + + override def convert(expr: RegExpExtractAll, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding) + val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding) + + val optExpr = if (expr.idx.isDefined) { + val idxExpr = exprToProtoInternal(expr.idx.get, inputs, binding) + scalarFunctionExprToProto( + "regexp_extract_all", + subjectExpr, + patternExpr, + idxExpr) + } else { + scalarFunctionExprToProto( + "regexp_extract_all", + subjectExpr, + patternExpr) + } + + if (expr.idx.isDefined) { + optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.idx.get) + } else { + optExprWithInfo(optExpr, expr, expr.subject, expr.regexp) + } + } +} \ No newline at end of file diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index f9882780c8..ffa609b8f1 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -391,4 +391,126 @@ class CometStringExpressionSuite extends CometTestBase { } } + test("regexp_extract basic") { + import org.apache.comet.CometConf + + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq( + ("100-200", 1), + ("300-400", 1), + (null, 1), // NULL input + ("no-match", 1), // no match → should return "" + ("abc123def456", 1), + ("", 1) // empty string + ) + + withParquetTable(data, "tbl") { + // Test basic extraction: group 0 (full match) + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 0) FROM tbl") + // Test group 1 + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 1) FROM tbl") + // Test group 2 + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 2) FROM tbl") + // Test non-existent group → should return "" + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 3) FROM tbl") + } + } + } + + test("regexp_extract edge cases") { + import org.apache.comet.CometConf + + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq( + ("email@example.com", 1), + ("phone: 123-456-7890", 1), + ("price: $99.99", 1), + (null, 1) + ) + + withParquetTable(data, "tbl") { + // Extract email domain + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '@([^.]+)', 1) FROM tbl") + // Extract phone number + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '(\\d{3}-\\d{3}-\\d{4})', 1) FROM tbl") + // Extract price + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '\\$(\\d+\\.\\d+)', 1) FROM tbl") + } + } + } + + test("regexp_extract_all basic") { + import org.apache.comet.CometConf + + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq( + ("a1b2c3", 1), + ("test123test456", 1), + (null, 1), // NULL input + ("no digits", 1), // no match → should return [] + ("", 1) // empty string + ) + + withParquetTable(data, "tbl") { + // Test default (group 0) + checkSparkAnswerAndOperator( + "SELECT regexp_extract_all(_1, '\\d+') FROM tbl") + // Test with explicit group 0 + checkSparkAnswerAndOperator( + "SELECT regexp_extract_all(_1, '(\\d+)', 0) FROM tbl") + // Test group 1 + checkSparkAnswerAndOperator( + "SELECT regexp_extract_all(_1, '(\\d+)', 1) FROM tbl") + } + } + } + + test("regexp_extract_all multiple matches") { + import org.apache.comet.CometConf + + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq( + ("The prices are $10, $20, and $30", 1), + ("colors: red, green, blue", 1), + ("words: hello world", 1), + (null, 1) + ) + + withParquetTable(data, "tbl") { + // Extract all prices + checkSparkAnswerAndOperator( + "SELECT regexp_extract_all(_1, '\\$(\\d+)', 1) FROM tbl") + // Extract all words + checkSparkAnswerAndOperator( + "SELECT regexp_extract_all(_1, '([a-z]+)', 1) FROM tbl") + } + } + } + + test("regexp_extract_all with dictionary encoding") { + import org.apache.comet.CometConf + + withSQLConf( + CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true", + "parquet.enable.dictionary" -> "true") { + // Use repeated values to trigger dictionary encoding + val data = (0 until 1000).map(i => { + val text = if (i % 3 == 0) "a1b2c3" else if (i % 3 == 1) "x5y6" else "no-match" + (text, 1) + }) + + withParquetTable(data, "tbl") { + checkSparkAnswerAndOperator( + "SELECT regexp_extract_all(_1, '\\d+') FROM tbl") + } + } + } + } From 4dbed777f3285af2d7a6c9e3cbc6e6ac1d84d5ed Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 29 Nov 2025 22:09:39 -0800 Subject: [PATCH 2/9] refactor strings.scala --- .../org/apache/comet/serde/strings.scala | 65 ++++++++----------- 1 file changed, 27 insertions(+), 38 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 0756615bd2..a4124048ae 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, Expression, InitCap, Length, Like, Literal, Lower, RegExpReplace, RegExpExtract, RegExpExtractAll, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, Expression, InitCap, Length, Like, Literal, Lower, RegExpExtract, RegExpExtractAll, RegExpReplace, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper} import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType} import org.apache.comet.CometConf @@ -289,7 +289,7 @@ trait CommonStringExprs { object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { override def getSupportLevel(expr: RegExpExtract): SupportLevel = { - // Check if the pattern is compatible with Spark + // Check if the pattern is compatible with Spark or allow incompatible patterns expr.regexp match { case Literal(pattern, DataTypes.StringType) => if (!RegExp.isSupportedPattern(pattern.toString) && @@ -302,13 +302,13 @@ object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { return Incompatible() } case _ => - // Pattern must be a literal return Unsupported(Some("Only literal regexp patterns are supported")) } - + // Check if idx is a literal expr.idx match { - case Literal(_, DataTypes.IntegerType) => Compatible() + case Literal(_, DataTypes.IntegerType) => + Compatible() case _ => Unsupported(Some("Only literal group index is supported")) } @@ -321,7 +321,6 @@ object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding) val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding) val idxExpr = exprToProtoInternal(expr.idx, inputs, binding) - val optExpr = scalarFunctionExprToProto( "regexp_extract", subjectExpr, @@ -333,7 +332,7 @@ object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { object CometRegExpExtractAll extends CometExpressionSerde[RegExpExtractAll] { override def getSupportLevel(expr: RegExpExtractAll): SupportLevel = { - // Check if the pattern is compatible with Spark + // Check if the pattern is compatible with Spark or allow incompatible patterns expr.regexp match { case Literal(pattern, DataTypes.StringType) => if (!RegExp.isSupportedPattern(pattern.toString) && @@ -346,41 +345,31 @@ object CometRegExpExtractAll extends CometExpressionSerde[RegExpExtractAll] { return Incompatible() } case _ => - // Pattern must be a literal return Unsupported(Some("Only literal regexp patterns are supported")) } - - // Check if idx is a literal if exists - if (expr.idx.isDefined) { - expr.idx.get match { - case Literal(_, DataTypes.IntegerType) => Compatible() - case _ => return Unsupported(Some("Only literal group index is supported")) - } + + // Check if idx is a literal + // For regexp_extract_all, idx will be default to 1 if not specified + expr.idx match { + case Literal(_, DataTypes.IntegerType) => + Compatible() + case _ => + Unsupported(Some("Only literal group index is supported")) } } - - override def convert(expr: RegExpExtractAll, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + override def convert( + expr: RegExpExtractAll, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + // Check if the pattern is compatible with Spark or allow incompatible patterns val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding) val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding) - - val optExpr = if (expr.idx.isDefined) { - val idxExpr = exprToProtoInternal(expr.idx.get, inputs, binding) - scalarFunctionExprToProto( - "regexp_extract_all", - subjectExpr, - patternExpr, - idxExpr) - } else { - scalarFunctionExprToProto( - "regexp_extract_all", - subjectExpr, - patternExpr) - } - - if (expr.idx.isDefined) { - optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.idx.get) - } else { - optExprWithInfo(optExpr, expr, expr.subject, expr.regexp) - } + val idxExpr = exprToProtoInternal(expr.idx, inputs, binding) + val optExpr = scalarFunctionExprToProto( + "regexp_extract_all", + subjectExpr, + patternExpr, + idxExpr) + optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.idx) } -} \ No newline at end of file +} From f1013628ea975c3bf8ec7fc1f2eefb412482fbaf Mon Sep 17 00:00:00 2001 From: Daniel Tu Date: Sun, 30 Nov 2025 21:35:49 -0800 Subject: [PATCH 3/9] test, format and configs --- docs/source/user-guide/latest/configs.md | 2 + .../org/apache/comet/serde/strings.scala | 15 +--- .../comet/CometStringExpressionSuite.scala | 90 +++++++++---------- 3 files changed, 47 insertions(+), 60 deletions(-) diff --git a/docs/source/user-guide/latest/configs.md b/docs/source/user-guide/latest/configs.md index a1c3212c20..f5638d5cf4 100644 --- a/docs/source/user-guide/latest/configs.md +++ b/docs/source/user-guide/latest/configs.md @@ -291,6 +291,8 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.RLike.enabled` | Enable Comet acceleration for `RLike` | true | | `spark.comet.expression.Rand.enabled` | Enable Comet acceleration for `Rand` | true | | `spark.comet.expression.Randn.enabled` | Enable Comet acceleration for `Randn` | true | +| `spark.comet.expression.RegExpExtract.enabled` | Enable Comet acceleration for `RegExpExtract` | true | +| `spark.comet.expression.RegExpExtractAll.enabled` | Enable Comet acceleration for `RegExpExtractAll` | true | | `spark.comet.expression.RegExpReplace.enabled` | Enable Comet acceleration for `RegExpReplace` | true | | `spark.comet.expression.Remainder.enabled` | Enable Comet acceleration for `Remainder` | true | | `spark.comet.expression.Reverse.enabled` | Enable Comet acceleration for `Reverse` | true | diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index a4124048ae..733c25ec2b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -321,11 +321,7 @@ object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding) val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding) val idxExpr = exprToProtoInternal(expr.idx, inputs, binding) - val optExpr = scalarFunctionExprToProto( - "regexp_extract", - subjectExpr, - patternExpr, - idxExpr) + val optExpr = scalarFunctionExprToProto("regexp_extract", subjectExpr, patternExpr, idxExpr) optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.idx) } } @@ -349,7 +345,7 @@ object CometRegExpExtractAll extends CometExpressionSerde[RegExpExtractAll] { } // Check if idx is a literal - // For regexp_extract_all, idx will be default to 1 if not specified + // For regexp_extract_all, idx will default to 0 (group 0, entire match) if not specified expr.idx match { case Literal(_, DataTypes.IntegerType) => Compatible() @@ -365,11 +361,8 @@ object CometRegExpExtractAll extends CometExpressionSerde[RegExpExtractAll] { val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding) val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding) val idxExpr = exprToProtoInternal(expr.idx, inputs, binding) - val optExpr = scalarFunctionExprToProto( - "regexp_extract_all", - subjectExpr, - patternExpr, - idxExpr) + val optExpr = + scalarFunctionExprToProto("regexp_extract_all", subjectExpr, patternExpr, idxExpr) optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.idx) } } diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index ffa609b8f1..5214eb8215 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -393,110 +393,102 @@ class CometStringExpressionSuite extends CometTestBase { test("regexp_extract basic") { import org.apache.comet.CometConf - + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { val data = Seq( ("100-200", 1), ("300-400", 1), - (null, 1), // NULL input - ("no-match", 1), // no match → should return "" + (null, 1), // NULL input + ("no-match", 1), // no match → should return "" ("abc123def456", 1), - ("", 1) // empty string + ("", 1) // empty string ) - + withParquetTable(data, "tbl") { // Test basic extraction: group 0 (full match) - checkSparkAnswerAndOperator( - "SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 0) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 0) FROM tbl") // Test group 1 - checkSparkAnswerAndOperator( - "SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 1) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 1) FROM tbl") // Test group 2 - checkSparkAnswerAndOperator( - "SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 2) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 2) FROM tbl") // Test non-existent group → should return "" - checkSparkAnswerAndOperator( - "SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 3) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 3) FROM tbl") + // Test empty pattern + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '', 0) FROM tbl") + // Test null pattern + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, NULL, 0) FROM tbl") } } } test("regexp_extract edge cases") { import org.apache.comet.CometConf - + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { - val data = Seq( - ("email@example.com", 1), - ("phone: 123-456-7890", 1), - ("price: $99.99", 1), - (null, 1) - ) - + val data = + Seq(("email@example.com", 1), ("phone: 123-456-7890", 1), ("price: $99.99", 1), (null, 1)) + withParquetTable(data, "tbl") { // Extract email domain - checkSparkAnswerAndOperator( - "SELECT regexp_extract(_1, '@([^.]+)', 1) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '@([^.]+)', 1) FROM tbl") // Extract phone number checkSparkAnswerAndOperator( "SELECT regexp_extract(_1, '(\\d{3}-\\d{3}-\\d{4})', 1) FROM tbl") // Extract price - checkSparkAnswerAndOperator( - "SELECT regexp_extract(_1, '\\$(\\d+\\.\\d+)', 1) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '\\$(\\d+\\.\\d+)', 1) FROM tbl") } } } test("regexp_extract_all basic") { import org.apache.comet.CometConf - + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { val data = Seq( ("a1b2c3", 1), ("test123test456", 1), - (null, 1), // NULL input - ("no digits", 1), // no match → should return [] - ("", 1) // empty string + (null, 1), // NULL input + ("no digits", 1), // no match → should return [] + ("", 1) // empty string ) - + withParquetTable(data, "tbl") { - // Test default (group 0) - checkSparkAnswerAndOperator( - "SELECT regexp_extract_all(_1, '\\d+') FROM tbl") + // Test with explicit group 0 (full match on no-group pattern) + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '\\d+', 0) FROM tbl") // Test with explicit group 0 - checkSparkAnswerAndOperator( - "SELECT regexp_extract_all(_1, '(\\d+)', 0) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '(\\d+)', 0) FROM tbl") // Test group 1 - checkSparkAnswerAndOperator( - "SELECT regexp_extract_all(_1, '(\\d+)', 1) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '(\\d+)', 1) FROM tbl") + // Test empty pattern + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '', 0) FROM tbl") + // Test null pattern + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, NULL, 0) FROM tbl") } } } test("regexp_extract_all multiple matches") { import org.apache.comet.CometConf - + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { val data = Seq( ("The prices are $10, $20, and $30", 1), ("colors: red, green, blue", 1), ("words: hello world", 1), - (null, 1) - ) - + (null, 1)) + withParquetTable(data, "tbl") { // Extract all prices - checkSparkAnswerAndOperator( - "SELECT regexp_extract_all(_1, '\\$(\\d+)', 1) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '\\$(\\d+)', 1) FROM tbl") // Extract all words - checkSparkAnswerAndOperator( - "SELECT regexp_extract_all(_1, '([a-z]+)', 1) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '([a-z]+)', 1) FROM tbl") } } } test("regexp_extract_all with dictionary encoding") { import org.apache.comet.CometConf - + withSQLConf( CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true", "parquet.enable.dictionary" -> "true") { @@ -505,10 +497,10 @@ class CometStringExpressionSuite extends CometTestBase { val text = if (i % 3 == 0) "a1b2c3" else if (i % 3 == 1) "x5y6" else "no-match" (text, 1) }) - + withParquetTable(data, "tbl") { - checkSparkAnswerAndOperator( - "SELECT regexp_extract_all(_1, '\\d+') FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '\\d+') FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '\\d+', 0) FROM tbl") } } } From ff1ebd6b3bebe85c0f393b268660dc8031614bc1 Mon Sep 17 00:00:00 2001 From: Daniel Tu Date: Mon, 1 Dec 2025 00:05:28 -0800 Subject: [PATCH 4/9] make regexp_extract more align with spark's behavior --- .../src/string_funcs/regexp_extract.rs | 169 ++++++++++++------ 1 file changed, 110 insertions(+), 59 deletions(-) diff --git a/native/spark-expr/src/string_funcs/regexp_extract.rs b/native/spark-expr/src/string_funcs/regexp_extract.rs index 2a9ce6b82c..eba2e7993c 100644 --- a/native/spark-expr/src/string_funcs/regexp_extract.rs +++ b/native/spark-expr/src/string_funcs/regexp_extract.rs @@ -22,8 +22,8 @@ use datafusion::logical_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use regex::Regex; -use std::sync::Arc; use std::any::Any; +use std::sync::Arc; /// Spark-compatible regexp_extract function #[derive(Debug, PartialEq, Eq, Hash)] @@ -106,8 +106,8 @@ impl ScalarUDFImpl for SparkRegExpExtract { } ColumnarValue::Scalar(ScalarValue::Utf8(s)) => { let result = match s { - Some(text) => Some(extract_group(text, ®ex, idx_val)), - None => None, // NULL input → NULL output + Some(text) => Some(extract_group(text, ®ex, idx_val)?), + None => None, // NULL input → NULL output }; Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) } @@ -157,9 +157,11 @@ impl ScalarUDFImpl for SparkRegExpExtractAll { fn return_type(&self, _arg_types: &[DataType]) -> Result { // regexp_extract_all returns Array - Ok(DataType::List(Arc::new( - arrow::datatypes::Field::new("item", DataType::Utf8, true), - ))) + Ok(DataType::List(Arc::new(arrow::datatypes::Field::new( + "item", + DataType::Utf8, + false, + )))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -181,7 +183,8 @@ impl ScalarUDFImpl for SparkRegExpExtractAll { } } } else { - 0 // default to group 0 (entire match) + // Using 1 here to align with Spark's default behavior. + 1 }; // Pattern must be a literal string @@ -205,7 +208,7 @@ impl ScalarUDFImpl for SparkRegExpExtractAll { ColumnarValue::Scalar(ScalarValue::Utf8(s)) => { match s { Some(text) => { - let matches = extract_all_groups(text, ®ex, idx_val); + let matches = extract_all_groups(text, ®ex, idx_val)?; // Build a list array with a single element let mut list_builder = arrow::array::ListBuilder::new(arrow::array::StringBuilder::new()); @@ -214,16 +217,17 @@ impl ScalarUDFImpl for SparkRegExpExtractAll { } list_builder.append(true); let list_array = list_builder.finish(); - - Ok(ColumnarValue::Scalar(ScalarValue::List( - Arc::new(list_array), - ))) + + Ok(ColumnarValue::Scalar(ScalarValue::List(Arc::new( + list_array, + )))) } None => { // Return NULL list using try_into (same as planner.rs:424) let null_list: ScalarValue = DataType::List(Arc::new( - arrow::datatypes::Field::new("item", DataType::Utf8, true) - )).try_into()?; + arrow::datatypes::Field::new("item", DataType::Utf8, false), + )) + .try_into()?; Ok(ColumnarValue::Scalar(null_list)) } } @@ -239,53 +243,86 @@ impl ScalarUDFImpl for SparkRegExpExtractAll { // Helper functions -fn extract_group(text: &str, regex: &Regex, idx: usize) -> String { - regex - .captures(text) - .and_then(|caps| caps.get(idx)) - .map(|m| m.as_str().to_string()) - // Spark behavior: return empty string "" if no match or group not found - .unwrap_or_else(|| String::new()) +fn extract_group(text: &str, regex: &Regex, idx: usize) -> Result { + match regex.captures(text) { + Some(caps) => { + // Spark behavior: throw error if group index is out of bounds + if idx >= caps.len() { + return exec_err!( + "Regex group count is {}, but the specified group index is {}", + caps.len(), + idx + ); + } + Ok(caps + .get(idx) + .map(|m| m.as_str().to_string()) + .unwrap_or_default()) + } + None => { + // No match: return empty string (Spark behavior) + Ok(String::new()) + } + } } fn regexp_extract_array(array: &ArrayRef, regex: &Regex, idx: usize) -> Result { let string_array = array .as_any() .downcast_ref::>() - .ok_or_else(|| { - internal_datafusion_err!("regexp_extract expects string array input") - })?; + .ok_or_else(|| internal_datafusion_err!("regexp_extract expects string array input"))?; - let result: GenericStringArray = string_array - .iter() - .map(|s| s.map(|text| extract_group(text, regex, idx))) // NULL → None, non-NULL → Some("") - .collect(); + let mut builder = arrow::array::StringBuilder::new(); + for s in string_array.iter() { + match s { + Some(text) => { + let extracted = extract_group(text, regex, idx)?; + builder.append_value(extracted); + } + None => { + builder.append_null(); // NULL → None + } + } + } - Ok(Arc::new(result)) + Ok(Arc::new(builder.finish())) } -fn extract_all_groups(text: &str, regex: &Regex, idx: usize) -> Vec { - regex - .captures_iter(text) - .filter_map(|caps| caps.get(idx).map(|m| m.as_str().to_string())) - .collect() +fn extract_all_groups(text: &str, regex: &Regex, idx: usize) -> Result> { + let mut results = Vec::new(); + + for caps in regex.captures_iter(text) { + // Check bounds for each capture (matches Spark behavior) + if idx >= caps.len() { + return exec_err!( + "Regex group count is {}, but the specified group index is {}", + caps.len(), + idx + ); + } + + let matched = caps + .get(idx) + .map(|m| m.as_str().to_string()) + .unwrap_or_default(); + results.push(matched); + } + + Ok(results) } fn regexp_extract_all_array(array: &ArrayRef, regex: &Regex, idx: usize) -> Result { let string_array = array .as_any() .downcast_ref::>() - .ok_or_else(|| { - internal_datafusion_err!("regexp_extract_all expects string array input") - })?; + .ok_or_else(|| internal_datafusion_err!("regexp_extract_all expects string array input"))?; - let mut list_builder = - arrow::array::ListBuilder::new(arrow::array::StringBuilder::new()); + let mut list_builder = arrow::array::ListBuilder::new(arrow::array::StringBuilder::new()); for s in string_array.iter() { match s { Some(text) => { - let matches = extract_all_groups(text, regex, idx); + let matches = extract_all_groups(text, regex, idx)?; for m in matches { list_builder.values().append_value(m); } @@ -310,11 +347,14 @@ mod tests { let regex = Regex::new(r"(\d+)-(\w+)").unwrap(); // Spark behavior: return "" on no match, not None - assert_eq!(extract_group("123-abc", ®ex, 0), "123-abc"); - assert_eq!(extract_group("123-abc", ®ex, 1), "123"); - assert_eq!(extract_group("123-abc", ®ex, 2), "abc"); - assert_eq!(extract_group("123-abc", ®ex, 3), ""); // no such group → "" - assert_eq!(extract_group("no match", ®ex, 0), ""); // no match → "" + assert_eq!(extract_group("123-abc", ®ex, 0).unwrap(), "123-abc"); + assert_eq!(extract_group("123-abc", ®ex, 1).unwrap(), "123"); + assert_eq!(extract_group("123-abc", ®ex, 2).unwrap(), "abc"); + assert_eq!(extract_group("no match", ®ex, 0).unwrap(), ""); // no match → "" + + // Spark behavior: group index out of bounds → error + assert!(extract_group("123-abc", ®ex, 3).is_err()); + assert!(extract_group("123-abc", ®ex, 99).is_err()); } #[test] @@ -322,23 +362,26 @@ mod tests { let regex = Regex::new(r"(\d+)").unwrap(); // Multiple matches - let matches = extract_all_groups("a1b2c3", ®ex, 0); + let matches = extract_all_groups("a1b2c3", ®ex, 0).unwrap(); assert_eq!(matches, vec!["1", "2", "3"]); // Same with group index 1 - let matches = extract_all_groups("a1b2c3", ®ex, 1); + let matches = extract_all_groups("a1b2c3", ®ex, 1).unwrap(); assert_eq!(matches, vec!["1", "2", "3"]); - // No match - let matches = extract_all_groups("no digits", ®ex, 0); + // No match: returns empty vec, not error + let matches = extract_all_groups("no digits", ®ex, 0).unwrap(); assert!(matches.is_empty()); assert_eq!(matches, Vec::::new()); + + // Group index out of bounds → error + assert!(extract_all_groups("a1b2c3", ®ex, 2).is_err()); } - + #[test] fn test_regexp_extract_all_array() -> Result<()> { use datafusion::common::cast::as_list_array; - + let regex = Regex::new(r"(\d+)").unwrap(); let array = Arc::new(StringArray::from(vec![ Some("a1b2"), @@ -352,23 +395,32 @@ mod tests { // Row 0: "a1b2" → ["1", "2"] let row0 = list_array.value(0); - let row0_str = row0.as_any().downcast_ref::>().unwrap(); + let row0_str = row0 + .as_any() + .downcast_ref::>() + .unwrap(); assert_eq!(row0_str.len(), 2); assert_eq!(row0_str.value(0), "1"); assert_eq!(row0_str.value(1), "2"); // Row 1: "no digits" → [] (empty array, not NULL) let row1 = list_array.value(1); - let row1_str = row1.as_any().downcast_ref::>().unwrap(); - assert_eq!(row1_str.len(), 0); // Empty array - assert!(!list_array.is_null(1)); // Not NULL, just empty + let row1_str = row1 + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(row1_str.len(), 0); // Empty array + assert!(!list_array.is_null(1)); // Not NULL, just empty // Row 2: NULL input → NULL output assert!(list_array.is_null(2)); // Row 3: "c3d4e5" → ["3", "4", "5"] let row3 = list_array.value(3); - let row3_str = row3.as_any().downcast_ref::>().unwrap(); + let row3_str = row3 + .as_any() + .downcast_ref::>() + .unwrap(); assert_eq!(row3_str.len(), 3); assert_eq!(row3_str.value(0), "3"); assert_eq!(row3_str.value(1), "4"); @@ -392,10 +444,9 @@ mod tests { assert_eq!(result_array.value(0), "123"); assert_eq!(result_array.value(1), "456"); - assert!(result_array.is_null(2)); // NULL input → NULL output - assert_eq!(result_array.value(3), ""); // no match → "" (empty string) + assert!(result_array.is_null(2)); // NULL input → NULL output + assert_eq!(result_array.value(3), ""); // no match → "" (empty string) Ok(()) } } - From 87dfed42583c3a6d975e0adae697f44cca0c96db Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 2 Dec 2025 23:11:12 -0800 Subject: [PATCH 5/9] Solve comments (test not yet fixed) 1. more data type support in scala side 2. unify errors as execution ones 3. reduce code duplication 4. negative index check --- .../src/string_funcs/regexp_extract.rs | 151 ++++++++---------- .../org/apache/comet/serde/strings.scala | 16 +- 2 files changed, 77 insertions(+), 90 deletions(-) diff --git a/native/spark-expr/src/string_funcs/regexp_extract.rs b/native/spark-expr/src/string_funcs/regexp_extract.rs index eba2e7993c..d8e1cbf3b0 100644 --- a/native/spark-expr/src/string_funcs/regexp_extract.rs +++ b/native/spark-expr/src/string_funcs/regexp_extract.rs @@ -17,7 +17,7 @@ use arrow::array::{Array, ArrayRef, GenericStringArray}; use arrow::datatypes::DataType; -use datafusion::common::{exec_err, internal_datafusion_err, Result, ScalarValue}; +use datafusion::common::{exec_err, DataFusionError, Result, ScalarValue}; use datafusion::logical_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; @@ -61,53 +61,21 @@ impl ScalarUDFImpl for SparkRegExpExtract { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - // regexp_extract always returns String Ok(DataType::Utf8) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - // regexp_extract(subject, pattern, idx) - if args.args.len() != 3 { - return exec_err!( - "regexp_extract expects 3 arguments, got {}", - args.args.len() - ); - } - - let subject = &args.args[0]; - let pattern = &args.args[1]; - let idx = &args.args[2]; - - // Pattern must be a literal string - let pattern_str = match pattern { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s.clone(), - _ => { - return exec_err!("regexp_extract pattern must be a string literal"); - } - }; - - // idx must be a literal int - let idx_val = match idx { - ColumnarValue::Scalar(ScalarValue::Int32(Some(i))) => *i as usize, - _ => { - return exec_err!("regexp_extract idx must be an integer literal"); - } - }; - - // Compile regex once - let regex = Regex::new(&pattern_str).map_err(|e| { - internal_datafusion_err!("Invalid regex pattern '{}': {}", pattern_str, e) - })?; + let (subject, regex, idx) = parse_args(&args, self.name())?; match subject { ColumnarValue::Array(array) => { - let result = regexp_extract_array(array, ®ex, idx_val)?; + let result = regexp_extract_array(&array, ®ex, idx)?; Ok(ColumnarValue::Array(result)) } ColumnarValue::Scalar(ScalarValue::Utf8(s)) => { let result = match s { - Some(text) => Some(extract_group(text, ®ex, idx_val)?), - None => None, // NULL input → NULL output + Some(text) => Some(extract_group(&text, ®ex, idx)?), + None => None, }; Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) } @@ -165,50 +133,18 @@ impl ScalarUDFImpl for SparkRegExpExtractAll { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - // regexp_extract_all(subject, pattern) or regexp_extract_all(subject, pattern, idx) - if args.args.len() < 2 || args.args.len() > 3 { - return exec_err!( - "regexp_extract_all expects 2 or 3 arguments, got {}", - args.args.len() - ); - } - - let subject = &args.args[0]; - let pattern = &args.args[1]; - let idx_val = if args.args.len() == 3 { - match &args.args[2] { - ColumnarValue::Scalar(ScalarValue::Int32(Some(i))) => *i as usize, - _ => { - return exec_err!("regexp_extract_all idx must be an integer literal"); - } - } - } else { - // Using 1 here to align with Spark's default behavior. - 1 - }; - - // Pattern must be a literal string - let pattern_str = match pattern { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s.clone(), - _ => { - return exec_err!("regexp_extract_all pattern must be a string literal"); - } - }; - - // Compile regex once - let regex = Regex::new(&pattern_str).map_err(|e| { - internal_datafusion_err!("Invalid regex pattern '{}': {}", pattern_str, e) - })?; + // regexp_extract_all(subject, pattern, idx) + let (subject, regex, idx) = parse_args(&args, self.name())?; match subject { ColumnarValue::Array(array) => { - let result = regexp_extract_all_array(array, ®ex, idx_val)?; + let result = regexp_extract_all_array(&array, ®ex, idx)?; Ok(ColumnarValue::Array(result)) } ColumnarValue::Scalar(ScalarValue::Utf8(s)) => { match s { Some(text) => { - let matches = extract_all_groups(text, ®ex, idx_val)?; + let matches = extract_all_groups(&text, ®ex, idx)?; // Build a list array with a single element let mut list_builder = arrow::array::ListBuilder::new(arrow::array::StringBuilder::new()); @@ -223,7 +159,6 @@ impl ScalarUDFImpl for SparkRegExpExtractAll { )))) } None => { - // Return NULL list using try_into (same as planner.rs:424) let null_list: ScalarValue = DataType::List(Arc::new( arrow::datatypes::Field::new("item", DataType::Utf8, false), )) @@ -243,14 +178,53 @@ impl ScalarUDFImpl for SparkRegExpExtractAll { // Helper functions -fn extract_group(text: &str, regex: &Regex, idx: usize) -> Result { +fn parse_args<'a>(args: &'a ScalarFunctionArgs, fn_name: &str) -> Result<(&'a ColumnarValue, Regex, i32)> { + if args.args.len() != 3 { + return exec_err!( + "{} expects 3 arguments, got {}", + fn_name, + args.args.len() + ); + } + + let subject = &args.args[0]; + let idx = &args.args[2]; + let pattern = &args.args[1]; + + let pattern_str = match pattern { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s.clone(), + _ => { + return exec_err!("{} pattern must be a string literal", fn_name); + } + }; + + let idx_val = match idx { + ColumnarValue::Scalar(ScalarValue::Int32(Some(i))) => *i as i32, + _ => { + return exec_err!("{} idx must be an integer literal", fn_name); + } + }; + if idx_val < 0 { + return exec_err!("{fn_name} group index must be non-negative"); + } + + let regex = Regex::new(&pattern_str).map_err(|e| { + DataFusionError::Execution(format!("Invalid regex pattern '{}': {}", pattern_str, e)) + })?; + + Ok((subject, regex, idx_val)) +} + +fn extract_group(text: &str, regex: &Regex, idx: i32) -> Result { + let idx = idx as usize; match regex.captures(text) { Some(caps) => { // Spark behavior: throw error if group index is out of bounds - if idx >= caps.len() { + let group_cnt = caps.len() - 1; + if idx > group_cnt { return exec_err!( - "Regex group count is {}, but the specified group index is {}", - caps.len(), + "Regex group index out of bounds, group count: {}, index: {}", + group_cnt, idx ); } @@ -266,11 +240,11 @@ fn extract_group(text: &str, regex: &Regex, idx: usize) -> Result { } } -fn regexp_extract_array(array: &ArrayRef, regex: &Regex, idx: usize) -> Result { +fn regexp_extract_array(array: &ArrayRef, regex: &Regex, idx: i32) -> Result { let string_array = array .as_any() .downcast_ref::>() - .ok_or_else(|| internal_datafusion_err!("regexp_extract expects string array input"))?; + .ok_or_else(|| DataFusionError::Execution("regexp_extract expects string array input".to_string()))?; let mut builder = arrow::array::StringBuilder::new(); for s in string_array.iter() { @@ -280,7 +254,7 @@ fn regexp_extract_array(array: &ArrayRef, regex: &Regex, idx: usize) -> Result { - builder.append_null(); // NULL → None + builder.append_null(); } } } @@ -288,15 +262,17 @@ fn regexp_extract_array(array: &ArrayRef, regex: &Regex, idx: usize) -> Result Result> { +fn extract_all_groups(text: &str, regex: &Regex, idx: i32) -> Result> { + let idx = idx as usize; let mut results = Vec::new(); for caps in regex.captures_iter(text) { // Check bounds for each capture (matches Spark behavior) - if idx >= caps.len() { + let group_num = caps.len() - 1; + if idx > group_num { return exec_err!( - "Regex group count is {}, but the specified group index is {}", - caps.len(), + "Regex group index out of bounds, group count: {}, index: {}", + group_num, idx ); } @@ -311,11 +287,11 @@ fn extract_all_groups(text: &str, regex: &Regex, idx: usize) -> Result Result { +fn regexp_extract_all_array(array: &ArrayRef, regex: &Regex, idx: i32) -> Result { let string_array = array .as_any() .downcast_ref::>() - .ok_or_else(|| internal_datafusion_err!("regexp_extract_all expects string array input"))?; + .ok_or_else(|| DataFusionError::Execution("regexp_extract_all expects string array input".to_string()))?; let mut list_builder = arrow::array::ListBuilder::new(arrow::array::StringBuilder::new()); @@ -355,6 +331,7 @@ mod tests { // Spark behavior: group index out of bounds → error assert!(extract_group("123-abc", ®ex, 3).is_err()); assert!(extract_group("123-abc", ®ex, 99).is_err()); + assert!(extract_group("123-abc", ®ex, -1).is_err()); } #[test] diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 733c25ec2b..6dfdfed385 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -309,11 +309,16 @@ object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { expr.idx match { case Literal(_, DataTypes.IntegerType) => Compatible() + case Literal(_, DataTypes.LongType) => + Compatible() + case Literal(_, DataTypes.ShortType) => + Compatible() + case Literal(_, DataTypes.ByteType) => + Compatible() case _ => Unsupported(Some("Only literal group index is supported")) } } - override def convert( expr: RegExpExtract, inputs: Seq[Attribute], @@ -345,10 +350,16 @@ object CometRegExpExtractAll extends CometExpressionSerde[RegExpExtractAll] { } // Check if idx is a literal - // For regexp_extract_all, idx will default to 0 (group 0, entire match) if not specified + // For regexp_extract_all, idx will default to 1 if not specified expr.idx match { case Literal(_, DataTypes.IntegerType) => Compatible() + case Literal(_, DataTypes.LongType) => + Compatible() + case Literal(_, DataTypes.ShortType) => + Compatible() + case Literal(_, DataTypes.ByteType) => + Compatible() case _ => Unsupported(Some("Only literal group index is supported")) } @@ -357,7 +368,6 @@ object CometRegExpExtractAll extends CometExpressionSerde[RegExpExtractAll] { expr: RegExpExtractAll, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { - // Check if the pattern is compatible with Spark or allow incompatible patterns val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding) val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding) val idxExpr = exprToProtoInternal(expr.idx, inputs, binding) From d83cac50d4f68aec730b874622217139b687256a Mon Sep 17 00:00:00 2001 From: Daniel Tu Date: Tue, 2 Dec 2025 23:24:31 -0800 Subject: [PATCH 6/9] fix regexp_extract_all test failure --- .../src/string_funcs/regexp_extract.rs | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/native/spark-expr/src/string_funcs/regexp_extract.rs b/native/spark-expr/src/string_funcs/regexp_extract.rs index d8e1cbf3b0..3a1d4e4228 100644 --- a/native/spark-expr/src/string_funcs/regexp_extract.rs +++ b/native/spark-expr/src/string_funcs/regexp_extract.rs @@ -128,7 +128,7 @@ impl ScalarUDFImpl for SparkRegExpExtractAll { Ok(DataType::List(Arc::new(arrow::datatypes::Field::new( "item", DataType::Utf8, - false, + true, )))) } @@ -160,7 +160,7 @@ impl ScalarUDFImpl for SparkRegExpExtractAll { } None => { let null_list: ScalarValue = DataType::List(Arc::new( - arrow::datatypes::Field::new("item", DataType::Utf8, false), + arrow::datatypes::Field::new("item", DataType::Utf8, true), )) .try_into()?; Ok(ColumnarValue::Scalar(null_list)) @@ -178,13 +178,12 @@ impl ScalarUDFImpl for SparkRegExpExtractAll { // Helper functions -fn parse_args<'a>(args: &'a ScalarFunctionArgs, fn_name: &str) -> Result<(&'a ColumnarValue, Regex, i32)> { +fn parse_args<'a>( + args: &'a ScalarFunctionArgs, + fn_name: &str, +) -> Result<(&'a ColumnarValue, Regex, i32)> { if args.args.len() != 3 { - return exec_err!( - "{} expects 3 arguments, got {}", - fn_name, - args.args.len() - ); + return exec_err!("{} expects 3 arguments, got {}", fn_name, args.args.len()); } let subject = &args.args[0]; @@ -244,7 +243,9 @@ fn regexp_extract_array(array: &ArrayRef, regex: &Regex, idx: i32) -> Result>() - .ok_or_else(|| DataFusionError::Execution("regexp_extract expects string array input".to_string()))?; + .ok_or_else(|| { + DataFusionError::Execution("regexp_extract expects string array input".to_string()) + })?; let mut builder = arrow::array::StringBuilder::new(); for s in string_array.iter() { @@ -291,7 +292,9 @@ fn regexp_extract_all_array(array: &ArrayRef, regex: &Regex, idx: i32) -> Result let string_array = array .as_any() .downcast_ref::>() - .ok_or_else(|| DataFusionError::Execution("regexp_extract_all expects string array input".to_string()))?; + .ok_or_else(|| { + DataFusionError::Execution("regexp_extract_all expects string array input".to_string()) + })?; let mut list_builder = arrow::array::ListBuilder::new(arrow::array::StringBuilder::new()); From f2d82b3145a4840865d23af710ba57a38a9040a8 Mon Sep 17 00:00:00 2001 From: Daniel Tu Date: Sun, 7 Dec 2025 11:24:53 -0800 Subject: [PATCH 7/9] refactor udf impl --- .../src/string_funcs/regexp_extract.rs | 427 ++++++++++++------ .../comet/CometStringExpressionSuite.scala | 223 ++++++++- 2 files changed, 507 insertions(+), 143 deletions(-) diff --git a/native/spark-expr/src/string_funcs/regexp_extract.rs b/native/spark-expr/src/string_funcs/regexp_extract.rs index 3a1d4e4228..559fc31f4c 100644 --- a/native/spark-expr/src/string_funcs/regexp_extract.rs +++ b/native/spark-expr/src/string_funcs/regexp_extract.rs @@ -15,21 +15,33 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayRef, GenericStringArray}; -use arrow::datatypes::DataType; +use arrow::array::{ + Array, ArrayRef, GenericStringArray, GenericStringBuilder, ListArray, OffsetSizeTrait, +}; +use arrow::datatypes::{DataType, FieldRef}; use datafusion::common::{exec_err, DataFusionError, Result, ScalarValue}; use datafusion::logical_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; +use datafusion::logical_expr_common::signature::TypeSignature::Exact; use regex::Regex; use std::any::Any; use std::sync::Arc; /// Spark-compatible regexp_extract function +/// +/// Extracts a substring matching a [regular expression](https://docs.rs/regex/latest/regex/#syntax) +/// and returns the specified capture group. +/// +/// The function signature is: `regexp_extract(str, regexp, idx)` +/// where: +/// - `str`: The input string to search in +/// - `regexp`: The regular expression pattern (must be a literal) +/// - `idx`: The capture group index (0 for the entire match, must be a literal) +/// #[derive(Debug, PartialEq, Eq, Hash)] pub struct SparkRegExpExtract { signature: Signature, - aliases: Vec, } impl Default for SparkRegExpExtract { @@ -41,8 +53,13 @@ impl Default for SparkRegExpExtract { impl SparkRegExpExtract { pub fn new() -> Self { Self { - signature: Signature::user_defined(Volatility::Immutable), - aliases: vec![], + signature: Signature::one_of( + vec![ + Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Int32]), + Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Int32]), + ], + Volatility::Immutable, + ), } } } @@ -61,38 +78,61 @@ impl ScalarUDFImpl for SparkRegExpExtract { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Utf8) + Ok(match &_arg_types[0] { + DataType::Utf8 => DataType::Utf8, + DataType::LargeUtf8 => DataType::LargeUtf8, + _ => { + return exec_err!( + "regexp_extract expects utf8 or largeutf8 input but got {:?}", + _arg_types[0] + ) + } + }) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let (subject, regex, idx) = parse_args(&args, self.name())?; - - match subject { - ColumnarValue::Array(array) => { - let result = regexp_extract_array(&array, ®ex, idx)?; - Ok(ColumnarValue::Array(result)) - } - ColumnarValue::Scalar(ScalarValue::Utf8(s)) => { - let result = match s { - Some(text) => Some(extract_group(&text, ®ex, idx)?), - None => None, - }; - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + let args = &args.args; + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + let is_scalar = len.is_none(); + let result = match args[0].data_type() { + DataType::Utf8 => regexp_extract_func::(args), + DataType::LargeUtf8 => regexp_extract_func::(args), + _ => { + return exec_err!( + "regexp_extract expects the data type of subject to be utf8 or largeutf8 but got {:?}", + args[0].data_type() + ); } - _ => exec_err!("regexp_extract expects string input"), + }; + if is_scalar { + result + .and_then(|arr| ScalarValue::try_from_array(&arr, 0)) + .map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) } } - - fn aliases(&self) -> &[String] { - &self.aliases - } } /// Spark-compatible regexp_extract_all function +/// +/// Extracts all substrings matching a [regular expression](https://docs.rs/regex/latest/regex/#syntax) +/// and returns them as an array. +/// +/// The function signature is: `regexp_extract_all(str, regexp, idx)` +/// where: +/// - `str`: The input string to search in +/// - `regexp`: The regular expression pattern (must be a literal) +/// - `idx`: The capture group index (0 for the entire match, must be a literal) +/// #[derive(Debug, PartialEq, Eq, Hash)] pub struct SparkRegExpExtractAll { signature: Signature, - aliases: Vec, } impl Default for SparkRegExpExtractAll { @@ -104,8 +144,13 @@ impl Default for SparkRegExpExtractAll { impl SparkRegExpExtractAll { pub fn new() -> Self { Self { - signature: Signature::user_defined(Volatility::Immutable), - aliases: vec![], + signature: Signature::one_of( + vec![ + Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Int32]), + Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Int32]), + ], + Volatility::Immutable, + ), } } } @@ -124,71 +169,90 @@ impl ScalarUDFImpl for SparkRegExpExtractAll { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - // regexp_extract_all returns Array - Ok(DataType::List(Arc::new(arrow::datatypes::Field::new( - "item", - DataType::Utf8, - true, - )))) + Ok(match &_arg_types[0] { + DataType::Utf8 => DataType::List(Arc::new(arrow::datatypes::Field::new( + "item", + DataType::Utf8, + false, + ))), + DataType::LargeUtf8 => DataType::List(Arc::new(arrow::datatypes::Field::new( + "item", + DataType::LargeUtf8, + false, + ))), + _ => { + return exec_err!( + "regexp_extract_all expects utf8 or largeutf8 input but got {:?}", + _arg_types[0] + ) + } + }) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - // regexp_extract_all(subject, pattern, idx) - let (subject, regex, idx) = parse_args(&args, self.name())?; - - match subject { - ColumnarValue::Array(array) => { - let result = regexp_extract_all_array(&array, ®ex, idx)?; - Ok(ColumnarValue::Array(result)) - } - ColumnarValue::Scalar(ScalarValue::Utf8(s)) => { - match s { - Some(text) => { - let matches = extract_all_groups(&text, ®ex, idx)?; - // Build a list array with a single element - let mut list_builder = - arrow::array::ListBuilder::new(arrow::array::StringBuilder::new()); - for m in matches { - list_builder.values().append_value(m); - } - list_builder.append(true); - let list_array = list_builder.finish(); - - Ok(ColumnarValue::Scalar(ScalarValue::List(Arc::new( - list_array, - )))) - } - None => { - let null_list: ScalarValue = DataType::List(Arc::new( - arrow::datatypes::Field::new("item", DataType::Utf8, true), - )) - .try_into()?; - Ok(ColumnarValue::Scalar(null_list)) - } - } + let args = &args.args; + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + let is_scalar = len.is_none(); + let result = match args[0].data_type() { + DataType::Utf8 => regexp_extract_all_func::(args), + DataType::LargeUtf8 => regexp_extract_all_func::(args), + _ => { + return exec_err!( + "regexp_extract_all expects the data type of subject to be utf8 or largeutf8 but got {:?}", + args[0].data_type() + ); } - _ => exec_err!("regexp_extract_all expects string input"), + }; + if is_scalar { + result + .and_then(|arr| ScalarValue::try_from_array(&arr, 0)) + .map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) } } - - fn aliases(&self) -> &[String] { - &self.aliases - } } // Helper functions +fn regexp_extract_func(args: &[ColumnarValue]) -> Result { + let (subject, regex, idx) = parse_args(args, "regexp_extract")?; + + let subject_array = match subject { + ColumnarValue::Array(array) => Arc::clone(array), + ColumnarValue::Scalar(scalar) => scalar.to_array()?, + }; + + regexp_extract_array::(&subject_array, ®ex, idx) +} + +fn regexp_extract_all_func(args: &[ColumnarValue]) -> Result { + let (subject, regex, idx) = parse_args(args, "regexp_extract_all")?; + + let subject_array = match subject { + ColumnarValue::Array(array) => Arc::clone(array), + ColumnarValue::Scalar(scalar) => scalar.to_array()?, + }; + + regexp_extract_all_array::(&subject_array, ®ex, idx) +} + fn parse_args<'a>( - args: &'a ScalarFunctionArgs, + args: &'a [ColumnarValue], fn_name: &str, ) -> Result<(&'a ColumnarValue, Regex, i32)> { - if args.args.len() != 3 { - return exec_err!("{} expects 3 arguments, got {}", fn_name, args.args.len()); + if args.len() != 3 { + return exec_err!("{} expects 3 arguments, got {}", fn_name, args.len()); } - let subject = &args.args[0]; - let idx = &args.args[2]; - let pattern = &args.args[1]; + let subject = &args[0]; + let pattern = &args[1]; + let idx = &args[2]; let pattern_str = match pattern { ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s.clone(), @@ -198,7 +262,7 @@ fn parse_args<'a>( }; let idx_val = match idx { - ColumnarValue::Scalar(ScalarValue::Int32(Some(i))) => *i as i32, + ColumnarValue::Scalar(ScalarValue::Int32(Some(i))) => *i, _ => { return exec_err!("{} idx must be an integer literal", fn_name); } @@ -214,7 +278,35 @@ fn parse_args<'a>( Ok((subject, regex, idx_val)) } -fn extract_group(text: &str, regex: &Regex, idx: i32) -> Result { +fn regexp_extract_array( + array: &ArrayRef, + regex: &Regex, + idx: i32, +) -> Result { + let string_array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Execution("regexp_extract expects string array input".to_string()) + })?; + + let mut builder = GenericStringBuilder::::new(); + for s in string_array.iter() { + match s { + Some(text) => { + let extracted = regexp_extract(text, regex, idx)?; + builder.append_value(extracted); + } + None => { + builder.append_null(); + } + } + } + + Ok(Arc::new(builder.finish())) +} + +fn regexp_extract(text: &str, regex: &Regex, idx: i32) -> Result { let idx = idx as usize; match regex.captures(text) { Some(caps) => { @@ -239,31 +331,62 @@ fn extract_group(text: &str, regex: &Regex, idx: i32) -> Result { } } -fn regexp_extract_array(array: &ArrayRef, regex: &Regex, idx: i32) -> Result { +fn regexp_extract_all_array( + array: &ArrayRef, + regex: &Regex, + idx: i32, +) -> Result { let string_array = array .as_any() - .downcast_ref::>() + .downcast_ref::>() .ok_or_else(|| { - DataFusionError::Execution("regexp_extract expects string array input".to_string()) + DataFusionError::Execution("regexp_extract_all expects string array input".to_string()) })?; - let mut builder = arrow::array::StringBuilder::new(); + let item_data_type = match array.data_type() { + DataType::Utf8 => DataType::Utf8, + DataType::LargeUtf8 => DataType::LargeUtf8, + _ => { + return exec_err!( + "regexp_extract_all expects utf8 or largeutf8 array but got {:?}", + array.data_type() + ); + } + }; + let item_field = Arc::new(arrow::datatypes::Field::new("item", item_data_type, false)); + + let string_builder = GenericStringBuilder::::new(); + let mut list_builder = + arrow::array::ListBuilder::new(string_builder).with_field(item_field.clone()); + for s in string_array.iter() { match s { Some(text) => { - let extracted = extract_group(text, regex, idx)?; - builder.append_value(extracted); + let matches = regexp_extract_all(text, regex, idx)?; + for m in matches { + list_builder.values().append_value(m); + } + list_builder.append(true); } None => { - builder.append_null(); + list_builder.append(false); } } } - Ok(Arc::new(builder.finish())) + let list_array = list_builder.finish(); + + // Manually create a new ListArray with the correct field schema to ensure nullable is false + // This ensures the schema matches what we declared in return_type + Ok(Arc::new(ListArray::new( + FieldRef::from(item_field.clone()), + list_array.offsets().clone(), + list_array.values().clone(), + list_array.nulls().cloned(), + ))) } -fn extract_all_groups(text: &str, regex: &Regex, idx: i32) -> Result> { +fn regexp_extract_all(text: &str, regex: &Regex, idx: i32) -> Result> { let idx = idx as usize; let mut results = Vec::new(); @@ -288,53 +411,25 @@ fn extract_all_groups(text: &str, regex: &Regex, idx: i32) -> Result Ok(results) } -fn regexp_extract_all_array(array: &ArrayRef, regex: &Regex, idx: i32) -> Result { - let string_array = array - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Execution("regexp_extract_all expects string array input".to_string()) - })?; - - let mut list_builder = arrow::array::ListBuilder::new(arrow::array::StringBuilder::new()); - - for s in string_array.iter() { - match s { - Some(text) => { - let matches = extract_all_groups(text, regex, idx)?; - for m in matches { - list_builder.values().append_value(m); - } - list_builder.append(true); - } - None => { - list_builder.append(false); - } - } - } - - Ok(Arc::new(list_builder.finish())) -} - #[cfg(test)] mod tests { use super::*; - use arrow::array::StringArray; + use arrow::array::{LargeStringArray, StringArray}; #[test] fn test_regexp_extract_basic() { let regex = Regex::new(r"(\d+)-(\w+)").unwrap(); // Spark behavior: return "" on no match, not None - assert_eq!(extract_group("123-abc", ®ex, 0).unwrap(), "123-abc"); - assert_eq!(extract_group("123-abc", ®ex, 1).unwrap(), "123"); - assert_eq!(extract_group("123-abc", ®ex, 2).unwrap(), "abc"); - assert_eq!(extract_group("no match", ®ex, 0).unwrap(), ""); // no match → "" + assert_eq!(regexp_extract("123-abc", ®ex, 0).unwrap(), "123-abc"); + assert_eq!(regexp_extract("123-abc", ®ex, 1).unwrap(), "123"); + assert_eq!(regexp_extract("123-abc", ®ex, 2).unwrap(), "abc"); + assert_eq!(regexp_extract("no match", ®ex, 0).unwrap(), ""); // no match → "" // Spark behavior: group index out of bounds → error - assert!(extract_group("123-abc", ®ex, 3).is_err()); - assert!(extract_group("123-abc", ®ex, 99).is_err()); - assert!(extract_group("123-abc", ®ex, -1).is_err()); + assert!(regexp_extract("123-abc", ®ex, 3).is_err()); + assert!(regexp_extract("123-abc", ®ex, 99).is_err()); + assert!(regexp_extract("123-abc", ®ex, -1).is_err()); } #[test] @@ -342,20 +437,20 @@ mod tests { let regex = Regex::new(r"(\d+)").unwrap(); // Multiple matches - let matches = extract_all_groups("a1b2c3", ®ex, 0).unwrap(); + let matches = regexp_extract_all("a1b2c3", ®ex, 0).unwrap(); assert_eq!(matches, vec!["1", "2", "3"]); // Same with group index 1 - let matches = extract_all_groups("a1b2c3", ®ex, 1).unwrap(); + let matches = regexp_extract_all("a1b2c3", ®ex, 1).unwrap(); assert_eq!(matches, vec!["1", "2", "3"]); // No match: returns empty vec, not error - let matches = extract_all_groups("no digits", ®ex, 0).unwrap(); + let matches = regexp_extract_all("no digits", ®ex, 0).unwrap(); assert!(matches.is_empty()); assert_eq!(matches, Vec::::new()); // Group index out of bounds → error - assert!(extract_all_groups("a1b2c3", ®ex, 2).is_err()); + assert!(regexp_extract_all("a1b2c3", ®ex, 2).is_err()); } #[test] @@ -370,7 +465,7 @@ mod tests { Some("c3d4e5"), ])) as ArrayRef; - let result = regexp_extract_all_array(&array, ®ex, 0)?; + let result = regexp_extract_all_array::(&array, ®ex, 0)?; let list_array = as_list_array(&result)?; // Row 0: "a1b2" → ["1", "2"] @@ -419,7 +514,7 @@ mod tests { Some("no-match"), ])) as ArrayRef; - let result = regexp_extract_array(&array, ®ex, 1)?; + let result = regexp_extract_array::(&array, ®ex, 1)?; let result_array = result.as_any().downcast_ref::().unwrap(); assert_eq!(result_array.value(0), "123"); @@ -429,4 +524,76 @@ mod tests { Ok(()) } + + #[test] + fn test_regexp_extract_largeutf8() -> Result<()> { + let regex = Regex::new(r"(\d+)").unwrap(); + let array = Arc::new(LargeStringArray::from(vec![ + Some("a1b2c3"), + Some("x5y6"), + None, + Some("no digits"), + ])) as ArrayRef; + + let result = regexp_extract_array::(&array, ®ex, 1)?; + let result_array = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result_array.value(0), "1"); // First digit from "a1b2c3" + assert_eq!(result_array.value(1), "5"); // First digit from "x5y6" + assert!(result_array.is_null(2)); // NULL input → NULL output + assert_eq!(result_array.value(3), ""); // no match → "" (empty string) + + Ok(()) + } + + #[test] + fn test_regexp_extract_all_largeutf8() -> Result<()> { + use datafusion::common::cast::as_list_array; + + let regex = Regex::new(r"(\d+)").unwrap(); + let array = Arc::new(LargeStringArray::from(vec![ + Some("a1b2c3"), + Some("x5y6"), + None, + Some("no digits"), + ])) as ArrayRef; + + let result = regexp_extract_all_array::(&array, ®ex, 0)?; + let list_array = as_list_array(&result)?; + + // Row 0: "a1b2c3" → ["1", "2", "3"] (all matches) + let row0 = list_array.value(0); + let row0_str = row0 + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(row0_str.len(), 3); + assert_eq!(row0_str.value(0), "1"); + assert_eq!(row0_str.value(1), "2"); + assert_eq!(row0_str.value(2), "3"); + + // Row 1: "x5y6" → ["5", "6"] (all matches) + let row1 = list_array.value(1); + let row1_str = row1 + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(row1_str.len(), 2); + assert_eq!(row1_str.value(0), "5"); + assert_eq!(row1_str.value(1), "6"); + + // Row 2: NULL input → NULL output + assert!(list_array.is_null(2)); + + // Row 3: "no digits" → [] (empty array, not NULL) + let row3 = list_array.value(3); + let row3_str = row3 + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(row3_str.len(), 0); // Empty array + assert!(!list_array.is_null(3)); // Not NULL, just empty + + Ok(()) + } } diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index 5214eb8215..01f6a24080 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -392,8 +392,6 @@ class CometStringExpressionSuite extends CometTestBase { } test("regexp_extract basic") { - import org.apache.comet.CometConf - withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { val data = Seq( ("100-200", 1), @@ -422,8 +420,6 @@ class CometStringExpressionSuite extends CometTestBase { } test("regexp_extract edge cases") { - import org.apache.comet.CometConf - withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { val data = Seq(("email@example.com", 1), ("phone: 123-456-7890", 1), ("price: $99.99", 1), (null, 1)) @@ -441,8 +437,6 @@ class CometStringExpressionSuite extends CometTestBase { } test("regexp_extract_all basic") { - import org.apache.comet.CometConf - withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { val data = Seq( ("a1b2c3", 1), @@ -468,8 +462,6 @@ class CometStringExpressionSuite extends CometTestBase { } test("regexp_extract_all multiple matches") { - import org.apache.comet.CometConf - withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { val data = Seq( ("The prices are $10, $20, and $30", 1), @@ -487,20 +479,225 @@ class CometStringExpressionSuite extends CometTestBase { } test("regexp_extract_all with dictionary encoding") { - import org.apache.comet.CometConf + withSQLConf( + CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true", + "parquet.enable.dictionary" -> "true") { + // Use repeated values to trigger dictionary encoding + // Mix short strings, long strings, and various patterns + val longString1 = "prefix" + ("abc" * 100) + "123" + ("xyz" * 100) + "456" + val longString2 = "start" + ("test" * 200) + "789" + ("end" * 150) + + val data = (0 until 2000).map(i => { + val text = i % 7 match { + case 0 => "a1b2c3" // Simple repeated pattern + case 1 => "x5y6" // Another simple pattern + case 2 => "no-match" // No digits + case 3 => longString1 // Long string with digits + case 4 => longString2 // Another long string + case 5 => "email@test.com-phone:123-456-7890" // Complex pattern + case 6 => "" // Empty string + } + (text, 1) + }) + + withParquetTable(data, "tbl") { + // Test simple pattern + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '(\\d+)') FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '(\\d+)', 0) FROM tbl") + + // Test complex patterns + checkSparkAnswerAndOperator( + "SELECT regexp_extract_all(_1, '(\\d{3}-\\d{3}-\\d{4})', 0) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '@([a-z]*)', 1) FROM tbl") + + // Test with multiple groups + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '([a-z])(\\d*)', 1) FROM tbl") + } + } + } + test("regexp_extract with dictionary encoding") { withSQLConf( CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true", "parquet.enable.dictionary" -> "true") { // Use repeated values to trigger dictionary encoding - val data = (0 until 1000).map(i => { - val text = if (i % 3 == 0) "a1b2c3" else if (i % 3 == 1) "x5y6" else "no-match" + // Mix short and long strings with various patterns + val longString1 = "data" + ("x" * 500) + "999" + ("y" * 500) + val longString2 = ("a" * 1000) + "777" + ("b" * 1000) + + val data = (0 until 2000).map(i => { + val text = i % 7 match { + case 0 => "a1b2c3" + case 1 => "x5y6" + case 2 => "no-match" + case 3 => longString1 + case 4 => longString2 + case 5 => "IP:192.168.1.100-PORT:8080" + case 6 => "" + } (text, 1) }) withParquetTable(data, "tbl") { - checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '\\d+') FROM tbl") - checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '\\d+', 0) FROM tbl") + // Test extracting first match with simple pattern + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '(\\d+)', 1) FROM tbl") + + // Test with complex patterns + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '(\\d+)\\.(\\d+)\\.(\\d+)\\.(\\d+)', 1) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, 'PORT:(\\d+)', 1) FROM tbl") + + // Test with multiple groups - extract second group + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '([a-z])(\\d+)', 2) FROM tbl") + } + } + } + + test("regexp_extract unicode and special characters") { + import org.apache.comet.CometConf + + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq( + ("测试123test", 1), // Chinese characters + ("日本語456にほんご", 1), // Japanese characters + ("한글789Korean", 1), // Korean characters + ("Привет999Hello", 1), // Cyrillic + ("line1\nline2", 1), // Newline + ("tab\there", 1), // Tab + ("special: $#@!%^&*", 1), // Special chars + ("mixed测试123test日本語", 1), // Mixed unicode + (null, 1)) + + withParquetTable(data, "tbl") { + // Extract digits from unicode text + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '(\\d+)', 1) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '(\\d+)', 1) FROM tbl") + + // Test word boundaries with unicode + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '([a-zA-Z]+)', 1) FROM tbl") + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '([a-zA-Z]+)', 1) FROM tbl") + } + } + } + + test("regexp_extract_all multiple groups") { + import org.apache.comet.CometConf + + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq( + ("a1b2c3", 1), + ("x5y6z7", 1), + ("test123demo456end789", 1), + (null, 1), + ("no match here", 1)) + + withParquetTable(data, "tbl") { + // Test extracting different groups - full match + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '([a-z])(\\d+)', 0) FROM tbl") + // Test extracting group 1 (letters) + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '([a-z])(\\d+)', 1) FROM tbl") + // Test extracting group 2 (digits) + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '([a-z])(\\d+)', 2) FROM tbl") + + // Test with three groups + checkSparkAnswerAndOperator( + "SELECT regexp_extract_all(_1, '([a-z]+)(\\d+)([a-z]+)', 1) FROM tbl") + checkSparkAnswerAndOperator( + "SELECT regexp_extract_all(_1, '([a-z]+)(\\d+)([a-z]+)', 2) FROM tbl") + checkSparkAnswerAndOperator( + "SELECT regexp_extract_all(_1, '([a-z]+)(\\d+)([a-z]+)', 3) FROM tbl") + } + } + } + + test("regexp_extract_all group index out of bounds") { + import org.apache.comet.CometConf + + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq(("a1b2c3", 1), ("test123", 1), (null, 1)) + + withParquetTable(data, "tbl") { + // Group index out of bounds - should match Spark's behavior (error or empty) + // Pattern has only 1 group, asking for group 2 + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '(\\d+)', 2) FROM tbl") + + // Pattern has no groups, asking for group 1 + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '\\d+', 1) FROM tbl") + } + } + } + + test("regexp_extract complex patterns") { + import org.apache.comet.CometConf + + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq( + ("2024-01-15", 1), // Date + ("192.168.1.1", 1), // IP address + ("user@domain.co.uk", 1), // Complex email + ("content", 1), // HTML-like + ("Time: 14:30:45.123", 1), // Timestamp + ("Version: 1.2.3-beta", 1), // Version string + ("RGB(255,128,0)", 1), // RGB color + (null, 1)) + + withParquetTable(data, "tbl") { + // Extract year from date + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '(\\d{4})-(\\d{2})-(\\d{2})', 1) FROM tbl") + + // Extract month from date + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '(\\d{4})-(\\d{2})-(\\d{2})', 2) FROM tbl") + + // Extract IP octets + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '(\\d+)\\.(\\d+)\\.(\\d+)\\.(\\d+)', 2) FROM tbl") + + // Extract email domain + checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '@([a-z.]+)', 1) FROM tbl") + + // Extract time components + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, '(\\d{2}):(\\d{2}):(\\d{2})', 1) FROM tbl") + + // Extract RGB values + checkSparkAnswerAndOperator( + "SELECT regexp_extract(_1, 'RGB\\((\\d+),(\\d+),(\\d+)\\)', 2) FROM tbl") + + // Test regexp_extract_all with complex patterns + checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '(\\d+)', 1) FROM tbl") + } + } + } + + test("regexp_extract vs regexp_extract_all comparison") { + import org.apache.comet.CometConf + + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq(("a1b2c3", 1), ("x5y6", 1), (null, 1), ("no digits", 1), ("single7match", 1)) + + withParquetTable(data, "tbl") { + // Compare single extraction vs all extractions in one query + checkSparkAnswerAndOperator("""SELECT + | regexp_extract(_1, '(\\d+)', 1) as first_match, + | regexp_extract_all(_1, '(\\d+)', 1) as all_matches + |FROM tbl""".stripMargin) + + // Verify regexp_extract returns first match only while regexp_extract_all returns all + checkSparkAnswerAndOperator("""SELECT + | _1, + | regexp_extract(_1, '(\\d+)', 1) as first_digit, + | regexp_extract_all(_1, '(\\d+)', 1) as all_digits + |FROM tbl""".stripMargin) + + // Test with multiple groups + checkSparkAnswerAndOperator("""SELECT + | regexp_extract(_1, '([a-z])(\\d+)', 1) as first_letter, + | regexp_extract_all(_1, '([a-z])(\\d+)', 1) as all_letters, + | regexp_extract(_1, '([a-z])(\\d+)', 2) as first_digit, + | regexp_extract_all(_1, '([a-z])(\\d+)', 2) as all_digits + |FROM tbl""".stripMargin) } } } From 84c0132e9e87203ce2b2f31dc33303dfc37b0228 Mon Sep 17 00:00:00 2001 From: Daniel Tu Date: Sun, 7 Dec 2025 11:28:55 -0800 Subject: [PATCH 8/9] fix rust lint --- native/spark-expr/src/string_funcs/regexp_extract.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/native/spark-expr/src/string_funcs/regexp_extract.rs b/native/spark-expr/src/string_funcs/regexp_extract.rs index 559fc31f4c..38c80d8129 100644 --- a/native/spark-expr/src/string_funcs/regexp_extract.rs +++ b/native/spark-expr/src/string_funcs/regexp_extract.rs @@ -357,7 +357,7 @@ fn regexp_extract_all_array( let string_builder = GenericStringBuilder::::new(); let mut list_builder = - arrow::array::ListBuilder::new(string_builder).with_field(item_field.clone()); + arrow::array::ListBuilder::new(string_builder).with_field(Arc::clone(&item_field)); for s in string_array.iter() { match s { @@ -379,9 +379,9 @@ fn regexp_extract_all_array( // Manually create a new ListArray with the correct field schema to ensure nullable is false // This ensures the schema matches what we declared in return_type Ok(Arc::new(ListArray::new( - FieldRef::from(item_field.clone()), + FieldRef::from(Arc::clone(&item_field)), list_array.offsets().clone(), - list_array.values().clone(), + Arc::clone(list_array.values()), list_array.nulls().cloned(), ))) } From a55263f2a5c6d2d96cfb0a1ba1e95aa595770e98 Mon Sep 17 00:00:00 2001 From: Daniel Date: Sun, 7 Dec 2025 13:12:25 -0800 Subject: [PATCH 9/9] minor updates --- .../main/scala/org/apache/comet/serde/strings.scala | 12 ------------ .../apache/comet/CometStringExpressionSuite.scala | 4 ++-- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 6dfdfed385..51b41e2593 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -309,12 +309,6 @@ object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { expr.idx match { case Literal(_, DataTypes.IntegerType) => Compatible() - case Literal(_, DataTypes.LongType) => - Compatible() - case Literal(_, DataTypes.ShortType) => - Compatible() - case Literal(_, DataTypes.ByteType) => - Compatible() case _ => Unsupported(Some("Only literal group index is supported")) } @@ -354,12 +348,6 @@ object CometRegExpExtractAll extends CometExpressionSerde[RegExpExtractAll] { expr.idx match { case Literal(_, DataTypes.IntegerType) => Compatible() - case Literal(_, DataTypes.LongType) => - Compatible() - case Literal(_, DataTypes.ShortType) => - Compatible() - case Literal(_, DataTypes.ByteType) => - Compatible() case _ => Unsupported(Some("Only literal group index is supported")) } diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index 01f6a24080..0bdaba62b1 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -409,7 +409,7 @@ class CometStringExpressionSuite extends CometTestBase { checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 1) FROM tbl") // Test group 2 checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 2) FROM tbl") - // Test non-existent group → should return "" + // Test non-existent group → should error checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '(\\d+)-(\\d+)', 3) FROM tbl") // Test empty pattern checkSparkAnswerAndOperator("SELECT regexp_extract(_1, '', 0) FROM tbl") @@ -617,7 +617,7 @@ class CometStringExpressionSuite extends CometTestBase { val data = Seq(("a1b2c3", 1), ("test123", 1), (null, 1)) withParquetTable(data, "tbl") { - // Group index out of bounds - should match Spark's behavior (error or empty) + // Group index out of bounds - should match Spark's behavior (error) // Pattern has only 1 group, asking for group 2 checkSparkAnswerAndOperator("SELECT regexp_extract_all(_1, '(\\d+)', 2) FROM tbl")