Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ type ProjectStore interface {
type Config[T any] map[string]map[string]T

// Get returns the config value for the given request.
func (c Config[T]) Get(_ context.Context, path string) (v T, err error) {
func (c Config[T]) Get(_ context.Context, path string) (v T, ok bool) {
if c == nil {
return v, fmt.Errorf("config is nil")
return v, false
}

p := strings.Split(path, "/")
if len(p) < 4 {
return v, fmt.Errorf("path has not enough parts: %s", path)
return v, false
}

var (
Expand All @@ -78,15 +78,14 @@ func (c Config[T]) Get(_ context.Context, path string) (v T, err error) {
)

if packageName != "rpc" {
return v, fmt.Errorf("path doesn't include rpc: %s", path)
return v, false
}

v, ok := c[serviceName][methodName]
if !ok {
return v, fmt.Errorf("acl not defined for path: %s", path)
if v, ok = c[serviceName][methodName]; !ok {
return v, false
}

return v, nil
return v, true
}

// Verify checks that the given config is valid for the given service.
Expand Down
24 changes: 13 additions & 11 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,23 +303,25 @@ func AccessControl(acl Config[ACL], cfg Options) func(next http.Handler) http.Ha
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
acl, err := acl.Get(ctx, r.URL.Path)
if err != nil {
cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("get acl: %w", err))

acl, ok := acl.Get(ctx, r.URL.Path)
if !ok {
// no ACL defined -> delegate to the next handler
next.ServeHTTP(w, r)
return
}

if session, _ := GetSessionType(ctx); !acl.Includes(session) {
err := proto.ErrPermissionDenied
if session == proto.SessionType_Public {
err = proto.ErrUnauthorized
}

cfg.ErrHandler(r, w, err)
session, _ := GetSessionType(ctx)
if acl.Includes(session) {
next.ServeHTTP(w, r)
return
}

next.ServeHTTP(w, r)
err := proto.ErrUnauthorized
if session > proto.SessionType_Public {
err = proto.ErrPermissionDenied
}
cfg.ErrHandler(r, w, err)
})
}
}
Expand Down
33 changes: 16 additions & 17 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,25 +232,25 @@ func TestInvalid(t *testing.T) {
assert.True(t, ok)
assert.NoError(t, err)

// Invalid request path with wrong not enough parts in path for valid RPC request
// Invalid request path with wrong not enough parts in path for valid RPC request, this will delegate to next handler and return no error
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrUnauthorized)
assert.True(t, ok)
assert.NoError(t, err)

// Invalid request path with wrong "rpc"
// Invalid request path with wrong "rpc", this will delegate to next handler and return no error
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/pcr/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrUnauthorized)
assert.True(t, ok)
assert.NoError(t, err)

// Invalid Service
// Invalid Service, this will delegate to next handler and return no error
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrUnauthorized)
assert.True(t, ok)
assert.NoError(t, err)

// Invalid Method
// Invalid Method, this will delegate to next handler and return no error
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrUnauthorized)
assert.True(t, ok)
assert.NoError(t, err)

// Expired JWT Token
claims["exp"] = time.Now().Add(-5 * time.Minute).Unix() // Note: Session() middleware allows some skew.
Expand Down Expand Up @@ -283,7 +283,7 @@ func TestCustomErrHandler(t *testing.T) {

ACLConfig := authcontrol.Config[authcontrol.ACL]{
ServiceName: {
MethodName: authcontrol.NewACL(proto.SessionType_Public.OrHigher()...),
MethodName: authcontrol.NewACL(proto.SessionType_AccessKey.OrHigher()...),
},
}

Expand Down Expand Up @@ -325,16 +325,15 @@ func TestCustomErrHandler(t *testing.T) {

r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))

var claims map[string]any
claims = map[string]any{"service": "client_service"}
claims := map[string]any{"service": "client_service"}

// Valid Request
ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
assert.True(t, ok)
assert.NoError(t, err)

// Invalid service which should return custom error from overrided ErrHandler
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
// Invalid Access, should return custom error
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName))
assert.False(t, ok)
assert.ErrorIs(t, err, customErr)
}
Expand Down