From 12010ecf4b408537cd240b1de3c6686563647722 Mon Sep 17 00:00:00 2001 From: Rossi Sun Date: Wed, 24 Dec 2025 09:27:55 +0800 Subject: [PATCH] Graceful error for decimal binary arithmatic and comparison instead of firing confusing assertion --- .../arrow/compute/kernels/codegen_internal.cc | 11 +++++++++ .../compute/kernels/codegen_internal_test.cc | 16 +++++++++++++ .../compute/kernels/scalar_arithmetic_test.cc | 23 +++++++++++++++++++ .../compute/kernels/scalar_compare_test.cc | 20 ++++++++++++++++ 4 files changed, 70 insertions(+) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index 10ed9344d97..99a9173041f 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -393,11 +393,22 @@ TypeHolder CommonBinary(const TypeHolder* begin, size_t count) { return large_binary(); } +bool CastableToDecimal(const DataType& type) { + return is_numeric(type.id()) || is_decimal(type.id()); +} + Status CastBinaryDecimalArgs(DecimalPromotion promotion, std::vector* types) { const DataType& left_type = *(*types)[0]; const DataType& right_type = *(*types)[1]; DCHECK(is_decimal(left_type.id()) || is_decimal(right_type.id())); + if ((is_decimal(left_type.id()) && !CastableToDecimal(right_type)) || + (is_decimal(right_type.id()) && !CastableToDecimal(left_type))) { + // If the other type is not castable to decimal, do not cast. The dispatch will + // gracefully fail by kernel selection. + return Status::OK(); + } + // decimal + float64 = float64 // decimal + float32 is roughly float64 + float32 so we choose float64 if (is_floating(left_type.id()) || is_floating(right_type.id())) { diff --git a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc index 8aa90823c1d..8596aa76809 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc @@ -51,6 +51,22 @@ TEST(TestDispatchBest, CastBinaryDecimalArgs) { EXPECT_RAISES_WITH_MESSAGE_THAT( NotImplemented, ::testing::HasSubstr("Decimals with negative scales not supported"), CastBinaryDecimalArgs(DecimalPromotion::kAdd, &args)); + + // Non-castable -> unchanged + for (const auto promotion : + {DecimalPromotion::kAdd, DecimalPromotion::kMultiply, DecimalPromotion::kDivide}) { + for (const auto& args : std::vector>{ + {decimal128(3, 2), boolean()}, + {boolean(), decimal128(3, 2)}, + {decimal128(3, 2), utf8()}, + {utf8(), decimal128(3, 2)}, + }) { + auto args_copy = args; + ASSERT_OK(CastBinaryDecimalArgs(promotion, &args_copy)); + AssertTypeEqual(*args_copy[0], *args[0]); + AssertTypeEqual(*args_copy[1], *args[1]); + } + } } TEST(TestDispatchBest, CastDecimalArgs) { diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index 9ff3a7fa18a..9367ad2c89d 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -2494,6 +2494,29 @@ TEST_F(TestBinaryArithmeticDecimal, Power) { } } +TEST_F(TestBinaryArithmeticDecimal, ErrorOnNonCastable) { + for (const auto& name : {"add", "subtract", "multiply", "divide"}) { + for (const auto& suffix : {"", "_checked"}) { + auto func = std::string(name) + suffix; + SCOPED_TRACE(func); + for (const auto& dec_ty : PositiveScaleTypes()) { + SCOPED_TRACE(dec_ty->ToString()); + auto dec_arr = ArrayFromJSON(dec_ty, R"([])"); + for (const auto& other_ty : {boolean(), fixed_size_binary(42), utf8()}) { + SCOPED_TRACE(other_ty->ToString()); + auto other_arr = ArrayFromJSON(other_ty, R"([])"); + EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, + ::testing::HasSubstr("has no kernel matching"), + CallFunction(func, {dec_arr, other_arr})); + EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, + ::testing::HasSubstr("has no kernel matching"), + CallFunction(func, {other_arr, dec_arr})); + } + } + } + } +} + TYPED_TEST(TestBinaryArithmeticIntegral, ShiftLeft) { for (auto check_overflow : {false, true}) { this->SetOverflowCheck(check_overflow); diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc index 9372955fb4c..23c7ab21bd2 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc @@ -681,6 +681,26 @@ TYPED_TEST(TestCompareDecimal, DifferentParameters) { } } +TYPED_TEST(TestCompareDecimal, ErrorOnNonCastable) { + auto dec_ty = std::make_shared(3, 2); + auto dec_arr = ArrayFromJSON(dec_ty, R"([])"); + + for (const auto& func : + {"equal", "not_equal", "less", "less_equal", "greater", "greater_equal"}) { + SCOPED_TRACE(func); + for (const auto& other_ty : {boolean(), fixed_size_binary(42), utf8()}) { + SCOPED_TRACE(other_ty->ToString()); + auto other_arr = ArrayFromJSON(other_ty, R"([])"); + EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, + ::testing::HasSubstr("has no kernel matching"), + CallFunction(func, {dec_arr, other_arr})); + EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, + ::testing::HasSubstr("has no kernel matching"), + CallFunction(func, {other_arr, dec_arr})); + } + } +} + // Helper to organize tests for fixed size binary comparisons struct CompareCase { std::shared_ptr lhs_type;