diff --git a/modules/machinery/az.py b/modules/machinery/az.py index a8f5edeed16..1f67b2b60d1 100644 --- a/modules/machinery/az.py +++ b/modules/machinery/az.py @@ -8,9 +8,10 @@ import threading import time import timeit -from typing import Optional, cast +from typing import Optional import sqlalchemy +from sqlalchemy import select # Cuckoo-specific imports from lib.cuckoo.common.config import Config @@ -608,30 +609,30 @@ def find_machine_to_service_task(self, task: Task) -> Optional[Machine]: task_archs, task_tags = self.db._task_arch_tags_helper(task) os_version = self.db._package_vm_requires_check(task.package) - def get_first_machine(query: sqlalchemy.orm.Query) -> Optional[Machine]: + def get_first_machine(statement: sqlalchemy.Select) -> Optional[Machine]: # Select for update a machine, preferring one that is available and was the one that was used the # longest time ago. This will give us a machine that can get locked or, if there are none that are # currently available, we'll at least know that the task is serviceable. - return cast( - Optional[Machine], query.order_by(Machine.locked, Machine.locked_changed_on).with_for_update(of=Machine).first() - ) + statement = statement.order_by(Machine.locked, Machine.locked_changed_on).with_for_update(of=Machine) + return self.db.session.scalars(statement).first() - machines = self.db.session.query(Machine).options(sqlalchemy.orm.joinedload(Machine.tags)) + machines_stmt = select(Machine).options(sqlalchemy.orm.joinedload(Machine.tags)) filter_kwargs = { - "statement": machines, + "statement": machines_stmt, "label": task.machine, "tags": task_tags, "archs": task_archs, "os_version": os_version, } - filtered_machines = self.db.filter_machines_to_task(include_reserved=False, **filter_kwargs) - machine = get_first_machine(filtered_machines) + + filtered_machines_stmt = self.db.filter_machines_to_task(include_reserved=False, **filter_kwargs) + machine = get_first_machine(filtered_machines_stmt) if machine is None and not task.machine and task_tags: # The task was given at least 1 tag, but there are no non-reserved machines # that could satisfy the request. So let's see if there are any "reserved" # machines that can satisfy it. - filtered_machines = self.db.filter_machines_to_task(include_reserved=True, **filter_kwargs) - machine = get_first_machine(filtered_machines) + filtered_machines_stmt = self.db.filter_machines_to_task(include_reserved=True, **filter_kwargs) + machine = get_first_machine(filtered_machines_stmt) if machine is None: self._scale_from_zero(task, os_version, task_tags) @@ -999,17 +1000,20 @@ def _thr_create_vmss(self, vmss_name, vmss_image_ref, vmss_image_os): "wait": False, } self.required_vmsss[vmss_name]["exists"] = True - try: - with self.db.session.begin(): - if machine_pools[vmss_name]["size"] == 0: - self._insert_placeholder_machine(vmss_name, self.required_vmsss[vmss_name]) - else: - self._add_machines_to_db(vmss_name) - except sqlalchemy.exc.InvalidRequestError: + with self.db.session.begin_nested() if self.db.session().in_transaction() else self.db.session.begin() as session: if machine_pools[vmss_name]["size"] == 0: self._insert_placeholder_machine(vmss_name, self.required_vmsss[vmss_name]) else: self._add_machines_to_db(vmss_name) + try: + session.commit() + except sqlalchemy.exc.InvalidRequestError: + session.rollback() + # Retry logic might be needed here if the session was already committed by an outer scope. + if machine_pools[vmss_name]["size"] == 0: + self._insert_placeholder_machine(vmss_name, self.required_vmsss[vmss_name]) + else: + self._add_machines_to_db(vmss_name) def _thr_reimage_vmss(self, vmss_name): """ @@ -1039,11 +1043,15 @@ def _thr_reimage_vmss(self, vmss_name): else: log.exception(repr(e)) raise - try: - with self.db.session.begin(): - self._add_machines_to_db(vmss_name) - except sqlalchemy.exc.InvalidRequestError: + with self.db.session.begin_nested() if self.db.session().in_transaction() else self.db.session.begin() as session: self._add_machines_to_db(vmss_name) + try: + session.commit() + except sqlalchemy.exc.InvalidRequestError: + session.rollback() + # Retry logic might be needed here if the session was already committed by an outer scope. + # For now, just call it again outside a transaction. + self._add_machines_to_db(vmss_name) def _thr_scale_machine_pool(self, tag, per_platform=False): """ @@ -1052,8 +1060,12 @@ def _thr_scale_machine_pool(self, tag, per_platform=False): @param per_platform: A boolean flag indicating that we should scale machine pools "per platform" vs. "per tag" @return: Ends method call """ - with self.db.session.begin(): - return self._scale_machine_pool(tag, per_platform=per_platform) + with self.db.session.begin_nested() if self.db.session().in_transaction() else self.db.session.begin() as session: + self._scale_machine_pool(tag, per_platform=per_platform) + try: + session.commit() + except sqlalchemy.exc.InvalidRequestError: + session.rollback() def _scale_machine_pool(self, tag, per_platform=False): global current_vmss_operations