diff --git a/dojo/api_v2/views.py b/dojo/api_v2/views.py index 382a8952dc7..09791d1c0f0 100644 --- a/dojo/api_v2/views.py +++ b/dojo/api_v2/views.py @@ -46,6 +46,7 @@ ) from dojo.api_v2.prefetch.prefetcher import _Prefetcher from dojo.authorization.roles_permissions import Permissions +from dojo.celery_dispatch import dojo_dispatch_task from dojo.cred.queries import get_authorized_cred_mappings from dojo.endpoint.queries import ( get_authorized_endpoint_status, @@ -678,13 +679,13 @@ def update_jira_epic(self, request, pk=None): try: if engagement.has_jira_issue: - jira_helper.update_epic(engagement.id, **request.data) + dojo_dispatch_task(jira_helper.update_epic, engagement.id, **request.data) response = Response( {"info": "Jira Epic update query sent"}, status=status.HTTP_200_OK, ) else: - jira_helper.add_epic(engagement.id, **request.data) + dojo_dispatch_task(jira_helper.add_epic, engagement.id, **request.data) response = Response( {"info": "Jira Epic create query sent"}, status=status.HTTP_200_OK, diff --git a/dojo/celery.py b/dojo/celery.py index ead4a8813a8..3cf09e1bc2c 100644 --- a/dojo/celery.py +++ b/dojo/celery.py @@ -12,16 +12,56 @@ os.environ.setdefault("DJANGO_SETTINGS_MODULE", "dojo.settings.settings") -class PgHistoryTask(Task): +class DojoAsyncTask(Task): + + """ + Base task class that provides dojo_async_task functionality without using a decorator. + + This class: + - Injects user context into task kwargs + - Tracks task calls for performance testing + - Supports all Celery features (signatures, chords, groups, chains) + """ + + def apply_async(self, args=None, kwargs=None, **options): + """Override apply_async to inject user context and track tasks.""" + from dojo.decorators import dojo_async_task_counter # noqa: PLC0415 circular import + from dojo.utils import get_current_user # noqa: PLC0415 circular import + + if kwargs is None: + kwargs = {} + + # Inject user context if not already present + if "async_user" not in kwargs: + kwargs["async_user"] = get_current_user() + + # Control flag used for sync/async decision; never pass into the task itself + kwargs.pop("sync", None) + + # Track dispatch + dojo_async_task_counter.incr( + self.name, + args=args, + kwargs=kwargs, + ) + + # Call parent to execute async + return super().apply_async(args=args, kwargs=kwargs, **options) + + +class PgHistoryTask(DojoAsyncTask): """ Custom Celery base task that automatically applies pghistory context. - When a task is dispatched via dojo_async_task, the current pghistory - context is captured and passed in kwargs as "_pgh_context". This base - class extracts that context and applies it before running the task, - ensuring all database events share the same context as the original - request. + This class inherits from DojoAsyncTask to provide: + - User context injection and task tracking (from DojoAsyncTask) + - Automatic pghistory context application (from this class) + + When a task is dispatched via dojo_dispatch_task or dojo_async_task, the current + pghistory context is captured and passed in kwargs as "_pgh_context". This base + class extracts that context and applies it before running the task, ensuring all + database events share the same context as the original request. """ def __call__(self, *args, **kwargs): diff --git a/dojo/celery_dispatch.py b/dojo/celery_dispatch.py new file mode 100644 index 00000000000..a96c4257553 --- /dev/null +++ b/dojo/celery_dispatch.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol, cast + +from celery.canvas import Signature + +if TYPE_CHECKING: + from collections.abc import Mapping + + +class _SupportsSi(Protocol): + def si(self, *args: Any, **kwargs: Any) -> Signature: ... + + +class _SupportsApplyAsync(Protocol): + def apply_async(self, args: Any | None = None, kwargs: Any | None = None, **options: Any) -> Any: ... + + +def _inject_async_user(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: + result: dict[str, Any] = dict(kwargs or {}) + if "async_user" not in result: + from dojo.utils import get_current_user # noqa: PLC0415 circular import + + result["async_user"] = get_current_user() + return result + + +def _inject_pghistory_context(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: + """Capture and inject pghistory context if available.""" + result: dict[str, Any] = dict(kwargs or {}) + if "_pgh_context" not in result: + from dojo.pghistory_utils import get_serializable_pghistory_context # noqa: PLC0415 circular import + + if pgh_context := get_serializable_pghistory_context(): + result["_pgh_context"] = pgh_context + return result + + +def dojo_create_signature(task_or_sig: _SupportsSi | Signature, *args: Any, **kwargs: Any) -> Signature: + """ + Build a Celery signature with DefectDojo user context and pghistory context injected. + + - If passed a task, returns `task_or_sig.si(*args, **kwargs)`. + - If passed an existing signature, returns a cloned signature with merged kwargs. + """ + injected = _inject_async_user(kwargs) + injected = _inject_pghistory_context(injected) + injected.pop("countdown", None) + + if isinstance(task_or_sig, Signature): + merged_kwargs = {**(task_or_sig.kwargs or {}), **injected} + return task_or_sig.clone(kwargs=merged_kwargs) + + return task_or_sig.si(*args, **injected) + + +def dojo_dispatch_task(task_or_sig: _SupportsSi | _SupportsApplyAsync | Signature, *args: Any, **kwargs: Any) -> Any: + """ + Dispatch a task/signature using DefectDojo semantics. + + - Inject `async_user` if missing. + - Capture and inject pghistory context if available. + - Respect `sync=True` (foreground execution) and user `block_execution`. + - Support `countdown=` for async dispatch. + + Returns: + - async: AsyncResult-like return from Celery + - sync: underlying return value of the task + + """ + from dojo.decorators import dojo_async_task_counter, we_want_async # noqa: PLC0415 circular import + + countdown = cast("int", kwargs.pop("countdown", 0)) + injected = _inject_async_user(kwargs) + injected = _inject_pghistory_context(injected) + + sig = dojo_create_signature(task_or_sig if isinstance(task_or_sig, Signature) else cast("_SupportsSi", task_or_sig), *args, **injected) + sig_kwargs = dict(sig.kwargs or {}) + + if we_want_async(*sig.args, func=getattr(sig, "type", None), **sig_kwargs): + # DojoAsyncTask.apply_async tracks async dispatch. Avoid double-counting here. + return sig.apply_async(countdown=countdown) + + # Track foreground execution as a "created task" as well (matches historical dojo_async_task behavior) + dojo_async_task_counter.incr(str(sig.task), args=sig.args, kwargs=sig_kwargs) + + sig_kwargs.pop("sync", None) + sig = sig.clone(kwargs=sig_kwargs) + eager = sig.apply() + return eager.get(propagate=True) diff --git a/dojo/endpoint/views.py b/dojo/endpoint/views.py index 1dc4df898c6..5d604eb9890 100644 --- a/dojo/endpoint/views.py +++ b/dojo/endpoint/views.py @@ -18,6 +18,7 @@ from dojo.authorization.authorization import user_has_permission_or_403 from dojo.authorization.authorization_decorators import user_is_authorized from dojo.authorization.roles_permissions import Permissions +from dojo.celery_dispatch import dojo_dispatch_task from dojo.endpoint.queries import get_authorized_endpoints from dojo.endpoint.utils import clean_hosts_run, endpoint_meta_import from dojo.filters import EndpointFilter, EndpointFilterWithoutObjectLookups @@ -373,7 +374,7 @@ def endpoint_bulk_update_all(request, pid=None): product_calc = list(Product.objects.filter(endpoint__id__in=endpoints_to_update).distinct()) endpoints.delete() for prod in product_calc: - calculate_grade(prod.id) + dojo_dispatch_task(calculate_grade, prod.id) if skipped_endpoint_count > 0: add_error_message_to_response(f"Skipped deletion of {skipped_endpoint_count} endpoints because you are not authorized.") diff --git a/dojo/engagement/services.py b/dojo/engagement/services.py index cd70af1ea2c..42a7c1c05e4 100644 --- a/dojo/engagement/services.py +++ b/dojo/engagement/services.py @@ -5,6 +5,7 @@ from django.dispatch import receiver import dojo.jira_link.helper as jira_helper +from dojo.celery_dispatch import dojo_dispatch_task from dojo.models import Engagement logger = logging.getLogger(__name__) @@ -16,7 +17,7 @@ def close_engagement(eng): eng.save() if jira_helper.get_jira_project(eng): - jira_helper.close_epic(eng.id, push_to_jira=True) + dojo_dispatch_task(jira_helper.close_epic, eng.id, push_to_jira=True) def reopen_engagement(eng): diff --git a/dojo/engagement/views.py b/dojo/engagement/views.py index 405f3953347..855511fce0a 100644 --- a/dojo/engagement/views.py +++ b/dojo/engagement/views.py @@ -37,6 +37,7 @@ from dojo.authorization.authorization import user_has_permission_or_403 from dojo.authorization.authorization_decorators import user_is_authorized from dojo.authorization.roles_permissions import Permissions +from dojo.celery_dispatch import dojo_dispatch_task from dojo.endpoint.utils import save_endpoints_to_add from dojo.engagement.queries import get_authorized_engagements from dojo.engagement.services import close_engagement, reopen_engagement @@ -390,7 +391,7 @@ def copy_engagement(request, eid): form = DoneForm(request.POST) if form.is_valid(): engagement_copy = engagement.copy() - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) messages.add_message( request, messages.SUCCESS, diff --git a/dojo/finding/deduplication.py b/dojo/finding/deduplication.py index eb8baf40db0..301884349a2 100644 --- a/dojo/finding/deduplication.py +++ b/dojo/finding/deduplication.py @@ -8,7 +8,6 @@ from django.db.models.query_utils import Q from dojo.celery import app -from dojo.decorators import dojo_async_task from dojo.models import Finding, System_Settings logger = logging.getLogger(__name__) @@ -45,13 +44,11 @@ def get_finding_models_for_deduplication(finding_ids): ) -@dojo_async_task @app.task def do_dedupe_finding_task(new_finding_id, *args, **kwargs): return do_dedupe_finding_task_internal(Finding.objects.get(id=new_finding_id), *args, **kwargs) -@dojo_async_task @app.task def do_dedupe_batch_task(finding_ids, *args, **kwargs): """ diff --git a/dojo/finding/helper.py b/dojo/finding/helper.py index 908afee38b9..03e5fd409a2 100644 --- a/dojo/finding/helper.py +++ b/dojo/finding/helper.py @@ -16,7 +16,6 @@ import dojo.jira_link.helper as jira_helper import dojo.risk_acceptance.helper as ra_helper from dojo.celery import app -from dojo.decorators import dojo_async_task from dojo.endpoint.utils import endpoint_get_or_create, save_endpoints_to_add from dojo.file_uploads.helper import delete_related_files from dojo.finding.deduplication import ( @@ -391,7 +390,6 @@ def add_findings_to_auto_group(name, findings, group_by, *, create_finding_group finding_group.findings.add(*findings) -@dojo_async_task @app.task def post_process_finding_save(finding_id, dedupe_option=True, rules_option=True, product_grading_option=True, # noqa: FBT002 issue_updater_option=True, push_to_jira=False, user=None, *args, **kwargs): # noqa: FBT002 - this is bit hard to fix nice have this universally fixed @@ -436,7 +434,9 @@ def post_process_finding_save_internal(finding, dedupe_option=True, rules_option if product_grading_option: if system_settings.enable_product_grade: - calculate_grade(finding.test.engagement.product.id) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(calculate_grade, finding.test.engagement.product.id) else: deduplicationLogger.debug("skipping product grading because it's disabled in system settings") @@ -453,7 +453,6 @@ def post_process_finding_save_internal(finding, dedupe_option=True, rules_option jira_helper.push_to_jira(finding.finding_group) -@dojo_async_task @app.task def post_process_findings_batch(finding_ids, *args, dedupe_option=True, rules_option=True, product_grading_option=True, issue_updater_option=True, push_to_jira=False, user=None, **kwargs): @@ -496,7 +495,9 @@ def post_process_findings_batch(finding_ids, *args, dedupe_option=True, rules_op tool_issue_updater.async_tool_issue_update(finding) if product_grading_option and system_settings.enable_product_grade: - calculate_grade(findings[0].test.engagement.product.id) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(calculate_grade, findings[0].test.engagement.product.id) if push_to_jira: for finding in findings: diff --git a/dojo/finding/views.py b/dojo/finding/views.py index ae224afde53..3f9ea214e6f 100644 --- a/dojo/finding/views.py +++ b/dojo/finding/views.py @@ -38,6 +38,7 @@ user_is_authorized, ) from dojo.authorization.roles_permissions import Permissions +from dojo.celery_dispatch import dojo_dispatch_task from dojo.filters import ( AcceptedFindingFilter, AcceptedFindingFilterWithoutObjectLookups, @@ -1082,7 +1083,7 @@ def process_form(self, request: HttpRequest, finding: Finding, context: dict): product = finding.test.engagement.product finding.delete() # Update the grade of the product async - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) # Add a message to the request that the finding was successfully deleted messages.add_message( request, @@ -1353,7 +1354,7 @@ def copy_finding(request, fid): test = form.cleaned_data.get("test") product = finding.test.engagement.product finding_copy = finding.copy(test=test) - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) messages.add_message( request, messages.SUCCESS, diff --git a/dojo/finding_group/views.py b/dojo/finding_group/views.py index 451d4dcd720..e29c401b80d 100644 --- a/dojo/finding_group/views.py +++ b/dojo/finding_group/views.py @@ -16,6 +16,7 @@ from dojo.authorization.authorization import user_has_permission_or_403 from dojo.authorization.authorization_decorators import user_is_authorized from dojo.authorization.roles_permissions import Permissions +from dojo.celery_dispatch import dojo_dispatch_task from dojo.filters import ( FindingFilter, FindingFilterWithoutObjectLookups, @@ -100,7 +101,7 @@ def view_finding_group(request, fgid): elif not finding_group.has_jira_issue: jira_helper.finding_group_link_jira(request, finding_group, jira_issue) elif push_to_jira: - jira_helper.push_to_jira(finding_group, sync=True) + dojo_dispatch_task(jira_helper.push_to_jira, finding_group, sync=True) finding_group.save() return HttpResponseRedirect(reverse("view_test", args=(finding_group.test.id,))) @@ -200,7 +201,7 @@ def push_to_jira(request, fgid): # it may look like success here, but the push_to_jira are swallowing exceptions # but cant't change too much now without having a test suite, so leave as is for now with the addition warning message to check alerts for background errors. - if jira_helper.push_to_jira(group, sync=True): + if dojo_dispatch_task(jira_helper.push_to_jira, group, sync=True): messages.add_message( request, messages.SUCCESS, diff --git a/dojo/importers/base_importer.py b/dojo/importers/base_importer.py index 9da381be678..8b82a6bfd45 100644 --- a/dojo/importers/base_importer.py +++ b/dojo/importers/base_importer.py @@ -3,7 +3,6 @@ import time from collections.abc import Iterable -from celery import chord, group from django.conf import settings from django.core.exceptions import ValidationError from django.core.files.base import ContentFile @@ -14,7 +13,7 @@ import dojo.finding.helper as finding_helper import dojo.risk_acceptance.helper as ra_helper -from dojo import utils +from dojo.celery_dispatch import dojo_dispatch_task from dojo.importers.endpoint_manager import EndpointManager from dojo.importers.options import ImporterOptions from dojo.models import ( @@ -29,7 +28,6 @@ Endpoint, FileUpload, Finding, - System_Settings, Test, Test_Import, Test_Import_Finding_Action, @@ -643,47 +641,6 @@ def update_test_type_from_internal_test(self, internal_test: ParserTest) -> None self.test.test_type.dynamic_tool = dynamic_tool self.test.test_type.save() - def maybe_launch_post_processing_chord( - self, - post_processing_task_signatures, - current_batch_number: int, - max_batch_size: int, - * - is_final_batch: bool, - ) -> tuple[list, int, bool]: - """ - Helper to optionally launch a chord of post-processing tasks with a calculate-grade callback - when async is desired. Uses exponential batch sizing up to the configured max batch size. - - Returns a tuple of (post_processing_task_signatures, current_batch_number, launched) - where launched indicates whether a chord/group was dispatched and signatures were reset. - """ - launched = False - if not post_processing_task_signatures: - return post_processing_task_signatures, current_batch_number, launched - - current_batch_size = min(2 ** current_batch_number, max_batch_size) - batch_full = len(post_processing_task_signatures) >= current_batch_size - - if batch_full or is_final_batch: - product = self.test.engagement.product - system_settings = System_Settings.objects.get() - if system_settings.enable_product_grade: - calculate_grade_signature = utils.calculate_grade.si(product.id) - chord(post_processing_task_signatures)(calculate_grade_signature) - else: - group(post_processing_task_signatures).apply_async() - - logger.debug( - f"Launched chord with {len(post_processing_task_signatures)} tasks (batch #{current_batch_number}, size: {len(post_processing_task_signatures)})", - ) - post_processing_task_signatures = [] - if not is_final_batch: - current_batch_number += 1 - launched = True - - return post_processing_task_signatures, current_batch_number, launched - def verify_tool_configuration_from_test(self): """ Verify that the Tool_Configuration supplied along with the @@ -930,7 +887,13 @@ def mitigate_finding( # risk_unaccept will check if finding.risk_accepted is True before proceeding ra_helper.risk_unaccept(self.user, finding, perform_save=False, post_comments=False) # Mitigate the endpoint statuses - self.endpoint_manager.mitigate_endpoint_status(finding.status_finding.all(), self.user, kwuser=self.user, sync=True) + dojo_dispatch_task( + self.endpoint_manager.mitigate_endpoint_status, + finding.status_finding.all(), + self.user, + kwuser=self.user, + sync=True, + ) # to avoid pushing a finding group multiple times, we push those outside of the loop if finding_groups_enabled and finding.finding_group: # don't try to dedupe findings that we are closing diff --git a/dojo/importers/default_importer.py b/dojo/importers/default_importer.py index 9dfa577099d..97cdc4201bf 100644 --- a/dojo/importers/default_importer.py +++ b/dojo/importers/default_importer.py @@ -7,6 +7,7 @@ from django.urls import reverse import dojo.jira_link.helper as jira_helper +from dojo.celery_dispatch import dojo_dispatch_task from dojo.finding import helper as finding_helper from dojo.importers.base_importer import BaseImporter, Parser from dojo.importers.options import ImporterOptions @@ -260,7 +261,8 @@ def process_findings( batch_finding_ids.clear() logger.debug("process_findings: dispatching batch with push_to_jira=%s (batch_size=%d, is_final=%s)", push_to_jira, len(finding_ids_batch), is_final_finding) - finding_helper.post_process_findings_batch( + dojo_dispatch_task( + finding_helper.post_process_findings_batch, finding_ids_batch, dedupe_option=True, rules_option=True, diff --git a/dojo/importers/default_reimporter.py b/dojo/importers/default_reimporter.py index d80b0de8b55..f23e4418336 100644 --- a/dojo/importers/default_reimporter.py +++ b/dojo/importers/default_reimporter.py @@ -7,6 +7,7 @@ import dojo.finding.helper as finding_helper import dojo.jira_link.helper as jira_helper +from dojo.celery_dispatch import dojo_dispatch_task from dojo.finding.deduplication import ( find_candidates_for_deduplication_hash, find_candidates_for_deduplication_uid_or_hash, @@ -414,7 +415,8 @@ def process_findings( if len(batch_finding_ids) >= dedupe_batch_max_size or is_final: finding_ids_batch = list(batch_finding_ids) batch_finding_ids.clear() - finding_helper.post_process_findings_batch( + dojo_dispatch_task( + finding_helper.post_process_findings_batch, finding_ids_batch, dedupe_option=True, rules_option=True, diff --git a/dojo/importers/endpoint_manager.py b/dojo/importers/endpoint_manager.py index f4b277d49fa..5fdd4603aad 100644 --- a/dojo/importers/endpoint_manager.py +++ b/dojo/importers/endpoint_manager.py @@ -5,7 +5,7 @@ from django.utils import timezone from dojo.celery import app -from dojo.decorators import dojo_async_task +from dojo.celery_dispatch import dojo_dispatch_task from dojo.endpoint.utils import endpoint_get_or_create from dojo.models import ( Dojo_User, @@ -18,17 +18,15 @@ class EndpointManager: - @dojo_async_task - @app.task() + @app.task def add_endpoints_to_unsaved_finding( - self, - finding: Finding, + finding: Finding, # noqa: N805 endpoints: list[Endpoint], **kwargs: dict, ) -> None: """Creates Endpoint objects for a single finding and creates the link via the endpoint status""" logger.debug(f"IMPORT_SCAN: Adding {len(endpoints)} endpoints to finding: {finding}") - self.clean_unsaved_endpoints(endpoints) + EndpointManager.clean_unsaved_endpoints(endpoints) for endpoint in endpoints: ep = None eps = [] @@ -41,7 +39,8 @@ def add_endpoints_to_unsaved_finding( path=endpoint.path, query=endpoint.query, fragment=endpoint.fragment, - product=finding.test.engagement.product) + product=finding.test.engagement.product, + ) eps.append(ep) except (MultipleObjectsReturned): msg = ( @@ -58,11 +57,9 @@ def add_endpoints_to_unsaved_finding( logger.debug(f"IMPORT_SCAN: {len(endpoints)} endpoints imported") - @dojo_async_task - @app.task() + @app.task def mitigate_endpoint_status( - self, - endpoint_status_list: list[Endpoint_Status], + endpoint_status_list: list[Endpoint_Status], # noqa: N805 user: Dojo_User, **kwargs: dict, ) -> None: @@ -85,11 +82,9 @@ def mitigate_endpoint_status( batch_size=1000, ) - @dojo_async_task - @app.task() + @app.task def reactivate_endpoint_status( - self, - endpoint_status_list: list[Endpoint_Status], + endpoint_status_list: list[Endpoint_Status], # noqa: N805 **kwargs: dict, ) -> None: """Reactivate all endpoint status objects that are supplied""" @@ -118,10 +113,10 @@ def chunk_endpoints_and_disperse( endpoints: list[Endpoint], **kwargs: dict, ) -> None: - self.add_endpoints_to_unsaved_finding(finding, endpoints, sync=True) + dojo_dispatch_task(self.add_endpoints_to_unsaved_finding, finding, endpoints, sync=True) + @staticmethod def clean_unsaved_endpoints( - self, endpoints: list[Endpoint], ) -> None: """ @@ -139,7 +134,7 @@ def chunk_endpoints_and_reactivate( endpoint_status_list: list[Endpoint_Status], **kwargs: dict, ) -> None: - self.reactivate_endpoint_status(endpoint_status_list, sync=True) + dojo_dispatch_task(self.reactivate_endpoint_status, endpoint_status_list, sync=True) def chunk_endpoints_and_mitigate( self, @@ -147,7 +142,7 @@ def chunk_endpoints_and_mitigate( user: Dojo_User, **kwargs: dict, ) -> None: - self.mitigate_endpoint_status(endpoint_status_list, user, sync=True) + dojo_dispatch_task(self.mitigate_endpoint_status, endpoint_status_list, user, sync=True) def update_endpoint_status( self, diff --git a/dojo/jira_link/helper.py b/dojo/jira_link/helper.py index 34c530975dc..ef9b883b5b9 100644 --- a/dojo/jira_link/helper.py +++ b/dojo/jira_link/helper.py @@ -18,7 +18,7 @@ from requests.auth import HTTPBasicAuth from dojo.celery import app -from dojo.decorators import dojo_async_task +from dojo.celery_dispatch import dojo_dispatch_task from dojo.forms import JIRAEngagementForm, JIRAProjectForm from dojo.models import ( Engagement, @@ -760,20 +760,19 @@ def push_to_jira(obj, *args, **kwargs): if isinstance(obj, Finding): if obj.has_finding_group: logger.debug("pushing finding group for %s to JIRA", obj) - return push_finding_group_to_jira(obj.finding_group.id, *args, **kwargs) - return push_finding_to_jira(obj.id, *args, **kwargs) + return dojo_dispatch_task(push_finding_group_to_jira, obj.finding_group.id, *args, **kwargs) + return dojo_dispatch_task(push_finding_to_jira, obj.id, *args, **kwargs) if isinstance(obj, Finding_Group): - return push_finding_group_to_jira(obj.id, *args, **kwargs) + return dojo_dispatch_task(push_finding_group_to_jira, obj.id, *args, **kwargs) if isinstance(obj, Engagement): - return push_engagement_to_jira(obj.id, *args, **kwargs) + return dojo_dispatch_task(push_engagement_to_jira, obj.id, *args, **kwargs) logger.error("unsupported object passed to push_to_jira: %s %i %s", obj.__name__, obj.id, obj) return None # we need thre separate celery tasks due to the decorators we're using to map to/from ids -@dojo_async_task @app.task def push_finding_to_jira(finding_id, *args, **kwargs): finding = get_object_or_none(Finding, id=finding_id) @@ -786,7 +785,6 @@ def push_finding_to_jira(finding_id, *args, **kwargs): return add_jira_issue(finding, *args, **kwargs) -@dojo_async_task @app.task def push_finding_group_to_jira(finding_group_id, *args, **kwargs): finding_group = get_object_or_none(Finding_Group, id=finding_group_id) @@ -803,7 +801,6 @@ def push_finding_group_to_jira(finding_group_id, *args, **kwargs): return add_jira_issue(finding_group, *args, **kwargs) -@dojo_async_task @app.task def push_engagement_to_jira(engagement_id, *args, **kwargs): engagement = get_object_or_none(Engagement, id=engagement_id) @@ -812,8 +809,8 @@ def push_engagement_to_jira(engagement_id, *args, **kwargs): return None if engagement.has_jira_issue: - return update_epic(engagement.id, *args, **kwargs) - return add_epic(engagement.id, *args, **kwargs) + return dojo_dispatch_task(update_epic, engagement.id, *args, **kwargs) + return dojo_dispatch_task(add_epic, engagement.id, *args, **kwargs) def add_issues_to_epic(jira, obj, epic_id, issue_keys, *, ignore_epics=True): @@ -1376,7 +1373,6 @@ def jira_check_attachment(issue, source_file_name): return file_exists -@dojo_async_task @app.task def close_epic(engagement_id, push_to_jira, **kwargs): engagement = get_object_or_none(Engagement, id=engagement_id) @@ -1425,7 +1421,6 @@ def close_epic(engagement_id, push_to_jira, **kwargs): return False -@dojo_async_task @app.task def update_epic(engagement_id, **kwargs): engagement = get_object_or_none(Engagement, id=engagement_id) @@ -1472,7 +1467,6 @@ def update_epic(engagement_id, **kwargs): return False -@dojo_async_task @app.task def add_epic(engagement_id, **kwargs): engagement = get_object_or_none(Engagement, id=engagement_id) @@ -1581,10 +1575,9 @@ def add_comment(obj, note, *, force_push=False, **kwargs): return False # Call the internal task with IDs (runs synchronously within this task) - return add_comment_internal(jira_issue.id, note.id, force_push=force_push, **kwargs) + return dojo_dispatch_task(add_comment_internal, jira_issue.id, note.id, force_push=force_push, **kwargs) -@dojo_async_task @app.task def add_comment_internal(jira_issue_id, note_id, *, force_push=False, **kwargs): """Internal Celery task that adds a comment to a JIRA issue.""" diff --git a/dojo/management/commands/dedupe.py b/dojo/management/commands/dedupe.py index 913c528f299..b7e0e669157 100644 --- a/dojo/management/commands/dedupe.py +++ b/dojo/management/commands/dedupe.py @@ -118,14 +118,25 @@ def _run_dedupe(self, *, restrict_to_parsers, hash_code_only, dedupe_only, dedup mass_model_updater(Finding, findings, do_dedupe_finding_task_internal, fields=None, order="desc", page_size=100, log_prefix="deduplicating ") else: # async tasks only need the id - mass_model_updater(Finding, findings.only("id"), lambda f: do_dedupe_finding_task(f.id), fields=None, order="desc", log_prefix="deduplicating ") + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + mass_model_updater( + Finding, + findings.only("id"), + lambda f: dojo_dispatch_task(do_dedupe_finding_task, f.id), + fields=None, + order="desc", + log_prefix="deduplicating ", + ) if dedupe_sync: # update the grading (if enabled) and only useful in sync mode # in async mode the background task that grades products every hour will pick it up logger.debug("Updating grades for products...") for product in Product.objects.all(): - calculate_grade(product.id) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(calculate_grade, product.id) logger.info("######## Done deduplicating (%s) ########", ("foreground" if dedupe_sync else "tasks submitted to celery")) else: @@ -172,7 +183,9 @@ def _dedupe_batch_mode(self, findings_queryset, *, dedupe_sync: bool = True): else: # Asynchronous: submit task with finding IDs logger.debug(f"Submitting async batch task for {len(batch_finding_ids)} findings for test {test_id}") - do_dedupe_batch_task(batch_finding_ids) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(do_dedupe_batch_task, batch_finding_ids) total_processed += len(batch_finding_ids) batch_finding_ids = [] diff --git a/dojo/models.py b/dojo/models.py index f610d47bd64..eb577ce2dfe 100644 --- a/dojo/models.py +++ b/dojo/models.py @@ -1081,7 +1081,9 @@ def save(self, *args, **kwargs): super(Product, product).save() # launch the async task to update all finding sla expiration dates from dojo.sla_config.helpers import async_update_sla_expiration_dates_sla_config_sync # noqa: I001, PLC0415 circular import - async_update_sla_expiration_dates_sla_config_sync(self, products, severities=severities) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(async_update_sla_expiration_dates_sla_config_sync, self, products, severities=severities) def clean(self): sla_days = [self.critical, self.high, self.medium, self.low] @@ -1243,7 +1245,9 @@ def save(self, *args, **kwargs): super(SLA_Configuration, sla_config).save() # launch the async task to update all finding sla expiration dates from dojo.sla_config.helpers import async_update_sla_expiration_dates_sla_config_sync # noqa: I001, PLC0415 circular import - async_update_sla_expiration_dates_sla_config_sync(sla_config, Product.objects.filter(id=self.id)) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(async_update_sla_expiration_dates_sla_config_sync, sla_config, Product.objects.filter(id=self.id)) def get_absolute_url(self): return reverse("view_product", args=[str(self.id)]) diff --git a/dojo/notifications/helper.py b/dojo/notifications/helper.py index c4458daec01..ae62fc8f4d7 100644 --- a/dojo/notifications/helper.py +++ b/dojo/notifications/helper.py @@ -18,7 +18,8 @@ from dojo import __version__ as dd_version from dojo.authorization.roles_permissions import Permissions from dojo.celery import app -from dojo.decorators import dojo_async_task, we_want_async +from dojo.celery_dispatch import dojo_dispatch_task +from dojo.decorators import we_want_async from dojo.labels import get_labels from dojo.models import ( Alerts, @@ -199,8 +200,6 @@ class SlackNotificationManger(NotificationManagerHelpers): """Manger for slack notifications and their helpers.""" - @dojo_async_task - @app.task def send_slack_notification( self, event: str, @@ -317,8 +316,6 @@ class MSTeamsNotificationManger(NotificationManagerHelpers): """Manger for Microsoft Teams notifications and their helpers.""" - @dojo_async_task - @app.task def send_msteams_notification( self, event: str, @@ -368,8 +365,6 @@ class EmailNotificationManger(NotificationManagerHelpers): """Manger for email notifications and their helpers.""" - @dojo_async_task - @app.task def send_mail_notification( self, event: str, @@ -420,8 +415,6 @@ class WebhookNotificationManger(NotificationManagerHelpers): ERROR_PERMANENT = "permanent" ERROR_TEMPORARY = "temporary" - @dojo_async_task - @app.task def send_webhooks_notification( self, event: str, @@ -480,11 +473,7 @@ def send_webhooks_notification( endpoint.first_error = now endpoint.status = Notification_Webhooks.Status.STATUS_INACTIVE_TMP # In case of failure within one day, endpoint can be deactivated temporally only for one minute - self._webhook_reactivation.apply_async( - args=[self], - kwargs={"endpoint_id": endpoint.pk}, - countdown=60, - ) + webhook_reactivation.apply_async(kwargs={"endpoint_id": endpoint.pk}, countdown=60) # There is no reason to keep endpoint active if it is returning 4xx errors else: endpoint.status = Notification_Webhooks.Status.STATUS_INACTIVE_PERMANENT @@ -559,7 +548,6 @@ def _test_webhooks_notification(self, endpoint: Notification_Webhooks) -> None: # in "send_webhooks_notification", we are doing deeper analysis, why it failed # for now, "raise_for_status" should be enough - @app.task(ignore_result=True) def _webhook_reactivation(self, endpoint_id: int, **_kwargs: dict): endpoint = Notification_Webhooks.objects.get(pk=endpoint_id) # User already changed status of endpoint @@ -832,9 +820,10 @@ def _process_notifications( notifications.other, ): logger.debug("Sending Slack Notification") - self._get_manager_instance("slack").send_slack_notification( + dojo_dispatch_task( + send_slack_notification, event, - user=notifications.user, + user_id=getattr(notifications.user, "id", None), **kwargs, ) @@ -844,9 +833,10 @@ def _process_notifications( notifications.other, ): logger.debug("Sending MSTeams Notification") - self._get_manager_instance("msteams").send_msteams_notification( + dojo_dispatch_task( + send_msteams_notification, event, - user=notifications.user, + user_id=getattr(notifications.user, "id", None), **kwargs, ) @@ -856,9 +846,10 @@ def _process_notifications( notifications.other, ): logger.debug("Sending Mail Notification") - self._get_manager_instance("mail").send_mail_notification( + dojo_dispatch_task( + send_mail_notification, event, - user=notifications.user, + user_id=getattr(notifications.user, "id", None), **kwargs, ) @@ -868,13 +859,43 @@ def _process_notifications( notifications.other, ): logger.debug("Sending Webhooks Notification") - self._get_manager_instance("webhooks").send_webhooks_notification( + dojo_dispatch_task( + send_webhooks_notification, event, - user=notifications.user, + user_id=getattr(notifications.user, "id", None), **kwargs, ) +@app.task +def send_slack_notification(event: str, user_id: int | None = None, **kwargs: dict) -> None: + user = Dojo_User.objects.get(pk=user_id) if user_id else None + SlackNotificationManger().send_slack_notification(event, user=user, **kwargs) + + +@app.task +def send_msteams_notification(event: str, user_id: int | None = None, **kwargs: dict) -> None: + user = Dojo_User.objects.get(pk=user_id) if user_id else None + MSTeamsNotificationManger().send_msteams_notification(event, user=user, **kwargs) + + +@app.task +def send_mail_notification(event: str, user_id: int | None = None, **kwargs: dict) -> None: + user = Dojo_User.objects.get(pk=user_id) if user_id else None + EmailNotificationManger().send_mail_notification(event, user=user, **kwargs) + + +@app.task +def send_webhooks_notification(event: str, user_id: int | None = None, **kwargs: dict) -> None: + user = Dojo_User.objects.get(pk=user_id) if user_id else None + WebhookNotificationManger().send_webhooks_notification(event, user=user, **kwargs) + + +@app.task(ignore_result=True) +def webhook_reactivation(endpoint_id: int, **_kwargs: dict) -> None: + WebhookNotificationManger()._webhook_reactivation(endpoint_id=endpoint_id) + + @app.task(ignore_result=True) def webhook_status_cleanup(*_args: list, **_kwargs: dict): # If some endpoint was affected by some outage (5xx, 429, Timeout) but it was clean during last 24 hours, @@ -902,4 +923,4 @@ def webhook_status_cleanup(*_args: list, **_kwargs: dict): ) for endpoint in broken_endpoints: manager = WebhookNotificationManger() - manager._webhook_reactivation(manager, endpoint_id=endpoint.pk) + manager._webhook_reactivation(endpoint_id=endpoint.pk) diff --git a/dojo/product/helpers.py b/dojo/product/helpers.py index aeadec0246d..8a308d9b62c 100644 --- a/dojo/product/helpers.py +++ b/dojo/product/helpers.py @@ -2,13 +2,11 @@ import logging from dojo.celery import app -from dojo.decorators import dojo_async_task from dojo.models import Endpoint, Engagement, Finding, Product, Test logger = logging.getLogger(__name__) -@dojo_async_task @app.task def propagate_tags_on_product(product_id, *args, **kwargs): with contextlib.suppress(Product.DoesNotExist): diff --git a/dojo/sla_config/helpers.py b/dojo/sla_config/helpers.py index da5899a85b0..045456f38d7 100644 --- a/dojo/sla_config/helpers.py +++ b/dojo/sla_config/helpers.py @@ -1,14 +1,12 @@ import logging from dojo.celery import app -from dojo.decorators import dojo_async_task from dojo.models import Finding, Product, SLA_Configuration, System_Settings from dojo.utils import get_custom_method, mass_model_updater logger = logging.getLogger(__name__) -@dojo_async_task @app.task def async_update_sla_expiration_dates_sla_config_sync(sla_config: SLA_Configuration, products: list[Product], *args, severities: list[str] | None = None, **kwargs): if method := get_custom_method("FINDING_SLA_EXPIRATION_CALCULATION_METHOD"): diff --git a/dojo/tags_signals.py b/dojo/tags_signals.py index 0cade958265..6b11b00644c 100644 --- a/dojo/tags_signals.py +++ b/dojo/tags_signals.py @@ -4,6 +4,7 @@ from django.db.models import signals from django.dispatch import receiver +from dojo.celery_dispatch import dojo_dispatch_task from dojo.models import Endpoint, Engagement, Finding, Product, Test from dojo.product import helpers as async_product_funcs from dojo.utils import get_system_setting @@ -19,7 +20,7 @@ def product_tags_post_add_remove(sender, instance, action, **kwargs): running_async_process = instance.running_async_process # Check if the async process is already running to avoid calling it a second time if not running_async_process and inherit_product_tags(instance): - async_product_funcs.propagate_tags_on_product(instance.id, countdown=5) + dojo_dispatch_task(async_product_funcs.propagate_tags_on_product, instance.id, countdown=5) instance.running_async_process = True diff --git a/dojo/tasks.py b/dojo/tasks.py index 29dfe11257c..8a566c9acfc 100644 --- a/dojo/tasks.py +++ b/dojo/tasks.py @@ -2,6 +2,7 @@ from datetime import timedelta import pghistory +from celery import Task from celery.utils.log import get_task_logger from django.apps import apps from django.conf import settings @@ -12,7 +13,7 @@ from dojo.auditlog import run_flush_auditlog from dojo.celery import app -from dojo.decorators import dojo_async_task +from dojo.celery_dispatch import dojo_dispatch_task from dojo.finding.helper import fix_loop_duplicates from dojo.management.commands.jira_status_reconciliation import jira_status_reconciliation from dojo.models import Alerts, Announcement, Endpoint, Engagement, Finding, Product, System_Settings, User @@ -72,7 +73,7 @@ def add_alerts(self, runinterval): if system_settings.enable_product_grade: products = Product.objects.all() for product in products: - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) @app.task(bind=True) @@ -169,11 +170,18 @@ def _async_dupe_delete_impl(): if system_settings.enable_product_grade: logger.info("performing batch product grading for %s products", len(affected_products)) for product in affected_products: - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) -@app.task(ignore_result=False) +@app.task(ignore_result=False, base=Task) def celery_status(): + """ + Simple health check task to verify Celery is running. + + Uses base Task class (not PgHistoryTask) since it doesn't need: + - User context tracking + - Pghistory context (no database modifications) + """ return True @@ -237,7 +245,6 @@ def clear_sessions(*args, **kwargs): call_command("clearsessions") -@dojo_async_task @app.task def update_watson_search_index_for_model(model_name, pk_list, *args, **kwargs): """ diff --git a/dojo/templatetags/display_tags.py b/dojo/templatetags/display_tags.py index 5a36e901011..00a348daf43 100644 --- a/dojo/templatetags/display_tags.py +++ b/dojo/templatetags/display_tags.py @@ -363,7 +363,9 @@ def product_grade(product): if system_settings.enable_product_grade and product: prod_numeric_grade = product.prod_numeric_grade if not prod_numeric_grade or prod_numeric_grade is None: - calculate_grade(product.id) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(calculate_grade, product.id) if prod_numeric_grade: if prod_numeric_grade >= system_settings.product_grade_a: grade = "A" diff --git a/dojo/test/views.py b/dojo/test/views.py index d37825822c6..9d0e593152e 100644 --- a/dojo/test/views.py +++ b/dojo/test/views.py @@ -26,6 +26,7 @@ from dojo.authorization.authorization import user_has_permission_or_403 from dojo.authorization.authorization_decorators import user_is_authorized from dojo.authorization.roles_permissions import Permissions +from dojo.celery_dispatch import dojo_dispatch_task from dojo.engagement.queries import get_authorized_engagements from dojo.filters import FindingFilter, FindingFilterWithoutObjectLookups, TemplateFindingFilter, TestImportFilter from dojo.finding.queries import prefetch_for_findings @@ -343,7 +344,7 @@ def copy_test(request, tid): engagement = form.cleaned_data.get("engagement") product = test.engagement.product test_copy = test.copy(engagement=engagement) - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) messages.add_message( request, messages.SUCCESS, diff --git a/dojo/tools/tool_issue_updater.py b/dojo/tools/tool_issue_updater.py index 854fb989113..8211e166eed 100644 --- a/dojo/tools/tool_issue_updater.py +++ b/dojo/tools/tool_issue_updater.py @@ -3,7 +3,7 @@ import pghistory from dojo.celery import app -from dojo.decorators import dojo_async_task +from dojo.celery_dispatch import dojo_dispatch_task from dojo.models import Finding from dojo.tools.api_sonarqube.parser import SCAN_SONARQUBE_API from dojo.tools.api_sonarqube.updater import SonarQubeApiUpdater @@ -15,7 +15,7 @@ def async_tool_issue_update(finding, *args, **kwargs): if is_tool_issue_updater_needed(finding): - tool_issue_updater(finding.id) + dojo_dispatch_task(tool_issue_updater, finding.id) def is_tool_issue_updater_needed(finding, *args, **kwargs): @@ -23,7 +23,6 @@ def is_tool_issue_updater_needed(finding, *args, **kwargs): return test_type.name == SCAN_SONARQUBE_API -@dojo_async_task @app.task def tool_issue_updater(finding_id, *args, **kwargs): finding = get_object_or_none(Finding, id=finding_id) @@ -37,7 +36,6 @@ def tool_issue_updater(finding_id, *args, **kwargs): SonarQubeApiUpdater().update_sonarqube_finding(finding) -@dojo_async_task @app.task def update_findings_from_source_issues(**kwargs): # Wrap with pghistory context for audit trail diff --git a/dojo/utils.py b/dojo/utils.py index 1463ae44970..8f5ba8fe04a 100644 --- a/dojo/utils.py +++ b/dojo/utils.py @@ -45,7 +45,6 @@ from dojo.authorization.roles_permissions import Permissions from dojo.celery import app -from dojo.decorators import dojo_async_task from dojo.finding.queries import get_authorized_findings from dojo.github import ( add_external_issue_github, @@ -1053,7 +1052,6 @@ def handle_uploaded_selenium(f, cred): cred.save() -@dojo_async_task @app.task def add_external_issue(finding_id, external_issue_provider, **kwargs): finding = get_object_or_none(Finding, id=finding_id) @@ -1069,7 +1067,6 @@ def add_external_issue(finding_id, external_issue_provider, **kwargs): add_external_issue_github(finding, prod, eng) -@dojo_async_task @app.task def update_external_issue(finding_id, old_status, external_issue_provider, **kwargs): finding = get_object_or_none(Finding, id=finding_id) @@ -1084,7 +1081,6 @@ def update_external_issue(finding_id, old_status, external_issue_provider, **kwa update_external_issue_github(finding, prod, eng) -@dojo_async_task @app.task def close_external_issue(finding_id, note, external_issue_provider, **kwargs): finding = get_object_or_none(Finding, id=finding_id) @@ -1099,7 +1095,6 @@ def close_external_issue(finding_id, note, external_issue_provider, **kwargs): close_external_issue_github(finding, note, prod, eng) -@dojo_async_task @app.task def reopen_external_issue(finding_id, note, external_issue_provider, **kwargs): finding = get_object_or_none(Finding, id=finding_id) @@ -1255,7 +1250,6 @@ def grade_product(crit, high, med, low): return max(health, 5) -@dojo_async_task @app.task def calculate_grade(product_id, *args, **kwargs): product = get_object_or_none(Product, id=product_id) @@ -1313,7 +1307,9 @@ def calculate_grade_internal(product, *args, **kwargs): def perform_product_grading(product): system_settings = System_Settings.objects.get() if system_settings.enable_product_grade: - calculate_grade(product.id) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(calculate_grade, product.id) def get_celery_worker_status(): @@ -2012,129 +2008,187 @@ def is_finding_groups_enabled(): return get_system_setting("enable_finding_groups") -class async_delete: - def __init__(self, *args, **kwargs): - self.mapping = { - "Product_Type": [ - (Endpoint, "product__prod_type__id"), - (Finding, "test__engagement__product__prod_type__id"), - (Test, "engagement__product__prod_type__id"), - (Engagement, "product__prod_type__id"), - (Product, "prod_type__id")], - "Product": [ - (Endpoint, "product__id"), - (Finding, "test__engagement__product__id"), - (Test, "engagement__product__id"), - (Engagement, "product__id")], - "Engagement": [ - (Finding, "test__engagement__id"), - (Test, "engagement__id")], - "Test": [(Finding, "test__id")], - } +# Mapping of object types to their related models for cascading deletes +ASYNC_DELETE_MAPPING = { + "Product_Type": [ + (Endpoint, "product__prod_type__id"), + (Finding, "test__engagement__product__prod_type__id"), + (Test, "engagement__product__prod_type__id"), + (Engagement, "product__prod_type__id"), + (Product, "prod_type__id")], + "Product": [ + (Endpoint, "product__id"), + (Finding, "test__engagement__product__id"), + (Test, "engagement__product__id"), + (Engagement, "product__id")], + "Engagement": [ + (Finding, "test__engagement__id"), + (Test, "engagement__id")], + "Test": [(Finding, "test__id")], +} + + +def _get_object_name(obj): + """Get the class name of an object or model class.""" + if obj.__class__.__name__ == "ModelBase": + return obj.__name__ + return obj.__class__.__name__ - @dojo_async_task - @app.task - def delete_chunk(self, objects, **kwargs): - # Now delete all objects with retry for deadlocks - max_retries = 3 - for obj in objects: - retry_count = 0 - while retry_count < max_retries: - try: - obj.delete() - break # Success, exit retry loop - except OperationalError as e: - error_msg = str(e) - if "deadlock detected" in error_msg.lower(): - retry_count += 1 - if retry_count < max_retries: - # Exponential backoff with jitter - wait_time = (2 ** retry_count) + random.uniform(0, 1) # noqa: S311 - logger.warning( - f"ASYNC_DELETE: Deadlock detected deleting {self.get_object_name(obj)} {obj.pk}, " - f"retrying ({retry_count}/{max_retries}) after {wait_time:.2f}s", - ) - time.sleep(wait_time) - # Refresh object from DB before retry - obj.refresh_from_db() - else: - logger.error( - f"ASYNC_DELETE: Deadlock persisted after {max_retries} retries for {self.get_object_name(obj)} {obj.pk}: {e}", - ) - raise + +@app.task +def async_delete_chunk_task(objects, **kwargs): + """ + Module-level Celery task to delete a chunk of objects. + + Accepts **kwargs for async_user and _pgh_context injected by dojo_dispatch_task. + Uses PgHistoryTask base class (default) to preserve pghistory context for audit trail. + """ + max_retries = 3 + for obj in objects: + retry_count = 0 + while retry_count < max_retries: + try: + obj.delete() + break # Success, exit retry loop + except OperationalError as e: + error_msg = str(e) + if "deadlock detected" in error_msg.lower(): + retry_count += 1 + if retry_count < max_retries: + # Exponential backoff with jitter + wait_time = (2 ** retry_count) + random.uniform(0, 1) # noqa: S311 + logger.warning( + f"ASYNC_DELETE: Deadlock detected deleting {_get_object_name(obj)} {obj.pk}, " + f"retrying ({retry_count}/{max_retries}) after {wait_time:.2f}s", + ) + time.sleep(wait_time) + # Refresh object from DB before retry + obj.refresh_from_db() else: - # Not a deadlock, re-raise + logger.error( + f"ASYNC_DELETE: Deadlock persisted after {max_retries} retries for {_get_object_name(obj)} {obj.pk}: {e}", + ) raise - except AssertionError: - logger.debug("ASYNC_DELETE: object has already been deleted elsewhere. Skipping") - # The id must be None - # The object has already been deleted elsewhere - break - except LogEntry.MultipleObjectsReturned: - # Delete the log entrys first, then delete - LogEntry.objects.filter( - content_type=ContentType.objects.get_for_model(obj.__class__), - object_pk=str(obj.pk), - action=LogEntry.Action.DELETE, - ).delete() - # Now delete the object again (no retry needed for this case) - obj.delete() - break - - @dojo_async_task - @app.task + else: + # Not a deadlock, re-raise + raise + except AssertionError: + logger.debug("ASYNC_DELETE: object has already been deleted elsewhere. Skipping") + # The id must be None + # The object has already been deleted elsewhere + break + except LogEntry.MultipleObjectsReturned: + # Delete the log entrys first, then delete + LogEntry.objects.filter( + content_type=ContentType.objects.get_for_model(obj.__class__), + object_pk=str(obj.pk), + action=LogEntry.Action.DELETE, + ).delete() + # Now delete the object again (no retry needed for this case) + obj.delete() + break + + +@app.task +def async_delete_crawl_task(obj, model_list, **kwargs): + """ + Module-level Celery task to crawl and delete related objects. + + Accepts **kwargs for async_user and _pgh_context injected by dojo_dispatch_task. + Uses PgHistoryTask base class (default) to preserve pghistory context for audit trail. + """ + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + logger.debug("ASYNC_DELETE: Crawling " + _get_object_name(obj) + ": " + str(obj)) + for model_info in model_list: + task_results = [] + model = model_info[0] + model_query = model_info[1] + filter_dict = {model_query: obj.id} + # Only fetch the IDs since we will make a list of IDs in the following function call + objects_to_delete = model.objects.only("id").filter(**filter_dict).distinct().order_by("id") + logger.debug("ASYNC_DELETE: Deleting " + str(len(objects_to_delete)) + " " + _get_object_name(model) + "s in chunks") + chunk_size = get_setting("ASYNC_OBEJECT_DELETE_CHUNK_SIZE") + chunks = [objects_to_delete[i:i + chunk_size] for i in range(0, len(objects_to_delete), chunk_size)] + logger.debug("ASYNC_DELETE: Split " + _get_object_name(model) + " into " + str(len(chunks)) + " chunks of " + str(chunk_size)) + for chunk in chunks: + logger.debug(f"deleting {len(chunk)} {_get_object_name(model)}") + result = dojo_dispatch_task(async_delete_chunk_task, list(chunk)) + # Collect async task results to wait for them all at once + if hasattr(result, "get"): + task_results.append(result) + # Wait for all chunk deletions to complete (they run in parallel) + for task_result in task_results: + task_result.get(timeout=300) # 5 minute timeout per chunk + # Now delete the main object after all chunks are done + result = dojo_dispatch_task(async_delete_chunk_task, [obj]) + # Wait for final deletion to complete + if hasattr(result, "get"): + result.get(timeout=300) # 5 minute timeout + logger.debug("ASYNC_DELETE: Successfully deleted " + _get_object_name(obj) + ": " + str(obj)) + + +@app.task +def async_delete_task(obj, **kwargs): + """ + Module-level Celery task to delete an object and its related objects. + + Accepts **kwargs for async_user and _pgh_context injected by dojo_dispatch_task. + Uses PgHistoryTask base class (default) to preserve pghistory context for audit trail. + """ + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + logger.debug("ASYNC_DELETE: Deleting " + _get_object_name(obj) + ": " + str(obj)) + model_list = ASYNC_DELETE_MAPPING.get(_get_object_name(obj)) + if model_list: + # The object to be deleted was found in the object list + dojo_dispatch_task(async_delete_crawl_task, obj, model_list) + else: + # The object is not supported in async delete, delete normally + logger.debug("ASYNC_DELETE: " + _get_object_name(obj) + " async delete not supported. Deleteing normally: " + str(obj)) + obj.delete() + + +class async_delete: + + """ + Entry point class for async object deletion. + + Usage: + async_del = async_delete() + async_del.delete(instance) + + This class dispatches deletion to module-level Celery tasks via dojo_dispatch_task, + which properly handles user context injection and pghistory context. + """ + + def __init__(self, *args, **kwargs): + # Keep mapping reference for backwards compatibility + self.mapping = ASYNC_DELETE_MAPPING + def delete(self, obj, **kwargs): - logger.debug("ASYNC_DELETE: Deleting " + self.get_object_name(obj) + ": " + str(obj)) - model_list = self.mapping.get(self.get_object_name(obj), None) - if model_list: - # The object to be deleted was found in the object list - self.crawl(obj, model_list) - else: - # The object is not supported in async delete, delete normally - logger.debug("ASYNC_DELETE: " + self.get_object_name(obj) + " async delete not supported. Deleteing normally: " + str(obj)) - obj.delete() - - @dojo_async_task - @app.task - def crawl(self, obj, model_list, **kwargs): - logger.debug("ASYNC_DELETE: Crawling " + self.get_object_name(obj) + ": " + str(obj)) - for model_info in model_list: - task_results = [] - model = model_info[0] - model_query = model_info[1] - filter_dict = {model_query: obj.id} - # Only fetch the IDs since we will make a list of IDs in the following function call - objects_to_delete = model.objects.only("id").filter(**filter_dict).distinct().order_by("id") - logger.debug("ASYNC_DELETE: Deleting " + str(len(objects_to_delete)) + " " + self.get_object_name(model) + "s in chunks") - chunks = self.chunk_list(model, objects_to_delete) - for chunk in chunks: - logger.debug(f"deleting {len(chunk)} {self.get_object_name(model)}") - result = self.delete_chunk(chunk) - # Collect async task results to wait for them all at once - if hasattr(result, "get"): - task_results.append(result) - # Wait for all chunk deletions to complete (they run in parallel) - for task_result in task_results: - task_result.get(timeout=300) # 5 minute timeout per chunk - # Now delete the main object after all chunks are done - result = self.delete_chunk([obj]) - # Wait for final deletion to complete - if hasattr(result, "get"): - result.get(timeout=300) # 5 minute timeout - logger.debug("ASYNC_DELETE: Successfully deleted " + self.get_object_name(obj) + ": " + str(obj)) - - def chunk_list(self, model, full_list): + """ + Entry point to delete an object asynchronously. + + Dispatches to async_delete_task via dojo_dispatch_task to ensure proper + handling of async_user and _pgh_context. + """ + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(async_delete_task, obj, **kwargs) + + # Keep helper methods for backwards compatibility and potential direct use + @staticmethod + def get_object_name(obj): + return _get_object_name(obj) + + @staticmethod + def chunk_list(model, full_list): chunk_size = get_setting("ASYNC_OBEJECT_DELETE_CHUNK_SIZE") - # Break the list of objects into "chunk_size" lists chunk_list = [full_list[i:i + chunk_size] for i in range(0, len(full_list), chunk_size)] - logger.debug("ASYNC_DELETE: Split " + self.get_object_name(model) + " into " + str(len(chunk_list)) + " chunks of " + str(chunk_size)) + logger.debug("ASYNC_DELETE: Split " + _get_object_name(model) + " into " + str(len(chunk_list)) + " chunks of " + str(chunk_size)) return chunk_list - def get_object_name(self, obj): - if obj.__class__.__name__ == "ModelBase": - return obj.__name__ - return obj.__class__.__name__ - @receiver(user_logged_in) def log_user_login(sender, request, user, **kwargs): diff --git a/unittests/test_async_delete.py b/unittests/test_async_delete.py new file mode 100644 index 00000000000..341723e8296 --- /dev/null +++ b/unittests/test_async_delete.py @@ -0,0 +1,313 @@ +""" +Unit tests for async_delete functionality. + +These tests verify that the async_delete class works correctly with dojo_dispatch_task, +which injects async_user and _pgh_context kwargs into task calls. + +The original bug was that @app.task decorated instance methods didn't properly handle +the injected kwargs, causing TypeError: unexpected keyword argument 'async_user'. +""" +import logging + +from crum import impersonate +from django.contrib.auth.models import User +from django.test import override_settings +from django.utils import timezone + +from dojo.models import Engagement, Finding, Product, Product_Type, Test, Test_Type, UserContactInfo +from dojo.utils import async_delete + +from .dojo_test_case import DojoTestCase + +logger = logging.getLogger(__name__) + + +class TestAsyncDelete(DojoTestCase): + + """ + Test async_delete functionality with dojo_dispatch_task kwargs injection. + + These tests use block_execution=True and crum.impersonate to run tasks synchronously, + which allows errors to surface immediately rather than being lost in background workers. + """ + + def setUp(self): + """Set up test user with block_execution=True and disable unneeded features.""" + super().setUp() + + # Create test user with block_execution=True to run tasks synchronously + self.testuser = User.objects.create( + username="test_async_delete_user", + is_staff=True, + is_superuser=True, + ) + UserContactInfo.objects.create(user=self.testuser, block_execution=True) + + # Log in as the test user (for API client) + self.client.force_login(self.testuser) + + # Disable features that might interfere with deletion + self.system_settings(enable_product_grade=False) + self.system_settings(enable_github=False) + self.system_settings(enable_jira=False) + + # Create base test data + self.product_type = Product_Type.objects.create(name="Test Product Type for Async Delete") + self.test_type = Test_Type.objects.get_or_create(name="Manual Test")[0] + + def tearDown(self): + """Clean up any remaining test data.""" + # Clean up in reverse order of dependencies + Finding.objects.filter(test__engagement__product__prod_type=self.product_type).delete() + Test.objects.filter(engagement__product__prod_type=self.product_type).delete() + Engagement.objects.filter(product__prod_type=self.product_type).delete() + Product.objects.filter(prod_type=self.product_type).delete() + self.product_type.delete() + + super().tearDown() + + def _create_product(self, name="Test Product"): + """Helper to create a product for testing.""" + return Product.objects.create( + name=name, + description="Test product for async delete", + prod_type=self.product_type, + ) + + def _create_engagement(self, product, name="Test Engagement"): + """Helper to create an engagement for testing.""" + return Engagement.objects.create( + name=name, + product=product, + target_start=timezone.now(), + target_end=timezone.now(), + ) + + def _create_test(self, engagement, name="Test"): + """Helper to create a test for testing.""" + return Test.objects.create( + engagement=engagement, + test_type=self.test_type, + target_start=timezone.now(), + target_end=timezone.now(), + ) + + def _create_finding(self, test, title="Test Finding"): + """Helper to create a finding for testing.""" + return Finding.objects.create( + test=test, + title=title, + severity="High", + description="Test finding for async delete", + mitigation="Test mitigation", + impact="Test impact", + reporter=self.testuser, + ) + + @override_settings(ASYNC_OBJECT_DELETE=True) + def test_async_delete_simple_object(self): + """ + Test that async_delete works for a simple object (Finding). + + Finding is not in the async_delete mapping, so it falls back to direct delete. + This tests that the module-level task accepts **kwargs properly. + """ + product = self._create_product() + engagement = self._create_engagement(product) + test = self._create_test(engagement) + finding = self._create_finding(test) + finding_pk = finding.pk + + # Use impersonate to set current user context (required for block_execution to work) + with impersonate(self.testuser): + # This would raise TypeError before the fix: + # TypeError: delete() got an unexpected keyword argument 'async_user' + async_del = async_delete() + async_del.delete(finding) + + # Verify the finding was deleted + self.assertFalse( + Finding.objects.filter(pk=finding_pk).exists(), + "Finding should be deleted", + ) + + @override_settings(ASYNC_OBJECT_DELETE=True) + def test_async_delete_test_with_findings(self): + """ + Test that async_delete cascades deletion for Test objects. + + Test is in the async_delete mapping and should cascade delete its findings. + """ + product = self._create_product() + engagement = self._create_engagement(product) + test = self._create_test(engagement) + finding1 = self._create_finding(test, "Finding 1") + finding2 = self._create_finding(test, "Finding 2") + + test_pk = test.pk + finding1_pk = finding1.pk + finding2_pk = finding2.pk + + # Use impersonate to set current user context (required for block_execution to work) + with impersonate(self.testuser): + # Delete the test (should cascade to findings) + async_del = async_delete() + async_del.delete(test) + + # Verify all objects were deleted + self.assertFalse( + Test.objects.filter(pk=test_pk).exists(), + "Test should be deleted", + ) + self.assertFalse( + Finding.objects.filter(pk=finding1_pk).exists(), + "Finding 1 should be deleted via cascade", + ) + self.assertFalse( + Finding.objects.filter(pk=finding2_pk).exists(), + "Finding 2 should be deleted via cascade", + ) + + @override_settings(ASYNC_OBJECT_DELETE=True) + def test_async_delete_engagement_with_tests(self): + """ + Test that async_delete cascades deletion for Engagement objects. + + Engagement is in the async_delete mapping and should cascade delete + its tests and findings. + """ + product = self._create_product() + engagement = self._create_engagement(product) + test1 = self._create_test(engagement, "Test 1") + test2 = self._create_test(engagement, "Test 2") + finding1 = self._create_finding(test1, "Finding in Test 1") + finding2 = self._create_finding(test2, "Finding in Test 2") + + engagement_pk = engagement.pk + test1_pk = test1.pk + test2_pk = test2.pk + finding1_pk = finding1.pk + finding2_pk = finding2.pk + + # Use impersonate to set current user context (required for block_execution to work) + with impersonate(self.testuser): + # Delete the engagement (should cascade to tests and findings) + async_del = async_delete() + async_del.delete(engagement) + + # Verify all objects were deleted + self.assertFalse( + Engagement.objects.filter(pk=engagement_pk).exists(), + "Engagement should be deleted", + ) + self.assertFalse( + Test.objects.filter(pk__in=[test1_pk, test2_pk]).exists(), + "Tests should be deleted via cascade", + ) + self.assertFalse( + Finding.objects.filter(pk__in=[finding1_pk, finding2_pk]).exists(), + "Findings should be deleted via cascade", + ) + + @override_settings(ASYNC_OBJECT_DELETE=True) + def test_async_delete_product_with_hierarchy(self): + """ + Test that async_delete cascades deletion for Product objects. + + Product is in the async_delete mapping and should cascade delete + its engagements, tests, and findings. + """ + product = self._create_product() + engagement = self._create_engagement(product) + test = self._create_test(engagement) + finding = self._create_finding(test) + + product_pk = product.pk + engagement_pk = engagement.pk + test_pk = test.pk + finding_pk = finding.pk + + # Use impersonate to set current user context (required for block_execution to work) + with impersonate(self.testuser): + # Delete the product (should cascade to everything) + async_del = async_delete() + async_del.delete(product) + + # Verify all objects were deleted + self.assertFalse( + Product.objects.filter(pk=product_pk).exists(), + "Product should be deleted", + ) + self.assertFalse( + Engagement.objects.filter(pk=engagement_pk).exists(), + "Engagement should be deleted via cascade", + ) + self.assertFalse( + Test.objects.filter(pk=test_pk).exists(), + "Test should be deleted via cascade", + ) + self.assertFalse( + Finding.objects.filter(pk=finding_pk).exists(), + "Finding should be deleted via cascade", + ) + + @override_settings(ASYNC_OBJECT_DELETE=True) + def test_async_delete_accepts_sync_kwarg(self): + """ + Test that async_delete passes through the sync kwarg properly. + + The sync=True kwarg forces synchronous execution for the top-level task. + However, nested task dispatches still need user context to run synchronously, + so we use impersonate here as well. + """ + product = self._create_product() + product_pk = product.pk + + # Use impersonate to ensure nested tasks also run synchronously + with impersonate(self.testuser): + # Explicitly pass sync=True + async_del = async_delete() + async_del.delete(product, sync=True) + + # Verify the product was deleted + self.assertFalse( + Product.objects.filter(pk=product_pk).exists(), + "Product should be deleted with sync=True", + ) + + def test_async_delete_helper_methods(self): + """ + Test that static helper methods on async_delete class still work. + + These are kept for backwards compatibility. + """ + product = self._create_product() + + # Test get_object_name + self.assertEqual( + async_delete.get_object_name(product), + "Product", + "get_object_name should return class name", + ) + + # Test get_object_name with model class + self.assertEqual( + async_delete.get_object_name(Product), + "Product", + "get_object_name should work with model class", + ) + + def test_async_delete_mapping_preserved(self): + """ + Test that the mapping attribute is preserved on async_delete instances. + + This ensures backwards compatibility for code that might access the mapping. + """ + async_del = async_delete() + + # Verify mapping exists and has expected keys + self.assertIsNotNone(async_del.mapping) + self.assertIn("Product", async_del.mapping) + self.assertIn("Product_Type", async_del.mapping) + self.assertIn("Engagement", async_del.mapping) + self.assertIn("Test", async_del.mapping) diff --git a/unittests/test_importers_performance.py b/unittests/test_importers_performance.py index c9bd839be00..ac425c62cff 100644 --- a/unittests/test_importers_performance.py +++ b/unittests/test_importers_performance.py @@ -265,20 +265,12 @@ def test_import_reimport_reimport_performance_pghistory_async(self): configure_pghistory_triggers() self._import_reimport_performance( - - - - expected_num_queries1=305, expected_num_async_tasks1=6, expected_num_queries2=232, expected_num_async_tasks2=17, expected_num_queries3=114, expected_num_async_tasks3=16, - - - - ) @override_settings(ENABLE_AUDITLOG=True) @@ -295,14 +287,12 @@ def test_import_reimport_reimport_performance_pghistory_no_async(self): testuser.usercontactinfo.save() self._import_reimport_performance( - expected_num_queries1=312, expected_num_async_tasks1=6, expected_num_queries2=239, expected_num_async_tasks2=17, expected_num_queries3=121, expected_num_async_tasks3=16, - ) @override_settings(ENABLE_AUDITLOG=True) @@ -464,11 +454,8 @@ def test_deduplication_performance_pghistory_no_async(self): testuser.usercontactinfo.save() self._deduplication_performance( - expected_num_queries1=281, expected_num_async_tasks1=7, expected_num_queries2=246, expected_num_async_tasks2=7, - - ) diff --git a/unittests/test_jira_import_and_pushing_api.py b/unittests/test_jira_import_and_pushing_api.py index 00019a955a9..da445948351 100644 --- a/unittests/test_jira_import_and_pushing_api.py +++ b/unittests/test_jira_import_and_pushing_api.py @@ -980,7 +980,7 @@ def test_engagement_epic_mapping_disabled_no_epic_and_push_findings(self): @patch("dojo.jira_link.helper.can_be_pushed_to_jira", return_value=(True, None, None)) @patch("dojo.jira_link.helper.is_push_all_issues", return_value=False) @patch("dojo.jira_link.helper.push_to_jira", return_value=None) - @patch("dojo.notifications.helper.WebhookNotificationManger.send_webhooks_notification") + @patch("dojo.notifications.helper.send_webhooks_notification") def test_bulk_edit_mixed_findings_and_groups_jira_push_bug(self, mock_webhooks, mock_push_to_jira, mock_is_push_all_issues, mock_can_be_pushed): """ Test the bug in bulk edit: when bulk editing findings where some are in groups diff --git a/unittests/test_notifications.py b/unittests/test_notifications.py index 7c5b289a211..c1956d82679 100644 --- a/unittests/test_notifications.py +++ b/unittests/test_notifications.py @@ -649,7 +649,7 @@ def test_webhook_reactivation(self): with self.subTest("active"): wh = Notification_Webhooks.objects.filter(owner=None).first() manager = WebhookNotificationManger() - manager._webhook_reactivation(manager, endpoint_id=wh.pk) + manager._webhook_reactivation(endpoint_id=wh.pk) updated_wh = Notification_Webhooks.objects.filter(owner=None).first() self.assertEqual(updated_wh.status, Notification_Webhooks.Status.STATUS_ACTIVE) @@ -668,7 +668,7 @@ def test_webhook_reactivation(self): with self.assertLogs("dojo.notifications.helper", level="DEBUG") as cm: manager = WebhookNotificationManger() - manager._webhook_reactivation(manager, endpoint_id=wh.pk) + manager._webhook_reactivation(endpoint_id=wh.pk) updated_wh = Notification_Webhooks.objects.filter(owner=None).first() self.assertEqual(updated_wh.status, Notification_Webhooks.Status.STATUS_ACTIVE_TMP)