diff --git a/paimon-api/src/main/java/org/apache/paimon/utils/StringUtils.java b/paimon-api/src/main/java/org/apache/paimon/utils/StringUtils.java index 75f9cd2aabf7..c189de92e5d6 100644 --- a/paimon-api/src/main/java/org/apache/paimon/utils/StringUtils.java +++ b/paimon-api/src/main/java/org/apache/paimon/utils/StringUtils.java @@ -545,6 +545,33 @@ public static String trim(String value) { return value.trim(); } + public static String trim(String value, String charsToTrim) { + return rtrim(ltrim(value, charsToTrim), charsToTrim); + } + + public static String ltrim(String value, String charsToTrim) { + if (value == null || charsToTrim == null) { + return null; + } + StringBuilder sb = new StringBuilder(value); + while (sb.length() > 0 && charsToTrim.contains(sb.substring(0, 1))) { + sb.deleteCharAt(0); + } + return sb.toString(); + } + + public static String rtrim(String value, String charsToTrim) { + if (value == null || charsToTrim == null) { + return null; + } + StringBuilder sb = new StringBuilder(value); + while (sb.length() > 0 + && charsToTrim.contains(sb.substring(sb.length() - 1, sb.length()))) { + sb.deleteCharAt(sb.length() - 1); + } + return sb.toString(); + } + public static String toUpperCase(String value) { if (value == null) { return null; diff --git a/paimon-common/src/main/java/org/apache/paimon/predicate/TrimTransform.java b/paimon-common/src/main/java/org/apache/paimon/predicate/TrimTransform.java new file mode 100644 index 000000000000..6182335bb221 --- /dev/null +++ b/paimon-common/src/main/java/org/apache/paimon/predicate/TrimTransform.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.predicate; + +import org.apache.paimon.data.BinaryString; +import org.apache.paimon.utils.StringUtils; + +import java.util.List; + +import static org.apache.paimon.utils.Preconditions.checkArgument; + +/** TRIM/LTRIM/RTRIM {@link Transform}. */ +public class TrimTransform extends StringTransform { + + private static final long serialVersionUID = 1L; + + public static final String NAME = "TRIM"; + + private final Flag trimFlag; + + public TrimTransform(List inputs, Flag trimFlag) { + super(inputs); + this.trimFlag = trimFlag; + checkArgument(inputs.size() == 1 || inputs.size() == 2); + } + + @Override + public String name() { + return NAME; + } + + @Override + public BinaryString transform(List inputs) { + if (inputs.get(0) == null) { + return null; + } + String sourceString = inputs.get(0).toString(); + String charsToTrim = inputs.size() == 1 ? " " : inputs.get(1).toString(); + switch (trimFlag) { + case BOTH: + return BinaryString.fromString(StringUtils.trim(sourceString, charsToTrim)); + case LEADING: + return BinaryString.fromString(StringUtils.ltrim(sourceString, charsToTrim)); + case TRAILING: + return BinaryString.fromString(StringUtils.rtrim(sourceString, charsToTrim)); + default: + throw new IllegalArgumentException("Invalid trim way " + trimFlag.name()); + } + } + + @Override + public Transform copyWithNewInputs(List inputs) { + return new TrimTransform(inputs, this.trimFlag); + } + + /** Enum of trim functions. */ + public enum Flag { + LEADING, + TRAILING, + BOTH + } +} diff --git a/paimon-common/src/test/java/org/apache/paimon/predicate/TrimTransformTest.java b/paimon-common/src/test/java/org/apache/paimon/predicate/TrimTransformTest.java new file mode 100644 index 000000000000..b24fda78a7d7 --- /dev/null +++ b/paimon-common/src/test/java/org/apache/paimon/predicate/TrimTransformTest.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.predicate; + +import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.GenericRow; +import org.apache.paimon.types.DataTypes; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +class TrimTransformTest { + + @Test + public void testNullInputs() { + List inputs = new ArrayList<>(); + // test for single argument + inputs.add(null); + TrimTransform transform = new TrimTransform(inputs, TrimTransform.Flag.BOTH); + Object result = transform.transform(GenericRow.of()); + assertThat(result).isNull(); + + transform = new TrimTransform(inputs, TrimTransform.Flag.LEADING); + result = transform.transform(GenericRow.of()); + assertThat(result).isNull(); + + transform = new TrimTransform(inputs, TrimTransform.Flag.TRAILING); + result = transform.transform(GenericRow.of()); + assertThat(result).isNull(); + + // test for binary argument + inputs.add(null); + transform = new TrimTransform(inputs, TrimTransform.Flag.BOTH); + result = transform.transform(GenericRow.of()); + assertThat(result).isNull(); + + transform = new TrimTransform(inputs, TrimTransform.Flag.LEADING); + result = transform.transform(GenericRow.of()); + assertThat(result).isNull(); + + transform = new TrimTransform(inputs, TrimTransform.Flag.TRAILING); + result = transform.transform(GenericRow.of()); + assertThat(result).isNull(); + } + + @Test + public void testNormalInputs() { + // test trim('cd', 'cddcaadccd') + List inputs = new ArrayList<>(); + inputs.add(BinaryString.fromString("cddcaadccd")); + inputs.add(BinaryString.fromString("cd")); + TrimTransform transform = new TrimTransform(inputs, TrimTransform.Flag.BOTH); + Object result = transform.transform(GenericRow.of()); + assertThat(result).isEqualTo(BinaryString.fromString("aa")); + + // test ltrim('cd', 'cddcaadccd') + transform = new TrimTransform(inputs, TrimTransform.Flag.LEADING); + result = transform.transform(GenericRow.of()); + assertThat(result).isEqualTo(BinaryString.fromString("aadccd")); + + // test rtrim('cd', 'cddcaadccd') + transform = new TrimTransform(inputs, TrimTransform.Flag.TRAILING); + result = transform.transform(GenericRow.of()); + assertThat(result).isEqualTo(BinaryString.fromString("cddcaa")); + + // test trim(' aa ') + inputs.clear(); + inputs.add(BinaryString.fromString(" aa ")); + transform = new TrimTransform(inputs, TrimTransform.Flag.BOTH); + result = transform.transform(GenericRow.of()); + assertThat(result).isEqualTo(BinaryString.fromString("aa")); + + // test trim(' aa ') + transform = new TrimTransform(inputs, TrimTransform.Flag.LEADING); + result = transform.transform(GenericRow.of()); + assertThat(result).isEqualTo(BinaryString.fromString("aa ")); + + // test trim(' aa ') + transform = new TrimTransform(inputs, TrimTransform.Flag.TRAILING); + result = transform.transform(GenericRow.of()); + assertThat(result).isEqualTo(BinaryString.fromString(" aa")); + } + + @Test + public void testSubstringRefInputs() { + List inputs = new ArrayList<>(); + inputs.add(new FieldRef(1, "f1", DataTypes.STRING())); + inputs.add(new FieldRef(2, "f2", DataTypes.STRING())); + TrimTransform transform = new TrimTransform(inputs, TrimTransform.Flag.BOTH); + Object result = + transform.transform( + GenericRow.of( + BinaryString.fromString(""), + BinaryString.fromString("ahellob"), + BinaryString.fromString("ab"))); + assertThat(result).isEqualTo(BinaryString.fromString("hello")); + + transform = new TrimTransform(inputs, TrimTransform.Flag.LEADING); + result = + transform.transform( + GenericRow.of( + BinaryString.fromString(""), + BinaryString.fromString("ahellob"), + BinaryString.fromString("ab"))); + assertThat(result).isEqualTo(BinaryString.fromString("hellob")); + + transform = new TrimTransform(inputs, TrimTransform.Flag.TRAILING); + result = + transform.transform( + GenericRow.of( + BinaryString.fromString(""), + BinaryString.fromString("ahellob"), + BinaryString.fromString("ab"))); + assertThat(result).isEqualTo(BinaryString.fromString("ahello")); + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/util/SparkExpressionConverter.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/util/SparkExpressionConverter.scala index 347bde6513e9..a5ff3598fbd5 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/util/SparkExpressionConverter.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/util/SparkExpressionConverter.scala @@ -39,6 +39,9 @@ object SparkExpressionConverter { private val UPPER = "UPPER" private val LOWER = "LOWER" private val SUBSTRING = "SUBSTRING" + private val TRIM = "TRIM" + private val LTRIM = "LTRIM" + private val RTRIM = "RTRIM" /** Convert Spark [[Expression]] to Paimon [[Transform]], return None if not supported. */ def toPaimonTransform(exp: Expression, rowType: RowType): Option[Transform] = { @@ -64,6 +67,13 @@ object SparkExpressionConverter { case UPPER => convertChildren(s.children()).map(i => new UpperTransform(i)) case LOWER => convertChildren(s.children()).map(i => new LowerTransform(i)) case SUBSTRING => convertChildren(s.children()).map(i => new SubstringTransform(i)) + case TRIM => + convertChildren(s.children()).map(i => new TrimTransform(i, TrimTransform.Flag.BOTH)) + case LTRIM => + convertChildren(s.children()).map(i => new TrimTransform(i, TrimTransform.Flag.LEADING)) + case RTRIM => + convertChildren(s.children()).map( + i => new TrimTransform(i, TrimTransform.Flag.TRAILING)) case _ => None } case c: Cast => diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala index 747b140cb389..2f90e47c8515 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala @@ -130,6 +130,72 @@ abstract class PaimonPushDownTestBase extends PaimonSparkTestBase with AdaptiveS } } + test(s"Paimon push down: apply TRIM/LTRM/RTRIM") { + // Spark support push down TRIM/LTRM/RTRIM since Spark 3.4. + if (gteqSpark3_4) { + withTable("t") { + sql(""" + |CREATE TABLE t (id int, value int, dt STRING) + |using paimon + |PARTITIONED BY (dt) + |""".stripMargin) + + sql(""" + |INSERT INTO t values + |(1, 100, 'chelloc'), (1, 100, 'caa'), (1, 100, 'bbc') + |""".stripMargin) + + val q = + """ + |SELECT * FROM t + |WHERE TRIM('c', dt) = 'hello' + |""".stripMargin + assert(!checkFilterExists(q)) + + checkAnswer( + spark.sql(q), + Seq(Row(1, 100, "chelloc")) + ) + + val q1 = + """ + |SELECT * FROM t + |WHERE LTRIM('c', dt) = 'aa' + |""".stripMargin + assert(!checkFilterExists(q1)) + + checkAnswer( + spark.sql(q1), + Seq(Row(1, 100, "caa")) + ) + + val q2 = + """ + |SELECT * FROM t + |WHERE RTRIM('c', dt) = 'bb' + |""".stripMargin + assert(!checkFilterExists(q2)) + + checkAnswer( + spark.sql(q2), + Seq(Row(1, 100, "bbc")) + ) + + val q3 = + """ + |SELECT * FROM t + |WHERE TRIM(LEADING 'c' FROM dt) = 'aa' + |""".stripMargin + assert(!checkFilterExists(q2)) + + checkAnswer( + spark.sql(q3), + Seq(Row(1, 100, "caa")) + ) + } + } + } + test(s"Paimon push down: apply UPPER") { // Spark support push down UPPER since Spark 3.4. if (gteqSpark3_4) {