diff --git a/api/src/shared/common/gcp_memory_utils.py b/api/src/shared/common/gcp_memory_utils.py new file mode 100644 index 000000000..84568331d --- /dev/null +++ b/api/src/shared/common/gcp_memory_utils.py @@ -0,0 +1,135 @@ +import os +import resource +import shutil +import logging + +MB_MULTIPLIER = 1024**2 + + +def find_tmpfs_mounts(): + """ + Returns a list of tmpfs mount points whose path contains 'in-memory', + from /proc/mounts. + """ + tmpfs_mounts = [] + try: + with open("/proc/mounts", "r") as f: + for line in f: + parts = line.split() + if len(parts) >= 3 and parts[2] == "tmpfs" and "in-memory" in parts[1]: + tmpfs_mounts.append(parts[1]) + except Exception as e: + logging.error(f"Error reading /proc/mounts: {e}") + return tmpfs_mounts + + +def get_memory_limit_cgroup_bytes(): + """ + Returns the memory limit for the process (in bytes) as set by cgroups, or None if not found. + """ + try: + with open("/sys/fs/cgroup/memory/memory.limit_in_bytes", "r") as f: + limit_bytes = int(f.read()) + # If the limit is a very large number (e.g., 2**63), treat as unlimited + if limit_bytes < (2**60): + return limit_bytes + except Exception: + pass + return None + + +def get_total_tmpfs_size_bytes(): + """ + Returns the total size (in bytes) of all tmpfs mounts whose path contains 'in-memory', + or None if none found or all unlimited. + """ + tmpfs_mounts = find_tmpfs_mounts() + total_size = 0 + found = False + for mount in tmpfs_mounts: + if os.path.exists(mount): + try: + total, _, _ = shutil.disk_usage(mount) + # If total is suspiciously large (>= 1 PB), treat as unlimited + if total < 1 << 50: # Ignore unlimited mounts + total_size += total + found = True + except Exception as e: + logging.error(f"Error getting disk usage for {mount}: {e}") + if found: + return total_size + return None + + +def get_available_process_memory_bytes(): + """ + Returns the available memory for the process in bytes: + total process memory limit (cgroup) minus the total size of all tmpfs + filesystems whose path contains 'in-memory'. If any value is unlimited + or not found, returns None. + """ + mem_limit = get_memory_limit_cgroup_bytes() + tmpfs_size = get_total_tmpfs_size_bytes() + if mem_limit is None or tmpfs_size is None: + logging.warning("Could not determine available process memory " "(limit or tmpfs size missing/unlimited).") + return None + available_bytes = mem_limit - tmpfs_size + logging.info( + "Process memory limit: %.2f MiB, total tmpfs size: %.2f MiB, available: %.2f MiB", + mem_limit / MB_MULTIPLIER, + tmpfs_size / MB_MULTIPLIER, + available_bytes / MB_MULTIPLIER, + ) + return available_bytes + + +def limit_gcp_memory(): + # Margin comes from env in megabytes (string), default 200 MiB + memory_margin_str_mb = os.getenv("MEMORY_MARGIN_MB", "200") + + available_memory_bytes = get_available_process_memory_bytes() + if not available_memory_bytes or available_memory_bytes <= 0: + logging.info("Could not find the total memory of the process. Memory limit not set.") + return + + memory_margin_mb = 200 + if memory_margin_str_mb: + try: + memory_margin_mb = int(memory_margin_str_mb) + except ValueError as err: + logging.error( + "Invalid MEMORY_MARGIN_MB value: %s. Using default of 200MB. Error: %s", + memory_margin_str_mb, + err, + ) + + memory_margin_bytes = memory_margin_mb * MB_MULTIPLIER if memory_margin_mb > 0 else 0 + logging.info( + "Available memory: %.2f MiB, memory margin: %.2f MiB", + available_memory_bytes / MB_MULTIPLIER, + memory_margin_bytes / MB_MULTIPLIER, + ) + mem_limit = available_memory_bytes - memory_margin_bytes + if mem_limit <= 0: + logging.warning( + "Computed RLIMIT_AS <= 0 (%.2f MiB). Skipping setrlimit.", + mem_limit / MB_MULTIPLIER, + ) + return + + # Set RLIMIT_AS in bytes, log the limit in MiB + resource.setrlimit(resource.RLIMIT_AS, (mem_limit, mem_limit)) + logging.info( + "RLIMIT_AS set to %.2f MiB (raw: %d bytes)", + mem_limit / MB_MULTIPLIER, + mem_limit, + ) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + available = get_available_process_memory_bytes() + if available is not None: + print(f"Available process memory: {available / MB_MULTIPLIER:.2f} MiB") + else: + print("Could not determine available process memory.") diff --git a/api/src/shared/database/database.py b/api/src/shared/database/database.py index 26afd327b..03a2ba8e5 100644 --- a/api/src/shared/database/database.py +++ b/api/src/shared/database/database.py @@ -10,6 +10,7 @@ from shared.database_gen.sqlacodegen_models import ( Base, Feed, + Gtfsdataset, Gtfsfeed, Gtfsrealtimefeed, Gbfsversion, @@ -94,6 +95,9 @@ def configure_polymorphic_mappers(): Validationreport: [ Validationreport.notices, # notice_validation_report_id_fkey ], + Gtfsdataset: [ + Gtfsdataset.gtfsfiles, + ], } diff --git a/functions-python/batch_process_dataset/src/main.py b/functions-python/batch_process_dataset/src/main.py index ed5ee356b..a792e8394 100644 --- a/functions-python/batch_process_dataset/src/main.py +++ b/functions-python/batch_process_dataset/src/main.py @@ -21,6 +21,7 @@ import random import uuid import zipfile +import shutil from dataclasses import dataclass from datetime import datetime from typing import Optional, List @@ -31,6 +32,7 @@ from sqlalchemy import func from sqlalchemy.orm import Session +from shared.common.gcp_memory_utils import limit_gcp_memory from shared.common.gcp_utils import create_refresh_materialized_view_task from shared.database.database import with_db_session from shared.database_gen.sqlacodegen_models import Gtfsdataset, Gtfsfile, Gtfsfeed @@ -45,6 +47,9 @@ init_logger() +# Limit the available memory of the process so if an OOM exception happens it can be handled properly by our code +limit_gcp_memory() + @dataclass class DatasetFile: @@ -141,16 +146,14 @@ def download_content(self, temporary_file_path, feed_id): is_zip = zipfile.is_zipfile(temporary_file_path) return file_hash, is_zip - def upload_files_to_storage( + def upload_dataset_zip_to_storage( self, source_file_path, dataset_stable_id, - extracted_files_path, public=True, - skip_dataset_upload=False, ): """ - Uploads the dataset file and extracted files to GCP storage + Uploads the dataset zip file to GCP storage as latest.zip and versioned zip. """ bucket = storage.Client().get_bucket(self.bucket_name) target_paths = [ @@ -158,48 +161,101 @@ def upload_files_to_storage( f"{self.feed_stable_id}/{dataset_stable_id}/{dataset_stable_id}.zip", ] blob = None - if not skip_dataset_upload: - for target_path in target_paths: - blob = bucket.blob(target_path) - blob.upload_from_filename(source_file_path) - if public: - blob.make_public() - self.logger.info(f"Uploaded {blob.public_url}") - - base_path, _ = os.path.splitext(source_file_path) + for target_path in target_paths: + blob = bucket.blob(target_path) + blob.upload_from_filename(source_file_path) + if public: + blob.make_public() + self.logger.info(f"Uploaded {blob.public_url}") + return blob + + def extract_and_upload_files_from_zip( + self, + zip_file_path: str, + dataset_stable_id: str, + public: bool = True, + ) -> List[Gtfsfile]: + """ + Extract files one at a time from a ZIP archive and upload each to GCS. + This minimizes local disk usage by extracting and uploading one file at a time, + then deleting the temporary extracted file before moving to the next. + + :param zip_file_path: Path to the ZIP file + :param dataset_stable_id: The dataset stable ID for the GCS path + :param public: Whether to make the uploaded files public + :return: List of Gtfsfile objects representing the extracted files + """ + if not zipfile.is_zipfile(zip_file_path): + self.logger.error("The file %s is not a valid ZIP file.", zip_file_path) + raise ValueError("File is not a valid ZIP file.") + + bucket = storage.Client().get_bucket(self.bucket_name) extracted_files: List[Gtfsfile] = [] - if not extracted_files_path or not os.path.exists(extracted_files_path): - self.logger.warning( - "Extracted files path %s does not exist.", extracted_files_path - ) - return blob, extracted_files - self.logger.info("Processing extracted files from %s", extracted_files_path) - for file_name in os.listdir(extracted_files_path): - file_path = os.path.join(extracted_files_path, file_name) - if os.path.isfile(file_path): - file_blob = bucket.blob( - f"{self.feed_stable_id}/{dataset_stable_id}/extracted/{file_name}" + working_dir = os.getenv("WORKING_DIR", "/tmp/in-memory") + + with zipfile.ZipFile(zip_file_path, "r") as zf: + for member in zf.infolist(): + # Skip directories + if member.is_dir(): + continue + + # Extract a single file to a temporary path. + # Use a unique filename with feed_stable_id and dataset_stable_id prefix to avoid collisions + # when multiple datasets are processed concurrently. Replace '/' with '_' to flatten any + # subdirectory structure from the ZIP into a single working directory. + temp_extracted_path = os.path.join( + working_dir, + f"{self.feed_stable_id}-{dataset_stable_id}-{member.filename.replace('/', '_')}", ) - file_blob.upload_from_filename(file_path) - if public: - file_blob.make_public() + self.logger.info( - "Uploaded extracted file %s to %s", file_name, file_blob.public_url + "Extracting %s to %s", member.filename, temp_extracted_path ) - extracted_files.append( - Gtfsfile( - id=str(uuid.uuid4()), - file_name=file_name, - file_size_bytes=os.path.getsize(file_path), - hosted_url=file_blob.public_url if public else None, - hash=get_hash_from_file(file_path), + with zf.open(member, "r") as src, open( + temp_extracted_path, "wb" + ) as dst: + shutil.copyfileobj(src, dst) + + # Upload this single file to GCS under extracted/ + if os.path.isfile(temp_extracted_path): + target_path = f"{self.feed_stable_id}/{dataset_stable_id}/extracted/{member.filename}" + file_blob = bucket.blob(target_path) + file_blob.upload_from_filename(temp_extracted_path) + if public: + file_blob.make_public() + self.logger.info( + "Uploaded extracted file %s to %s", + member.filename, + file_blob.public_url, + ) + + extracted_files.append( + Gtfsfile( + id=str(uuid.uuid4()), + file_name=member.filename, + file_size_bytes=os.path.getsize(temp_extracted_path), + hosted_url=file_blob.public_url if public else None, + hash=get_hash_from_file(temp_extracted_path), + ) + ) + + # Remove the local temporary extracted file to free disk space + try: + if os.path.exists(temp_extracted_path): + os.remove(temp_extracted_path) + except Exception as cleanup_err: + self.logger.warning( + "Failed to remove temporary file %s: %s", + temp_extracted_path, + cleanup_err, ) - ) - return blob, extracted_files - def upload_dataset(self, feed_id, public=True) -> DatasetFile or None: + return extracted_files + + def transfer_dataset(self, feed_id, public=True) -> DatasetFile or None: """ - Uploads a dataset to a GCP bucket as /latest.zip and + Transfer a dataset from the provider url to the local disk then upload to the GCP bucket as + /latest.zip and /-.zip if the dataset hash is different from the latest dataset stored :return: the file hash and the hosted url as a tuple or None if no upload is required @@ -224,10 +280,6 @@ def upload_dataset(self, feed_id, public=True) -> DatasetFile or None: f"[{self.feed_stable_id}] Dataset has changed (hash {self.latest_hash}" f"-> {file_sha256_hash}). Uploading new version." ) - extracted_files_path = self.unzip_files(temp_file_path) - self.logger.info( - f"Creating file {self.feed_stable_id}/latest.zip in bucket {self.bucket_name}" - ) dataset_stable_id = self.create_dataset_stable_id( self.feed_stable_id, self.date @@ -235,13 +287,24 @@ def upload_dataset(self, feed_id, public=True) -> DatasetFile or None: dataset_full_path = ( f"{self.feed_stable_id}/{dataset_stable_id}/{dataset_stable_id}.zip" ) + + # Upload the zip file to GCS + self.logger.info( + f"Creating file {self.feed_stable_id}/latest.zip in bucket {self.bucket_name}" + ) self.logger.info( f"Creating file {dataset_full_path} in bucket {self.bucket_name}" ) - _, extracted_files = self.upload_files_to_storage( + self.upload_dataset_zip_to_storage( + temp_file_path, + dataset_stable_id, + public=public, + ) + + # Extract and upload files one at a time to minimize disk usage + extracted_files = self.extract_and_upload_files_from_zip( temp_file_path, dataset_stable_id, - extracted_files_path, public=public, ) @@ -261,6 +324,9 @@ def upload_dataset(self, feed_id, public=True) -> DatasetFile or None: f"[{self.feed_stable_id}] Datasets hash has not changed (hash {self.latest_hash} " f"-> {file_sha256_hash}). Not uploading it." ) + except Exception as e: + self.logger.error(f"Error transferring dataset: {e}") + raise e finally: if temp_file_path and os.path.exists(temp_file_path): os.remove(temp_file_path) @@ -268,27 +334,29 @@ def upload_dataset(self, feed_id, public=True) -> DatasetFile or None: @with_db_session def process_from_bucket(self, db_session, public=True) -> Optional[DatasetFile]: + """Process an existing dataset from the GCP bucket and update related DB entities. + + To reduce local disk usage, we no longer unzip all files at once. Instead, we: + - Download the dataset ZIP to a temporary local file. + - Iterate over each member of the ZIP. + - Extract a single file to a temporary path under WORKING_DIR. + - Upload that file immediately to GCS and record it as a Gtfsfile. + - Delete the local temporary extracted file before moving to the next one. """ - Process an existing dataset from the GCP bucket updates the related database entities - :return: The DatasetFile object created - """ - temp_file_path = None + temp_zip_path = None try: - temp_file_path = self.generate_temp_filename() + temp_zip_path = self.generate_temp_filename() blob_file_path = f"{self.feed_stable_id}/{self.dataset_stable_id}/{self.dataset_stable_id}.zip" - self.logger.info(f"Processing dataset from bucket: {blob_file_path}") + self.logger.info("Processing dataset from bucket: %s", blob_file_path) download_from_gcs( - os.getenv("DATASETS_BUCKET_NAME"), blob_file_path, temp_file_path + os.getenv("DATASETS_BUCKET_NAME"), blob_file_path, temp_zip_path ) - extracted_files_path = self.unzip_files(temp_file_path) - - _, extracted_files = self.upload_files_to_storage( - temp_file_path, + # Extract and upload files one at a time to minimize disk usage + extracted_files = self.extract_and_upload_files_from_zip( + temp_zip_path, self.dataset_stable_id, - extracted_files_path, public=public, - skip_dataset_upload=True, # Skip the upload of the dataset file ) dataset_file = DatasetFile( @@ -297,11 +365,12 @@ def process_from_bucket(self, db_session, public=True) -> Optional[DatasetFile]: hosted_url=f"{self.public_hosted_datasets_url}/{blob_file_path}", extracted_files=extracted_files, zipped_size=( - os.path.getsize(temp_file_path) - if os.path.exists(temp_file_path) + os.path.getsize(temp_zip_path) + if os.path.exists(temp_zip_path) else None ), ) + dataset, latest = self.create_dataset_entities( dataset_file, skip_dataset_creation=True, db_session=db_session ) @@ -319,26 +388,14 @@ def process_from_bucket(self, db_session, public=True) -> Optional[DatasetFile]: raise ValueError("Dataset update failed, dataset is None.") return dataset_file finally: - if temp_file_path and os.path.exists(temp_file_path): - os.remove(temp_file_path) - - def unzip_files(self, temp_file_path): - extracted_files_path = os.path.join(temp_file_path.split(".")[0], "extracted") - self.logger.info(f"Unzipping files to {extracted_files_path}") - # Create the directory for extracted files if it does not exist - os.makedirs(extracted_files_path, exist_ok=True) - with zipfile.ZipFile(temp_file_path, "r") as zip_ref: - zip_ref.extractall(path=extracted_files_path) - # List all files in the extracted directory - extracted_files = os.listdir(extracted_files_path) - self.logger.info(f"Extracted files: {extracted_files}") - return extracted_files_path + if temp_zip_path and os.path.exists(temp_zip_path): + os.remove(temp_zip_path) def generate_temp_filename(self): """ Generates a temporary filename """ - working_dir = os.getenv("WORKING_DIR", "/in-memory") + working_dir = os.getenv("WORKING_DIR", "/tmp/in-memory") temporary_file_path = ( f"{working_dir}/{self.feed_stable_id}-{random.randint(0, 1000000)}.zip" ) @@ -431,7 +488,7 @@ def process_from_producer_url( Process the dataset and store new version in GCP bucket if any changes are detected :return: the DatasetFile object created """ - dataset_file = self.upload_dataset(feed_id) + dataset_file = self.transfer_dataset(feed_id) if dataset_file is None: self.logger.info(f"[{self.feed_stable_id}] No database update required.") @@ -487,6 +544,7 @@ def process_dataset(cloud_event: CloudEvent): """ logging.info("Function Started") stable_id = "UNKNOWN" + execution_id = "UNKNOWN" bucket_name = os.getenv("DATASETS_BUCKET_NAME") try: @@ -506,19 +564,17 @@ def process_dataset(cloud_event: CloudEvent): return try: - maximum_executions = os.getenv("MAXIMUM_EXECUTIONS", 1) + try: + maximum_executions = int(os.getenv("MAXIMUM_EXECUTIONS", "1")) + except (ValueError, TypeError): + maximum_executions = 1 public_hosted_datasets_url = os.getenv("PUBLIC_HOSTED_DATASETS_URL") trace_service = None dataset_file: DatasetFile = None error_message = None - # Extract data from message - data = base64.b64decode(cloud_event.data["message"]["data"]).decode() - json_payload = json.loads(data) - stable_id = json_payload["feed_stable_id"] logger = get_logger("process_dataset", stable_id) logger.info(f"JSON Payload: {json.dumps(json_payload)}") - execution_id = json_payload["execution_id"] trace_service = DatasetTraceService() trace = trace_service.get_by_execution_and_stable_ids(execution_id, stable_id) logger.info(f"Dataset trace: {trace}") @@ -557,7 +613,9 @@ def process_dataset(cloud_event: CloudEvent): # This makes sure the logger is initialized logger = get_logger("process_dataset", stable_id if stable_id else "UNKNOWN") logger.error(e) - error_message = f"Error execution: [{execution_id}] error: [{e}]" + error_message = ( + f"Error execution: [{execution_id}] error: [{type(e).__name__}]: {e}" + ) logger.error(error_message) logger.error(f"Function completed with error:{error_message}") finally: @@ -587,3 +645,41 @@ def process_dataset(cloud_event: CloudEvent): "successfully completed" if not error_message else "Failed", ) return "Completed." if error_message is None else error_message + + +def simulate(request) -> dict: # pragma: no cover + """HTTP endpoint to simulate a process_dataset call for testing.""" + # Hardcoded test values + payload = { + "execution_id": "task-executor-uuid-af993d49-0d95-42cb-96a4-9cffc5301e87", + "producer_url": "https://data.bus-data.dft.gov.uk/timetable/download/gtfs-file/all/", + "feed_stable_id": "mdb-2014", + "feed_id": "34434a73-0ba7-4070-b01f-dfadb6e30d42", + "dataset_stable_id": "mdb-2014-202408202259", + "dataset_hash": "abc", + "authentication_type": "0", + "authentication_info_url": "", + "api_key_parameter_name": "", + } + + # Create CloudEvent + encoded_data = base64.b64encode(json.dumps(payload).encode()).decode() + attributes = { + "type": "google.cloud.pubsub.topic.v1.messagePublished", + "source": "//pubsub.googleapis.com/test", + "specversion": "1.0", + } + data = {"message": {"data": encoded_data}} + cloud_event = CloudEvent(attributes, data) + + # Call process_dataset + process_dataset(cloud_event) + return {"status": "completed"} + + +def main(): # pragma: no cover + simulate(None) + + +if __name__ == "__main__": + main() diff --git a/functions-python/batch_process_dataset/src/scripts/download_verifier.py b/functions-python/batch_process_dataset/src/scripts/download_verifier.py index 8007bcbdc..08a8be08c 100644 --- a/functions-python/batch_process_dataset/src/scripts/download_verifier.py +++ b/functions-python/batch_process_dataset/src/scripts/download_verifier.py @@ -60,7 +60,7 @@ def verify_upload_dataset(producer_url: str): ) tempfile = processor.generate_temp_filename() logging.info(f"Temp filename: {tempfile}") - dataset_file = processor.upload_dataset("feed_id_2126", False) + dataset_file = processor.transfer_dataset("feed_id_2126", False) logging.info(f"Dataset File: {dataset_file}") diff --git a/functions-python/batch_process_dataset/tests/test_batch_process_dataset_main.py b/functions-python/batch_process_dataset/tests/test_batch_process_dataset_main.py index 6cb43a15e..f54066633 100644 --- a/functions-python/batch_process_dataset/tests/test_batch_process_dataset_main.py +++ b/functions-python/batch_process_dataset/tests/test_batch_process_dataset_main.py @@ -2,13 +2,11 @@ import datetime import json import os -import tempfile +import shutil import unittest from hashlib import sha256 from typing import Final -from unittest.mock import patch, MagicMock, Mock, mock_open - -import faker +from unittest.mock import patch, MagicMock, Mock from main import ( DatasetProcessor, @@ -45,11 +43,31 @@ def create_cloud_event(mock_data): class TestDatasetProcessor(unittest.TestCase): - @patch("main.DatasetProcessor.upload_files_to_storage") + @classmethod + def setUpClass(cls): + """Set up test environment with a dedicated working directory""" + # Use a test-specific working directory + cls.test_working_dir = os.path.join( + os.path.dirname(__file__), "test_working_dir" + ) + os.makedirs(cls.test_working_dir, exist_ok=True) + # Set the environment variable for all tests + os.environ["WORKING_DIR"] = cls.test_working_dir + + @classmethod + def tearDownClass(cls): + """Clean up the test working directory after all tests""" + if os.path.exists(cls.test_working_dir): + shutil.rmtree(cls.test_working_dir, ignore_errors=True) + + @patch("main.DatasetProcessor.extract_and_upload_files_from_zip") + @patch("main.DatasetProcessor.upload_dataset_zip_to_storage") @patch("main.DatasetProcessor.download_content") - @patch("main.DatasetProcessor.unzip_files") def test_upload_dataset_diff_hash( - self, mock_unzip_files, mock_download_url_content, upload_files_to_storage + self, + mock_download_url_content, + mock_upload_dataset_zip, + mock_extract_and_upload, ): """ Test upload_dataset method of DatasetProcessor class with different hash from the latest one @@ -57,9 +75,12 @@ def test_upload_dataset_diff_hash( mock_blob = MagicMock() mock_blob.public_url = public_url mock_blob.path = public_url - upload_files_to_storage.return_value = mock_blob, [] + + # Mock the new methods used in transfer_dataset + mock_upload_dataset_zip.return_value = mock_blob + mock_extracted_files = [] # Empty list of extracted files + mock_extract_and_upload.return_value = mock_extracted_files mock_download_url_content.return_value = file_hash, True - mock_unzip_files.return_value = [mock_blob, mock_blob] processor = DatasetProcessor( public_url, @@ -73,7 +94,7 @@ def test_upload_dataset_diff_hash( test_hosted_public_url, ) with patch.object(processor, "date", "mocked_timestamp"): - result = processor.upload_dataset("feed_id") + result = processor.transfer_dataset("feed_id") self.assertIsNotNone(result) mock_download_url_content.assert_called_once() @@ -84,19 +105,15 @@ def test_upload_dataset_diff_hash( f"/feed_stable_id-mocked_timestamp.zip", ) self.assertEqual(result.file_sha256_hash, file_hash) - self.assertEqual(upload_files_to_storage.call_count, 1) + # Verify the new methods were called + self.assertEqual(mock_upload_dataset_zip.call_count, 1) + self.assertEqual(mock_extract_and_upload.call_count, 1) - @patch("main.DatasetProcessor.upload_files_to_storage") @patch("main.DatasetProcessor.download_content") - def test_upload_dataset_same_hash( - self, mock_download_url_content, upload_files_to_storage - ): + def test_upload_dataset_same_hash(self, mock_download_url_content): """ Test upload_dataset method of DatasetProcessor class with the hash from the latest one """ - mock_blob = MagicMock() - mock_blob.public_url = public_url - upload_files_to_storage.return_value = mock_blob mock_download_url_content.return_value = file_hash, True processor = DatasetProcessor( @@ -111,24 +128,16 @@ def test_upload_dataset_same_hash( test_hosted_public_url, ) - result = processor.upload_dataset("feed_id") + result = processor.transfer_dataset("feed_id") self.assertIsNone(result) - upload_files_to_storage.blob.assert_not_called() - mock_blob.make_public.assert_not_called() mock_download_url_content.assert_called_once() - @patch("main.DatasetProcessor.upload_files_to_storage") @patch("main.DatasetProcessor.download_content") - def test_upload_dataset_not_zip( - self, mock_download_url_content, upload_files_to_storage - ): + def test_upload_dataset_not_zip(self, mock_download_url_content): """ Test upload_dataset method of DatasetProcessor class with a non zip file """ - mock_blob = MagicMock() - mock_blob.public_url = public_url - upload_files_to_storage.return_value = mock_blob mock_download_url_content.return_value = file_hash, False processor = DatasetProcessor( @@ -143,24 +152,16 @@ def test_upload_dataset_not_zip( test_hosted_public_url, ) - result = processor.upload_dataset("feed_id") + result = processor.transfer_dataset("feed_id") self.assertIsNone(result) - upload_files_to_storage.blob.assert_not_called() - mock_blob.make_public.assert_not_called() mock_download_url_content.assert_called_once() - @patch("main.DatasetProcessor.upload_files_to_storage") @patch("main.DatasetProcessor.download_content") - def test_upload_dataset_download_exception( - self, mock_download_url_content, upload_files_to_storage - ): + def test_upload_dataset_download_exception(self, mock_download_url_content): """ - Test upload_dataset method of DatasetProcessor class with the hash from the latest one + Test upload_dataset method of DatasetProcessor class when download fails """ - mock_blob = MagicMock() - mock_blob.public_url = public_url - upload_files_to_storage.return_value = mock_blob mock_download_url_content.side_effect = Exception("Download failed") processor = DatasetProcessor( @@ -176,47 +177,261 @@ def test_upload_dataset_download_exception( ) with self.assertRaises(Exception): - processor.upload_dataset("feed_id") + processor.transfer_dataset("feed_id") + + @patch("main.get_hash_from_file", return_value="test_file_hash_123") + @patch("main.os.path.getsize", return_value=1024) + @patch("main.os.remove") + @patch("main.storage.Client") + @patch("main.zipfile.is_zipfile", return_value=True) + def test_extract_and_upload_files_from_zip_success( + self, + _mock_is_zipfile, + mock_storage_client, + mock_remove, + _mock_getsize, + _mock_get_hash, + ): + """ + Test extract_and_upload_files_from_zip with a valid ZIP file containing multiple files + """ + import tempfile + import zipfile + + # Create a real temporary ZIP file with test content + with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmp_zip: + zip_path = tmp_zip.name + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("stops.txt", "stop_id,stop_name\n1,Stop A\n") + zf.writestr("routes.txt", "route_id,route_name\n1,Route 1\n") + # Add a directory to test that it's skipped + zf.writestr("subfolder/", "") + + try: + # Setup mocks + mock_blob = Mock() + mock_blob.public_url = ( + "https://storage.googleapis.com/bucket/feed/dataset/extracted/stops.txt" + ) + mock_bucket = Mock() + mock_bucket.blob.return_value = mock_blob + mock_client = Mock() + mock_client.get_bucket.return_value = mock_bucket + mock_storage_client.return_value = mock_client - def test_upload_files_to_storage(self): - bucket_name = "test-bucket" - source_file_path = "path/to/source/file" - extracted_file_path = "path/to/source/file" + # Create processor + processor = DatasetProcessor( + producer_url="https://example.com/feed.zip", + feed_id="test_feed_id", + feed_stable_id="test_feed", + execution_id="exec_123", + latest_hash="hash123", + bucket_name="test-bucket", + authentication_type=0, + api_key_parameter_name=None, + public_hosted_datasets_url="https://public.example.com", + ) - mock_blob = Mock() - mock_blob.public_url = public_url - mock_bucket = Mock() - mock_bucket.blob.return_value = mock_blob - mock_client = Mock() - mock_client.get_bucket.return_value = mock_bucket + # Call the method + result = processor.extract_and_upload_files_from_zip( + zip_file_path=zip_path, + dataset_stable_id="dataset_123", + public=True, + ) + + # Assertions + self.assertEqual(len(result), 2) # 2 files, directory should be skipped + self.assertEqual(result[0].file_name, "stops.txt") + self.assertEqual(result[1].file_name, "routes.txt") + self.assertEqual(result[0].file_size_bytes, 1024) + self.assertEqual(result[0].hash, "test_file_hash_123") + self.assertEqual(result[0].hosted_url, mock_blob.public_url) + + # Verify bucket.blob was called for each file + self.assertEqual(mock_bucket.blob.call_count, 2) + mock_bucket.blob.assert_any_call( + "test_feed/dataset_123/extracted/stops.txt" + ) + mock_bucket.blob.assert_any_call( + "test_feed/dataset_123/extracted/routes.txt" + ) + + # Verify upload was called for each file + self.assertEqual(mock_blob.upload_from_filename.call_count, 2) + + # Verify make_public was called (public=True) + self.assertEqual(mock_blob.make_public.call_count, 2) - # Mock open function - mock_file = mock_open() + # Verify cleanup (os.remove) was called for each extracted file + self.assertEqual(mock_remove.call_count, 2) + + finally: + # Cleanup the temporary ZIP file + if os.path.exists(zip_path): + os.remove(zip_path) + + @patch("main.storage.Client") + @patch("main.zipfile.is_zipfile", return_value=True) + def test_extract_and_upload_files_from_zip_not_public( + self, mock_is_zipfile, mock_storage_client + ): + """ + Test extract_and_upload_files_from_zip with public=False + """ + import tempfile + import zipfile + + # Create a real temporary ZIP file + with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmp_zip: + zip_path = tmp_zip.name + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("agency.txt", "agency_id,agency_name\n1,Agency\n") + + try: + # Setup mocks + mock_blob = Mock() + mock_blob.public_url = "https://storage.googleapis.com/bucket/feed/dataset/extracted/agency.txt" + mock_bucket = Mock() + mock_bucket.blob.return_value = mock_blob + mock_client = Mock() + mock_client.get_bucket.return_value = mock_bucket + mock_storage_client.return_value = mock_client - with patch("google.cloud.storage.Client", return_value=mock_client), patch( - "builtins.open", mock_file - ): processor = DatasetProcessor( - public_url, - "feed_id", - "feed_stable_id", - "execution_id", - "latest_hash", - bucket_name, - 0, - None, - test_hosted_public_url, + producer_url="https://example.com/feed.zip", + feed_id="test_feed_id", + feed_stable_id="test_feed", + execution_id="exec_123", + latest_hash="hash123", + bucket_name="test-bucket", + authentication_type=0, + api_key_parameter_name=None, + public_hosted_datasets_url="https://public.example.com", + ) + + # Call with public=False + result = processor.extract_and_upload_files_from_zip( + zip_file_path=zip_path, + dataset_stable_id="dataset_456", + public=False, ) - dataset_id = faker.Faker().uuid4() - result, _ = processor.upload_files_to_storage( - source_file_path, dataset_id, extracted_file_path + + # Assertions + self.assertEqual(len(result), 1) + self.assertIsNone(result[0].hosted_url) # Should be None when public=False + + # Verify make_public was NOT called + mock_blob.make_public.assert_not_called() + + finally: + if os.path.exists(zip_path): + os.remove(zip_path) + + @patch("main.zipfile.is_zipfile", return_value=False) + def test_extract_and_upload_files_from_zip_invalid_zip(self, mock_is_zipfile): + """ + Test extract_and_upload_files_from_zip with an invalid ZIP file + """ + import tempfile + + # Create a temporary non-ZIP file + with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as tmp_file: + tmp_file.write(b"This is not a ZIP file") + file_path = tmp_file.name + + try: + processor = DatasetProcessor( + producer_url="https://example.com/feed.zip", + feed_id="test_feed_id", + feed_stable_id="test_feed", + execution_id="exec_123", + latest_hash="hash123", + bucket_name="test-bucket", + authentication_type=0, + api_key_parameter_name=None, + public_hosted_datasets_url="https://public.example.com", ) - self.assertEqual(result.public_url, public_url) - mock_client.get_bucket.assert_called_with(bucket_name) - mock_bucket.blob.assert_called_with( - f"feed_stable_id/{dataset_id}/{dataset_id}.zip" + + # Should raise ValueError for invalid ZIP + with self.assertRaises(ValueError) as context: + processor.extract_and_upload_files_from_zip( + zip_file_path=file_path, + dataset_stable_id="dataset_789", + public=True, + ) + + self.assertIn("not a valid ZIP file", str(context.exception)) + + finally: + if os.path.exists(file_path): + os.remove(file_path) + + @patch("main.get_hash_from_file", return_value="hash_abc") + @patch("main.storage.Client") + @patch("main.zipfile.is_zipfile", return_value=True) + def test_extract_and_upload_files_from_zip_cleanup_failure( + self, mock_is_zipfile, mock_storage_client, mock_get_hash + ): + """ + Test that cleanup failures are caught and logged but don't stop processing + """ + import tempfile + import zipfile + + # Create a real temporary ZIP file + with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmp_zip: + zip_path = tmp_zip.name + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("test.txt", "test content\n") + + try: + # Setup mocks + mock_blob = Mock() + mock_blob.public_url = "https://storage.googleapis.com/bucket/file.txt" + mock_bucket = Mock() + mock_bucket.blob.return_value = mock_blob + mock_client = Mock() + mock_client.get_bucket.return_value = mock_bucket + mock_storage_client.return_value = mock_client + + processor = DatasetProcessor( + producer_url="https://example.com/feed.zip", + feed_id="test_feed_id", + feed_stable_id="test_feed", + execution_id="exec_123", + latest_hash="hash123", + bucket_name="test-bucket", + authentication_type=0, + api_key_parameter_name=None, + public_hosted_datasets_url="https://public.example.com", ) - mock_blob.upload_from_filename.assert_called() + + # Mock os.remove to fail only for extracted files (not the ZIP itself) + original_remove = os.remove + + def mock_remove_side_effect(path): + # Only raise exception for extracted files, not the ZIP file + if "in-memory" in path and path != zip_path: + raise Exception("Cleanup failed") + else: + # Allow cleanup of the ZIP file to succeed + original_remove(path) + + with patch("main.os.remove", side_effect=mock_remove_side_effect): + # Should not raise exception even though cleanup fails + result = processor.extract_and_upload_files_from_zip( + zip_file_path=zip_path, + dataset_stable_id="dataset_cleanup", + public=True, + ) + + # Should still return the file + self.assertEqual(len(result), 1) + self.assertEqual(result[0].file_name, "test.txt") + + finally: + if os.path.exists(zip_path): + os.remove(zip_path) @patch.dict( os.environ, {"FEEDS_CREDENTIALS": '{"test_stable_id": "test_credentials"}'} @@ -247,7 +462,7 @@ def test_process(self, db_session): test_hosted_public_url, ) - processor.upload_dataset = MagicMock( + processor.transfer_dataset = MagicMock( return_value=DatasetFile( stable_id="test_stable_id", file_sha256_hash=new_hash, @@ -260,7 +475,7 @@ def test_process(self, db_session): self.assertIsNotNone(result) self.assertEqual(result.file_sha256_hash, new_hash) - processor.upload_dataset.assert_called_once() + processor.transfer_dataset.assert_called_once() @patch.dict( os.environ, @@ -355,7 +570,7 @@ def test_process_no_change(self): test_hosted_public_url, ) - processor.upload_dataset = MagicMock(return_value=None) + processor.transfer_dataset = MagicMock(return_value=None) processor.create_dataset_entities = MagicMock() result = processor.process_from_producer_url(feed_id) @@ -453,26 +668,18 @@ def test_process_dataset_missing_stable_id(self, mock_dataset_trace): @patch.dict(os.environ, {"DATASETS_BUCKET_NAME": "test-bucket"}) @patch("main.create_pipeline_tasks") @patch("main.DatasetProcessor.create_dataset_entities") - @patch("main.DatasetProcessor.upload_files_to_storage") - @patch("main.DatasetProcessor.unzip_files") + @patch("main.DatasetProcessor.extract_and_upload_files_from_zip") @patch("main.download_from_gcs") def test_process_from_bucket_latest_happy_path( self, mock_download_from_gcs, - mock_unzip_files, - mock_upload_files_to_storage, + mock_extract_and_upload, mock_create_dataset_entities, _, ): # Arrange - mock_blob = MagicMock() - mock_upload_files_to_storage.return_value = ( - mock_blob, - [], - ) # (blob, extracted_files) - mock_unzip_files.return_value = ( - "/tmp/extracted" # not used deeply because upload is mocked - ) + mock_extracted_files = [] # Empty list of extracted files + mock_extract_and_upload.return_value = mock_extracted_files processor = DatasetProcessor( producer_url="https://ignored-in-bucket-mode.example.com/feed.zip", @@ -491,8 +698,9 @@ def test_process_from_bucket_latest_happy_path( # Act result = processor.process_from_bucket(public=True) - # Assert: function returns None in current implementation - self.assertIsNone(result.zipped_size) + # Assert: function returns a DatasetFile + self.assertIsNotNone(result) + self.assertIsInstance(result, DatasetFile) # Assert: downloads from the bucket latest.zip for this feed mock_download_from_gcs.assert_called_once() @@ -506,12 +714,12 @@ def test_process_from_bucket_latest_happy_path( ) # temp file path (random), so just ensure it exists self.assertNotEqual(args[2], "") # sanity - # Assert: upload of extracted files happened with skip_dataset_upload=True - mock_upload_files_to_storage.assert_called_once() - u_args, u_kwargs = mock_upload_files_to_storage.call_args - # args: (source_file_path, dataset_stable_id, extracted_files_path, ...) - self.assertEqual(u_args[1], "dataset-stable-id-123") - self.assertEqual(u_kwargs.get("skip_dataset_upload"), True) + # Assert: extract_and_upload_files_from_zip was called with the dataset stable ID + mock_extract_and_upload.assert_called_once() + extract_args, extract_kwargs = mock_extract_and_upload.call_args + # First arg should be the temp zip path, second arg should be dataset_stable_id + self.assertEqual(extract_args[1], "dataset-stable-id-123") + self.assertEqual(extract_kwargs.get("public"), True) # Assert: DB update called with skip_dataset_creation=True and a DatasetFile-like object mock_create_dataset_entities.assert_called_once() @@ -524,101 +732,160 @@ def test_process_from_bucket_latest_happy_path( self.assertEqual(c_args[0].stable_id, "dataset-stable-id-123") self.assertTrue(hasattr(c_args[0], "file_sha256_hash")) self.assertEqual(c_args[0].file_sha256_hash, "latest_hash_value") + self.assertEqual(c_args[0].file_sha256_hash, "latest_hash_value") - @patch("main.get_hash_from_file", return_value="fakehash123") - @patch("google.cloud.storage.Client") - def test_upload_files_to_storage_branches(self, mock_client_cls, mock_get_hash): - # Arrange global mocks - mock_blob_latest = Mock() - mock_blob_versioned = Mock() - mock_blob_extracted = Mock() - - # configure bucket.blob() to return different blobs in sequence - mock_bucket = Mock() - # First call scenario (dataset uploads happen): two blobs for latest.zip + versioned.zip - # Second call scenario (skip dataset upload): blobs created only for extracted file(s) - blob_side_effects = [mock_blob_latest, mock_blob_versioned, mock_blob_extracted] - mock_bucket.blob.side_effect = blob_side_effects + @with_db_session(db_url=default_db_url) + def test_create_dataset_entities_update_existing(self, db_session): + """ + Test create_dataset_entities when updating an existing dataset (skip_dataset_creation=True) + This specifically tests lines 457-468 of main.py + """ + from shared.database_gen.sqlacodegen_models import Gtfsfeed, Gtfsdataset - mock_client = Mock() - mock_client.get_bucket.return_value = mock_bucket - mock_client_cls.return_value = mock_client + # Use an existing feed from the database to avoid foreign key issues + feeds = db_session.query(Gtfsfeed).all() + if not feeds: + self.skipTest("No feeds available in test database") + + test_feed = feeds[0] + + # Create an existing dataset for this feed + existing_dataset = Gtfsdataset( + id="existing_dataset_update_test", + feed_id=test_feed.id, + stable_id="dataset_existing_update", + hash="old_hash", + hosted_url="https://storage.example.com/old.zip", + gtfsfiles=[], + zipped_size_bytes=1000, + unzipped_size_bytes=2000, + ) - # Create processor - from main import DatasetProcessor + test_feed.latest_dataset = existing_dataset + db_session.add(existing_dataset) + db_session.commit() processor = DatasetProcessor( producer_url="https://example.com/feed.zip", - feed_id="feed_id", - feed_stable_id="feed_stable_id", - execution_id="execution_id", - latest_hash="hash", - bucket_name="bucket-name", + feed_id=test_feed.id, + feed_stable_id=test_feed.stable_id, + execution_id="exec_456", + latest_hash="new_hash", + bucket_name="test-bucket", authentication_type=0, api_key_parameter_name=None, - public_hosted_datasets_url="https://public-hosted", + public_hosted_datasets_url="https://public.example.com", ) - # --- SCENARIO A: extracted path DOES NOT exist; dataset uploaded (skip_dataset_upload=False) - src_path = "/tmp/fake-src.zip" # not read, only passed to upload_from_filename - dataset_id_A = "datasetA" - non_existing_path = "/tmp/this/path/does/not/exist" - - result_blob_A, extracted_A = processor.upload_files_to_storage( - source_file_path=src_path, - dataset_stable_id=dataset_id_A, - extracted_files_path=non_existing_path, - public=True, - skip_dataset_upload=False, + # Create dataset file with new extracted files + from main import Gtfsfile + + dataset_file = DatasetFile( + stable_id="dataset_existing_update", + file_sha256_hash="new_hash", + hosted_url="https://storage.example.com/feed/dataset.zip", + extracted_files=[ + Gtfsfile( + id="file3", + file_name="agency.txt", + file_size_bytes=512, + hosted_url="https://storage.example.com/feed/agency.txt", + hash="agency_hash", + ), + ], + zipped_size=3000, ) - # Asserts Scenario A - self.assertIs(result_blob_A, mock_blob_versioned) # last dataset upload blob - self.assertEqual(extracted_A, []) # no extracted files - # two dataset uploads: latest.zip + versioned zip - self.assertEqual(mock_bucket.blob.call_count, 2) - mock_blob_latest.upload_from_filename.assert_called_once_with(src_path) - mock_blob_versioned.upload_from_filename.assert_called_once_with(src_path) - - # --- SCENARIO B: extracted path EXISTS; includes a file and a directory - with tempfile.TemporaryDirectory() as tmpdir: - extracted_dir = os.path.join(tmpdir, "extracted") - os.makedirs(extracted_dir, exist_ok=True) - # create one file - file_path = os.path.join(extracted_dir, "stops.txt") - with open(file_path, "wb") as f: - f.write(b"stop_id,stop_name\n1,A\n") - # create a subdirectory to ensure we skip non-files - os.makedirs(os.path.join(extracted_dir, "subdir"), exist_ok=True) - - dataset_id_B = "datasetB" - # Reset call counters for clarity - mock_bucket.blob.reset_mock() - - result_blob_B, extracted_B = processor.upload_files_to_storage( - source_file_path=src_path, - dataset_stable_id=dataset_id_B, - extracted_files_path=extracted_dir, - public=False, # ensure no make_public called - skip_dataset_upload=True, # skip dataset zips + # Mock create_refresh_materialized_view_task inside the test + with patch("main.create_refresh_materialized_view_task") as mock_refresh_task: + # Call with skip_dataset_creation=True to test the selected code branch + result_dataset, is_latest = processor.create_dataset_entities( + dataset_file=dataset_file, + db_session=db_session, + skip_dataset_creation=True, ) - # Asserts Scenario B - self.assertIsNone(result_blob_B) # because skip_dataset_upload=True - # Only one extracted file should be uploaded - self.assertEqual(len(extracted_B), 1) - self.assertEqual(extracted_B[0].file_name, "stops.txt") - self.assertEqual(extracted_B[0].file_size_bytes, os.path.getsize(file_path)) - self.assertEqual(extracted_B[0].hash, "fakehash123") - self.assertIsNone(extracted_B[0].hosted_url) # public=False → no hosted_url - - # bucket.blob called once for the extracted file - mock_bucket.blob.assert_called_once_with( - f"feed_stable_id/{dataset_id_B}/extracted/stops.txt" + # Assertions - should return the existing dataset updated + self.assertIsNotNone(result_dataset) + self.assertEqual(result_dataset.id, "existing_dataset_update_test") + + # Verify line 462-463: latest_dataset.gtfsfiles updated + self.assertEqual(len(result_dataset.gtfsfiles), 1) + self.assertEqual(result_dataset.gtfsfiles[0].file_name, "agency.txt") + + # Verify line 464: latest_dataset.zipped_size_bytes updated + self.assertEqual(result_dataset.zipped_size_bytes, 3000) + + # Verify line 465-467: latest_dataset.unzipped_size_bytes updated + self.assertEqual(result_dataset.unzipped_size_bytes, 512) + + mock_refresh_task.assert_called_once() + + @with_db_session(db_url=default_db_url) + def test_create_dataset_entities_update_existing_no_files(self, db_session): + """ + Test create_dataset_entities with skip_dataset_creation=True and no extracted files + This tests the else branch on line 462: dataset_file.extracted_files else [] + """ + from shared.database_gen.sqlacodegen_models import Gtfsfeed, Gtfsdataset + + # Use an existing feed from the database + feeds = db_session.query(Gtfsfeed).all() + if not feeds: + self.skipTest("No feeds available in test database") + + test_feed = feeds[0] + + # Create an existing dataset for this feed + existing_dataset = Gtfsdataset( + id="existing_dataset_no_files_test", + feed_id=test_feed.id, + stable_id="dataset_no_files_test", + hash="old_hash", + hosted_url="https://storage.example.com/old.zip", + gtfsfiles=[], + zipped_size_bytes=1000, + unzipped_size_bytes=2000, + ) + + test_feed.latest_dataset = existing_dataset + db_session.add(existing_dataset) + db_session.commit() + + processor = DatasetProcessor( + producer_url="https://example.com/feed.zip", + feed_id=test_feed.id, + feed_stable_id=test_feed.stable_id, + execution_id="exec_789", + latest_hash="new_hash", + bucket_name="test-bucket", + authentication_type=0, + api_key_parameter_name=None, + public_hosted_datasets_url="https://public.example.com", + ) + + # Create dataset file with NO extracted files (None) + dataset_file = DatasetFile( + stable_id="dataset_no_files_test", + file_sha256_hash="new_hash", + hosted_url="https://storage.example.com/feed/dataset.zip", + extracted_files=None, # Test the else branch + zipped_size=5000, + ) + + # Mock create_refresh_materialized_view_task inside the test + with patch("main.create_refresh_materialized_view_task") as mock_refresh_task: + # Call with skip_dataset_creation=True + result_dataset, is_latest = processor.create_dataset_entities( + dataset_file=dataset_file, + db_session=db_session, + skip_dataset_creation=True, ) - mock_blob_extracted.upload_from_filename.assert_called_once_with(file_path) - # No make_public in this branch - self.assertFalse(getattr(mock_blob_extracted, "make_public").called) - # Sanity: hash function used for extracted file - mock_get_hash.assert_called_with(file_path) + # Assertions + self.assertIsNotNone(result_dataset) + self.assertEqual(len(result_dataset.gtfsfiles), 0) # Should be empty list + self.assertEqual(result_dataset.zipped_size_bytes, 5000) + self.assertIsNone(result_dataset.unzipped_size_bytes) # None when no files + + mock_refresh_task.assert_called_once()