From e9ec2cf75d01ce9d2cf20f6f8207d5baab92b9e9 Mon Sep 17 00:00:00 2001 From: Valentijn Scholten Date: Fri, 26 Dec 2025 18:26:16 +0100 Subject: [PATCH 01/36] remove dojo_model_to/from_id decorator --- dojo/finding/deduplication.py | 8 +++----- dojo/finding/helper.py | 19 +------------------ .../commands/test_celery_decorator.py | 14 -------------- dojo/tools/tool_issue_updater.py | 4 +--- 4 files changed, 5 insertions(+), 40 deletions(-) diff --git a/dojo/finding/deduplication.py b/dojo/finding/deduplication.py index fea6a83d584..1df0b88acc8 100644 --- a/dojo/finding/deduplication.py +++ b/dojo/finding/deduplication.py @@ -8,7 +8,7 @@ from django.db.models.query_utils import Q from dojo.celery import app -from dojo.decorators import dojo_async_task, dojo_model_from_id, dojo_model_to_id +from dojo.decorators import dojo_async_task from dojo.models import Finding, System_Settings logger = logging.getLogger(__name__) @@ -45,12 +45,10 @@ def get_finding_models_for_deduplication(finding_ids): ) -@dojo_model_to_id @dojo_async_task @app.task -@dojo_model_from_id -def do_dedupe_finding_task(new_finding, *args, **kwargs): - return do_dedupe_finding(new_finding, *args, **kwargs) +def do_dedupe_finding_task(new_finding_id, *args, **kwargs): + return do_dedupe_finding(Finding.objects.get(id=new_finding_id), *args, **kwargs) @dojo_async_task diff --git a/dojo/finding/helper.py b/dojo/finding/helper.py index 8b829455d21..db5761ed0d1 100644 --- a/dojo/finding/helper.py +++ b/dojo/finding/helper.py @@ -16,7 +16,7 @@ import dojo.jira_link.helper as jira_helper import dojo.risk_acceptance.helper as ra_helper from dojo.celery import app -from dojo.decorators import dojo_async_task, dojo_model_from_id, dojo_model_to_id +from dojo.decorators import dojo_async_task from dojo.endpoint.utils import endpoint_get_or_create, save_endpoints_to_add from dojo.file_uploads.helper import delete_related_files from dojo.finding.deduplication import ( @@ -390,25 +390,8 @@ def add_findings_to_auto_group(name, findings, group_by, *, create_finding_group finding_group.findings.add(*findings) -@dojo_model_to_id -@dojo_async_task(signature=True) -@app.task -@dojo_model_from_id -def post_process_finding_save_signature(finding, dedupe_option=True, rules_option=True, product_grading_option=True, # noqa: FBT002 - issue_updater_option=True, push_to_jira=False, user=None, *args, **kwargs): # noqa: FBT002 - this is bit hard to fix nice have this universally fixed - """ - Returns a task signature for post-processing a finding. This is useful for creating task signatures - that can be used in chords or groups or to await results. We need this extra method because of our dojo_async decorator. - If we use more of these celery features, we should probably move away from that decorator. - """ - return post_process_finding_save_internal(finding, dedupe_option, rules_option, product_grading_option, - issue_updater_option, push_to_jira, user, *args, **kwargs) - - -@dojo_model_to_id @dojo_async_task @app.task -@dojo_model_from_id def post_process_finding_save(finding, dedupe_option=True, rules_option=True, product_grading_option=True, # noqa: FBT002 issue_updater_option=True, push_to_jira=False, user=None, *args, **kwargs): # noqa: FBT002 - this is bit hard to fix nice have this universally fixed diff --git a/dojo/management/commands/test_celery_decorator.py b/dojo/management/commands/test_celery_decorator.py index ed9488541fa..16ccef02b39 100644 --- a/dojo/management/commands/test_celery_decorator.py +++ b/dojo/management/commands/test_celery_decorator.py @@ -7,7 +7,6 @@ from dojo.celery import app # from dojo.utils import get_system_setting, do_dedupe_finding, dojo_async_task -from dojo.decorators import dojo_async_task, dojo_model_from_id, dojo_model_to_id from dojo.models import Finding, Notes from dojo.utils import test_valentijn @@ -81,16 +80,3 @@ def wrapper(*args, **kwargs): @my_decorator_inside def my_test_task(new_finding, *args, **kwargs): logger.debug("oh la la what a nice task") - - -# example working with multiple parameters... -@dojo_model_to_id(parameter=1) -@dojo_model_to_id -@dojo_async_task -@app.task -@dojo_model_from_id(model=Notes, parameter=1) -@dojo_model_from_id -def test_valentijn_task(new_finding, note, **kwargs): - logger.debug("test_valentijn:") - logger.debug(new_finding) - logger.debug(note) diff --git a/dojo/tools/tool_issue_updater.py b/dojo/tools/tool_issue_updater.py index fd203edebea..5a3f2895658 100644 --- a/dojo/tools/tool_issue_updater.py +++ b/dojo/tools/tool_issue_updater.py @@ -1,7 +1,7 @@ import pghistory from dojo.celery import app -from dojo.decorators import dojo_async_task, dojo_model_from_id, dojo_model_to_id +from dojo.decorators import dojo_async_task from dojo.tools.api_sonarqube.parser import SCAN_SONARQUBE_API from dojo.tools.api_sonarqube.updater import SonarQubeApiUpdater from dojo.tools.api_sonarqube.updater_from_source import SonarQubeApiUpdaterFromSource @@ -17,10 +17,8 @@ def is_tool_issue_updater_needed(finding, *args, **kwargs): return test_type.name == SCAN_SONARQUBE_API -@dojo_model_to_id @dojo_async_task @app.task -@dojo_model_from_id def tool_issue_updater(finding, *args, **kwargs): test_type = finding.test.test_type From 4bda3d30b34287575a289702853c89cdf6190427 Mon Sep 17 00:00:00 2001 From: Valentijn Scholten Date: Fri, 26 Dec 2025 18:50:01 +0100 Subject: [PATCH 02/36] remove dojo_model_from/to_id --- dojo/api_v2/views.py | 4 +- dojo/endpoint/views.py | 2 +- dojo/engagement/services.py | 2 +- dojo/engagement/views.py | 2 +- dojo/finding/deduplication.py | 4 +- dojo/finding/helper.py | 17 +++-- dojo/finding/views.py | 20 ++--- dojo/jira_link/helper.py | 69 ++++++++++------- dojo/management/commands/dedupe.py | 6 +- dojo/models.py | 2 +- dojo/product/views.py | 2 +- dojo/settings/settings.dist.py | 4 +- dojo/tasks.py | 4 +- dojo/templatetags/display_tags.py | 2 +- dojo/test/views.py | 2 +- dojo/tools/tool_issue_updater.py | 14 +++- dojo/utils.py | 76 +++++++++++-------- unittests/test_jira_import_and_pushing_api.py | 2 +- 18 files changed, 140 insertions(+), 94 deletions(-) diff --git a/dojo/api_v2/views.py b/dojo/api_v2/views.py index 01318a70bc5..e8c7d278e65 100644 --- a/dojo/api_v2/views.py +++ b/dojo/api_v2/views.py @@ -678,13 +678,13 @@ def update_jira_epic(self, request, pk=None): try: if engagement.has_jira_issue: - jira_helper.update_epic(engagement, **request.data) + jira_helper.update_epic(engagement.id, **request.data) response = Response( {"info": "Jira Epic update query sent"}, status=status.HTTP_200_OK, ) else: - jira_helper.add_epic(engagement, **request.data) + jira_helper.add_epic(engagement.id, **request.data) response = Response( {"info": "Jira Epic create query sent"}, status=status.HTTP_200_OK, diff --git a/dojo/endpoint/views.py b/dojo/endpoint/views.py index 561a0135d45..1dc4df898c6 100644 --- a/dojo/endpoint/views.py +++ b/dojo/endpoint/views.py @@ -373,7 +373,7 @@ def endpoint_bulk_update_all(request, pid=None): product_calc = list(Product.objects.filter(endpoint__id__in=endpoints_to_update).distinct()) endpoints.delete() for prod in product_calc: - calculate_grade(prod) + calculate_grade(prod.id) if skipped_endpoint_count > 0: add_error_message_to_response(f"Skipped deletion of {skipped_endpoint_count} endpoints because you are not authorized.") diff --git a/dojo/engagement/services.py b/dojo/engagement/services.py index 18aed9e425b..cd70af1ea2c 100644 --- a/dojo/engagement/services.py +++ b/dojo/engagement/services.py @@ -16,7 +16,7 @@ def close_engagement(eng): eng.save() if jira_helper.get_jira_project(eng): - jira_helper.close_epic(eng, push_to_jira=True) + jira_helper.close_epic(eng.id, push_to_jira=True) def reopen_engagement(eng): diff --git a/dojo/engagement/views.py b/dojo/engagement/views.py index 9ecdabfdd9b..ebc2e09ce6b 100644 --- a/dojo/engagement/views.py +++ b/dojo/engagement/views.py @@ -390,7 +390,7 @@ def copy_engagement(request, eid): form = DoneForm(request.POST) if form.is_valid(): engagement_copy = engagement.copy() - calculate_grade(product) + calculate_grade(product.id) messages.add_message( request, messages.SUCCESS, diff --git a/dojo/finding/deduplication.py b/dojo/finding/deduplication.py index 1df0b88acc8..eb8baf40db0 100644 --- a/dojo/finding/deduplication.py +++ b/dojo/finding/deduplication.py @@ -48,7 +48,7 @@ def get_finding_models_for_deduplication(finding_ids): @dojo_async_task @app.task def do_dedupe_finding_task(new_finding_id, *args, **kwargs): - return do_dedupe_finding(Finding.objects.get(id=new_finding_id), *args, **kwargs) + return do_dedupe_finding_task_internal(Finding.objects.get(id=new_finding_id), *args, **kwargs) @dojo_async_task @@ -69,7 +69,7 @@ def do_dedupe_batch_task(finding_ids, *args, **kwargs): dedupe_batch_of_findings(findings) -def do_dedupe_finding(new_finding, *args, **kwargs): +def do_dedupe_finding_task_internal(new_finding, *args, **kwargs): from dojo.utils import get_custom_method # noqa: PLC0415 -- circular import if dedupe_method := get_custom_method("FINDING_DEDUPE_METHOD"): return dedupe_method(new_finding, *args, **kwargs) diff --git a/dojo/finding/helper.py b/dojo/finding/helper.py index db5761ed0d1..57d85086119 100644 --- a/dojo/finding/helper.py +++ b/dojo/finding/helper.py @@ -21,7 +21,7 @@ from dojo.file_uploads.helper import delete_related_files from dojo.finding.deduplication import ( dedupe_batch_of_findings, - do_dedupe_finding, + do_dedupe_finding_task_internal, get_finding_models_for_deduplication, ) from dojo.models import ( @@ -43,6 +43,7 @@ close_external_issue, do_false_positive_history, get_current_user, + get_object_or_none, mass_model_updater, to_str_typed, ) @@ -392,8 +393,12 @@ def add_findings_to_auto_group(name, findings, group_by, *, create_finding_group @dojo_async_task @app.task -def post_process_finding_save(finding, dedupe_option=True, rules_option=True, product_grading_option=True, # noqa: FBT002 +def post_process_finding_save(finding_id, dedupe_option=True, rules_option=True, product_grading_option=True, # noqa: FBT002 issue_updater_option=True, push_to_jira=False, user=None, *args, **kwargs): # noqa: FBT002 - this is bit hard to fix nice have this universally fixed + finding = get_object_or_none(Finding, id=finding_id) + if not finding: + logger.warning("Finding with id %s does not exist, skipping post_process_finding_save", finding_id) + return None return post_process_finding_save_internal(finding, dedupe_option, rules_option, product_grading_option, issue_updater_option, push_to_jira, user, *args, **kwargs) @@ -412,7 +417,7 @@ def post_process_finding_save_internal(finding, dedupe_option=True, rules_option if dedupe_option: if finding.hash_code is not None: if system_settings.enable_deduplication: - do_dedupe_finding(finding, *args, **kwargs) + do_dedupe_finding_task_internal(finding, *args, **kwargs) else: deduplicationLogger.debug("skipping dedupe because it's disabled in system settings") else: @@ -431,7 +436,7 @@ def post_process_finding_save_internal(finding, dedupe_option=True, rules_option if product_grading_option: if system_settings.enable_product_grade: - calculate_grade(finding.test.engagement.product) + calculate_grade(finding.test.engagement.product.id) else: deduplicationLogger.debug("skipping product grading because it's disabled in system settings") @@ -499,7 +504,7 @@ def post_process_findings_batch(finding_ids, *args, dedupe_option=True, rules_op tool_issue_updater.async_tool_issue_update(finding) if product_grading_option and system_settings.enable_product_grade: - calculate_grade(findings[0].test.engagement.product) + calculate_grade(findings[0].test.engagement.product.id) if push_to_jira: for finding in findings: @@ -1021,7 +1026,7 @@ def close_finding( ra_helper.risk_unaccept(user, finding, perform_save=False) # External issues (best effort) - close_external_issue(finding, "Closed by defectdojo", "github") + close_external_issue(finding.id, "Closed by defectdojo", "github") # JIRA sync push_to_jira = False diff --git a/dojo/finding/views.py b/dojo/finding/views.py index 999e15e6e47..b5bfb593043 100644 --- a/dojo/finding/views.py +++ b/dojo/finding/views.py @@ -1003,9 +1003,9 @@ def process_github_form(self, request: HttpRequest, finding: Finding, context: d if context["gform"].is_valid(): if GITHUB_Issue.objects.filter(finding=finding).exists(): - update_external_issue(finding, old_status, "github") + update_external_issue(finding.id, old_status, "github") else: - add_external_issue(finding, "github") + add_external_issue(finding.id, "github") return request, True add_field_errors_to_response(context["gform"]) @@ -1082,7 +1082,7 @@ def process_form(self, request: HttpRequest, finding: Finding, context: dict): product = finding.test.engagement.product finding.delete() # Update the grade of the product async - calculate_grade(product) + calculate_grade(product.id) # Add a message to the request that the finding was successfully deleted messages.add_message( request, @@ -1318,7 +1318,7 @@ def reopen_finding(request, fid): if jira_helper.is_push_all_issues(finding) or jira_helper.is_keep_in_sync_with_jira(finding): jira_helper.push_to_jira(finding) - reopen_external_issue(finding, "re-opened by defectdojo", "github") + reopen_external_issue(finding.id, "re-opened by defectdojo", "github") messages.add_message( request, messages.SUCCESS, "Finding Reopened.", extra_tags="alert-success", @@ -1353,7 +1353,7 @@ def copy_finding(request, fid): test = form.cleaned_data.get("test") product = finding.test.engagement.product finding_copy = finding.copy(test=test) - calculate_grade(product) + calculate_grade(product.id) messages.add_message( request, messages.SUCCESS, @@ -2101,7 +2101,7 @@ def promote_to_finding(request, fid): ).push_all_issues, ) if gform.is_valid(): - add_external_issue(new_finding, "github") + add_external_issue(new_finding.id, "github") messages.add_message( request, @@ -2733,7 +2733,7 @@ def _bulk_update_finding_status_and_severity(finds, form, request, system_settin fp.save_no_options() for prod in prods: - calculate_grade(prod) + calculate_grade(prod.id) if skipped_duplicate_count > 0: messages.add_message( @@ -2789,7 +2789,7 @@ def _bulk_update_risk_acceptance(finds, form, request, prods): ra_helper.risk_unaccept(request.user, finding) for prod in prods: - calculate_grade(prod) + calculate_grade(prod.id) if skipped_risk_accept_count > 0: messages.add_message( @@ -3084,9 +3084,9 @@ def finding_bulk_update_all(request, pid=None): old_status = finding.status() if form.cleaned_data["push_to_github"]: if GITHUB_Issue.objects.filter(finding=finding).exists(): - update_external_issue(finding, old_status, "github") + update_external_issue(finding.id, old_status, "github") else: - add_external_issue(finding, "github") + add_external_issue(finding.id, "github") if form.cleaned_data["notes"]: logger.debug("Setting bulk notes") diff --git a/dojo/jira_link/helper.py b/dojo/jira_link/helper.py index 453a33cd23d..67dfa346e00 100644 --- a/dojo/jira_link/helper.py +++ b/dojo/jira_link/helper.py @@ -40,6 +40,7 @@ add_error_message_to_response, get_file_images, get_full_url, + get_object_or_none, get_system_setting, prod_name, to_str_typed, @@ -759,34 +760,40 @@ def push_to_jira(obj, *args, **kwargs): if isinstance(obj, Finding): if obj.has_finding_group: logger.debug("pushing finding group for %s to JIRA", obj) - return push_finding_group_to_jira(obj.finding_group, *args, **kwargs) - return push_finding_to_jira(obj, *args, **kwargs) + return push_finding_group_to_jira(obj.finding_group.id, *args, **kwargs) + return push_finding_to_jira(obj.id, *args, **kwargs) if isinstance(obj, Finding_Group): - return push_finding_group_to_jira(obj, *args, **kwargs) + return push_finding_group_to_jira(obj.id, *args, **kwargs) if isinstance(obj, Engagement): - return push_engagement_to_jira(obj, *args, **kwargs) + return push_engagement_to_jira(obj.id, *args, **kwargs) logger.error("unsupported object passed to push_to_jira: %s %i %s", obj.__name__, obj.id, obj) return None # we need thre separate celery tasks due to the decorators we're using to map to/from ids -@dojo_model_to_id @dojo_async_task @app.task -@dojo_model_from_id -def push_finding_to_jira(finding, *args, **kwargs): +def push_finding_to_jira(finding_id, *args, **kwargs): + finding = get_object_or_none(Finding, id=finding_id) + if not finding: + logger.warning("Finding with id %s does not exist, skipping push_finding_to_jira", finding_id) + return None + if finding.has_jira_issue: return update_jira_issue(finding, *args, **kwargs) return add_jira_issue(finding, *args, **kwargs) -@dojo_model_to_id @dojo_async_task @app.task -@dojo_model_from_id(model=Finding_Group) -def push_finding_group_to_jira(finding_group, *args, **kwargs): +def push_finding_group_to_jira(finding_group_id, *args, **kwargs): + finding_group = get_object_or_none(Finding_Group, id=finding_group_id) + if not finding_group: + logger.warning("Finding_Group with id %s does not exist, skipping push_finding_group_to_jira", finding_group_id) + return None + # Look for findings that have single ticket associations separate from the group for finding in finding_group.findings.filter(jira_issue__isnull=False): update_jira_issue(finding, *args, **kwargs) @@ -796,14 +803,17 @@ def push_finding_group_to_jira(finding_group, *args, **kwargs): return add_jira_issue(finding_group, *args, **kwargs) -@dojo_model_to_id @dojo_async_task @app.task -@dojo_model_from_id(model=Engagement) -def push_engagement_to_jira(engagement, *args, **kwargs): +def push_engagement_to_jira(engagement_id, *args, **kwargs): + engagement = get_object_or_none(Engagement, id=engagement_id) + if not engagement: + logger.warning("Engagement with id %s does not exist, skipping push_engagement_to_jira", engagement_id) + return None + if engagement.has_jira_issue: - return update_epic(engagement, *args, **kwargs) - return add_epic(engagement, *args, **kwargs) + return update_epic(engagement.id, *args, **kwargs) + return add_epic(engagement.id, *args, **kwargs) def add_issues_to_epic(jira, obj, epic_id, issue_keys, *, ignore_epics=True): @@ -1366,12 +1376,13 @@ def jira_check_attachment(issue, source_file_name): return file_exists -@dojo_model_to_id @dojo_async_task @app.task -@dojo_model_from_id(model=Engagement) -def close_epic(eng, push_to_jira, **kwargs): - engagement = eng +def close_epic(engagement_id, push_to_jira, **kwargs): + engagement = get_object_or_none(Engagement, id=engagement_id) + if not engagement: + logger.warning("Engagement with id %s does not exist, skipping close_epic", engagement_id) + return False if not is_jira_enabled(): return False @@ -1387,7 +1398,7 @@ def close_epic(eng, push_to_jira, **kwargs): if jira_project and jira_project.enable_engagement_epic_mapping: if push_to_jira: try: - jissue = get_jira_issue(eng) + jissue = get_jira_issue(engagement) if jissue is None: logger.warning("JIRA close epic failed: no issue found") return False @@ -1414,11 +1425,14 @@ def close_epic(eng, push_to_jira, **kwargs): return False -@dojo_model_to_id @dojo_async_task @app.task -@dojo_model_from_id(model=Engagement) -def update_epic(engagement, **kwargs): +def update_epic(engagement_id, **kwargs): + engagement = get_object_or_none(Engagement, id=engagement_id) + if not engagement: + logger.warning("Engagement with id %s does not exist, skipping update_epic", engagement_id) + return False + logger.debug("trying to update jira EPIC for %d:%s", engagement.id, engagement.name) if not is_jira_configured_and_enabled(engagement): @@ -1458,11 +1472,14 @@ def update_epic(engagement, **kwargs): return False -@dojo_model_to_id @dojo_async_task @app.task -@dojo_model_from_id(model=Engagement) -def add_epic(engagement, **kwargs): +def add_epic(engagement_id, **kwargs): + engagement = get_object_or_none(Engagement, id=engagement_id) + if not engagement: + logger.warning("Engagement with id %s does not exist, skipping add_epic", engagement_id) + return False + logger.debug("trying to create a new jira EPIC for %d:%s", engagement.id, engagement.name) if not is_jira_configured_and_enabled(engagement): diff --git a/dojo/management/commands/dedupe.py b/dojo/management/commands/dedupe.py index 3eddccdc45d..913c528f299 100644 --- a/dojo/management/commands/dedupe.py +++ b/dojo/management/commands/dedupe.py @@ -8,8 +8,8 @@ from dojo.finding.deduplication import ( dedupe_batch_of_findings, do_dedupe_batch_task, - do_dedupe_finding, do_dedupe_finding_task, + do_dedupe_finding_task_internal, get_finding_models_for_deduplication, ) from dojo.models import Finding, Product @@ -115,7 +115,7 @@ def _run_dedupe(self, *, restrict_to_parsers, hash_code_only, dedupe_only, dedup if dedupe_batch_mode: self._dedupe_batch_mode(findings, dedupe_sync=dedupe_sync) elif dedupe_sync: - mass_model_updater(Finding, findings, do_dedupe_finding, fields=None, order="desc", page_size=100, log_prefix="deduplicating ") + mass_model_updater(Finding, findings, do_dedupe_finding_task_internal, fields=None, order="desc", page_size=100, log_prefix="deduplicating ") else: # async tasks only need the id mass_model_updater(Finding, findings.only("id"), lambda f: do_dedupe_finding_task(f.id), fields=None, order="desc", log_prefix="deduplicating ") @@ -125,7 +125,7 @@ def _run_dedupe(self, *, restrict_to_parsers, hash_code_only, dedupe_only, dedup # in async mode the background task that grades products every hour will pick it up logger.debug("Updating grades for products...") for product in Product.objects.all(): - calculate_grade(product) + calculate_grade(product.id) logger.info("######## Done deduplicating (%s) ########", ("foreground" if dedupe_sync else "tasks submitted to celery")) else: diff --git a/dojo/models.py b/dojo/models.py index 57ce9c18e72..0e4680de67d 100644 --- a/dojo/models.py +++ b/dojo/models.py @@ -2844,7 +2844,7 @@ def save(self, dedupe_option=True, rules_option=True, product_grading_option=Tru # only perform post processing (in celery task) if needed. this check avoids submitting 1000s of tasks to celery that will do nothing system_settings = System_Settings.objects.get() if dedupe_option or issue_updater_option or (product_grading_option and system_settings.enable_product_grade) or push_to_jira: - finding_helper.post_process_finding_save(self, dedupe_option=dedupe_option, rules_option=rules_option, product_grading_option=product_grading_option, + finding_helper.post_process_finding_save(self.id, dedupe_option=dedupe_option, rules_option=rules_option, product_grading_option=product_grading_option, issue_updater_option=issue_updater_option, push_to_jira=push_to_jira, user=user, *args, **kwargs) else: logger.debug("no options selected that require finding post processing") diff --git a/dojo/product/views.py b/dojo/product/views.py index 837e0bdfefc..fccc738851f 100644 --- a/dojo/product/views.py +++ b/dojo/product/views.py @@ -1520,7 +1520,7 @@ def process_github_form(self, request: HttpRequest, finding: Finding, context: d return request, True if context["gform"].is_valid(): - add_external_issue(finding, "github") + add_external_issue(finding.id, "github") return request, True add_field_errors_to_response(context["gform"]) diff --git a/dojo/settings/settings.dist.py b/dojo/settings/settings.dist.py index ab7918c922c..32b136a6d3c 100644 --- a/dojo/settings/settings.dist.py +++ b/dojo/settings/settings.dist.py @@ -89,8 +89,7 @@ DD_CELERY_RESULT_EXPIRES=(int, 86400), DD_CELERY_BEAT_SCHEDULE_FILENAME=(str, root("dojo.celery.beat.db")), DD_CELERY_TASK_SERIALIZER=(str, "pickle"), - DD_CELERY_PASS_MODEL_BY_ID=(str, True), - DD_CELERY_LOG_LEVEL=(str, "INFO"), + DD_CELERY_LOG_LEVEL=(str, "INFO"), DD_TAG_BULK_ADD_BATCH_SIZE=(int, 1000), # Tagulous slug truncate unique setting. Set to -1 to use tagulous internal default (5) DD_TAGULOUS_SLUG_TRUNCATE_UNIQUE=(int, -1), @@ -1222,7 +1221,6 @@ def saml2_attrib_map_format(din): CELERY_BEAT_SCHEDULE_FILENAME = env("DD_CELERY_BEAT_SCHEDULE_FILENAME") CELERY_ACCEPT_CONTENT = ["pickle", "json", "msgpack", "yaml"] CELERY_TASK_SERIALIZER = env("DD_CELERY_TASK_SERIALIZER") -CELERY_PASS_MODEL_BY_ID = env("DD_CELERY_PASS_MODEL_BY_ID") CELERY_LOG_LEVEL = env("DD_CELERY_LOG_LEVEL") if len(env("DD_CELERY_BROKER_TRANSPORT_OPTIONS")) > 0: diff --git a/dojo/tasks.py b/dojo/tasks.py index d02040fa5b3..29dfe11257c 100644 --- a/dojo/tasks.py +++ b/dojo/tasks.py @@ -72,7 +72,7 @@ def add_alerts(self, runinterval): if system_settings.enable_product_grade: products = Product.objects.all() for product in products: - calculate_grade(product) + calculate_grade(product.id) @app.task(bind=True) @@ -169,7 +169,7 @@ def _async_dupe_delete_impl(): if system_settings.enable_product_grade: logger.info("performing batch product grading for %s products", len(affected_products)) for product in affected_products: - calculate_grade(product) + calculate_grade(product.id) @app.task(ignore_result=False) diff --git a/dojo/templatetags/display_tags.py b/dojo/templatetags/display_tags.py index 2d26b874be7..f19c704fd55 100644 --- a/dojo/templatetags/display_tags.py +++ b/dojo/templatetags/display_tags.py @@ -304,7 +304,7 @@ def product_grade(product): if system_settings.enable_product_grade and product: prod_numeric_grade = product.prod_numeric_grade if not prod_numeric_grade or prod_numeric_grade is None: - calculate_grade(product) + calculate_grade(product.id) if prod_numeric_grade: if prod_numeric_grade >= system_settings.product_grade_a: grade = "A" diff --git a/dojo/test/views.py b/dojo/test/views.py index a05c0b3b660..d37825822c6 100644 --- a/dojo/test/views.py +++ b/dojo/test/views.py @@ -343,7 +343,7 @@ def copy_test(request, tid): engagement = form.cleaned_data.get("engagement") product = test.engagement.product test_copy = test.copy(engagement=engagement) - calculate_grade(product) + calculate_grade(product.id) messages.add_message( request, messages.SUCCESS, diff --git a/dojo/tools/tool_issue_updater.py b/dojo/tools/tool_issue_updater.py index 5a3f2895658..854fb989113 100644 --- a/dojo/tools/tool_issue_updater.py +++ b/dojo/tools/tool_issue_updater.py @@ -1,15 +1,21 @@ +import logging + import pghistory from dojo.celery import app from dojo.decorators import dojo_async_task +from dojo.models import Finding from dojo.tools.api_sonarqube.parser import SCAN_SONARQUBE_API from dojo.tools.api_sonarqube.updater import SonarQubeApiUpdater from dojo.tools.api_sonarqube.updater_from_source import SonarQubeApiUpdaterFromSource +from dojo.utils import get_object_or_none + +logger = logging.getLogger(__name__) def async_tool_issue_update(finding, *args, **kwargs): if is_tool_issue_updater_needed(finding): - tool_issue_updater(finding) + tool_issue_updater(finding.id) def is_tool_issue_updater_needed(finding, *args, **kwargs): @@ -19,7 +25,11 @@ def is_tool_issue_updater_needed(finding, *args, **kwargs): @dojo_async_task @app.task -def tool_issue_updater(finding, *args, **kwargs): +def tool_issue_updater(finding_id, *args, **kwargs): + finding = get_object_or_none(Finding, id=finding_id) + if not finding: + logger.warning("Finding with id %s does not exist, skipping tool_issue_updater", finding_id) + return test_type = finding.test.test_type diff --git a/dojo/utils.py b/dojo/utils.py index 33e99846b81..2a9f967a030 100644 --- a/dojo/utils.py +++ b/dojo/utils.py @@ -46,7 +46,7 @@ from dojo.authorization.roles_permissions import Permissions from dojo.celery import app -from dojo.decorators import dojo_async_task, dojo_model_from_id, dojo_model_to_id +from dojo.decorators import dojo_async_task from dojo.finding.queries import get_authorized_findings from dojo.github import ( add_external_issue_github, @@ -1054,53 +1054,65 @@ def handle_uploaded_selenium(f, cred): cred.save() -@dojo_model_to_id @dojo_async_task @app.task -@dojo_model_from_id -def add_external_issue(find, external_issue_provider, **kwargs): - eng = Engagement.objects.get(test=find.test) +def add_external_issue(finding_id, external_issue_provider, **kwargs): + finding = get_object_or_none(Finding, id=finding_id) + if not finding: + logger.warning("Finding with id %s does not exist, skipping add_external_issue", finding_id) + return + + eng = Engagement.objects.get(test=finding.test) prod = Product.objects.get(engagement=eng) logger.debug("adding external issue with provider: " + external_issue_provider) if external_issue_provider == "github": - add_external_issue_github(find, prod, eng) + add_external_issue_github(finding, prod, eng) -@dojo_model_to_id @dojo_async_task @app.task -@dojo_model_from_id -def update_external_issue(find, old_status, external_issue_provider, **kwargs): - prod = Product.objects.get(engagement=Engagement.objects.get(test=find.test)) - eng = Engagement.objects.get(test=find.test) +def update_external_issue(finding_id, old_status, external_issue_provider, **kwargs): + finding = get_object_or_none(Finding, id=finding_id) + if not finding: + logger.warning("Finding with id %s does not exist, skipping update_external_issue", finding_id) + return + + prod = Product.objects.get(engagement=Engagement.objects.get(test=finding.test)) + eng = Engagement.objects.get(test=finding.test) if external_issue_provider == "github": - update_external_issue_github(find, prod, eng) + update_external_issue_github(finding, prod, eng) -@dojo_model_to_id @dojo_async_task @app.task -@dojo_model_from_id -def close_external_issue(find, note, external_issue_provider, **kwargs): - prod = Product.objects.get(engagement=Engagement.objects.get(test=find.test)) - eng = Engagement.objects.get(test=find.test) +def close_external_issue(finding_id, note, external_issue_provider, **kwargs): + finding = get_object_or_none(Finding, id=finding_id) + if not finding: + logger.warning("Finding with id %s does not exist, skipping close_external_issue", finding_id) + return + + prod = Product.objects.get(engagement=Engagement.objects.get(test=finding.test)) + eng = Engagement.objects.get(test=finding.test) if external_issue_provider == "github": - close_external_issue_github(find, note, prod, eng) + close_external_issue_github(finding, note, prod, eng) -@dojo_model_to_id @dojo_async_task @app.task -@dojo_model_from_id -def reopen_external_issue(find, note, external_issue_provider, **kwargs): - prod = Product.objects.get(engagement=Engagement.objects.get(test=find.test)) - eng = Engagement.objects.get(test=find.test) +def reopen_external_issue(finding_id, note, external_issue_provider, **kwargs): + finding = get_object_or_none(Finding, id=finding_id) + if not finding: + logger.warning("Finding with id %s does not exist, skipping reopen_external_issue", finding_id) + return + + prod = Product.objects.get(engagement=Engagement.objects.get(test=finding.test)) + eng = Engagement.objects.get(test=finding.test) if external_issue_provider == "github": - reopen_external_issue_github(find, note, prod, eng) + reopen_external_issue_github(finding, note, prod, eng) def process_tag_notifications(request, note, parent_url, parent_title): @@ -1224,20 +1236,24 @@ def get_setting(setting): return getattr(settings, setting) -@dojo_model_to_id @dojo_async_task(signature=True) @app.task -@dojo_model_from_id(model=Product) -def calculate_grade_signature(product, *args, **kwargs): +def calculate_grade_signature(product_id, *args, **kwargs): """Returns a signature for calculating product grade that can be used in chords or groups.""" + product = get_object_or_none(Product, id=product_id) + if not product: + logger.warning("Product with id %s does not exist, skipping calculate_grade_signature", product_id) + return None return calculate_grade_internal(product, *args, **kwargs) -@dojo_model_to_id @dojo_async_task @app.task -@dojo_model_from_id(model=Product) -def calculate_grade(product, *args, **kwargs): +def calculate_grade(product_id, *args, **kwargs): + product = get_object_or_none(Product, id=product_id) + if not product: + logger.warning("Product with id %s does not exist, skipping calculate_grade", product_id) + return None return calculate_grade_internal(product, *args, **kwargs) diff --git a/unittests/test_jira_import_and_pushing_api.py b/unittests/test_jira_import_and_pushing_api.py index 84d173667a7..ee0808b3ca8 100644 --- a/unittests/test_jira_import_and_pushing_api.py +++ b/unittests/test_jira_import_and_pushing_api.py @@ -1064,7 +1064,7 @@ def test_bulk_edit_mixed_findings_and_groups_jira_push_bug(self, mock_webhooks, # we take a shortcut here as creating an engagement with epic mapping via the API is not implemented yet def create_engagement_epic(self, engagement): with impersonate(self.testuser): - return jira_helper.add_epic(engagement) + return jira_helper.add_epic(engagement.id) def assert_epic_issue_count(self, engagement, count): jira_issues = self.get_epic_issues(engagement) From 7a8579c15f37f9da524e8a10206f72428c8f5bfa Mon Sep 17 00:00:00 2001 From: Valentijn Scholten Date: Fri, 26 Dec 2025 18:53:05 +0100 Subject: [PATCH 03/36] remove dojo_model_from/to_id --- dojo/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dojo/utils.py b/dojo/utils.py index 2a9f967a030..1d3b8b24aa3 100644 --- a/dojo/utils.py +++ b/dojo/utils.py @@ -1308,7 +1308,7 @@ def calculate_grade_internal(product, *args, **kwargs): def perform_product_grading(product): system_settings = System_Settings.objects.get() if system_settings.enable_product_grade: - calculate_grade(product) + calculate_grade(product.id) def get_celery_worker_status(): From a74981e8e9f9e7c5863463dc50a5cad888acd4a9 Mon Sep 17 00:00:00 2001 From: Valentijn Scholten Date: Fri, 26 Dec 2025 18:58:12 +0100 Subject: [PATCH 04/36] remove dojo_model_from/to_id --- unittests/test_importers_performance.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/unittests/test_importers_performance.py b/unittests/test_importers_performance.py index 1e7b05d8fe5..d2301a1f5d7 100644 --- a/unittests/test_importers_performance.py +++ b/unittests/test_importers_performance.py @@ -310,12 +310,14 @@ def test_import_reimport_reimport_performance_pghistory_no_async_with_product_gr self.system_settings(enable_product_grade=True) self._import_reimport_performance( - expected_num_queries1=315, + + expected_num_queries1=317, expected_num_async_tasks1=8, - expected_num_queries2=241, + expected_num_queries2=243, expected_num_async_tasks2=19, - expected_num_queries3=123, + expected_num_queries3=125, expected_num_async_tasks3=18, + ) # Deduplication is enabled in the tests above, but to properly test it we must run the same import twice and capture the results. From b396e16949cfb360f36589a5bd5a28c1c44ba1ef Mon Sep 17 00:00:00 2001 From: Valentijn Scholten Date: Fri, 26 Dec 2025 19:13:48 +0100 Subject: [PATCH 05/36] remove dojo_model_from/to_id --- dojo/decorators.py | 78 +--------------------------------------- dojo/jira_link/helper.py | 74 +++++++++++++++++++++++++++----------- 2 files changed, 54 insertions(+), 98 deletions(-) diff --git a/dojo/decorators.py b/dojo/decorators.py index 91f6934b719..a2fbba9dd54 100644 --- a/dojo/decorators.py +++ b/dojo/decorators.py @@ -3,12 +3,11 @@ from functools import wraps from django.conf import settings -from django.db import models from django_ratelimit import UNSAFE from django_ratelimit.core import is_ratelimited from django_ratelimit.exceptions import Ratelimited -from dojo.models import Dojo_User, Finding +from dojo.models import Dojo_User logger = logging.getLogger(__name__) @@ -116,81 +115,6 @@ def __wrapper__(*args, **kwargs): return decorator(func) -# decorator with parameters needs another wrapper layer -# example usage: @dojo_model_to_id(parameter=0) but defaults to parameter=0 -def dojo_model_to_id(_func=None, *, parameter=0): - # logger.debug('dec_args:' + str(dec_args)) - # logger.debug('dec_kwargs:' + str(dec_kwargs)) - # logger.debug('_func:%s', _func) - - def dojo_model_to_id_internal(func, *args, **kwargs): - @wraps(func) - def __wrapper__(*args, **kwargs): - if not settings.CELERY_PASS_MODEL_BY_ID: - return func(*args, **kwargs) - - model_or_id = get_parameter_froms_args_kwargs(args, kwargs, parameter) - - if model_or_id: - if isinstance(model_or_id, models.Model) and we_want_async(*args, func=func, **kwargs): - logger.debug("converting model_or_id to id: %s", model_or_id) - args = list(args) - args[parameter] = model_or_id.id - - return func(*args, **kwargs) - - return __wrapper__ - - if _func is None: - # decorator called without parameters - return dojo_model_to_id_internal - return dojo_model_to_id_internal(_func) - - -# decorator with parameters needs another wrapper layer -# example usage: @dojo_model_from_id(parameter=0, model=Finding) but defaults to parameter 0 and model Finding -def dojo_model_from_id(_func=None, *, model=Finding, parameter=0): - # logger.debug('dec_args:' + str(dec_args)) - # logger.debug('dec_kwargs:' + str(dec_kwargs)) - # logger.debug('_func:%s', _func) - # logger.debug('model: %s', model) - - def dojo_model_from_id_internal(func, *args, **kwargs): - @wraps(func) - def __wrapper__(*args, **kwargs): - if not settings.CELERY_PASS_MODEL_BY_ID: - return func(*args, **kwargs) - - logger.debug("args:" + str(args)) - logger.debug("kwargs:" + str(kwargs)) - - logger.debug("checking if we need to convert id to model: %s for parameter: %s", model.__name__, parameter) - - model_or_id = get_parameter_froms_args_kwargs(args, kwargs, parameter) - - if model_or_id: - if not isinstance(model_or_id, models.Model) and we_want_async(*args, func=func, **kwargs): - logger.debug("instantiating model_or_id: %s for model: %s", model_or_id, model) - try: - instance = model.objects.get(id=model_or_id) - except model.DoesNotExist: - logger.warning("error instantiating model_or_id: %s for model: %s: DoesNotExist", model_or_id, model) - instance = None - args = list(args) - args[parameter] = instance - else: - logger.debug("model_or_id already a model instance %s for model: %s", model_or_id, model) - - return func(*args, **kwargs) - - return __wrapper__ - - if _func is None: - # decorator called without parameters - return dojo_model_from_id_internal - return dojo_model_from_id_internal(_func) - - def get_parameter_froms_args_kwargs(args, kwargs, parameter): model_or_id = None if isinstance(parameter, int): diff --git a/dojo/jira_link/helper.py b/dojo/jira_link/helper.py index 67dfa346e00..f020a4d5b19 100644 --- a/dojo/jira_link/helper.py +++ b/dojo/jira_link/helper.py @@ -18,7 +18,7 @@ from requests.auth import HTTPBasicAuth from dojo.celery import app -from dojo.decorators import dojo_async_task, dojo_model_from_id, dojo_model_to_id +from dojo.decorators import dojo_async_task from dojo.forms import JIRAEngagementForm, JIRAProjectForm from dojo.models import ( Engagement, @@ -1562,33 +1562,65 @@ def jira_get_issue(jira_project, issue_key): return None -@dojo_model_to_id(parameter=1) -@dojo_model_to_id -@dojo_async_task -@app.task -@dojo_model_from_id(model=Notes, parameter=1) -@dojo_model_from_id def add_comment(obj, note, *, force_push=False, **kwargs): + """ + Wrapper function that extracts jira_issue from obj and calls the internal Celery task. + + The decorators convert obj and note to IDs before Celery serialization. + After deserialization, obj and note are model instances again. + """ if not is_jira_configured_and_enabled(obj): return False logger.debug("trying to add a comment to a linked jira issue for: %d:%s", obj.id, obj) - if not note.private: - jira_project = get_jira_project(obj) - jira_instance = get_jira_instance(obj) - if jira_project.push_notes or force_push is True: - try: - jira = get_jira_connection(jira_instance) - j_issue = obj.jira_issue - jira.add_comment( - j_issue.jira_id, - f"({note.author.get_full_name() or note.author.username}): {note.entry}") - except JIRAError as e: - log_jira_generic_alert("Jira Add Comment Error", str(e)) - return False - return True + # Get the jira_issue from obj + jira_issue = get_jira_issue(obj) + if not jira_issue: + logger.warning("No jira_issue found for obj %s, skipping add_comment", obj) + return False + + # Call the internal task with IDs (runs synchronously within this task) + return add_comment_internal(jira_issue.id, note.id, force_push=force_push, **kwargs) + + +@dojo_async_task +@app.task +def add_comment_internal(jira_issue_id, note_id, *, force_push=False, **kwargs): + """Internal Celery task that adds a comment to a JIRA issue.""" + jira_issue = get_object_or_none(JIRA_Issue, id=jira_issue_id) + if not jira_issue: + logger.warning("JIRA_Issue with id %s does not exist, skipping add_comment_internal", jira_issue_id) + return False + + note = get_object_or_none(Notes, id=note_id) + if not note: + logger.warning("Note with id %s does not exist, skipping add_comment_internal", note_id) + return False + + if note.private: return None + + jira_project = get_jira_project(jira_issue) + if not jira_project: + logger.warning("No jira_project found for jira_issue %s, skipping add_comment_internal", jira_issue_id) + return False + + jira_instance = jira_project.jira_instance + if not jira_instance: + logger.warning("No jira_instance found for jira_project %s, skipping add_comment_internal", jira_project.id) + return False + + if jira_project.push_notes or force_push is True: + try: + jira = get_jira_connection(jira_instance) + jira.add_comment( + jira_issue.jira_id, + f"({note.author.get_full_name() or note.author.username}): {note.entry}") + except JIRAError as e: + log_jira_generic_alert("Jira Add Comment Error", str(e)) + return False + return True return None From e08cdffa5205b0819dd05f3441406d8cafef63ef Mon Sep 17 00:00:00 2001 From: Valentijn Scholten Date: Fri, 26 Dec 2025 19:43:30 +0100 Subject: [PATCH 06/36] fix tests --- unittests/test_deduplication_logic.py | 95 +++++++++++++++++++++++++++ unittests/test_product_grading.py | 2 + 2 files changed, 97 insertions(+) diff --git a/unittests/test_deduplication_logic.py b/unittests/test_deduplication_logic.py index 118ca267f91..a941de1f2c4 100644 --- a/unittests/test_deduplication_logic.py +++ b/unittests/test_deduplication_logic.py @@ -182,6 +182,7 @@ def test_identical_legacy(self): # expect: marked as duplicate finding_new, finding_24 = self.copy_and_reset_finding(find_id=24) finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=24, duplicate=True, duplicate_finding_id=finding_24.duplicate_finding.id, hash_code=finding_24.hash_code) @@ -195,12 +196,14 @@ def test_identical_ordering_legacy(self): finding_23.duplicate_finding = None finding_23.active = True finding_23.save(dedupe_option=False) + finding_23.refresh_from_db() self.assert_finding(finding_23, duplicate=False, hash_code=finding_22.hash_code) # create a copy of 22 finding_new, finding_22 = self.copy_and_reset_finding(find_id=22) finding_new.save() + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=22, duplicate=True, duplicate_finding_id=finding_22.id, hash_code=finding_22.hash_code) # self.assert_finding(finding_new, not_pk=22, duplicate=True, duplicate_finding_id=finding_23.id, hash_code=finding_22.hash_code) @@ -211,6 +214,7 @@ def test_identical_except_title_legacy(self): finding_new, finding_4 = self.copy_and_reset_finding(find_id=4) finding_new.title = "the best title" finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=24, duplicate=False, not_hash_code=finding_4.hash_code) @@ -220,6 +224,7 @@ def test_identical_except_description_legacy(self): finding_new, finding_24 = self.copy_and_reset_finding(find_id=24) finding_new.description = "useless finding" finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=24, duplicate=False, not_hash_code=finding_24.hash_code) @@ -229,6 +234,7 @@ def test_identical_except_line_legacy(self): finding_new, finding_24 = self.copy_and_reset_finding(find_id=24) finding_new.line = 666 finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=24, duplicate=False, not_hash_code=finding_24.hash_code) @@ -241,6 +247,7 @@ def test_identical_except_filepath_legacy(self): Finding.objects.get(id=22) finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=24, duplicate=False, not_hash_code=finding_24.hash_code) @@ -255,6 +262,7 @@ def test_dedupe_inside_engagement_legacy(self): finding_new.test = test_new finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=22, duplicate=False, hash_code=finding_22.hash_code) @@ -272,6 +280,7 @@ def test_dedupe_not_inside_engagement_legacy(self): finding_new.test = test_new finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=22, duplicate=True, duplicate_finding_id=22, hash_code=finding_22.hash_code) @@ -281,16 +290,19 @@ def test_identical_no_filepath_no_line_no_endpoints_legacy(self): finding_new.file_path = None finding_new.line = None finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=22, duplicate=False) def test_identical_legacy_with_identical_endpoints_static(self): finding_new, finding_24 = self.copy_and_reset_finding_add_endpoints(find_id=24, static=True, dynamic=False) # has myhost.com, myhost2.com finding_new.save() + finding_new.refresh_from_db() # create an identical copy of the new finding with the same endpoints. it should be marked as duplicate finding_new2, finding_new = self.copy_and_reset_finding(find_id=finding_new.id) finding_new2.save(dedupe_option=False) + finding_new2.refresh_from_db() ep1 = Endpoint(product=finding_new2.test.engagement.product, finding=finding_new2, host="myhost.com", protocol="https") ep1.save() @@ -299,16 +311,19 @@ def test_identical_legacy_with_identical_endpoints_static(self): finding_new2.endpoints.add(ep1) finding_new2.endpoints.add(ep2) finding_new2.save() + finding_new2.refresh_from_db() self.assert_finding(finding_new2, not_pk=finding_new.pk, duplicate=True, duplicate_finding_id=finding_new.id, hash_code=finding_new.hash_code, not_hash_code=finding_24.hash_code) def test_identical_legacy_extra_endpoints_static(self): finding_new, finding_24 = self.copy_and_reset_finding_add_endpoints(find_id=24, static=True, dynamic=False) # has myhost.com, myhost2.com finding_new.save() + finding_new.refresh_from_db() # create a new finding with 3 endpoints (so 1 extra) finding_new3, finding_new = self.copy_and_reset_finding(find_id=finding_new.id) finding_new3.save(dedupe_option=False) + finding_new3.refresh_from_db() ep1 = Endpoint(product=finding_new3.test.engagement.product, finding=finding_new3, host="myhost.com", protocol="https") ep1.save() ep2 = Endpoint(product=finding_new3.test.engagement.product, finding=finding_new3, host="myhost2.com", protocol="https") @@ -319,6 +334,7 @@ def test_identical_legacy_extra_endpoints_static(self): finding_new3.endpoints.add(ep2) finding_new3.endpoints.add(ep3) finding_new3.save() + finding_new3.refresh_from_db() # expect: marked as duplicate as the requirement for static findings is that the new finding has to contain all the endpoints of the existing finding (extra is no problem) # hash_code not affected by endpoints @@ -327,10 +343,12 @@ def test_identical_legacy_extra_endpoints_static(self): def test_identical_legacy_different_endpoints_static(self): finding_new, finding_24 = self.copy_and_reset_finding_add_endpoints(find_id=24, static=True, dynamic=False) # has myhost.com, myhost2.com finding_new.save() + finding_new.refresh_from_db() # create an identical copy of the new finding, but with different endpoints finding_new3, finding_new = self.copy_and_reset_finding(find_id=finding_new.id) finding_new3.save(dedupe_option=False) + finding_new3.refresh_from_db() ep1 = Endpoint(product=finding_new3.test.engagement.product, finding=finding_new3, host="myhost4.com", protocol="https") ep1.save() ep2 = Endpoint(product=finding_new3.test.engagement.product, finding=finding_new3, host="myhost2.com", protocol="https") @@ -338,6 +356,7 @@ def test_identical_legacy_different_endpoints_static(self): finding_new3.endpoints.add(ep1) finding_new3.endpoints.add(ep2) finding_new3.save() + finding_new3.refresh_from_db() # expect: not marked as duplicate as the requirement for static findings is that the new finding has to contain all the endpoints of the existing finding and this is not met # hash_code not affected by endpoints @@ -346,11 +365,13 @@ def test_identical_legacy_different_endpoints_static(self): def test_identical_legacy_no_endpoints_static(self): finding_new, finding_24 = self.copy_and_reset_finding_add_endpoints(find_id=24, static=True, dynamic=False) # has myhost.com, myhost2.com finding_new.save() + finding_new.refresh_from_db() # create an identical copy of the new finding, but with 1 extra endpoint. should not be marked as duplicate finding_new3, finding_new = self.copy_and_reset_finding(find_id=finding_new.id) finding_new3.save(dedupe_option=False) finding_new3.save() + finding_new3.refresh_from_db() # expect not marked as duplicate as the new finding doesn't have endpoints and we don't have filepath/line self.assert_finding(finding_new3, not_pk=finding_new.pk, duplicate=False, hash_code=finding_new.hash_code, not_hash_code=finding_24.hash_code) @@ -358,10 +379,12 @@ def test_identical_legacy_no_endpoints_static(self): def test_identical_legacy_with_identical_endpoints_dynamic(self): finding_new, finding_24 = self.copy_and_reset_finding_add_endpoints(find_id=24, static=True, dynamic=False) # has myhost.com, myhost2.com finding_new.save() + finding_new.refresh_from_db() # create an identical copy of the new finding. it should be marked as duplicate finding_new2, finding_new = self.copy_and_reset_finding(find_id=finding_new.id) finding_new2.save(dedupe_option=False) + finding_new2.refresh_from_db() ep1 = Endpoint(product=finding_new2.test.engagement.product, finding=finding_new2, host="myhost.com", protocol="https") ep1.save() @@ -370,16 +393,19 @@ def test_identical_legacy_with_identical_endpoints_dynamic(self): finding_new2.endpoints.add(ep1) finding_new2.endpoints.add(ep2) finding_new2.save() + finding_new2.refresh_from_db() self.assert_finding(finding_new2, not_pk=finding_new.pk, duplicate=True, duplicate_finding_id=finding_new.id, hash_code=finding_new.hash_code, not_hash_code=finding_24.hash_code) def test_identical_legacy_extra_endpoints_dynamic(self): finding_new, _finding_24 = self.copy_and_reset_finding_add_endpoints(find_id=24) finding_new.save() + finding_new.refresh_from_db() # create an identical copy of the new finding, but with 1 extra endpoint. finding_new3, finding_new = self.copy_and_reset_finding(find_id=finding_new.id) finding_new3.save(dedupe_option=False) + finding_new3.refresh_from_db() ep1 = Endpoint(product=finding_new3.test.engagement.product, finding=finding_new3, host="myhost.com", protocol="https") ep1.save() ep2 = Endpoint(product=finding_new3.test.engagement.product, finding=finding_new3, host="myhost2.com", protocol="https") @@ -390,6 +416,7 @@ def test_identical_legacy_extra_endpoints_dynamic(self): finding_new3.endpoints.add(ep2) finding_new3.endpoints.add(ep3) finding_new3.save() + finding_new3.refresh_from_db() # expect: marked as duplicate as hash_code is not affected by endpoints anymore with the legacy algorithm self.assert_finding(finding_new3, not_pk=finding_new.pk, duplicate=True, hash_code=finding_new.hash_code) @@ -404,10 +431,12 @@ def test_identical_legacy_different_endpoints_dynamic(self): # hash_code not affected by endpoints finding_new, _finding_24 = self.copy_and_reset_finding_add_endpoints(find_id=24) finding_new.save() + finding_new.refresh_from_db() # create an identical copy of the new finding, but with 1 extra endpoint. should not be marked as duplicate finding_new3, finding_new = self.copy_and_reset_finding(find_id=finding_new.id) finding_new3.save(dedupe_option=False) + finding_new3.refresh_from_db() ep1 = Endpoint(product=finding_new3.test.engagement.product, finding=finding_new3, host="myhost4.com", protocol="https") ep1.save() ep2 = Endpoint(product=finding_new3.test.engagement.product, finding=finding_new3, host="myhost2.com", protocol="https") @@ -415,6 +444,7 @@ def test_identical_legacy_different_endpoints_dynamic(self): finding_new3.endpoints.add(ep1) finding_new3.endpoints.add(ep2) finding_new3.save() + finding_new3.refresh_from_db() # expected: hash_code is not affected by endpoints anymore in legacy algorithm # but not duplicate because the legacy dedupe algo examines not only hash_code but endpoints too @@ -423,11 +453,13 @@ def test_identical_legacy_different_endpoints_dynamic(self): def test_identical_legacy_no_endpoints_dynamic(self): finding_new, _finding_24 = self.copy_and_reset_finding_add_endpoints(find_id=24) finding_new.save() + finding_new.refresh_from_db() # create an identical copy of the new finding, but with no endpoints finding_new3, finding_new = self.copy_and_reset_finding(find_id=finding_new.id) finding_new3.save(dedupe_option=False) finding_new3.save() + finding_new3.refresh_from_db() # expect: marked as duplicate, hash_code not affected by endpoints with the legacy algorithm # but not duplicate because the legacy dedupe algo examines not only hash_code but endpoints too @@ -451,6 +483,7 @@ def test_identical_hash_code(self): finding_new, finding_2 = self.copy_with_endpoints_without_dedupe_and_reset_finding(find_id=2) finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=2, duplicate=True, duplicate_finding_id=finding_4.duplicate_finding.id, hash_code=finding_2.hash_code) def test_identical_ordering_hash_code(self): @@ -465,12 +498,14 @@ def test_identical_ordering_hash_code(self): finding_3.duplicate_finding = None finding_3.active = True finding_3.save(dedupe_option=False) + finding_3.refresh_from_db() self.assert_finding(finding_3, duplicate=False, hash_code=finding_2.hash_code) # create a copy of 2 finding_new, finding_2 = self.copy_and_reset_finding(find_id=2) finding_new.save() + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=2, duplicate=True, duplicate_finding_id=finding_2.id, hash_code=finding_2.hash_code) # self.assert_finding(finding_new, not_pk=2, duplicate=True, duplicate_finding_id=finding_3.id, hash_code=finding_2.hash_code) @@ -484,6 +519,7 @@ def test_identical_except_title_hash_code(self): finding_new, finding_4 = self.copy_and_reset_finding(find_id=4) finding_new.title = "the best title" finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=4, duplicate=False, not_hash_code=finding_4.hash_code) @@ -495,6 +531,7 @@ def test_identical_except_description_hash_code(self): finding_new.description = "useless finding" finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() if (settings.DEDUPE_ALGO_ENDPOINT_FIELDS == []): # expect duplicate, as endpoints shouldn't affect dedupe @@ -504,6 +541,7 @@ def test_identical_except_description_hash_code(self): finding_new, finding_2 = self.copy_with_endpoints_without_dedupe_and_reset_finding(find_id=2) finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=2, duplicate=True, duplicate_finding_id=finding_4.duplicate_finding.id, hash_code=finding_2.hash_code) # TODO: not usefile with ZAP? @@ -514,6 +552,7 @@ def test_identical_except_line_hash_code(self): finding_new, finding_4 = self.copy_and_reset_finding(find_id=4) finding_new.line = 666 finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() if (settings.DEDUPE_ALGO_ENDPOINT_FIELDS == []): # expect duplicate, as endpoints shouldn't affect dedupe @@ -524,6 +563,7 @@ def test_identical_except_line_hash_code(self): finding_new, finding_2 = self.copy_with_endpoints_without_dedupe_and_reset_finding(find_id=2) finding_new.line = 666 finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=2, duplicate=True, duplicate_finding_id=finding_4.duplicate_finding.id, hash_code=finding_2.hash_code) # TODO: not usefile with ZAP? @@ -533,6 +573,7 @@ def test_identical_except_filepath_hash_code(self): finding_new, finding_4 = self.copy_and_reset_finding(find_id=4) finding_new.file_path = "/dev/null" finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() if (settings.DEDUPE_ALGO_ENDPOINT_FIELDS == []): # expect duplicate, as endpoints shouldn't affect dedupe @@ -543,6 +584,7 @@ def test_identical_except_filepath_hash_code(self): finding_new, finding_2 = self.copy_with_endpoints_without_dedupe_and_reset_finding(find_id=2) finding_new.file_path = "/dev/null" finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=2, duplicate=True, duplicate_finding_id=finding_4.duplicate_finding.id, hash_code=finding_2.hash_code) def test_dedupe_inside_engagement_hash_code(self): @@ -553,6 +595,7 @@ def test_dedupe_inside_engagement_hash_code(self): finding_new, finding_2 = self.copy_with_endpoints_without_dedupe_and_reset_finding(find_id=2) finding_new.test = Test.objects.get(id=4) finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=2, duplicate=False, hash_code=finding_2.hash_code) @@ -566,6 +609,7 @@ def test_dedupe_not_inside_engagement_hash_code(self): finding_new, finding_2 = self.copy_with_endpoints_without_dedupe_and_reset_finding(find_id=2) finding_new.test = Test.objects.get(id=4) finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=2, duplicate=True, duplicate_finding_id=2, hash_code=finding_2.hash_code) @@ -576,6 +620,7 @@ def test_identical_no_filepath_no_line_no_endpoints_hash_code(self): finding_new.file_path = None finding_new.line = None finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=2, duplicate=True, duplicate_finding_id=2, hash_code=finding_2.hash_code) @@ -583,6 +628,7 @@ def test_identical_hash_code_with_identical_endpoints(self): # create an identical copy of the new finding, with the same endpoints finding_new, finding_2 = self.copy_with_endpoints_without_dedupe_and_reset_finding(find_id=2) # has ftp://localhost finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() # expect: marked as duplicate of original finding 2 (because finding 4 is a duplicate of finding 2 in sample data), hash_code not affected by endpoints (endpoints are not anymore in ZAP configuration for hash_code) self.assert_finding(finding_new, not_pk=finding_2.pk, duplicate=True, duplicate_finding_id=2, hash_code=finding_2.hash_code, not_hash_code=None) @@ -594,11 +640,13 @@ def test_dedupe_algo_endpoint_fields_host_port_identical(self): # create an identical copy of the new finding, with the same endpoints but different path finding_new, finding_2 = self.copy_and_reset_finding(find_id=2) # finding_2 has host ftp://localhost finding_new.save() + finding_new.refresh_from_db() ep = Endpoint(product=finding_new.test.engagement.product, finding=finding_new, host="localhost", protocol="ftp", path="local") ep.save() finding_new.endpoints.add(ep) finding_new.save() + finding_new.refresh_from_db() # expect: marked as duplicate of original finding 2 (because finding 4 is a duplicate of finding 2 in sample data), hash_code not affected by endpoints (endpoints are not anymore in ZAP configuration for hash_code) self.assert_finding(finding_new, not_pk=finding_2.pk, duplicate=True, duplicate_finding_id=2, hash_code=finding_2.hash_code, not_hash_code=None) @@ -613,11 +661,13 @@ def test_dedupe_algo_endpoint_field_path_different(self): # create an identical copy of the new finding, with the same endpoints but different path finding_new, finding_2 = self.copy_and_reset_finding(find_id=2) # finding_2 has host ftp://localhost finding_new.save() + finding_new.refresh_from_db() ep = Endpoint(product=finding_new.test.engagement.product, finding=finding_new, host="localhost", protocol="ftp", path="local") ep.save() finding_new.endpoints.add(ep) finding_new.save() + finding_new.refresh_from_db() # expect: marked as duplicate of original finding 2 (because finding 4 is a duplicate of finding 2 in sample data), hash_code not affected by endpoints (endpoints are not anymore in ZAP configuration for hash_code) self.assert_finding(finding_new, not_pk=finding_2.pk, duplicate=False, duplicate_finding_id=None, hash_code=finding_2.hash_code, not_hash_code=None) @@ -638,12 +688,14 @@ def test_identical_hash_code_with_intersect_endpoints(self): finding_new.endpoints.add(ep1) finding_new.endpoints.add(ep2) finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() # expect: marked not as duplicate of original finding 2 because the endpoints are different self.assert_finding(finding_new, not_pk=finding_2.pk, duplicate=False, hash_code=finding_2.hash_code) # create an identical copy of the new finding without original endpoints, but with 3 extra endpoints. finding_new3, finding_new = self.copy_and_reset_finding(find_id=finding_new.id) finding_new3.save(dedupe_option=False) + finding_new3.refresh_from_db() ep1 = Endpoint(product=finding_new3.test.engagement.product, finding=finding_new3, host="myhost4.com", protocol="https") ep1.save() ep2 = Endpoint(product=finding_new3.test.engagement.product, finding=finding_new3, host="myhost2.com", protocol="https") @@ -654,10 +706,12 @@ def test_identical_hash_code_with_intersect_endpoints(self): finding_new3.endpoints.add(ep2) finding_new3.endpoints.add(ep3) finding_new3.save() + finding_new3.refresh_from_db() # expect: marked not as duplicate of original finding 2 or finding_new3 because the endpoints are different self.assert_finding(finding_new3, not_pk=finding_new.pk, duplicate=True, duplicate_finding_id=finding_new.id, hash_code=finding_new.hash_code) # expect: marked not as duplicate of original finding 2 because the endpoints are different + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=finding_2.pk, duplicate=False, hash_code=finding_2.hash_code) # reset for further tests settings.DEDUPE_ALGO_ENDPOINT_FIELDS = dedupe_algo_endpoint_fields @@ -675,12 +729,14 @@ def test_identical_hash_code_with_different_endpoints(self): finding_new.endpoints.add(ep1) finding_new.endpoints.add(ep2) finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() # expect: marked not as duplicate of original finding 2 because the endpoints are different self.assert_finding(finding_new, not_pk=finding_2.pk, duplicate=False, hash_code=finding_2.hash_code) # create an identical copy of the new finding without original endpoints, but with 3 extra endpoints. finding_new3, finding_new = self.copy_and_reset_finding(find_id=finding_new.id) finding_new3.save(dedupe_option=False) + finding_new3.refresh_from_db() ep1 = Endpoint(product=finding_new3.test.engagement.product, finding=finding_new3, host="myhost4.com", protocol="https") ep1.save() ep2 = Endpoint(product=finding_new3.test.engagement.product, finding=finding_new3, host="myhost2.com", protocol="http") @@ -691,11 +747,13 @@ def test_identical_hash_code_with_different_endpoints(self): finding_new3.endpoints.add(ep2) finding_new3.endpoints.add(ep3) finding_new3.save() + finding_new3.refresh_from_db() # expect: marked not as duplicate of original finding 2 or finding_new3 because the endpoints are different self.assert_finding(finding_new3, not_pk=finding_new.pk, duplicate=False, hash_code=finding_new.hash_code) self.assert_finding(finding_new3, not_pk=finding_2.pk, duplicate=False, hash_code=finding_2.hash_code) # expect: marked not as duplicate of original finding 2 because the endpoints are different + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=finding_2.pk, duplicate=False, hash_code=finding_2.hash_code) # reset for further tests settings.DEDUPE_ALGO_ENDPOINT_FIELDS = dedupe_algo_endpoint_fields @@ -705,6 +763,7 @@ def test_identical_unique_id(self): # create identical copy finding_new, finding_124 = self.copy_and_reset_finding(find_id=124) finding_new.save() + finding_new.refresh_from_db() # expect duplicate self.assert_finding(finding_new, not_pk=124, duplicate=True, duplicate_finding_id=124, hash_code=finding_124.hash_code) @@ -714,6 +773,7 @@ def test_different_unique_id_unique_id(self): finding_new, finding_124 = self.copy_and_reset_finding(find_id=124) finding_new.unique_id_from_tool = "9999" finding_new.save() + finding_new.refresh_from_db() # expect not duplicate, but same hash_code self.assert_finding(finding_new, not_pk=124, duplicate=False, hash_code=finding_124.hash_code) @@ -722,6 +782,7 @@ def test_identical_ordering_unique_id(self): # create identical copy finding_new, finding_125 = self.copy_and_reset_finding(find_id=125) finding_new.save() + finding_new.refresh_from_db() # expect duplicate, but of 124 as that is first in the list, but it's newer then 125. feature or BUG? self.assert_finding(finding_new, not_pk=124, duplicate=True, duplicate_finding_id=124, hash_code=finding_125.hash_code) @@ -734,6 +795,7 @@ def test_title_description_line_filepath_different_unique_id(self): finding_new.cwe = "456" finding_new.description = "useless finding" finding_new.save() + finding_new.refresh_from_db() # expect duplicate as we only match on unique id, hash_code also different self.assert_finding(finding_new, not_pk=124, duplicate=True, duplicate_finding_id=124, not_hash_code=finding_124.hash_code) @@ -747,6 +809,7 @@ def test_title_description_line_filepath_different_and_id_different_unique_id(se finding_new.description = "useless finding" finding_new.unique_id_from_tool = "9999" finding_new.save() + finding_new.refresh_from_db() # expect not duplicate as we match on unique id, hash_code also different because fields changed self.assert_finding(finding_new, not_pk=124, duplicate=False, not_hash_code=finding_124.hash_code) @@ -766,6 +829,7 @@ def test_dedupe_not_inside_engagement_unique_id(self): finding_new.unique_id_from_tool = "888" finding_new.save() + finding_new.refresh_from_db() # expect not duplicate as dedupe_inside_engagement is True self.assert_finding(finding_new, not_pk=124, duplicate=False, hash_code=finding_124.hash_code) @@ -777,6 +841,7 @@ def test_dedupe_inside_engagement_unique_id_different_test_type(self): # first setup some finding with same unique_id in same engagement, but different test (same test_type) finding_new.test = Test.objects.get(id=90) finding_new.save() + finding_new.refresh_from_db() # expect not duplicate as the test_type doesn't match self.assert_finding(finding_new, not_pk=124, duplicate=False, hash_code=finding_124.hash_code) @@ -788,6 +853,7 @@ def test_dedupe_inside_engagement_unique_id(self): # first setup some finding with same unique_id in same engagement, but different test (same test_type) finding_new.test = Test.objects.get(id=66) finding_new.save() + finding_new.refresh_from_db() # expect duplicate as dedupe_inside_engagement is True and the other test is in the same engagement and has the same test type self.assert_finding(finding_new, not_pk=124, duplicate=True, duplicate_finding_id=124, hash_code=finding_124.hash_code) @@ -808,6 +874,7 @@ def test_dedupe_inside_engagement_unique_id2(self): finding_new.unique_id_from_tool = "888" finding_new.save() + finding_new.refresh_from_db() # expect duplicate as dedupe_inside_engagement is false self.assert_finding(finding_new, not_pk=124, duplicate=True, duplicate_finding_id=finding_22.id, hash_code=finding_124.hash_code) @@ -824,6 +891,7 @@ def test_dedupe_same_id_different_test_type_unique_id(self): self.set_dedupe_inside_engagement(False) finding_22.save(dedupe_option=False) finding_new.save() + finding_new.refresh_from_db() # expect not duplicate as the mathcing finding is from another test_type, hash_code is the same as original self.assert_finding(finding_new, not_pk=124, duplicate=False, hash_code=finding_124.hash_code) @@ -837,6 +905,7 @@ def test_identical_different_endpoints_unique_id(self): ep1.save() finding_new.endpoints.add(ep1) finding_new.save() + finding_new.refresh_from_db() # expect duplicate, as endpoints shouldn't affect dedupe and hash_code due to unique_id self.assert_finding(finding_new, not_pk=124, duplicate=True, duplicate_finding_id=124, hash_code=finding_124.hash_code) @@ -855,6 +924,7 @@ def test_identical_endpoints_unique_id(self): ep_n.save() finding_new.endpoints.add(ep_n) finding_new.save() + finding_new.refresh_from_db() # expect duplicate: unique_id match dominates regardless of identical endpoints self.assert_finding(finding_new, not_pk=124, duplicate=True, duplicate_finding_id=124, hash_code=finding_124.hash_code) @@ -876,6 +946,7 @@ def test_extra_endpoints_unique_id(self): finding_new.endpoints.add(ep2) finding_new.endpoints.add(ep3) finding_new.save() + finding_new.refresh_from_db() # expect duplicate: unique_id match regardless of extra endpoints self.assert_finding(finding_new, not_pk=124, duplicate=True, duplicate_finding_id=124, hash_code=finding_124.hash_code) @@ -1016,6 +1087,7 @@ def test_identical_unique_id_or_hash_code(self): # create identical copy finding_new, finding_224 = self.copy_and_reset_finding(find_id=224) finding_new.save() + finding_new.refresh_from_db() # expect duplicate as uid matches self.assert_finding(finding_new, not_pk=224, duplicate=True, duplicate_finding_id=224, hash_code=finding_224.hash_code) @@ -1034,6 +1106,7 @@ def test_identical_endpoints_unique_id_or_hash_code(self): ep_n.save() finding_new.endpoints.add(ep_n) finding_new.save() + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=224, duplicate=True, duplicate_finding_id=224, hash_code=finding_224.hash_code) @@ -1054,6 +1127,7 @@ def test_extra_endpoints_unique_id_or_hash_code(self): finding_new.endpoints.add(ep_n1) finding_new.endpoints.add(ep_n2) finding_new.save() + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=224, duplicate=True, duplicate_finding_id=224, hash_code=finding_224.hash_code) @@ -1077,6 +1151,7 @@ def test_intersect_endpoints_unique_id_or_hash_code(self): finding_new.endpoints.add(ep_n1) finding_new.endpoints.add(ep_n2) finding_new.save() + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=224, duplicate=True, duplicate_finding_id=224, hash_code=finding_224.hash_code) @@ -1087,6 +1162,7 @@ def test_identical_unique_id_or_hash_code_bug(self): finding_new, _finding_224 = self.copy_and_reset_finding(find_id=224) finding_new.title = finding_124.title # use title from 124 to get matching hashcode finding_new.save() + finding_new.refresh_from_db() # marked as duplicate of 124 as that has the same hashcode and is earlier in the list of findings ordered by id self.assert_finding(finding_new, not_pk=224, duplicate=True, duplicate_finding_id=124, hash_code=finding_124.hash_code) @@ -1096,6 +1172,7 @@ def test_different_unique_id_unique_id_or_hash_code(self): finding_new, finding_224 = self.copy_and_reset_finding(find_id=224) finding_new.unique_id_from_tool = "9999" finding_new.save() + finding_new.refresh_from_db() # expect duplicate, uid mismatch, but same hash_code self.assert_finding(finding_new, not_pk=224, duplicate=True, duplicate_finding_id=finding_224.id, hash_code=finding_224.hash_code) @@ -1105,6 +1182,7 @@ def test_different_unique_id_unique_id_or_hash_code(self): finding_new.unique_id_from_tool = "9999" finding_new.title = "no no no no no no" finding_new.save() + finding_new.refresh_from_db() # expect duplicate, uid mismatch, but same hash_code self.assert_finding(finding_new, not_pk=224, duplicate=False, not_hash_code=finding_224.hash_code) @@ -1129,6 +1207,7 @@ def test_uid_mismatch_hash_match_identical_endpoints_unique_id_or_hash_code(self ep_n.save() finding_new.endpoints.add(ep_n) finding_new.save() + finding_new.refresh_from_db() # expect duplicate via hash path despite UID mismatch and identical endpoints self.assert_finding(finding_new, not_pk=224, duplicate=True, duplicate_finding_id=finding_224.id, hash_code=finding_224.hash_code) @@ -1157,6 +1236,7 @@ def test_uid_mismatch_hash_match_extra_endpoints_unique_id_or_hash_code(self): finding_new.endpoints.add(ep_n1) finding_new.endpoints.add(ep_n2) finding_new.save() + finding_new.refresh_from_db() # expect duplicate via hash path despite UID mismatch and extra endpoints self.assert_finding(finding_new, not_pk=224, duplicate=True, duplicate_finding_id=finding_224.id, hash_code=finding_224.hash_code) @@ -1188,6 +1268,7 @@ def test_uid_mismatch_hash_match_intersect_endpoints_unique_id_or_hash_code(self finding_new.endpoints.add(ep_n1) finding_new.endpoints.add(ep_n2) finding_new.save() + finding_new.refresh_from_db() # expect duplicate via hash path despite UID mismatch and intersecting endpoints self.assert_finding(finding_new, not_pk=224, duplicate=True, duplicate_finding_id=finding_224.id, hash_code=finding_224.hash_code) @@ -1199,6 +1280,7 @@ def test_identical_ordering_unique_id_or_hash_code(self): # create identical copy finding_new, finding_225 = self.copy_and_reset_finding(find_id=225) finding_new.save() + finding_new.refresh_from_db() # expect duplicate, but of 124 as that is first in the list, but it's newer then 225. feature or BUG? self.assert_finding(finding_new, not_pk=224, duplicate=True, duplicate_finding_id=224, hash_code=finding_225.hash_code) @@ -1211,6 +1293,7 @@ def test_title_description_line_filepath_different_unique_id_or_hash_code(self): finding_new.cwe = "456" finding_new.description = "useless finding" finding_new.save() + finding_new.refresh_from_db() # expect duplicate as we only match on unique id, hash_code also different self.assert_finding(finding_new, not_pk=224, duplicate=True, duplicate_finding_id=224, not_hash_code=finding_224.hash_code) @@ -1224,6 +1307,7 @@ def test_title_description_line_filepath_different_and_id_different_unique_id_or finding_new.description = "useless finding" finding_new.unique_id_from_tool = "9999" finding_new.save() + finding_new.refresh_from_db() # expect not duplicate as we match on unique id, hash_code also different because fields changed self.assert_finding(finding_new, not_pk=224, duplicate=False, not_hash_code=finding_224.hash_code) @@ -1243,6 +1327,7 @@ def test_dedupe_not_inside_engagement_same_hash_unique_id_or_hash_code(self): finding_new.unique_id_from_tool = "888" finding_new.save() + finding_new.refresh_from_db() # should become duplicate of finding 22 because of the uid match, but existing BUG makes it duplicate of 224 due to hashcode match self.assert_finding(finding_new, not_pk=224, duplicate=True, duplicate_finding_id=224, hash_code=finding_224.hash_code) @@ -1263,6 +1348,7 @@ def test_dedupe_not_inside_engagement_same_hash_unique_id_or_hash_code2(self): finding_new.hash_code = finding_22.hash_code # sneaky copy of hash_code to be able to test this case icm with the bug in previous test case above finding_new.unique_id_from_tool = "333" finding_new.save() + finding_new.refresh_from_db() # expect not duplicate as dedupe_inside_engagement is True and 22 is in another engagement # but existing BUG? it is marked as duplicate of 124 which has the same hash and same engagement, but different unique_id_from_tool at same test_type @@ -1276,6 +1362,7 @@ def test_dedupe_inside_engagement_unique_id_or_hash_code_different_test_type(sel # first setup some finding with same unique_id in same engagement, but different test, different test_type finding_new.test = Test.objects.get(id=91) finding_new.save() + finding_new.refresh_from_db() # expect not duplicate as the test_type doesn't match self.assert_finding(finding_new, not_pk=224, duplicate=False) @@ -1288,6 +1375,7 @@ def test_dedupe_inside_engagement_unique_id_or_hash_code(self): # first setup some finding with same unique_id in same engagement, but different test (same test_type) finding_new.test = Test.objects.get(id=88) finding_new.save() + finding_new.refresh_from_db() # expect duplicate as dedupe_inside_engagement is True and the other test is in the same engagement and has the same test type self.assert_finding(finding_new, not_pk=224, duplicate=True, duplicate_finding_id=224, hash_code=finding_224.hash_code) @@ -1310,6 +1398,7 @@ def test_dedupe_inside_engagement_unique_id_or_hash_code2(self): finding_new.unique_id_from_tool = "888" finding_new.title = "hack to work around bug that matches on hash_code first" # arrange different hash_code finding_new.save() + finding_new.refresh_from_db() # expect duplicate as dedupe_inside_engagement is false self.assert_finding(finding_new, not_pk=224, duplicate=True, duplicate_finding_id=finding_22.id, not_hash_code=finding_22.hash_code) @@ -1327,6 +1416,7 @@ def test_dedupe_same_id_different_test_type_unique_id_or_hash_code(self): finding_22.save(dedupe_option=False) finding_new.title = "title to change hash_code" finding_new.save() + finding_new.refresh_from_db() # expect not duplicate as the mathcing finding is from another test_type, hash_code is also different self.assert_finding(finding_new, not_pk=224, duplicate=False, not_hash_code=finding_224.hash_code) @@ -1342,6 +1432,7 @@ def test_dedupe_same_id_different_test_type_unique_id_or_hash_code(self): self.set_dedupe_inside_engagement(False) finding_22.save(dedupe_option=False) finding_new.save() + finding_new.refresh_from_db() # expect not duplicate as the mathcing finding is from another test_type, hash_code is also different self.assert_finding(finding_new, not_pk=224, duplicate=True, duplicate_finding_id=224, hash_code=finding_224.hash_code) @@ -1449,6 +1540,7 @@ def test_identical_different_endpoints_unique_id_or_hash_code_multiple(self): finding_new2.unique_id_from_tool = 1 finding_new2.dynamic_finding = True finding_new2.save() + finding_new2.refresh_from_db() if settings.DEDUPE_ALGO_ENDPOINT_FIELDS == []: # different uid. and different endpoints, but endpoints not used for hash anymore -> duplicate @@ -1468,6 +1560,7 @@ def test_identical_different_endpoints_unique_id_or_hash_code_multiple(self): finding_new3.unique_id_from_tool = 1 finding_new3.dynamic_finding = False finding_new3.save() + finding_new3.refresh_from_db() if settings.DEDUPE_ALGO_ENDPOINT_FIELDS == []: # different uid. and different endpoints, dynamic_finding is set to False hash_code still not affected by endpoints @@ -1508,6 +1601,7 @@ def test_identical_legacy_dedupe_option_true_false(self): # expect duplicate when saving with dedupe_option=True finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=24, duplicate=True, duplicate_finding_id=finding_24.duplicate_finding.id, hash_code=finding_24.hash_code) def test_duplicate_after_modification(self): @@ -1531,6 +1625,7 @@ def test_case_sensitiveness_hash_code_computation(self): finding_new, finding_22 = self.copy_and_reset_finding(find_id=22) finding_new.title = finding_22.title.upper() finding_new.save(dedupe_option=True) + finding_new.refresh_from_db() self.assert_finding(finding_new, not_pk=22, duplicate=True, duplicate_finding_id=finding_22.id, hash_code=finding_22.hash_code) def test_title_case(self): diff --git a/unittests/test_product_grading.py b/unittests/test_product_grading.py index e7000e0fb48..8d5f2dce2e7 100644 --- a/unittests/test_product_grading.py +++ b/unittests/test_product_grading.py @@ -43,6 +43,8 @@ def create_single_critical_and_assert_grade(self, expected_grade, *, verified=Fa self.assertIsNone(self.product.prod_numeric_grade) # Add a single critical finding self.create_finding_on_test(severity="Critical", verified=verified) + # Refresh product from database to get updated grade + self.product.refresh_from_db() # See that the grade does not degrade at all self.assertEqual(self.product.prod_numeric_grade, expected_grade) From 597ba2f0b2ab12802de318fdb516ab8affca78ee Mon Sep 17 00:00:00 2001 From: Valentijn Scholten Date: Fri, 26 Dec 2025 20:13:00 +0100 Subject: [PATCH 07/36] remove leftover signature methods --- dojo/finding/helper.py | 8 -------- dojo/importers/base_importer.py | 2 +- dojo/importers/default_importer.py | 30 ++++++++-------------------- dojo/importers/default_reimporter.py | 27 ++++++++----------------- dojo/utils.py | 11 ---------- 5 files changed, 17 insertions(+), 61 deletions(-) diff --git a/dojo/finding/helper.py b/dojo/finding/helper.py index 57d85086119..908afee38b9 100644 --- a/dojo/finding/helper.py +++ b/dojo/finding/helper.py @@ -453,14 +453,6 @@ def post_process_finding_save_internal(finding, dedupe_option=True, rules_option jira_helper.push_to_jira(finding.finding_group) -@dojo_async_task(signature=True) -@app.task -def post_process_findings_batch_signature(finding_ids, *args, dedupe_option=True, rules_option=True, product_grading_option=True, - issue_updater_option=True, push_to_jira=False, user=None, **kwargs): - return post_process_findings_batch(finding_ids, *args, dedupe_option=dedupe_option, rules_option=rules_option, product_grading_option=product_grading_option, issue_updater_option=issue_updater_option, push_to_jira=push_to_jira, user=user, **kwargs) - # Pass arguments as keyword arguments to ensure Celery properly serializes them - - @dojo_async_task @app.task def post_process_findings_batch(finding_ids, *args, dedupe_option=True, rules_option=True, product_grading_option=True, diff --git a/dojo/importers/base_importer.py b/dojo/importers/base_importer.py index b9a0289f9ef..1c0b84687b1 100644 --- a/dojo/importers/base_importer.py +++ b/dojo/importers/base_importer.py @@ -668,7 +668,7 @@ def maybe_launch_post_processing_chord( product = self.test.engagement.product system_settings = System_Settings.objects.get() if system_settings.enable_product_grade: - calculate_grade_signature = utils.calculate_grade_signature(product) + calculate_grade_signature = utils.calculate_grade.si(product.id) chord(post_processing_task_signatures)(calculate_grade_signature) else: group(post_processing_task_signatures).apply_async() diff --git a/dojo/importers/default_importer.py b/dojo/importers/default_importer.py index 35fe6712387..95c6adeed3c 100644 --- a/dojo/importers/default_importer.py +++ b/dojo/importers/default_importer.py @@ -7,7 +7,6 @@ from django.urls import reverse import dojo.jira_link.helper as jira_helper -from dojo.decorators import we_want_async from dojo.finding import helper as finding_helper from dojo.importers.base_importer import BaseImporter, Parser from dojo.importers.options import ImporterOptions @@ -255,27 +254,14 @@ def process_findings( batch_finding_ids.clear() logger.debug("process_findings: dispatching batch with push_to_jira=%s (batch_size=%d, is_final=%s)", push_to_jira, len(finding_ids_batch), is_final_finding) - if we_want_async(async_user=self.user): - signature = finding_helper.post_process_findings_batch_signature( - finding_ids_batch, - dedupe_option=True, - rules_option=True, - product_grading_option=True, - issue_updater_option=True, - push_to_jira=push_to_jira, - ) - logger.debug("process_findings: signature created with push_to_jira=%s, signature.kwargs=%s", - push_to_jira, signature.kwargs) - signature() - else: - finding_helper.post_process_findings_batch( - finding_ids_batch, - dedupe_option=True, - rules_option=True, - product_grading_option=True, - issue_updater_option=True, - push_to_jira=push_to_jira, - ) + finding_helper.post_process_findings_batch( + finding_ids_batch, + dedupe_option=True, + rules_option=True, + product_grading_option=True, + issue_updater_option=True, + push_to_jira=push_to_jira, + ) # No chord: tasks are dispatched immediately above per batch diff --git a/dojo/importers/default_reimporter.py b/dojo/importers/default_reimporter.py index 47ce8c61acd..1ddabf5e87f 100644 --- a/dojo/importers/default_reimporter.py +++ b/dojo/importers/default_reimporter.py @@ -7,7 +7,6 @@ import dojo.finding.helper as finding_helper import dojo.jira_link.helper as jira_helper -from dojo.decorators import we_want_async from dojo.finding.deduplication import ( find_candidates_for_deduplication_hash, find_candidates_for_deduplication_uid_or_hash, @@ -413,24 +412,14 @@ def process_findings( if len(batch_finding_ids) >= dedupe_batch_max_size or is_final: finding_ids_batch = list(batch_finding_ids) batch_finding_ids.clear() - if we_want_async(async_user=self.user): - finding_helper.post_process_findings_batch_signature( - finding_ids_batch, - dedupe_option=True, - rules_option=True, - product_grading_option=True, - issue_updater_option=True, - push_to_jira=push_to_jira, - )() - else: - finding_helper.post_process_findings_batch( - finding_ids_batch, - dedupe_option=True, - rules_option=True, - product_grading_option=True, - issue_updater_option=True, - push_to_jira=push_to_jira, - ) + finding_helper.post_process_findings_batch( + finding_ids_batch, + dedupe_option=True, + rules_option=True, + product_grading_option=True, + issue_updater_option=True, + push_to_jira=push_to_jira, + ) # No chord: tasks are dispatched immediately above per batch diff --git a/dojo/utils.py b/dojo/utils.py index 1d3b8b24aa3..d87fcd59540 100644 --- a/dojo/utils.py +++ b/dojo/utils.py @@ -1236,17 +1236,6 @@ def get_setting(setting): return getattr(settings, setting) -@dojo_async_task(signature=True) -@app.task -def calculate_grade_signature(product_id, *args, **kwargs): - """Returns a signature for calculating product grade that can be used in chords or groups.""" - product = get_object_or_none(Product, id=product_id) - if not product: - logger.warning("Product with id %s does not exist, skipping calculate_grade_signature", product_id) - return None - return calculate_grade_internal(product, *args, **kwargs) - - @dojo_async_task @app.task def calculate_grade(product_id, *args, **kwargs): From d7354621e9b0d77d3bc0f03269b266a49081f5cb Mon Sep 17 00:00:00 2001 From: Valentijn Scholten Date: Fri, 26 Dec 2025 22:49:07 +0100 Subject: [PATCH 08/36] fix test counts --- unittests/test_importers_performance.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/unittests/test_importers_performance.py b/unittests/test_importers_performance.py index d2301a1f5d7..9fa27109ac9 100644 --- a/unittests/test_importers_performance.py +++ b/unittests/test_importers_performance.py @@ -265,12 +265,16 @@ def test_import_reimport_reimport_performance_pghistory_async(self): configure_pghistory_triggers() self._import_reimport_performance( + + expected_num_queries1=306, - expected_num_async_tasks1=7, + expected_num_async_tasks1=305, expected_num_queries2=232, expected_num_async_tasks2=18, expected_num_queries3=114, expected_num_async_tasks3=17, + + ) @override_settings(ENABLE_AUDITLOG=True) @@ -287,12 +291,14 @@ def test_import_reimport_reimport_performance_pghistory_no_async(self): testuser.usercontactinfo.save() self._import_reimport_performance( - expected_num_queries1=313, + + expected_num_queries1=312, expected_num_async_tasks1=6, expected_num_queries2=239, expected_num_async_tasks2=17, expected_num_queries3=121, expected_num_async_tasks3=16, + ) @override_settings(ENABLE_AUDITLOG=True) @@ -311,13 +317,15 @@ def test_import_reimport_reimport_performance_pghistory_no_async_with_product_gr self._import_reimport_performance( - expected_num_queries1=317, + + expected_num_queries1=316, expected_num_async_tasks1=8, expected_num_queries2=243, expected_num_async_tasks2=19, expected_num_queries3=125, expected_num_async_tasks3=18, + ) # Deduplication is enabled in the tests above, but to properly test it we must run the same import twice and capture the results. @@ -456,9 +464,11 @@ def test_deduplication_performance_pghistory_no_async(self): testuser.usercontactinfo.save() self._deduplication_performance( - expected_num_queries1=282, + + expected_num_queries1=281, expected_num_async_tasks1=7, expected_num_queries2=246, expected_num_async_tasks2=7, + ) From de166e74c094a77bbb39583d385aa59bc5ee2c9c Mon Sep 17 00:00:00 2001 From: Valentijn Scholten Date: Fri, 26 Dec 2025 22:49:50 +0100 Subject: [PATCH 09/36] fix test counts --- unittests/test_importers_performance.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unittests/test_importers_performance.py b/unittests/test_importers_performance.py index 9fa27109ac9..cc33053601b 100644 --- a/unittests/test_importers_performance.py +++ b/unittests/test_importers_performance.py @@ -267,6 +267,7 @@ def test_import_reimport_reimport_performance_pghistory_async(self): self._import_reimport_performance( + expected_num_queries1=306, expected_num_async_tasks1=305, expected_num_queries2=232, @@ -275,6 +276,7 @@ def test_import_reimport_reimport_performance_pghistory_async(self): expected_num_async_tasks3=17, + ) @override_settings(ENABLE_AUDITLOG=True) From 9e742056b12406d47ab8459a7d46f91508b18781 Mon Sep 17 00:00:00 2001 From: Valentijn Scholten Date: Fri, 26 Dec 2025 22:51:22 +0100 Subject: [PATCH 10/36] fix test counts --- scripts/update_performance_test_counts.py | 264 +++++++++++----------- unittests/test_importers_performance.py | 16 +- 2 files changed, 141 insertions(+), 139 deletions(-) diff --git a/scripts/update_performance_test_counts.py b/scripts/update_performance_test_counts.py index f7cfaae2859..bc99e108da8 100644 --- a/scripts/update_performance_test_counts.py +++ b/scripts/update_performance_test_counts.py @@ -219,26 +219,56 @@ def parse_test_output(output: str) -> list[TestCount]: # The test output format is: # FAIL: test_name (step='import1', metric='queries') # AssertionError: 118 != 120 : 118 queries executed, 120 expected - # OR for async tasks: + # + # For async tasks we may see: # FAIL: test_name (step='import1', metric='async_tasks') - # AssertionError: 7 != 8 : 7 async tasks executed, 8 expected - - # Pattern to match the full failure block: - # FAIL: test_name (full.path.to.test) (step='...', metric='...') - # AssertionError: actual != expected : actual ... executed, expected expected - # The test name may include the full path in parentheses, so we extract just the method name - failure_pattern = re.compile( - r"FAIL:\s+(test_\w+)\s+\([^)]+\)\s+\(step=['\"](\w+)['\"],\s*metric=['\"](\w+)['\"]\)\s*\n" - r".*?AssertionError:\s+(\d+)\s+!=\s+(\d+)\s+:\s+\d+\s+(?:queries|async tasks?)\s+executed,\s+\d+\s+expected", - re.MULTILINE | re.DOTALL, + # AssertionError: Expected 7 celery tasks, but 6 were created. + + # Parse failures by splitting into individual FAIL blocks, to avoid accidentally + # associating an assertion from a different FAIL with the wrong metric. + fail_header = re.compile( + r"^FAIL:\s+(test_\w+)\s+\([^)]+\)\s+\(step=['\"](\w+)['\"],\s*metric=['\"](\w+)['\"]\)\s*$", + re.MULTILINE, ) - for match in failure_pattern.finditer(output): + headers = list(fail_header.finditer(output)) + for idx, match in enumerate(headers): test_name = match.group(1) step = match.group(2) metric = match.group(3) - actual = int(match.group(4)) - expected = int(match.group(5)) + + block_start = match.end() + block_end = headers[idx + 1].start() if idx + 1 < len(headers) else len(output) + block = output[block_start:block_end] + + actual: int | None = None + expected: int | None = None + + if metric == "queries": + m = re.search( + r"AssertionError:\s+(\d+)\s+!=\s+(\d+)\s+:\s+\d+\s+queries\s+executed,\s+\d+\s+expected", + block, + ) + if m: + actual = int(m.group(1)) + expected = int(m.group(2)) + elif metric == "async_tasks": + # Celery task count assertions can be in a different format. + m = re.search(r"AssertionError:\s+Expected\s+(\d+)\s+celery tasks?,\s+but\s+(\d+)\s+were created\.", block) + if m: + expected = int(m.group(1)) + actual = int(m.group(2)) + else: + m = re.search( + r"AssertionError:\s+(\d+)\s+!=\s+(\d+)\s+:\s+\d+\s+async tasks?\s+executed,\s+\d+\s+expected", + block, + ) + if m: + actual = int(m.group(1)) + expected = int(m.group(2)) + + if actual is None or expected is None: + continue count = TestCount(test_name, step, metric) count.actual = actual @@ -246,40 +276,6 @@ def parse_test_output(output: str) -> list[TestCount]: count.difference = expected - actual counts.append(count) - # Also try a simpler pattern in case the format is slightly different - if not counts: - # Look for lines with step/metric followed by AssertionError on nearby lines - lines = output.split("\n") - i = 0 - while i < len(lines): - line = lines[i] - - # Look for FAIL: test_name (may include full path in parentheses) - # Format: FAIL: test_name (full.path) (step='...', metric='...') - fail_match = re.search(r"FAIL:\s+(test_\w+)\s+\([^)]+\)\s+\(step=['\"](\w+)['\"],\s*metric=['\"](\w+)['\"]\)", line) - if fail_match: - test_name = fail_match.group(1) - step = fail_match.group(2) - metric = fail_match.group(3) - # Look ahead for AssertionError - for j in range(i, min(i + 15, len(lines))): - assertion_match = re.search( - r"AssertionError:\s+(\d+)\s+!=\s+(\d+)\s+:\s+\d+\s+(?:queries|async tasks?)\s+executed,\s+\d+\s+expected", - lines[j], - ) - - if assertion_match: - actual = int(assertion_match.group(1)) - expected = int(assertion_match.group(2)) - - count = TestCount(test_name, step, metric) - count.actual = actual - count.expected = expected - count.difference = expected - actual - counts.append(count) - break - i += 1 - if counts: print(f"\n📊 Parsed {len(counts)} count mismatch(es) from test output:") for count in counts: @@ -378,6 +374,27 @@ def update_test_file(counts: list[TestCount]): content = TEST_FILE.read_text() + def _extract_call_span(method_content: str, call_name: str) -> tuple[int, int] | None: + """Return (start, end) indices of the first call to `call_name(...)` within method_content.""" + start = method_content.find(call_name) + if start == -1: + return None + + open_paren = method_content.find("(", start) + if open_paren == -1: + return None + + depth = 0 + for idx in range(open_paren, len(method_content)): + ch = method_content[idx] + if ch == "(": + depth += 1 + elif ch == ")": + depth -= 1 + if depth == 0: + return start, idx + 1 + return None + # Create a mapping of test_name -> step_metric -> new_value updates = {} for count in counts: @@ -419,100 +436,49 @@ def update_test_file(counts: list[TestCount]): test_method_start = test_match.start() test_method_end = test_match.end() - # Try to find _import_reimport_performance call first - perf_call_pattern_import_reimport = re.compile( - r"(self\._import_reimport_performance\s*\(\s*)" - r"expected_num_queries1\s*=\s*(\d+)\s*,\s*" - r"expected_num_async_tasks1\s*=\s*(\d+)\s*,\s*" - r"expected_num_queries2\s*=\s*(\d+)\s*,\s*" - r"expected_num_async_tasks2\s*=\s*(\d+)\s*,\s*" - r"expected_num_queries3\s*=\s*(\d+)\s*,\s*" - r"expected_num_async_tasks3\s*=\s*(\d+)\s*," - r"(\s*\))", - re.DOTALL, - ) - - # Try to find _deduplication_performance call - perf_call_pattern_deduplication = re.compile( - r"(self\._deduplication_performance\s*\(\s*)" - r"expected_num_queries1\s*=\s*(\d+)\s*,\s*" - r"expected_num_async_tasks1\s*=\s*(\d+)\s*,\s*" - r"expected_num_queries2\s*=\s*(\d+)\s*,\s*" - r"expected_num_async_tasks2\s*=\s*(\d+)\s*," - r"(\s*\))", - re.DOTALL, - ) - - perf_match = perf_call_pattern_import_reimport.search(test_method_content) - method_type = "import_reimport" + call_span = _extract_call_span(test_method_content, "self._import_reimport_performance") param_map = param_map_import_reimport - param_order = [ - "import1_queries", - "import1_async_tasks", - "reimport1_queries", - "reimport1_async_tasks", - "reimport2_queries", - "reimport2_async_tasks", - ] - - if not perf_match: - perf_match = perf_call_pattern_deduplication.search(test_method_content) - if perf_match: - method_type = "deduplication" + if call_span is None: + call_span = _extract_call_span(test_method_content, "self._deduplication_performance") + if call_span is not None: param_map = param_map_deduplication - param_order = [ - "first_import_queries", - "first_import_async_tasks", - "second_import_queries", - "second_import_async_tasks", - ] else: - print(f"⚠️ Warning: Could not find _import_reimport_performance or _deduplication_performance call in {test_name}") + print( + f"⚠️ Warning: Could not find _import_reimport_performance or _deduplication_performance call in {test_name}", + ) continue - # Get the indentation from the original call (first line after opening paren) - call_lines = test_method_content[perf_match.start():perf_match.end()].split("\n") - indent = "" - for line in call_lines: - if "expected_num_queries1" in line: - # Extract indentation (spaces before the parameter) - indent_match = re.match(r"(\s*)expected_num_queries1", line) - if indent_match: - indent = indent_match.group(1) - break - - # If we couldn't find indentation, use a default - if not indent: - indent = " " # 12 spaces default - - replacement_parts = [perf_match.group(1)] # Opening: "self._import_reimport_performance(" - updated_params = [] - for i, step_metric in enumerate(param_order): - param_name = param_map[step_metric] - old_value = int(perf_match.group(i + 2)) # +2 because group 1 is the opening - if step_metric in test_updates: - new_value = test_updates[step_metric] - if old_value != new_value: - updated_params.append(f"{param_name}: {old_value} → {new_value}") - else: - # Keep the existing value - new_value = old_value + call_start, call_end = call_span + original_call = test_method_content[call_start:call_end] + updated_call = original_call - replacement_parts.append(f"{indent}{param_name}={new_value},") - - # Closing parenthesis - group number depends on method type - closing_group = 8 if method_type == "import_reimport" else 6 - replacement_parts.append(perf_match.group(closing_group)) # Closing parenthesis - replacement = "\n".join(replacement_parts) + updated_params = [] + for step_metric, param_name in param_map.items(): + if step_metric not in test_updates: + continue + new_value = test_updates[step_metric] + m = re.search(rf"({re.escape(param_name)}\s*=\s*)(\d+)", updated_call) + if not m: + continue + old_value = int(m.group(2)) + if old_value == new_value: + continue + updated_params.append(f"{param_name}: {old_value} → {new_value}") + updated_call = re.sub( + rf"({re.escape(param_name)}\s*=\s*)\d+", + rf"\g<1>{new_value}", + updated_call, + count=1, + ) if updated_params: print(f" Updated: {', '.join(updated_params)}") - # Replace the method call within the test method content + # Replace the method call within the test method content (in-place; do not reformat) updated_method_content = ( - test_method_content[: perf_match.start()] - + replacement - + test_method_content[perf_match.end() :] + test_method_content[:call_start] + + updated_call + + test_method_content[call_end:] ) # Replace the entire test method in the original content @@ -547,6 +513,30 @@ def verify_tests(test_class: str) -> bool: return True +def verify_and_get_mismatches(test_class: str) -> tuple[bool, list[TestCount]]: + """Run the full test class and return (success, parsed mismatches).""" + print(f"Verifying tests for {test_class}...") + output, return_code = run_tests(test_class) + + success, error_msg = check_test_execution_success(output, return_code) + if not success: + print(f"\n❌ Test execution failed: {error_msg}") + return False, [] + + counts = parse_test_output(output) + if counts: + print("\n❌ Some tests still have count mismatches:") + for count in counts: + print( + f" {count.test_name} - {count.step} {count.metric}: " + f"expected {count.expected}, got {count.actual}", + ) + return False, counts + + print("\n✅ All tests pass!") + return True, [] + + def main(): parser = argparse.ArgumentParser( description="Update performance test query counts", @@ -657,7 +647,17 @@ def main(): if all_counts: print(f"\n{'=' * 80}") print(f"✅ Updated {len(all_counts)} count(s) across {len({c.test_name for c in all_counts})} test(s)") - print("\nNext step: Run --verify to ensure all tests pass") + # Some performance counts can vary depending on test ordering / keepdb state. + # Do a final full-suite pass and apply any remaining mismatches so the suite passes as run in CI. + print("\nRunning a final verify pass for stability...") + success, suite_mismatches = verify_and_get_mismatches(args.test_class) + if not success and suite_mismatches: + print("\nApplying remaining mismatches from full-suite run...") + update_test_file(suite_mismatches) + print("\nRe-running verify...") + success, _ = verify_and_get_mismatches(args.test_class) + sys.exit(0 if success else 1) + sys.exit(0 if success else 1) else: print(f"\n{'=' * 80}") print("\n✅ No differences found. All tests are already up to date.") diff --git a/unittests/test_importers_performance.py b/unittests/test_importers_performance.py index cc33053601b..db43013a129 100644 --- a/unittests/test_importers_performance.py +++ b/unittests/test_importers_performance.py @@ -268,12 +268,14 @@ def test_import_reimport_reimport_performance_pghistory_async(self): - expected_num_queries1=306, - expected_num_async_tasks1=305, + + expected_num_queries1=305, + expected_num_async_tasks1=6, expected_num_queries2=232, - expected_num_async_tasks2=18, + expected_num_async_tasks2=17, expected_num_queries3=114, - expected_num_async_tasks3=17, + expected_num_async_tasks3=16, + @@ -445,10 +447,10 @@ def test_deduplication_performance_pghistory_async(self): self.system_settings(enable_deduplication=True) self._deduplication_performance( - expected_num_queries1=275, - expected_num_async_tasks1=8, + expected_num_queries1=274, + expected_num_async_tasks1=7, expected_num_queries2=185, - expected_num_async_tasks2=8, + expected_num_async_tasks2=7, check_duplicates=False, # Async mode - deduplication happens later ) From 8cf28101e2b14779cd55aedf4e0e42e7d53df447 Mon Sep 17 00:00:00 2001 From: valentijnscholten Date: Mon, 29 Dec 2025 18:36:52 +0100 Subject: [PATCH 11/36] Update dojo/settings/settings.dist.py Co-authored-by: Cody Maffucci <46459665+Maffooch@users.noreply.github.com> --- dojo/settings/settings.dist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dojo/settings/settings.dist.py b/dojo/settings/settings.dist.py index 32b136a6d3c..f219c1c0db1 100644 --- a/dojo/settings/settings.dist.py +++ b/dojo/settings/settings.dist.py @@ -89,7 +89,7 @@ DD_CELERY_RESULT_EXPIRES=(int, 86400), DD_CELERY_BEAT_SCHEDULE_FILENAME=(str, root("dojo.celery.beat.db")), DD_CELERY_TASK_SERIALIZER=(str, "pickle"), - DD_CELERY_LOG_LEVEL=(str, "INFO"), + DD_CELERY_LOG_LEVEL=(str, "INFO"), DD_TAG_BULK_ADD_BATCH_SIZE=(int, 1000), # Tagulous slug truncate unique setting. Set to -1 to use tagulous internal default (5) DD_TAGULOUS_SLUG_TRUNCATE_UNIQUE=(int, -1), From 8b90d520905fd75c9caabe039f9a2fd14dfc7c7f Mon Sep 17 00:00:00 2001 From: Valentijn Scholten Date: Fri, 26 Dec 2025 18:50:01 +0100 Subject: [PATCH 12/36] remove dojo_model_from/to_id --- dojo/utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/dojo/utils.py b/dojo/utils.py index d87fcd59540..1d3b8b24aa3 100644 --- a/dojo/utils.py +++ b/dojo/utils.py @@ -1236,6 +1236,17 @@ def get_setting(setting): return getattr(settings, setting) +@dojo_async_task(signature=True) +@app.task +def calculate_grade_signature(product_id, *args, **kwargs): + """Returns a signature for calculating product grade that can be used in chords or groups.""" + product = get_object_or_none(Product, id=product_id) + if not product: + logger.warning("Product with id %s does not exist, skipping calculate_grade_signature", product_id) + return None + return calculate_grade_internal(product, *args, **kwargs) + + @dojo_async_task @app.task def calculate_grade(product_id, *args, **kwargs): From dd5a95ff8c49cb63135cc96eaa010cc3209f8d85 Mon Sep 17 00:00:00 2001 From: Valentijn Scholten Date: Fri, 26 Dec 2025 21:02:52 +0100 Subject: [PATCH 13/36] initial base task --- dojo/celery.py | 117 +++++++++++++++++++++++++++++ dojo/finding/deduplication.py | 9 +-- dojo/finding/helper.py | 9 +-- dojo/importers/endpoint_manager.py | 28 +++---- dojo/jira_link/helper.py | 24 ++---- dojo/notifications/helper.py | 16 ++-- dojo/product/helpers.py | 6 +- dojo/sla_config/helpers.py | 6 +- dojo/tasks.py | 6 +- dojo/tools/tool_issue_updater.py | 9 +-- dojo/utils.py | 38 +++------- 11 files changed, 166 insertions(+), 102 deletions(-) diff --git a/dojo/celery.py b/dojo/celery.py index ead4a8813a8..2ffe65b5d69 100644 --- a/dojo/celery.py +++ b/dojo/celery.py @@ -44,6 +44,123 @@ def __call__(self, *args, **kwargs): app.autodiscover_tasks(lambda: settings.INSTALLED_APPS) +class DojoAsyncTask(Task): + + """ + Base task class that provides dojo_async_task functionality without using a decorator. + + This class: + - Injects user context into task kwargs + - Tracks task calls for performance testing + - Handles sync/async execution based on user settings + - Supports all Celery features (signatures, chords, groups, chains) + """ + + def apply_async(self, args=None, kwargs=None, **options): + """Override apply_async to inject user context and track tasks.""" + from dojo.decorators import dojo_async_task_counter # noqa: PLC0415 circular import + from dojo.utils import get_current_user # noqa: PLC0415 circular import + + if kwargs is None: + kwargs = {} + + # Inject user context if not already present + if "async_user" not in kwargs: + kwargs["async_user"] = get_current_user() + + # Track task call (only if not already tracked by __call__) + # Check if this is a direct call to apply_async (not from __call__) + # by checking if _dojo_tracked is not set + if not getattr(self, "_dojo_tracked", False): + dojo_async_task_counter.incr( + self.name, + args=args, + kwargs=kwargs, + ) + + # Call parent to execute async + return super().apply_async(args=args, kwargs=kwargs, **options) + + def s(self, *args, **kwargs): + """Create a mutable signature with injected user context.""" + from dojo.decorators import dojo_async_task_counter # noqa: PLC0415 circular import + from dojo.utils import get_current_user # noqa: PLC0415 circular import + + if "async_user" not in kwargs: + kwargs["async_user"] = get_current_user() + + # Track task call + dojo_async_task_counter.incr( + self.name, + args=args, + kwargs=kwargs, + ) + + return super().s(*args, **kwargs) + + def si(self, *args, **kwargs): + """Create an immutable signature with injected user context.""" + from dojo.decorators import dojo_async_task_counter # noqa: PLC0415 circular import + from dojo.utils import get_current_user # noqa: PLC0415 circular import + + if "async_user" not in kwargs: + kwargs["async_user"] = get_current_user() + + # Track task call + dojo_async_task_counter.incr( + self.name, + args=args, + kwargs=kwargs, + ) + + return super().si(*args, **kwargs) + + def __call__(self, *args, **kwargs): + """ + Override __call__ to handle direct task calls with sync/async logic. + + This replicates the behavior of the dojo_async_task decorator wrapper. + """ + # In Celery worker execution, __call__ is how tasks actually run. + # We only want the sync/async decision when tasks are called directly + # from application code (task(...)), not when the worker is executing a message. + if not getattr(self.request, "called_directly", True): + return super().__call__(*args, **kwargs) + + from dojo.decorators import dojo_async_task_counter, we_want_async # noqa: PLC0415 circular import + from dojo.utils import get_current_user # noqa: PLC0415 circular import + + # Inject user context if not already present + if "async_user" not in kwargs: + kwargs["async_user"] = get_current_user() + + # Track task call + dojo_async_task_counter.incr( + self.name, + args=args, + kwargs=kwargs, + ) + + # Extract countdown if present (don't pass to sync execution) + countdown = kwargs.pop("countdown", 0) + + # Check if we should run async or sync + if we_want_async(*args, func=self, **kwargs): + # Mark as tracked to avoid double tracking in apply_async + self._dojo_tracked = True + try: + # Run asynchronously + return self.apply_async(args=args, kwargs=kwargs, countdown=countdown) + finally: + # Clean up the flag + delattr(self, "_dojo_tracked") + else: + # Run synchronously in-process, matching the original decorator behavior: func(*args, **kwargs) + # Remove sync from kwargs as it's a control flag, not a task argument. + kwargs.pop("sync", None) + return self.run(*args, **kwargs) + + @app.task(bind=True) def debug_task(self): logger.info(f"Request: {self.request!r}") diff --git a/dojo/finding/deduplication.py b/dojo/finding/deduplication.py index eb8baf40db0..1d778e1c9b9 100644 --- a/dojo/finding/deduplication.py +++ b/dojo/finding/deduplication.py @@ -7,8 +7,7 @@ from django.db.models import Prefetch from django.db.models.query_utils import Q -from dojo.celery import app -from dojo.decorators import dojo_async_task +from dojo.celery import DojoAsyncTask, app from dojo.models import Finding, System_Settings logger = logging.getLogger(__name__) @@ -45,14 +44,12 @@ def get_finding_models_for_deduplication(finding_ids): ) -@dojo_async_task -@app.task +@app.task(base=DojoAsyncTask) def do_dedupe_finding_task(new_finding_id, *args, **kwargs): return do_dedupe_finding_task_internal(Finding.objects.get(id=new_finding_id), *args, **kwargs) -@dojo_async_task -@app.task +@app.task(base=DojoAsyncTask) def do_dedupe_batch_task(finding_ids, *args, **kwargs): """ Async task to deduplicate a batch of findings. The findings are assumed to be in the same test. diff --git a/dojo/finding/helper.py b/dojo/finding/helper.py index 908afee38b9..e7a93e90a65 100644 --- a/dojo/finding/helper.py +++ b/dojo/finding/helper.py @@ -15,8 +15,7 @@ import dojo.jira_link.helper as jira_helper import dojo.risk_acceptance.helper as ra_helper -from dojo.celery import app -from dojo.decorators import dojo_async_task +from dojo.celery import DojoAsyncTask, app from dojo.endpoint.utils import endpoint_get_or_create, save_endpoints_to_add from dojo.file_uploads.helper import delete_related_files from dojo.finding.deduplication import ( @@ -391,8 +390,7 @@ def add_findings_to_auto_group(name, findings, group_by, *, create_finding_group finding_group.findings.add(*findings) -@dojo_async_task -@app.task +@app.task(base=DojoAsyncTask) def post_process_finding_save(finding_id, dedupe_option=True, rules_option=True, product_grading_option=True, # noqa: FBT002 issue_updater_option=True, push_to_jira=False, user=None, *args, **kwargs): # noqa: FBT002 - this is bit hard to fix nice have this universally fixed finding = get_object_or_none(Finding, id=finding_id) @@ -453,8 +451,7 @@ def post_process_finding_save_internal(finding, dedupe_option=True, rules_option jira_helper.push_to_jira(finding.finding_group) -@dojo_async_task -@app.task +@app.task(base=DojoAsyncTask) def post_process_findings_batch(finding_ids, *args, dedupe_option=True, rules_option=True, product_grading_option=True, issue_updater_option=True, push_to_jira=False, user=None, **kwargs): diff --git a/dojo/importers/endpoint_manager.py b/dojo/importers/endpoint_manager.py index f4b277d49fa..3c1326c220c 100644 --- a/dojo/importers/endpoint_manager.py +++ b/dojo/importers/endpoint_manager.py @@ -4,8 +4,7 @@ from django.urls import reverse from django.utils import timezone -from dojo.celery import app -from dojo.decorators import dojo_async_task +from dojo.celery import DojoAsyncTask, app from dojo.endpoint.utils import endpoint_get_or_create from dojo.models import ( Dojo_User, @@ -18,17 +17,15 @@ class EndpointManager: - @dojo_async_task - @app.task() + @app.task(base=DojoAsyncTask) def add_endpoints_to_unsaved_finding( - self, - finding: Finding, + finding: Finding, # noqa: N805 endpoints: list[Endpoint], **kwargs: dict, ) -> None: """Creates Endpoint objects for a single finding and creates the link via the endpoint status""" logger.debug(f"IMPORT_SCAN: Adding {len(endpoints)} endpoints to finding: {finding}") - self.clean_unsaved_endpoints(endpoints) + EndpointManager.clean_unsaved_endpoints(endpoints) for endpoint in endpoints: ep = None eps = [] @@ -41,7 +38,8 @@ def add_endpoints_to_unsaved_finding( path=endpoint.path, query=endpoint.query, fragment=endpoint.fragment, - product=finding.test.engagement.product) + product=finding.test.engagement.product, + ) eps.append(ep) except (MultipleObjectsReturned): msg = ( @@ -58,11 +56,9 @@ def add_endpoints_to_unsaved_finding( logger.debug(f"IMPORT_SCAN: {len(endpoints)} endpoints imported") - @dojo_async_task - @app.task() + @app.task(base=DojoAsyncTask) def mitigate_endpoint_status( - self, - endpoint_status_list: list[Endpoint_Status], + endpoint_status_list: list[Endpoint_Status], # noqa: N805 user: Dojo_User, **kwargs: dict, ) -> None: @@ -85,11 +81,9 @@ def mitigate_endpoint_status( batch_size=1000, ) - @dojo_async_task - @app.task() + @app.task(base=DojoAsyncTask) def reactivate_endpoint_status( - self, - endpoint_status_list: list[Endpoint_Status], + endpoint_status_list: list[Endpoint_Status], # noqa: N805 **kwargs: dict, ) -> None: """Reactivate all endpoint status objects that are supplied""" @@ -120,8 +114,8 @@ def chunk_endpoints_and_disperse( ) -> None: self.add_endpoints_to_unsaved_finding(finding, endpoints, sync=True) + @staticmethod def clean_unsaved_endpoints( - self, endpoints: list[Endpoint], ) -> None: """ diff --git a/dojo/jira_link/helper.py b/dojo/jira_link/helper.py index f020a4d5b19..e864121d57a 100644 --- a/dojo/jira_link/helper.py +++ b/dojo/jira_link/helper.py @@ -17,8 +17,7 @@ from jira.exceptions import JIRAError from requests.auth import HTTPBasicAuth -from dojo.celery import app -from dojo.decorators import dojo_async_task +from dojo.celery import DojoAsyncTask, app from dojo.forms import JIRAEngagementForm, JIRAProjectForm from dojo.models import ( Engagement, @@ -773,8 +772,7 @@ def push_to_jira(obj, *args, **kwargs): # we need thre separate celery tasks due to the decorators we're using to map to/from ids -@dojo_async_task -@app.task +@app.task(base=DojoAsyncTask) def push_finding_to_jira(finding_id, *args, **kwargs): finding = get_object_or_none(Finding, id=finding_id) if not finding: @@ -786,8 +784,7 @@ def push_finding_to_jira(finding_id, *args, **kwargs): return add_jira_issue(finding, *args, **kwargs) -@dojo_async_task -@app.task +@app.task(base=DojoAsyncTask) def push_finding_group_to_jira(finding_group_id, *args, **kwargs): finding_group = get_object_or_none(Finding_Group, id=finding_group_id) if not finding_group: @@ -803,8 +800,7 @@ def push_finding_group_to_jira(finding_group_id, *args, **kwargs): return add_jira_issue(finding_group, *args, **kwargs) -@dojo_async_task -@app.task +@app.task(base=DojoAsyncTask) def push_engagement_to_jira(engagement_id, *args, **kwargs): engagement = get_object_or_none(Engagement, id=engagement_id) if not engagement: @@ -1376,8 +1372,7 @@ def jira_check_attachment(issue, source_file_name): return file_exists -@dojo_async_task -@app.task +@app.task(base=DojoAsyncTask) def close_epic(engagement_id, push_to_jira, **kwargs): engagement = get_object_or_none(Engagement, id=engagement_id) if not engagement: @@ -1425,8 +1420,7 @@ def close_epic(engagement_id, push_to_jira, **kwargs): return False -@dojo_async_task -@app.task +@app.task(base=DojoAsyncTask) def update_epic(engagement_id, **kwargs): engagement = get_object_or_none(Engagement, id=engagement_id) if not engagement: @@ -1472,8 +1466,7 @@ def update_epic(engagement_id, **kwargs): return False -@dojo_async_task -@app.task +@app.task(base=DojoAsyncTask) def add_epic(engagement_id, **kwargs): engagement = get_object_or_none(Engagement, id=engagement_id) if not engagement: @@ -1584,8 +1577,7 @@ def add_comment(obj, note, *, force_push=False, **kwargs): return add_comment_internal(jira_issue.id, note.id, force_push=force_push, **kwargs) -@dojo_async_task -@app.task +@app.task(base=DojoAsyncTask) def add_comment_internal(jira_issue_id, note_id, *, force_push=False, **kwargs): """Internal Celery task that adds a comment to a JIRA issue.""" jira_issue = get_object_or_none(JIRA_Issue, id=jira_issue_id) diff --git a/dojo/notifications/helper.py b/dojo/notifications/helper.py index c4458daec01..b422896e83d 100644 --- a/dojo/notifications/helper.py +++ b/dojo/notifications/helper.py @@ -17,8 +17,8 @@ from dojo import __version__ as dd_version from dojo.authorization.roles_permissions import Permissions -from dojo.celery import app -from dojo.decorators import dojo_async_task, we_want_async +from dojo.celery import DojoAsyncTask, app +from dojo.decorators import we_want_async from dojo.labels import get_labels from dojo.models import ( Alerts, @@ -199,8 +199,7 @@ class SlackNotificationManger(NotificationManagerHelpers): """Manger for slack notifications and their helpers.""" - @dojo_async_task - @app.task + @app.task(base=DojoAsyncTask) def send_slack_notification( self, event: str, @@ -317,8 +316,7 @@ class MSTeamsNotificationManger(NotificationManagerHelpers): """Manger for Microsoft Teams notifications and their helpers.""" - @dojo_async_task - @app.task + @app.task(base=DojoAsyncTask) def send_msteams_notification( self, event: str, @@ -368,8 +366,7 @@ class EmailNotificationManger(NotificationManagerHelpers): """Manger for email notifications and their helpers.""" - @dojo_async_task - @app.task + @app.task(base=DojoAsyncTask) def send_mail_notification( self, event: str, @@ -420,8 +417,7 @@ class WebhookNotificationManger(NotificationManagerHelpers): ERROR_PERMANENT = "permanent" ERROR_TEMPORARY = "temporary" - @dojo_async_task - @app.task + @app.task(base=DojoAsyncTask) def send_webhooks_notification( self, event: str, diff --git a/dojo/product/helpers.py b/dojo/product/helpers.py index aeadec0246d..f23e4155548 100644 --- a/dojo/product/helpers.py +++ b/dojo/product/helpers.py @@ -1,15 +1,13 @@ import contextlib import logging -from dojo.celery import app -from dojo.decorators import dojo_async_task +from dojo.celery import DojoAsyncTask, app from dojo.models import Endpoint, Engagement, Finding, Product, Test logger = logging.getLogger(__name__) -@dojo_async_task -@app.task +@app.task(base=DojoAsyncTask) def propagate_tags_on_product(product_id, *args, **kwargs): with contextlib.suppress(Product.DoesNotExist): product = Product.objects.get(id=product_id) diff --git a/dojo/sla_config/helpers.py b/dojo/sla_config/helpers.py index da5899a85b0..dd2567729dc 100644 --- a/dojo/sla_config/helpers.py +++ b/dojo/sla_config/helpers.py @@ -1,15 +1,13 @@ import logging -from dojo.celery import app -from dojo.decorators import dojo_async_task +from dojo.celery import DojoAsyncTask, app from dojo.models import Finding, Product, SLA_Configuration, System_Settings from dojo.utils import get_custom_method, mass_model_updater logger = logging.getLogger(__name__) -@dojo_async_task -@app.task +@app.task(base=DojoAsyncTask) def async_update_sla_expiration_dates_sla_config_sync(sla_config: SLA_Configuration, products: list[Product], *args, severities: list[str] | None = None, **kwargs): if method := get_custom_method("FINDING_SLA_EXPIRATION_CALCULATION_METHOD"): method(sla_config, products, severities=severities) diff --git a/dojo/tasks.py b/dojo/tasks.py index 29dfe11257c..d5b904601d7 100644 --- a/dojo/tasks.py +++ b/dojo/tasks.py @@ -11,8 +11,7 @@ from django.utils import timezone from dojo.auditlog import run_flush_auditlog -from dojo.celery import app -from dojo.decorators import dojo_async_task +from dojo.celery import DojoAsyncTask, app from dojo.finding.helper import fix_loop_duplicates from dojo.management.commands.jira_status_reconciliation import jira_status_reconciliation from dojo.models import Alerts, Announcement, Endpoint, Engagement, Finding, Product, System_Settings, User @@ -237,8 +236,7 @@ def clear_sessions(*args, **kwargs): call_command("clearsessions") -@dojo_async_task -@app.task +@app.task(base=DojoAsyncTask) def update_watson_search_index_for_model(model_name, pk_list, *args, **kwargs): """ Async task to update watson search indexes for a specific model type. diff --git a/dojo/tools/tool_issue_updater.py b/dojo/tools/tool_issue_updater.py index 854fb989113..26cfcf6d973 100644 --- a/dojo/tools/tool_issue_updater.py +++ b/dojo/tools/tool_issue_updater.py @@ -2,8 +2,7 @@ import pghistory -from dojo.celery import app -from dojo.decorators import dojo_async_task +from dojo.celery import DojoAsyncTask, app from dojo.models import Finding from dojo.tools.api_sonarqube.parser import SCAN_SONARQUBE_API from dojo.tools.api_sonarqube.updater import SonarQubeApiUpdater @@ -23,8 +22,7 @@ def is_tool_issue_updater_needed(finding, *args, **kwargs): return test_type.name == SCAN_SONARQUBE_API -@dojo_async_task -@app.task +@app.task(base=DojoAsyncTask) def tool_issue_updater(finding_id, *args, **kwargs): finding = get_object_or_none(Finding, id=finding_id) if not finding: @@ -37,8 +35,7 @@ def tool_issue_updater(finding_id, *args, **kwargs): SonarQubeApiUpdater().update_sonarqube_finding(finding) -@dojo_async_task -@app.task +@app.task(base=DojoAsyncTask) def update_findings_from_source_issues(**kwargs): # Wrap with pghistory context for audit trail with pghistory.context(source="sonarqube_sync"): diff --git a/dojo/utils.py b/dojo/utils.py index 1d3b8b24aa3..33347c82960 100644 --- a/dojo/utils.py +++ b/dojo/utils.py @@ -45,8 +45,7 @@ from django.utils.translation import gettext as _ from dojo.authorization.roles_permissions import Permissions -from dojo.celery import app -from dojo.decorators import dojo_async_task +from dojo.celery import DojoAsyncTask, app from dojo.finding.queries import get_authorized_findings from dojo.github import ( add_external_issue_github, @@ -1054,8 +1053,7 @@ def handle_uploaded_selenium(f, cred): cred.save() -@dojo_async_task -@app.task +@app.task(base=DojoAsyncTask) def add_external_issue(finding_id, external_issue_provider, **kwargs): finding = get_object_or_none(Finding, id=finding_id) if not finding: @@ -1070,8 +1068,7 @@ def add_external_issue(finding_id, external_issue_provider, **kwargs): add_external_issue_github(finding, prod, eng) -@dojo_async_task -@app.task +@app.task(base=DojoAsyncTask) def update_external_issue(finding_id, old_status, external_issue_provider, **kwargs): finding = get_object_or_none(Finding, id=finding_id) if not finding: @@ -1085,8 +1082,7 @@ def update_external_issue(finding_id, old_status, external_issue_provider, **kwa update_external_issue_github(finding, prod, eng) -@dojo_async_task -@app.task +@app.task(base=DojoAsyncTask) def close_external_issue(finding_id, note, external_issue_provider, **kwargs): finding = get_object_or_none(Finding, id=finding_id) if not finding: @@ -1100,8 +1096,7 @@ def close_external_issue(finding_id, note, external_issue_provider, **kwargs): close_external_issue_github(finding, note, prod, eng) -@dojo_async_task -@app.task +@app.task(base=DojoAsyncTask) def reopen_external_issue(finding_id, note, external_issue_provider, **kwargs): finding = get_object_or_none(Finding, id=finding_id) if not finding: @@ -1236,19 +1231,7 @@ def get_setting(setting): return getattr(settings, setting) -@dojo_async_task(signature=True) -@app.task -def calculate_grade_signature(product_id, *args, **kwargs): - """Returns a signature for calculating product grade that can be used in chords or groups.""" - product = get_object_or_none(Product, id=product_id) - if not product: - logger.warning("Product with id %s does not exist, skipping calculate_grade_signature", product_id) - return None - return calculate_grade_internal(product, *args, **kwargs) - - -@dojo_async_task -@app.task +@app.task(base=DojoAsyncTask) def calculate_grade(product_id, *args, **kwargs): product = get_object_or_none(Product, id=product_id) if not product: @@ -2027,8 +2010,7 @@ def __init__(self, *args, **kwargs): "Test": [(Finding, "test__id")], } - @dojo_async_task - @app.task + @app.task(base=DojoAsyncTask) def delete_chunk(self, objects, **kwargs): # Now delete all objects with retry for deadlocks max_retries = 3 @@ -2076,8 +2058,7 @@ def delete_chunk(self, objects, **kwargs): obj.delete() break - @dojo_async_task - @app.task + @app.task(base=DojoAsyncTask) def delete(self, obj, **kwargs): logger.debug("ASYNC_DELETE: Deleting " + self.get_object_name(obj) + ": " + str(obj)) model_list = self.mapping.get(self.get_object_name(obj), None) @@ -2089,8 +2070,7 @@ def delete(self, obj, **kwargs): logger.debug("ASYNC_DELETE: " + self.get_object_name(obj) + " async delete not supported. Deleteing normally: " + str(obj)) obj.delete() - @dojo_async_task - @app.task + @app.task(base=DojoAsyncTask) def crawl(self, obj, model_list, **kwargs): logger.debug("ASYNC_DELETE: Crawling " + self.get_object_name(obj) + ": " + str(obj)) for model_info in model_list: From 2ed15335a50baffd088d3a233f242bf49bffe9f7 Mon Sep 17 00:00:00 2001 From: Valentijn Scholten Date: Fri, 26 Dec 2025 22:22:53 +0100 Subject: [PATCH 14/36] replace dojo_async_task decorator with class+helper --- dojo/api_v2/views.py | 5 +- dojo/celery.py | 90 ++----------------------- dojo/celery_dispatch.py | 76 +++++++++++++++++++++ dojo/endpoint/views.py | 3 +- dojo/engagement/services.py | 3 +- dojo/engagement/views.py | 3 +- dojo/finding/helper.py | 8 ++- dojo/finding/views.py | 5 +- dojo/finding_group/views.py | 5 +- dojo/importers/base_importer.py | 53 +++------------ dojo/importers/default_importer.py | 4 +- dojo/importers/default_reimporter.py | 4 +- dojo/importers/endpoint_manager.py | 7 +- dojo/jira_link/helper.py | 15 +++-- dojo/management/commands/dedupe.py | 19 +++++- dojo/models.py | 8 ++- dojo/notifications/helper.py | 13 ++-- dojo/tags_signals.py | 3 +- dojo/tasks.py | 5 +- dojo/templatetags/display_tags.py | 4 +- dojo/test/views.py | 3 +- dojo/tools/tool_issue_updater.py | 3 +- dojo/utils.py | 4 +- unittests/test_importers_performance.py | 17 ----- 24 files changed, 174 insertions(+), 186 deletions(-) create mode 100644 dojo/celery_dispatch.py diff --git a/dojo/api_v2/views.py b/dojo/api_v2/views.py index e8c7d278e65..d3aeb51d85c 100644 --- a/dojo/api_v2/views.py +++ b/dojo/api_v2/views.py @@ -46,6 +46,7 @@ ) from dojo.api_v2.prefetch.prefetcher import _Prefetcher from dojo.authorization.roles_permissions import Permissions +from dojo.celery_dispatch import dojo_dispatch_task from dojo.cred.queries import get_authorized_cred_mappings from dojo.endpoint.queries import ( get_authorized_endpoint_status, @@ -678,13 +679,13 @@ def update_jira_epic(self, request, pk=None): try: if engagement.has_jira_issue: - jira_helper.update_epic(engagement.id, **request.data) + dojo_dispatch_task(jira_helper.update_epic, engagement.id, **request.data) response = Response( {"info": "Jira Epic update query sent"}, status=status.HTTP_200_OK, ) else: - jira_helper.add_epic(engagement.id, **request.data) + dojo_dispatch_task(jira_helper.add_epic, engagement.id, **request.data) response = Response( {"info": "Jira Epic create query sent"}, status=status.HTTP_200_OK, diff --git a/dojo/celery.py b/dojo/celery.py index 2ffe65b5d69..3079d25d20e 100644 --- a/dojo/celery.py +++ b/dojo/celery.py @@ -52,7 +52,6 @@ class DojoAsyncTask(Task): This class: - Injects user context into task kwargs - Tracks task calls for performance testing - - Handles sync/async execution based on user settings - Supports all Celery features (signatures, chords, groups, chains) """ @@ -68,97 +67,18 @@ def apply_async(self, args=None, kwargs=None, **options): if "async_user" not in kwargs: kwargs["async_user"] = get_current_user() - # Track task call (only if not already tracked by __call__) - # Check if this is a direct call to apply_async (not from __call__) - # by checking if _dojo_tracked is not set - if not getattr(self, "_dojo_tracked", False): - dojo_async_task_counter.incr( - self.name, - args=args, - kwargs=kwargs, - ) + # Control flag used for sync/async decision; never pass into the task itself + kwargs.pop("sync", None) - # Call parent to execute async - return super().apply_async(args=args, kwargs=kwargs, **options) - - def s(self, *args, **kwargs): - """Create a mutable signature with injected user context.""" - from dojo.decorators import dojo_async_task_counter # noqa: PLC0415 circular import - from dojo.utils import get_current_user # noqa: PLC0415 circular import - - if "async_user" not in kwargs: - kwargs["async_user"] = get_current_user() - - # Track task call - dojo_async_task_counter.incr( - self.name, - args=args, - kwargs=kwargs, - ) - - return super().s(*args, **kwargs) - - def si(self, *args, **kwargs): - """Create an immutable signature with injected user context.""" - from dojo.decorators import dojo_async_task_counter # noqa: PLC0415 circular import - from dojo.utils import get_current_user # noqa: PLC0415 circular import - - if "async_user" not in kwargs: - kwargs["async_user"] = get_current_user() - - # Track task call - dojo_async_task_counter.incr( - self.name, - args=args, - kwargs=kwargs, - ) - - return super().si(*args, **kwargs) - - def __call__(self, *args, **kwargs): - """ - Override __call__ to handle direct task calls with sync/async logic. - - This replicates the behavior of the dojo_async_task decorator wrapper. - """ - # In Celery worker execution, __call__ is how tasks actually run. - # We only want the sync/async decision when tasks are called directly - # from application code (task(...)), not when the worker is executing a message. - if not getattr(self.request, "called_directly", True): - return super().__call__(*args, **kwargs) - - from dojo.decorators import dojo_async_task_counter, we_want_async # noqa: PLC0415 circular import - from dojo.utils import get_current_user # noqa: PLC0415 circular import - - # Inject user context if not already present - if "async_user" not in kwargs: - kwargs["async_user"] = get_current_user() - - # Track task call + # Track dispatch dojo_async_task_counter.incr( self.name, args=args, kwargs=kwargs, ) - # Extract countdown if present (don't pass to sync execution) - countdown = kwargs.pop("countdown", 0) - - # Check if we should run async or sync - if we_want_async(*args, func=self, **kwargs): - # Mark as tracked to avoid double tracking in apply_async - self._dojo_tracked = True - try: - # Run asynchronously - return self.apply_async(args=args, kwargs=kwargs, countdown=countdown) - finally: - # Clean up the flag - delattr(self, "_dojo_tracked") - else: - # Run synchronously in-process, matching the original decorator behavior: func(*args, **kwargs) - # Remove sync from kwargs as it's a control flag, not a task argument. - kwargs.pop("sync", None) - return self.run(*args, **kwargs) + # Call parent to execute async + return super().apply_async(args=args, kwargs=kwargs, **options) @app.task(bind=True) diff --git a/dojo/celery_dispatch.py b/dojo/celery_dispatch.py new file mode 100644 index 00000000000..5bbf5a6ea7c --- /dev/null +++ b/dojo/celery_dispatch.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol, cast + +from celery.canvas import Signature + +if TYPE_CHECKING: + from collections.abc import Mapping + + +class _SupportsSi(Protocol): + def si(self, *args: Any, **kwargs: Any) -> Signature: ... + + +class _SupportsApplyAsync(Protocol): + def apply_async(self, args: Any | None = None, kwargs: Any | None = None, **options: Any) -> Any: ... + + +def _inject_async_user(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: + result: dict[str, Any] = dict(kwargs or {}) + if "async_user" not in result: + from dojo.utils import get_current_user # noqa: PLC0415 circular import + + result["async_user"] = get_current_user() + return result + + +def dojo_create_signature(task_or_sig: _SupportsSi | Signature, *args: Any, **kwargs: Any) -> Signature: + """ + Build a Celery signature with DefectDojo user context injected. + + - If passed a task, returns `task_or_sig.si(*args, **kwargs)`. + - If passed an existing signature, returns a cloned signature with merged kwargs. + """ + injected = _inject_async_user(kwargs) + injected.pop("countdown", None) + + if isinstance(task_or_sig, Signature): + merged_kwargs = {**(task_or_sig.kwargs or {}), **injected} + return task_or_sig.clone(kwargs=merged_kwargs) + + return task_or_sig.si(*args, **injected) + + +def dojo_dispatch_task(task_or_sig: _SupportsSi | _SupportsApplyAsync | Signature, *args: Any, **kwargs: Any) -> Any: + """ + Dispatch a task/signature using DefectDojo semantics. + + - Inject `async_user` if missing. + - Respect `sync=True` (foreground execution) and user `block_execution`. + - Support `countdown=` for async dispatch. + + Returns: + - async: AsyncResult-like return from Celery + - sync: underlying return value of the task + + """ + from dojo.decorators import dojo_async_task_counter, we_want_async # noqa: PLC0415 circular import + + countdown = cast("int", kwargs.pop("countdown", 0)) + injected = _inject_async_user(kwargs) + + sig = dojo_create_signature(task_or_sig if isinstance(task_or_sig, Signature) else cast("_SupportsSi", task_or_sig), *args, **injected) + sig_kwargs = dict(sig.kwargs or {}) + + if we_want_async(*sig.args, func=getattr(sig, "type", None), **sig_kwargs): + # DojoAsyncTask.apply_async tracks async dispatch. Avoid double-counting here. + return sig.apply_async(countdown=countdown) + + # Track foreground execution as a "created task" as well (matches historical dojo_async_task behavior) + dojo_async_task_counter.incr(str(sig.task), args=sig.args, kwargs=sig_kwargs) + + sig_kwargs.pop("sync", None) + sig = sig.clone(kwargs=sig_kwargs) + eager = sig.apply() + return eager.get(propagate=True) diff --git a/dojo/endpoint/views.py b/dojo/endpoint/views.py index 1dc4df898c6..5d604eb9890 100644 --- a/dojo/endpoint/views.py +++ b/dojo/endpoint/views.py @@ -18,6 +18,7 @@ from dojo.authorization.authorization import user_has_permission_or_403 from dojo.authorization.authorization_decorators import user_is_authorized from dojo.authorization.roles_permissions import Permissions +from dojo.celery_dispatch import dojo_dispatch_task from dojo.endpoint.queries import get_authorized_endpoints from dojo.endpoint.utils import clean_hosts_run, endpoint_meta_import from dojo.filters import EndpointFilter, EndpointFilterWithoutObjectLookups @@ -373,7 +374,7 @@ def endpoint_bulk_update_all(request, pid=None): product_calc = list(Product.objects.filter(endpoint__id__in=endpoints_to_update).distinct()) endpoints.delete() for prod in product_calc: - calculate_grade(prod.id) + dojo_dispatch_task(calculate_grade, prod.id) if skipped_endpoint_count > 0: add_error_message_to_response(f"Skipped deletion of {skipped_endpoint_count} endpoints because you are not authorized.") diff --git a/dojo/engagement/services.py b/dojo/engagement/services.py index cd70af1ea2c..42a7c1c05e4 100644 --- a/dojo/engagement/services.py +++ b/dojo/engagement/services.py @@ -5,6 +5,7 @@ from django.dispatch import receiver import dojo.jira_link.helper as jira_helper +from dojo.celery_dispatch import dojo_dispatch_task from dojo.models import Engagement logger = logging.getLogger(__name__) @@ -16,7 +17,7 @@ def close_engagement(eng): eng.save() if jira_helper.get_jira_project(eng): - jira_helper.close_epic(eng.id, push_to_jira=True) + dojo_dispatch_task(jira_helper.close_epic, eng.id, push_to_jira=True) def reopen_engagement(eng): diff --git a/dojo/engagement/views.py b/dojo/engagement/views.py index ebc2e09ce6b..a1b417083fb 100644 --- a/dojo/engagement/views.py +++ b/dojo/engagement/views.py @@ -37,6 +37,7 @@ from dojo.authorization.authorization import user_has_permission_or_403 from dojo.authorization.authorization_decorators import user_is_authorized from dojo.authorization.roles_permissions import Permissions +from dojo.celery_dispatch import dojo_dispatch_task from dojo.endpoint.utils import save_endpoints_to_add from dojo.engagement.queries import get_authorized_engagements from dojo.engagement.services import close_engagement, reopen_engagement @@ -390,7 +391,7 @@ def copy_engagement(request, eid): form = DoneForm(request.POST) if form.is_valid(): engagement_copy = engagement.copy() - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) messages.add_message( request, messages.SUCCESS, diff --git a/dojo/finding/helper.py b/dojo/finding/helper.py index e7a93e90a65..65f4d2bf9d3 100644 --- a/dojo/finding/helper.py +++ b/dojo/finding/helper.py @@ -434,7 +434,9 @@ def post_process_finding_save_internal(finding, dedupe_option=True, rules_option if product_grading_option: if system_settings.enable_product_grade: - calculate_grade(finding.test.engagement.product.id) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(calculate_grade, finding.test.engagement.product.id) else: deduplicationLogger.debug("skipping product grading because it's disabled in system settings") @@ -493,7 +495,9 @@ def post_process_findings_batch(finding_ids, *args, dedupe_option=True, rules_op tool_issue_updater.async_tool_issue_update(finding) if product_grading_option and system_settings.enable_product_grade: - calculate_grade(findings[0].test.engagement.product.id) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(calculate_grade, findings[0].test.engagement.product.id) if push_to_jira: for finding in findings: diff --git a/dojo/finding/views.py b/dojo/finding/views.py index b5bfb593043..b3abe239cdb 100644 --- a/dojo/finding/views.py +++ b/dojo/finding/views.py @@ -38,6 +38,7 @@ user_is_authorized, ) from dojo.authorization.roles_permissions import Permissions +from dojo.celery_dispatch import dojo_dispatch_task from dojo.filters import ( AcceptedFindingFilter, AcceptedFindingFilterWithoutObjectLookups, @@ -1082,7 +1083,7 @@ def process_form(self, request: HttpRequest, finding: Finding, context: dict): product = finding.test.engagement.product finding.delete() # Update the grade of the product async - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) # Add a message to the request that the finding was successfully deleted messages.add_message( request, @@ -1353,7 +1354,7 @@ def copy_finding(request, fid): test = form.cleaned_data.get("test") product = finding.test.engagement.product finding_copy = finding.copy(test=test) - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) messages.add_message( request, messages.SUCCESS, diff --git a/dojo/finding_group/views.py b/dojo/finding_group/views.py index 451d4dcd720..e29c401b80d 100644 --- a/dojo/finding_group/views.py +++ b/dojo/finding_group/views.py @@ -16,6 +16,7 @@ from dojo.authorization.authorization import user_has_permission_or_403 from dojo.authorization.authorization_decorators import user_is_authorized from dojo.authorization.roles_permissions import Permissions +from dojo.celery_dispatch import dojo_dispatch_task from dojo.filters import ( FindingFilter, FindingFilterWithoutObjectLookups, @@ -100,7 +101,7 @@ def view_finding_group(request, fgid): elif not finding_group.has_jira_issue: jira_helper.finding_group_link_jira(request, finding_group, jira_issue) elif push_to_jira: - jira_helper.push_to_jira(finding_group, sync=True) + dojo_dispatch_task(jira_helper.push_to_jira, finding_group, sync=True) finding_group.save() return HttpResponseRedirect(reverse("view_test", args=(finding_group.test.id,))) @@ -200,7 +201,7 @@ def push_to_jira(request, fgid): # it may look like success here, but the push_to_jira are swallowing exceptions # but cant't change too much now without having a test suite, so leave as is for now with the addition warning message to check alerts for background errors. - if jira_helper.push_to_jira(group, sync=True): + if dojo_dispatch_task(jira_helper.push_to_jira, group, sync=True): messages.add_message( request, messages.SUCCESS, diff --git a/dojo/importers/base_importer.py b/dojo/importers/base_importer.py index 1c0b84687b1..f7fa29e3c3a 100644 --- a/dojo/importers/base_importer.py +++ b/dojo/importers/base_importer.py @@ -3,7 +3,6 @@ import time from collections.abc import Iterable -from celery import chord, group from django.conf import settings from django.core.exceptions import ValidationError from django.core.files.base import ContentFile @@ -13,7 +12,7 @@ from django.utils.timezone import make_aware import dojo.finding.helper as finding_helper -from dojo import utils +from dojo.celery_dispatch import dojo_dispatch_task from dojo.importers.endpoint_manager import EndpointManager from dojo.importers.options import ImporterOptions from dojo.models import ( @@ -28,7 +27,6 @@ Endpoint, FileUpload, Finding, - System_Settings, Test, Test_Import, Test_Import_Finding_Action, @@ -642,47 +640,6 @@ def update_test_type_from_internal_test(self, internal_test: ParserTest) -> None self.test.test_type.dynamic_tool = dynamic_tool self.test.test_type.save() - def maybe_launch_post_processing_chord( - self, - post_processing_task_signatures, - current_batch_number: int, - max_batch_size: int, - * - is_final_batch: bool, - ) -> tuple[list, int, bool]: - """ - Helper to optionally launch a chord of post-processing tasks with a calculate-grade callback - when async is desired. Uses exponential batch sizing up to the configured max batch size. - - Returns a tuple of (post_processing_task_signatures, current_batch_number, launched) - where launched indicates whether a chord/group was dispatched and signatures were reset. - """ - launched = False - if not post_processing_task_signatures: - return post_processing_task_signatures, current_batch_number, launched - - current_batch_size = min(2 ** current_batch_number, max_batch_size) - batch_full = len(post_processing_task_signatures) >= current_batch_size - - if batch_full or is_final_batch: - product = self.test.engagement.product - system_settings = System_Settings.objects.get() - if system_settings.enable_product_grade: - calculate_grade_signature = utils.calculate_grade.si(product.id) - chord(post_processing_task_signatures)(calculate_grade_signature) - else: - group(post_processing_task_signatures).apply_async() - - logger.debug( - f"Launched chord with {len(post_processing_task_signatures)} tasks (batch #{current_batch_number}, size: {len(post_processing_task_signatures)})", - ) - post_processing_task_signatures = [] - if not is_final_batch: - current_batch_number += 1 - launched = True - - return post_processing_task_signatures, current_batch_number, launched - def verify_tool_configuration_from_test(self): """ Verify that the Tool_Configuration supplied along with the @@ -922,7 +879,13 @@ def mitigate_finding( entry=note_message, ) # Mitigate the endpoint statuses - self.endpoint_manager.mitigate_endpoint_status(finding.status_finding.all(), self.user, kwuser=self.user, sync=True) + dojo_dispatch_task( + self.endpoint_manager.mitigate_endpoint_status, + finding.status_finding.all(), + self.user, + kwuser=self.user, + sync=True, + ) # to avoid pushing a finding group multiple times, we push those outside of the loop if finding_groups_enabled and finding.finding_group: # don't try to dedupe findings that we are closing diff --git a/dojo/importers/default_importer.py b/dojo/importers/default_importer.py index 95c6adeed3c..edd88f74005 100644 --- a/dojo/importers/default_importer.py +++ b/dojo/importers/default_importer.py @@ -7,6 +7,7 @@ from django.urls import reverse import dojo.jira_link.helper as jira_helper +from dojo.celery_dispatch import dojo_dispatch_task from dojo.finding import helper as finding_helper from dojo.importers.base_importer import BaseImporter, Parser from dojo.importers.options import ImporterOptions @@ -254,7 +255,8 @@ def process_findings( batch_finding_ids.clear() logger.debug("process_findings: dispatching batch with push_to_jira=%s (batch_size=%d, is_final=%s)", push_to_jira, len(finding_ids_batch), is_final_finding) - finding_helper.post_process_findings_batch( + dojo_dispatch_task( + finding_helper.post_process_findings_batch, finding_ids_batch, dedupe_option=True, rules_option=True, diff --git a/dojo/importers/default_reimporter.py b/dojo/importers/default_reimporter.py index 1ddabf5e87f..ac49689cb6a 100644 --- a/dojo/importers/default_reimporter.py +++ b/dojo/importers/default_reimporter.py @@ -7,6 +7,7 @@ import dojo.finding.helper as finding_helper import dojo.jira_link.helper as jira_helper +from dojo.celery_dispatch import dojo_dispatch_task from dojo.finding.deduplication import ( find_candidates_for_deduplication_hash, find_candidates_for_deduplication_uid_or_hash, @@ -412,7 +413,8 @@ def process_findings( if len(batch_finding_ids) >= dedupe_batch_max_size or is_final: finding_ids_batch = list(batch_finding_ids) batch_finding_ids.clear() - finding_helper.post_process_findings_batch( + dojo_dispatch_task( + finding_helper.post_process_findings_batch, finding_ids_batch, dedupe_option=True, rules_option=True, diff --git a/dojo/importers/endpoint_manager.py b/dojo/importers/endpoint_manager.py index 3c1326c220c..6092ca82c77 100644 --- a/dojo/importers/endpoint_manager.py +++ b/dojo/importers/endpoint_manager.py @@ -5,6 +5,7 @@ from django.utils import timezone from dojo.celery import DojoAsyncTask, app +from dojo.celery_dispatch import dojo_dispatch_task from dojo.endpoint.utils import endpoint_get_or_create from dojo.models import ( Dojo_User, @@ -112,7 +113,7 @@ def chunk_endpoints_and_disperse( endpoints: list[Endpoint], **kwargs: dict, ) -> None: - self.add_endpoints_to_unsaved_finding(finding, endpoints, sync=True) + dojo_dispatch_task(self.add_endpoints_to_unsaved_finding, finding, endpoints, sync=True) @staticmethod def clean_unsaved_endpoints( @@ -133,7 +134,7 @@ def chunk_endpoints_and_reactivate( endpoint_status_list: list[Endpoint_Status], **kwargs: dict, ) -> None: - self.reactivate_endpoint_status(endpoint_status_list, sync=True) + dojo_dispatch_task(self.reactivate_endpoint_status, endpoint_status_list, sync=True) def chunk_endpoints_and_mitigate( self, @@ -141,7 +142,7 @@ def chunk_endpoints_and_mitigate( user: Dojo_User, **kwargs: dict, ) -> None: - self.mitigate_endpoint_status(endpoint_status_list, user, sync=True) + dojo_dispatch_task(self.mitigate_endpoint_status, endpoint_status_list, user, sync=True) def update_endpoint_status( self, diff --git a/dojo/jira_link/helper.py b/dojo/jira_link/helper.py index e864121d57a..513acf4ef5b 100644 --- a/dojo/jira_link/helper.py +++ b/dojo/jira_link/helper.py @@ -18,6 +18,7 @@ from requests.auth import HTTPBasicAuth from dojo.celery import DojoAsyncTask, app +from dojo.celery_dispatch import dojo_dispatch_task from dojo.forms import JIRAEngagementForm, JIRAProjectForm from dojo.models import ( Engagement, @@ -759,14 +760,14 @@ def push_to_jira(obj, *args, **kwargs): if isinstance(obj, Finding): if obj.has_finding_group: logger.debug("pushing finding group for %s to JIRA", obj) - return push_finding_group_to_jira(obj.finding_group.id, *args, **kwargs) - return push_finding_to_jira(obj.id, *args, **kwargs) + return dojo_dispatch_task(push_finding_group_to_jira, obj.finding_group.id, *args, **kwargs) + return dojo_dispatch_task(push_finding_to_jira, obj.id, *args, **kwargs) if isinstance(obj, Finding_Group): - return push_finding_group_to_jira(obj.id, *args, **kwargs) + return dojo_dispatch_task(push_finding_group_to_jira, obj.id, *args, **kwargs) if isinstance(obj, Engagement): - return push_engagement_to_jira(obj.id, *args, **kwargs) + return dojo_dispatch_task(push_engagement_to_jira, obj.id, *args, **kwargs) logger.error("unsupported object passed to push_to_jira: %s %i %s", obj.__name__, obj.id, obj) return None @@ -808,8 +809,8 @@ def push_engagement_to_jira(engagement_id, *args, **kwargs): return None if engagement.has_jira_issue: - return update_epic(engagement.id, *args, **kwargs) - return add_epic(engagement.id, *args, **kwargs) + return dojo_dispatch_task(update_epic, engagement.id, *args, **kwargs) + return dojo_dispatch_task(add_epic, engagement.id, *args, **kwargs) def add_issues_to_epic(jira, obj, epic_id, issue_keys, *, ignore_epics=True): @@ -1574,7 +1575,7 @@ def add_comment(obj, note, *, force_push=False, **kwargs): return False # Call the internal task with IDs (runs synchronously within this task) - return add_comment_internal(jira_issue.id, note.id, force_push=force_push, **kwargs) + return dojo_dispatch_task(add_comment_internal, jira_issue.id, note.id, force_push=force_push, **kwargs) @app.task(base=DojoAsyncTask) diff --git a/dojo/management/commands/dedupe.py b/dojo/management/commands/dedupe.py index 913c528f299..b7e0e669157 100644 --- a/dojo/management/commands/dedupe.py +++ b/dojo/management/commands/dedupe.py @@ -118,14 +118,25 @@ def _run_dedupe(self, *, restrict_to_parsers, hash_code_only, dedupe_only, dedup mass_model_updater(Finding, findings, do_dedupe_finding_task_internal, fields=None, order="desc", page_size=100, log_prefix="deduplicating ") else: # async tasks only need the id - mass_model_updater(Finding, findings.only("id"), lambda f: do_dedupe_finding_task(f.id), fields=None, order="desc", log_prefix="deduplicating ") + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + mass_model_updater( + Finding, + findings.only("id"), + lambda f: dojo_dispatch_task(do_dedupe_finding_task, f.id), + fields=None, + order="desc", + log_prefix="deduplicating ", + ) if dedupe_sync: # update the grading (if enabled) and only useful in sync mode # in async mode the background task that grades products every hour will pick it up logger.debug("Updating grades for products...") for product in Product.objects.all(): - calculate_grade(product.id) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(calculate_grade, product.id) logger.info("######## Done deduplicating (%s) ########", ("foreground" if dedupe_sync else "tasks submitted to celery")) else: @@ -172,7 +183,9 @@ def _dedupe_batch_mode(self, findings_queryset, *, dedupe_sync: bool = True): else: # Asynchronous: submit task with finding IDs logger.debug(f"Submitting async batch task for {len(batch_finding_ids)} findings for test {test_id}") - do_dedupe_batch_task(batch_finding_ids) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(do_dedupe_batch_task, batch_finding_ids) total_processed += len(batch_finding_ids) batch_finding_ids = [] diff --git a/dojo/models.py b/dojo/models.py index 0e4680de67d..95abdb64fed 100644 --- a/dojo/models.py +++ b/dojo/models.py @@ -1095,7 +1095,9 @@ def save(self, *args, **kwargs): super(Product, product).save() # launch the async task to update all finding sla expiration dates from dojo.sla_config.helpers import async_update_sla_expiration_dates_sla_config_sync # noqa: I001, PLC0415 circular import - async_update_sla_expiration_dates_sla_config_sync(self, products, severities=severities) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(async_update_sla_expiration_dates_sla_config_sync, self, products, severities=severities) def clean(self): sla_days = [self.critical, self.high, self.medium, self.low] @@ -1257,7 +1259,9 @@ def save(self, *args, **kwargs): super(SLA_Configuration, sla_config).save() # launch the async task to update all finding sla expiration dates from dojo.sla_config.helpers import async_update_sla_expiration_dates_sla_config_sync # noqa: I001, PLC0415 circular import - async_update_sla_expiration_dates_sla_config_sync(sla_config, Product.objects.filter(id=self.id)) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(async_update_sla_expiration_dates_sla_config_sync, sla_config, Product.objects.filter(id=self.id)) def get_absolute_url(self): return reverse("view_product", args=[str(self.id)]) diff --git a/dojo/notifications/helper.py b/dojo/notifications/helper.py index b422896e83d..4f1fe992875 100644 --- a/dojo/notifications/helper.py +++ b/dojo/notifications/helper.py @@ -18,6 +18,7 @@ from dojo import __version__ as dd_version from dojo.authorization.roles_permissions import Permissions from dojo.celery import DojoAsyncTask, app +from dojo.celery_dispatch import dojo_dispatch_task from dojo.decorators import we_want_async from dojo.labels import get_labels from dojo.models import ( @@ -828,7 +829,8 @@ def _process_notifications( notifications.other, ): logger.debug("Sending Slack Notification") - self._get_manager_instance("slack").send_slack_notification( + dojo_dispatch_task( + self._get_manager_instance("slack").send_slack_notification, event, user=notifications.user, **kwargs, @@ -840,7 +842,8 @@ def _process_notifications( notifications.other, ): logger.debug("Sending MSTeams Notification") - self._get_manager_instance("msteams").send_msteams_notification( + dojo_dispatch_task( + self._get_manager_instance("msteams").send_msteams_notification, event, user=notifications.user, **kwargs, @@ -852,7 +855,8 @@ def _process_notifications( notifications.other, ): logger.debug("Sending Mail Notification") - self._get_manager_instance("mail").send_mail_notification( + dojo_dispatch_task( + self._get_manager_instance("mail").send_mail_notification, event, user=notifications.user, **kwargs, @@ -864,7 +868,8 @@ def _process_notifications( notifications.other, ): logger.debug("Sending Webhooks Notification") - self._get_manager_instance("webhooks").send_webhooks_notification( + dojo_dispatch_task( + self._get_manager_instance("webhooks").send_webhooks_notification, event, user=notifications.user, **kwargs, diff --git a/dojo/tags_signals.py b/dojo/tags_signals.py index 0cade958265..6b11b00644c 100644 --- a/dojo/tags_signals.py +++ b/dojo/tags_signals.py @@ -4,6 +4,7 @@ from django.db.models import signals from django.dispatch import receiver +from dojo.celery_dispatch import dojo_dispatch_task from dojo.models import Endpoint, Engagement, Finding, Product, Test from dojo.product import helpers as async_product_funcs from dojo.utils import get_system_setting @@ -19,7 +20,7 @@ def product_tags_post_add_remove(sender, instance, action, **kwargs): running_async_process = instance.running_async_process # Check if the async process is already running to avoid calling it a second time if not running_async_process and inherit_product_tags(instance): - async_product_funcs.propagate_tags_on_product(instance.id, countdown=5) + dojo_dispatch_task(async_product_funcs.propagate_tags_on_product, instance.id, countdown=5) instance.running_async_process = True diff --git a/dojo/tasks.py b/dojo/tasks.py index d5b904601d7..50d471cf68d 100644 --- a/dojo/tasks.py +++ b/dojo/tasks.py @@ -12,6 +12,7 @@ from dojo.auditlog import run_flush_auditlog from dojo.celery import DojoAsyncTask, app +from dojo.celery_dispatch import dojo_dispatch_task from dojo.finding.helper import fix_loop_duplicates from dojo.management.commands.jira_status_reconciliation import jira_status_reconciliation from dojo.models import Alerts, Announcement, Endpoint, Engagement, Finding, Product, System_Settings, User @@ -71,7 +72,7 @@ def add_alerts(self, runinterval): if system_settings.enable_product_grade: products = Product.objects.all() for product in products: - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) @app.task(bind=True) @@ -168,7 +169,7 @@ def _async_dupe_delete_impl(): if system_settings.enable_product_grade: logger.info("performing batch product grading for %s products", len(affected_products)) for product in affected_products: - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) @app.task(ignore_result=False) diff --git a/dojo/templatetags/display_tags.py b/dojo/templatetags/display_tags.py index f19c704fd55..e6a85adc4f3 100644 --- a/dojo/templatetags/display_tags.py +++ b/dojo/templatetags/display_tags.py @@ -304,7 +304,9 @@ def product_grade(product): if system_settings.enable_product_grade and product: prod_numeric_grade = product.prod_numeric_grade if not prod_numeric_grade or prod_numeric_grade is None: - calculate_grade(product.id) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(calculate_grade, product.id) if prod_numeric_grade: if prod_numeric_grade >= system_settings.product_grade_a: grade = "A" diff --git a/dojo/test/views.py b/dojo/test/views.py index d37825822c6..9d0e593152e 100644 --- a/dojo/test/views.py +++ b/dojo/test/views.py @@ -26,6 +26,7 @@ from dojo.authorization.authorization import user_has_permission_or_403 from dojo.authorization.authorization_decorators import user_is_authorized from dojo.authorization.roles_permissions import Permissions +from dojo.celery_dispatch import dojo_dispatch_task from dojo.engagement.queries import get_authorized_engagements from dojo.filters import FindingFilter, FindingFilterWithoutObjectLookups, TemplateFindingFilter, TestImportFilter from dojo.finding.queries import prefetch_for_findings @@ -343,7 +344,7 @@ def copy_test(request, tid): engagement = form.cleaned_data.get("engagement") product = test.engagement.product test_copy = test.copy(engagement=engagement) - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) messages.add_message( request, messages.SUCCESS, diff --git a/dojo/tools/tool_issue_updater.py b/dojo/tools/tool_issue_updater.py index 26cfcf6d973..93e6d93857f 100644 --- a/dojo/tools/tool_issue_updater.py +++ b/dojo/tools/tool_issue_updater.py @@ -3,6 +3,7 @@ import pghistory from dojo.celery import DojoAsyncTask, app +from dojo.celery_dispatch import dojo_dispatch_task from dojo.models import Finding from dojo.tools.api_sonarqube.parser import SCAN_SONARQUBE_API from dojo.tools.api_sonarqube.updater import SonarQubeApiUpdater @@ -14,7 +15,7 @@ def async_tool_issue_update(finding, *args, **kwargs): if is_tool_issue_updater_needed(finding): - tool_issue_updater(finding.id) + dojo_dispatch_task(tool_issue_updater, finding.id) def is_tool_issue_updater_needed(finding, *args, **kwargs): diff --git a/dojo/utils.py b/dojo/utils.py index 33347c82960..0806af4ce5c 100644 --- a/dojo/utils.py +++ b/dojo/utils.py @@ -1291,7 +1291,9 @@ def calculate_grade_internal(product, *args, **kwargs): def perform_product_grading(product): system_settings = System_Settings.objects.get() if system_settings.enable_product_grade: - calculate_grade(product.id) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(calculate_grade, product.id) def get_celery_worker_status(): diff --git a/unittests/test_importers_performance.py b/unittests/test_importers_performance.py index db43013a129..1f08b443899 100644 --- a/unittests/test_importers_performance.py +++ b/unittests/test_importers_performance.py @@ -265,20 +265,12 @@ def test_import_reimport_reimport_performance_pghistory_async(self): configure_pghistory_triggers() self._import_reimport_performance( - - - - expected_num_queries1=305, expected_num_async_tasks1=6, expected_num_queries2=232, expected_num_async_tasks2=17, expected_num_queries3=114, expected_num_async_tasks3=16, - - - - ) @override_settings(ENABLE_AUDITLOG=True) @@ -295,14 +287,12 @@ def test_import_reimport_reimport_performance_pghistory_no_async(self): testuser.usercontactinfo.save() self._import_reimport_performance( - expected_num_queries1=312, expected_num_async_tasks1=6, expected_num_queries2=239, expected_num_async_tasks2=17, expected_num_queries3=121, expected_num_async_tasks3=16, - ) @override_settings(ENABLE_AUDITLOG=True) @@ -320,16 +310,12 @@ def test_import_reimport_reimport_performance_pghistory_no_async_with_product_gr self.system_settings(enable_product_grade=True) self._import_reimport_performance( - - expected_num_queries1=316, expected_num_async_tasks1=8, expected_num_queries2=243, expected_num_async_tasks2=19, expected_num_queries3=125, expected_num_async_tasks3=18, - - ) # Deduplication is enabled in the tests above, but to properly test it we must run the same import twice and capture the results. @@ -468,11 +454,8 @@ def test_deduplication_performance_pghistory_no_async(self): testuser.usercontactinfo.save() self._deduplication_performance( - expected_num_queries1=281, expected_num_async_tasks1=7, expected_num_queries2=246, expected_num_async_tasks2=7, - - ) From 2869813f760023184d15fa4e4fe1bdcb4042361d Mon Sep 17 00:00:00 2001 From: Valentijn Scholten Date: Sat, 27 Dec 2025 09:56:06 +0100 Subject: [PATCH 15/36] fix notifications --- dojo/notifications/helper.py | 58 +++++++++++++------ unittests/test_jira_import_and_pushing_api.py | 2 +- 2 files changed, 40 insertions(+), 20 deletions(-) diff --git a/dojo/notifications/helper.py b/dojo/notifications/helper.py index 4f1fe992875..dc4cb434c1a 100644 --- a/dojo/notifications/helper.py +++ b/dojo/notifications/helper.py @@ -200,7 +200,6 @@ class SlackNotificationManger(NotificationManagerHelpers): """Manger for slack notifications and their helpers.""" - @app.task(base=DojoAsyncTask) def send_slack_notification( self, event: str, @@ -317,7 +316,6 @@ class MSTeamsNotificationManger(NotificationManagerHelpers): """Manger for Microsoft Teams notifications and their helpers.""" - @app.task(base=DojoAsyncTask) def send_msteams_notification( self, event: str, @@ -367,7 +365,6 @@ class EmailNotificationManger(NotificationManagerHelpers): """Manger for email notifications and their helpers.""" - @app.task(base=DojoAsyncTask) def send_mail_notification( self, event: str, @@ -418,7 +415,6 @@ class WebhookNotificationManger(NotificationManagerHelpers): ERROR_PERMANENT = "permanent" ERROR_TEMPORARY = "temporary" - @app.task(base=DojoAsyncTask) def send_webhooks_notification( self, event: str, @@ -477,11 +473,7 @@ def send_webhooks_notification( endpoint.first_error = now endpoint.status = Notification_Webhooks.Status.STATUS_INACTIVE_TMP # In case of failure within one day, endpoint can be deactivated temporally only for one minute - self._webhook_reactivation.apply_async( - args=[self], - kwargs={"endpoint_id": endpoint.pk}, - countdown=60, - ) + webhook_reactivation.apply_async(kwargs={"endpoint_id": endpoint.pk}, countdown=60) # There is no reason to keep endpoint active if it is returning 4xx errors else: endpoint.status = Notification_Webhooks.Status.STATUS_INACTIVE_PERMANENT @@ -556,7 +548,6 @@ def _test_webhooks_notification(self, endpoint: Notification_Webhooks) -> None: # in "send_webhooks_notification", we are doing deeper analysis, why it failed # for now, "raise_for_status" should be enough - @app.task(ignore_result=True) def _webhook_reactivation(self, endpoint_id: int, **_kwargs: dict): endpoint = Notification_Webhooks.objects.get(pk=endpoint_id) # User already changed status of endpoint @@ -830,9 +821,9 @@ def _process_notifications( ): logger.debug("Sending Slack Notification") dojo_dispatch_task( - self._get_manager_instance("slack").send_slack_notification, + send_slack_notification, event, - user=notifications.user, + user_id=getattr(notifications.user, "id", None), **kwargs, ) @@ -843,9 +834,9 @@ def _process_notifications( ): logger.debug("Sending MSTeams Notification") dojo_dispatch_task( - self._get_manager_instance("msteams").send_msteams_notification, + send_msteams_notification, event, - user=notifications.user, + user_id=getattr(notifications.user, "id", None), **kwargs, ) @@ -856,9 +847,9 @@ def _process_notifications( ): logger.debug("Sending Mail Notification") dojo_dispatch_task( - self._get_manager_instance("mail").send_mail_notification, + send_mail_notification, event, - user=notifications.user, + user_id=getattr(notifications.user, "id", None), **kwargs, ) @@ -869,13 +860,42 @@ def _process_notifications( ): logger.debug("Sending Webhooks Notification") dojo_dispatch_task( - self._get_manager_instance("webhooks").send_webhooks_notification, + send_webhooks_notification, event, - user=notifications.user, + user_id=getattr(notifications.user, "id", None), **kwargs, ) +@app.task(base=DojoAsyncTask) +def send_slack_notification(event: str, user_id: int | None = None, **kwargs: dict) -> None: + user = Dojo_User.objects.get(pk=user_id) if user_id else None + SlackNotificationManger().send_slack_notification(event, user=user, **kwargs) + + +@app.task(base=DojoAsyncTask) +def send_msteams_notification(event: str, user_id: int | None = None, **kwargs: dict) -> None: + user = Dojo_User.objects.get(pk=user_id) if user_id else None + MSTeamsNotificationManger().send_msteams_notification(event, user=user, **kwargs) + + +@app.task(base=DojoAsyncTask) +def send_mail_notification(event: str, user_id: int | None = None, **kwargs: dict) -> None: + user = Dojo_User.objects.get(pk=user_id) if user_id else None + EmailNotificationManger().send_mail_notification(event, user=user, **kwargs) + + +@app.task(base=DojoAsyncTask) +def send_webhooks_notification(event: str, user_id: int | None = None, **kwargs: dict) -> None: + user = Dojo_User.objects.get(pk=user_id) if user_id else None + WebhookNotificationManger().send_webhooks_notification(event, user=user, **kwargs) + + +@app.task(ignore_result=True) +def webhook_reactivation(endpoint_id: int, **_kwargs: dict) -> None: + WebhookNotificationManger()._webhook_reactivation(endpoint_id=endpoint_id) + + @app.task(ignore_result=True) def webhook_status_cleanup(*_args: list, **_kwargs: dict): # If some endpoint was affected by some outage (5xx, 429, Timeout) but it was clean during last 24 hours, @@ -903,4 +923,4 @@ def webhook_status_cleanup(*_args: list, **_kwargs: dict): ) for endpoint in broken_endpoints: manager = WebhookNotificationManger() - manager._webhook_reactivation(manager, endpoint_id=endpoint.pk) + manager._webhook_reactivation(endpoint_id=endpoint.pk) diff --git a/unittests/test_jira_import_and_pushing_api.py b/unittests/test_jira_import_and_pushing_api.py index ee0808b3ca8..fc165118b90 100644 --- a/unittests/test_jira_import_and_pushing_api.py +++ b/unittests/test_jira_import_and_pushing_api.py @@ -971,7 +971,7 @@ def test_engagement_epic_mapping_disabled_no_epic_and_push_findings(self): @patch("dojo.jira_link.helper.can_be_pushed_to_jira", return_value=(True, None, None)) @patch("dojo.jira_link.helper.is_push_all_issues", return_value=False) @patch("dojo.jira_link.helper.push_to_jira", return_value=None) - @patch("dojo.notifications.helper.WebhookNotificationManger.send_webhooks_notification") + @patch("dojo.notifications.helper.send_webhooks_notification") def test_bulk_edit_mixed_findings_and_groups_jira_push_bug(self, mock_webhooks, mock_push_to_jira, mock_is_push_all_issues, mock_can_be_pushed): """ Test the bug in bulk edit: when bulk editing findings where some are in groups From 917aa72f043e286c86118905e3cfb923b086fb87 Mon Sep 17 00:00:00 2001 From: Valentijn Scholten Date: Sat, 27 Dec 2025 21:05:18 +0100 Subject: [PATCH 16/36] fix test --- unittests/test_notifications.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unittests/test_notifications.py b/unittests/test_notifications.py index 7c5b289a211..c1956d82679 100644 --- a/unittests/test_notifications.py +++ b/unittests/test_notifications.py @@ -649,7 +649,7 @@ def test_webhook_reactivation(self): with self.subTest("active"): wh = Notification_Webhooks.objects.filter(owner=None).first() manager = WebhookNotificationManger() - manager._webhook_reactivation(manager, endpoint_id=wh.pk) + manager._webhook_reactivation(endpoint_id=wh.pk) updated_wh = Notification_Webhooks.objects.filter(owner=None).first() self.assertEqual(updated_wh.status, Notification_Webhooks.Status.STATUS_ACTIVE) @@ -668,7 +668,7 @@ def test_webhook_reactivation(self): with self.assertLogs("dojo.notifications.helper", level="DEBUG") as cm: manager = WebhookNotificationManger() - manager._webhook_reactivation(manager, endpoint_id=wh.pk) + manager._webhook_reactivation(endpoint_id=wh.pk) updated_wh = Notification_Webhooks.objects.filter(owner=None).first() self.assertEqual(updated_wh.status, Notification_Webhooks.Status.STATUS_ACTIVE_TMP) From a33fb20ad50c726a079bc97e1b21cde760f5ba54 Mon Sep 17 00:00:00 2001 From: Valentijn Scholten Date: Mon, 5 Jan 2026 20:36:18 +0100 Subject: [PATCH 17/36] fix test --- unittests/test_deduplication_logic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unittests/test_deduplication_logic.py b/unittests/test_deduplication_logic.py index a941de1f2c4..56deb1e96d1 100644 --- a/unittests/test_deduplication_logic.py +++ b/unittests/test_deduplication_logic.py @@ -1075,6 +1075,7 @@ def test_multiple_findings_same_unique_id_mixed_states_unique_id(self): finding_new.is_mitigated = False finding_new.unique_id_from_tool = original_unique_id finding_new.save() + finding_new.refresh_from_db() # The new finding should be marked as duplicate of the active finding, # not the mitigated one (even though mitigated has lower ID) From e43ef4078006d323d4fa4b1ebba3a4bdcd27aa6f Mon Sep 17 00:00:00 2001 From: Valentijn Scholten Date: Wed, 7 Jan 2026 18:15:48 +0100 Subject: [PATCH 18/36] pghistory inherits from dojoasynctask --- dojo/celery.py | 67 ++++++++++++++++-------------- dojo/celery_dispatch.py | 16 ++++++- dojo/finding/deduplication.py | 6 +-- dojo/finding/helper.py | 6 +-- dojo/importers/endpoint_manager.py | 8 ++-- dojo/jira_link/helper.py | 16 +++---- dojo/notifications/helper.py | 10 ++--- dojo/product/helpers.py | 4 +- dojo/sla_config/helpers.py | 4 +- dojo/tasks.py | 4 +- dojo/tools/tool_issue_updater.py | 6 +-- dojo/utils.py | 18 ++++---- 12 files changed, 91 insertions(+), 74 deletions(-) diff --git a/dojo/celery.py b/dojo/celery.py index 3079d25d20e..3cf09e1bc2c 100644 --- a/dojo/celery.py +++ b/dojo/celery.py @@ -12,38 +12,6 @@ os.environ.setdefault("DJANGO_SETTINGS_MODULE", "dojo.settings.settings") -class PgHistoryTask(Task): - - """ - Custom Celery base task that automatically applies pghistory context. - - When a task is dispatched via dojo_async_task, the current pghistory - context is captured and passed in kwargs as "_pgh_context". This base - class extracts that context and applies it before running the task, - ensuring all database events share the same context as the original - request. - """ - - def __call__(self, *args, **kwargs): - # Import here to avoid circular imports during Celery startup - from dojo.pghistory_utils import get_pghistory_context_manager # noqa: PLC0415 - - # Extract context from kwargs (won't be passed to task function) - pgh_context = kwargs.pop("_pgh_context", None) - - with get_pghistory_context_manager(pgh_context): - return super().__call__(*args, **kwargs) - - -app = Celery("dojo", task_cls=PgHistoryTask) - -# Using a string here means the worker will not have to -# pickle the object when using Windows. -app.config_from_object("django.conf:settings", namespace="CELERY") - -app.autodiscover_tasks(lambda: settings.INSTALLED_APPS) - - class DojoAsyncTask(Task): """ @@ -81,6 +49,41 @@ def apply_async(self, args=None, kwargs=None, **options): return super().apply_async(args=args, kwargs=kwargs, **options) +class PgHistoryTask(DojoAsyncTask): + + """ + Custom Celery base task that automatically applies pghistory context. + + This class inherits from DojoAsyncTask to provide: + - User context injection and task tracking (from DojoAsyncTask) + - Automatic pghistory context application (from this class) + + When a task is dispatched via dojo_dispatch_task or dojo_async_task, the current + pghistory context is captured and passed in kwargs as "_pgh_context". This base + class extracts that context and applies it before running the task, ensuring all + database events share the same context as the original request. + """ + + def __call__(self, *args, **kwargs): + # Import here to avoid circular imports during Celery startup + from dojo.pghistory_utils import get_pghistory_context_manager # noqa: PLC0415 + + # Extract context from kwargs (won't be passed to task function) + pgh_context = kwargs.pop("_pgh_context", None) + + with get_pghistory_context_manager(pgh_context): + return super().__call__(*args, **kwargs) + + +app = Celery("dojo", task_cls=PgHistoryTask) + +# Using a string here means the worker will not have to +# pickle the object when using Windows. +app.config_from_object("django.conf:settings", namespace="CELERY") + +app.autodiscover_tasks(lambda: settings.INSTALLED_APPS) + + @app.task(bind=True) def debug_task(self): logger.info(f"Request: {self.request!r}") diff --git a/dojo/celery_dispatch.py b/dojo/celery_dispatch.py index 5bbf5a6ea7c..a96c4257553 100644 --- a/dojo/celery_dispatch.py +++ b/dojo/celery_dispatch.py @@ -25,14 +25,26 @@ def _inject_async_user(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: return result +def _inject_pghistory_context(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: + """Capture and inject pghistory context if available.""" + result: dict[str, Any] = dict(kwargs or {}) + if "_pgh_context" not in result: + from dojo.pghistory_utils import get_serializable_pghistory_context # noqa: PLC0415 circular import + + if pgh_context := get_serializable_pghistory_context(): + result["_pgh_context"] = pgh_context + return result + + def dojo_create_signature(task_or_sig: _SupportsSi | Signature, *args: Any, **kwargs: Any) -> Signature: """ - Build a Celery signature with DefectDojo user context injected. + Build a Celery signature with DefectDojo user context and pghistory context injected. - If passed a task, returns `task_or_sig.si(*args, **kwargs)`. - If passed an existing signature, returns a cloned signature with merged kwargs. """ injected = _inject_async_user(kwargs) + injected = _inject_pghistory_context(injected) injected.pop("countdown", None) if isinstance(task_or_sig, Signature): @@ -47,6 +59,7 @@ def dojo_dispatch_task(task_or_sig: _SupportsSi | _SupportsApplyAsync | Signatur Dispatch a task/signature using DefectDojo semantics. - Inject `async_user` if missing. + - Capture and inject pghistory context if available. - Respect `sync=True` (foreground execution) and user `block_execution`. - Support `countdown=` for async dispatch. @@ -59,6 +72,7 @@ def dojo_dispatch_task(task_or_sig: _SupportsSi | _SupportsApplyAsync | Signatur countdown = cast("int", kwargs.pop("countdown", 0)) injected = _inject_async_user(kwargs) + injected = _inject_pghistory_context(injected) sig = dojo_create_signature(task_or_sig if isinstance(task_or_sig, Signature) else cast("_SupportsSi", task_or_sig), *args, **injected) sig_kwargs = dict(sig.kwargs or {}) diff --git a/dojo/finding/deduplication.py b/dojo/finding/deduplication.py index 1d778e1c9b9..301884349a2 100644 --- a/dojo/finding/deduplication.py +++ b/dojo/finding/deduplication.py @@ -7,7 +7,7 @@ from django.db.models import Prefetch from django.db.models.query_utils import Q -from dojo.celery import DojoAsyncTask, app +from dojo.celery import app from dojo.models import Finding, System_Settings logger = logging.getLogger(__name__) @@ -44,12 +44,12 @@ def get_finding_models_for_deduplication(finding_ids): ) -@app.task(base=DojoAsyncTask) +@app.task def do_dedupe_finding_task(new_finding_id, *args, **kwargs): return do_dedupe_finding_task_internal(Finding.objects.get(id=new_finding_id), *args, **kwargs) -@app.task(base=DojoAsyncTask) +@app.task def do_dedupe_batch_task(finding_ids, *args, **kwargs): """ Async task to deduplicate a batch of findings. The findings are assumed to be in the same test. diff --git a/dojo/finding/helper.py b/dojo/finding/helper.py index 65f4d2bf9d3..03e5fd409a2 100644 --- a/dojo/finding/helper.py +++ b/dojo/finding/helper.py @@ -15,7 +15,7 @@ import dojo.jira_link.helper as jira_helper import dojo.risk_acceptance.helper as ra_helper -from dojo.celery import DojoAsyncTask, app +from dojo.celery import app from dojo.endpoint.utils import endpoint_get_or_create, save_endpoints_to_add from dojo.file_uploads.helper import delete_related_files from dojo.finding.deduplication import ( @@ -390,7 +390,7 @@ def add_findings_to_auto_group(name, findings, group_by, *, create_finding_group finding_group.findings.add(*findings) -@app.task(base=DojoAsyncTask) +@app.task def post_process_finding_save(finding_id, dedupe_option=True, rules_option=True, product_grading_option=True, # noqa: FBT002 issue_updater_option=True, push_to_jira=False, user=None, *args, **kwargs): # noqa: FBT002 - this is bit hard to fix nice have this universally fixed finding = get_object_or_none(Finding, id=finding_id) @@ -453,7 +453,7 @@ def post_process_finding_save_internal(finding, dedupe_option=True, rules_option jira_helper.push_to_jira(finding.finding_group) -@app.task(base=DojoAsyncTask) +@app.task def post_process_findings_batch(finding_ids, *args, dedupe_option=True, rules_option=True, product_grading_option=True, issue_updater_option=True, push_to_jira=False, user=None, **kwargs): diff --git a/dojo/importers/endpoint_manager.py b/dojo/importers/endpoint_manager.py index 6092ca82c77..5fdd4603aad 100644 --- a/dojo/importers/endpoint_manager.py +++ b/dojo/importers/endpoint_manager.py @@ -4,7 +4,7 @@ from django.urls import reverse from django.utils import timezone -from dojo.celery import DojoAsyncTask, app +from dojo.celery import app from dojo.celery_dispatch import dojo_dispatch_task from dojo.endpoint.utils import endpoint_get_or_create from dojo.models import ( @@ -18,7 +18,7 @@ class EndpointManager: - @app.task(base=DojoAsyncTask) + @app.task def add_endpoints_to_unsaved_finding( finding: Finding, # noqa: N805 endpoints: list[Endpoint], @@ -57,7 +57,7 @@ def add_endpoints_to_unsaved_finding( logger.debug(f"IMPORT_SCAN: {len(endpoints)} endpoints imported") - @app.task(base=DojoAsyncTask) + @app.task def mitigate_endpoint_status( endpoint_status_list: list[Endpoint_Status], # noqa: N805 user: Dojo_User, @@ -82,7 +82,7 @@ def mitigate_endpoint_status( batch_size=1000, ) - @app.task(base=DojoAsyncTask) + @app.task def reactivate_endpoint_status( endpoint_status_list: list[Endpoint_Status], # noqa: N805 **kwargs: dict, diff --git a/dojo/jira_link/helper.py b/dojo/jira_link/helper.py index 513acf4ef5b..7154c54a3a6 100644 --- a/dojo/jira_link/helper.py +++ b/dojo/jira_link/helper.py @@ -17,7 +17,7 @@ from jira.exceptions import JIRAError from requests.auth import HTTPBasicAuth -from dojo.celery import DojoAsyncTask, app +from dojo.celery import app from dojo.celery_dispatch import dojo_dispatch_task from dojo.forms import JIRAEngagementForm, JIRAProjectForm from dojo.models import ( @@ -773,7 +773,7 @@ def push_to_jira(obj, *args, **kwargs): # we need thre separate celery tasks due to the decorators we're using to map to/from ids -@app.task(base=DojoAsyncTask) +@app.task def push_finding_to_jira(finding_id, *args, **kwargs): finding = get_object_or_none(Finding, id=finding_id) if not finding: @@ -785,7 +785,7 @@ def push_finding_to_jira(finding_id, *args, **kwargs): return add_jira_issue(finding, *args, **kwargs) -@app.task(base=DojoAsyncTask) +@app.task def push_finding_group_to_jira(finding_group_id, *args, **kwargs): finding_group = get_object_or_none(Finding_Group, id=finding_group_id) if not finding_group: @@ -801,7 +801,7 @@ def push_finding_group_to_jira(finding_group_id, *args, **kwargs): return add_jira_issue(finding_group, *args, **kwargs) -@app.task(base=DojoAsyncTask) +@app.task def push_engagement_to_jira(engagement_id, *args, **kwargs): engagement = get_object_or_none(Engagement, id=engagement_id) if not engagement: @@ -1373,7 +1373,7 @@ def jira_check_attachment(issue, source_file_name): return file_exists -@app.task(base=DojoAsyncTask) +@app.task def close_epic(engagement_id, push_to_jira, **kwargs): engagement = get_object_or_none(Engagement, id=engagement_id) if not engagement: @@ -1421,7 +1421,7 @@ def close_epic(engagement_id, push_to_jira, **kwargs): return False -@app.task(base=DojoAsyncTask) +@app.task def update_epic(engagement_id, **kwargs): engagement = get_object_or_none(Engagement, id=engagement_id) if not engagement: @@ -1467,7 +1467,7 @@ def update_epic(engagement_id, **kwargs): return False -@app.task(base=DojoAsyncTask) +@app.task def add_epic(engagement_id, **kwargs): engagement = get_object_or_none(Engagement, id=engagement_id) if not engagement: @@ -1578,7 +1578,7 @@ def add_comment(obj, note, *, force_push=False, **kwargs): return dojo_dispatch_task(add_comment_internal, jira_issue.id, note.id, force_push=force_push, **kwargs) -@app.task(base=DojoAsyncTask) +@app.task def add_comment_internal(jira_issue_id, note_id, *, force_push=False, **kwargs): """Internal Celery task that adds a comment to a JIRA issue.""" jira_issue = get_object_or_none(JIRA_Issue, id=jira_issue_id) diff --git a/dojo/notifications/helper.py b/dojo/notifications/helper.py index dc4cb434c1a..ae62fc8f4d7 100644 --- a/dojo/notifications/helper.py +++ b/dojo/notifications/helper.py @@ -17,7 +17,7 @@ from dojo import __version__ as dd_version from dojo.authorization.roles_permissions import Permissions -from dojo.celery import DojoAsyncTask, app +from dojo.celery import app from dojo.celery_dispatch import dojo_dispatch_task from dojo.decorators import we_want_async from dojo.labels import get_labels @@ -867,25 +867,25 @@ def _process_notifications( ) -@app.task(base=DojoAsyncTask) +@app.task def send_slack_notification(event: str, user_id: int | None = None, **kwargs: dict) -> None: user = Dojo_User.objects.get(pk=user_id) if user_id else None SlackNotificationManger().send_slack_notification(event, user=user, **kwargs) -@app.task(base=DojoAsyncTask) +@app.task def send_msteams_notification(event: str, user_id: int | None = None, **kwargs: dict) -> None: user = Dojo_User.objects.get(pk=user_id) if user_id else None MSTeamsNotificationManger().send_msteams_notification(event, user=user, **kwargs) -@app.task(base=DojoAsyncTask) +@app.task def send_mail_notification(event: str, user_id: int | None = None, **kwargs: dict) -> None: user = Dojo_User.objects.get(pk=user_id) if user_id else None EmailNotificationManger().send_mail_notification(event, user=user, **kwargs) -@app.task(base=DojoAsyncTask) +@app.task def send_webhooks_notification(event: str, user_id: int | None = None, **kwargs: dict) -> None: user = Dojo_User.objects.get(pk=user_id) if user_id else None WebhookNotificationManger().send_webhooks_notification(event, user=user, **kwargs) diff --git a/dojo/product/helpers.py b/dojo/product/helpers.py index f23e4155548..8a308d9b62c 100644 --- a/dojo/product/helpers.py +++ b/dojo/product/helpers.py @@ -1,13 +1,13 @@ import contextlib import logging -from dojo.celery import DojoAsyncTask, app +from dojo.celery import app from dojo.models import Endpoint, Engagement, Finding, Product, Test logger = logging.getLogger(__name__) -@app.task(base=DojoAsyncTask) +@app.task def propagate_tags_on_product(product_id, *args, **kwargs): with contextlib.suppress(Product.DoesNotExist): product = Product.objects.get(id=product_id) diff --git a/dojo/sla_config/helpers.py b/dojo/sla_config/helpers.py index dd2567729dc..045456f38d7 100644 --- a/dojo/sla_config/helpers.py +++ b/dojo/sla_config/helpers.py @@ -1,13 +1,13 @@ import logging -from dojo.celery import DojoAsyncTask, app +from dojo.celery import app from dojo.models import Finding, Product, SLA_Configuration, System_Settings from dojo.utils import get_custom_method, mass_model_updater logger = logging.getLogger(__name__) -@app.task(base=DojoAsyncTask) +@app.task def async_update_sla_expiration_dates_sla_config_sync(sla_config: SLA_Configuration, products: list[Product], *args, severities: list[str] | None = None, **kwargs): if method := get_custom_method("FINDING_SLA_EXPIRATION_CALCULATION_METHOD"): method(sla_config, products, severities=severities) diff --git a/dojo/tasks.py b/dojo/tasks.py index 50d471cf68d..3268d9d4b69 100644 --- a/dojo/tasks.py +++ b/dojo/tasks.py @@ -11,7 +11,7 @@ from django.utils import timezone from dojo.auditlog import run_flush_auditlog -from dojo.celery import DojoAsyncTask, app +from dojo.celery import app from dojo.celery_dispatch import dojo_dispatch_task from dojo.finding.helper import fix_loop_duplicates from dojo.management.commands.jira_status_reconciliation import jira_status_reconciliation @@ -237,7 +237,7 @@ def clear_sessions(*args, **kwargs): call_command("clearsessions") -@app.task(base=DojoAsyncTask) +@app.task def update_watson_search_index_for_model(model_name, pk_list, *args, **kwargs): """ Async task to update watson search indexes for a specific model type. diff --git a/dojo/tools/tool_issue_updater.py b/dojo/tools/tool_issue_updater.py index 93e6d93857f..8211e166eed 100644 --- a/dojo/tools/tool_issue_updater.py +++ b/dojo/tools/tool_issue_updater.py @@ -2,7 +2,7 @@ import pghistory -from dojo.celery import DojoAsyncTask, app +from dojo.celery import app from dojo.celery_dispatch import dojo_dispatch_task from dojo.models import Finding from dojo.tools.api_sonarqube.parser import SCAN_SONARQUBE_API @@ -23,7 +23,7 @@ def is_tool_issue_updater_needed(finding, *args, **kwargs): return test_type.name == SCAN_SONARQUBE_API -@app.task(base=DojoAsyncTask) +@app.task def tool_issue_updater(finding_id, *args, **kwargs): finding = get_object_or_none(Finding, id=finding_id) if not finding: @@ -36,7 +36,7 @@ def tool_issue_updater(finding_id, *args, **kwargs): SonarQubeApiUpdater().update_sonarqube_finding(finding) -@app.task(base=DojoAsyncTask) +@app.task def update_findings_from_source_issues(**kwargs): # Wrap with pghistory context for audit trail with pghistory.context(source="sonarqube_sync"): diff --git a/dojo/utils.py b/dojo/utils.py index 0806af4ce5c..d964e8cc736 100644 --- a/dojo/utils.py +++ b/dojo/utils.py @@ -45,7 +45,7 @@ from django.utils.translation import gettext as _ from dojo.authorization.roles_permissions import Permissions -from dojo.celery import DojoAsyncTask, app +from dojo.celery import app from dojo.finding.queries import get_authorized_findings from dojo.github import ( add_external_issue_github, @@ -1053,7 +1053,7 @@ def handle_uploaded_selenium(f, cred): cred.save() -@app.task(base=DojoAsyncTask) +@app.task def add_external_issue(finding_id, external_issue_provider, **kwargs): finding = get_object_or_none(Finding, id=finding_id) if not finding: @@ -1068,7 +1068,7 @@ def add_external_issue(finding_id, external_issue_provider, **kwargs): add_external_issue_github(finding, prod, eng) -@app.task(base=DojoAsyncTask) +@app.task def update_external_issue(finding_id, old_status, external_issue_provider, **kwargs): finding = get_object_or_none(Finding, id=finding_id) if not finding: @@ -1082,7 +1082,7 @@ def update_external_issue(finding_id, old_status, external_issue_provider, **kwa update_external_issue_github(finding, prod, eng) -@app.task(base=DojoAsyncTask) +@app.task def close_external_issue(finding_id, note, external_issue_provider, **kwargs): finding = get_object_or_none(Finding, id=finding_id) if not finding: @@ -1096,7 +1096,7 @@ def close_external_issue(finding_id, note, external_issue_provider, **kwargs): close_external_issue_github(finding, note, prod, eng) -@app.task(base=DojoAsyncTask) +@app.task def reopen_external_issue(finding_id, note, external_issue_provider, **kwargs): finding = get_object_or_none(Finding, id=finding_id) if not finding: @@ -1231,7 +1231,7 @@ def get_setting(setting): return getattr(settings, setting) -@app.task(base=DojoAsyncTask) +@app.task def calculate_grade(product_id, *args, **kwargs): product = get_object_or_none(Product, id=product_id) if not product: @@ -2012,7 +2012,7 @@ def __init__(self, *args, **kwargs): "Test": [(Finding, "test__id")], } - @app.task(base=DojoAsyncTask) + @app.task def delete_chunk(self, objects, **kwargs): # Now delete all objects with retry for deadlocks max_retries = 3 @@ -2060,7 +2060,7 @@ def delete_chunk(self, objects, **kwargs): obj.delete() break - @app.task(base=DojoAsyncTask) + @app.task def delete(self, obj, **kwargs): logger.debug("ASYNC_DELETE: Deleting " + self.get_object_name(obj) + ": " + str(obj)) model_list = self.mapping.get(self.get_object_name(obj), None) @@ -2072,7 +2072,7 @@ def delete(self, obj, **kwargs): logger.debug("ASYNC_DELETE: " + self.get_object_name(obj) + " async delete not supported. Deleteing normally: " + str(obj)) obj.delete() - @app.task(base=DojoAsyncTask) + @app.task def crawl(self, obj, model_list, **kwargs): logger.debug("ASYNC_DELETE: Crawling " + self.get_object_name(obj) + ": " + str(obj)) for model_info in model_list: From 6d73654ba7d44fb3199f1aedb84d34125073ea57 Mon Sep 17 00:00:00 2001 From: Valentijn Scholten Date: Wed, 7 Jan 2026 18:37:07 +0100 Subject: [PATCH 19/36] fix system settings celery_status --- dojo/tasks.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/dojo/tasks.py b/dojo/tasks.py index 3268d9d4b69..8a566c9acfc 100644 --- a/dojo/tasks.py +++ b/dojo/tasks.py @@ -2,6 +2,7 @@ from datetime import timedelta import pghistory +from celery import Task from celery.utils.log import get_task_logger from django.apps import apps from django.conf import settings @@ -172,8 +173,15 @@ def _async_dupe_delete_impl(): dojo_dispatch_task(calculate_grade, product.id) -@app.task(ignore_result=False) +@app.task(ignore_result=False, base=Task) def celery_status(): + """ + Simple health check task to verify Celery is running. + + Uses base Task class (not PgHistoryTask) since it doesn't need: + - User context tracking + - Pghistory context (no database modifications) + """ return True From ce845173517d19bd465ef1b6b8a7ba393b95e818 Mon Sep 17 00:00:00 2001 From: Cody Maffucci <46459665+Maffooch@users.noreply.github.com> Date: Wed, 14 Jan 2026 10:49:20 -0700 Subject: [PATCH 20/36] Enforce readonly name field for existing Test_Type instances in form --- dojo/forms.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/dojo/forms.py b/dojo/forms.py index b2b39509933..000dec362a0 100644 --- a/dojo/forms.py +++ b/dojo/forms.py @@ -324,6 +324,17 @@ class Meta: model = Test_Type exclude = ["dynamically_generated"] + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if self.instance.pk: + self.fields["name"].widget.attrs["readonly"] = True + + def clean_name(self): + if self.instance.pk: + return self.instance.name + return self.cleaned_data["name"] + class Development_EnvironmentForm(forms.ModelForm): class Meta: From 764d7cc1e9f54c28135025c9cefce4640d1266eb Mon Sep 17 00:00:00 2001 From: Cody Maffucci <46459665+Maffooch@users.noreply.github.com> Date: Wed, 14 Jan 2026 10:49:30 -0700 Subject: [PATCH 21/36] Add TestTypeCreateSerializer and enforce readonly name field in TestTypeSerializer --- dojo/api_v2/serializers.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/dojo/api_v2/serializers.py b/dojo/api_v2/serializers.py index 6bc18b5115f..9a17e4be985 100644 --- a/dojo/api_v2/serializers.py +++ b/dojo/api_v2/serializers.py @@ -1472,8 +1472,15 @@ class Meta: exclude = ("inherited_tags",) +class TestTypeCreateSerializer(serializers.ModelSerializer): + + class Meta: + model = Test_Type + exclude = ("dynamically_generated",) + + class TestTypeSerializer(serializers.ModelSerializer): - tags = TagListSerializerField(required=False) + name = serializers.ReadOnlyField() class Meta: model = Test_Type From 8a40cd1e4828a9acc8d99e945a9b7a5451e35aa1 Mon Sep 17 00:00:00 2001 From: Cody Maffucci <46459665+Maffooch@users.noreply.github.com> Date: Wed, 14 Jan 2026 10:49:36 -0700 Subject: [PATCH 22/36] Add dynamic serializer selection in TestTypesViewSet for create action --- dojo/api_v2/views.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dojo/api_v2/views.py b/dojo/api_v2/views.py index d3aeb51d85c..a6c89639147 100644 --- a/dojo/api_v2/views.py +++ b/dojo/api_v2/views.py @@ -2264,6 +2264,11 @@ class TestTypesViewSet( def get_queryset(self): return Test_Type.objects.all().order_by("id") + def get_serializer_class(self): + if self.action == "create": + return serializers.TestTypeCreateSerializer + return serializers.TestTypeSerializer + # @extend_schema_view(**schema_with_prefetch()) # Nested models with prefetch make the response schema too long for Swagger UI From 396bd08819ee025cd5cc54dc46408bbafa956836 Mon Sep 17 00:00:00 2001 From: Cody Maffucci <46459665+Maffooch@users.noreply.github.com> Date: Wed, 14 Jan 2026 11:44:38 -0700 Subject: [PATCH 23/36] Update test payload to set 'active' field instead of 'name' --- unittests/test_rest_framework.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unittests/test_rest_framework.py b/unittests/test_rest_framework.py index f32350e2e86..f89b20b4d66 100644 --- a/unittests/test_rest_framework.py +++ b/unittests/test_rest_framework.py @@ -3256,7 +3256,7 @@ def __init__(self, *args, **kwargs): self.viewname = "test_type" self.viewset = TestTypesViewSet self.payload = { - "name": "Test_1", + "active": False, } self.update_fields = {"name": "Test_2"} self.test_type = TestType.CONFIGURATION_PERMISSIONS From 8eb4ee103b8857d1055f9048cf6a8c26224d06ae Mon Sep 17 00:00:00 2001 From: Cody Maffucci <46459665+Maffooch@users.noreply.github.com> Date: Wed, 14 Jan 2026 12:06:41 -0700 Subject: [PATCH 24/36] Update TestTypeTest payload to use 'name' and modify update_fields to 'active' --- unittests/test_rest_framework.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unittests/test_rest_framework.py b/unittests/test_rest_framework.py index f89b20b4d66..89969a09321 100644 --- a/unittests/test_rest_framework.py +++ b/unittests/test_rest_framework.py @@ -3256,9 +3256,9 @@ def __init__(self, *args, **kwargs): self.viewname = "test_type" self.viewset = TestTypesViewSet self.payload = { - "active": False, + "name": "Test_1", } - self.update_fields = {"name": "Test_2"} + self.update_fields = {"active": False} self.test_type = TestType.CONFIGURATION_PERMISSIONS self.deleted_objects = 1 BaseClass.RESTEndpointTest.__init__(self, *args, **kwargs) From 4d23535efd1c3e6f7c435395f7fed2b518cdbbfa Mon Sep 17 00:00:00 2001 From: Valentijn Scholten Date: Sat, 17 Jan 2026 12:18:59 +0100 Subject: [PATCH 25/36] fix async delete, add tests --- dojo/utils.py | 285 ++++++++++++++++++------------ unittests/test_async_delete.py | 313 +++++++++++++++++++++++++++++++++ 2 files changed, 486 insertions(+), 112 deletions(-) create mode 100644 unittests/test_async_delete.py diff --git a/dojo/utils.py b/dojo/utils.py index 48be725ec57..c435cb3470a 100644 --- a/dojo/utils.py +++ b/dojo/utils.py @@ -2008,126 +2008,187 @@ def is_finding_groups_enabled(): return get_system_setting("enable_finding_groups") -class async_delete: - def __init__(self, *args, **kwargs): - self.mapping = { - "Product_Type": [ - (Endpoint, "product__prod_type__id"), - (Finding, "test__engagement__product__prod_type__id"), - (Test, "engagement__product__prod_type__id"), - (Engagement, "product__prod_type__id"), - (Product, "prod_type__id")], - "Product": [ - (Endpoint, "product__id"), - (Finding, "test__engagement__product__id"), - (Test, "engagement__product__id"), - (Engagement, "product__id")], - "Engagement": [ - (Finding, "test__engagement__id"), - (Test, "engagement__id")], - "Test": [(Finding, "test__id")], - } +# Mapping of object types to their related models for cascading deletes +ASYNC_DELETE_MAPPING = { + "Product_Type": [ + (Endpoint, "product__prod_type__id"), + (Finding, "test__engagement__product__prod_type__id"), + (Test, "engagement__product__prod_type__id"), + (Engagement, "product__prod_type__id"), + (Product, "prod_type__id")], + "Product": [ + (Endpoint, "product__id"), + (Finding, "test__engagement__product__id"), + (Test, "engagement__product__id"), + (Engagement, "product__id")], + "Engagement": [ + (Finding, "test__engagement__id"), + (Test, "engagement__id")], + "Test": [(Finding, "test__id")], +} + + +def _get_object_name(obj): + """Get the class name of an object or model class.""" + if obj.__class__.__name__ == "ModelBase": + return obj.__name__ + return obj.__class__.__name__ - @app.task - def delete_chunk(self, objects, **kwargs): - # Now delete all objects with retry for deadlocks - max_retries = 3 - for obj in objects: - retry_count = 0 - while retry_count < max_retries: - try: - obj.delete() - break # Success, exit retry loop - except OperationalError as e: - error_msg = str(e) - if "deadlock detected" in error_msg.lower(): - retry_count += 1 - if retry_count < max_retries: - # Exponential backoff with jitter - wait_time = (2 ** retry_count) + random.uniform(0, 1) # noqa: S311 - logger.warning( - f"ASYNC_DELETE: Deadlock detected deleting {self.get_object_name(obj)} {obj.pk}, " - f"retrying ({retry_count}/{max_retries}) after {wait_time:.2f}s", - ) - time.sleep(wait_time) - # Refresh object from DB before retry - obj.refresh_from_db() - else: - logger.error( - f"ASYNC_DELETE: Deadlock persisted after {max_retries} retries for {self.get_object_name(obj)} {obj.pk}: {e}", - ) - raise + +@app.task +def async_delete_chunk_task(objects, **kwargs): + """ + Module-level Celery task to delete a chunk of objects. + + Accepts **kwargs for async_user and _pgh_context injected by dojo_dispatch_task. + Uses PgHistoryTask base class (default) to preserve pghistory context for audit trail. + """ + max_retries = 3 + for obj in objects: + retry_count = 0 + while retry_count < max_retries: + try: + obj.delete() + break # Success, exit retry loop + except OperationalError as e: + error_msg = str(e) + if "deadlock detected" in error_msg.lower(): + retry_count += 1 + if retry_count < max_retries: + # Exponential backoff with jitter + wait_time = (2 ** retry_count) + random.uniform(0, 1) # noqa: S311 + logger.warning( + f"ASYNC_DELETE: Deadlock detected deleting {_get_object_name(obj)} {obj.pk}, " + f"retrying ({retry_count}/{max_retries}) after {wait_time:.2f}s", + ) + time.sleep(wait_time) + # Refresh object from DB before retry + obj.refresh_from_db() else: - # Not a deadlock, re-raise + logger.error( + f"ASYNC_DELETE: Deadlock persisted after {max_retries} retries for {_get_object_name(obj)} {obj.pk}: {e}", + ) raise - except AssertionError: - logger.debug("ASYNC_DELETE: object has already been deleted elsewhere. Skipping") - # The id must be None - # The object has already been deleted elsewhere - break - except LogEntry.MultipleObjectsReturned: - # Delete the log entrys first, then delete - LogEntry.objects.filter( - content_type=ContentType.objects.get_for_model(obj.__class__), - object_pk=str(obj.pk), - action=LogEntry.Action.DELETE, - ).delete() - # Now delete the object again (no retry needed for this case) - obj.delete() - break - - @app.task + else: + # Not a deadlock, re-raise + raise + except AssertionError: + logger.debug("ASYNC_DELETE: object has already been deleted elsewhere. Skipping") + # The id must be None + # The object has already been deleted elsewhere + break + except LogEntry.MultipleObjectsReturned: + # Delete the log entrys first, then delete + LogEntry.objects.filter( + content_type=ContentType.objects.get_for_model(obj.__class__), + object_pk=str(obj.pk), + action=LogEntry.Action.DELETE, + ).delete() + # Now delete the object again (no retry needed for this case) + obj.delete() + break + + +@app.task +def async_delete_crawl_task(obj, model_list, **kwargs): + """ + Module-level Celery task to crawl and delete related objects. + + Accepts **kwargs for async_user and _pgh_context injected by dojo_dispatch_task. + Uses PgHistoryTask base class (default) to preserve pghistory context for audit trail. + """ + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + logger.debug("ASYNC_DELETE: Crawling " + _get_object_name(obj) + ": " + str(obj)) + for model_info in model_list: + task_results = [] + model = model_info[0] + model_query = model_info[1] + filter_dict = {model_query: obj.id} + # Only fetch the IDs since we will make a list of IDs in the following function call + objects_to_delete = model.objects.only("id").filter(**filter_dict).distinct().order_by("id") + logger.debug("ASYNC_DELETE: Deleting " + str(len(objects_to_delete)) + " " + _get_object_name(model) + "s in chunks") + chunk_size = get_setting("ASYNC_OBEJECT_DELETE_CHUNK_SIZE") + chunks = [objects_to_delete[i:i + chunk_size] for i in range(0, len(objects_to_delete), chunk_size)] + logger.debug("ASYNC_DELETE: Split " + _get_object_name(model) + " into " + str(len(chunks)) + " chunks of " + str(chunk_size)) + for chunk in chunks: + logger.debug(f"deleting {len(chunk)} {_get_object_name(model)}") + result = dojo_dispatch_task(async_delete_chunk_task, list(chunk)) + # Collect async task results to wait for them all at once + if hasattr(result, "get"): + task_results.append(result) + # Wait for all chunk deletions to complete (they run in parallel) + for task_result in task_results: + task_result.get(timeout=300) # 5 minute timeout per chunk + # Now delete the main object after all chunks are done + result = dojo_dispatch_task(async_delete_chunk_task, [obj]) + # Wait for final deletion to complete + if hasattr(result, "get"): + result.get(timeout=300) # 5 minute timeout + logger.debug("ASYNC_DELETE: Successfully deleted " + _get_object_name(obj) + ": " + str(obj)) + + +@app.task +def async_delete_task(obj, **kwargs): + """ + Module-level Celery task to delete an object and its related objects. + + Accepts **kwargs for async_user and _pgh_context injected by dojo_dispatch_task. + Uses PgHistoryTask base class (default) to preserve pghistory context for audit trail. + """ + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + logger.debug("ASYNC_DELETE: Deleting " + _get_object_name(obj) + ": " + str(obj)) + model_list = ASYNC_DELETE_MAPPING.get(_get_object_name(obj)) + if model_list: + # The object to be deleted was found in the object list + dojo_dispatch_task(async_delete_crawl_task, obj, model_list) + else: + # The object is not supported in async delete, delete normally + logger.debug("ASYNC_DELETE: " + _get_object_name(obj) + " async delete not supported. Deleteing normally: " + str(obj)) + obj.delete() + + +class async_delete: + + """ + Entry point class for async object deletion. + + Usage: + async_del = async_delete() + async_del.delete(instance) + + This class dispatches deletion to module-level Celery tasks via dojo_dispatch_task, + which properly handles user context injection and pghistory context. + """ + + def __init__(self, *args, **kwargs): + # Keep mapping reference for backwards compatibility + self.mapping = ASYNC_DELETE_MAPPING + def delete(self, obj, **kwargs): - logger.debug("ASYNC_DELETE: Deleting " + self.get_object_name(obj) + ": " + str(obj)) - model_list = self.mapping.get(self.get_object_name(obj), None) - if model_list: - # The object to be deleted was found in the object list - self.crawl(obj, model_list) - else: - # The object is not supported in async delete, delete normally - logger.debug("ASYNC_DELETE: " + self.get_object_name(obj) + " async delete not supported. Deleteing normally: " + str(obj)) - obj.delete() - - @app.task - def crawl(self, obj, model_list, **kwargs): - logger.debug("ASYNC_DELETE: Crawling " + self.get_object_name(obj) + ": " + str(obj)) - for model_info in model_list: - task_results = [] - model = model_info[0] - model_query = model_info[1] - filter_dict = {model_query: obj.id} - # Only fetch the IDs since we will make a list of IDs in the following function call - objects_to_delete = model.objects.only("id").filter(**filter_dict).distinct().order_by("id") - logger.debug("ASYNC_DELETE: Deleting " + str(len(objects_to_delete)) + " " + self.get_object_name(model) + "s in chunks") - chunks = self.chunk_list(model, objects_to_delete) - for chunk in chunks: - logger.debug(f"deleting {len(chunk)} {self.get_object_name(model)}") - result = self.delete_chunk(chunk) - # Collect async task results to wait for them all at once - if hasattr(result, "get"): - task_results.append(result) - # Wait for all chunk deletions to complete (they run in parallel) - for task_result in task_results: - task_result.get(timeout=300) # 5 minute timeout per chunk - # Now delete the main object after all chunks are done - result = self.delete_chunk([obj]) - # Wait for final deletion to complete - if hasattr(result, "get"): - result.get(timeout=300) # 5 minute timeout - logger.debug("ASYNC_DELETE: Successfully deleted " + self.get_object_name(obj) + ": " + str(obj)) - - def chunk_list(self, model, full_list): + """ + Entry point to delete an object asynchronously. + + Dispatches to async_delete_task via dojo_dispatch_task to ensure proper + handling of async_user and _pgh_context. + """ + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(async_delete_task, obj, **kwargs) + + # Keep helper methods for backwards compatibility and potential direct use + @staticmethod + def get_object_name(obj): + return _get_object_name(obj) + + @staticmethod + def chunk_list(model, full_list): chunk_size = get_setting("ASYNC_OBEJECT_DELETE_CHUNK_SIZE") - # Break the list of objects into "chunk_size" lists chunk_list = [full_list[i:i + chunk_size] for i in range(0, len(full_list), chunk_size)] - logger.debug("ASYNC_DELETE: Split " + self.get_object_name(model) + " into " + str(len(chunk_list)) + " chunks of " + str(chunk_size)) + logger.debug("ASYNC_DELETE: Split " + _get_object_name(model) + " into " + str(len(chunk_list)) + " chunks of " + str(chunk_size)) return chunk_list - def get_object_name(self, obj): - if obj.__class__.__name__ == "ModelBase": - return obj.__name__ - return obj.__class__.__name__ - @receiver(user_logged_in) def log_user_login(sender, request, user, **kwargs): diff --git a/unittests/test_async_delete.py b/unittests/test_async_delete.py new file mode 100644 index 00000000000..341723e8296 --- /dev/null +++ b/unittests/test_async_delete.py @@ -0,0 +1,313 @@ +""" +Unit tests for async_delete functionality. + +These tests verify that the async_delete class works correctly with dojo_dispatch_task, +which injects async_user and _pgh_context kwargs into task calls. + +The original bug was that @app.task decorated instance methods didn't properly handle +the injected kwargs, causing TypeError: unexpected keyword argument 'async_user'. +""" +import logging + +from crum import impersonate +from django.contrib.auth.models import User +from django.test import override_settings +from django.utils import timezone + +from dojo.models import Engagement, Finding, Product, Product_Type, Test, Test_Type, UserContactInfo +from dojo.utils import async_delete + +from .dojo_test_case import DojoTestCase + +logger = logging.getLogger(__name__) + + +class TestAsyncDelete(DojoTestCase): + + """ + Test async_delete functionality with dojo_dispatch_task kwargs injection. + + These tests use block_execution=True and crum.impersonate to run tasks synchronously, + which allows errors to surface immediately rather than being lost in background workers. + """ + + def setUp(self): + """Set up test user with block_execution=True and disable unneeded features.""" + super().setUp() + + # Create test user with block_execution=True to run tasks synchronously + self.testuser = User.objects.create( + username="test_async_delete_user", + is_staff=True, + is_superuser=True, + ) + UserContactInfo.objects.create(user=self.testuser, block_execution=True) + + # Log in as the test user (for API client) + self.client.force_login(self.testuser) + + # Disable features that might interfere with deletion + self.system_settings(enable_product_grade=False) + self.system_settings(enable_github=False) + self.system_settings(enable_jira=False) + + # Create base test data + self.product_type = Product_Type.objects.create(name="Test Product Type for Async Delete") + self.test_type = Test_Type.objects.get_or_create(name="Manual Test")[0] + + def tearDown(self): + """Clean up any remaining test data.""" + # Clean up in reverse order of dependencies + Finding.objects.filter(test__engagement__product__prod_type=self.product_type).delete() + Test.objects.filter(engagement__product__prod_type=self.product_type).delete() + Engagement.objects.filter(product__prod_type=self.product_type).delete() + Product.objects.filter(prod_type=self.product_type).delete() + self.product_type.delete() + + super().tearDown() + + def _create_product(self, name="Test Product"): + """Helper to create a product for testing.""" + return Product.objects.create( + name=name, + description="Test product for async delete", + prod_type=self.product_type, + ) + + def _create_engagement(self, product, name="Test Engagement"): + """Helper to create an engagement for testing.""" + return Engagement.objects.create( + name=name, + product=product, + target_start=timezone.now(), + target_end=timezone.now(), + ) + + def _create_test(self, engagement, name="Test"): + """Helper to create a test for testing.""" + return Test.objects.create( + engagement=engagement, + test_type=self.test_type, + target_start=timezone.now(), + target_end=timezone.now(), + ) + + def _create_finding(self, test, title="Test Finding"): + """Helper to create a finding for testing.""" + return Finding.objects.create( + test=test, + title=title, + severity="High", + description="Test finding for async delete", + mitigation="Test mitigation", + impact="Test impact", + reporter=self.testuser, + ) + + @override_settings(ASYNC_OBJECT_DELETE=True) + def test_async_delete_simple_object(self): + """ + Test that async_delete works for a simple object (Finding). + + Finding is not in the async_delete mapping, so it falls back to direct delete. + This tests that the module-level task accepts **kwargs properly. + """ + product = self._create_product() + engagement = self._create_engagement(product) + test = self._create_test(engagement) + finding = self._create_finding(test) + finding_pk = finding.pk + + # Use impersonate to set current user context (required for block_execution to work) + with impersonate(self.testuser): + # This would raise TypeError before the fix: + # TypeError: delete() got an unexpected keyword argument 'async_user' + async_del = async_delete() + async_del.delete(finding) + + # Verify the finding was deleted + self.assertFalse( + Finding.objects.filter(pk=finding_pk).exists(), + "Finding should be deleted", + ) + + @override_settings(ASYNC_OBJECT_DELETE=True) + def test_async_delete_test_with_findings(self): + """ + Test that async_delete cascades deletion for Test objects. + + Test is in the async_delete mapping and should cascade delete its findings. + """ + product = self._create_product() + engagement = self._create_engagement(product) + test = self._create_test(engagement) + finding1 = self._create_finding(test, "Finding 1") + finding2 = self._create_finding(test, "Finding 2") + + test_pk = test.pk + finding1_pk = finding1.pk + finding2_pk = finding2.pk + + # Use impersonate to set current user context (required for block_execution to work) + with impersonate(self.testuser): + # Delete the test (should cascade to findings) + async_del = async_delete() + async_del.delete(test) + + # Verify all objects were deleted + self.assertFalse( + Test.objects.filter(pk=test_pk).exists(), + "Test should be deleted", + ) + self.assertFalse( + Finding.objects.filter(pk=finding1_pk).exists(), + "Finding 1 should be deleted via cascade", + ) + self.assertFalse( + Finding.objects.filter(pk=finding2_pk).exists(), + "Finding 2 should be deleted via cascade", + ) + + @override_settings(ASYNC_OBJECT_DELETE=True) + def test_async_delete_engagement_with_tests(self): + """ + Test that async_delete cascades deletion for Engagement objects. + + Engagement is in the async_delete mapping and should cascade delete + its tests and findings. + """ + product = self._create_product() + engagement = self._create_engagement(product) + test1 = self._create_test(engagement, "Test 1") + test2 = self._create_test(engagement, "Test 2") + finding1 = self._create_finding(test1, "Finding in Test 1") + finding2 = self._create_finding(test2, "Finding in Test 2") + + engagement_pk = engagement.pk + test1_pk = test1.pk + test2_pk = test2.pk + finding1_pk = finding1.pk + finding2_pk = finding2.pk + + # Use impersonate to set current user context (required for block_execution to work) + with impersonate(self.testuser): + # Delete the engagement (should cascade to tests and findings) + async_del = async_delete() + async_del.delete(engagement) + + # Verify all objects were deleted + self.assertFalse( + Engagement.objects.filter(pk=engagement_pk).exists(), + "Engagement should be deleted", + ) + self.assertFalse( + Test.objects.filter(pk__in=[test1_pk, test2_pk]).exists(), + "Tests should be deleted via cascade", + ) + self.assertFalse( + Finding.objects.filter(pk__in=[finding1_pk, finding2_pk]).exists(), + "Findings should be deleted via cascade", + ) + + @override_settings(ASYNC_OBJECT_DELETE=True) + def test_async_delete_product_with_hierarchy(self): + """ + Test that async_delete cascades deletion for Product objects. + + Product is in the async_delete mapping and should cascade delete + its engagements, tests, and findings. + """ + product = self._create_product() + engagement = self._create_engagement(product) + test = self._create_test(engagement) + finding = self._create_finding(test) + + product_pk = product.pk + engagement_pk = engagement.pk + test_pk = test.pk + finding_pk = finding.pk + + # Use impersonate to set current user context (required for block_execution to work) + with impersonate(self.testuser): + # Delete the product (should cascade to everything) + async_del = async_delete() + async_del.delete(product) + + # Verify all objects were deleted + self.assertFalse( + Product.objects.filter(pk=product_pk).exists(), + "Product should be deleted", + ) + self.assertFalse( + Engagement.objects.filter(pk=engagement_pk).exists(), + "Engagement should be deleted via cascade", + ) + self.assertFalse( + Test.objects.filter(pk=test_pk).exists(), + "Test should be deleted via cascade", + ) + self.assertFalse( + Finding.objects.filter(pk=finding_pk).exists(), + "Finding should be deleted via cascade", + ) + + @override_settings(ASYNC_OBJECT_DELETE=True) + def test_async_delete_accepts_sync_kwarg(self): + """ + Test that async_delete passes through the sync kwarg properly. + + The sync=True kwarg forces synchronous execution for the top-level task. + However, nested task dispatches still need user context to run synchronously, + so we use impersonate here as well. + """ + product = self._create_product() + product_pk = product.pk + + # Use impersonate to ensure nested tasks also run synchronously + with impersonate(self.testuser): + # Explicitly pass sync=True + async_del = async_delete() + async_del.delete(product, sync=True) + + # Verify the product was deleted + self.assertFalse( + Product.objects.filter(pk=product_pk).exists(), + "Product should be deleted with sync=True", + ) + + def test_async_delete_helper_methods(self): + """ + Test that static helper methods on async_delete class still work. + + These are kept for backwards compatibility. + """ + product = self._create_product() + + # Test get_object_name + self.assertEqual( + async_delete.get_object_name(product), + "Product", + "get_object_name should return class name", + ) + + # Test get_object_name with model class + self.assertEqual( + async_delete.get_object_name(Product), + "Product", + "get_object_name should work with model class", + ) + + def test_async_delete_mapping_preserved(self): + """ + Test that the mapping attribute is preserved on async_delete instances. + + This ensures backwards compatibility for code that might access the mapping. + """ + async_del = async_delete() + + # Verify mapping exists and has expected keys + self.assertIsNotNone(async_del.mapping) + self.assertIn("Product", async_del.mapping) + self.assertIn("Product_Type", async_del.mapping) + self.assertIn("Engagement", async_del.mapping) + self.assertIn("Test", async_del.mapping) From 945e35923df0b4174f94f7c938e59627e3ea76da Mon Sep 17 00:00:00 2001 From: Cody Maffucci <46459665+Maffooch@users.noreply.github.com> Date: Fri, 16 Jan 2026 17:34:28 -0700 Subject: [PATCH 26/36] Add additional fields to AssetSerializer for business criticality, platform, lifecycle, and origin --- dojo/asset/api/serializers.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/dojo/asset/api/serializers.py b/dojo/asset/api/serializers.py index 688d772ce9b..70b639345a6 100644 --- a/dojo/asset/api/serializers.py +++ b/dojo/asset/api/serializers.py @@ -41,7 +41,13 @@ class AssetSerializer(serializers.ModelSerializer): enable_asset_tag_inheritance = serializers.BooleanField(source="enable_product_tag_inheritance") asset_managers = serializers.PrimaryKeyRelatedField( source="product_manager", - queryset=Dojo_User.objects.exclude(is_active=False)) + queryset=Dojo_User.objects.exclude(is_active=False), + required=False, allow_null=True, + ) + business_criticality = serializers.ChoiceField(choices=Product.BUSINESS_CRITICALITY_CHOICES, allow_blank=True, allow_null=True, required=False) + platform = serializers.ChoiceField(choices=Product.PLATFORM_CHOICES, allow_blank=True, allow_null=True, required=False) + lifecycle = serializers.ChoiceField(choices=Product.LIFECYCLE_CHOICES, allow_blank=True, allow_null=True, required=False) + origin = serializers.ChoiceField(choices=Product.ORIGIN_CHOICES, allow_blank=True, allow_null=True, required=False) class Meta: model = Product From 70f381eb543e403e170c1e6c1ee0450bff114078 Mon Sep 17 00:00:00 2001 From: Cody Maffucci <46459665+Maffooch@users.noreply.github.com> Date: Fri, 16 Jan 2026 17:50:08 -0700 Subject: [PATCH 27/36] Correct some filters too --- dojo/asset/api/filters.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/dojo/asset/api/filters.py b/dojo/asset/api/filters.py index 991fd329ac8..91281a9ebe2 100644 --- a/dojo/asset/api/filters.py +++ b/dojo/asset/api/filters.py @@ -8,12 +8,13 @@ CharFieldInFilter, DateRangeFilter, DojoFilter, + MultipleChoiceFilter, NumberInFilter, ProductSLAFilter, - custom_filter, ) from dojo.labels import get_labels from dojo.models import ( + Product, Product_API_Scan_Configuration, Product_Group, Product_Member, @@ -38,18 +39,18 @@ class ApiAssetFilter(DojoFilter): name = CharFilter(lookup_expr="icontains") name_exact = CharFilter(field_name="name", lookup_expr="iexact") description = CharFilter(lookup_expr="icontains") - business_criticality = CharFilter(method=custom_filter, field_name="business_criticality") - platform = CharFilter(method=custom_filter, field_name="platform") - lifecycle = CharFilter(method=custom_filter, field_name="lifecycle") - origin = CharFilter(method=custom_filter, field_name="origin") + business_criticality = MultipleChoiceFilter(choices=Product.BUSINESS_CRITICALITY_CHOICES) + platform = MultipleChoiceFilter(choices=Product.PLATFORM_CHOICES) + lifecycle = MultipleChoiceFilter(choices=Product.LIFECYCLE_CHOICES) + origin = MultipleChoiceFilter(choices=Product.ORIGIN_CHOICES) # NumberInFilter id = NumberInFilter(field_name="id", lookup_expr="in") asset_manager = NumberInFilter(field_name="product_manager", lookup_expr="in") technical_contact = NumberInFilter(field_name="technical_contact", lookup_expr="in") team_manager = NumberInFilter(field_name="team_manager", lookup_expr="in") - prod_type = NumberInFilter(field_name="prod_type", lookup_expr="in") + organization = NumberInFilter(field_name="prod_type", lookup_expr="in") tid = NumberInFilter(field_name="tid", lookup_expr="in") - prod_numeric_grade = NumberInFilter(field_name="prod_numeric_grade", lookup_expr="in") + asset_numeric_grade = NumberInFilter(field_name="prod_numeric_grade", lookup_expr="in") user_records = NumberInFilter(field_name="user_records", lookup_expr="in") regulations = NumberInFilter(field_name="regulations", lookup_expr="in") @@ -80,7 +81,7 @@ class ApiAssetFilter(DojoFilter): ("tid", "tid"), ("name", "name"), ("created", "created"), - ("prod_numeric_grade", "prod_numeric_grade"), + ("prod_numeric_grade", "asset_numeric_grade"), ("business_criticality", "business_criticality"), ("platform", "platform"), ("lifecycle", "lifecycle"), @@ -97,8 +98,8 @@ class ApiAssetFilter(DojoFilter): ("team_manager", "team_manager"), ("team_manager__first_name", "team_manager__first_name"), ("team_manager__last_name", "team_manager__last_name"), - ("prod_type", "prod_type"), - ("prod_type__name", "prod_type__name"), + ("prod_type", "organization"), + ("prod_type__name", "organization__name"), ("updated", "updated"), ("user_records", "user_records"), ), From 4e60765f9fc8ff0e54d9e49b752f6f3311c75b0f Mon Sep 17 00:00:00 2001 From: Cody Maffucci <46459665+Maffooch@users.noreply.github.com> Date: Mon, 12 Jan 2026 13:04:34 -0700 Subject: [PATCH 28/36] Update AssetSerializer fields to allow null values and set defaults --- dojo/asset/api/serializers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dojo/asset/api/serializers.py b/dojo/asset/api/serializers.py index 70b639345a6..4b3bd0e9d5d 100644 --- a/dojo/asset/api/serializers.py +++ b/dojo/asset/api/serializers.py @@ -37,8 +37,8 @@ class AssetSerializer(serializers.ModelSerializer): # V3 fields asset_meta = ProductMetaSerializer(source="product_meta", read_only=True, many=True) organization = RelatedOrganizationField(source="prod_type") - asset_numeric_grade = serializers.IntegerField(source="prod_numeric_grade") - enable_asset_tag_inheritance = serializers.BooleanField(source="enable_product_tag_inheritance") + asset_numeric_grade = serializers.IntegerField(source="prod_numeric_grade", required=False, allow_null=True) + enable_asset_tag_inheritance = serializers.BooleanField(source="enable_product_tag_inheritance", required=False, default=False) asset_managers = serializers.PrimaryKeyRelatedField( source="product_manager", queryset=Dojo_User.objects.exclude(is_active=False), From cf5c84c1af70600cfd1d73cd371f59fcb4528701 Mon Sep 17 00:00:00 2001 From: Cody Maffucci <46459665+Maffooch@users.noreply.github.com> Date: Mon, 12 Jan 2026 13:04:44 -0700 Subject: [PATCH 29/36] Refactor authorization functions to use type hints for better clarity and maintainability --- dojo/authorization/authorization.py | 44 +++++++++++++++-------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/dojo/authorization/authorization.py b/dojo/authorization/authorization.py index b410bb3a95d..840eeb7ea35 100644 --- a/dojo/authorization/authorization.py +++ b/dojo/authorization/authorization.py @@ -1,4 +1,5 @@ from django.core.exceptions import PermissionDenied +from django.db.models import Model, QuerySet from dojo.authorization.roles_permissions import ( Permissions, @@ -11,6 +12,7 @@ Cred_Mapping, Dojo_Group, Dojo_Group_Member, + Dojo_User, Endpoint, Engagement, Finding, @@ -30,7 +32,7 @@ from dojo.request_cache import cache_for_request -def user_has_configuration_permission(user, permission): +def user_has_configuration_permission(user: Dojo_User, permission: str): if not user: return False @@ -40,7 +42,7 @@ def user_has_configuration_permission(user, permission): return user.has_perm(permission) -def user_is_superuser_or_global_owner(user): +def user_is_superuser_or_global_owner(user: Dojo_User) -> bool: """ Returns True if the user is a superuser or has a global role (directly or via group membership) whose Role.is_owner is True. @@ -69,7 +71,7 @@ def user_is_superuser_or_global_owner(user): return False -def user_has_permission(user, obj, permission): +def user_has_permission(user: Dojo_User, obj: Model, permission: int) -> bool: if user.is_anonymous: return False @@ -229,7 +231,7 @@ def user_has_permission(user, obj, permission): raise NoAuthorizationImplementedError(msg) -def user_has_global_permission(user, permission): +def user_has_global_permission(user: Dojo_User, permission: int) -> bool: if not user: return False @@ -263,22 +265,22 @@ def user_has_global_permission(user, permission): return False -def user_has_configuration_permission_or_403(user, permission): +def user_has_configuration_permission_or_403(user: Dojo_User, permission: str) -> None: if not user_has_configuration_permission(user, permission): raise PermissionDenied -def user_has_permission_or_403(user, obj, permission): +def user_has_permission_or_403(user: Dojo_User, obj: Model, permission: int) -> None: if not user_has_permission(user, obj, permission): raise PermissionDenied -def user_has_global_permission_or_403(user, permission): +def user_has_global_permission_or_403(user: Dojo_User, permission: int) -> None: if not user_has_global_permission(user, permission): raise PermissionDenied -def get_roles_for_permission(permission): +def get_roles_for_permission(permission: int) -> set[int]: if not Permissions.has_value(permission): msg = f"Permission {permission} does not exist" raise PermissionDoesNotExistError(msg) @@ -291,7 +293,7 @@ def get_roles_for_permission(permission): return roles_for_permissions -def role_has_permission(role, permission): +def role_has_permission(role: int, permission: int) -> bool: if role is None: return False if not Roles.has_value(role): @@ -304,7 +306,7 @@ def role_has_permission(role, permission): return permission in permissions -def role_has_global_permission(role, permission): +def role_has_global_permission(role: int, permission: int) -> bool: if role is None: return False if not Roles.has_value(role): @@ -332,12 +334,12 @@ def __init__(self, message): self.message = message -def get_product_member(user, product): +def get_product_member(user: Dojo_User, product: Product) -> Product_Member | None: return get_product_member_dict(user).get(product.id) @cache_for_request -def get_product_member_dict(user): +def get_product_member_dict(user: Dojo_User) -> dict[int, Product_Member]: pm_dict = {} for product_member in ( Product_Member.objects.select_related("product") @@ -348,12 +350,12 @@ def get_product_member_dict(user): return pm_dict -def get_product_type_member(user, product_type): +def get_product_type_member(user: Dojo_User, product_type: Product_Type) -> Product_Type_Member | None: return get_product_type_member_dict(user).get(product_type.id) @cache_for_request -def get_product_type_member_dict(user): +def get_product_type_member_dict(user: Dojo_User) -> dict[int, Product_Type_Member]: ptm_dict = {} for product_type_member in ( Product_Type_Member.objects.select_related("product_type") @@ -364,12 +366,12 @@ def get_product_type_member_dict(user): return ptm_dict -def get_product_groups(user, product): +def get_product_groups(user: Dojo_User, product: Product) -> list[Product_Group]: return get_product_groups_dict(user).get(product.id, []) @cache_for_request -def get_product_groups_dict(user): +def get_product_groups_dict(user: Dojo_User) -> dict[int, list[Product_Group]]: pg_dict = {} for product_group in ( Product_Group.objects.select_related("product") @@ -382,12 +384,12 @@ def get_product_groups_dict(user): return pg_dict -def get_product_type_groups(user, product_type): +def get_product_type_groups(user: Dojo_User, product_type: Product_Type) -> list[Product_Type_Group]: return get_product_type_groups_dict(user).get(product_type.id, []) @cache_for_request -def get_product_type_groups_dict(user): +def get_product_type_groups_dict(user: Dojo_User) -> dict[int, list[Product_Type_Group]]: pgt_dict = {} for product_type_group in ( Product_Type_Group.objects.select_related("product_type") @@ -404,16 +406,16 @@ def get_product_type_groups_dict(user): @cache_for_request -def get_groups(user): +def get_groups(user: Dojo_User) -> QuerySet[Dojo_Group]: return Dojo_Group.objects.select_related("global_role").filter(users=user) -def get_group_member(user, group): +def get_group_member(user: Dojo_User, group: Dojo_Group) -> dict[int, Dojo_Group_Member]: return get_group_members_dict(user).get(group.id) @cache_for_request -def get_group_members_dict(user): +def get_group_members_dict(user: Dojo_User) -> dict[int, Dojo_Group_Member]: gu_dict = {} for group_member in ( Dojo_Group_Member.objects.select_related("group") From 214186dd987674ecc96350dd7cfc418eb75fabea Mon Sep 17 00:00:00 2001 From: Cody Maffucci <46459665+Maffooch@users.noreply.github.com> Date: Mon, 12 Jan 2026 13:04:52 -0700 Subject: [PATCH 30/36] Enhance permission checks to support multiple primary key attributes in post requests --- dojo/api_v2/permissions.py | 65 +++++++++++++++++++++++--------------- 1 file changed, 39 insertions(+), 26 deletions(-) diff --git a/dojo/api_v2/permissions.py b/dojo/api_v2/permissions.py index 421fb87b526..166ff05ec41 100644 --- a/dojo/api_v2/permissions.py +++ b/dojo/api_v2/permissions.py @@ -1,5 +1,7 @@ import re +from collections.abc import Iterable +from django.db.models import Model from django.shortcuts import get_object_or_404 from rest_framework import permissions, serializers from rest_framework.exceptions import ( @@ -7,6 +9,7 @@ PermissionDenied, ValidationError, ) +from rest_framework.request import Request from dojo.authorization.authorization import ( user_has_configuration_permission, @@ -29,24 +32,34 @@ ) -def check_post_permission(request, post_model, post_pk, post_permission): +def check_post_permission(request: Request, post_model: Model, post_pk: str | Iterable[str], post_permission: int) -> bool: if request.method == "POST": - if request.data.get(post_pk) is None: - msg = f"Unable to check for permissions: Attribute '{post_pk}' is required" + eligible_post_pk = None + # Support both single PK string and list of PK strings + searchable_post_pks = post_pk if isinstance(post_pk, Iterable) else [post_pk] + # Iterate until we find a matching PK in the request data + for pk in searchable_post_pks: + if request.data.get(pk) is not None: + eligible_post_pk = pk + break + # Raise an error if we never find anything + if eligible_post_pk is None: + msg = f"Unable to check for permissions: No valid attribute in '{post_pk}' is required" raise ParseError(msg) - obj = get_object_or_404(post_model, pk=request.data.get(post_pk)) + # Attempt to get the object + obj = get_object_or_404(post_model, pk=request.data.get(eligible_post_pk)) return user_has_permission(request.user, obj, post_permission) return True def check_object_permission( - request, - obj, - get_permission, - put_permission, - delete_permission, - post_permission=None, -): + request: Request, + obj: Model, + get_permission: int, + put_permission: int, + delete_permission: int, + post_permission: int | None = None, +) -> bool: if request.method == "GET": return user_has_permission(request.user, obj, get_permission) if request.method in {"PUT", "PATCH"}: @@ -61,7 +74,7 @@ def check_object_permission( class UserHasAppAnalysisPermission(permissions.BasePermission): def has_permission(self, request, view): return check_post_permission( - request, Product, "product", Permissions.Technology_Add, + request, Product, ["product", "asset"], Permissions.Technology_Add, ) def has_object_permission(self, request, view, obj): @@ -78,7 +91,7 @@ class UserHasCredentialPermission(permissions.BasePermission): def has_permission(self, request, view): if request.data.get("product") is not None: return check_post_permission( - request, Cred_Mapping, "product", Permissions.Credential_Add, + request, Cred_Mapping, ["product", "asset"], Permissions.Credential_Add, ) if request.data.get("engagement") is not None: return check_post_permission( @@ -93,7 +106,7 @@ def has_permission(self, request, view): request, Cred_Mapping, "finding", Permissions.Credential_Add, ) return check_post_permission( - request, Cred_Mapping, "product", Permissions.Credential_Add, + request, Cred_Mapping, ["product", "asset"], Permissions.Credential_Add, ) def has_object_permission(self, request, view, obj): @@ -231,7 +244,7 @@ def has_object_permission(self, request, view, obj): class UserHasToolProductSettingsPermission(permissions.BasePermission): def has_permission(self, request, view): return check_post_permission( - request, Product, "product", Permissions.Product_Edit, + request, Product, ["product", "asset"], Permissions.Product_Edit, ) def has_object_permission(self, request, view, obj): @@ -247,7 +260,7 @@ def has_object_permission(self, request, view, obj): class UserHasEndpointPermission(permissions.BasePermission): def has_permission(self, request, view): return check_post_permission( - request, Product, "product", Permissions.Endpoint_Add, + request, Product, ["product", "asset"], Permissions.Endpoint_Add, ) def has_object_permission(self, request, view, obj): @@ -287,7 +300,7 @@ def has_permission(self, request, view): request.path, ) or UserHasEngagementPermission.path_engagement.match(request.path): return check_post_permission( - request, Product, "product", Permissions.Engagement_Add, + request, Product, ["product", "asset"], Permissions.Engagement_Add, ) # related object only need object permission return True @@ -326,7 +339,7 @@ def has_permission(self, request, view): request.path, ): return check_post_permission( - request, Product, "product", Permissions.Risk_Acceptance, + request, Product, ["product", "asset"], Permissions.Risk_Acceptance, ) # related object only need object permission return True @@ -493,7 +506,7 @@ def has_permission(self, request, view): return check_post_permission( request, Product_Type, - "prod_type", + ["prod_type", "organization"], Permissions.Product_Type_Add_Product, ) @@ -510,7 +523,7 @@ def has_object_permission(self, request, view, obj): class UserHasProductMemberPermission(permissions.BasePermission): def has_permission(self, request, view): return check_post_permission( - request, Product, "product", Permissions.Product_Manage_Members, + request, Product, ["product", "asset"], Permissions.Product_Manage_Members, ) def has_object_permission(self, request, view, obj): @@ -526,7 +539,7 @@ def has_object_permission(self, request, view, obj): class UserHasProductGroupPermission(permissions.BasePermission): def has_permission(self, request, view): return check_post_permission( - request, Product, "product", Permissions.Product_Group_Add, + request, Product, ["product", "asset"], Permissions.Product_Group_Add, ) def has_object_permission(self, request, view, obj): @@ -562,7 +575,7 @@ def has_permission(self, request, view): return check_post_permission( request, Product_Type, - "product_type", + ["product_type", "organization"], Permissions.Product_Type_Manage_Members, ) @@ -581,7 +594,7 @@ def has_permission(self, request, view): return check_post_permission( request, Product_Type, - "product_type", + ["product_type", "organization"], Permissions.Product_Type_Group_Add, ) @@ -707,7 +720,7 @@ def has_object_permission(self, request, view, obj): class UserHasLanguagePermission(permissions.BasePermission): def has_permission(self, request, view): return check_post_permission( - request, Product, "product", Permissions.Language_Add, + request, Product, ["product", "asset"], Permissions.Language_Add, ) def has_object_permission(self, request, view, obj): @@ -725,7 +738,7 @@ def has_permission(self, request, view): return check_post_permission( request, Product, - "product", + ["product", "asset"], Permissions.Product_API_Scan_Configuration_Add, ) @@ -881,7 +894,7 @@ def has_permission(self, request, view): class UserHasEngagementPresetPermission(permissions.BasePermission): def has_permission(self, request, view): return check_post_permission( - request, Product, "product", Permissions.Product_Edit, + request, Product, ["product", "asset"], Permissions.Product_Edit, ) def has_object_permission(self, request, view, obj): From 08d488b5ace6484239c05db04d5260d8a94bb04e Mon Sep 17 00:00:00 2001 From: Cody Maffucci <46459665+Maffooch@users.noreply.github.com> Date: Mon, 12 Jan 2026 14:10:46 -0700 Subject: [PATCH 31/36] Refactor check_post_permission to use list type for post_pk parameter --- dojo/api_v2/permissions.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/dojo/api_v2/permissions.py b/dojo/api_v2/permissions.py index 166ff05ec41..80611d262e8 100644 --- a/dojo/api_v2/permissions.py +++ b/dojo/api_v2/permissions.py @@ -1,5 +1,4 @@ import re -from collections.abc import Iterable from django.db.models import Model from django.shortcuts import get_object_or_404 @@ -32,11 +31,11 @@ ) -def check_post_permission(request: Request, post_model: Model, post_pk: str | Iterable[str], post_permission: int) -> bool: +def check_post_permission(request: Request, post_model: Model, post_pk: str | list[str], post_permission: int) -> bool: if request.method == "POST": eligible_post_pk = None # Support both single PK string and list of PK strings - searchable_post_pks = post_pk if isinstance(post_pk, Iterable) else [post_pk] + searchable_post_pks = post_pk if isinstance(post_pk, list) else [post_pk] # Iterate until we find a matching PK in the request data for pk in searchable_post_pks: if request.data.get(pk) is not None: From 34359c904a76c7b4802f7bb5e454f981a89c13b5 Mon Sep 17 00:00:00 2001 From: Cody Maffucci <46459665+Maffooch@users.noreply.github.com> Date: Mon, 12 Jan 2026 14:11:03 -0700 Subject: [PATCH 32/36] Refactor Organization serializers to handle default values for critical and key assets, and update OrganizationViewSet to use OrganizationFilterSet for filtering. --- dojo/organization/api/serializers.py | 6 +++--- dojo/organization/api/views.py | 13 +++++-------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/dojo/organization/api/serializers.py b/dojo/organization/api/serializers.py index d624c72524d..73eb68e0294 100644 --- a/dojo/organization/api/serializers.py +++ b/dojo/organization/api/serializers.py @@ -51,7 +51,7 @@ def validate(self, data): if self.instance is not None and not data.get("role").is_owner: owners = ( Product_Type_Member.objects.filter( - product_type=data.get("organization"), role__is_owner=True, + product_type=data.get("organization", data.get("product_type")), role__is_owner=True, ) .exclude(id=self.instance.id) .count() @@ -115,8 +115,8 @@ def validate(self, data): class OrganizationSerializer(serializers.ModelSerializer): - critical_asset = serializers.BooleanField(source="critical_product") - key_asset = serializers.BooleanField(source="key_product") + critical_asset = serializers.BooleanField(source="critical_product", default=False) + key_asset = serializers.BooleanField(source="key_product", default=False) class Meta: model = Product_Type diff --git a/dojo/organization/api/views.py b/dojo/organization/api/views.py index dc9f3fc0cc2..ea3aa9005d9 100644 --- a/dojo/organization/api/views.py +++ b/dojo/organization/api/views.py @@ -17,6 +17,7 @@ ) from dojo.organization.api import serializers from dojo.organization.api.filters import ( + OrganizationFilterSet, OrganizationGroupFilterSet, OrganizationMemberFilterSet, ) @@ -36,14 +37,7 @@ class OrganizationViewSet( serializer_class = serializers.OrganizationSerializer queryset = Product_Type.objects.none() filter_backends = (DjangoFilterBackend,) - filterset_fields = [ - "id", - "name", - "critical_product", - "key_product", - "created", - "updated", - ] + filterset_class = OrganizationFilterSet permission_classes = ( IsAuthenticated, permissions.UserHasProductTypePermission, @@ -60,6 +54,9 @@ def perform_create(self, serializer): product_type_data = serializer.data product_type_data.pop("authorization_groups") product_type_data.pop("members") + # Manage custom fields separately with default fields of false + product_type_data["critical_product"] = product_type_data.pop("critical_asset", False) + product_type_data["key_product"] = product_type_data.pop("key_asset", False) member = Product_Type_Member() member.user = self.request.user member.product_type = Product_Type(**product_type_data) From 7ca4e7c350637bc2de7131eedf9c69aa1e560af0 Mon Sep 17 00:00:00 2001 From: Cody Maffucci <46459665+Maffooch@users.noreply.github.com> Date: Mon, 12 Jan 2026 14:11:09 -0700 Subject: [PATCH 33/36] Refactor API tests to include asset and organization endpoints, enhancing coverage for asset-related functionalities. --- unittests/test_apiv2_methods_and_endpoints.py | 19 +- unittests/test_rest_framework.py | 194 ++++++++++++++++++ 2 files changed, 209 insertions(+), 4 deletions(-) diff --git a/unittests/test_apiv2_methods_and_endpoints.py b/unittests/test_apiv2_methods_and_endpoints.py index 3ca9f05cd24..38e40f96b89 100644 --- a/unittests/test_apiv2_methods_and_endpoints.py +++ b/unittests/test_apiv2_methods_and_endpoints.py @@ -49,15 +49,18 @@ def test_is_defined(self): "configuration_permissions", "questionnaire_questions", "questionnaire_answers", "questionnaire_answered_questionnaires", "questionnaire_engagement_questionnaires", "questionnaire_general_questionnaires", - "dojo_group_members", "product_members", "product_groups", "product_type_groups", - "product_type_members", "asset_members", "asset_groups", "organization_groups", - "organization_members", # pghistory Event models (should not be exposed via API) "dojo_userevents", "endpointevents", "engagementevents", "findingevents", "finding_groupevents", "product_typeevents", "productevents", "testevents", "risk_acceptanceevents", "finding_templateevents", "cred_userevents", "notification_webhooksevents", } + patch_exempt_list = { + "dojo_group_members", "product_members", "product_groups", "product_type_groups", + "product_type_members", "asset_members", "asset_groups", "organization_groups", + "organization_members", + } + for reg, _, _ in sorted(self.registry): if reg in exempt_list: continue @@ -67,7 +70,15 @@ def test_is_defined(self): f"Endpoint: {reg}, Method: {method}", ) - for method in ["get", "put", "patch", "delete"]: + for method in ["get", "put", "delete"]: + self.assertIsNotNone( + self.schema["paths"][f"{BASE_API_URL}/{reg}" + "/{id}/"].get(method), + f"Endpoint: {reg}, Method: {method}", + ) + + for method in ["patch"]: + if reg in patch_exempt_list: + continue self.assertIsNotNone( self.schema["paths"][f"{BASE_API_URL}/{reg}" + "/{id}/"].get(method), f"Endpoint: {reg}, Method: {method}", diff --git a/unittests/test_rest_framework.py b/unittests/test_rest_framework.py index 89969a09321..a7b9f898ecf 100644 --- a/unittests/test_rest_framework.py +++ b/unittests/test_rest_framework.py @@ -86,6 +86,12 @@ UserContactInfoViewSet, UsersViewSet, ) +from dojo.asset.api.views import ( + AssetAPIScanConfigurationViewSet, + AssetGroupViewSet, + AssetMemberViewSet, + AssetViewSet, +) from dojo.authorization.roles_permissions import Permissions from dojo.models import ( Announcement, @@ -140,6 +146,11 @@ User, UserContactInfo, ) +from dojo.organization.api.views import ( + OrganizationGroupViewSet, + OrganizationMemberViewSet, + OrganizationViewSet, +) from .dojo_test_case import DojoAPITestCase, get_unit_tests_scans_path @@ -1905,6 +1916,29 @@ def __init__(self, *args, **kwargs): BaseClass.RESTEndpointTest.__init__(self, *args, **kwargs) +class Asset_API_Scan_ConfigurationTest(BaseClass.BaseClassTest): + fixtures = ["dojo_testdata.json"] + + def __init__(self, *args, **kwargs): + self.endpoint_model = Product_API_Scan_Configuration + self.endpoint_path = "asset_api_scan_configurations" + self.viewname = "asset_api_scan_configuration" + self.viewset = AssetAPIScanConfigurationViewSet + self.payload = { + "asset": 2, + "service_key_1": "dojo_sonar_key", + "tool_configuration": 3, + } + self.update_fields = {"tool_configuration": 2} + self.test_type = TestType.OBJECT_PERMISSIONS + self.permission_check_class = Product_API_Scan_Configuration + self.permission_create = Permissions.Product_API_Scan_Configuration_Add + self.permission_update = Permissions.Product_API_Scan_Configuration_Edit + self.permission_delete = Permissions.Product_API_Scan_Configuration_Delete + self.deleted_objects = 1 + BaseClass.RESTEndpointTest.__init__(self, *args, **kwargs) + + class ProductTest(BaseClass.BaseClassTest): fixtures = ["dojo_testdata.json"] @@ -1932,6 +1966,33 @@ def __init__(self, *args, **kwargs): BaseClass.RESTEndpointTest.__init__(self, *args, **kwargs) +class AssetTest(BaseClass.BaseClassTest): + fixtures = ["dojo_testdata.json"] + + def __init__(self, *args, **kwargs): + self.endpoint_model = Product + self.endpoint_path = "assets" + self.viewname = "asset" + self.viewset = AssetViewSet + self.payload = { + "product_manager": 2, + "technical_contact": 3, + "team_manager": 2, + "organization": 1, + "name": "Test Product", + "description": "test product", + "tags": ["mytag", "yourtag"], + } + self.update_fields = {"organization": 2} + self.test_type = TestType.OBJECT_PERMISSIONS + self.permission_check_class = Product + self.permission_create = Permissions.Product_Type_Add_Product + self.permission_update = Permissions.Product_Edit + self.permission_delete = Permissions.Product_Delete + self.deleted_objects = 25 + BaseClass.RESTEndpointTest.__init__(self, *args, **kwargs) + + class StubFindingsTest(BaseClass.BaseClassTest): fixtures = ["dojo_testdata.json"] @@ -2873,6 +2934,47 @@ def test_create_authorized_owner(self): self.assertEqual(201, response.status_code, response.content[:1000]) +class OrganizationTest(BaseClass.BaseClassTest): + fixtures = ["dojo_testdata.json"] + + def __init__(self, *args, **kwargs): + self.endpoint_model = Product_Type + self.endpoint_path = "organizations" + self.viewname = "organization" + self.viewset = OrganizationViewSet + self.payload = { + "name": "Test Organization", + "description": "Test", + "key_product": True, + "critical_product": False, + } + self.update_fields = {"description": "changed"} + self.test_type = TestType.OBJECT_PERMISSIONS + self.permission_check_class = Product_Type + self.permission_update = Permissions.Product_Type_Edit + self.permission_delete = Permissions.Product_Type_Delete + self.deleted_objects = 25 + BaseClass.RESTEndpointTest.__init__(self, *args, **kwargs) + + def test_create_object_not_authorized(self): + self.setUp_not_authorized() + + response = self.client.post(self.url, self.payload) + self.assertEqual(403, response.status_code, response.content[:1000]) + + def test_create_not_authorized_reader(self): + self.setUp_global_reader() + + response = self.client.post(self.url, self.payload) + self.assertEqual(403, response.status_code, response.content[:1000]) + + def test_create_authorized_owner(self): + self.setUp_global_owner() + + response = self.client.post(self.url, self.payload) + self.assertEqual(201, response.status_code, response.content[:1000]) + + class DojoGroupsTest(BaseClass.BaseClassTest): fixtures = ["dojo_testdata.json"] @@ -3016,6 +3118,29 @@ def __init__(self, *args, **kwargs): BaseClass.RESTEndpointTest.__init__(self, *args, **kwargs) +class OrganizationMemberTest(BaseClass.MemberEndpointTest): + fixtures = ["dojo_testdata.json"] + + def __init__(self, *args, **kwargs): + self.endpoint_model = Product_Type_Member + self.endpoint_path = "organization_members" + self.viewname = "organization_member" + self.viewset = OrganizationMemberViewSet + self.payload = { + "organization": 1, + "user": 3, + "role": 2, + } + self.update_fields = {"role": 3} + self.test_type = TestType.OBJECT_PERMISSIONS + self.permission_check_class = Product_Type_Member + self.permission_create = Permissions.Product_Type_Manage_Members + self.permission_update = Permissions.Product_Type_Manage_Members + self.permission_delete = Permissions.Product_Type_Member_Delete + self.deleted_objects = 1 + BaseClass.RESTEndpointTest.__init__(self, *args, **kwargs) + + class ProductMemberTest(BaseClass.MemberEndpointTest): fixtures = ["dojo_testdata.json"] @@ -3039,6 +3164,29 @@ def __init__(self, *args, **kwargs): BaseClass.RESTEndpointTest.__init__(self, *args, **kwargs) +class AssetMemberTest(BaseClass.MemberEndpointTest): + fixtures = ["dojo_testdata.json"] + + def __init__(self, *args, **kwargs): + self.endpoint_model = Product_Member + self.endpoint_path = "asset_members" + self.viewname = "asset_member" + self.viewset = AssetMemberViewSet + self.payload = { + "asset": 3, + "user": 2, + "role": 2, + } + self.update_fields = {"role": 3} + self.test_type = TestType.OBJECT_PERMISSIONS + self.permission_check_class = Product_Member + self.permission_create = Permissions.Product_Manage_Members + self.permission_update = Permissions.Product_Manage_Members + self.permission_delete = Permissions.Product_Member_Delete + self.deleted_objects = 1 + BaseClass.RESTEndpointTest.__init__(self, *args, **kwargs) + + class ProductTypeGroupTest(BaseClass.MemberEndpointTest): fixtures = ["dojo_testdata.json"] @@ -3062,6 +3210,29 @@ def __init__(self, *args, **kwargs): BaseClass.RESTEndpointTest.__init__(self, *args, **kwargs) +class OrganiazationGroupTest(BaseClass.MemberEndpointTest): + fixtures = ["dojo_testdata.json"] + + def __init__(self, *args, **kwargs): + self.endpoint_model = Product_Type_Group + self.endpoint_path = "organization_groups" + self.viewname = "organization_group" + self.viewset = OrganizationGroupViewSet + self.payload = { + "organization": 1, + "group": 2, + "role": 2, + } + self.update_fields = {"role": 3} + self.test_type = TestType.OBJECT_PERMISSIONS + self.permission_check_class = Product_Type_Group + self.permission_create = Permissions.Product_Type_Group_Add + self.permission_update = Permissions.Product_Type_Group_Edit + self.permission_delete = Permissions.Product_Type_Group_Delete + self.deleted_objects = 1 + BaseClass.RESTEndpointTest.__init__(self, *args, **kwargs) + + class ProductGroupTest(BaseClass.MemberEndpointTest): fixtures = ["dojo_testdata.json"] @@ -3085,6 +3256,29 @@ def __init__(self, *args, **kwargs): BaseClass.RESTEndpointTest.__init__(self, *args, **kwargs) +class AssetGroupTest(BaseClass.MemberEndpointTest): + fixtures = ["dojo_testdata.json"] + + def __init__(self, *args, **kwargs): + self.endpoint_model = Product_Group + self.endpoint_path = "asset_groups" + self.viewname = "asset_group" + self.viewset = AssetGroupViewSet + self.payload = { + "asset": 1, + "group": 2, + "role": 2, + } + self.update_fields = {"role": 3} + self.test_type = TestType.OBJECT_PERMISSIONS + self.permission_check_class = Product_Group + self.permission_create = Permissions.Product_Group_Add + self.permission_update = Permissions.Product_Group_Edit + self.permission_delete = Permissions.Product_Group_Delete + self.deleted_objects = 1 + BaseClass.RESTEndpointTest.__init__(self, *args, **kwargs) + + class LanguageTypeTest(BaseClass.BaseClassTest): fixtures = ["dojo_testdata.json"] From 40808f0aaa0fdad33b7ef3140a5e3cd60e8f15ab Mon Sep 17 00:00:00 2001 From: Cody Maffucci <46459665+Maffooch@users.noreply.github.com> Date: Mon, 12 Jan 2026 14:33:22 -0700 Subject: [PATCH 34/36] Refactor permission classes to use asset and organization-specific permissions, enhancing clarity and maintainability. --- dojo/api_v2/permissions.py | 171 +++++++++++++++++++++++++++------ dojo/asset/api/views.py | 8 +- dojo/organization/api/views.py | 6 +- 3 files changed, 150 insertions(+), 35 deletions(-) diff --git a/dojo/api_v2/permissions.py b/dojo/api_v2/permissions.py index 80611d262e8..9ba08bc0b0e 100644 --- a/dojo/api_v2/permissions.py +++ b/dojo/api_v2/permissions.py @@ -33,20 +33,10 @@ def check_post_permission(request: Request, post_model: Model, post_pk: str | list[str], post_permission: int) -> bool: if request.method == "POST": - eligible_post_pk = None - # Support both single PK string and list of PK strings - searchable_post_pks = post_pk if isinstance(post_pk, list) else [post_pk] - # Iterate until we find a matching PK in the request data - for pk in searchable_post_pks: - if request.data.get(pk) is not None: - eligible_post_pk = pk - break - # Raise an error if we never find anything - if eligible_post_pk is None: - msg = f"Unable to check for permissions: No valid attribute in '{post_pk}' is required" + if request.data.get(post_pk) is None: + msg = f"Unable to check for permissions: Attribute '{post_pk}' is required" raise ParseError(msg) - # Attempt to get the object - obj = get_object_or_404(post_model, pk=request.data.get(eligible_post_pk)) + obj = get_object_or_404(post_model, pk=request.data.get(post_pk)) return user_has_permission(request.user, obj, post_permission) return True @@ -73,7 +63,7 @@ def check_object_permission( class UserHasAppAnalysisPermission(permissions.BasePermission): def has_permission(self, request, view): return check_post_permission( - request, Product, ["product", "asset"], Permissions.Technology_Add, + request, Product, "product", Permissions.Technology_Add, ) def has_object_permission(self, request, view, obj): @@ -90,7 +80,7 @@ class UserHasCredentialPermission(permissions.BasePermission): def has_permission(self, request, view): if request.data.get("product") is not None: return check_post_permission( - request, Cred_Mapping, ["product", "asset"], Permissions.Credential_Add, + request, Cred_Mapping, "product", Permissions.Credential_Add, ) if request.data.get("engagement") is not None: return check_post_permission( @@ -105,7 +95,7 @@ def has_permission(self, request, view): request, Cred_Mapping, "finding", Permissions.Credential_Add, ) return check_post_permission( - request, Cred_Mapping, ["product", "asset"], Permissions.Credential_Add, + request, Cred_Mapping, "product", Permissions.Credential_Add, ) def has_object_permission(self, request, view, obj): @@ -243,7 +233,7 @@ def has_object_permission(self, request, view, obj): class UserHasToolProductSettingsPermission(permissions.BasePermission): def has_permission(self, request, view): return check_post_permission( - request, Product, ["product", "asset"], Permissions.Product_Edit, + request, Product, "product", Permissions.Product_Edit, ) def has_object_permission(self, request, view, obj): @@ -259,7 +249,7 @@ def has_object_permission(self, request, view, obj): class UserHasEndpointPermission(permissions.BasePermission): def has_permission(self, request, view): return check_post_permission( - request, Product, ["product", "asset"], Permissions.Endpoint_Add, + request, Product, "product", Permissions.Endpoint_Add, ) def has_object_permission(self, request, view, obj): @@ -299,7 +289,7 @@ def has_permission(self, request, view): request.path, ) or UserHasEngagementPermission.path_engagement.match(request.path): return check_post_permission( - request, Product, ["product", "asset"], Permissions.Engagement_Add, + request, Product, "product", Permissions.Engagement_Add, ) # related object only need object permission return True @@ -338,7 +328,7 @@ def has_permission(self, request, view): request.path, ): return check_post_permission( - request, Product, ["product", "asset"], Permissions.Risk_Acceptance, + request, Product, "product", Permissions.Risk_Acceptance, ) # related object only need object permission return True @@ -505,7 +495,26 @@ def has_permission(self, request, view): return check_post_permission( request, Product_Type, - ["prod_type", "organization"], + "prod_type", + Permissions.Product_Type_Add_Product, + ) + + def has_object_permission(self, request, view, obj): + return check_object_permission( + request, + obj, + Permissions.Product_View, + Permissions.Product_Edit, + Permissions.Product_Delete, + ) + + +class UserHasAssetPermission(permissions.BasePermission): + def has_permission(self, request, view): + return check_post_permission( + request, + Product_Type, + "organization", Permissions.Product_Type_Add_Product, ) @@ -522,7 +531,23 @@ def has_object_permission(self, request, view, obj): class UserHasProductMemberPermission(permissions.BasePermission): def has_permission(self, request, view): return check_post_permission( - request, Product, ["product", "asset"], Permissions.Product_Manage_Members, + request, Product, "product", Permissions.Product_Manage_Members, + ) + + def has_object_permission(self, request, view, obj): + return check_object_permission( + request, + obj, + Permissions.Product_View, + Permissions.Product_Manage_Members, + Permissions.Product_Member_Delete, + ) + + +class UserHasAssetMemberPermission(permissions.BasePermission): + def has_permission(self, request, view): + return check_post_permission( + request, Product, "asset", Permissions.Product_Manage_Members, ) def has_object_permission(self, request, view, obj): @@ -538,7 +563,23 @@ def has_object_permission(self, request, view, obj): class UserHasProductGroupPermission(permissions.BasePermission): def has_permission(self, request, view): return check_post_permission( - request, Product, ["product", "asset"], Permissions.Product_Group_Add, + request, Product, "product", Permissions.Product_Group_Add, + ) + + def has_object_permission(self, request, view, obj): + return check_object_permission( + request, + obj, + Permissions.Product_Group_View, + Permissions.Product_Group_Edit, + Permissions.Product_Group_Delete, + ) + + +class UserHasAssetGroupPermission(permissions.BasePermission): + def has_permission(self, request, view): + return check_post_permission( + request, Product, "asset", Permissions.Product_Group_Add, ) def has_object_permission(self, request, view, obj): @@ -569,12 +610,49 @@ def has_object_permission(self, request, view, obj): ) +class UserHasOrganizationPermission(permissions.BasePermission): + def has_permission(self, request, view): + if request.method == "POST": + return user_has_global_permission( + request.user, Permissions.Product_Type_Add, + ) + return True + + def has_object_permission(self, request, view, obj): + return check_object_permission( + request, + obj, + Permissions.Product_Type_View, + Permissions.Product_Type_Edit, + Permissions.Product_Type_Delete, + ) + + class UserHasProductTypeMemberPermission(permissions.BasePermission): def has_permission(self, request, view): return check_post_permission( request, Product_Type, - ["product_type", "organization"], + "product_type", + Permissions.Product_Type_Manage_Members, + ) + + def has_object_permission(self, request, view, obj): + return check_object_permission( + request, + obj, + Permissions.Product_Type_View, + Permissions.Product_Type_Manage_Members, + Permissions.Product_Type_Member_Delete, + ) + + +class UserHasOrganizationMemberPermission(permissions.BasePermission): + def has_permission(self, request, view): + return check_post_permission( + request, + Product_Type, + "organization", Permissions.Product_Type_Manage_Members, ) @@ -593,7 +671,25 @@ def has_permission(self, request, view): return check_post_permission( request, Product_Type, - ["product_type", "organization"], + "product_type", + Permissions.Product_Type_Group_Add, + ) + + def has_object_permission(self, request, view, obj): + return check_object_permission( + request, + obj, + Permissions.Product_Type_Group_View, + Permissions.Product_Type_Group_Edit, + Permissions.Product_Type_Group_Delete, + ) + +class UserHasOrganizationGroupPermission(permissions.BasePermission): + def has_permission(self, request, view): + return check_post_permission( + request, + Product_Type, + "organization", Permissions.Product_Type_Group_Add, ) @@ -719,7 +815,7 @@ def has_object_permission(self, request, view, obj): class UserHasLanguagePermission(permissions.BasePermission): def has_permission(self, request, view): return check_post_permission( - request, Product, ["product", "asset"], Permissions.Language_Add, + request, Product, "product", Permissions.Language_Add, ) def has_object_permission(self, request, view, obj): @@ -737,7 +833,26 @@ def has_permission(self, request, view): return check_post_permission( request, Product, - ["product", "asset"], + "product", + Permissions.Product_API_Scan_Configuration_Add, + ) + + def has_object_permission(self, request, view, obj): + return check_object_permission( + request, + obj, + Permissions.Product_API_Scan_Configuration_View, + Permissions.Product_API_Scan_Configuration_Edit, + Permissions.Product_API_Scan_Configuration_Delete, + ) + + +class UserHasAssetAPIScanConfigurationPermission(permissions.BasePermission): + def has_permission(self, request, view): + return check_post_permission( + request, + Product, + "asset", Permissions.Product_API_Scan_Configuration_Add, ) @@ -893,7 +1008,7 @@ def has_permission(self, request, view): class UserHasEngagementPresetPermission(permissions.BasePermission): def has_permission(self, request, view): return check_post_permission( - request, Product, ["product", "asset"], Permissions.Product_Edit, + request, Product, "product", Permissions.Product_Edit, ) def has_object_permission(self, request, view, obj): diff --git a/dojo/asset/api/views.py b/dojo/asset/api/views.py index d3a873f97da..0e01499c466 100644 --- a/dojo/asset/api/views.py +++ b/dojo/asset/api/views.py @@ -43,7 +43,7 @@ class AssetAPIScanConfigurationViewSet( filterset_class = AssetAPIScanConfigurationFilterSet permission_classes = ( IsAuthenticated, - permissions.UserHasProductAPIScanConfigurationPermission, + permissions.UserHasAssetAPIScanConfigurationPermission, ) def get_queryset(self): @@ -68,7 +68,7 @@ class AssetViewSet( filterset_class = ApiAssetFilter permission_classes = ( IsAuthenticated, - permissions.UserHasProductPermission, + permissions.UserHasAssetPermission, ) def get_queryset(self): @@ -138,7 +138,7 @@ class AssetMemberViewSet( filterset_class = AssetMemberFilterSet permission_classes = ( IsAuthenticated, - permissions.UserHasProductMemberPermission, + permissions.UserHasAssetMemberPermission, ) def get_queryset(self): @@ -166,7 +166,7 @@ class AssetGroupViewSet( filterset_class = AssetGroupFilterSet permission_classes = ( IsAuthenticated, - permissions.UserHasProductGroupPermission, + permissions.UserHasAssetGroupPermission, ) def get_queryset(self): diff --git a/dojo/organization/api/views.py b/dojo/organization/api/views.py index ea3aa9005d9..0cbbf561eaf 100644 --- a/dojo/organization/api/views.py +++ b/dojo/organization/api/views.py @@ -40,7 +40,7 @@ class OrganizationViewSet( filterset_class = OrganizationFilterSet permission_classes = ( IsAuthenticated, - permissions.UserHasProductTypePermission, + permissions.UserHasOrganizationPermission, ) def get_queryset(self): @@ -121,7 +121,7 @@ class OrganizationMemberViewSet( filterset_class = OrganizationMemberFilterSet permission_classes = ( IsAuthenticated, - permissions.UserHasProductTypeMemberPermission, + permissions.UserHasOrganizationMemberPermission, ) def get_queryset(self): @@ -163,7 +163,7 @@ class OrganizationGroupViewSet( filterset_class = OrganizationGroupFilterSet permission_classes = ( IsAuthenticated, - permissions.UserHasProductTypeGroupPermission, + permissions.UserHasOrganizationGroupPermission, ) def get_queryset(self): From e4da98af98b8c8022cfe94e03268384a4d27442a Mon Sep 17 00:00:00 2001 From: Cody Maffucci <46459665+Maffooch@users.noreply.github.com> Date: Mon, 12 Jan 2026 14:34:23 -0700 Subject: [PATCH 35/36] Add blank line before UserHasOrganizationGroupPermission class for improved readability --- dojo/api_v2/permissions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dojo/api_v2/permissions.py b/dojo/api_v2/permissions.py index 9ba08bc0b0e..905ddf99b58 100644 --- a/dojo/api_v2/permissions.py +++ b/dojo/api_v2/permissions.py @@ -684,6 +684,7 @@ def has_object_permission(self, request, view, obj): Permissions.Product_Type_Group_Delete, ) + class UserHasOrganizationGroupPermission(permissions.BasePermission): def has_permission(self, request, view): return check_post_permission( From 5f3b466d9cf1cfa0985c7f4e9d3c1666b579fd9b Mon Sep 17 00:00:00 2001 From: Valentijn Scholten Date: Sun, 25 Jan 2026 21:17:30 +0100 Subject: [PATCH 36/36] ruff --- dojo/importers/base_importer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dojo/importers/base_importer.py b/dojo/importers/base_importer.py index 009102c7769..8b82a6bfd45 100644 --- a/dojo/importers/base_importer.py +++ b/dojo/importers/base_importer.py @@ -12,9 +12,8 @@ from django.utils.timezone import make_aware import dojo.finding.helper as finding_helper -from dojo.celery_dispatch import dojo_dispatch_task import dojo.risk_acceptance.helper as ra_helper -from dojo import utils +from dojo.celery_dispatch import dojo_dispatch_task from dojo.importers.endpoint_manager import EndpointManager from dojo.importers.options import ImporterOptions from dojo.models import (