Skip to content

Commit bd9a6fb

Browse files
committed
fix hive client
1 parent 5cf26a9 commit bd9a6fb

File tree

2 files changed

+51
-26
lines changed

2 files changed

+51
-26
lines changed

pyiceberg/catalog/hive.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@
3838
CheckLockRequest,
3939
EnvironmentContext,
4040
FieldSchema,
41+
GetTableRequest,
42+
GetTableResult,
43+
GetTablesRequest,
44+
GetTablesResult,
4145
InvalidOperationException,
4246
LockComponent,
4347
LockLevel,
@@ -389,7 +393,10 @@ def _create_hive_table(self, open_client: Client, hive_table: HiveTable) -> None
389393

390394
def _get_hive_table(self, open_client: Client, database_name: str, table_name: str) -> HiveTable:
391395
try:
392-
return open_client.get_table_objects_by_name(dbname=database_name, tbl_names=[table_name]).pop()
396+
get_table_result: GetTableResult = open_client.get_table_req(
397+
req=GetTableRequest(dbName=database_name, tblName=table_name)
398+
)
399+
return get_table_result.table
393400
except IndexError as e:
394401
raise NoSuchTableError(f"Table does not exists: {table_name}") from e
395402

@@ -436,7 +443,7 @@ def create_table(
436443
with self._client as open_client:
437444
self._create_hive_table(open_client, tbl)
438445
try:
439-
hive_table = open_client.get_table_objects_by_name(dbname=database_name, tbl_names=[table_name]).pop()
446+
hive_table = self._get_hive_table(open_client, database_name, table_name)
440447
except IndexError as e:
441448
raise NoSuchObjectException("get_table failed: unknown result") from e
442449

@@ -469,7 +476,7 @@ def register_table(self, identifier: Union[str, Identifier], metadata_location:
469476
with self._client as open_client:
470477
self._create_hive_table(open_client, tbl)
471478
try:
472-
hive_table = open_client.get_table_objects_by_name(dbname=database_name, tbl_names=[table_name]).pop()
479+
hive_table = self._get_hive_table(open_client, database_name, table_name)
473480
except IndexError as e:
474481
raise NoSuchObjectException("get_table failed: unknown result") from e
475482

@@ -663,7 +670,7 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U
663670
try:
664671
with self._client as open_client:
665672
try:
666-
tbl = open_client.get_table_objects_by_name(dbname=from_database_name, tbl_names=[from_table_name]).pop()
673+
tbl = self._get_hive_table(open_client, from_database_name, from_table_name)
667674
except IndexError as e:
668675
raise NoSuchObjectException("get_table failed: unknown result") from e
669676
tbl.dbName = to_database_name
@@ -735,11 +742,13 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]:
735742
"""
736743
database_name = self.identifier_to_database(namespace, NoSuchNamespaceError)
737744
with self._client as open_client:
745+
all_table_names = open_client.get_all_tables(db_name=database_name)
746+
get_tables_result: GetTablesResult = open_client.get_table_objects_by_name_req(
747+
req=GetTablesRequest(dbName=database_name, tblNames=all_table_names)
748+
)
738749
return [
739750
(database_name, table.tableName)
740-
for table in open_client.get_table_objects_by_name(
741-
dbname=database_name, tbl_names=open_client.get_all_tables(db_name=database_name)
742-
)
751+
for table in get_tables_result.tables
743752
if table.parameters.get(TABLE_TYPE, "").lower() == ICEBERG
744753
]
745754

tests/catalog/test_hive.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
AlreadyExistsException,
3232
EnvironmentContext,
3333
FieldSchema,
34+
GetTableRequest,
35+
GetTableResult,
36+
GetTablesRequest,
37+
GetTablesResult,
3438
InvalidOperationException,
3539
LockResponse,
3640
LockState,
@@ -280,7 +284,7 @@ def test_create_table(
280284

281285
catalog._client = MagicMock()
282286
catalog._client.__enter__().create_table.return_value = None
283-
catalog._client.__enter__().get_table_objects_by_name.return_value = [hive_table]
287+
catalog._client.__enter__().get_table_req.return_value = GetTableResult(table=hive_table)
284288
catalog._client.__enter__().get_database.return_value = hive_database
285289
catalog.create_table(("default", "table"), schema=table_schema_with_all_types, properties={"owner": "javaberg"})
286290

@@ -459,7 +463,7 @@ def test_create_table_with_given_location_removes_trailing_slash(
459463

460464
catalog._client = MagicMock()
461465
catalog._client.__enter__().create_table.return_value = None
462-
catalog._client.__enter__().get_table_objects_by_name.return_value = [hive_table]
466+
catalog._client.__enter__().get_table_req.return_value = GetTableResult(table=hive_table)
463467
catalog._client.__enter__().get_database.return_value = hive_database
464468
catalog.create_table(
465469
("default", "table"), schema=table_schema_with_all_types, properties={"owner": "javaberg"}, location=f"{location}/"
@@ -633,7 +637,7 @@ def test_create_v1_table(table_schema_simple: Schema, hive_database: HiveDatabas
633637

634638
catalog._client = MagicMock()
635639
catalog._client.__enter__().create_table.return_value = None
636-
catalog._client.__enter__().get_table_objects_by_name.return_value = [hive_table]
640+
catalog._client.__enter__().get_table_req.return_value = GetTableResult(table=hive_table)
637641
catalog._client.__enter__().get_database.return_value = hive_database
638642
catalog.create_table(
639643
("default", "table"), schema=table_schema_simple, properties={"owner": "javaberg", "format-version": "1"}
@@ -684,10 +688,10 @@ def test_load_table(hive_table: HiveTable) -> None:
684688
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
685689

686690
catalog._client = MagicMock()
687-
catalog._client.__enter__().get_table_objects_by_name.return_value = [hive_table]
691+
catalog._client.__enter__().get_table_req.return_value = GetTableResult(table=hive_table)
688692
table = catalog.load_table(("default", "new_tabl2e"))
689693

690-
catalog._client.__enter__().get_table_objects_by_name.assert_called_with(dbname="default", tbl_names=["new_tabl2e"])
694+
catalog._client.__enter__().get_table_req.assert_called_with(req=GetTableRequest(dbName="default", tblName="new_tabl2e"))
691695

692696
expected = TableMetadataV2(
693697
location="s3://bucket/test/location",
@@ -784,11 +788,11 @@ def test_load_table_from_self_identifier(hive_table: HiveTable) -> None:
784788
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
785789

786790
catalog._client = MagicMock()
787-
catalog._client.__enter__().get_table_objects_by_name.side_effect = lambda dbname, tbl_names: [hive_table]
791+
catalog._client.__enter__().get_table_req.return_value = GetTableResult(table=hive_table)
788792
intermediate = catalog.load_table(("default", "new_tabl2e"))
789793
table = catalog.load_table(intermediate.name())
790794

791-
catalog._client.__enter__().get_table_objects_by_name.assert_called_with(dbname="default", tbl_names=["new_tabl2e"])
795+
catalog._client.__enter__().get_table_req.assert_called_with(req=GetTableRequest(dbName="default", tblName="new_tabl2e"))
792796

793797
expected = TableMetadataV2(
794798
location="s3://bucket/test/location",
@@ -889,7 +893,10 @@ def test_rename_table(hive_table: HiveTable) -> None:
889893
renamed_table.tableName = "new_tabl3e"
890894

891895
catalog._client = MagicMock()
892-
catalog._client.__enter__().get_table_objects_by_name.side_effect = [[hive_table], [renamed_table]]
896+
catalog._client.__enter__().get_table_req.side_effect = [
897+
GetTableResult(table=hive_table),
898+
GetTableResult(table=renamed_table),
899+
]
893900
catalog._client.__enter__().alter_table_with_environment_context.return_value = None
894901

895902
from_identifier = ("default", "new_tabl2e")
@@ -898,8 +905,11 @@ def test_rename_table(hive_table: HiveTable) -> None:
898905

899906
assert table.name() == to_identifier
900907

901-
calls = [call(dbname="default", tbl_names=["new_tabl2e"]), call(dbname="default", tbl_names=["new_tabl3e"])]
902-
catalog._client.__enter__().get_table_objects_by_name.assert_has_calls(calls)
908+
expected_calls = [
909+
call(req=GetTableRequest(dbName="default", tblName="new_tabl2e")),
910+
call(req=GetTableRequest(dbName="default", tblName="new_tabl3e")),
911+
]
912+
catalog._client.__enter__().get_table_req.assert_has_calls(expected_calls)
903913
catalog._client.__enter__().alter_table_with_environment_context.assert_called_with(
904914
dbname="default",
905915
tbl_name="new_tabl2e",
@@ -912,25 +922,31 @@ def test_rename_table_from_self_identifier(hive_table: HiveTable) -> None:
912922
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
913923

914924
catalog._client = MagicMock()
915-
catalog._client.__enter__().get_table_objects_by_name.return_value = [hive_table]
925+
catalog._client.__enter__().get_table_req.return_value = GetTableResult(table=hive_table)
916926

917927
from_identifier = ("default", "new_tabl2e")
918928
from_table = catalog.load_table(from_identifier)
919-
catalog._client.__enter__().get_table_objects_by_name.assert_called_with(dbname="default", tbl_names=["new_tabl2e"])
929+
catalog._client.__enter__().get_table_req.assert_called_with(req=GetTableRequest(dbName="default", tblName="new_tabl2e"))
920930

921931
renamed_table = copy.deepcopy(hive_table)
922932
renamed_table.dbName = "default"
923933
renamed_table.tableName = "new_tabl3e"
924934

925-
catalog._client.__enter__().get_table_objects_by_name.side_effect = [[hive_table], [renamed_table]]
935+
catalog._client.__enter__().get_table_req.side_effect = [
936+
GetTableResult(table=hive_table),
937+
GetTableResult(table=renamed_table),
938+
]
926939
catalog._client.__enter__().alter_table_with_environment_context.return_value = None
927940
to_identifier = ("default", "new_tabl3e")
928941
table = catalog.rename_table(from_table.name(), to_identifier)
929942

930943
assert table.name() == to_identifier
931944

932-
calls = [call(dbname="default", tbl_names=["new_tabl2e"]), call(dbname="default", tbl_names=["new_tabl3e"])]
933-
catalog._client.__enter__().get_table_objects_by_name.assert_has_calls(calls)
945+
expected_calls = [
946+
call(req=GetTableRequest(dbName="default", tblName="new_tabl2e")),
947+
call(req=GetTableRequest(dbName="default", tblName="new_tabl3e")),
948+
]
949+
catalog._client.__enter__().get_table_req.assert_has_calls(expected_calls)
934950
catalog._client.__enter__().alter_table_with_environment_context.assert_called_with(
935951
dbname="default",
936952
tbl_name="new_tabl2e",
@@ -1013,13 +1029,13 @@ def test_list_tables(hive_table: HiveTable) -> None:
10131029

10141030
catalog._client = MagicMock()
10151031
catalog._client.__enter__().get_all_tables.return_value = ["table1", "table2", "table3", "table4"]
1016-
catalog._client.__enter__().get_table_objects_by_name.return_value = [tbl1, tbl2, tbl3, tbl4]
1032+
catalog._client.__enter__().get_table_objects_by_name_req.return_value = GetTablesResult(tables=[tbl1, tbl2, tbl3, tbl4])
10171033

10181034
got_tables = catalog.list_tables("database")
10191035
assert got_tables == [("database", "table1"), ("database", "table2")]
10201036
catalog._client.__enter__().get_all_tables.assert_called_with(db_name="database")
1021-
catalog._client.__enter__().get_table_objects_by_name.assert_called_with(
1022-
dbname="database", tbl_names=["table1", "table2", "table3", "table4"]
1037+
catalog._client.__enter__().get_table_objects_by_name_req.assert_called_with(
1038+
req=GetTablesRequest(dbName="database", tblNames=["table1", "table2", "table3", "table4"])
10231039
)
10241040

10251041

@@ -1049,7 +1065,7 @@ def test_drop_table_from_self_identifier(hive_table: HiveTable) -> None:
10491065
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
10501066

10511067
catalog._client = MagicMock()
1052-
catalog._client.__enter__().get_table_objects_by_name.return_value = [hive_table]
1068+
catalog._client.__enter__().get_table_req.return_value = GetTableResult(table=hive_table)
10531069
table = catalog.load_table(("default", "new_tabl2e"))
10541070

10551071
catalog._client.__enter__().get_all_databases.return_value = ["namespace1", "namespace2"]

0 commit comments

Comments
 (0)