Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .github/workflows/base-lambdas-reusable-deploy-all.yml
Original file line number Diff line number Diff line change
Expand Up @@ -796,3 +796,17 @@ jobs:
lambda_layer_names: "core_lambda_layer"
secrets:
AWS_ASSUME_ROLE: ${{ secrets.AWS_ASSUME_ROLE }}

deploy_concurrency_controller_lambda:
name: Deploy Concurrency Controller Lambda
uses: ./.github/workflows/base-lambdas-reusable-deploy.yml
with:
environment: ${{ inputs.environment }}
python_version: ${{ inputs.python_version }}
build_branch: ${{ inputs.build_branch }}
sandbox: ${{ inputs.sandbox }}
lambda_handler_name: concurrency_controller_handler
lambda_aws_name: ConcurrencyController
lambda_layer_names: "core_lambda_layer"
secrets:
AWS_ASSUME_ROLE: ${{ secrets.AWS_ASSUME_ROLE }}
41 changes: 41 additions & 0 deletions lambdas/handlers/concurrency_controller_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from services.concurrency_controller_service import ConcurrencyControllerService
from utils.audit_logging_setup import LoggingService
from utils.decorators.handle_lambda_exceptions import handle_lambda_exceptions
from utils.decorators.override_error_check import override_error_check
from utils.decorators.set_audit_arg import set_request_context_for_logging

logger = LoggingService(__name__)


def validate_event(event):
target_function = event.get("targetFunction")
reserved_concurrency = event.get("reservedConcurrency")

if not target_function:
logger.error("Missing required parameter: targetFunction")
raise ValueError("targetFunction is required")

if reserved_concurrency is None:
logger.error("Missing required parameter: reservedConcurrency")
raise ValueError("reservedConcurrency is required")

return target_function, reserved_concurrency


@set_request_context_for_logging
@override_error_check
@handle_lambda_exceptions
def lambda_handler(event, _context):
target_function, reserved_concurrency = validate_event(event)

service = ConcurrencyControllerService()
updated_concurrency = service.update_function_concurrency(target_function, reserved_concurrency)

return {
"statusCode": 200,
"body": {
"message": "Concurrency updated successfully",
"function": target_function,
"reservedConcurrency": updated_concurrency
}
}
48 changes: 48 additions & 0 deletions lambdas/services/concurrency_controller_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import boto3
from botocore.exceptions import ClientError
from utils.audit_logging_setup import LoggingService

logger = LoggingService(__name__)


class ConcurrencyControllerService:
def __init__(self):
self.lambda_client = boto3.client("lambda")

def update_function_concurrency(self, target_function, reserved_concurrency):
logger.info(
f"Updating reserved concurrency for function '{target_function}' to {reserved_concurrency}"
)

try:
response = self.lambda_client.put_function_concurrency(
FunctionName=target_function,
ReservedConcurrentExecutions=reserved_concurrency
)

updated_concurrency = response.get("ReservedConcurrentExecutions")

if updated_concurrency is None:
logger.error("Response did not contain ReservedConcurrentExecutions")
raise ValueError("Failed to confirm concurrency update from AWS response")

if updated_concurrency != reserved_concurrency:
logger.error(
f"Concurrency mismatch: requested {reserved_concurrency}, "
f"AWS returned {updated_concurrency}"
)
raise ValueError("Concurrency update verification failed")

logger.info(
f"Successfully updated concurrency for '{target_function}'. "
f"Reserved concurrency set to: {updated_concurrency}"
)

return updated_concurrency
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "")
if error_code == "ResourceNotFoundException":
logger.error(f"Lambda function '{target_function}' not found")
else:
logger.error(f"Failed to update concurrency: {str(e)}")
raise
217 changes: 217 additions & 0 deletions lambdas/tests/unit/handlers/test_concurrency_controller_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
import json
import pytest
from botocore.exceptions import ClientError
from handlers.concurrency_controller_handler import lambda_handler, validate_event
from unittest.mock import MagicMock


@pytest.fixture
def mock_concurrency_controller_service(mocker):
mocked_class = mocker.patch(
"handlers.concurrency_controller_handler.ConcurrencyControllerService"
)
mocked_instance = mocked_class.return_value
yield mocked_instance


@pytest.fixture
def mock_logger(mocker):
return mocker.patch("handlers.concurrency_controller_handler.logger")


@pytest.fixture
def valid_event():
return {
"targetFunction": "test-lambda-function",
"reservedConcurrency": 10
}


@pytest.fixture
def event_with_zero_concurrency():
return {
"targetFunction": "test-lambda-function",
"reservedConcurrency": 0
}


def test_lambda_handler_success(valid_event, context, mock_concurrency_controller_service):
mock_concurrency_controller_service.update_function_concurrency.return_value = 10

result = lambda_handler(valid_event, context)

mock_concurrency_controller_service.update_function_concurrency.assert_called_once_with(
"test-lambda-function", 10
)

assert result["statusCode"] == 200
assert result["body"]["message"] == "Concurrency updated successfully"
assert result["body"]["function"] == "test-lambda-function"
assert result["body"]["reservedConcurrency"] == 10


def test_lambda_handler_with_zero_concurrency(
event_with_zero_concurrency, context, mock_concurrency_controller_service
):
mock_concurrency_controller_service.update_function_concurrency.return_value = 0

result = lambda_handler(event_with_zero_concurrency, context)

mock_concurrency_controller_service.update_function_concurrency.assert_called_once_with(
"test-lambda-function", 0
)

assert result["statusCode"] == 200
assert result["body"]["message"] == "Concurrency updated successfully"
assert result["body"]["function"] == "test-lambda-function"
assert result["body"]["reservedConcurrency"] == 0


def test_lambda_handler_with_large_concurrency(context, mock_concurrency_controller_service):
event = {
"targetFunction": "test-lambda-function",
"reservedConcurrency": 1000
}

mock_concurrency_controller_service.update_function_concurrency.return_value = 1000

result = lambda_handler(event, context)

mock_concurrency_controller_service.update_function_concurrency.assert_called_once_with(
"test-lambda-function", 1000
)

assert result["statusCode"] == 200
assert result["body"]["message"] == "Concurrency updated successfully"
assert result["body"]["function"] == "test-lambda-function"
assert result["body"]["reservedConcurrency"] == 1000


def test_validate_event_success(valid_event):
target_function, reserved_concurrency = validate_event(valid_event)

assert target_function == "test-lambda-function"
assert reserved_concurrency == 10


def test_validate_event_missing_target_function(mock_logger):
event = {
"reservedConcurrency": 10
}

with pytest.raises(ValueError) as exc_info:
validate_event(event)

assert str(exc_info.value) == "targetFunction is required"
mock_logger.error.assert_called_once_with("Missing required parameter: targetFunction")


def test_validate_event_missing_reserved_concurrency(mock_logger):
event = {
"targetFunction": "test-lambda-function"
}

with pytest.raises(ValueError) as exc_info:
validate_event(event)

assert str(exc_info.value) == "reservedConcurrency is required"
mock_logger.error.assert_called_once_with("Missing required parameter: reservedConcurrency")


def test_validate_event_both_parameters_missing(mock_logger):
event = {}

with pytest.raises(ValueError) as exc_info:
validate_event(event)

# Should fail on first missing parameter
assert str(exc_info.value) == "targetFunction is required"


def test_validate_event_empty_target_function(mock_logger):
event = {
"targetFunction": "",
"reservedConcurrency": 10
}

with pytest.raises(ValueError) as exc_info:
validate_event(event)

assert str(exc_info.value) == "targetFunction is required"
mock_logger.error.assert_called_once_with("Missing required parameter: targetFunction")


def test_validate_event_reserved_concurrency_zero_is_valid():
event = {
"targetFunction": "test-lambda-function",
"reservedConcurrency": 0
}

target_function, reserved_concurrency = validate_event(event)

assert target_function == "test-lambda-function"
assert reserved_concurrency == 0


def test_validate_event_with_additional_fields():
event = {
"targetFunction": "test-lambda-function",
"reservedConcurrency": 10,
"extraField": "should-be-ignored"
}

target_function, reserved_concurrency = validate_event(event)

assert target_function == "test-lambda-function"
assert reserved_concurrency == 10


def test_lambda_handler_service_raises_resource_not_found(
valid_event, context, mock_concurrency_controller_service
):
error_response = {
'Error': {
'Code': 'ResourceNotFoundException',
'Message': 'Function not found'
}
}

mock_concurrency_controller_service.update_function_concurrency.side_effect = ClientError(
error_response, 'PutFunctionConcurrency'
)

result = lambda_handler(valid_event, context)

# The decorators convert exceptions to API Gateway error responses
assert result['statusCode'] == 500
body = json.loads(result['body'])
assert body['message'] == 'Failed to utilise AWS client/resource'
assert body['err_code'] == 'GWY_5001'


def test_lambda_handler_service_raises_invalid_parameter(
context, mock_concurrency_controller_service
):
event = {
"targetFunction": "test-lambda-function",
"reservedConcurrency": -1
}

error_response = {
'Error': {
'Code': 'InvalidParameterValueException',
'Message': 'Reserved concurrency value must be non-negative'
}
}

mock_concurrency_controller_service.update_function_concurrency.side_effect = ClientError(
error_response, 'PutFunctionConcurrency'
)

result = lambda_handler(event, context)

# The decorators convert exceptions to API Gateway error responses
assert result['statusCode'] == 500
body = json.loads(result['body'])
assert body['message'] == 'Failed to utilise AWS client/resource'
assert body['err_code'] == 'GWY_5001'
Loading
Loading