Skip to content

Commit a58a574

Browse files
committed
test: add test for select after deduplicated join on DataFrame
- Verify that selecting columns works correctly after a join with deduplicate=True - Confirm the joined DataFrame contains only matching IDs - Test selecting single and multiple columns post-join to ensure correct data retrieval
1 parent 42f5a72 commit a58a574

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

python/tests/test_dataframe.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2640,29 +2640,31 @@ def trigger_interrupt():
26402640
def test_join_deduplicate_select():
26412641
"""Test that select works correctly after a deduplicated join."""
26422642
ctx = SessionContext()
2643-
2643+
26442644
left_df = ctx.from_pydict({"id": [1, 2, 3], "name": ["Alice", "Bob", "Charlie"]})
2645-
right_df = ctx.from_pydict({"id": [2, 3, 4], "city": ["New York", "London", "Paris"]})
2646-
2645+
right_df = ctx.from_pydict(
2646+
{"id": [2, 3, 4], "city": ["New York", "London", "Paris"]}
2647+
)
2648+
26472649
# Join and select the id column to confirm it works
2648-
joined_df = left_df.join(right_df, on="id")
2650+
joined_df = left_df.join(right_df, on="id", deduplicate=True)
26492651
selected_df = joined_df.select(column("id"))
26502652
result = selected_df.collect()[0]
2651-
2653+
26522654
# Should have only the matching ids (2, 3)
26532655
expected_ids = [2, 3]
26542656
assert result.column(0).to_pylist() == expected_ids
2655-
2657+
26562658
# Also test selecting multiple columns
26572659
multi_select_df = joined_df.select(column("id"), column("name"), column("city"))
26582660
multi_result = multi_select_df.collect()[0]
2659-
2661+
26602662
expected_data = {
26612663
"id": [2, 3],
2662-
"name": ["Bob", "Charlie"],
2663-
"city": ["New York", "London"]
2664+
"name": ["Bob", "Charlie"],
2665+
"city": ["New York", "London"],
26642666
}
2665-
2667+
26662668
assert multi_result.column(0).to_pylist() == expected_data["id"]
26672669
assert multi_result.column(1).to_pylist() == expected_data["name"]
26682670
assert multi_result.column(2).to_pylist() == expected_data["city"]

0 commit comments

Comments
 (0)