From 70b986d1fe63eb164dbb3f3d4b9cd46836d758b9 Mon Sep 17 00:00:00 2001 From: Blake Gentry Date: Thu, 22 Jan 2026 20:20:53 -0600 Subject: [PATCH 1/2] serve UI HTML for wildcard or missing Accept header Fixes #485. --- CHANGELOG.md | 1 + spa_response_writer.go | 48 ++++++++++++++++++++++-- spa_response_writer_test.go | 73 +++++++++++++++++++++++++++++++++++++ 3 files changed, 119 insertions(+), 3 deletions(-) create mode 100644 spa_response_writer_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f64249..f1856db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Prevent double slash in URLs for root path prefix. Thanks [Jan Kott](https://github.com/boostvolt)! [PR #487](https://github.com/riverqueue/riverui/pull/487) +- Serve UI HTML for wildcard or missing Accept headers and return 406 for explicit non-HTML requests. Fixes #485. [PR #XXX](https://github.com/riverqueue/riverui/pull/XXX). ## [v0.14.0] - 2026-01-02 diff --git a/spa_response_writer.go b/spa_response_writer.go index 8a0cbdd..98a58f4 100644 --- a/spa_response_writer.go +++ b/spa_response_writer.go @@ -6,7 +6,9 @@ import ( "fmt" "html/template" "io" + "mime" "net/http" + "strconv" "strings" ) @@ -30,9 +32,9 @@ func intercept404(handler, on404 http.Handler) http.Handler { func serveIndexHTML(devMode bool, manifest map[string]any, pathPrefix string, files http.FileSystem) http.HandlerFunc { return func(rw http.ResponseWriter, req *http.Request) { // Restrict only to instances where the browser is looking for an HTML file - if !strings.Contains(req.Header.Get("Accept"), "text/html") { - rw.WriteHeader(http.StatusNotFound) - fmt.Fprint(rw, "404 not found") + if !acceptsHTML(req) { + rw.WriteHeader(http.StatusNotAcceptable) + fmt.Fprint(rw, "not acceptable: only text/html is available") return } @@ -94,6 +96,46 @@ func serveIndexHTML(devMode bool, manifest map[string]any, pathPrefix string, fi } } +func acceptsHTML(req *http.Request) bool { + accept := strings.TrimSpace(req.Header.Get("Accept")) + if accept == "" { + return true + } + + for part := range strings.SplitSeq(accept, ",") { + part = strings.TrimSpace(part) + if part == "" { + continue + } + + mediaType, params, err := mime.ParseMediaType(part) + if err != nil { + mediaType = strings.TrimSpace(strings.SplitN(part, ";", 2)[0]) + params = nil + } + + quality := 1.0 + if params != nil { + if qRaw, ok := params["q"]; ok { + if parsed, err := strconv.ParseFloat(qRaw, 64); err == nil { + quality = parsed + } + } + } + + if quality <= 0 { + continue + } + + switch mediaType { + case "text/html", "text/*", "*/*": + return true + } + } + + return false +} + type spaResponseWriter struct { http.ResponseWriter diff --git a/spa_response_writer_test.go b/spa_response_writer_test.go new file mode 100644 index 0000000..a3d9fe1 --- /dev/null +++ b/spa_response_writer_test.go @@ -0,0 +1,73 @@ +package riverui + +import ( + "net/http" + "net/http/httptest" + "testing" + "testing/fstest" + + "github.com/stretchr/testify/require" +) + +func TestServeIndexHTMLAcceptNegotiation(t *testing.T) { + t.Parallel() + + files := fstest.MapFS{ + "index.html": &fstest.MapFile{Data: []byte("ok")}, + } + + handler := serveIndexHTML(false, map[string]any{}, "/riverui", http.FS(files)) + + tests := []struct { + name string + acceptHeader string + setAccept bool + wantStatus int + }{ + { + name: "AcceptHTML", + acceptHeader: "text/html", + setAccept: true, + wantStatus: http.StatusOK, + }, + { + name: "AcceptWildcard", + acceptHeader: "*/*", + setAccept: true, + wantStatus: http.StatusOK, + }, + { + name: "AcceptTextWildcard", + acceptHeader: "text/*", + setAccept: true, + wantStatus: http.StatusOK, + }, + { + name: "AcceptMissing", + setAccept: false, + wantStatus: http.StatusOK, + }, + { + name: "AcceptJSON", + acceptHeader: "application/json", + setAccept: true, + wantStatus: http.StatusNotAcceptable, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tt.setAccept { + req.Header.Set("Accept", tt.acceptHeader) + } + + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, req) + + require.Equal(t, tt.wantStatus, recorder.Result().StatusCode) + }) + } +} From 1321cc9fa76f90a5f84b6fd5657c35239e1a2782 Mon Sep 17 00:00:00 2001 From: Blake Gentry Date: Thu, 22 Jan 2026 20:33:04 -0600 Subject: [PATCH 2/2] harden & refactor SPA response writer Tighten Accept negotiation and method handling, add Vary support, and cache the parsed index template outside dev mode with tests. --- CHANGELOG.md | 2 +- spa_response_writer.go | 160 +++++++++++++++++++++++++---------- spa_response_writer_test.go | 162 ++++++++++++++++++++++++++++-------- 3 files changed, 244 insertions(+), 80 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f1856db..32c4e67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Prevent double slash in URLs for root path prefix. Thanks [Jan Kott](https://github.com/boostvolt)! [PR #487](https://github.com/riverqueue/riverui/pull/487) -- Serve UI HTML for wildcard or missing Accept headers and return 406 for explicit non-HTML requests. Fixes #485. [PR #XXX](https://github.com/riverqueue/riverui/pull/XXX). +- Serve UI HTML for wildcard or missing Accept headers and return 406 for explicit non-HTML requests. Fixes #485. [PR #493](https://github.com/riverqueue/riverui/pull/493). ## [v0.14.0] - 2026-01-02 diff --git a/spa_response_writer.go b/spa_response_writer.go index 98a58f4..a37d9ed 100644 --- a/spa_response_writer.go +++ b/spa_response_writer.go @@ -3,13 +3,14 @@ package riverui import ( "bytes" "encoding/json" - "fmt" "html/template" "io" "mime" "net/http" + "slices" "strconv" "strings" + "time" ) func intercept404(handler, on404 http.Handler) http.Handler { @@ -30,75 +31,71 @@ func intercept404(handler, on404 http.Handler) http.Handler { } func serveIndexHTML(devMode bool, manifest map[string]any, pathPrefix string, files http.FileSystem) http.HandlerFunc { - return func(rw http.ResponseWriter, req *http.Request) { - // Restrict only to instances where the browser is looking for an HTML file - if !acceptsHTML(req) { - rw.WriteHeader(http.StatusNotAcceptable) - fmt.Fprint(rw, "not acceptable: only text/html is available") + cachedIndex := indexTemplateResult{} + if !devMode { + cachedIndex = loadIndexTemplate(files) + } + return func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet && req.Method != http.MethodHead { + rw.Header().Set("Allow", "GET, HEAD") + http.Error(rw, "method not allowed", http.StatusMethodNotAllowed) return } - rawIndex, err := files.Open("index.html") - if err != nil { - http.Error(rw, "could not open index.html", http.StatusInternalServerError) - return - } + addVaryHeader(rw.Header(), "Accept") - config := struct { - APIURL string `json:"apiUrl"` //nolint:tagliatelle - Base string `json:"base"` - }{ - APIURL: pathPrefix + "/api", - Base: pathPrefix, + // Restrict only to instances where the browser is looking for an HTML file + if !acceptsHTML(req) { + http.Error(rw, "not acceptable: only text/html is available", http.StatusNotAcceptable) + return } - templateData := map[string]any{ - "Config": config, - "Dev": devMode, - "Manifest": manifest, - "Base": pathPrefix, + indexTemplate := cachedIndex + if devMode { + indexTemplate = loadIndexTemplate(files) } - - fileInfo, err := rawIndex.Stat() - if err != nil { - http.Error(rw, "could not stat index.html", http.StatusInternalServerError) + if indexTemplate.err != nil { + http.Error(rw, indexTemplate.errMessage, http.StatusInternalServerError) return } - indexBuf, err := io.ReadAll(rawIndex) - if err != nil { - http.Error(rw, "could not read index.html", http.StatusInternalServerError) - return + config := indexTemplateConfig{ + APIURL: pathPrefix + "/api", + Base: pathPrefix, } - tmpl, err := template.New("index.html").Funcs(template.FuncMap{ - "marshal": func(v any) template.JS { - a, _ := json.Marshal(v) - return template.JS(a) //nolint:gosec - }, - }).Parse(string(indexBuf)) - if err != nil { - http.Error(rw, "could not parse index.html", http.StatusInternalServerError) - return + templateData := indexTemplateData{ + Config: config, + Dev: devMode, + Manifest: manifest, + Base: pathPrefix, } var output bytes.Buffer - if err = tmpl.Execute(&output, templateData); err != nil { + if err := indexTemplate.tmpl.Execute(&output, templateData); err != nil { http.Error(rw, "could not execute index.html", http.StatusInternalServerError) return } - index := bytes.NewReader(output.Bytes()) + indexReader := bytes.NewReader(output.Bytes()) rw.Header().Set("Content-Type", "text/html; charset=utf-8") - http.ServeContent(rw, req, fileInfo.Name(), fileInfo.ModTime(), index) + http.ServeContent(rw, req, indexTemplate.name, indexTemplate.modTime, indexReader) } } func acceptsHTML(req *http.Request) bool { - accept := strings.TrimSpace(req.Header.Get("Accept")) - if accept == "" { + acceptValues := req.Header.Values("Accept") + if len(acceptValues) == 0 { + return true + } + + return slices.ContainsFunc(acceptValues, acceptsHTMLValue) +} + +func acceptsHTMLValue(accept string) bool { + if strings.TrimSpace(accept) == "" { return true } @@ -128,7 +125,7 @@ func acceptsHTML(req *http.Request) bool { } switch mediaType { - case "text/html", "text/*", "*/*": + case "text/html", "application/xhtml+xml", "text/*", "*/*": return true } } @@ -136,6 +133,79 @@ func acceptsHTML(req *http.Request) bool { return false } +func addVaryHeader(headers http.Header, value string) { + for _, existing := range headers.Values("Vary") { + for part := range strings.SplitSeq(existing, ",") { + if strings.EqualFold(strings.TrimSpace(part), value) { + return + } + } + } + + headers.Add("Vary", value) +} + +type indexTemplateConfig struct { + APIURL string `json:"apiUrl"` //nolint:tagliatelle + Base string `json:"base"` +} + +type indexTemplateData struct { + Config indexTemplateConfig + Dev bool + Manifest map[string]any + Base string +} + +type indexTemplateResult struct { + tmpl *template.Template + name string + modTime time.Time + err error + errMessage string +} + +func loadIndexTemplate(files http.FileSystem) indexTemplateResult { + rawIndex, err := files.Open("index.html") + if err != nil { + return indexTemplateResult{err: err, errMessage: "could not open index.html"} + } + defer rawIndex.Close() + + fileInfo, err := rawIndex.Stat() + if err != nil { + return indexTemplateResult{err: err, errMessage: "could not stat index.html"} + } + + indexBuf, err := io.ReadAll(rawIndex) + if err != nil { + return indexTemplateResult{err: err, errMessage: "could not read index.html"} + } + + tmpl, err := parseIndexTemplate(indexBuf) + if err != nil { + return indexTemplateResult{err: err, errMessage: "could not parse index.html"} + } + + return indexTemplateResult{ + tmpl: tmpl, + name: fileInfo.Name(), + modTime: fileInfo.ModTime(), + } +} + +func parseIndexTemplate(indexBuf []byte) (*template.Template, error) { + return template.New("index.html").Funcs(template.FuncMap{ + "marshal": func(v any) (template.JS, error) { + payload, err := json.Marshal(v) + if err != nil { + return "", err + } + return template.JS(payload), nil //nolint:gosec + }, + }).Parse(string(indexBuf)) +} + type spaResponseWriter struct { http.ResponseWriter diff --git a/spa_response_writer_test.go b/spa_response_writer_test.go index a3d9fe1..5943233 100644 --- a/spa_response_writer_test.go +++ b/spa_response_writer_test.go @@ -3,6 +3,8 @@ package riverui import ( "net/http" "net/http/httptest" + "strings" + "sync/atomic" "testing" "testing/fstest" @@ -12,46 +14,56 @@ import ( func TestServeIndexHTMLAcceptNegotiation(t *testing.T) { t.Parallel() - files := fstest.MapFS{ - "index.html": &fstest.MapFile{Data: []byte("ok")}, - } - - handler := serveIndexHTML(false, map[string]any{}, "/riverui", http.FS(files)) + handler := serveIndexHTML(false, map[string]any{}, "/riverui", newIndexFileSystem()) tests := []struct { - name string - acceptHeader string - setAccept bool - wantStatus int + name string + acceptHeaders []string + wantStatus int }{ { - name: "AcceptHTML", - acceptHeader: "text/html", - setAccept: true, - wantStatus: http.StatusOK, + name: "AcceptHTML", + acceptHeaders: []string{"text/html"}, + wantStatus: http.StatusOK, }, { - name: "AcceptWildcard", - acceptHeader: "*/*", - setAccept: true, - wantStatus: http.StatusOK, + name: "AcceptWildcard", + acceptHeaders: []string{"*/*"}, + wantStatus: http.StatusOK, }, { - name: "AcceptTextWildcard", - acceptHeader: "text/*", - setAccept: true, - wantStatus: http.StatusOK, + name: "AcceptTextWildcard", + acceptHeaders: []string{"text/*"}, + wantStatus: http.StatusOK, }, { name: "AcceptMissing", - setAccept: false, wantStatus: http.StatusOK, }, { - name: "AcceptJSON", - acceptHeader: "application/json", - setAccept: true, - wantStatus: http.StatusNotAcceptable, + name: "AcceptJSON", + acceptHeaders: []string{"application/json"}, + wantStatus: http.StatusNotAcceptable, + }, + { + name: "AcceptXHTML", + acceptHeaders: []string{"application/xhtml+xml"}, + wantStatus: http.StatusOK, + }, + { + name: "AcceptHTMLWithQualityZero", + acceptHeaders: []string{"text/html;q=0"}, + wantStatus: http.StatusNotAcceptable, + }, + { + name: "AcceptMultipleHeaders", + acceptHeaders: []string{"application/json", "text/html"}, + wantStatus: http.StatusOK, + }, + { + name: "AcceptMultipleValues", + acceptHeaders: []string{"application/json, text/html"}, + wantStatus: http.StatusOK, }, } @@ -59,15 +71,97 @@ func TestServeIndexHTMLAcceptNegotiation(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - req := httptest.NewRequest(http.MethodGet, "/", nil) - if tt.setAccept { - req.Header.Set("Accept", tt.acceptHeader) - } - - recorder := httptest.NewRecorder() - handler.ServeHTTP(recorder, req) - + recorder := performRequest(handler, http.MethodGet, tt.acceptHeaders) require.Equal(t, tt.wantStatus, recorder.Result().StatusCode) }) } } + +func TestServeIndexHTMLVaryHeader(t *testing.T) { + t.Parallel() + + handler := serveIndexHTML(false, map[string]any{}, "/riverui", newIndexFileSystem()) + + recorder := performRequest(handler, http.MethodGet, []string{"application/json"}) + require.Equal(t, http.StatusNotAcceptable, recorder.Result().StatusCode) + require.Contains(t, strings.Join(recorder.Header().Values("Vary"), ","), "Accept") +} + +func TestServeIndexHTMLMethodNotAllowed(t *testing.T) { + t.Parallel() + + handler := serveIndexHTML(false, map[string]any{}, "/riverui", newIndexFileSystem()) + + recorder := performRequest(handler, http.MethodPost, []string{"text/html"}) + require.Equal(t, http.StatusMethodNotAllowed, recorder.Result().StatusCode) + require.Contains(t, recorder.Header().Get("Allow"), "GET") + require.Contains(t, recorder.Header().Get("Allow"), "HEAD") +} + +func TestServeIndexHTMLHead(t *testing.T) { + t.Parallel() + + handler := serveIndexHTML(false, map[string]any{}, "/riverui", newIndexFileSystem()) + + recorder := performRequest(handler, http.MethodHead, []string{"text/html"}) + require.Equal(t, http.StatusOK, recorder.Result().StatusCode) + require.Empty(t, recorder.Body.String()) +} + +func TestServeIndexHTMLTemplateCaching(t *testing.T) { + t.Parallel() + + t.Run("CachesWhenNotDevMode", func(t *testing.T) { + t.Parallel() + + counting := &countingFS{fs: newIndexFileSystem()} + handler := serveIndexHTML(false, map[string]any{}, "/riverui", counting) + + require.Equal(t, http.StatusOK, performRequest(handler, http.MethodGet, []string{"text/html"}).Result().StatusCode) + require.Equal(t, http.StatusOK, performRequest(handler, http.MethodGet, []string{"text/html"}).Result().StatusCode) + + require.Equal(t, int32(1), counting.opens.Load()) + }) + + t.Run("NoCacheInDevMode", func(t *testing.T) { + t.Parallel() + + counting := &countingFS{fs: newIndexFileSystem()} + handler := serveIndexHTML(true, map[string]any{}, "/riverui", counting) + + require.Equal(t, http.StatusOK, performRequest(handler, http.MethodGet, []string{"text/html"}).Result().StatusCode) + require.Equal(t, http.StatusOK, performRequest(handler, http.MethodGet, []string{"text/html"}).Result().StatusCode) + + require.Equal(t, int32(2), counting.opens.Load()) + }) +} + +func performRequest(handler http.Handler, method string, acceptHeaders []string) *httptest.ResponseRecorder { + req := httptest.NewRequest(method, "/", nil) + for _, header := range acceptHeaders { + req.Header.Add("Accept", header) + } + + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, req) + + return recorder +} + +func newIndexFileSystem() http.FileSystem { + files := fstest.MapFS{ + "index.html": &fstest.MapFile{Data: []byte("ok")}, + } + + return http.FS(files) +} + +type countingFS struct { + fs http.FileSystem + opens atomic.Int32 +} + +func (counting *countingFS) Open(name string) (http.File, error) { + counting.opens.Add(1) + return counting.fs.Open(name) +}