diff --git a/README.md b/README.md index 3ee9120..31074e1 100644 --- a/README.md +++ b/README.md @@ -10,15 +10,16 @@ A Snakemake storage plugin for downloading files via HTTP with local caching, ch **Supported sources:** - **zenodo.org** - Zenodo data repository (checksum from API) - **data.pypsa.org** - PyPSA data repository (checksum from manifest.yaml) +- **storage.googleapis.com** - Google Cloud Storage (checksum from GCS JSON API) ## Features - **Local caching**: Downloads are cached to avoid redundant transfers (can be disabled) -- **Checksum verification**: Automatically verifies checksums (from Zenodo API or data.pypsa.org manifests) +- **Checksum verification**: Automatically verifies checksums (from Zenodo API, data.pypsa.org manifests, or GCS object metadata) - **Rate limit handling**: Automatically respects Zenodo's rate limits using `X-RateLimit-*` headers with exponential backoff retry -- **Concurrent download control**: Limits simultaneous downloads to prevent overwhelming Zenodo +- **Concurrent download control**: Limits simultaneous downloads to prevent overwhelming servers - **Progress bars**: Shows download progress with tqdm -- **Immutable URLs**: Returns mtime=0 since Zenodo URLs are persistent +- **Immutable URLs**: Returns mtime=0 for Zenodo and data.pypsa.org (persistent URLs); uses actual mtime for GCS - **Environment variable support**: Configure via environment variables for CI/CD workflows ## Installation @@ -66,7 +67,7 @@ If you don't explicitly configure it, the plugin will use default settings autom ## Usage -Use Zenodo or data.pypsa.org URLs directly in your rules. Snakemake automatically detects supported URLs and routes them to this plugin: +Use Zenodo, data.pypsa.org, or Google Cloud Storage URLs directly in your rules. Snakemake automatically detects supported URLs and routes them to this plugin: ```python rule download_zenodo: @@ -84,6 +85,14 @@ rule download_pypsa: "resources/eez.zip" shell: "cp {input} {output}" + +rule download_gcs: + input: + storage("https://storage.googleapis.com/open-tyndp-data-store/CBA_projects.zip"), + output: + "resources/cba_projects.zip" + shell: + "cp {input} {output}" ``` Or if you configured a tagged storage entity: @@ -107,7 +116,7 @@ The plugin will: - Progress bar showing download status - Automatic rate limit handling with exponential backoff retry - Concurrent download limiting - - Checksum verification (from Zenodo API or data.pypsa.org manifest) + - Checksum verification (from Zenodo API, data.pypsa.org manifest, or GCS metadata) 4. Store in cache for future use (if caching is enabled) ### Example: CI/CD Configuration @@ -139,19 +148,19 @@ The plugin automatically: ## URL Handling -- Handles URLs from `zenodo.org`, `sandbox.zenodo.org`, and `data.pypsa.org` +- Handles URLs from `zenodo.org`, `sandbox.zenodo.org`, `data.pypsa.org`, and `storage.googleapis.com` - Other HTTP(S) URLs are handled by the standard `snakemake-storage-plugin-http` - Both plugins can coexist in the same workflow ### Plugin Priority When using `storage()` without specifying a plugin name, Snakemake checks all installed plugins: -- **Cached HTTP plugin**: Only accepts zenodo.org and data.pypsa.org URLs +- **Cached HTTP plugin**: Only accepts zenodo.org, data.pypsa.org, and storage.googleapis.com URLs - **HTTP plugin**: Accepts all HTTP/HTTPS URLs (including zenodo.org) If both plugins are installed, supported URLs would be ambiguous - both plugins accept them. Typically snakemake would raise an error: **"Multiple suitable storage providers found"** if you try to use `storage()` without specifying which plugin to use, ie. one needs to explicitly call the Cached HTTP provider using `storage.cached_http(url)` instead of `storage(url)`, -but we monkey-patch the http plugin to refuse zenodo.org and data.pypsa.org URLs. +but we monkey-patch the http plugin to refuse zenodo.org, data.pypsa.org, and storage.googleapis.com URLs. ## License diff --git a/src/snakemake_storage_plugin_cached_http/__init__.py b/src/snakemake_storage_plugin_cached_http/__init__.py index b4dd009..52fec1c 100644 --- a/src/snakemake_storage_plugin_cached_http/__init__.py +++ b/src/snakemake_storage_plugin_cached_http/__init__.py @@ -2,16 +2,19 @@ # # SPDX-License-Identifier: MIT +import asyncio +import base64 import hashlib import json import shutil import time from contextlib import asynccontextmanager from dataclasses import dataclass, field +from datetime import datetime from logging import Logger from pathlib import Path from posixpath import basename, dirname, join, normpath, relpath -from urllib.parse import urlparse +from urllib.parse import quote, urlparse import httpx import platformdirs @@ -81,7 +84,7 @@ class StorageProviderSettings(SettingsBase): "env_var": True, }, ) - max_concurrent_downloads: int | None = field( + max_concurrent_downloads: int = field( default=3, metadata={ "help": "Maximum number of concurrent downloads.", @@ -92,10 +95,11 @@ class StorageProviderSettings(SettingsBase): @dataclass class FileMetadata: - """Metadata for a file in a Zenodo or data.pypsa.org record.""" + """Metadata for a file in a Zenodo, data.pypsa.org, or GCS record.""" checksum: str | None size: int + mtime: float = 0 # modification time (Unix timestamp), used for GCS redirect: str | None = None # used to indicate data.pypsa.org redirection @@ -144,6 +148,7 @@ def __post_init__(self): # Cache for record metadata to avoid repeated API calls self._zenodo_record_cache: dict[str, dict[str, FileMetadata]] = {} self._pypsa_manifest_cache: dict[str, dict[str, FileMetadata]] = {} + self._gcs_metadata_cache: dict[str, FileMetadata] = {} @override def use_rate_limiter(self) -> bool: @@ -173,6 +178,11 @@ def example_queries(cls) -> list[ExampleQuery]: description="A data pypsa file URL", type=QueryType.INPUT, ), + ExampleQuery( + query="https://storage.googleapis.com/open-tyndp-data-store/CBA_projects.zip", + description="A Google Cloud Storage file URL", + type=QueryType.INPUT, + ), ] @override @@ -185,7 +195,7 @@ def is_valid_query(cls, query: str) -> StorageQueryValidationResult: return StorageQueryValidationResult( query=query, valid=False, - reason="Only zenodo.org and data.pypsa.org URLs are handled by this plugin", + reason="Only zenodo.org, data.pypsa.org, and storage.googleapis.com URLs are handled by this plugin", ) @override @@ -288,9 +298,24 @@ async def get_metadata(self, path: str, netloc: str) -> FileMetadata | None: return await self.get_zenodo_metadata(path, netloc) elif netloc == "data.pypsa.org": return await self.get_pypsa_metadata(path, netloc) + elif netloc == "storage.googleapis.com": + return await self.get_gcs_metadata(path, netloc) raise WorkflowError( - "Cached-http storage plugin is only implemented for zenodo.org and data.pypsa.org urls" + "Cached-http storage plugin is only implemented for zenodo.org, data.pypsa.org, and storage.googleapis.com urls" + ) + + @staticmethod + def is_immutable(netloc: str): + if netloc in ("zenodo.org", "sandbox.zenodo.org"): + return True + elif netloc == "data.pypsa.org": + return True + elif netloc == "storage.googleapis.com": + return False + + raise WorkflowError( + "Cached-http storage plugin is only implemented for zenodo.org, data.pypsa.org, and storage.googleapis.com urls" ) async def get_zenodo_metadata(self, path: str, netloc: str) -> FileMetadata | None: @@ -407,6 +432,73 @@ async def get_pypsa_metadata(self, path: str, netloc: str) -> FileMetadata | Non filename = relpath(path, base_path) return metadata.get(filename) + async def get_gcs_metadata(self, path: str, netloc: str) -> FileMetadata | None: + """ + Retrieve and cache file metadata from Google Cloud Storage. + + Uses the GCS JSON API to fetch object metadata including MD5 hash. + URL format: https://storage.googleapis.com/{bucket}/{object-path} + API endpoint: https://storage.googleapis.com/storage/v1/b/{bucket}/o/{encoded-object} + + Args: + path: Server path (bucket/object-path) + netloc: Network location (storage.googleapis.com) + + Returns: + FileMetadata for the requested file, or None if not found + """ + # Check cache first + if path in self._gcs_metadata_cache: + return self._gcs_metadata_cache[path] + + # Parse bucket and object path from the URL path + # Path format: /{bucket}/{object-path} + parts = path.split("/", maxsplit=1) + if len(parts) < 2: + raise WorkflowError( + f"Invalid GCS URL format: http(s)://{netloc}/{path}. " + f"Expected format: https://storage.googleapis.com/{{bucket}}/{{object-path}}" + ) + + bucket, object_path = parts + + # URL-encode the object path for the API request (slashes must be encoded) + encoded_object = quote(object_path, safe="") + + # GCS JSON API endpoint for object metadata + api_url = f"https://{netloc}/storage/v1/b/{bucket}/o/{encoded_object}" + + async with self.httpr("get", api_url) as response: + if response.status_code == 404: + return None + if response.status_code != 200: + raise WorkflowError( + f"Failed to fetch GCS object metadata: HTTP {response.status_code} ({api_url})" + ) + + content = await response.aread() + data = json.loads(content) + + # GCS returns MD5 as base64-encoded bytes + md5_base64: str | None = data.get("md5Hash") + checksum: str | None = None + if md5_base64: + # Convert base64 to hex digest + md5_bytes = base64.b64decode(md5_base64) + checksum = f"md5:{md5_bytes.hex()}" + + size: int = int(data.get("size", 0)) + + updated: str | None = data.get("updated") + mtime: float = datetime.fromisoformat(updated).timestamp() if updated else 0 + + metadata = FileMetadata(checksum=checksum, size=size, mtime=mtime) + + # Store in cache + self._gcs_metadata_cache[path] = metadata + + return metadata + # Implementation of storage object class StorageObject(StorageObjectRead): @@ -441,7 +533,7 @@ async def managed_exists(self) -> bool: if self.provider.cache: cached = self.provider.cache.get(str(self.query)) - if cached is not None: + if cached is not None and self.provider.is_immutable(self.netloc): return True metadata = await self.provider.get_metadata(self.path, self.netloc) @@ -449,7 +541,11 @@ async def managed_exists(self) -> bool: @override async def managed_mtime(self) -> float: - return 0 + if self.provider.settings.skip_remote_checks: + return 0 + + metadata = await self.provider.get_metadata(self.path, self.netloc) + return metadata.mtime if metadata is not None else 0 @override async def managed_size(self) -> int: @@ -458,11 +554,20 @@ async def managed_size(self) -> int: if self.provider.cache: cached = self.provider.cache.get(str(self.query)) - if cached is not None: + if cached is not None and self.provider.is_immutable(self.netloc): return cached.stat().st_size + else: + cached = None metadata = await self.provider.get_metadata(self.path, self.netloc) - return metadata.size if metadata is not None else 0 + if metadata is None: + return 0 + + if cached is not None: + if cached.stat().st_mtime >= metadata.mtime: + return cached.stat().st_size + + return metadata.size @override async def inventory(self, cache: IOCacheStorageInterface) -> None: @@ -483,17 +588,31 @@ async def inventory(self, cache: IOCacheStorageInterface) -> None: if self.provider.cache: cached = self.provider.cache.get(str(self.query)) - if cached is not None: + if cached is not None and self.provider.is_immutable(self.netloc): cache.exists_in_storage[key] = True - cache.mtime[key] = Mtime(storage=0) + cache.mtime[key] = Mtime(storage=cached.stat().st_mtime) cache.size[key] = cached.stat().st_size return + else: + cached = None metadata = await self.provider.get_metadata(self.path, self.netloc) - exists = metadata is not None - cache.exists_in_storage[key] = exists - cache.mtime[key] = Mtime(storage=0) - cache.size[key] = metadata.size if exists else 0 + if metadata is None: + cache.exists_in_storage[key] = False + cache.mtime[key] = Mtime(storage=0) + cache.size[key] = 0 + return + + if cached is not None: + if cached.stat().st_mtime >= metadata.mtime: + cache.exists_in_storage[key] = True + cache.mtime[key] = Mtime(storage=cached.stat().st_mtime) + cache.size[key] = cached.stat().st_size + return + + cache.exists_in_storage[key] = True + cache.mtime[key] = Mtime(storage=metadata.mtime) + cache.size[key] = metadata.size @override def cleanup(self): @@ -558,17 +677,20 @@ async def managed_retrieve(self): if metadata is not None and metadata.redirect is not None: query = f"https://{self.netloc}/{metadata.redirect}" - # If already in cache, just copy + # If already in cache, check if still valid if self.provider.cache: cached = self.provider.cache.get(query) if cached is not None: - logger.info(f"Retrieved {filename} from cache ({query})") - shutil.copy2(cached, local_path) - return + if self.provider.is_immutable(self.netloc) or ( + metadata is not None and cached.stat().st_mtime >= metadata.mtime + ): + logger.info(f"Retrieved {filename} from cache ({query})") + shutil.copy2(cached, local_path) + return try: - # Download from Zenodo or data.pypsa.org using a get request, rate limit errors are detected and - # raise WorkflowError to trigger a retry + # Download using a get request, rate limit errors are detected and raise + # WorkflowError to trigger a retry async with self.provider.httpr("get", query) as response: if response.status_code != 200: raise WorkflowError( diff --git a/src/snakemake_storage_plugin_cached_http/monkeypatch.py b/src/snakemake_storage_plugin_cached_http/monkeypatch.py index 2036624..c808bb6 100644 --- a/src/snakemake_storage_plugin_cached_http/monkeypatch.py +++ b/src/snakemake_storage_plugin_cached_http/monkeypatch.py @@ -23,6 +23,7 @@ def is_pypsa_or_zenodo_url(url: str) -> bool: "zenodo.org", "sandbox.zenodo.org", "data.pypsa.org", + "storage.googleapis.com", ) and parsed.scheme in ( "http", "https", diff --git a/tests/test_download.py b/tests/test_download.py index 6690542..8ac131d 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -6,6 +6,8 @@ import json import logging +import os +import time import pytest @@ -24,18 +26,33 @@ "path": "records/16810901/files/attributed_ports.json", "netloc": "zenodo.org", "has_size": True, + "has_mtime": False, # Zenodo records are immutable }, "pypsa": { "url": "https://data.pypsa.org/workflows/eur/attributed_ports/2020-07-10/attributed_ports.json", "path": "workflows/eur/attributed_ports/2020-07-10/attributed_ports.json", "netloc": "data.pypsa.org", "has_size": False, # data.pypsa.org manifests don't include size + "has_mtime": False, # data.pypsa.org files are immutable + }, + "gcs": { + "url": "https://storage.googleapis.com/open-tyndp-data-store/cached-http/attributed_ports/archive/2020-07-10/attributed_ports.json", + "path": "open-tyndp-data-store/cached-http/attributed_ports/archive/2020-07-10/attributed_ports.json", + "netloc": "storage.googleapis.com", + "has_size": True, + "has_mtime": True, # GCS provides modification timestamps }, } @pytest.fixture -def storage_provider(tmp_path): +def test_logger(): + """Provide a logger for testing.""" + return logging.getLogger("test") + + +@pytest.fixture +def storage_provider(tmp_path, test_logger): """Create a StorageProvider instance for testing.""" cache_dir = tmp_path / "cache" cache_dir.mkdir() @@ -48,20 +65,24 @@ def storage_provider(tmp_path): max_concurrent_downloads=3, ) - logger = logging.getLogger("test") - provider = StorageProvider( local_prefix=local_prefix, - logger=logger, + logger=test_logger, settings=settings, ) return provider -@pytest.fixture(params=["zenodo", "pypsa"]) +@pytest.fixture(params=["zenodo", "pypsa", "gcs"]) def test_config(request): - """Provide test configuration (parametrized for zenodo and pypsa).""" + """Provide test configuration (parametrized for zenodo, pypsa, and gcs).""" + return TEST_CONFIGS[request.param] + + +@pytest.fixture(params=[k for k, v in TEST_CONFIGS.items() if v["has_mtime"]]) +def mutable_test_config(request): + """Provide test configuration for mutable sources only (those with mtime support).""" return TEST_CONFIGS[request.param] @@ -112,10 +133,13 @@ async def test_storage_object_size(storage_object, test_config): @pytest.mark.asyncio -async def test_storage_object_mtime(storage_object): - """Test that mtime is 0 for immutable URLs.""" +async def test_storage_object_mtime(storage_object, test_config): + """Test that mtime is 0 for immutable URLs, non-zero for mutable sources.""" mtime = await storage_object.managed_mtime() - assert mtime == 0 + if test_config["has_mtime"]: + assert mtime > 0 + else: + assert mtime == 0 @pytest.mark.asyncio @@ -180,6 +204,7 @@ async def test_cache_functionality(storage_provider, test_config, tmp_path): local_path2.parent.mkdir(parents=True, exist_ok=True) obj2.local_path = lambda: local_path2 + # Verify no HTTP requests are made (cache hit skips download, metadata is cached) with assert_no_http_requests(storage_provider): await obj2.managed_retrieve() @@ -188,7 +213,7 @@ async def test_cache_functionality(storage_provider, test_config, tmp_path): @pytest.mark.asyncio -async def test_skip_remote_checks(test_config, tmp_path): +async def test_skip_remote_checks(test_config, tmp_path, test_logger): """Test that skip_remote_checks works correctly.""" local_prefix = tmp_path / "local" local_prefix.mkdir() @@ -200,10 +225,9 @@ async def test_skip_remote_checks(test_config, tmp_path): max_concurrent_downloads=3, ) - logger = logging.getLogger("test") provider_skip = StorageProvider( local_prefix=local_prefix, - logger=logger, + logger=test_logger, settings=settings, ) @@ -230,3 +254,85 @@ async def test_wrong_checksum_detection(storage_object, tmp_path): # Verify checksum should raise WrongChecksum with pytest.raises(WrongChecksum): await storage_object.verify_checksum(corrupted_path) + + +@pytest.mark.asyncio +async def test_cache_staleness_for_mutable_sources( + mutable_test_config, tmp_path, test_logger +): + """Test that stale cached files are re-downloaded for mutable sources.""" + url = mutable_test_config["url"] + + cache_dir = tmp_path / "cache" + cache_dir.mkdir() + local_prefix = tmp_path / "local" + local_prefix.mkdir() + + settings = StorageProviderSettings( + cache=str(cache_dir), + skip_remote_checks=False, + max_concurrent_downloads=3, + ) + + provider = StorageProvider( + local_prefix=local_prefix, + logger=test_logger, + settings=settings, + ) + + # First download to populate cache + obj1 = StorageObject( + query=url, + keep_local=False, + retrieve=True, + provider=provider, + ) + + local_path1 = tmp_path / "download1" / "testfile" + local_path1.parent.mkdir(parents=True, exist_ok=True) + obj1.local_path = lambda: local_path1 + + await obj1.managed_retrieve() + + # Verify cache was populated + assert provider.cache is not None + cached_path = provider.cache.get(url) + assert cached_path is not None + assert cached_path.exists() + + # Modify the cached file slightly + original_content = cached_path.read_bytes() + modified_content = original_content.replace(b"}", b', "stale": true}') + cached_path.write_bytes(modified_content) + + # Set mtime to 5 years ago + five_years_ago = time.time() - (5 * 365 * 24 * 60 * 60) + os.utime(cached_path, (five_years_ago, five_years_ago)) + + # Clear metadata cache to force re-fetch + provider._gcs_metadata_cache.clear() + + # Second download should detect stale cache and re-download + obj2 = StorageObject( + query=url, + keep_local=False, + retrieve=True, + provider=provider, + ) + + local_path2 = tmp_path / "download2" / "testfile" + local_path2.parent.mkdir(parents=True, exist_ok=True) + obj2.local_path = lambda: local_path2 + + await obj2.managed_retrieve() + + # Verify cache was populated + assert provider.cache is not None + cached_path = provider.cache.get(url) + assert cached_path is not None + assert cached_path.exists() + + # The downloaded file should be the original, not the stale modified version + downloaded_content = cached_path.read_bytes() + assert b'"stale": true' not in downloaded_content + assert downloaded_content == original_content diff --git a/tests/test_import.py b/tests/test_import.py index 194322c..c0440d0 100644 --- a/tests/test_import.py +++ b/tests/test_import.py @@ -4,8 +4,6 @@ """Basic import tests for the snakemake-storage-plugin-cached-http package.""" -import pytest - def test_import_module(): """Test that the main module can be imported."""