diff --git a/src/iceberg/expression/binder.cc b/src/iceberg/expression/binder.cc index 43c3ebcdf..650dc730d 100644 --- a/src/iceberg/expression/binder.cc +++ b/src/iceberg/expression/binder.cc @@ -19,6 +19,9 @@ #include "iceberg/expression/binder.h" +#include "iceberg/result.h" +#include "iceberg/util/macros.h" + namespace iceberg { Binder::Binder(const Schema& schema, bool case_sensitive) @@ -54,30 +57,30 @@ Result> Binder::Or( Result> Binder::Predicate( const std::shared_ptr& pred) { - ICEBERG_DCHECK(pred != nullptr, "Predicate cannot be null"); + ICEBERG_PRECHECK(pred != nullptr, "Predicate cannot be null"); return pred->Bind(schema_, case_sensitive_); } Result> Binder::Predicate( const std::shared_ptr& pred) { - ICEBERG_DCHECK(pred != nullptr, "Predicate cannot be null"); + ICEBERG_PRECHECK(pred != nullptr, "Predicate cannot be null"); return InvalidExpression("Found already bound predicate: {}", pred->ToString()); } Result> Binder::Aggregate( const std::shared_ptr& aggregate) { - ICEBERG_DCHECK(aggregate != nullptr, "Aggregate cannot be null"); + ICEBERG_PRECHECK(aggregate != nullptr, "Aggregate cannot be null"); return InvalidExpression("Found already bound aggregate: {}", aggregate->ToString()); } Result> Binder::Aggregate( const std::shared_ptr& aggregate) { - ICEBERG_DCHECK(aggregate != nullptr, "Aggregate cannot be null"); + ICEBERG_PRECHECK(aggregate != nullptr, "Aggregate cannot be null"); return aggregate->Bind(schema_, case_sensitive_); } Result IsBoundVisitor::IsBound(const std::shared_ptr& expr) { - ICEBERG_DCHECK(expr != nullptr, "Expression cannot be null"); + ICEBERG_PRECHECK(expr != nullptr, "Expression cannot be null"); IsBoundVisitor visitor; return Visit(expr, visitor); } @@ -113,4 +116,54 @@ Result IsBoundVisitor::Aggregate( return false; } +Result> ReferenceVisitor::GetReferencedFieldIds( + const std::shared_ptr& expr) { + ICEBERG_PRECHECK(expr != nullptr, "Expression cannot be null"); + ReferenceVisitor visitor; + return Visit(expr, visitor); +} + +Result ReferenceVisitor::AlwaysTrue() { return referenced_field_ids_; } + +Result ReferenceVisitor::AlwaysFalse() { return referenced_field_ids_; } + +Result ReferenceVisitor::Not( + [[maybe_unused]] const FieldIdsSetRef& child_result) { + return referenced_field_ids_; +} + +Result ReferenceVisitor::And( + [[maybe_unused]] const FieldIdsSetRef& left_result, + [[maybe_unused]] const FieldIdsSetRef& right_result) { + return referenced_field_ids_; +} + +Result ReferenceVisitor::Or( + [[maybe_unused]] const FieldIdsSetRef& left_result, + [[maybe_unused]] const FieldIdsSetRef& right_result) { + return referenced_field_ids_; +} + +Result ReferenceVisitor::Predicate( + const std::shared_ptr& pred) { + referenced_field_ids_.insert(pred->reference()->field_id()); + return referenced_field_ids_; +} + +Result ReferenceVisitor::Predicate( + [[maybe_unused]] const std::shared_ptr& pred) { + return InvalidExpression("Cannot get referenced field IDs from unbound predicate"); +} + +Result ReferenceVisitor::Aggregate( + const std::shared_ptr& aggregate) { + referenced_field_ids_.insert(aggregate->reference()->field_id()); + return referenced_field_ids_; +} + +Result ReferenceVisitor::Aggregate( + [[maybe_unused]] const std::shared_ptr& aggregate) { + return InvalidExpression("Cannot get referenced field IDs from unbound aggregate"); +} + } // namespace iceberg diff --git a/src/iceberg/expression/binder.h b/src/iceberg/expression/binder.h index a78b7a4bb..276ab0760 100644 --- a/src/iceberg/expression/binder.h +++ b/src/iceberg/expression/binder.h @@ -22,6 +22,9 @@ /// \file iceberg/expression/binder.h /// Bind an expression to a schema. +#include +#include + #include "iceberg/expression/expression_visitor.h" namespace iceberg { @@ -73,6 +76,31 @@ class ICEBERG_EXPORT IsBoundVisitor : public ExpressionVisitor { Result Aggregate(const std::shared_ptr& aggregate) override; }; -// TODO(gangwu): add the Java parity `ReferenceVisitor` +using FieldIdsSetRef = std::reference_wrapper>; + +/// \brief Visitor to collect referenced field IDs from an expression. +class ICEBERG_EXPORT ReferenceVisitor : public ExpressionVisitor { + public: + static Result> GetReferencedFieldIds( + const std::shared_ptr& expr); + + Result AlwaysTrue() override; + Result AlwaysFalse() override; + Result Not(const FieldIdsSetRef& child_result) override; + Result And(const FieldIdsSetRef& left_result, + const FieldIdsSetRef& right_result) override; + Result Or(const FieldIdsSetRef& left_result, + const FieldIdsSetRef& right_result) override; + Result Predicate(const std::shared_ptr& pred) override; + Result Predicate( + const std::shared_ptr& pred) override; + Result Aggregate( + const std::shared_ptr& aggregate) override; + Result Aggregate( + const std::shared_ptr& aggregate) override; + + private: + std::unordered_set referenced_field_ids_; +}; } // namespace iceberg diff --git a/src/iceberg/test/expression_visitor_test.cc b/src/iceberg/test/expression_visitor_test.cc index f2bbe70ea..697c0096a 100644 --- a/src/iceberg/test/expression_visitor_test.cc +++ b/src/iceberg/test/expression_visitor_test.cc @@ -22,6 +22,7 @@ #include "iceberg/expression/binder.h" #include "iceberg/expression/expressions.h" #include "iceberg/expression/rewrite_not.h" +#include "iceberg/result.h" #include "iceberg/schema.h" #include "iceberg/test/matchers.h" #include "iceberg/type.h" @@ -505,4 +506,207 @@ TEST_F(RewriteNotTest, ComplexExpression) { EXPECT_EQ(rewritten->op(), Expression::Operation::kOr); } +class ReferenceVisitorTest : public ExpressionVisitorTest {}; + +TEST_F(ReferenceVisitorTest, Constants) { + // Constants should have no referenced fields + auto true_expr = Expressions::AlwaysTrue(); + ICEBERG_UNWRAP_OR_FAIL(auto refs_true, + ReferenceVisitor::GetReferencedFieldIds(true_expr)); + EXPECT_TRUE(refs_true.empty()); + + auto false_expr = Expressions::AlwaysFalse(); + ICEBERG_UNWRAP_OR_FAIL(auto refs_false, + ReferenceVisitor::GetReferencedFieldIds(false_expr)); + EXPECT_TRUE(refs_false.empty()); +} + +TEST_F(ReferenceVisitorTest, UnboundPredicate) { + auto unbound_pred = Expressions::Equal("name", Literal::String("Alice")); + auto result = ReferenceVisitor::GetReferencedFieldIds(unbound_pred); + EXPECT_THAT(result, IsError(ErrorKind::kInvalidExpression)); + EXPECT_THAT(result, + HasErrorMessage("Cannot get referenced field IDs from unbound predicate")); +} + +TEST_F(ReferenceVisitorTest, BoundPredicate) { + // Bound predicate should return the field ID + auto unbound_pred = Expressions::Equal("name", Literal::String("Alice")); + ICEBERG_UNWRAP_OR_FAIL(auto bound_pred, Bind(unbound_pred)); + + ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_pred)); + EXPECT_EQ(refs.size(), 1); + EXPECT_EQ(refs.count(2), 1); // name field has id=2 +} + +TEST_F(ReferenceVisitorTest, MultiplePredicates) { + // Test various predicates with different fields + auto pred_age = Expressions::GreaterThan("age", Literal::Int(25)); + ICEBERG_UNWRAP_OR_FAIL(auto bound_age, Bind(pred_age)); + ICEBERG_UNWRAP_OR_FAIL(auto refs_age, + ReferenceVisitor::GetReferencedFieldIds(bound_age)); + EXPECT_EQ(refs_age.size(), 1); + EXPECT_EQ(refs_age.count(3), 1); // age field has id=3 + + auto pred_salary = Expressions::LessThan("salary", Literal::Double(50000.0)); + ICEBERG_UNWRAP_OR_FAIL(auto bound_salary, Bind(pred_salary)); + ICEBERG_UNWRAP_OR_FAIL(auto refs_salary, + ReferenceVisitor::GetReferencedFieldIds(bound_salary)); + EXPECT_EQ(refs_salary.size(), 1); + EXPECT_EQ(refs_salary.count(4), 1); // salary field has id=4 +} + +TEST_F(ReferenceVisitorTest, UnaryPredicates) { + // Test unary predicates + auto pred_is_null = Expressions::IsNull("name"); + ICEBERG_UNWRAP_OR_FAIL(auto bound_is_null, Bind(pred_is_null)); + ICEBERG_UNWRAP_OR_FAIL(auto refs, + ReferenceVisitor::GetReferencedFieldIds(bound_is_null)); + EXPECT_EQ(refs.size(), 1); + EXPECT_EQ(refs.count(2), 1); + + auto pred_is_nan = Expressions::IsNaN("salary"); + ICEBERG_UNWRAP_OR_FAIL(auto bound_is_nan, Bind(pred_is_nan)); + ICEBERG_UNWRAP_OR_FAIL(auto refs_nan, + ReferenceVisitor::GetReferencedFieldIds(bound_is_nan)); + EXPECT_EQ(refs_nan.size(), 1); + EXPECT_EQ(refs_nan.count(4), 1); +} + +TEST_F(ReferenceVisitorTest, AndExpression) { + // AND expression should return union of field IDs from both sides + auto pred1 = Expressions::Equal("name", Literal::String("Alice")); + auto pred2 = Expressions::GreaterThan("age", Literal::Int(25)); + auto and_expr = Expressions::And(pred1, pred2); + + ICEBERG_UNWRAP_OR_FAIL(auto bound_and, Bind(and_expr)); + ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_and)); + + EXPECT_EQ(refs.size(), 2); + EXPECT_EQ(refs.count(2), 1); // name field + EXPECT_EQ(refs.count(3), 1); // age field +} + +TEST_F(ReferenceVisitorTest, OrExpression) { + // OR expression should return union of field IDs from both sides + auto pred1 = Expressions::IsNull("salary"); + auto pred2 = Expressions::Equal("active", Literal::Boolean(true)); + auto or_expr = Expressions::Or(pred1, pred2); + + ICEBERG_UNWRAP_OR_FAIL(auto bound_or, Bind(or_expr)); + ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_or)); + + EXPECT_EQ(refs.size(), 2); + EXPECT_EQ(refs.count(4), 1); // salary field + EXPECT_EQ(refs.count(5), 1); // active field +} + +TEST_F(ReferenceVisitorTest, NotExpression) { + // NOT expression should return field IDs from its child + auto pred = Expressions::Equal("name", Literal::String("Alice")); + auto not_expr = Expressions::Not(pred); + + ICEBERG_UNWRAP_OR_FAIL(auto bound_not, Bind(not_expr)); + ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_not)); + + EXPECT_EQ(refs.size(), 1); + EXPECT_EQ(refs.count(2), 1); // name field +} + +TEST_F(ReferenceVisitorTest, ComplexNestedExpression) { + // (name = 'Alice' AND age > 25) OR (salary < 30000 AND active = true) + // Should reference fields: name(2), age(3), salary(4), active(5) + auto pred1 = Expressions::Equal("name", Literal::String("Alice")); + auto pred2 = Expressions::GreaterThan("age", Literal::Int(25)); + auto pred3 = Expressions::LessThan("salary", Literal::Double(30000.0)); + auto pred4 = Expressions::Equal("active", Literal::Boolean(true)); + + auto and1 = Expressions::And(pred1, pred2); + auto and2 = Expressions::And(pred3, pred4); + auto complex_or = Expressions::Or(and1, and2); + + ICEBERG_UNWRAP_OR_FAIL(auto bound_complex, Bind(complex_or)); + ICEBERG_UNWRAP_OR_FAIL(auto refs, + ReferenceVisitor::GetReferencedFieldIds(bound_complex)); + + EXPECT_EQ(refs.size(), 4); + EXPECT_EQ(refs.count(2), 1); // name field + EXPECT_EQ(refs.count(3), 1); // age field + EXPECT_EQ(refs.count(4), 1); // salary field + EXPECT_EQ(refs.count(5), 1); // active field +} + +TEST_F(ReferenceVisitorTest, DuplicateFieldReferences) { + // Multiple predicates referencing the same field + // age > 25 AND age < 50 + auto pred1 = Expressions::GreaterThan("age", Literal::Int(25)); + auto pred2 = Expressions::LessThan("age", Literal::Int(50)); + auto and_expr = Expressions::And(pred1, pred2); + + ICEBERG_UNWRAP_OR_FAIL(auto bound_and, Bind(and_expr)); + ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_and)); + + // Should only contain the field ID once (set semantics) + EXPECT_EQ(refs.size(), 1); + EXPECT_EQ(refs.count(3), 1); // age field +} + +TEST_F(ReferenceVisitorTest, SetPredicates) { + // Test In predicate + auto pred_in = + Expressions::In("age", {Literal::Int(25), Literal::Int(30), Literal::Int(35)}); + ICEBERG_UNWRAP_OR_FAIL(auto bound_in, Bind(pred_in)); + ICEBERG_UNWRAP_OR_FAIL(auto refs_in, ReferenceVisitor::GetReferencedFieldIds(bound_in)); + + EXPECT_EQ(refs_in.size(), 1); + EXPECT_EQ(refs_in.count(3), 1); // age field + + // Test NotIn predicate + auto pred_not_in = + Expressions::NotIn("name", {Literal::String("Alice"), Literal::String("Bob")}); + ICEBERG_UNWRAP_OR_FAIL(auto bound_not_in, Bind(pred_not_in)); + ICEBERG_UNWRAP_OR_FAIL(auto refs_not_in, + ReferenceVisitor::GetReferencedFieldIds(bound_not_in)); + + EXPECT_EQ(refs_not_in.size(), 1); + EXPECT_EQ(refs_not_in.count(2), 1); // name field +} + +TEST_F(ReferenceVisitorTest, MixedBoundAndUnbound) { + auto bound_pred = Expressions::Equal("name", Literal::String("Alice")); + ICEBERG_UNWRAP_OR_FAIL(auto pred1, Bind(bound_pred)); + auto unbound_pred = Expressions::GreaterThan("age", Literal::Int(25)); + auto mixed_and = Expressions::And(pred1, unbound_pred); + + auto result = ReferenceVisitor::GetReferencedFieldIds(mixed_and); + EXPECT_THAT(result, IsError(ErrorKind::kInvalidExpression)); + EXPECT_THAT(result, + HasErrorMessage("Cannot get referenced field IDs from unbound predicate")); +} + +TEST_F(ReferenceVisitorTest, AllFields) { + // Create expression referencing all fields in the schema + auto pred1 = Expressions::NotNull("id"); + auto pred2 = Expressions::Equal("name", Literal::String("Test")); + auto pred3 = Expressions::GreaterThan("age", Literal::Int(0)); + auto pred4 = Expressions::LessThan("salary", Literal::Double(100000.0)); + auto pred5 = Expressions::Equal("active", Literal::Boolean(true)); + + auto and1 = Expressions::And(pred1, pred2); + auto and2 = Expressions::And(pred3, pred4); + auto and3 = Expressions::And(and1, and2); + auto all_fields = Expressions::And(and3, pred5); + + ICEBERG_UNWRAP_OR_FAIL(auto bound_all, Bind(all_fields)); + ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_all)); + + // Should reference all 5 fields + EXPECT_EQ(refs.size(), 4); + EXPECT_EQ(refs.count(1), 0); // id field is optimized out + EXPECT_EQ(refs.count(2), 1); // name field + EXPECT_EQ(refs.count(3), 1); // age field + EXPECT_EQ(refs.count(4), 1); // salary field + EXPECT_EQ(refs.count(5), 1); // active field +} + } // namespace iceberg