Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
TEST_FLAGS ?= -p 8 -failfast -race -shuffle on
GOTOOLCHAIN := $(shell cat go.mod | grep "^go" | tr -d ' ')

all:
@echo "make <cmd>:"
Expand All @@ -22,11 +23,10 @@ test-coverage-inspect: test-coverage
go tool cover -html=coverage.out

generate:
go generate -x ./...
WEBRPC_SCHEMA_VERSION=$(shell git log -1 --date=format:'v0-%y.%-m.%-d' --format='%ad+%h' ./proto/*.ridl) \
GOTOOLCHAIN=$(GOTOOLCHAIN) go generate -x ./...

.PHONY: proto
proto:
go generate -x ./proto/...
proto: generate

lint:
golangci-lint run ./... --fix -c .golangci.yml
Expand Down
6 changes: 3 additions & 3 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func origin(v string) requestOption {
}
}

func executeRequest(t *testing.T, ctx context.Context, handler http.Handler, path string, options ...requestOption) (bool, error) {
func executeRequest(t *testing.T, ctx context.Context, handler http.Handler, path string, options ...requestOption) (bool, http.Header, error) {
req, err := http.NewRequest("POST", path, nil)
require.NoError(t, err)

Expand All @@ -57,10 +57,10 @@ func executeRequest(t *testing.T, ctx context.Context, handler http.Handler, pat
webrpcErr := proto.WebRPCError{}
err = json.Unmarshal(rr.Body.Bytes(), &webrpcErr)
require.NoError(t, err, "failed to unmarshal response body: %s", rr.Body.Bytes())
return false, webrpcErr
return false, rr.Header(), webrpcErr
}

return true, nil
return true, rr.Header(), nil
}

func TestVerify(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions go.work.sum
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtX
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
Expand Down
9 changes: 4 additions & 5 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"log/slog"
"net/http"
"slices"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -211,11 +212,9 @@ func Session(cfg Options) func(next http.Handler) http.Handler {
}

if adminClaim {
if scopeClaim == "" || scopeClaim == cfg.ServiceName || strings.Contains(scopeClaim, cfg.ServiceName) {
// Allow admin if no scope claim is provided or if it matches service name.
sessionType = proto.SessionType_Admin
} else {
// Reduce to public if scope claim does not match.
sessionType = proto.SessionType_Admin
// Reduce to public if a scope is provided and the claim does not match.
if scopeClaim != "" && !slices.Contains(strings.Split(scopeClaim, ","), cfg.ServiceName) {
sessionType = proto.SessionType_Public
}
}
Expand Down
66 changes: 47 additions & 19 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func TestSession(t *testing.T) {
session = proto.SessionType_AccessKey
}

ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", service, method), options...)
ok, _, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", service, method), options...)
if !expectedACL.Includes(session) {
assert.Error(t, err)
assert.False(t, ok)
Expand Down Expand Up @@ -196,6 +196,7 @@ func TestInvalid(t *testing.T) {
AdminAddress: true,
},
AccessKeyFuncs: []authcontrol.AccessKeyFunc{keyFunc},
ServiceName: ServiceName,
}

r := chi.NewRouter()
Expand All @@ -216,39 +217,39 @@ func TestInvalid(t *testing.T) {
}))

// Without JWT
ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(""))
ok, _, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(""))
assert.True(t, ok)
assert.NoError(t, err)

// Wrong JWT
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt("wrong-secret"))
ok, _, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt("wrong-secret"))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrUnauthorized)

claims := map[string]any{"service": "client_service"}

// Valid Request
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
// Valid S2S Request
ok, _, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
assert.True(t, ok)
assert.NoError(t, err)

// Invalid request path with wrong not enough parts in path for valid RPC request, this will delegate to next handler and return no error
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
ok, _, err = executeRequest(t, ctx, r, fmt.Sprintf("/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
assert.True(t, ok)
assert.NoError(t, err)

// Invalid request path with wrong "rpc", this will delegate to next handler and return no error
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/pcr/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
ok, _, err = executeRequest(t, ctx, r, fmt.Sprintf("/pcr/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
assert.True(t, ok)
assert.NoError(t, err)

// Invalid Service, this will delegate to next handler and return no error
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
ok, _, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
assert.True(t, ok)
assert.NoError(t, err)

// Invalid Method, this will delegate to next handler and return no error
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
ok, _, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
assert.True(t, ok)
assert.NoError(t, err)

Expand All @@ -257,19 +258,46 @@ func TestInvalid(t *testing.T) {
expiredJWT := authcontrol.S2SToken(JWTSecret, claims)

// Expired JWT Token valid method
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(expiredJWT))
ok, _, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(expiredJWT))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrSessionExpired)

// Expired JWT Token invalid service
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), accessKey(AccessKey), jwt(expiredJWT))
ok, _, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), accessKey(AccessKey), jwt(expiredJWT))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrSessionExpired)

// Expired JWT Token invalid method
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), accessKey(AccessKey), jwt(expiredJWT))
ok, _, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), accessKey(AccessKey), jwt(expiredJWT))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrSessionExpired)

// Valid Admin Request (no scope claim)
claims = map[string]any{"account": AdminAddress, "admin": true}
ok, _, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
assert.True(t, ok)
assert.NoError(t, err)

// Valid Admin Request (with matching scope claim)
claims = map[string]any{"account": AdminAddress, "admin": true, "scope": ServiceName}
ok, headers, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
assert.True(t, ok)
assert.NoError(t, err)
assert.Equal(t, "Admin", headers.Get(authcontrol.HeaderSessionType))

// Valid Admin Request (with multiple scope claims)
claims = map[string]any{"account": AdminAddress, "admin": true, "scope": ServiceName + ",other_service"}
ok, headers, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
assert.True(t, ok)
assert.NoError(t, err)
assert.Equal(t, "Admin", headers.Get(authcontrol.HeaderSessionType))

// Invalid Admin Request (with non-matching scope claim)
claims = map[string]any{"account": AdminAddress, "admin": true, "scope": "other_service"}
ok, headers, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
assert.True(t, ok)
assert.NoError(t, err)
assert.NotEqual(t, "User", headers.Get(authcontrol.HeaderSessionType))
}

func TestCustomErrHandler(t *testing.T) {
Expand Down Expand Up @@ -328,12 +356,12 @@ func TestCustomErrHandler(t *testing.T) {
claims := map[string]any{"service": "client_service"}

// Valid Request
ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
ok, _, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
assert.True(t, ok)
assert.NoError(t, err)

// Invalid Access, should return custom error
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName))
ok, _, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName))
assert.False(t, ok)
assert.ErrorIs(t, err, customErr)
}
Expand All @@ -356,17 +384,17 @@ func TestOrigin(t *testing.T) {
})

// No Origin header
ok, err := executeRequest(t, ctx, r, "", jwt(token))
ok, _, err := executeRequest(t, ctx, r, "", jwt(token))
assert.True(t, ok)
assert.NoError(t, err)

// Valid Origin header
ok, err = executeRequest(t, ctx, r, "", jwt(token), origin("http://localhost"))
ok, _, err = executeRequest(t, ctx, r, "", jwt(token), origin("http://localhost"))
assert.True(t, ok)
assert.NoError(t, err)

// Invalid Origin header
ok, err = executeRequest(t, ctx, r, "", jwt(token), origin("http://evil.com"))
ok, _, err = executeRequest(t, ctx, r, "", jwt(token), origin("http://evil.com"))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrUnauthorized)
}
Expand Down Expand Up @@ -403,7 +431,7 @@ func TestProjectVerifier(t *testing.T) {
"project_id": projectID,
})

ok, err := executeRequest(t, ctx, r, "", jwt(token))
ok, _, err := executeRequest(t, ctx, r, "", jwt(token))
assert.True(t, ok)
assert.NoError(t, err)

Expand All @@ -429,7 +457,7 @@ func TestProjectVerifier(t *testing.T) {
})
require.NoError(t, err)

ok, err = executeRequest(t, ctx, r, "", jwt(token))
ok, _, err = executeRequest(t, ctx, r, "", jwt(token))
assert.True(t, ok)
assert.NoError(t, err)
}
23 changes: 11 additions & 12 deletions proto/authcontrol.errors.ridl
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
webrpc = v1

name = authcontrol
version = v0.9.1

error 1000 Unauthorized "Unauthorized access" HTTP 401
error 1001 PermissionDenied "Permission denied" HTTP 403
error 1002 SessionExpired "Session expired" HTTP 403
error 1003 MethodNotFound "Method not found" HTTP 404
error 1004 RequestConflict "Conflict with target resource" HTTP 409
error 1005 Aborted "Request aborted" HTTP 400
error 1006 Geoblocked "Geoblocked region" HTTP 451
error 1007 RateLimited "Rate-limited. Please slow down." HTTP 429
error 1008 ProjectNotFound "Project not found" HTTP 401
error 1009 SecretKeyCorsDisallowed "CORS disallowed. Admin API Secret Key can't be used from a web app." HTTP 403
version = v0.4.12

error 1000 Unauthorized "Unauthorized access" HTTP 401
error 1001 PermissionDenied "Permission denied" HTTP 403
error 1002 SessionExpired "Session expired" HTTP 403
error 1003 MethodNotFound "Method not found" HTTP 404
error 1004 RequestConflict "Conflict with target resource" HTTP 409
error 1005 Aborted "Request aborted" HTTP 400
error 1006 Geoblocked "Geoblocked region" HTTP 451
error 1007 RateLimited "Rate limit exceeded. Configure an Access Key to increase your limits: https://dashboard.trails.build or https://sequence.build" HTTP 429
error 1008 ProjectNotFound "Project not found" HTTP 401
error 1009 SecretKeyCorsDisallowed "CORS disallowed. Admin API Secret Key can't be used from a web app." HTTP 403
6 changes: 3 additions & 3 deletions proto/authcontrol.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions proto/authcontrol.gen.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* eslint-disable */
// authcontrol v0.9.1 efc70751a8d3d04b62886c568ebe71265a4e3d5b
// authcontrol v0.9.1 6855ab0cf75fb9058df94efbbb43b02641cd0918
// --
// Code generated by webrpc-gen@v0.22.1 with typescript generator. DO NOT EDIT.
//
Expand All @@ -16,7 +16,7 @@ export const WebRPCVersion = "v1"
export const WebRPCSchemaVersion = "v0.9.1"

// Schema hash generated from your RIDL schema
export const WebRPCSchemaHash = "efc70751a8d3d04b62886c568ebe71265a4e3d5b"
export const WebRPCSchemaHash = "6855ab0cf75fb9058df94efbbb43b02641cd0918"

type WebrpcGenVersions = {
webrpcGenVersion: string;
Expand Down Expand Up @@ -391,7 +391,7 @@ export class RateLimitedError extends WebrpcError {
constructor(
name: string = 'RateLimited',
code: number = 1007,
message: string = `Rate-limited. Please slow down.`,
message: string = `Rate limit exceeded. Configure an Access Key to increase your limits: https://dashboard.trails.build or https://sequence.build`,
status: number = 0,
cause?: string
) {
Expand Down