diff --git a/aws_lambda_opentelemetry/trace/helpers.py b/aws_lambda_opentelemetry/trace/helpers.py index 2fb6d20..9c3f20d 100644 --- a/aws_lambda_opentelemetry/trace/helpers.py +++ b/aws_lambda_opentelemetry/trace/helpers.py @@ -10,7 +10,7 @@ ) from aws_lambda_opentelemetry.typing.context import LambdaContext -from aws_lambda_opentelemetry.utils import AwsAttributesMapper +from aws_lambda_opentelemetry.utils import AwsAttributesExtractor def instrument_handler(**kwargs): @@ -54,8 +54,8 @@ def wrapper(event: dict, context: LambdaContext): span.record_exception(exc) raise finally: - mapper = AwsAttributesMapper(event, context) - mapper.add_attributes() + extractor = AwsAttributesExtractor(event, context) + extractor.add_attributes() finally: provider.force_flush() diff --git a/aws_lambda_opentelemetry/utils.py b/aws_lambda_opentelemetry/utils.py index 299c757..272bc13 100644 --- a/aws_lambda_opentelemetry/utils.py +++ b/aws_lambda_opentelemetry/utils.py @@ -1,5 +1,6 @@ import enum import os +from abc import ABC, abstractmethod from opentelemetry import trace from opentelemetry.semconv._incubating.attributes.cloud_attributes import ( @@ -58,144 +59,276 @@ class AwsDataSource(enum.Enum): OTHER = "aws.other" -class AwsAttributesMapper: - def __init__(self, event: dict, context: LambdaContext) -> None: - self.event = event - self.context = context - self.span = trace.get_current_span() - self.data_source = self._get_aws_data_source() - self.faas_trigger = self._get_faas_trigger() +class AttributeExtractor(ABC): + """Base class for AWS service-specific attribute extractors.""" - def add_attributes(self) -> None: - """ - Generic method which inspects given event/context - and tries to add as much metadata to the current span as it can. - """ - self._add_aws_attributes() + @property + @abstractmethod + def data_source(self) -> AwsDataSource: + """Return the AWS data source this extractor handles.""" + raise NotImplementedError() # pragma: no cover - match self.data_source: - case AwsDataSource.API_GATEWAY: - self._add_apigateway_attributes() - case AwsDataSource.SQS: - self._add_sqs_attributes() - case _: - ... + @abstractmethod + def can_handle(self, event: dict) -> bool: + """Determine if this extractor can handle the given event.""" + raise NotImplementedError() # pragma: no cover - def _get_aws_data_source(self) -> AwsDataSource: - # HTTP triggers - if "requestContext" in self.event: - if "apiId" in self.event["requestContext"]: - return AwsDataSource.API_GATEWAY + @abstractmethod + def get_attributes(self, event: dict, context: LambdaContext) -> dict: + """Extract related attributes from the event and context.""" + raise NotImplementedError() # pragma: no cover - if "http" in self.event["requestContext"]: - return AwsDataSource.HTTP_API - if "elb" in self.event["requestContext"]: - return AwsDataSource.ELB - - # EventBridge - if "source" in self.event and "detail-type" in self.event: - return AwsDataSource.EVENT_BRIDGE +class GenericAwsExtractor(AttributeExtractor): + @property + def data_source(self) -> AwsDataSource: + return AwsDataSource.OTHER - # SNS/SQS/S3/DynamoDB/Kinesis - if "Records" in self.event and len(self.event["Records"]) > 0: - record = self.event["Records"][0] - event_source = record.get("eventSource") + def can_handle(self, event: dict) -> bool: + return True + + def get_attributes(self, event: dict, context: LambdaContext) -> dict: + return { + FAAS_INVOCATION_ID: context.aws_request_id, + FAAS_INVOKED_NAME: context.function_name, + FAAS_INVOKED_REGION: context.region, + FAAS_INVOKED_PROVIDER: FaasInvokedProviderValues.AWS.value, + FAAS_MAX_MEMORY: context.memory_limit_in_mb, + FAAS_VERSION: context.function_version, + FAAS_COLDSTART: _check_cold_start(), + CLOUD_RESOURCE_ID: context.invoked_function_arn, + } + + +class HttpApiExtractor(AttributeExtractor): + @property + def data_source(self) -> AwsDataSource: + return AwsDataSource.HTTP_API + + def can_handle(self, event: dict) -> bool: + return "requestContext" in event and "http" in event["requestContext"] + + def get_attributes(self, event: dict, context: LambdaContext) -> dict: + request_context = event.get("requestContext", {}) + http_context = request_context.get("http", {}) + protocol = http_context.get("protocol", "") + + return { + FAAS_TRIGGER: FaasTriggerValues.HTTP.value, + HTTP_REQUEST_METHOD: http_context.get("method", ""), + HTTP_ROUTE: event.get("routeKey", ""), + HTTP_REQUEST_BODY_SIZE: len(event.get("body", "") or ""), + NETWORK_PROTOCOL_NAME: protocol.split("/")[0] if protocol else "", + NETWORK_PROTOCOL_VERSION: protocol.split("/")[-1] if protocol else "", + USER_AGENT_ORIGINAL: http_context.get("userAgent", ""), + URL_FULL: http_context.get("path", ""), + } + + +class ApiGatewayExtractor(AttributeExtractor): + @property + def data_source(self) -> AwsDataSource: + return AwsDataSource.API_GATEWAY + + def can_handle(self, event: dict) -> bool: + return ( + "requestContext" in event + and "apiId" in event["requestContext"] + and "http" not in event["requestContext"] + ) - if event_source == "aws:sns": - return AwsDataSource.SNS + def get_attributes(self, event: dict, context: LambdaContext) -> dict: + request_context = event.get("requestContext", {}) + headers = event.get("headers", {}) + protocol = request_context.get("protocol", "") - if event_source == "aws:sqs": - return AwsDataSource.SQS + return { + FAAS_TRIGGER: FaasTriggerValues.HTTP.value, + HTTP_REQUEST_METHOD: event.get("httpMethod", ""), + HTTP_ROUTE: event.get("resource", ""), + HTTP_REQUEST_BODY_SIZE: len(event.get("body", "") or ""), + NETWORK_PROTOCOL_NAME: protocol.split("/")[0], + NETWORK_PROTOCOL_VERSION: protocol.split("/")[-1], + USER_AGENT_ORIGINAL: headers.get("User-Agent", ""), + URL_FULL: event.get("path", ""), + } + + +class ElbExtractor(AttributeExtractor): + @property + def data_source(self) -> AwsDataSource: + return AwsDataSource.ELB + + def can_handle(self, event: dict) -> bool: + return "requestContext" in event and "elb" in event["requestContext"] + + def get_attributes(self, event: dict, context: LambdaContext) -> dict: + headers = event.get("headers", {}) + + return { + FAAS_TRIGGER: FaasTriggerValues.HTTP.value, + HTTP_REQUEST_METHOD: event.get("httpMethod", ""), + HTTP_ROUTE: event.get("path", ""), + HTTP_REQUEST_BODY_SIZE: len(event.get("body", "") or ""), + URL_FULL: event.get("path", ""), + USER_AGENT_ORIGINAL: headers.get("user-agent", ""), + } + + +class SqsExtractor(AttributeExtractor): + @property + def data_source(self) -> AwsDataSource: + return AwsDataSource.SQS + + def can_handle(self, event: dict) -> bool: + if "Records" not in event or len(event["Records"]) == 0: + return False + return event["Records"][0].get("eventSource") == "aws:sqs" + + def get_attributes(self, event: dict, context: LambdaContext) -> dict: + records = event.get("Records", []) + message_count = len(records) + queue_arn = records[0].get("eventSourceARN", "") if message_count > 0 else "" + queue_name = queue_arn.split(":")[-1] - if event_source == "aws:s3": - return AwsDataSource.S3 + return { + FAAS_TRIGGER: FaasTriggerValues.PUBSUB.value, + CLOUD_RESOURCE_ID: queue_arn, + MESSAGING_SYSTEM: self.data_source.value, + MESSAGING_OPERATION: MessagingOperationTypeValues.RECEIVE.value, + MESSAGING_BATCH_MESSAGE_COUNT: message_count, + MESSAGING_DESTINATION_NAME: queue_name, + } + + +class SnsExtractor(AttributeExtractor): + @property + def data_source(self) -> AwsDataSource: + return AwsDataSource.SNS + + def can_handle(self, event: dict) -> bool: + if "Records" not in event or len(event["Records"]) == 0: + return False + return event["Records"][0].get("eventSource") == "aws:sns" + + def get_attributes(self, event: dict, context: LambdaContext) -> dict: + return { + FAAS_TRIGGER: FaasTriggerValues.PUBSUB.value, + } + + +class S3Extractor(AttributeExtractor): + @property + def data_source(self) -> AwsDataSource: + return AwsDataSource.S3 + + def can_handle(self, event: dict) -> bool: + if "Records" not in event or len(event["Records"]) == 0: + return False + return event["Records"][0].get("eventSource") == "aws:s3" + + def get_attributes(self, event: dict, context: LambdaContext) -> dict: + return { + FAAS_TRIGGER: FaasTriggerValues.DATASOURCE.value, + } + + +class DynamoDbExtractor(AttributeExtractor): + @property + def data_source(self) -> AwsDataSource: + return AwsDataSource.DYNAMODB + + def can_handle(self, event: dict) -> bool: + if "Records" not in event or len(event["Records"]) == 0: + return False + return event["Records"][0].get("eventSource") == "aws:dynamodb" + + def get_attributes(self, event: dict, context: LambdaContext) -> dict: + return { + FAAS_TRIGGER: FaasTriggerValues.DATASOURCE.value, + } + + +class KinesisExtractor(AttributeExtractor): + @property + def data_source(self) -> AwsDataSource: + return AwsDataSource.KINESIS + + def can_handle(self, event: dict) -> bool: + if "Records" not in event or len(event["Records"]) == 0: + return False + return event["Records"][0].get("eventSource") == "aws:kinesis" + + def get_attributes(self, event: dict, context: LambdaContext) -> dict: + return { + FAAS_TRIGGER: FaasTriggerValues.DATASOURCE.value, + } + + +class EventBridgeExtractor(AttributeExtractor): + @property + def data_source(self) -> AwsDataSource: + return AwsDataSource.EVENT_BRIDGE + + def can_handle(self, event: dict) -> bool: + return "source" in event and "detail-type" in event + + def get_attributes(self, event: dict, context: LambdaContext) -> dict: + detail_type = event.get("detail-type", "") + trigger_type = ( + FaasTriggerValues.TIMER.value + if detail_type == "Scheduled Event" + else FaasTriggerValues.PUBSUB.value + ) - if event_source == "aws:dynamodb": - return AwsDataSource.DYNAMODB + return { + FAAS_TRIGGER: trigger_type, + } - if event_source == "aws:kinesis": - return AwsDataSource.KINESIS - # CloudWatch Logs - if "awslogs" in self.event and "data" in self.event["awslogs"]: - return AwsDataSource.CLOUDWATCH_LOGS +class CloudWatchLogsExtractor(AttributeExtractor): + @property + def data_source(self) -> AwsDataSource: + return AwsDataSource.CLOUDWATCH_LOGS - return AwsDataSource.OTHER + def can_handle(self, event: dict) -> bool: + return "awslogs" in event and "data" in event["awslogs"] - def _get_faas_trigger(self) -> FaasTriggerValues: - if self.data_source in { - AwsDataSource.API_GATEWAY, - AwsDataSource.HTTP_API, - AwsDataSource.ELB, - }: - return FaasTriggerValues.HTTP - - if self.data_source == AwsDataSource.EVENT_BRIDGE: - if self.event["detail-type"] == "Scheduled Event": - return FaasTriggerValues.TIMER - return FaasTriggerValues.PUBSUB - - if self.data_source in {AwsDataSource.SQS, AwsDataSource.SNS}: - return FaasTriggerValues.PUBSUB - - if self.data_source in { - AwsDataSource.S3, - AwsDataSource.DYNAMODB, - AwsDataSource.KINESIS, - AwsDataSource.CLOUDWATCH_LOGS, - }: - return FaasTriggerValues.DATASOURCE - - return FaasTriggerValues.OTHER - - def _add_aws_attributes(self) -> None: - self.span.set_attributes( - { - FAAS_INVOCATION_ID: self.context.aws_request_id, - FAAS_INVOKED_NAME: self.context.function_name, - FAAS_INVOKED_REGION: self.context.region, - FAAS_INVOKED_PROVIDER: FaasInvokedProviderValues.AWS.value, - FAAS_MAX_MEMORY: self.context.memory_limit_in_mb, - FAAS_VERSION: self.context.function_version, - FAAS_COLDSTART: _check_cold_start(), - FAAS_TRIGGER: self.faas_trigger.value, - CLOUD_RESOURCE_ID: self.context.invoked_function_arn, - } - ) + def get_attributes(self, event: dict, context: LambdaContext) -> dict: + return { + FAAS_TRIGGER: FaasTriggerValues.DATASOURCE.value, + } - def _add_apigateway_attributes(self) -> None: - request_context = self.event.get("requestContext", {}) - headers = self.event.get("headers", {}) - protocol = request_context.get("protocol", "") - self.span.set_attributes( - { - HTTP_REQUEST_METHOD: self.event.get("httpMethod", ""), - HTTP_ROUTE: self.event.get("resource", ""), - URL_FULL: self.event.get("path", ""), - HTTP_REQUEST_BODY_SIZE: len(self.event.get("body", "") or ""), - NETWORK_PROTOCOL_NAME: protocol.split("/")[0], - NETWORK_PROTOCOL_VERSION: protocol.split("/")[-1], - USER_AGENT_ORIGINAL: headers.get("User-Agent", ""), - } - ) +class AwsAttributesExtractor: + _EXTRACTORS: list[AttributeExtractor] = [ + GenericAwsExtractor(), + HttpApiExtractor(), + ApiGatewayExtractor(), + ElbExtractor(), + SqsExtractor(), + SnsExtractor(), + S3Extractor(), + DynamoDbExtractor(), + KinesisExtractor(), + EventBridgeExtractor(), + CloudWatchLogsExtractor(), + ] - def _add_sqs_attributes(self) -> None: - records = self.event.get("Records", []) - message_count = len(records) - queue_arn = records[0].get("eventSourceARN", "") if message_count > 0 else "" - queue_name = queue_arn.split(":")[-1] + def __init__(self, event: dict, context: LambdaContext) -> None: + self.event = event + self.context = context + self.span = trace.get_current_span() - self.span.set_attributes( - { - MESSAGING_SYSTEM: self.data_source.value, - MESSAGING_OPERATION: MessagingOperationTypeValues.RECEIVE.value, - MESSAGING_BATCH_MESSAGE_COUNT: message_count, - MESSAGING_DESTINATION_NAME: queue_name, - CLOUD_RESOURCE_ID: queue_arn, - } - ) + def add_attributes(self) -> None: + """ + Generic method which inspects given event/context + and tries to add as much metadata to the current span as it can. + """ + for extractor in self._EXTRACTORS: + if extractor.can_handle(self.event): + attributes = extractor.get_attributes(self.event, self.context) + self.span.set_attributes(attributes) def _check_cold_start() -> bool: diff --git a/tests/conftest.py b/tests/conftest.py index 5d68137..21c70c1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -36,3 +36,13 @@ def sqs_event() -> dict: @pytest.fixture def apigateway_event() -> dict: return get_fixture("apigateway.json") + + +@pytest.fixture +def http_api_event() -> dict: + return get_fixture("http_api.json") + + +@pytest.fixture +def alb_event() -> dict: + return get_fixture("alb.json") diff --git a/tests/fixtures/alb.json b/tests/fixtures/alb.json new file mode 100644 index 0000000..7eb8505 --- /dev/null +++ b/tests/fixtures/alb.json @@ -0,0 +1,28 @@ +{ + "requestContext": { + "elb": { + "targetGroupArn": "arn:aws:elasticloadbalancing:us-east-2:123456789012:targetgroup/lambda-279XGJDqGZ5rsrHC2Fjr/49e9d65c45c6791a" + } + }, + "httpMethod": "POST", + "path": "/path/to/resource", + "queryStringParameters": { + "query": "1234ABCD" + }, + "headers": { + "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8", + "accept-encoding": "gzip", + "accept-language": "en-US,en;q=0.9", + "connection": "keep-alive", + "host": "lambda-alb-123578498.us-east-2.elb.amazonaws.com", + "upgrade-insecure-requests": "1", + "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)", + "x-amzn-trace-id": "Root=1-5c536348-3d683b8b04734faae651f476", + "x-forwarded-for": "72.12.164.125", + "x-forwarded-port": "80", + "x-forwarded-proto": "http", + "x-imforwards": "20" + }, + "body": "eyJ0ZXN0IjoiYm9keSJ9", + "isBase64Encoded": true +} \ No newline at end of file diff --git a/tests/fixtures/http_api.json b/tests/fixtures/http_api.json new file mode 100644 index 0000000..59a4bf7 --- /dev/null +++ b/tests/fixtures/http_api.json @@ -0,0 +1,69 @@ +{ + "version": "2.0", + "routeKey": "$default", + "rawPath": "/path/to/resource", + "rawQueryString": "parameter1=value1¶meter1=value2¶meter2=value", + "cookies": [ + "cookie1", + "cookie2" + ], + "headers": { + "Header1": "value1", + "Header2": "value1,value2" + }, + "queryStringParameters": { + "parameter1": "value1,value2", + "parameter2": "value" + }, + "requestContext": { + "accountId": "123456789012", + "apiId": "api-id", + "authentication": { + "clientCert": { + "clientCertPem": "CERT_CONTENT", + "subjectDN": "www.example.com", + "issuerDN": "Example issuer", + "serialNumber": "a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1", + "validity": { + "notBefore": "May 28 12:30:02 2019 GMT", + "notAfter": "Aug 5 09:36:04 2021 GMT" + } + } + }, + "authorizer": { + "jwt": { + "claims": { + "claim1": "value1", + "claim2": "value2" + }, + "scopes": [ + "scope1", + "scope2" + ] + } + }, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "http": { + "method": "POST", + "path": "/path/to/resource", + "protocol": "HTTP/1.1", + "sourceIp": "192.168.0.1/32", + "userAgent": "agent" + }, + "requestId": "id", + "routeKey": "$default", + "stage": "$default", + "time": "12/Mar/2020:19:03:58 +0000", + "timeEpoch": 1583348638390 + }, + "body": "eyJ0ZXN0IjoiYm9keSJ9", + "pathParameters": { + "parameter1": "value1" + }, + "isBase64Encoded": true, + "stageVariables": { + "stageVariable1": "value1", + "stageVariable2": "value2" + } +} diff --git a/tests/test_utils.py b/tests/test_utils.py index fbcebae..81b1dd3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,9 +2,6 @@ import pytest from opentelemetry.sdk.trace import Span -from opentelemetry.semconv._incubating.attributes.faas_attributes import ( - FaasTriggerValues, -) from aws_lambda_opentelemetry import utils from aws_lambda_opentelemetry.typing.context import LambdaContext @@ -34,66 +31,69 @@ def test_cold_start_provisioned_concurrency(self, monkeypatch): class TestLambdaDataSource: @pytest.mark.parametrize( - "key,aws_data_source", + "aws_data_source,extractor_class,event", [ - ("apiId", utils.AwsDataSource.API_GATEWAY), - ("http", utils.AwsDataSource.HTTP_API), - ("elb", utils.AwsDataSource.ELB), + ( + utils.AwsDataSource.API_GATEWAY, + utils.ApiGatewayExtractor, + {"requestContext": {"apiId": ""}}, + ), + ( + utils.AwsDataSource.HTTP_API, + utils.HttpApiExtractor, + {"requestContext": {"http": {}}}, + ), + ( + utils.AwsDataSource.ELB, + utils.ElbExtractor, + {"requestContext": {"elb": {}}}, + ), ], ) def test_http_trigger( self, - key: str, aws_data_source: utils.AwsDataSource, - lambda_context: LambdaContext, + extractor_class: type[utils.AttributeExtractor], + event: dict, ): - event = { - "requestContext": { - key: "example-api-id", - } - } - - mapper = utils.AwsAttributesMapper(event, lambda_context) - assert mapper.faas_trigger == utils.FaasTriggerValues.HTTP - assert mapper.data_source == aws_data_source + extractor = extractor_class() + assert extractor.can_handle(event) + assert extractor.data_source == aws_data_source - @pytest.mark.parametrize( - "detail_type, expected", - [ - ("Scheduled Event", FaasTriggerValues.TIMER), - ("Some Other Event", FaasTriggerValues.PUBSUB), - ], - ) - def test_eventbridge_trigger( - self, - detail_type: str, - expected: FaasTriggerValues, - lambda_context: LambdaContext, - ): - event = { - "source": "aws.events", - "detail-type": detail_type, - } + def test_eventbridge_trigger(self): + event = {"source": "aws.events", "detail-type": "Scheduled Event"} + extractor = utils.EventBridgeExtractor() - mapper = utils.AwsAttributesMapper(event, lambda_context) - assert mapper.faas_trigger == expected - assert mapper.data_source == utils.AwsDataSource.EVENT_BRIDGE + assert extractor.can_handle(event) + assert extractor.data_source == utils.AwsDataSource.EVENT_BRIDGE @pytest.mark.parametrize( - "event_source, aws_data_source, faas_trigger", + "event_source, aws_data_source, extractor_class", [ - ("aws:sns", utils.AwsDataSource.SNS, utils.FaasTriggerValues.PUBSUB), - ("aws:sqs", utils.AwsDataSource.SQS, utils.FaasTriggerValues.PUBSUB), - ("aws:s3", utils.AwsDataSource.S3, utils.FaasTriggerValues.DATASOURCE), + ( + "aws:sns", + utils.AwsDataSource.SNS, + utils.SnsExtractor, + ), + ( + "aws:sqs", + utils.AwsDataSource.SQS, + utils.SqsExtractor, + ), + ( + "aws:s3", + utils.AwsDataSource.S3, + utils.S3Extractor, + ), ( "aws:dynamodb", utils.AwsDataSource.DYNAMODB, - utils.FaasTriggerValues.DATASOURCE, + utils.DynamoDbExtractor, ), ( "aws:kinesis", utils.AwsDataSource.KINESIS, - utils.FaasTriggerValues.DATASOURCE, + utils.KinesisExtractor, ), ], ) @@ -101,8 +101,7 @@ def test_pubsub_trigger( self, event_source: str, aws_data_source: utils.AwsDataSource, - faas_trigger: FaasTriggerValues, - lambda_context: LambdaContext, + extractor_class: type[utils.AttributeExtractor], ): event = { "Records": [ @@ -112,27 +111,28 @@ def test_pubsub_trigger( ] } - mapper = utils.AwsAttributesMapper(event, lambda_context) - assert mapper.faas_trigger == faas_trigger - assert mapper.data_source == aws_data_source + extractor = extractor_class() + + assert extractor.can_handle(event) + assert extractor.data_source == aws_data_source - def test_cloudwatch_logs_trigger(self, lambda_context: LambdaContext): + def test_cloudwatch_logs_trigger(self): event = { "awslogs": { "data": "example-data", } } + extractor = utils.CloudWatchLogsExtractor() - mapper = utils.AwsAttributesMapper(event, lambda_context) - assert mapper.faas_trigger == utils.FaasTriggerValues.DATASOURCE - assert mapper.data_source == utils.AwsDataSource.CLOUDWATCH_LOGS + assert extractor.can_handle(event) + assert extractor.data_source == utils.AwsDataSource.CLOUDWATCH_LOGS - def test_unknown_trigger(self, lambda_context: LambdaContext): + def test_unknown_trigger(self): event = {} + extractor = utils.GenericAwsExtractor() - mapper = utils.AwsAttributesMapper(event, lambda_context) - assert mapper.faas_trigger == utils.FaasTriggerValues.OTHER - assert mapper.data_source == utils.AwsDataSource.OTHER + assert extractor.can_handle(event) + assert extractor.data_source == utils.AwsDataSource.OTHER class TestSetLambdaHandlerAttributes: @@ -144,8 +144,10 @@ def test_general_attributes(self, lambda_context: LambdaContext): ) as mock_span: mock_span.return_value = span - mapper = utils.AwsAttributesMapper({}, lambda_context) - mapper.add_attributes() + extractor = utils.AwsAttributesExtractor({}, lambda_context) + extractor.add_attributes() + + assert span.set_attributes.call_count == 1 attributes = span.set_attributes.call_args_list[0][0][0] assert attributes["faas.invocation_id"] == lambda_context.aws_request_id @@ -155,7 +157,6 @@ def test_general_attributes(self, lambda_context: LambdaContext): assert attributes["faas.max_memory"] == lambda_context.memory_limit_in_mb assert attributes["faas.version"] == lambda_context.function_version assert attributes["faas.coldstart"] is False - assert attributes["faas.trigger"] == "other" assert attributes["cloud.resource_id"] == lambda_context.invoked_function_arn def test_sqs_attributes(self, sqs_event: dict, lambda_context: LambdaContext): @@ -166,15 +167,22 @@ def test_sqs_attributes(self, sqs_event: dict, lambda_context: LambdaContext): ) as mock_span: mock_span.return_value = span - mapper = utils.AwsAttributesMapper(sqs_event, lambda_context) - mapper.add_attributes() + extractor = utils.AwsAttributesExtractor(sqs_event, lambda_context) + extractor.add_attributes() + + assert span.set_attributes.call_count == 2 + + general_attributes = span.set_attributes.call_args_list[0][0][0] + assert general_attributes["faas.invocation_id"] == lambda_context.aws_request_id + assert general_attributes["faas.coldstart"] is False - attributes = span.set_attributes.call_args_list[1][0][0] - assert attributes["messaging.system"] == "aws.sqs" - assert attributes["messaging.destination.name"] == "MyQueue" - assert attributes["messaging.operation"] == "receive" + sqs_attributes = span.set_attributes.call_args_list[1][0][0] + assert sqs_attributes["faas.trigger"] == "pubsub" + assert sqs_attributes["messaging.system"] == "aws.sqs" + assert sqs_attributes["messaging.destination.name"] == "MyQueue" + assert sqs_attributes["messaging.operation"] == "receive" assert ( - attributes["cloud.resource_id"] + sqs_attributes["cloud.resource_id"] == "arn:aws:sqs:us-east-1:123456789012:MyQueue" ) @@ -188,14 +196,117 @@ def test_apigateway_attributes( ) as mock_span: mock_span.return_value = span - mapper = utils.AwsAttributesMapper(apigateway_event, lambda_context) - mapper.add_attributes() - - attributes = span.set_attributes.call_args_list[1][0][0] - assert attributes["http.request.method"] == "POST" - assert attributes["url.full"] == "/path/to/resource" - assert attributes["http.route"] == "/{proxy+}" - assert attributes["http.request.body.size"] == 20 - assert attributes["network.protocol.name"] == "HTTP" - assert attributes["network.protocol.version"] == "1.1" - assert attributes["user_agent.original"] == "Custom User Agent String" + extractor = utils.AwsAttributesExtractor(apigateway_event, lambda_context) + extractor.add_attributes() + + assert span.set_attributes.call_count == 2 + + general_attributes = span.set_attributes.call_args_list[0][0][0] + assert general_attributes["faas.invocation_id"] == lambda_context.aws_request_id + assert general_attributes["faas.coldstart"] is False + + api_gateway_attributes = span.set_attributes.call_args_list[1][0][0] + assert api_gateway_attributes["faas.trigger"] == "http" + assert api_gateway_attributes["http.request.method"] == "POST" + assert api_gateway_attributes["url.full"] == "/path/to/resource" + assert api_gateway_attributes["http.route"] == "/{proxy+}" + assert api_gateway_attributes["http.request.body.size"] == 20 + assert api_gateway_attributes["network.protocol.name"] == "HTTP" + assert api_gateway_attributes["network.protocol.version"] == "1.1" + assert ( + api_gateway_attributes["user_agent.original"] == "Custom User Agent String" + ) + + def test_http_api_attributes( + self, http_api_event: dict, lambda_context: LambdaContext + ): + span = MagicMock(spec=Span) + + with patch( + "aws_lambda_opentelemetry.utils.trace.get_current_span" + ) as mock_span: + mock_span.return_value = span + + extractor = utils.AwsAttributesExtractor(http_api_event, lambda_context) + extractor.add_attributes() + + assert span.set_attributes.call_count == 2 + + general_attributes = span.set_attributes.call_args_list[0][0][0] + assert general_attributes["faas.invocation_id"] == lambda_context.aws_request_id + assert general_attributes["faas.coldstart"] is False + + http_attributes = span.set_attributes.call_args_list[1][0][0] + assert http_attributes["faas.trigger"] == "http" + assert http_attributes["http.request.method"] == "POST" + assert http_attributes["url.full"] == "/path/to/resource" + assert http_attributes["http.route"] == "$default" + assert http_attributes["http.request.body.size"] == 20 + assert http_attributes["network.protocol.name"] == "HTTP" + assert http_attributes["network.protocol.version"] == "1.1" + assert http_attributes["user_agent.original"] == "agent" + + def test_http_api_attributes_with_empty_protocol( + self, http_api_event: dict, lambda_context: LambdaContext + ): + span = MagicMock(spec=Span) + + http_api_event["requestContext"]["http"]["protocol"] = "" + + with patch( + "aws_lambda_opentelemetry.utils.trace.get_current_span" + ) as mock_span: + mock_span.return_value = span + + extractor = utils.AwsAttributesExtractor(http_api_event, lambda_context) + extractor.add_attributes() + + http_attributes = span.set_attributes.call_args_list[1][0][0] + assert http_attributes["network.protocol.name"] == "" + assert http_attributes["network.protocol.version"] == "" + + def test_http_api_attributes_with_missing_body( + self, http_api_event: dict, lambda_context: LambdaContext + ): + span = MagicMock(spec=Span) + + http_api_event["body"] = None + + with patch( + "aws_lambda_opentelemetry.utils.trace.get_current_span" + ) as mock_span: + mock_span.return_value = span + + extractor = utils.AwsAttributesExtractor(http_api_event, lambda_context) + extractor.add_attributes() + + http_attributes = span.set_attributes.call_args_list[1][0][0] + assert http_attributes["http.request.body.size"] == 0 + + def test_alb_attributes(self, alb_event: dict, lambda_context: LambdaContext): + span = MagicMock(spec=Span) + + with patch( + "aws_lambda_opentelemetry.utils.trace.get_current_span" + ) as mock_span: + mock_span.return_value = span + + extractor = utils.AwsAttributesExtractor(alb_event, lambda_context) + extractor.add_attributes() + + assert span.set_attributes.call_count == 2 + + general_attributes = span.set_attributes.call_args_list[0][0][0] + assert general_attributes["faas.invocation_id"] == lambda_context.aws_request_id + assert general_attributes["faas.coldstart"] is False + + alb_attributes = span.set_attributes.call_args_list[1][0][0] + assert alb_attributes["faas.trigger"] == "http" + assert alb_attributes["http.request.method"] == "POST" + assert alb_attributes["url.full"] == "/path/to/resource" + assert alb_attributes["http.route"] == "/path/to/resource" + assert alb_attributes["http.request.body.size"] == 20 + assert ( + alb_attributes["user_agent.original"] + == "Mozilla/5.0 (Windows NT 10.0; Win64; x64)" + )