diff --git a/lambdacontext/example_logger_test.go b/lambdacontext/example_logger_test.go new file mode 100644 index 00000000..e20f422e --- /dev/null +++ b/lambdacontext/example_logger_test.go @@ -0,0 +1,65 @@ +//go:build go1.21 +// +build go1.21 + +package lambdacontext_test + +import ( + "context" + "log/slog" + + "github.com/aws/aws-lambda-go/lambda" + "github.com/aws/aws-lambda-go/lambdacontext" +) + +// ExampleNewLogger demonstrates the simplest usage of NewLogger for structured logging. +// The logger automatically injects requestId from Lambda context into each log record. +func ExampleNewLogger() { + // Set up the Lambda-aware slog logger + slog.SetDefault(lambdacontext.NewLogger()) + + lambda.Start(func(ctx context.Context) (string, error) { + // Use slog.InfoContext to include Lambda context in logs + slog.InfoContext(ctx, "processing request", "action", "example") + return "success", nil + }) +} + +// ExampleNewLogHandler demonstrates using NewLogHandler for more control. +func ExampleNewLogHandler() { + // Set up the Lambda-aware slog handler + slog.SetDefault(slog.New(lambdacontext.NewLogHandler())) + + lambda.Start(func(ctx context.Context) (string, error) { + slog.InfoContext(ctx, "processing request", "action", "example") + return "success", nil + }) +} + +// ExampleNewLogHandler_withOptions demonstrates NewLogHandler with additional fields. +// Use WithFunctionARN() and WithTenantID() to include extra context. +func ExampleNewLogHandler_withOptions() { + // Set up handler with function ARN and tenant ID fields + slog.SetDefault(slog.New(lambdacontext.NewLogHandler( + lambdacontext.WithFunctionARN(), + lambdacontext.WithTenantID(), + ))) + + lambda.Start(func(ctx context.Context) (string, error) { + slog.InfoContext(ctx, "multi-tenant request", "tenant", "acme-corp") + return "success", nil + }) +} + +// ExampleWithFunctionARN demonstrates using WithFunctionARN to include the function ARN. +func ExampleWithFunctionARN() { + // Include only function ARN + slog.SetDefault(lambdacontext.NewLogger( + lambdacontext.WithFunctionARN(), + )) + + lambda.Start(func(ctx context.Context) (string, error) { + // Log output will include "functionArn" field + slog.InfoContext(ctx, "function invoked") + return "success", nil + }) +} diff --git a/lambdacontext/logger.go b/lambdacontext/logger.go new file mode 100644 index 00000000..4e0bd6de --- /dev/null +++ b/lambdacontext/logger.go @@ -0,0 +1,151 @@ +//go:build go1.21 +// +build go1.21 + +// Copyright 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +package lambdacontext + +import ( + "context" + "log/slog" + "os" +) + +// logFormat is the log format from AWS_LAMBDA_LOG_FORMAT (TEXT or JSON) +var logFormat = os.Getenv("AWS_LAMBDA_LOG_FORMAT") + +// logLevel is the log level from AWS_LAMBDA_LOG_LEVEL +var logLevel = os.Getenv("AWS_LAMBDA_LOG_LEVEL") + +// field represents a Lambda context field to include in log records. +type field struct { + key string + value func(*LambdaContext) string +} + +// logOptions holds configuration for the Lambda log handler. +type logOptions struct { + fields []field +} + +// LogOption is a functional option for configuring the Lambda log handler. +type LogOption func(*logOptions) + +// WithFunctionARN includes the invoked function ARN in log records. +func WithFunctionARN() LogOption { + return func(o *logOptions) { + o.fields = append(o.fields, field{"functionArn", func(lc *LambdaContext) string { return lc.InvokedFunctionArn }}) + } +} + +// WithTenantID includes the tenant ID in log records (for multi-tenant functions). +func WithTenantID() LogOption { + return func(o *logOptions) { + o.fields = append(o.fields, field{"tenantId", func(lc *LambdaContext) string { return lc.TenantID }}) + } +} + +// NewLogHandler returns a [slog.Handler] for AWS Lambda structured logging. +// It reads AWS_LAMBDA_LOG_FORMAT and AWS_LAMBDA_LOG_LEVEL from environment, +// and injects requestId from Lambda context into each log record. +// +// By default, only requestId is injected. Use WithFunctionARN or WithTenantID to include more. +// See the package examples for usage. +func NewLogHandler(opts ...LogOption) slog.Handler { + options := &logOptions{} + for _, opt := range opts { + opt(options) + } + + level := parseLogLevel() + handlerOpts := &slog.HandlerOptions{ + Level: level, + ReplaceAttr: ReplaceAttr, + } + + var h slog.Handler + if logFormat == "JSON" { + h = slog.NewJSONHandler(os.Stdout, handlerOpts) + } else { + h = slog.NewTextHandler(os.Stdout, handlerOpts) + } + + return &lambdaHandler{handler: h, fields: options.fields} +} + +// NewLogger returns a [*slog.Logger] configured for AWS Lambda structured logging. +// This is a convenience function equivalent to slog.New(NewLogHandler(opts...)). +func NewLogger(opts ...LogOption) *slog.Logger { + return slog.New(NewLogHandler(opts...)) +} + +// ReplaceAttr maps slog's default keys to AWS Lambda's log format (time->timestamp, msg->message). +func ReplaceAttr(groups []string, attr slog.Attr) slog.Attr { + if len(groups) > 0 { + return attr + } + + switch attr.Key { + case slog.TimeKey: + attr.Key = "timestamp" + case slog.MessageKey: + attr.Key = "message" + } + return attr +} + +// lambdaHandler wraps a slog.Handler to inject Lambda context fields. +type lambdaHandler struct { + handler slog.Handler + fields []field +} + +// Enabled implements slog.Handler. +func (h *lambdaHandler) Enabled(ctx context.Context, level slog.Level) bool { + return h.handler.Enabled(ctx, level) +} + +// Handle implements slog.Handler. +func (h *lambdaHandler) Handle(ctx context.Context, r slog.Record) error { + if lc, ok := FromContext(ctx); ok { + r.AddAttrs(slog.String("requestId", lc.AwsRequestID)) + + for _, field := range h.fields { + if v := field.value(lc); v != "" { + r.AddAttrs(slog.String(field.key, v)) + } + } + } + return h.handler.Handle(ctx, r) +} + +// WithAttrs implements slog.Handler. +func (h *lambdaHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return &lambdaHandler{ + handler: h.handler.WithAttrs(attrs), + fields: h.fields, + } +} + +// WithGroup implements slog.Handler. +func (h *lambdaHandler) WithGroup(name string) slog.Handler { + return &lambdaHandler{ + handler: h.handler.WithGroup(name), + fields: h.fields, + } +} + +func parseLogLevel() slog.Level { + switch logLevel { + case "DEBUG": + return slog.LevelDebug + case "INFO": + return slog.LevelInfo + case "WARN": + return slog.LevelWarn + case "ERROR": + return slog.LevelError + default: + return slog.LevelInfo + } +} diff --git a/lambdacontext/logger_test.go b/lambdacontext/logger_test.go new file mode 100644 index 00000000..a09e49c3 --- /dev/null +++ b/lambdacontext/logger_test.go @@ -0,0 +1,399 @@ +//go:build go1.21 +// +build go1.21 + +// Copyright 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +package lambdacontext + +import ( + "bytes" + "context" + "encoding/json" + "log/slog" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestReplaceAttr(t *testing.T) { + tests := []struct { + name string + groups []string + attr slog.Attr + expected slog.Attr + }{ + { + name: "time to timestamp", + groups: nil, + attr: slog.String(slog.TimeKey, "2025-01-09T12:00:00Z"), + expected: slog.String("timestamp", "2025-01-09T12:00:00Z"), + }, + { + name: "msg to message", + groups: nil, + attr: slog.String(slog.MessageKey, "test message"), + expected: slog.String("message", "test message"), + }, + { + name: "level unchanged", + groups: nil, + attr: slog.String(slog.LevelKey, "INFO"), + expected: slog.String(slog.LevelKey, "INFO"), + }, + { + name: "custom key unchanged", + groups: nil, + attr: slog.String("customKey", "value"), + expected: slog.String("customKey", "value"), + }, + { + name: "grouped attrs not replaced", + groups: []string{"group1"}, + attr: slog.String(slog.TimeKey, "2025-01-09T12:00:00Z"), + expected: slog.String(slog.TimeKey, "2025-01-09T12:00:00Z"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ReplaceAttr(tt.groups, tt.attr) + assert.Equal(t, tt.expected.Key, result.Key) + assert.Equal(t, tt.expected.Value.String(), result.Value.String()) + }) + } +} + +func TestParseLogLevel(t *testing.T) { + tests := []struct { + name string + input string + expected slog.Level + }{ + {"DEBUG", "DEBUG", slog.LevelDebug}, + {"INFO", "INFO", slog.LevelInfo}, + {"WARN", "WARN", slog.LevelWarn}, + {"ERROR", "ERROR", slog.LevelError}, + {"empty", "", slog.LevelInfo}, + {"INVALID", "INVALID", slog.LevelInfo}, + {"lowercase debug", "debug", slog.LevelInfo}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logLevel = tt.input + result := parseLogLevel() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestLogHandler_JSONFormat(t *testing.T) { + var buf bytes.Buffer + + opts := &slog.HandlerOptions{ + Level: slog.LevelInfo, + ReplaceAttr: ReplaceAttr, + } + baseHandler := slog.NewJSONHandler(&buf, opts) + handler := &lambdaHandler{handler: baseHandler} + + lc := &LambdaContext{AwsRequestID: "test-request-123"} + ctx := NewContext(context.Background(), lc) + + logger := slog.New(handler) + logger.InfoContext(ctx, "test message", "key", "value") + + var logOutput map[string]interface{} + err := json.Unmarshal(buf.Bytes(), &logOutput) + require.NoError(t, err) + + assert.Equal(t, "INFO", logOutput["level"]) + assert.Equal(t, "test message", logOutput["message"]) + assert.Equal(t, "test-request-123", logOutput["requestId"]) + assert.Equal(t, "value", logOutput["key"]) + assert.Contains(t, logOutput, "timestamp") + assert.NotContains(t, logOutput, "functionArn") + assert.NotContains(t, logOutput, "tenantId") +} + +func TestLogHandler_NoLambdaContext(t *testing.T) { + var buf bytes.Buffer + + opts := &slog.HandlerOptions{ + Level: slog.LevelInfo, + ReplaceAttr: ReplaceAttr, + } + baseHandler := slog.NewJSONHandler(&buf, opts) + handler := &lambdaHandler{handler: baseHandler} + + ctx := context.Background() + + logger := slog.New(handler) + logger.InfoContext(ctx, "no context message") + + var logOutput map[string]interface{} + err := json.Unmarshal(buf.Bytes(), &logOutput) + require.NoError(t, err) + + assert.Equal(t, "no context message", logOutput["message"]) + assert.NotContains(t, logOutput, "requestId") +} + +func TestLogHandler_ConcurrencySafe(t *testing.T) { + var buf1, buf2 bytes.Buffer + + opts := &slog.HandlerOptions{ + Level: slog.LevelInfo, + ReplaceAttr: ReplaceAttr, + } + + handler1 := &lambdaHandler{handler: slog.NewJSONHandler(&buf1, opts)} + handler2 := &lambdaHandler{handler: slog.NewJSONHandler(&buf2, opts)} + + lc1 := &LambdaContext{AwsRequestID: "request-aaa"} + lc2 := &LambdaContext{AwsRequestID: "request-bbb"} + + ctx1 := NewContext(context.Background(), lc1) + ctx2 := NewContext(context.Background(), lc2) + + logger1 := slog.New(handler1) + logger2 := slog.New(handler2) + + logger1.InfoContext(ctx1, "message 1") + logger2.InfoContext(ctx2, "message 2") + + var output1, output2 map[string]interface{} + require.NoError(t, json.Unmarshal(buf1.Bytes(), &output1)) + require.NoError(t, json.Unmarshal(buf2.Bytes(), &output2)) + + assert.Equal(t, "request-aaa", output1["requestId"]) + assert.Equal(t, "request-bbb", output2["requestId"]) +} + +func TestLogHandler_SharedHandlerConcurrencySafe(t *testing.T) { + var buf bytes.Buffer + + opts := &slog.HandlerOptions{ + Level: slog.LevelInfo, + ReplaceAttr: ReplaceAttr, + } + + sharedHandler := &lambdaHandler{handler: slog.NewJSONHandler(&buf, opts)} + logger := slog.New(sharedHandler) + + lc1 := &LambdaContext{AwsRequestID: "request-aaa"} + lc2 := &LambdaContext{AwsRequestID: "request-bbb"} + + ctx1 := NewContext(context.Background(), lc1) + ctx2 := NewContext(context.Background(), lc2) + + logger.InfoContext(ctx1, "message 1") + logger.InfoContext(ctx2, "message 2") + + lines := bytes.Split(bytes.TrimSpace(buf.Bytes()), []byte("\n")) + require.Len(t, lines, 2) + + var output1, output2 map[string]interface{} + require.NoError(t, json.Unmarshal(lines[0], &output1)) + require.NoError(t, json.Unmarshal(lines[1], &output2)) + + assert.Equal(t, "request-aaa", output1["requestId"]) + assert.Equal(t, "request-bbb", output2["requestId"]) +} + +func TestLogHandler_WithAttrs(t *testing.T) { + var buf bytes.Buffer + + opts := &slog.HandlerOptions{ + Level: slog.LevelInfo, + ReplaceAttr: ReplaceAttr, + } + baseHandler := slog.NewJSONHandler(&buf, opts) + handler := &lambdaHandler{handler: baseHandler} + + lc := &LambdaContext{AwsRequestID: "test-request"} + ctx := NewContext(context.Background(), lc) + + logger := slog.New(handler).With("service", "test-service") + logger.InfoContext(ctx, "test message") + + var logOutput map[string]interface{} + err := json.Unmarshal(buf.Bytes(), &logOutput) + require.NoError(t, err) + + assert.Equal(t, "test-request", logOutput["requestId"]) + assert.Equal(t, "test-service", logOutput["service"]) +} + +func TestLogHandler_WithGroup(t *testing.T) { + var buf bytes.Buffer + + opts := &slog.HandlerOptions{ + Level: slog.LevelInfo, + ReplaceAttr: ReplaceAttr, + } + baseHandler := slog.NewJSONHandler(&buf, opts) + handler := &lambdaHandler{handler: baseHandler} + + lc := &LambdaContext{AwsRequestID: "test-request"} + ctx := NewContext(context.Background(), lc) + + logger := slog.New(handler).WithGroup("app").With("version", "1.0") + logger.InfoContext(ctx, "test message") + + var logOutput map[string]interface{} + err := json.Unmarshal(buf.Bytes(), &logOutput) + require.NoError(t, err) + + app, ok := logOutput["app"].(map[string]interface{}) + require.True(t, ok, "expected 'app' group in output: %s", buf.String()) + assert.Equal(t, "1.0", app["version"]) + assert.Equal(t, "test-request", app["requestId"]) +} + +func TestLogHandler_WithFields(t *testing.T) { + var buf bytes.Buffer + + opts := &slog.HandlerOptions{ + Level: slog.LevelInfo, + ReplaceAttr: ReplaceAttr, + } + baseHandler := slog.NewJSONHandler(&buf, opts) + + // Create options with fields + options := &logOptions{} + WithFunctionARN()(options) + WithTenantID()(options) + + handler := &lambdaHandler{ + handler: baseHandler, + fields: options.fields, + } + + lc := &LambdaContext{ + AwsRequestID: "test-request-123", + InvokedFunctionArn: "arn:aws:lambda:us-east-1:123456789:function:test", + TenantID: "tenant-abc", + } + ctx := NewContext(context.Background(), lc) + + logger := slog.New(handler) + logger.InfoContext(ctx, "test message") + + var logOutput map[string]interface{} + err := json.Unmarshal(buf.Bytes(), &logOutput) + require.NoError(t, err) + + assert.Equal(t, "test-request-123", logOutput["requestId"]) + assert.Equal(t, "arn:aws:lambda:us-east-1:123456789:function:test", logOutput["functionArn"]) + assert.Equal(t, "tenant-abc", logOutput["tenantId"]) +} + +func TestLogHandler_WithFieldFunctionARNOnly(t *testing.T) { + var buf bytes.Buffer + + opts := &slog.HandlerOptions{ + Level: slog.LevelInfo, + ReplaceAttr: ReplaceAttr, + } + baseHandler := slog.NewJSONHandler(&buf, opts) + + options := &logOptions{} + WithFunctionARN()(options) + + handler := &lambdaHandler{ + handler: baseHandler, + fields: options.fields, + } + + lc := &LambdaContext{ + AwsRequestID: "test-request-123", + InvokedFunctionArn: "arn:aws:lambda:us-east-1:123456789:function:test", + TenantID: "tenant-abc", + } + ctx := NewContext(context.Background(), lc) + + logger := slog.New(handler) + logger.InfoContext(ctx, "test message") + + var logOutput map[string]interface{} + err := json.Unmarshal(buf.Bytes(), &logOutput) + require.NoError(t, err) + + assert.Equal(t, "test-request-123", logOutput["requestId"]) + assert.Equal(t, "arn:aws:lambda:us-east-1:123456789:function:test", logOutput["functionArn"]) + assert.NotContains(t, logOutput, "tenantId") +} + +func TestLogHandler_FieldsEmpty(t *testing.T) { + var buf bytes.Buffer + + opts := &slog.HandlerOptions{ + Level: slog.LevelInfo, + ReplaceAttr: ReplaceAttr, + } + baseHandler := slog.NewJSONHandler(&buf, opts) + + options := &logOptions{} + WithFunctionARN()(options) + WithTenantID()(options) + + handler := &lambdaHandler{ + handler: baseHandler, + fields: options.fields, + } + + lc := &LambdaContext{ + AwsRequestID: "test-request-123", + InvokedFunctionArn: "", + TenantID: "", + } + ctx := NewContext(context.Background(), lc) + + logger := slog.New(handler) + logger.InfoContext(ctx, "test message") + + var logOutput map[string]interface{} + err := json.Unmarshal(buf.Bytes(), &logOutput) + require.NoError(t, err) + + assert.Equal(t, "test-request-123", logOutput["requestId"]) + assert.NotContains(t, logOutput, "functionArn") + assert.NotContains(t, logOutput, "tenantId") +} + +func TestWithFunctionARN(t *testing.T) { + options := &logOptions{} + WithFunctionARN()(options) + + assert.Len(t, options.fields, 1) + assert.Equal(t, "functionArn", options.fields[0].key) + + lc := &LambdaContext{InvokedFunctionArn: "arn:aws:lambda:us-east-1:123456789:function:test"} + assert.Equal(t, "arn:aws:lambda:us-east-1:123456789:function:test", options.fields[0].value(lc)) +} + +func TestWithTenantID(t *testing.T) { + options := &logOptions{} + WithTenantID()(options) + + assert.Len(t, options.fields, 1) + assert.Equal(t, "tenantId", options.fields[0].key) + + lc := &LambdaContext{TenantID: "tenant-abc"} + assert.Equal(t, "tenant-abc", options.fields[0].value(lc)) +} + +func TestNewLogger(t *testing.T) { + logger := NewLogger() + assert.NotNil(t, logger) +} + +func TestNewLogHandler(t *testing.T) { + handler := NewLogHandler() + assert.NotNil(t, handler) + + handlerWithOpts := NewLogHandler(WithFunctionARN(), WithTenantID()) + assert.NotNil(t, handlerWithOpts) +}