diff --git a/pyiceberg/catalog/__init__.py b/pyiceberg/catalog/__init__.py index a143ad09e5..44d8c6ec8b 100644 --- a/pyiceberg/catalog/__init__.py +++ b/pyiceberg/catalog/__init__.py @@ -26,6 +26,7 @@ from enum import Enum from typing import ( TYPE_CHECKING, + Any, Callable, Dict, List, @@ -793,6 +794,33 @@ def _delete_old_metadata(io: FileIO, base: TableMetadata, metadata: TableMetadat removed_previous_metadata_files.difference_update(current_metadata_files) delete_files(io, removed_previous_metadata_files, METADATA) + def close(self) -> None: # noqa: B027 + """Close the catalog and release any resources. + + This method should be called when the catalog is no longer needed to ensure + proper cleanup of resources like database connections, file handles, etc. + + Default implementation does nothing. Override in subclasses that need cleanup. + """ + + def __enter__(self) -> "Catalog": + """Enter the context manager. + + Returns: + Catalog: The catalog instance. + """ + return self + + def __exit__(self, exc_type: Optional[type], exc_val: Optional[BaseException], exc_tb: Optional[Any]) -> None: + """Exit the context manager and close the catalog. + + Args: + exc_type: Exception type if an exception occurred. + exc_val: Exception value if an exception occurred. + exc_tb: Exception traceback if an exception occurred. + """ + self.close() + def __repr__(self) -> str: """Return the string representation of the Catalog class.""" return f"{self.name} ({self.__class__})" diff --git a/pyiceberg/catalog/sql.py b/pyiceberg/catalog/sql.py index 0167b5a1c1..c0746dc983 100644 --- a/pyiceberg/catalog/sql.py +++ b/pyiceberg/catalog/sql.py @@ -733,3 +733,14 @@ def view_exists(self, identifier: Union[str, Identifier]) -> bool: def drop_view(self, identifier: Union[str, Identifier]) -> None: raise NotImplementedError + + def close(self) -> None: + """Close the catalog and release database connections. + + This method closes the SQLAlchemy engine and disposes of all connection pools. + This ensures that any cached connections are properly closed, which is especially + important for blobfuse scenarios where file handles need to be closed for + data to be flushed to persistent storage. + """ + if hasattr(self, "engine"): + self.engine.dispose() diff --git a/tests/catalog/test_base.py b/tests/catalog/test_base.py index 01fea25dfa..80c01f70fa 100644 --- a/tests/catalog/test_base.py +++ b/tests/catalog/test_base.py @@ -48,7 +48,7 @@ ) from pyiceberg.transforms import IdentityTransform from pyiceberg.typedef import EMPTY_DICT, Properties -from pyiceberg.types import IntegerType, LongType, NestedField +from pyiceberg.types import IntegerType, LongType, NestedField, StringType @pytest.fixture @@ -631,3 +631,24 @@ def test_table_metadata_writes_reflect_latest_path(catalog: InMemoryCatalog) -> table.transaction().set_properties({TableProperties.WRITE_METADATA_PATH: new_metadata_path}).commit_transaction() assert table.location_provider().new_metadata_location("metadata.json") == f"{new_metadata_path}/metadata.json" + + +class TestCatalogClose: + """Test catalog close functionality.""" + + def test_in_memory_catalog_close(self, catalog: InMemoryCatalog) -> None: + """Test that InMemoryCatalog close method works.""" + # Should not raise any exception + catalog.close() + + def test_in_memory_catalog_context_manager(self, catalog: InMemoryCatalog) -> None: + """Test that InMemoryCatalog works as a context manager.""" + with InMemoryCatalog("test") as cat: + assert cat.name == "test" + # Create a namespace and table to test functionality + cat.create_namespace("test_db") + schema = Schema(NestedField(1, "name", StringType(), required=True)) + cat.create_table(("test_db", "test_table"), schema) + + # InMemoryCatalog inherits close from SqlCatalog, so engine should be disposed + assert hasattr(cat, "engine") diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py index 235951484f..27105e8004 100644 --- a/tests/catalog/test_sql.py +++ b/tests/catalog/test_sql.py @@ -60,7 +60,7 @@ ) from pyiceberg.transforms import IdentityTransform from pyiceberg.typedef import Identifier -from pyiceberg.types import IntegerType, strtobool +from pyiceberg.types import IntegerType, NestedField, StringType, strtobool CATALOG_TABLES = [c.__tablename__ for c in SqlCatalogBaseTable.__subclasses__()] @@ -1704,3 +1704,56 @@ def test_delete_metadata_multiple(catalog: SqlCatalog, table_schema_nested: Sche assert not os.path.exists(original_metadata_location[len("file://") :]) assert not os.path.exists(updated_metadata_1.metadata_file[len("file://") :]) assert os.path.exists(updated_metadata_2.metadata_file[len("file://") :]) + + +class TestSqlCatalogClose: + """Test SqlCatalog close functionality.""" + + def test_sql_catalog_close(self, catalog_sqlite: SqlCatalog) -> None: + """Test that SqlCatalog close method properly disposes the engine.""" + # Verify engine exists + assert hasattr(catalog_sqlite, "engine") + + # Close the catalog + catalog_sqlite.close() + + # Verify engine is disposed by checking that the engine still exists + assert hasattr(catalog_sqlite, "engine") + + def test_sql_catalog_context_manager(self, warehouse: Path) -> None: + """Test that SqlCatalog works as a context manager.""" + with SqlCatalog("test", uri="sqlite:///:memory:", warehouse=str(warehouse)) as catalog: + # Verify engine exists + assert hasattr(catalog, "engine") + + # Create a namespace and table to test functionality + catalog.create_namespace("test_db") + schema = Schema(NestedField(1, "name", StringType(), required=True)) + catalog.create_table(("test_db", "test_table"), schema) + + # Verify engine is disposed after exiting context + assert hasattr(catalog, "engine") + + def test_sql_catalog_context_manager_with_exception(self) -> None: + """Test that SqlCatalog context manager properly closes even with exceptions.""" + catalog = None + try: + with SqlCatalog("test", uri="sqlite:///:memory:") as cat: + catalog = cat + # Verify engine exists + assert hasattr(catalog, "engine") + raise ValueError("Test exception") + except ValueError: + pass + + # Verify engine is disposed even after exception + assert catalog is not None + assert hasattr(catalog, "engine") + + def test_sql_catalog_multiple_close_calls(self, catalog_sqlite: SqlCatalog) -> None: + """Test that multiple close calls on SqlCatalog are safe.""" + # First close + catalog_sqlite.close() + + # Second close should not raise an exception + catalog_sqlite.close()