Skip to content

Commit 2bd6ceb

Browse files
committed
Move actual implementation of upsert from Table to Transaction
1 parent c06e320 commit 2bd6ceb

File tree

1 file changed

+113
-66
lines changed

1 file changed

+113
-66
lines changed

pyiceberg/table/__init__.py

Lines changed: 113 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,115 @@ def delete(
685685
if not delete_snapshot.files_affected and not delete_snapshot.rewrites_needed:
686686
warnings.warn("Delete operation did not match any records")
687687

688+
def upsert(
689+
self,
690+
df: pa.Table,
691+
join_cols: Optional[List[str]] = None,
692+
when_matched_update_all: bool = True,
693+
when_not_matched_insert_all: bool = True,
694+
case_sensitive: bool = True,
695+
) -> UpsertResult:
696+
"""Shorthand API for performing an upsert to an iceberg table.
697+
698+
Args:
699+
700+
df: The input dataframe to upsert with the table's data.
701+
join_cols: Columns to join on, if not provided, it will use the identifier-field-ids.
702+
when_matched_update_all: Bool indicating to update rows that are matched but require an update due to a value in a non-key column changing
703+
when_not_matched_insert_all: Bool indicating new rows to be inserted that do not match any existing rows in the table
704+
case_sensitive: Bool indicating if the match should be case-sensitive
705+
706+
To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids
707+
708+
Example Use Cases:
709+
Case 1: Both Parameters = True (Full Upsert)
710+
Existing row found → Update it
711+
New row found → Insert it
712+
713+
Case 2: when_matched_update_all = False, when_not_matched_insert_all = True
714+
Existing row found → Do nothing (no updates)
715+
New row found → Insert it
716+
717+
Case 3: when_matched_update_all = True, when_not_matched_insert_all = False
718+
Existing row found → Update it
719+
New row found → Do nothing (no inserts)
720+
721+
Case 4: Both Parameters = False (No Merge Effect)
722+
Existing row found → Do nothing
723+
New row found → Do nothing
724+
(Function effectively does nothing)
725+
726+
727+
Returns:
728+
An UpsertResult class (contains details of rows updated and inserted)
729+
"""
730+
try:
731+
import pyarrow as pa # noqa: F401
732+
except ModuleNotFoundError as e:
733+
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e
734+
735+
from pyiceberg.io.pyarrow import expression_to_pyarrow
736+
from pyiceberg.table import upsert_util
737+
738+
if join_cols is None:
739+
join_cols = []
740+
for field_id in df.schema.identifier_field_ids:
741+
col = df.schema.find_column_name(field_id)
742+
if col is not None:
743+
join_cols.append(col)
744+
else:
745+
raise ValueError(f"Field-ID could not be found: {join_cols}")
746+
747+
if len(join_cols) == 0:
748+
raise ValueError("Join columns could not be found, please set identifier-field-ids or pass in explicitly.")
749+
750+
if not when_matched_update_all and not when_not_matched_insert_all:
751+
raise ValueError("no upsert options selected...exiting")
752+
753+
if upsert_util.has_duplicate_rows(df, join_cols):
754+
raise ValueError("Duplicate rows found in source dataset based on the key columns. No upsert executed")
755+
756+
from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible
757+
758+
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
759+
_check_pyarrow_schema_compatible(
760+
df.schema, provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
761+
)
762+
763+
# get list of rows that exist so we don't have to load the entire target table
764+
matched_predicate = upsert_util.create_match_filter(df, join_cols)
765+
matched_iceberg_table = df.scan(row_filter=matched_predicate, case_sensitive=case_sensitive).to_arrow()
766+
767+
update_row_cnt = 0
768+
insert_row_cnt = 0
769+
770+
if when_matched_update_all:
771+
# function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed
772+
# we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed
773+
# this extra step avoids unnecessary IO and writes
774+
rows_to_update = upsert_util.get_rows_to_update(df, matched_iceberg_table, join_cols)
775+
776+
update_row_cnt = len(rows_to_update)
777+
778+
if len(rows_to_update) > 0:
779+
# build the match predicate filter
780+
overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols)
781+
782+
self.overwrite(rows_to_update, overwrite_filter=overwrite_mask_predicate)
783+
784+
if when_not_matched_insert_all:
785+
expr_match = upsert_util.create_match_filter(matched_iceberg_table, join_cols)
786+
expr_match_bound = bind(df.schema, expr_match, case_sensitive=case_sensitive)
787+
expr_match_arrow = expression_to_pyarrow(expr_match_bound)
788+
rows_to_insert = df.filter(~expr_match_arrow)
789+
790+
insert_row_cnt = len(rows_to_insert)
791+
792+
if insert_row_cnt > 0:
793+
self.append(rows_to_insert)
794+
795+
return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt)
796+
688797
def add_files(
689798
self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True
690799
) -> None:
@@ -1149,73 +1258,11 @@ def upsert(
11491258
Returns:
11501259
An UpsertResult class (contains details of rows updated and inserted)
11511260
"""
1152-
try:
1153-
import pyarrow as pa # noqa: F401
1154-
except ModuleNotFoundError as e:
1155-
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e
1156-
1157-
from pyiceberg.io.pyarrow import expression_to_pyarrow
1158-
from pyiceberg.table import upsert_util
1159-
1160-
if join_cols is None:
1161-
join_cols = []
1162-
for field_id in self.schema().identifier_field_ids:
1163-
col = self.schema().find_column_name(field_id)
1164-
if col is not None:
1165-
join_cols.append(col)
1166-
else:
1167-
raise ValueError(f"Field-ID could not be found: {join_cols}")
1168-
1169-
if len(join_cols) == 0:
1170-
raise ValueError("Join columns could not be found, please set identifier-field-ids or pass in explicitly.")
1171-
1172-
if not when_matched_update_all and not when_not_matched_insert_all:
1173-
raise ValueError("no upsert options selected...exiting")
1174-
1175-
if upsert_util.has_duplicate_rows(df, join_cols):
1176-
raise ValueError("Duplicate rows found in source dataset based on the key columns. No upsert executed")
1177-
1178-
from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible
1179-
1180-
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
1181-
_check_pyarrow_schema_compatible(
1182-
self.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
1183-
)
1184-
1185-
# get list of rows that exist so we don't have to load the entire target table
1186-
matched_predicate = upsert_util.create_match_filter(df, join_cols)
1187-
matched_iceberg_table = self.scan(row_filter=matched_predicate, case_sensitive=case_sensitive).to_arrow()
1188-
1189-
update_row_cnt = 0
1190-
insert_row_cnt = 0
1191-
11921261
with self.transaction() as tx:
1193-
if when_matched_update_all:
1194-
# function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed
1195-
# we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed
1196-
# this extra step avoids unnecessary IO and writes
1197-
rows_to_update = upsert_util.get_rows_to_update(df, matched_iceberg_table, join_cols)
1198-
1199-
update_row_cnt = len(rows_to_update)
1200-
1201-
if len(rows_to_update) > 0:
1202-
# build the match predicate filter
1203-
overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols)
1204-
1205-
tx.overwrite(rows_to_update, overwrite_filter=overwrite_mask_predicate)
1206-
1207-
if when_not_matched_insert_all:
1208-
expr_match = upsert_util.create_match_filter(matched_iceberg_table, join_cols)
1209-
expr_match_bound = bind(self.schema(), expr_match, case_sensitive=case_sensitive)
1210-
expr_match_arrow = expression_to_pyarrow(expr_match_bound)
1211-
rows_to_insert = df.filter(~expr_match_arrow)
1212-
1213-
insert_row_cnt = len(rows_to_insert)
1214-
1215-
if insert_row_cnt > 0:
1216-
tx.append(rows_to_insert)
1217-
1218-
return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt)
1262+
return tx.upsert(
1263+
df=df, join_cols=join_cols, when_matched_update_all=when_matched_update_all, when_not_matched_insert_all=when_not_matched_insert_all,
1264+
case_sensitive=case_sensitive
1265+
)
12191266

12201267
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
12211268
"""

0 commit comments

Comments
 (0)