From 81d44a62deb684ad0d2c0f58d9358c5a6430d7a9 Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Fri, 28 Nov 2025 20:42:15 +0100 Subject: [PATCH 1/2] fix: AccessControl middleware returns ErrMethodNotFound --- middleware.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/middleware.go b/middleware.go index 668dc66..e6c0cb6 100644 --- a/middleware.go +++ b/middleware.go @@ -305,7 +305,7 @@ func AccessControl(acl Config[ACL], cfg Options) func(next http.Handler) http.Ha ctx := r.Context() acl, err := acl.Get(ctx, r.URL.Path) if err != nil { - cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("get acl: %w", err)) + cfg.ErrHandler(r, w, proto.ErrMethodNotFound) return } From 6591a9a2ea2d99bd2de10f432fa43ad56522c727 Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Mon, 1 Dec 2025 09:04:53 +0100 Subject: [PATCH 2/2] Return no error when ACL is not found --- common.go | 15 +++++++-------- middleware.go | 24 +++++++++++++----------- middleware_test.go | 33 ++++++++++++++++----------------- 3 files changed, 36 insertions(+), 36 deletions(-) diff --git a/common.go b/common.go index 1dc3b8a..b9c1b47 100644 --- a/common.go +++ b/common.go @@ -61,14 +61,14 @@ type ProjectStore interface { type Config[T any] map[string]map[string]T // Get returns the config value for the given request. -func (c Config[T]) Get(_ context.Context, path string) (v T, err error) { +func (c Config[T]) Get(_ context.Context, path string) (v T, ok bool) { if c == nil { - return v, fmt.Errorf("config is nil") + return v, false } p := strings.Split(path, "/") if len(p) < 4 { - return v, fmt.Errorf("path has not enough parts: %s", path) + return v, false } var ( @@ -78,15 +78,14 @@ func (c Config[T]) Get(_ context.Context, path string) (v T, err error) { ) if packageName != "rpc" { - return v, fmt.Errorf("path doesn't include rpc: %s", path) + return v, false } - v, ok := c[serviceName][methodName] - if !ok { - return v, fmt.Errorf("acl not defined for path: %s", path) + if v, ok = c[serviceName][methodName]; !ok { + return v, false } - return v, nil + return v, true } // Verify checks that the given config is valid for the given service. diff --git a/middleware.go b/middleware.go index e6c0cb6..f74da9b 100644 --- a/middleware.go +++ b/middleware.go @@ -303,23 +303,25 @@ func AccessControl(acl Config[ACL], cfg Options) func(next http.Handler) http.Ha return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - acl, err := acl.Get(ctx, r.URL.Path) - if err != nil { - cfg.ErrHandler(r, w, proto.ErrMethodNotFound) + + acl, ok := acl.Get(ctx, r.URL.Path) + if !ok { + // no ACL defined -> delegate to the next handler + next.ServeHTTP(w, r) return } - if session, _ := GetSessionType(ctx); !acl.Includes(session) { - err := proto.ErrPermissionDenied - if session == proto.SessionType_Public { - err = proto.ErrUnauthorized - } - - cfg.ErrHandler(r, w, err) + session, _ := GetSessionType(ctx) + if acl.Includes(session) { + next.ServeHTTP(w, r) return } - next.ServeHTTP(w, r) + err := proto.ErrUnauthorized + if session > proto.SessionType_Public { + err = proto.ErrPermissionDenied + } + cfg.ErrHandler(r, w, err) }) } } diff --git a/middleware_test.go b/middleware_test.go index 147a23c..e463d87 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -232,25 +232,25 @@ func TestInvalid(t *testing.T) { assert.True(t, ok) assert.NoError(t, err) - // Invalid request path with wrong not enough parts in path for valid RPC request + // Invalid request path with wrong not enough parts in path for valid RPC request, this will delegate to next handler and return no error ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims))) - assert.False(t, ok) - assert.ErrorIs(t, err, proto.ErrUnauthorized) + assert.True(t, ok) + assert.NoError(t, err) - // Invalid request path with wrong "rpc" + // Invalid request path with wrong "rpc", this will delegate to next handler and return no error ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/pcr/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims))) - assert.False(t, ok) - assert.ErrorIs(t, err, proto.ErrUnauthorized) + assert.True(t, ok) + assert.NoError(t, err) - // Invalid Service + // Invalid Service, this will delegate to next handler and return no error ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims))) - assert.False(t, ok) - assert.ErrorIs(t, err, proto.ErrUnauthorized) + assert.True(t, ok) + assert.NoError(t, err) - // Invalid Method + // Invalid Method, this will delegate to next handler and return no error ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims))) - assert.False(t, ok) - assert.ErrorIs(t, err, proto.ErrUnauthorized) + assert.True(t, ok) + assert.NoError(t, err) // Expired JWT Token claims["exp"] = time.Now().Add(-5 * time.Minute).Unix() // Note: Session() middleware allows some skew. @@ -283,7 +283,7 @@ func TestCustomErrHandler(t *testing.T) { ACLConfig := authcontrol.Config[authcontrol.ACL]{ ServiceName: { - MethodName: authcontrol.NewACL(proto.SessionType_Public.OrHigher()...), + MethodName: authcontrol.NewACL(proto.SessionType_AccessKey.OrHigher()...), }, } @@ -325,16 +325,15 @@ func TestCustomErrHandler(t *testing.T) { r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - var claims map[string]any - claims = map[string]any{"service": "client_service"} + claims := map[string]any{"service": "client_service"} // Valid Request ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims))) assert.True(t, ok) assert.NoError(t, err) - // Invalid service which should return custom error from overrided ErrHandler - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims))) + // Invalid Access, should return custom error + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName)) assert.False(t, ok) assert.ErrorIs(t, err, customErr) }