diff --git a/bigframes/core/compile/polars/compiler.py b/bigframes/core/compile/polars/compiler.py index 12f944c211..a70ea49752 100644 --- a/bigframes/core/compile/polars/compiler.py +++ b/bigframes/core/compile/polars/compiler.py @@ -513,6 +513,30 @@ def compile_join(self, node: nodes.JoinNode): left, right, node.type, left_on, right_on, node.joins_nulls ) + @compile_node.register + def compile_isin(self, node: nodes.InNode): + left = self.compile_node(node.left_child) + right = self.compile_node(node.right_child).unique(node.right_col.id.sql) + right = right.with_columns(pl.lit(True).alias(node.indicator_col.sql)) + + left_ex, right_ex = lowering._coerce_comparables(node.left_col, node.right_col) + + left_pl_ex = self.expr_compiler.compile_expression(left_ex) + right_pl_ex = self.expr_compiler.compile_expression(right_ex) + + joined = left.join( + right, + how="left", + left_on=left_pl_ex, + right_on=right_pl_ex, + # Note: join_nulls renamed to nulls_equal for polars 1.24 + join_nulls=node.joins_nulls, # type: ignore + coalesce=False, + ) + passthrough = [pl.col(id) for id in left.columns] + indicator = pl.col(node.indicator_col.sql).fill_null(False) + return joined.select((*passthrough, indicator)) + def _ordered_join( self, left_frame: pl.LazyFrame, diff --git a/bigframes/core/rewrite/schema_binding.py b/bigframes/core/rewrite/schema_binding.py index 40a00ff8f6..f7f2ca8c59 100644 --- a/bigframes/core/rewrite/schema_binding.py +++ b/bigframes/core/rewrite/schema_binding.py @@ -65,6 +65,16 @@ def bind_schema_to_node( node, conditions=conditions, ) + if isinstance(node, nodes.InNode): + return dataclasses.replace( + node, + left_col=ex.ResolvedDerefOp.from_field( + node.left_child.field_by_id[node.left_col.id] + ), + right_col=ex.ResolvedDerefOp.from_field( + node.right_child.field_by_id[node.right_col.id] + ), + ) if isinstance(node, nodes.AggregateNode): aggregations = [] diff --git a/bigframes/session/polars_executor.py b/bigframes/session/polars_executor.py index 5dbaa30c2f..9b2346a7ed 100644 --- a/bigframes/session/polars_executor.py +++ b/bigframes/session/polars_executor.py @@ -39,6 +39,7 @@ nodes.FilterNode, nodes.ConcatNode, nodes.JoinNode, + nodes.InNode, ) _COMPATIBLE_SCALAR_OPS = ( diff --git a/tests/system/small/engines/test_join.py b/tests/system/small/engines/test_join.py index 402a41134b..91c199a437 100644 --- a/tests/system/small/engines/test_join.py +++ b/tests/system/small/engines/test_join.py @@ -88,3 +88,22 @@ def test_engines_cross_join( result, _ = scalars_array_value.relational_join(scalars_array_value, type="cross") assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize( + ("left_key", "right_key"), + [ + ("int64_col", "float64_col"), + ("float64_col", "int64_col"), + ("int64_too", "int64_col"), + ], +) +def test_engines_isin( + scalars_array_value: array_value.ArrayValue, engine, left_key, right_key +): + result, _ = scalars_array_value.isin( + scalars_array_value, lcol=left_key, rcol=right_key + ) + + assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine)