Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 58 additions & 5 deletions src/iceberg/expression/binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

#include "iceberg/expression/binder.h"

#include "iceberg/result.h"
#include "iceberg/util/macros.h"

namespace iceberg {

Binder::Binder(const Schema& schema, bool case_sensitive)
Expand Down Expand Up @@ -54,30 +57,30 @@ Result<std::shared_ptr<Expression>> Binder::Or(

Result<std::shared_ptr<Expression>> Binder::Predicate(
const std::shared_ptr<UnboundPredicate>& pred) {
ICEBERG_DCHECK(pred != nullptr, "Predicate cannot be null");
ICEBERG_PRECHECK(pred != nullptr, "Predicate cannot be null");
return pred->Bind(schema_, case_sensitive_);
}

Result<std::shared_ptr<Expression>> Binder::Predicate(
const std::shared_ptr<BoundPredicate>& pred) {
ICEBERG_DCHECK(pred != nullptr, "Predicate cannot be null");
ICEBERG_PRECHECK(pred != nullptr, "Predicate cannot be null");
return InvalidExpression("Found already bound predicate: {}", pred->ToString());
}

Result<std::shared_ptr<Expression>> Binder::Aggregate(
const std::shared_ptr<BoundAggregate>& aggregate) {
ICEBERG_DCHECK(aggregate != nullptr, "Aggregate cannot be null");
ICEBERG_PRECHECK(aggregate != nullptr, "Aggregate cannot be null");
return InvalidExpression("Found already bound aggregate: {}", aggregate->ToString());
}

Result<std::shared_ptr<Expression>> Binder::Aggregate(
const std::shared_ptr<UnboundAggregate>& aggregate) {
ICEBERG_DCHECK(aggregate != nullptr, "Aggregate cannot be null");
ICEBERG_PRECHECK(aggregate != nullptr, "Aggregate cannot be null");
return aggregate->Bind(schema_, case_sensitive_);
}

Result<bool> IsBoundVisitor::IsBound(const std::shared_ptr<Expression>& expr) {
ICEBERG_DCHECK(expr != nullptr, "Expression cannot be null");
ICEBERG_PRECHECK(expr != nullptr, "Expression cannot be null");
IsBoundVisitor visitor;
return Visit<bool, IsBoundVisitor>(expr, visitor);
}
Expand Down Expand Up @@ -113,4 +116,54 @@ Result<bool> IsBoundVisitor::Aggregate(
return false;
}

Result<std::unordered_set<int32_t>> ReferenceVisitor::GetReferencedFieldIds(
const std::shared_ptr<Expression>& expr) {
ICEBERG_PRECHECK(expr != nullptr, "Expression cannot be null");
ReferenceVisitor visitor;
return Visit<FieldIdsSetRef, ReferenceVisitor>(expr, visitor);
}

Result<FieldIdsSetRef> ReferenceVisitor::AlwaysTrue() { return referenced_field_ids_; }

Result<FieldIdsSetRef> ReferenceVisitor::AlwaysFalse() { return referenced_field_ids_; }

Result<FieldIdsSetRef> ReferenceVisitor::Not(
[[maybe_unused]] const FieldIdsSetRef& child_result) {
return referenced_field_ids_;
}

Result<FieldIdsSetRef> ReferenceVisitor::And(
[[maybe_unused]] const FieldIdsSetRef& left_result,
[[maybe_unused]] const FieldIdsSetRef& right_result) {
return referenced_field_ids_;
}

Result<FieldIdsSetRef> ReferenceVisitor::Or(
[[maybe_unused]] const FieldIdsSetRef& left_result,
[[maybe_unused]] const FieldIdsSetRef& right_result) {
return referenced_field_ids_;
}

Result<FieldIdsSetRef> ReferenceVisitor::Predicate(
const std::shared_ptr<BoundPredicate>& pred) {
referenced_field_ids_.insert(pred->reference()->field_id());
return referenced_field_ids_;
}

Result<FieldIdsSetRef> ReferenceVisitor::Predicate(
[[maybe_unused]] const std::shared_ptr<UnboundPredicate>& pred) {
return InvalidExpression("Cannot get referenced field IDs from unbound predicate");
}

Result<FieldIdsSetRef> ReferenceVisitor::Aggregate(
const std::shared_ptr<BoundAggregate>& aggregate) {
referenced_field_ids_.insert(aggregate->reference()->field_id());
return referenced_field_ids_;
}

Result<FieldIdsSetRef> ReferenceVisitor::Aggregate(
[[maybe_unused]] const std::shared_ptr<UnboundAggregate>& aggregate) {
return InvalidExpression("Cannot get referenced field IDs from unbound aggregate");
}

} // namespace iceberg
30 changes: 29 additions & 1 deletion src/iceberg/expression/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
/// \file iceberg/expression/binder.h
/// Bind an expression to a schema.

#include <functional>
#include <unordered_set>

#include "iceberg/expression/expression_visitor.h"

namespace iceberg {
Expand Down Expand Up @@ -73,6 +76,31 @@ class ICEBERG_EXPORT IsBoundVisitor : public ExpressionVisitor<bool> {
Result<bool> Aggregate(const std::shared_ptr<UnboundAggregate>& aggregate) override;
};

// TODO(gangwu): add the Java parity `ReferenceVisitor`
using FieldIdsSetRef = std::reference_wrapper<std::unordered_set<int32_t>>;

/// \brief Visitor to collect referenced field IDs from an expression.
class ICEBERG_EXPORT ReferenceVisitor : public ExpressionVisitor<FieldIdsSetRef> {
public:
static Result<std::unordered_set<int32_t>> GetReferencedFieldIds(
const std::shared_ptr<Expression>& expr);

Result<FieldIdsSetRef> AlwaysTrue() override;
Result<FieldIdsSetRef> AlwaysFalse() override;
Result<FieldIdsSetRef> Not(const FieldIdsSetRef& child_result) override;
Result<FieldIdsSetRef> And(const FieldIdsSetRef& left_result,
const FieldIdsSetRef& right_result) override;
Result<FieldIdsSetRef> Or(const FieldIdsSetRef& left_result,
const FieldIdsSetRef& right_result) override;
Result<FieldIdsSetRef> Predicate(const std::shared_ptr<BoundPredicate>& pred) override;
Result<FieldIdsSetRef> Predicate(
const std::shared_ptr<UnboundPredicate>& pred) override;
Result<FieldIdsSetRef> Aggregate(
const std::shared_ptr<BoundAggregate>& aggregate) override;
Result<FieldIdsSetRef> Aggregate(
const std::shared_ptr<UnboundAggregate>& aggregate) override;

private:
std::unordered_set<int32_t> referenced_field_ids_;
};

} // namespace iceberg
204 changes: 204 additions & 0 deletions src/iceberg/test/expression_visitor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "iceberg/expression/binder.h"
#include "iceberg/expression/expressions.h"
#include "iceberg/expression/rewrite_not.h"
#include "iceberg/result.h"
#include "iceberg/schema.h"
#include "iceberg/test/matchers.h"
#include "iceberg/type.h"
Expand Down Expand Up @@ -505,4 +506,207 @@ TEST_F(RewriteNotTest, ComplexExpression) {
EXPECT_EQ(rewritten->op(), Expression::Operation::kOr);
}

class ReferenceVisitorTest : public ExpressionVisitorTest {};

TEST_F(ReferenceVisitorTest, Constants) {
// Constants should have no referenced fields
auto true_expr = Expressions::AlwaysTrue();
ICEBERG_UNWRAP_OR_FAIL(auto refs_true,
ReferenceVisitor::GetReferencedFieldIds(true_expr));
EXPECT_TRUE(refs_true.empty());

auto false_expr = Expressions::AlwaysFalse();
ICEBERG_UNWRAP_OR_FAIL(auto refs_false,
ReferenceVisitor::GetReferencedFieldIds(false_expr));
EXPECT_TRUE(refs_false.empty());
}

TEST_F(ReferenceVisitorTest, UnboundPredicate) {
auto unbound_pred = Expressions::Equal("name", Literal::String("Alice"));
auto result = ReferenceVisitor::GetReferencedFieldIds(unbound_pred);
EXPECT_THAT(result, IsError(ErrorKind::kInvalidExpression));
EXPECT_THAT(result,
HasErrorMessage("Cannot get referenced field IDs from unbound predicate"));
}

TEST_F(ReferenceVisitorTest, BoundPredicate) {
// Bound predicate should return the field ID
auto unbound_pred = Expressions::Equal("name", Literal::String("Alice"));
ICEBERG_UNWRAP_OR_FAIL(auto bound_pred, Bind(unbound_pred));

ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_pred));
EXPECT_EQ(refs.size(), 1);
EXPECT_EQ(refs.count(2), 1); // name field has id=2
}

TEST_F(ReferenceVisitorTest, MultiplePredicates) {
// Test various predicates with different fields
auto pred_age = Expressions::GreaterThan("age", Literal::Int(25));
ICEBERG_UNWRAP_OR_FAIL(auto bound_age, Bind(pred_age));
ICEBERG_UNWRAP_OR_FAIL(auto refs_age,
ReferenceVisitor::GetReferencedFieldIds(bound_age));
EXPECT_EQ(refs_age.size(), 1);
EXPECT_EQ(refs_age.count(3), 1); // age field has id=3

auto pred_salary = Expressions::LessThan("salary", Literal::Double(50000.0));
ICEBERG_UNWRAP_OR_FAIL(auto bound_salary, Bind(pred_salary));
ICEBERG_UNWRAP_OR_FAIL(auto refs_salary,
ReferenceVisitor::GetReferencedFieldIds(bound_salary));
EXPECT_EQ(refs_salary.size(), 1);
EXPECT_EQ(refs_salary.count(4), 1); // salary field has id=4
}

TEST_F(ReferenceVisitorTest, UnaryPredicates) {
// Test unary predicates
auto pred_is_null = Expressions::IsNull("name");
ICEBERG_UNWRAP_OR_FAIL(auto bound_is_null, Bind(pred_is_null));
ICEBERG_UNWRAP_OR_FAIL(auto refs,
ReferenceVisitor::GetReferencedFieldIds(bound_is_null));
EXPECT_EQ(refs.size(), 1);
EXPECT_EQ(refs.count(2), 1);

auto pred_is_nan = Expressions::IsNaN("salary");
ICEBERG_UNWRAP_OR_FAIL(auto bound_is_nan, Bind(pred_is_nan));
ICEBERG_UNWRAP_OR_FAIL(auto refs_nan,
ReferenceVisitor::GetReferencedFieldIds(bound_is_nan));
EXPECT_EQ(refs_nan.size(), 1);
EXPECT_EQ(refs_nan.count(4), 1);
}

TEST_F(ReferenceVisitorTest, AndExpression) {
// AND expression should return union of field IDs from both sides
auto pred1 = Expressions::Equal("name", Literal::String("Alice"));
auto pred2 = Expressions::GreaterThan("age", Literal::Int(25));
auto and_expr = Expressions::And(pred1, pred2);

ICEBERG_UNWRAP_OR_FAIL(auto bound_and, Bind(and_expr));
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_and));

EXPECT_EQ(refs.size(), 2);
EXPECT_EQ(refs.count(2), 1); // name field
EXPECT_EQ(refs.count(3), 1); // age field
}

TEST_F(ReferenceVisitorTest, OrExpression) {
// OR expression should return union of field IDs from both sides
auto pred1 = Expressions::IsNull("salary");
auto pred2 = Expressions::Equal("active", Literal::Boolean(true));
auto or_expr = Expressions::Or(pred1, pred2);

ICEBERG_UNWRAP_OR_FAIL(auto bound_or, Bind(or_expr));
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_or));

EXPECT_EQ(refs.size(), 2);
EXPECT_EQ(refs.count(4), 1); // salary field
EXPECT_EQ(refs.count(5), 1); // active field
}

TEST_F(ReferenceVisitorTest, NotExpression) {
// NOT expression should return field IDs from its child
auto pred = Expressions::Equal("name", Literal::String("Alice"));
auto not_expr = Expressions::Not(pred);

ICEBERG_UNWRAP_OR_FAIL(auto bound_not, Bind(not_expr));
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_not));

EXPECT_EQ(refs.size(), 1);
EXPECT_EQ(refs.count(2), 1); // name field
}

TEST_F(ReferenceVisitorTest, ComplexNestedExpression) {
// (name = 'Alice' AND age > 25) OR (salary < 30000 AND active = true)
// Should reference fields: name(2), age(3), salary(4), active(5)
auto pred1 = Expressions::Equal("name", Literal::String("Alice"));
auto pred2 = Expressions::GreaterThan("age", Literal::Int(25));
auto pred3 = Expressions::LessThan("salary", Literal::Double(30000.0));
auto pred4 = Expressions::Equal("active", Literal::Boolean(true));

auto and1 = Expressions::And(pred1, pred2);
auto and2 = Expressions::And(pred3, pred4);
auto complex_or = Expressions::Or(and1, and2);

ICEBERG_UNWRAP_OR_FAIL(auto bound_complex, Bind(complex_or));
ICEBERG_UNWRAP_OR_FAIL(auto refs,
ReferenceVisitor::GetReferencedFieldIds(bound_complex));

EXPECT_EQ(refs.size(), 4);
EXPECT_EQ(refs.count(2), 1); // name field
EXPECT_EQ(refs.count(3), 1); // age field
EXPECT_EQ(refs.count(4), 1); // salary field
EXPECT_EQ(refs.count(5), 1); // active field
}

TEST_F(ReferenceVisitorTest, DuplicateFieldReferences) {
// Multiple predicates referencing the same field
// age > 25 AND age < 50
auto pred1 = Expressions::GreaterThan("age", Literal::Int(25));
auto pred2 = Expressions::LessThan("age", Literal::Int(50));
auto and_expr = Expressions::And(pred1, pred2);

ICEBERG_UNWRAP_OR_FAIL(auto bound_and, Bind(and_expr));
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_and));

// Should only contain the field ID once (set semantics)
EXPECT_EQ(refs.size(), 1);
EXPECT_EQ(refs.count(3), 1); // age field
}

TEST_F(ReferenceVisitorTest, SetPredicates) {
// Test In predicate
auto pred_in =
Expressions::In("age", {Literal::Int(25), Literal::Int(30), Literal::Int(35)});
ICEBERG_UNWRAP_OR_FAIL(auto bound_in, Bind(pred_in));
ICEBERG_UNWRAP_OR_FAIL(auto refs_in, ReferenceVisitor::GetReferencedFieldIds(bound_in));

EXPECT_EQ(refs_in.size(), 1);
EXPECT_EQ(refs_in.count(3), 1); // age field

// Test NotIn predicate
auto pred_not_in =
Expressions::NotIn("name", {Literal::String("Alice"), Literal::String("Bob")});
ICEBERG_UNWRAP_OR_FAIL(auto bound_not_in, Bind(pred_not_in));
ICEBERG_UNWRAP_OR_FAIL(auto refs_not_in,
ReferenceVisitor::GetReferencedFieldIds(bound_not_in));

EXPECT_EQ(refs_not_in.size(), 1);
EXPECT_EQ(refs_not_in.count(2), 1); // name field
}

TEST_F(ReferenceVisitorTest, MixedBoundAndUnbound) {
auto bound_pred = Expressions::Equal("name", Literal::String("Alice"));
ICEBERG_UNWRAP_OR_FAIL(auto pred1, Bind(bound_pred));
auto unbound_pred = Expressions::GreaterThan("age", Literal::Int(25));
auto mixed_and = Expressions::And(pred1, unbound_pred);

auto result = ReferenceVisitor::GetReferencedFieldIds(mixed_and);
EXPECT_THAT(result, IsError(ErrorKind::kInvalidExpression));
EXPECT_THAT(result,
HasErrorMessage("Cannot get referenced field IDs from unbound predicate"));
}

TEST_F(ReferenceVisitorTest, AllFields) {
// Create expression referencing all fields in the schema
auto pred1 = Expressions::NotNull("id");
auto pred2 = Expressions::Equal("name", Literal::String("Test"));
auto pred3 = Expressions::GreaterThan("age", Literal::Int(0));
auto pred4 = Expressions::LessThan("salary", Literal::Double(100000.0));
auto pred5 = Expressions::Equal("active", Literal::Boolean(true));

auto and1 = Expressions::And(pred1, pred2);
auto and2 = Expressions::And(pred3, pred4);
auto and3 = Expressions::And(and1, and2);
auto all_fields = Expressions::And(and3, pred5);

ICEBERG_UNWRAP_OR_FAIL(auto bound_all, Bind(all_fields));
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_all));

// Should reference all 5 fields
EXPECT_EQ(refs.size(), 4);
EXPECT_EQ(refs.count(1), 0); // id field is optimized out
EXPECT_EQ(refs.count(2), 1); // name field
EXPECT_EQ(refs.count(3), 1); // age field
EXPECT_EQ(refs.count(4), 1); // salary field
EXPECT_EQ(refs.count(5), 1); // active field
}

} // namespace iceberg
Loading