diff --git a/controller_test.go b/controller_test.go index 6a3d394..36a79a1 100644 --- a/controller_test.go +++ b/controller_test.go @@ -12,7 +12,7 @@ func Test_BaseController(t *testing.T) { type testCase struct { title string method string - mws []Middleware + mws []func(w http.ResponseWriter) func (http.Handler) http.Handler out string } @@ -20,18 +20,24 @@ func Test_BaseController(t *testing.T) { { title: "register middleware for HTTP method", method: http.MethodGet, - mws: []Middleware{middlewareOne}, - out: "/mw1 before next/final handler/mw1 after next", + mws: []func(w http.ResponseWriter) func(http.Handler) http.Handler{ + middlewareOne, + }, + out: "/mw1 prepare/mw1 before next/final handler/mw1 after next", }, { title: "add middleware to existing chain", method: http.MethodGet, - mws: []Middleware{middlewareTwo, middlewareThree}, - out: "/mw1 before next/mw2 before next/mw3 before next/final handler/mw3 after next/mw2 after next/mw1 after next", + mws: []func(w http.ResponseWriter) func(http.Handler) http.Handler{ + middlewareTwo, middlewareThree, + }, + out: "/mw2 prepare/mw3 prepare/mw1 before next/mw2 before next/mw3 before next/final handler" + + "/mw3 after next/mw2 after next/mw1 after next", }, { title: "get an empty middleware chain (by default)", method: http.MethodPost, + mws: []func(w http.ResponseWriter) func(http.Handler) http.Handler{}, out: "/final handler", }, } @@ -40,10 +46,14 @@ func Test_BaseController(t *testing.T) { controller := NewBaseController() for _, tc := range cases { t.Run(tc.title, func(t *testing.T) { + w := httptest.NewRecorder() if len(tc.mws) > 0 { - controller.AddMiddleware(tc.method, tc.mws...) + var mws []Middleware + for _, mw := range tc.mws { + mws = append(mws, mw(w)) + } + controller.AddMiddleware(tc.method, mws...) } - w := httptest.NewRecorder() controller.Middleware(tc.method).Then(handlerFinal).ServeHTTP(w, nil) if w.Body.String() != tc.out { t.Errorf("handler output is expected to be %q but was %q", tc.out, w.Body.String()) diff --git a/middleware.go b/middleware.go index 5f93d6f..42d661a 100644 --- a/middleware.go +++ b/middleware.go @@ -38,7 +38,12 @@ func (mw Middleware) Use(middlewares ...Middleware) Middleware { for _, next := range middlewares { mw = func(curr, next Middleware) Middleware { return func(handler http.Handler) http.Handler { - return curr(next(handler)) + var nextHandler http.Handler + currHandler := curr(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextHandler.ServeHTTP(w, r) + })) + nextHandler = next(handler) + return currHandler } }(mw, next) } diff --git a/middleware_test.go b/middleware_test.go index 2aaee68..cd11e34 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -7,28 +7,37 @@ import ( ) var ( - middlewareOne = func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("/mw1 before next")) - next.ServeHTTP(w, r) - w.Write([]byte("/mw1 after next")) - }) + middlewareOne = func(w http.ResponseWriter) func (http.Handler) http.Handler { + w.Write([]byte("/mw1 prepare")) + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("/mw1 before next")) + next.ServeHTTP(w, r) + w.Write([]byte("/mw1 after next")) + }) + } } - middlewareTwo = func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("/mw2 before next")) - next.ServeHTTP(w, r) - w.Write([]byte("/mw2 after next")) - }) + middlewareTwo = func(w http.ResponseWriter) func (http.Handler) http.Handler { + w.Write([]byte("/mw2 prepare")) + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("/mw2 before next")) + next.ServeHTTP(w, r) + w.Write([]byte("/mw2 after next")) + }) + } } - middlewareThree = func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("/mw3 before next")) - next.ServeHTTP(w, r) - w.Write([]byte("/mw3 after next")) - }) + middlewareThree = func(w http.ResponseWriter) func (http.Handler) http.Handler { + w.Write([]byte("/mw3 prepare")) + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("/mw3 before next")) + next.ServeHTTP(w, r) + w.Write([]byte("/mw3 after next")) + }) + } } middlewareFuncOne = func(w http.ResponseWriter, r *http.Request, next http.Handler) { @@ -43,10 +52,13 @@ var ( w.Write([]byte("/mw func2 after next")) } - middlewareBreak Middleware = func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("/skip the rest")) - }) + middlewareBreak = func(w http.ResponseWriter) Middleware { + w.Write([]byte("/skip the rest prepare")) + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("/skip the rest")) + }) + } } handlerOne = func(w http.ResponseWriter, r *http.Request) { @@ -65,37 +77,48 @@ var ( func Test_Middleware(t *testing.T) { type testCase struct { title string - handler http.Handler + handler func (w http.ResponseWriter) http.Handler out string } cases := []testCase{ { title: "build handler with single middleware (one call of Use() func with single argument)", - handler: New().Use(middlewareOne).Then(handlerFinal), - out: "/mw1 before next/final handler/mw1 after next", + handler: func(w http.ResponseWriter) http.Handler { + return New().Use(middlewareOne(w)).Then(handlerFinal) + }, + out: "/mw1 prepare/mw1 before next/final handler/mw1 after next", }, { title: "build handler passing middleware to the constructor (call New() with arguments)", - handler: New(middlewareOne, middlewareTwo).Use(middlewareThree).Then(handlerFinal), - out: "/mw1 before next/mw2 before next/mw3 before next/final handler/mw3 after next/mw2 after next/mw1 after next", + handler: func(w http.ResponseWriter) http.Handler { + return New(middlewareOne(w), middlewareTwo(w)).Use(middlewareThree(w)).Then(handlerFinal) + }, + out: "/mw1 prepare/mw2 prepare/mw3 prepare/mw1 before next/mw2 before next/mw3 before next" + + "/final handler/mw3 after next/mw2 after next/mw1 after next", }, { title: "build handler with multiple middleware (adding one middleware per Use())", - handler: New().Use(middlewareOne).Use(middlewareTwo).Use(middlewareThree).Then(handlerFinal), - out: "/mw1 before next/mw2 before next/mw3 before next/final handler/mw3 after next/mw2 after next/mw1 after next", + handler: func(w http.ResponseWriter) http.Handler { + return New().Use(middlewareOne(w)).Use(middlewareTwo(w)).Use(middlewareThree(w)).Then(handlerFinal) + }, + out: "/mw1 prepare/mw2 prepare/mw3 prepare/mw1 before next/mw2 before next/mw3 before next" + + "/final handler/mw3 after next/mw2 after next/mw1 after next", }, { title: "build handler with combination of single/plural calls of Use()", - handler: New().Use(middlewareOne).Use(middlewareTwo, middlewareThree).Then(handlerFinal), - out: "/mw1 before next/mw2 before next/mw3 before next/final handler/mw3 after next/mw2 after next/mw1 after next", + handler: func(w http.ResponseWriter) http.Handler { + return New().Use(middlewareOne(w)).Use(middlewareTwo(w), middlewareThree(w)).Then(handlerFinal) + }, + out: "/mw1 prepare/mw2 prepare/mw3 prepare/mw1 before next/mw2 before next/mw3 before next" + + "/final handler/mw3 after next/mw2 after next/mw1 after next", }, } for _, tc := range cases { t.Run(tc.title, func(t *testing.T) { w := httptest.NewRecorder() - tc.handler.ServeHTTP(w, nil) + tc.handler(w).ServeHTTP(w, nil) if w.Body.String() != tc.out { t.Errorf("the output %q is expected to be %q", w.Body.String(), tc.out) } @@ -111,7 +134,7 @@ func Test_Chain(t *testing.T) { type testCase struct { title string - args []interface{} + args func (http.ResponseWriter) []interface{} out string panic bool } @@ -119,44 +142,54 @@ func Test_Chain(t *testing.T) { cases := []testCase{ { title: "building handler with unsupported argument types should panic", - args: []interface{}{ - middlewareOne, - middlewareTwo, - true, - middlewareThree, - handlerFinal, + args: func(w http.ResponseWriter) []interface{} { + return []interface{}{ + middlewareOne(w), + middlewareTwo(w), + true, + middlewareThree(w), + handlerFinal, + } }, panic: true, }, { title: "middleware should have control over the \"next\" handlers", - args: []interface{}{ - middlewareOne, - middlewareTwo, - middlewareBreak, - middlewareThree, - handlerFinal, + args: func(w http.ResponseWriter) []interface{} { + return []interface{}{ + middlewareOne(w), + middlewareTwo(w), + middlewareBreak(w), + middlewareThree(w), + handlerFinal, + } }, - out: "/mw1 before next/mw2 before next/skip the rest/mw2 after next/mw1 after next", + out: "/mw1 prepare/mw2 prepare/skip the rest prepare/mw3 prepare/mw1 before next/mw2 before next" + + "/skip the rest/mw2 after next/mw1 after next", }, { title: "calling function without any arguments should build a middleware with only blobHandler", + args: func(w http.ResponseWriter) []interface{} { + return []interface{}{} + }, out: "/blob handler", }, { title: "building handler with all kind of supported arguments should be successful", - args: []interface{}{ - middlewareOne, - Middleware(middlewareTwo), - middlewareFuncOne, - MiddlewareFunc(middlewareFuncTwo), - handlerOne, - http.HandlerFunc(handlerTwo), - middlewareThree, - handlerFinal, + args: func(w http.ResponseWriter) []interface{} { + return []interface{}{ + middlewareOne(w), + Middleware(middlewareTwo(w)), + middlewareFuncOne, + MiddlewareFunc(middlewareFuncTwo), + handlerOne, + http.HandlerFunc(handlerTwo), + middlewareThree(w), + handlerFinal, + } }, - out: "/mw1 before next/mw2 before next/mw func1 before next/mw func2 before next" + - "/first handler/second handler/mw3 before next/final handler/blob handler" + + out: "/mw1 prepare/mw2 prepare/mw3 prepare/mw1 before next/mw2 before next/mw func1 before next" + + "/mw func2 before next/first handler/second handler/mw3 before next/final handler/blob handler" + "/mw3 after next/mw func2 after next/mw func1 after next/mw2 after next/mw1 after next", }, } @@ -175,7 +208,7 @@ func Test_Chain(t *testing.T) { } }() w := httptest.NewRecorder() - Chain(tc.args...).ServeHTTP(w, nil) + Chain(tc.args(w)...).ServeHTTP(w, nil) if w.Body.String() != tc.out { t.Errorf("out %v expected to be %v", w.Body.String(), tc.out) }