diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ef20374..571d808 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -13,6 +13,11 @@ jobs: - name: check out repository uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: 3.11 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v2 @@ -31,6 +36,9 @@ jobs: run: | python tests/upload_server.py & + - name: Install dependencies + run: pip install . + - name: instal pytest run: pip install pytest diff --git a/docker-compose.yaml b/docker-compose.yaml index 3d8803d..5b0c5af 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -6,7 +6,7 @@ services: target: development image: tesp-api environment: - - CONTAINER_TYPE=docker # Set to "docker", "singularity", or "both" + - CONTAINER_TYPE=singularity # Set to "docker", "singularity", or "both" container_name: tesp-api privileged: true ports: diff --git a/docker/pulsar_rest/Dockerfile b/docker/pulsar_rest/Dockerfile index b2edee2..64d939c 100644 --- a/docker/pulsar_rest/Dockerfile +++ b/docker/pulsar_rest/Dockerfile @@ -26,6 +26,28 @@ RUN pip install 'pulsar-app[web]' FROM python-base as development COPY --from=builder $PYSETUP_PATH $PYSETUP_PATH +# Install dependencies required by Apptainer (Singularity) +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ + curl \ + gnupg-agent \ + software-properties-common \ + lsb-release \ + wget \ + libseccomp2 \ + uidmap \ + squashfs-tools \ + squashfuse \ + fuse2fs \ + fuse-overlayfs \ + fakeroot \ + cryptsetup + +# Download and install Apptainer +ARG APPTAINER_VERSION=1.3.6 +RUN curl -LO https://github.com/apptainer/apptainer/releases/download/v${APPTAINER_VERSION}/apptainer_${APPTAINER_VERSION}_amd64.deb \ + && apt-get install -y ./apptainer_${APPTAINER_VERSION}_amd64.deb \ + && rm apptainer_${APPTAINER_VERSION}_amd64.deb + RUN apt-get update && apt-get install -y curl gnupg-agent software-properties-common lsb-release RUN curl -fsSL https://download.docker.com/linux/debian/gpg | apt-key add - RUN add-apt-repository "deb [arch=amd64] https://download.docker.com/linux/debian $(lsb_release -cs) stable" @@ -37,4 +59,4 @@ WORKDIR $PYSETUP_PATH COPY startup.sh startup.sh RUN pulsar-config --host 0.0.0.0 EXPOSE 8913 -CMD ["/bin/bash", "./startup.sh"] \ No newline at end of file +CMD ["/bin/bash", "./startup.sh"] diff --git a/pyproject.toml b/pyproject.toml index f750663..a75255b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ repository = "https://github.com/ndopj/tesp-api" [tool.poetry.dependencies] python = "^3.10.0" +aio_pika = "^9.5.7" fastapi = "^0.75.1" orjson = "^3.6.8" gunicorn = "^20.1.0" diff --git a/settings.toml b/settings.toml index f48cf62..7665757 100644 --- a/settings.toml +++ b/settings.toml @@ -2,7 +2,8 @@ db.mongodb_uri = "mongodb://localhost:27017" pulsar.url = "http://localhost:8913" pulsar.status.poll_interval = 4 -pulsar.status.max_polls = 100 +pulsar.status.max_polls = 400 +pulsar.client_timeout = 30 logging.level = "DEBUG" logging.output_json = false diff --git a/tesp_api/api/endpoints/endpoint_utils.py b/tesp_api/api/endpoints/endpoint_utils.py index 0dc57f9..127a3ee 100644 --- a/tesp_api/api/endpoints/endpoint_utils.py +++ b/tesp_api/api/endpoints/endpoint_utils.py @@ -107,5 +107,14 @@ def resource_not_found_response(message: Maybe[str] = Nothing): def response_from_model(model: BaseModel, model_rules: dict = None) -> Response: - return Response(model.json(**(model_rules if model_rules else {}), by_alias=False), - status_code=200, media_type='application/json') + response = Response( + model.json(**(model_rules if model_rules else {}), by_alias=False), + status_code=200, + media_type='application/json' + ) + # FORCE NO CACHING + response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate" + response.headers["Pragma"] = "no-cache" + response.headers["Expires"] = "0" + + return response diff --git a/tesp_api/repository/task_repository.py b/tesp_api/repository/task_repository.py index a1534f0..88e29cc 100644 --- a/tesp_api/repository/task_repository.py +++ b/tesp_api/repository/task_repository.py @@ -97,17 +97,20 @@ def cancel_task( p_author: Maybe[str], task_id: ObjectId ) -> Promise: - full_search_query = dict() - full_search_query.update({'_id': task_id}) - full_search_query.update(p_author.maybe({}, lambda a: {'author': a})) - - return Promise(lambda resolve, reject: resolve(full_search_query)) \ - .then(self._tasks.find_one) \ - .then(lambda _task: self.update_task( - {'_id': task_id}, - {'$set': {'state': TesTaskState.CANCELED}} - )).map(lambda updated_task: updated_task - .map(lambda _updated_task: _updated_task.id))\ + search_query = { + '_id': task_id, + 'state': {'$in': [ + TesTaskState.QUEUED, + TesTaskState.INITIALIZING, + TesTaskState.RUNNING + ]} + } + search_query.update(p_author.maybe({}, lambda a: {'author': a})) + update_query = {'$set': {'state': TesTaskState.CANCELED}} + + return self.update_task(search_query, update_query)\ + .map(lambda updated_task: updated_task + .map(lambda _updated_task: _updated_task.id))\ .catch(handle_data_layer_error) diff --git a/tesp_api/service/event_actions.py b/tesp_api/service/event_actions.py index e436e49..5f70eaf 100644 --- a/tesp_api/service/event_actions.py +++ b/tesp_api/service/event_actions.py @@ -16,7 +16,7 @@ from tesp_api.repository.task_repository import task_repository from tesp_api.service.file_transfer_service import file_transfer_service from tesp_api.service.error import pulsar_event_handle_error, TaskNotFoundError, TaskExecutorError -from tesp_api.service.pulsar_operations import PulsarRestOperations, PulsarAmpqOperations, DataType +from tesp_api.service.pulsar_operations import PulsarRestOperations, PulsarAmqpOperations, DataType from tesp_api.repository.model.task import ( TesTaskState, TesTaskExecutor, @@ -29,6 +29,7 @@ CONTAINER_TYPE = os.getenv("CONTAINER_TYPE", "docker") + @local_handler.register(event_name="queued_task") def handle_queued_task(event: Event) -> None: """ @@ -39,8 +40,9 @@ def handle_queued_task(event: Event) -> None: match pulsar_service.get_operations(): case PulsarRestOperations() as pulsar_rest_operations: dispatch_event('queued_task_rest', {**payload, 'pulsar_operations': pulsar_rest_operations}) - case PulsarAmpqOperations() as pulsar_ampq_operations: - dispatch_event('queued_task_ampq', {**payload, 'pulsar_operations': pulsar_ampq_operations}) + case PulsarAmqpOperations() as pulsar_amqp_operations: + dispatch_event('queued_task_amqp', {**payload, 'pulsar_operations': pulsar_amqp_operations}) + @local_handler.register(event_name="queued_task_rest") async def handle_queued_task_rest(event: Event): @@ -53,12 +55,37 @@ async def handle_queued_task_rest(event: Event): print(f"Queued task rest: {task_id}") - await Promise(lambda resolve, reject: resolve(None))\ - .then(lambda nothing: pulsar_operations.setup_job(task_id))\ - .map(lambda setup_job_result: dispatch_event('initialize_task', {**payload, 'task_config': setup_job_result}))\ - .catch(lambda error: pulsar_event_handle_error(error, task_id, event_name, pulsar_operations))\ + await Promise(lambda resolve, reject: resolve(None)) \ + .then(lambda nothing: pulsar_operations.setup_job(task_id)) \ + .map(lambda setup_job_result: dispatch_event('initialize_task', {**payload, 'task_config': setup_job_result})) \ + .catch(lambda error: pulsar_event_handle_error(error, task_id, event_name, pulsar_operations)) \ .then(lambda x: x) # Invokes promise, potentially from error handler + +@local_handler.register(event_name="queued_task_amqp") +async def handle_queued_task_amqp(event: Event): + """ + Sets up the job in Pulsar via AMQP operations and dispatches an 'initialize_task' event. + """ + event_name, payload = event + task_id: ObjectId = payload['task_id'] + pulsar_operations: PulsarAmqpOperations = payload['pulsar_operations'] + + print(f"Queued task AMQP: {task_id}") + + try: + # Setup job via AMQP + setup_job_result = await pulsar_operations.setup_job(task_id) + + # Dispatch initialize event + await dispatch_event('initialize_task', { + **payload, + 'task_config': setup_job_result + }) + except Exception as error: + await pulsar_event_handle_error(error, task_id, event_name, pulsar_operations) + + @local_handler.register(event_name="initialize_task") async def handle_initializing_task(event: Event) -> None: """ @@ -69,12 +96,11 @@ async def handle_initializing_task(event: Event) -> None: task_id: ObjectId = payload['task_id'] pulsar_operations: PulsarRestOperations = payload['pulsar_operations'] - # Merged Logic: Using the feature-complete setup_data from the new version async def setup_data(job_id: ObjectId, - resources: TesTaskResources, - volumes: List[str], - inputs: List[TesTaskInput], - outputs: List[TesTaskOutput]): + resources: TesTaskResources, + volumes: List[str], + inputs: List[TesTaskInput], + outputs: List[TesTaskOutput]): resource_conf: dict volume_confs: List[dict] = [] input_confs: List[dict] = [] @@ -109,28 +135,29 @@ async def setup_data(job_id: ObjectId, return resource_conf, volume_confs, input_confs, output_confs print(f"Initializing task: {task_id}") - await Promise(lambda resolve, reject: resolve(None))\ + await Promise(lambda resolve, reject: resolve(None)) \ .then(lambda nothing: task_repository.update_task_state( - task_id, - TesTaskState.QUEUED, - TesTaskState.INITIALIZING - )).map(lambda updated_task: get_else_throw( - updated_task, TaskNotFoundError(task_id, Just(TesTaskState.QUEUED)) - )).then(lambda updated_task: setup_data( - task_id, - maybe_of(updated_task.resources).maybe(None, lambda x: x), - maybe_of(updated_task.volumes).maybe([], lambda x: x), - maybe_of(updated_task.inputs).maybe([], lambda x: x), - maybe_of(updated_task.outputs).maybe([], lambda x: x) - )).map(lambda res_input_output_confs: dispatch_event('run_task', { - **payload, - 'resource_conf': res_input_output_confs[0], - 'volume_confs': res_input_output_confs[1], - 'input_confs': res_input_output_confs[2], - 'output_confs': res_input_output_confs[3] - })).catch(lambda error: pulsar_event_handle_error(error, task_id, event_name, pulsar_operations))\ + task_id, + TesTaskState.QUEUED, + TesTaskState.INITIALIZING + )).map(lambda updated_task: get_else_throw( + updated_task, TaskNotFoundError(task_id, Just(TesTaskState.QUEUED)) + )).then(lambda updated_task: setup_data( + task_id, + maybe_of(updated_task.resources).maybe(None, lambda x: x), + maybe_of(updated_task.volumes).maybe([], lambda x: x), + maybe_of(updated_task.inputs).maybe([], lambda x: x), + maybe_of(updated_task.outputs).maybe([], lambda x: x) + )).map(lambda res_input_output_confs: dispatch_event('run_task', { + **payload, + 'resource_conf': res_input_output_confs[0], + 'volume_confs': res_input_output_confs[1], + 'input_confs': res_input_output_confs[2], + 'output_confs': res_input_output_confs[3] + })).catch(lambda error: pulsar_event_handle_error(error, task_id, event_name, pulsar_operations)) \ .then(lambda x: x) + @local_handler.register(event_name="run_task") async def handle_run_task(event: Event) -> None: """ @@ -146,8 +173,8 @@ async def handle_run_task(event: Event) -> None: input_confs: List[dict] = payload['input_confs'] output_confs: List[dict] = payload['output_confs'] pulsar_operations: PulsarRestOperations = payload['pulsar_operations'] - - run_command_str = None + + run_command_str = None command_start_time = datetime.datetime.now(datetime.timezone.utc) try: @@ -175,7 +202,7 @@ async def handle_run_task(event: Event) -> None: ) stage_exec = TesTaskExecutor(image="willdockerhub/curl-wget:latest", command=[], workdir=Path("/downloads")) - + # Stage-in command stage_in_cmd = "" stage_in_mount = "" @@ -211,7 +238,6 @@ async def handle_run_task(event: Event) -> None: non_empty_parts = [p.strip() for p in parts if p and p.strip()] run_command_str = " && ".join(non_empty_parts) if non_empty_parts else None - # Resume with the polished version's logic for execution and state management command_start_time = datetime.datetime.now(datetime.timezone.utc) command_status: dict @@ -231,27 +257,28 @@ async def handle_run_task(event: Event) -> None: command_status.get('returncode', -1) ) - current_task_monad = await task_repository.get_task(maybe_of(author), {'_id': task_id}) - current_task_obj = get_else_throw(current_task_monad, TaskNotFoundError(task_id)) + current_task_monad = await task_repository.get_task(maybe_of(author), {'_id': task_id}) + current_task_obj = get_else_throw(current_task_monad, TaskNotFoundError(task_id)) if current_task_obj.state == TesTaskState.CANCELED: print(f"Task {task_id} found CANCELED after job completion polling. Aborting state changes.") - return + return if command_status.get('returncode', -1) != 0: - print(f"Task {task_id} executor error (return code: {command_status.get('returncode', -1)}). Setting state to EXECUTOR_ERROR.") + print( + f"Task {task_id} executor error (return code: {command_status.get('returncode', -1)}). Setting state to EXECUTOR_ERROR.") await task_repository.update_task_state(task_id, TesTaskState.RUNNING, TesTaskState.EXECUTOR_ERROR) await pulsar_operations.erase_job(task_id) - return + return print(f"Task {task_id} completed successfully. Setting state to COMPLETE.") await Promise(lambda resolve, reject: resolve(None)) \ .then(lambda ignored: task_repository.update_task_state( - task_id, TesTaskState.RUNNING, TesTaskState.COMPLETE - )) \ + task_id, TesTaskState.RUNNING, TesTaskState.COMPLETE + )) \ .map(lambda task_after_complete_update: get_else_throw( - task_after_complete_update, TaskNotFoundError(task_id, Just(TesTaskState.RUNNING)) - )) \ + task_after_complete_update, TaskNotFoundError(task_id, Just(TesTaskState.RUNNING)) + )) \ .then(lambda ignored: pulsar_operations.erase_job(task_id)) \ .catch(lambda error: pulsar_event_handle_error(error, task_id, event_name, pulsar_operations)) \ .then(lambda x: x) @@ -262,22 +289,24 @@ async def handle_run_task(event: Event) -> None: await pulsar_operations.kill_job(task_id) await pulsar_operations.erase_job(task_id) print(f"Task {task_id} Pulsar job cleanup attempted after asyncio cancellation.") - + except Exception as error: print(f"Exception in handle_run_task for task {task_id}: {type(error).__name__} - {error}") task_state_after_error_monad = await task_repository.get_task(maybe_of(author), {'_id': task_id}) if task_state_after_error_monad.is_just() and task_state_after_error_monad.value.state == TesTaskState.CANCELED: - print(f"Task {task_id} is already CANCELED. Exception '{type(error).__name__}' likely due to this. No further error processing by handler.") - return + print( + f"Task {task_id} is already CANCELED. Exception '{type(error).__name__}' likely due to this. No further error processing by handler.") + return print(f"Task {task_id} not CANCELED; proceeding with pulsar_event_handle_error for '{type(error).__name__}'.") error_handler_result = pulsar_event_handle_error(error, task_id, event_name, pulsar_operations) if asyncio.iscoroutine(error_handler_result) or isinstance(error_handler_result, _Promise): await error_handler_result - - try: - print(f"Ensuring Pulsar job for task {task_id} is erased after general error handling in run_task.") - await pulsar_operations.erase_job(task_id) - except Exception as final_erase_error: - print(f"Error during final Pulsar erase attempt for task {task_id} after general error: {final_erase_error}") + + # try: + # print(f"Ensuring Pulsar job for task {task_id} is erased after general error handling in run_task.") + # await pulsar_operations.erase_job(task_id) + # except Exception as final_erase_error: + # print( + # f"Error during final Pulsar erase attempt for task {task_id} after general error: {final_erase_error}") diff --git a/tesp_api/service/pulsar_operations.py b/tesp_api/service/pulsar_operations.py index f9f870f..fa470c5 100644 --- a/tesp_api/service/pulsar_operations.py +++ b/tesp_api/service/pulsar_operations.py @@ -1,4 +1,6 @@ import asyncio +import json +import aio_pika from enum import Enum from typing import Literal from abc import ABC, abstractmethod @@ -38,9 +40,22 @@ def __repr__(self): class PulsarOperations(ABC): @abstractmethod - def erase_job(self, task_id: ObjectId): - pass + def erase_job(self, task_id: ObjectId): pass + + @abstractmethod + def setup_job(self, job_id: ObjectId): pass + + @abstractmethod + def run_job(self, job_id: ObjectId, run_command: str): pass + + @abstractmethod + def job_status_complete(self, job_id: str): pass + + @abstractmethod + def upload(self, job_id: ObjectId, io_type: TesTaskIOType, file_path: str, file_content: Maybe[str]): pass + @abstractmethod + def download_output(self, job_id: ObjectId, file_name: str): pass class PulsarRestOperations(PulsarOperations): @@ -66,7 +81,7 @@ async def _pulsar_request(self, path: str, method: Literal['GET', 'POST', 'PUT', case 'JSON': return await response.json(content_type='text/html') case 'BYTES': return await response.read() case _ as value: raise ValueError(f'Got unexpected value [{value}] for response_type parameter') - except ClientError as err: + except (ClientError, asyncio.TimeoutError) as err: raise PulsarLayerConnectionError(err) async def job_status_complete(self, job_id: str): @@ -118,14 +133,88 @@ def erase_job(self, job_id: ObjectId): )).catch(self._reraise_custom) -class PulsarAmpqOperations(PulsarOperations): - - def __init__(self, pulsar_client: ClientSession, base_url: str): +class PulsarAmqpOperations(PulsarOperations): + def __init__(self, amqp_url: str, pulsar_client: ClientSession, base_url: str, + status_poll_interval: int, status_max_polls: int): + self.amqp_url = amqp_url self.pulsar_client = pulsar_client self.base_url = base_url + self.status_poll_interval = status_poll_interval + self.status_max_polls = status_max_polls + self.connection: aio_pika.RobustConnection | None = None + self.channel: aio_pika.Channel | None = None + + async def _connect(self): + if self.connection is None or self.connection.is_closed: + self.connection = await aio_pika.connect_robust(self.amqp_url) + self.channel = await self.connection.channel() + + async def _send_message(self, message: dict, routing_key: str = "pulsar.queue"): + await self._connect() + await self.channel.default_exchange.publish( + aio_pika.Message(body=json.dumps(message).encode()), + routing_key=routing_key + ) + + @staticmethod + def _reraise_custom(error: Exception): + match error: + case PulsarLayerConnectionError() as client_error: raise client_error + case _ as any_error: raise PulsarOperationsError(any_error) + + async def _pulsar_request(self, path: str, method: Literal['GET', 'POST', 'PUT', 'DELETE'], + response_type: Literal['JSON', 'BYTES'], params=None, data=None): + try: + async with self.pulsar_client.request( + url=f'{self.base_url}{path}', method=method, params=params, data=data) as response: + match response_type: + case 'JSON': return await response.json(content_type='text/html') + case 'BYTES': return await response.read() + case _ as value: raise ValueError(f'Unexpected response_type: {value}') + except ClientError as err: + raise PulsarLayerConnectionError(err) + + async def job_status_complete(self, job_id: str): + for i in range(0, self.status_max_polls): + await asyncio.sleep(self.status_poll_interval) + json_response = await self._pulsar_request( + path=f'/jobs/{job_id}/status', method='GET', response_type='JSON') + if json_response['complete'] == 'true': + return json_response + raise LookupError() def setup_job(self, job_id: str): - raise NotImplementedError() + return Promise(lambda resolve, reject: resolve(None))\ + .then(lambda _: self._send_message({ + "type": "job", + "job_id": str(job_id) + })).catch(lambda e: PulsarOperationsError(e)) + + def run_job(self, job_id: ObjectId, run_command: str): + return Promise(lambda resolve, reject: resolve(None))\ + .then(lambda _: self._send_message({ + "type": "submit", + "job_id": str(job_id), + "command_line": run_command + })).catch(lambda e: PulsarOperationsError(e)) def erase_job(self, task_id: ObjectId): - raise NotImplementedError() + return Promise(lambda resolve, reject: resolve(None))\ + .then(lambda _: self._send_message({ + "type": "cancel", + "job_id": str(task_id) + })).catch(lambda e: PulsarOperationsError(e)) + + def upload(self, job_id: ObjectId, io_type: TesTaskIOType, file_path: str, file_content: Maybe[str] = Nothing): + return Promise(lambda resolve, reject: resolve({'type': io_type.value, 'name': file_path}))\ + .then(lambda query_params: self._pulsar_request( + path=f'/jobs/{str(job_id)}/files', method='POST', response_type='JSON', + params=query_params, data=file_content.maybe("", lambda x: x) + )).map(lambda json_result: json_result['path']) + + def download_output(self, job_id: ObjectId, file_name: str): + return Promise(lambda resolve, reject: resolve(None))\ + .then(lambda _: self._pulsar_request( + path=f'/jobs/{str(job_id)}/files', method='GET', response_type='BYTES', + params={'name': file_name} + )).catch(self._reraise_custom) diff --git a/tesp_api/service/pulsar_service.py b/tesp_api/service/pulsar_service.py index a3cd21d..59bcfcd 100644 --- a/tesp_api/service/pulsar_service.py +++ b/tesp_api/service/pulsar_service.py @@ -3,7 +3,7 @@ from socket import AF_INET from tesp_api.config.properties import properties -from tesp_api.service.pulsar_operations import PulsarRestOperations, PulsarOperations +from tesp_api.service.pulsar_operations import PulsarRestOperations, PulsarOperations, PulsarAmqpOperations # aiohttp tracing feature allows to log each request @@ -14,18 +14,27 @@ async def on_request_start(session, context, params): class PulsarService: def __init__(self): - timeout = aiohttp.ClientTimeout(total=2) + timeout = aiohttp.ClientTimeout(total=properties.pulsar.client_timeout) connector = aiohttp.TCPConnector(family=AF_INET, limit_per_host=100) trace_config = aiohttp.TraceConfig() trace_config.on_request_start.append(on_request_start) self.pulsar_client = aiohttp.ClientSession(timeout=timeout, connector=connector, trace_configs=[trace_config]) def get_operations(self) -> PulsarOperations: + if hasattr(properties.pulsar, 'amqp_url') and properties.pulsar.amqp_url: + return PulsarAmqpOperations( + amqp_url=properties.pulsar.amqp_url, + pulsar_client=self.pulsar_client, + base_url=properties.pulsar.url, + status_poll_interval=properties.pulsar.status.poll_interval, + status_max_polls=properties.pulsar.status.max_polls + ) return PulsarRestOperations( self.pulsar_client, properties.pulsar.url, properties.pulsar.status.poll_interval, - properties.pulsar.status.max_polls) + properties.pulsar.status.max_polls + ) pulsar_service = PulsarService() diff --git a/tesp_api/utils/container.py b/tesp_api/utils/container.py index 9daea2d..b55981b 100644 --- a/tesp_api/utils/container.py +++ b/tesp_api/utils/container.py @@ -22,6 +22,7 @@ def __init__(self, container_type: str) -> None: self._envs: Dict[str, str] = {} self._volumes: Dict[str, str] = {} self._bind_mounts: Dict[str, str] = {} + self._dirs_to_create: List[str] = [] self._command: Maybe[str] = Nothing self._stdin: Maybe[str] = Nothing self._stdout: Maybe[str] = Nothing @@ -46,9 +47,21 @@ def with_volume(self, container_path: str, volume_name: str): self._volumes[container_path] = volume_name return self + def with_directory(self, path: str): + self._dirs_to_create.append(path) + return self + def with_image(self, image: str): - if self.container_type == "singularity" and not image.startswith("docker://"): - self._image = Just(f"docker://{image}") + if self.container_type == "singularity": + # If it's an absolute path, use as-is + if os.path.isabs(image): + self._image = Just(image) + # If it's already got a URI scheme, use as-is + elif image.startswith(("docker://", "library://", "shub://", "oras://")): + self._image = Just(image) + # Otherwise, assume Docker Hub reference + else: + self._image = Just(f"docker://{image}") else: self._image = Just(image) return self @@ -141,6 +154,7 @@ def reset(self) -> None: self._envs = {} self._volumes = {} self._bind_mounts = {} + self._dirs_to_create = [] self._command = Nothing self._stdin = Nothing self._stdout = Nothing @@ -149,11 +163,11 @@ def reset(self) -> None: def get_run_command(self) -> str: # Common resource flags - cpu_flag = self._resource_cpu.maybe("", lambda cpu: - f"--cpus={cpu}" if self.container_type == "docker" else f"--cpu={cpu}") - - mem_flag = self._resource_mem.maybe("", lambda mem: - f"--memory={mem}g" if self.container_type == "docker" else f"--memory={mem}G") + cpu_flag = self._resource_cpu.maybe("", lambda cpu: + f"--cpus={cpu}" if self.container_type == "docker" else f"--cpu={cpu}") + + mem_flag = self._resource_mem.maybe("", lambda mem: + f"--memory={mem}g" if self.container_type == "docker" else f"--memory={mem}G") # Environment variables env_flags = [] @@ -164,14 +178,22 @@ def get_run_command(self) -> str: env_flags.append(f'--env {k}="{v}"') # Work directory - workdir_flag = self._workdir.maybe("", lambda w: - f'-w "{w}"' if self.container_type == "docker" else f'--pwd "{w}"') + workdir_flag = self._workdir.maybe("", lambda w: + f'-w "{w}"' if self.container_type == "docker" else f'--pwd "{w}"') # Image image = self._image.maybe("", lambda i: i) # Mounts mount_flags = [] + mkdir_cmds = [] # for singularity pre-creation + + # Handle explicit directories (Singularity only) + if self.container_type == "singularity": + for path in self._dirs_to_create: + mkdir_cmds.append(f'mkdir -p "{path}"') + + # Handle mounts if self.container_type == "docker": for container_path, host_path in self._bind_mounts.items(): mount_flags.append(f'-v "{host_path}":"{container_path}"') @@ -181,11 +203,13 @@ def get_run_command(self) -> str: for container_path, host_path in self._bind_mounts.items(): mount_flags.append(f'-B "{host_path}":"{container_path}"') for container_path, volume_name in self._volumes.items(): + # Volumes are always directories, so we can auto-create them + mkdir_cmds.append(f'mkdir -p "{volume_name}"') mount_flags.append(f'-B "{volume_name}":"{container_path}"') # Command - command_str = self._command.maybe("", lambda cmd: - " ".join(shlex.quote(arg) for arg in cmd) if isinstance(cmd, list) else cmd) + command_str = self._command.maybe("", lambda cmd: + " ".join(shlex.quote(arg) for arg in cmd) if isinstance(cmd, list) else cmd) # Build final command if self.container_type == "docker": @@ -195,11 +219,14 @@ def get_run_command(self) -> str: f"{image} {command_str}" ).strip() else: # singularity - return ( + mkdir_prefix = " && ".join(mkdir_cmds) + singularity_cmd = ( f"singularity exec {cpu_flag} {mem_flag} {workdir_flag} " f"{' '.join(env_flags)} {' '.join(mount_flags)} " f"{image} {command_str}" ).strip() + return f"{mkdir_prefix} && {singularity_cmd}" if mkdir_prefix else singularity_cmd + # Unified command functions def stage_in_command( @@ -212,7 +239,8 @@ def stage_in_command( builder = ContainerCommandBuilder(container_type) \ .with_image(executor.image) \ .with_workdir(executor.workdir) \ - .with_resource(resource_conf) + .with_resource(resource_conf) \ + .with_directory(bind_mount) commands = [] for input_conf in input_confs: @@ -227,12 +255,12 @@ def stage_in_command( if input_type == TesTaskIOType.DIRECTORY: # Recursive download commands.append( - f"wget --mirror --no-parent --no-host-directories " + f"wget -e robots=off --mirror --no-parent --no-host-directories " f"--directory-prefix={shlex.quote(filename)} {shlex.quote(url)}" ) else: # Single file download - commands.append(f"curl -o {shlex.quote(filename)} {shlex.quote(url)}") + commands.append(f"curl -f -o {shlex.quote(filename)} {shlex.quote(url)}") if commands: builder.with_command(["sh", "-c", " && ".join(commands)]) @@ -278,8 +306,14 @@ def run_command( builder.with_volume(volume_conf['container_path'], volume_conf['volume_name']) for input_conf in input_confs: + if input_conf.get('type') == TesTaskIOType.DIRECTORY: + builder.with_directory(input_conf['pulsar_path']) + builder.with_bind_mount(input_conf['container_path'], input_conf['pulsar_path']) + if not executor.workdir and container_type == "singularity": + builder.with_workdir("/") + return builder.get_run_command() def stage_out_command( @@ -306,11 +340,11 @@ def stage_out_command( # Recursive upload cmd = ( f"find {shlex.quote(path)} -type f -exec " - f"curl -X POST -F 'file=@{{}}' {shlex.quote(url)} \\;" + f"curl -f -X POST -F 'file=@{{}}' {shlex.quote(url)} \\;" ) else: # Single file upload - cmd = f"curl -X POST -F 'file=@{shlex.quote(path)}' {shlex.quote(url)}" + cmd = f"curl -f -X POST -F 'file=@{shlex.quote(path)}' {shlex.quote(url)}" commands.append(cmd) @@ -320,12 +354,10 @@ def stage_out_command( # Mount required directories if container_type == "singularity" and bind_mount: builder.with_bind_mount(executor.workdir, bind_mount) + builder.with_directory(bind_mount) for volume_conf in volume_confs: - if container_type == "docker": - builder.with_volume(volume_conf['container_path'], volume_conf['volume_name']) - elif container_type == "singularity" and job_directory: - builder.with_bind_mount(volume_conf['container_path'], job_directory) + builder.with_volume(volume_conf['container_path'], volume_conf['volume_name']) if executor.env: for env_name, env_value in executor.env.items(): @@ -333,7 +365,6 @@ def stage_out_command( return builder.get_run_command() -# Volume mapping function remains the same def map_volumes(job_id: str, volumes: List[str], outputs: List[TesTaskOutput]): output_confs: List[dict] = [] volume_confs: List[dict] = [] diff --git a/tests/test_data/input_file.txt b/tests/test_data/input_file.txt new file mode 100644 index 0000000..c7602bf --- /dev/null +++ b/tests/test_data/input_file.txt @@ -0,0 +1 @@ +This is an input test file. diff --git a/tests/test_jsons/inputs.json b/tests/test_jsons/inputs.json index 569d55c..02f08dd 100644 --- a/tests/test_jsons/inputs.json +++ b/tests/test_jsons/inputs.json @@ -7,7 +7,7 @@ }, { "path": "/data/file_http", - "url": "http://172.17.0.1:5000/download/test.txt", + "url": "http://172.17.0.1:5000/test_data/test.txt", "type": "FILE" } ], diff --git a/tests/test_jsons/volumes.json b/tests/test_jsons/volumes.json index 2e0b45c..0af4527 100644 --- a/tests/test_jsons/volumes.json +++ b/tests/test_jsons/volumes.json @@ -2,7 +2,7 @@ "inputs": [ { "path": "/data/file_http", - "url": "http://172.17.0.1:5000/download/test.txt", + "url": "http://172.17.0.1:5000/test_data/test.txt", "type": "FILE" } ],