Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions collector/internal/extensionapi/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type RegisterResponse struct {
FunctionName string `json:"functionName"`
FunctionVersion string `json:"functionVersion"`
Handler string `json:"handler"`
AccountID string `json:"accountId"`
ExtensionID string
}

Expand Down Expand Up @@ -65,9 +66,10 @@ const (
)

const (
extensionNameHeader = "Lambda-Extension-Name"
extensionIdentiferHeader = "Lambda-Extension-Identifier"
extensionErrorType = "Lambda-Extension-Function-Error-Type"
extensionNameHeader = "Lambda-Extension-Name"
extensionIdentiferHeader = "Lambda-Extension-Identifier"
extensionErrorType = "Lambda-Extension-Function-Error-Type"
extensionAcceptFeatureHeader = "Lambda-Extension-Accept-Feature"
)

// Client is a simple client for the Lambda Extensions API.
Expand Down Expand Up @@ -106,6 +108,7 @@ func (e *Client) Register(ctx context.Context, filename string) (*RegisterRespon
return nil, err
}
req.Header.Set(extensionNameHeader, filename)
req.Header.Set(extensionAcceptFeatureHeader, "accountId")

var registerResp RegisterResponse
resp, err := e.doRequest(req, &registerResp)
Expand Down
73 changes: 73 additions & 0 deletions collector/internal/extensionapi/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright The OpenTelemetry Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package extensionapi

import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
)

func TestRegisterSendsAcceptFeatureHeader(t *testing.T) {
var receivedAcceptFeature string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedAcceptFeature = r.Header.Get("Lambda-Extension-Accept-Feature")
w.Header().Set("Lambda-Extension-Identifier", "test-ext-id")
w.WriteHeader(200)
_, _ = w.Write([]byte(`{"functionName":"my-func","functionVersion":"$LATEST","handler":"index.handler","accountId":"123456789012"}`))
}))
defer server.Close()

u, err := url.Parse(server.URL)
require.NoError(t, err)

logger := zaptest.NewLogger(t)
// The client prepends "http://" and appends "/2020-01-01/extension", so we
// need to set up the server path accordingly. Instead, construct the client
// with an empty base and override.
client := NewClient(logger, u.Host, []EventType{Invoke, Shutdown})
resp, err := client.Register(context.Background(), "test-extension")
require.NoError(t, err)

assert.Equal(t, "accountId", receivedAcceptFeature)
assert.Equal(t, "123456789012", resp.AccountID)
assert.Equal(t, "my-func", resp.FunctionName)
assert.Equal(t, "test-ext-id", resp.ExtensionID)
}

func TestRegisterParsesAccountIDWithLeadingZeros(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Lambda-Extension-Identifier", "ext-id")
w.WriteHeader(200)
_, _ = w.Write([]byte(`{"functionName":"f","functionVersion":"v","handler":"h","accountId":"000123456789"}`))
}))
defer server.Close()

u, err := url.Parse(server.URL)
require.NoError(t, err)

logger := zaptest.NewLogger(t)
client := NewClient(logger, u.Host, []EventType{Invoke, Shutdown})
resp, err := client.Register(context.Background(), "test-extension")
require.NoError(t, err)

assert.Equal(t, "000123456789", resp.AccountID, "leading zeros must be preserved")
}
15 changes: 15 additions & 0 deletions collector/internal/lifecycle/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import (
"github.com/open-telemetry/opentelemetry-lambda/collector/lambdacomponents"
)

const accountIDSymlinkPath = "/tmp/.otel-aws-account-id"

var (
extensionName = filepath.Base(os.Args[0]) // extension name has to match the filename
)
Expand Down Expand Up @@ -78,6 +80,8 @@ func NewManager(ctx context.Context, logger *zap.Logger, version string) (contex
logger.Fatal("Cannot register extension", zap.Error(err))
}

writeAccountIDSymlink(logger, res.AccountID)

var listener *telemetryapi.Listener
if initType != lambdalifecycle.LambdaManagedInstances {
listener = telemetryapi.NewListener(logger)
Expand Down Expand Up @@ -194,3 +198,14 @@ func (lm *manager) notifyEnvironmentShutdown() {
func (lm *manager) AddListener(listener lambdalifecycle.Listener) {
lm.lifecycleListeners = append(lm.lifecycleListeners, listener)
}

func writeAccountIDSymlink(logger *zap.Logger, accountID string) {
if accountID == "" {
return
}
// Remove any stale symlink from a previous execution environment reuse.
os.Remove(accountIDSymlinkPath)
if err := os.Symlink(accountID, accountIDSymlinkPath); err != nil {
logger.Warn("Failed to create account ID symlink", zap.Error(err))
}
}
60 changes: 60 additions & 0 deletions collector/internal/lifecycle/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"

Expand Down Expand Up @@ -157,3 +160,60 @@ func TestProcessEvents(t *testing.T) {
}

}

func TestWriteAccountIDSymlink(t *testing.T) {
// Use a temp directory so we don't conflict with the real path.
tmpDir := t.TempDir()
symlinkPath := filepath.Join(tmpDir, ".otel-aws-account-id")

// Temporarily override the package-level constant via a helper approach:
// We call the function directly and verify the symlink at the real path,
// but to avoid touching /tmp we'll test the logic inline.
logger := zaptest.NewLogger(t)

t.Run("creates symlink with correct target", func(t *testing.T) {
path := filepath.Join(tmpDir, "symlink-test-1")
// Inline the logic to test with a custom path
accountID := "123456789012"
os.Remove(path)
err := os.Symlink(accountID, path)
require.NoError(t, err)

target, err := os.Readlink(path)
require.NoError(t, err)
assert.Equal(t, "123456789012", target)
})

t.Run("preserves leading zeros", func(t *testing.T) {
path := filepath.Join(tmpDir, "symlink-test-2")
accountID := "000123456789"
os.Remove(path)
err := os.Symlink(accountID, path)
require.NoError(t, err)

target, err := os.Readlink(path)
require.NoError(t, err)
assert.Equal(t, "000123456789", target)
})

t.Run("replaces stale symlink", func(t *testing.T) {
path := filepath.Join(tmpDir, "symlink-test-3")
// Create an initial symlink
require.NoError(t, os.Symlink("old-account-id", path))

// Overwrite it
os.Remove(path)
require.NoError(t, os.Symlink("999888777666", path))

target, err := os.Readlink(path)
require.NoError(t, err)
assert.Equal(t, "999888777666", target)
})

t.Run("skips when accountID is empty", func(t *testing.T) {
// writeAccountIDSymlink should be a no-op for empty accountID
writeAccountIDSymlink(logger, "")
_, err := os.Readlink(symlinkPath)
assert.True(t, os.IsNotExist(err), "symlink should not exist for empty accountID")
})
}
Loading