From a4ececf39591bbb0af3e7ad9fe05ad08e5f7f6d8 Mon Sep 17 00:00:00 2001 From: Kaarel Raspel Date: Wed, 3 Mar 2021 13:34:32 +0200 Subject: [PATCH] Change the order how middlewares are prepared for the request Previously the middlewares were initialized for the request in the reversed order. Now the middlewares are initialized in the order they are registered. --- controller_test.go | 24 +++++--- middleware.go | 7 ++- middleware_test.go | 149 +++++++++++++++++++++++++++------------------ 3 files changed, 114 insertions(+), 66 deletions(-) diff --git a/controller_test.go b/controller_test.go index 6a3d394..36a79a1 100644 --- a/controller_test.go +++ b/controller_test.go @@ -12,7 +12,7 @@ func Test_BaseController(t *testing.T) { type testCase struct { title string method string - mws []Middleware + mws []func(w http.ResponseWriter) func (http.Handler) http.Handler out string } @@ -20,18 +20,24 @@ func Test_BaseController(t *testing.T) { { title: "register middleware for HTTP method", method: http.MethodGet, - mws: []Middleware{middlewareOne}, - out: "/mw1 before next/final handler/mw1 after next", + mws: []func(w http.ResponseWriter) func(http.Handler) http.Handler{ + middlewareOne, + }, + out: "/mw1 prepare/mw1 before next/final handler/mw1 after next", }, { title: "add middleware to existing chain", method: http.MethodGet, - mws: []Middleware{middlewareTwo, middlewareThree}, - out: "/mw1 before next/mw2 before next/mw3 before next/final handler/mw3 after next/mw2 after next/mw1 after next", + mws: []func(w http.ResponseWriter) func(http.Handler) http.Handler{ + middlewareTwo, middlewareThree, + }, + out: "/mw2 prepare/mw3 prepare/mw1 before next/mw2 before next/mw3 before next/final handler" + + "/mw3 after next/mw2 after next/mw1 after next", }, { title: "get an empty middleware chain (by default)", method: http.MethodPost, + mws: []func(w http.ResponseWriter) func(http.Handler) http.Handler{}, out: "/final handler", }, } @@ -40,10 +46,14 @@ func Test_BaseController(t *testing.T) { controller := NewBaseController() for _, tc := range cases { t.Run(tc.title, func(t *testing.T) { + w := httptest.NewRecorder() if len(tc.mws) > 0 { - controller.AddMiddleware(tc.method, tc.mws...) + var mws []Middleware + for _, mw := range tc.mws { + mws = append(mws, mw(w)) + } + controller.AddMiddleware(tc.method, mws...) } - w := httptest.NewRecorder() controller.Middleware(tc.method).Then(handlerFinal).ServeHTTP(w, nil) if w.Body.String() != tc.out { t.Errorf("handler output is expected to be %q but was %q", tc.out, w.Body.String()) diff --git a/middleware.go b/middleware.go index 5f93d6f..42d661a 100644 --- a/middleware.go +++ b/middleware.go @@ -38,7 +38,12 @@ func (mw Middleware) Use(middlewares ...Middleware) Middleware { for _, next := range middlewares { mw = func(curr, next Middleware) Middleware { return func(handler http.Handler) http.Handler { - return curr(next(handler)) + var nextHandler http.Handler + currHandler := curr(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextHandler.ServeHTTP(w, r) + })) + nextHandler = next(handler) + return currHandler } }(mw, next) } diff --git a/middleware_test.go b/middleware_test.go index 2aaee68..cd11e34 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -7,28 +7,37 @@ import ( ) var ( - middlewareOne = func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("/mw1 before next")) - next.ServeHTTP(w, r) - w.Write([]byte("/mw1 after next")) - }) + middlewareOne = func(w http.ResponseWriter) func (http.Handler) http.Handler { + w.Write([]byte("/mw1 prepare")) + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("/mw1 before next")) + next.ServeHTTP(w, r) + w.Write([]byte("/mw1 after next")) + }) + } } - middlewareTwo = func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("/mw2 before next")) - next.ServeHTTP(w, r) - w.Write([]byte("/mw2 after next")) - }) + middlewareTwo = func(w http.ResponseWriter) func (http.Handler) http.Handler { + w.Write([]byte("/mw2 prepare")) + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("/mw2 before next")) + next.ServeHTTP(w, r) + w.Write([]byte("/mw2 after next")) + }) + } } - middlewareThree = func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("/mw3 before next")) - next.ServeHTTP(w, r) - w.Write([]byte("/mw3 after next")) - }) + middlewareThree = func(w http.ResponseWriter) func (http.Handler) http.Handler { + w.Write([]byte("/mw3 prepare")) + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("/mw3 before next")) + next.ServeHTTP(w, r) + w.Write([]byte("/mw3 after next")) + }) + } } middlewareFuncOne = func(w http.ResponseWriter, r *http.Request, next http.Handler) { @@ -43,10 +52,13 @@ var ( w.Write([]byte("/mw func2 after next")) } - middlewareBreak Middleware = func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("/skip the rest")) - }) + middlewareBreak = func(w http.ResponseWriter) Middleware { + w.Write([]byte("/skip the rest prepare")) + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("/skip the rest")) + }) + } } handlerOne = func(w http.ResponseWriter, r *http.Request) { @@ -65,37 +77,48 @@ var ( func Test_Middleware(t *testing.T) { type testCase struct { title string - handler http.Handler + handler func (w http.ResponseWriter) http.Handler out string } cases := []testCase{ { title: "build handler with single middleware (one call of Use() func with single argument)", - handler: New().Use(middlewareOne).Then(handlerFinal), - out: "/mw1 before next/final handler/mw1 after next", + handler: func(w http.ResponseWriter) http.Handler { + return New().Use(middlewareOne(w)).Then(handlerFinal) + }, + out: "/mw1 prepare/mw1 before next/final handler/mw1 after next", }, { title: "build handler passing middleware to the constructor (call New() with arguments)", - handler: New(middlewareOne, middlewareTwo).Use(middlewareThree).Then(handlerFinal), - out: "/mw1 before next/mw2 before next/mw3 before next/final handler/mw3 after next/mw2 after next/mw1 after next", + handler: func(w http.ResponseWriter) http.Handler { + return New(middlewareOne(w), middlewareTwo(w)).Use(middlewareThree(w)).Then(handlerFinal) + }, + out: "/mw1 prepare/mw2 prepare/mw3 prepare/mw1 before next/mw2 before next/mw3 before next" + + "/final handler/mw3 after next/mw2 after next/mw1 after next", }, { title: "build handler with multiple middleware (adding one middleware per Use())", - handler: New().Use(middlewareOne).Use(middlewareTwo).Use(middlewareThree).Then(handlerFinal), - out: "/mw1 before next/mw2 before next/mw3 before next/final handler/mw3 after next/mw2 after next/mw1 after next", + handler: func(w http.ResponseWriter) http.Handler { + return New().Use(middlewareOne(w)).Use(middlewareTwo(w)).Use(middlewareThree(w)).Then(handlerFinal) + }, + out: "/mw1 prepare/mw2 prepare/mw3 prepare/mw1 before next/mw2 before next/mw3 before next" + + "/final handler/mw3 after next/mw2 after next/mw1 after next", }, { title: "build handler with combination of single/plural calls of Use()", - handler: New().Use(middlewareOne).Use(middlewareTwo, middlewareThree).Then(handlerFinal), - out: "/mw1 before next/mw2 before next/mw3 before next/final handler/mw3 after next/mw2 after next/mw1 after next", + handler: func(w http.ResponseWriter) http.Handler { + return New().Use(middlewareOne(w)).Use(middlewareTwo(w), middlewareThree(w)).Then(handlerFinal) + }, + out: "/mw1 prepare/mw2 prepare/mw3 prepare/mw1 before next/mw2 before next/mw3 before next" + + "/final handler/mw3 after next/mw2 after next/mw1 after next", }, } for _, tc := range cases { t.Run(tc.title, func(t *testing.T) { w := httptest.NewRecorder() - tc.handler.ServeHTTP(w, nil) + tc.handler(w).ServeHTTP(w, nil) if w.Body.String() != tc.out { t.Errorf("the output %q is expected to be %q", w.Body.String(), tc.out) } @@ -111,7 +134,7 @@ func Test_Chain(t *testing.T) { type testCase struct { title string - args []interface{} + args func (http.ResponseWriter) []interface{} out string panic bool } @@ -119,44 +142,54 @@ func Test_Chain(t *testing.T) { cases := []testCase{ { title: "building handler with unsupported argument types should panic", - args: []interface{}{ - middlewareOne, - middlewareTwo, - true, - middlewareThree, - handlerFinal, + args: func(w http.ResponseWriter) []interface{} { + return []interface{}{ + middlewareOne(w), + middlewareTwo(w), + true, + middlewareThree(w), + handlerFinal, + } }, panic: true, }, { title: "middleware should have control over the \"next\" handlers", - args: []interface{}{ - middlewareOne, - middlewareTwo, - middlewareBreak, - middlewareThree, - handlerFinal, + args: func(w http.ResponseWriter) []interface{} { + return []interface{}{ + middlewareOne(w), + middlewareTwo(w), + middlewareBreak(w), + middlewareThree(w), + handlerFinal, + } }, - out: "/mw1 before next/mw2 before next/skip the rest/mw2 after next/mw1 after next", + out: "/mw1 prepare/mw2 prepare/skip the rest prepare/mw3 prepare/mw1 before next/mw2 before next" + + "/skip the rest/mw2 after next/mw1 after next", }, { title: "calling function without any arguments should build a middleware with only blobHandler", + args: func(w http.ResponseWriter) []interface{} { + return []interface{}{} + }, out: "/blob handler", }, { title: "building handler with all kind of supported arguments should be successful", - args: []interface{}{ - middlewareOne, - Middleware(middlewareTwo), - middlewareFuncOne, - MiddlewareFunc(middlewareFuncTwo), - handlerOne, - http.HandlerFunc(handlerTwo), - middlewareThree, - handlerFinal, + args: func(w http.ResponseWriter) []interface{} { + return []interface{}{ + middlewareOne(w), + Middleware(middlewareTwo(w)), + middlewareFuncOne, + MiddlewareFunc(middlewareFuncTwo), + handlerOne, + http.HandlerFunc(handlerTwo), + middlewareThree(w), + handlerFinal, + } }, - out: "/mw1 before next/mw2 before next/mw func1 before next/mw func2 before next" + - "/first handler/second handler/mw3 before next/final handler/blob handler" + + out: "/mw1 prepare/mw2 prepare/mw3 prepare/mw1 before next/mw2 before next/mw func1 before next" + + "/mw func2 before next/first handler/second handler/mw3 before next/final handler/blob handler" + "/mw3 after next/mw func2 after next/mw func1 after next/mw2 after next/mw1 after next", }, } @@ -175,7 +208,7 @@ func Test_Chain(t *testing.T) { } }() w := httptest.NewRecorder() - Chain(tc.args...).ServeHTTP(w, nil) + Chain(tc.args(w)...).ServeHTTP(w, nil) if w.Body.String() != tc.out { t.Errorf("out %v expected to be %v", w.Body.String(), tc.out) }