diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f64249..32c4e67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Prevent double slash in URLs for root path prefix. Thanks [Jan Kott](https://github.com/boostvolt)! [PR #487](https://github.com/riverqueue/riverui/pull/487) +- Serve UI HTML for wildcard or missing Accept headers and return 406 for explicit non-HTML requests. Fixes #485. [PR #493](https://github.com/riverqueue/riverui/pull/493). ## [v0.14.0] - 2026-01-02 diff --git a/spa_response_writer.go b/spa_response_writer.go index 8a0cbdd..a37d9ed 100644 --- a/spa_response_writer.go +++ b/spa_response_writer.go @@ -3,11 +3,14 @@ package riverui import ( "bytes" "encoding/json" - "fmt" "html/template" "io" + "mime" "net/http" + "slices" + "strconv" "strings" + "time" ) func intercept404(handler, on404 http.Handler) http.Handler { @@ -28,70 +31,179 @@ func intercept404(handler, on404 http.Handler) http.Handler { } func serveIndexHTML(devMode bool, manifest map[string]any, pathPrefix string, files http.FileSystem) http.HandlerFunc { + cachedIndex := indexTemplateResult{} + if !devMode { + cachedIndex = loadIndexTemplate(files) + } + return func(rw http.ResponseWriter, req *http.Request) { - // Restrict only to instances where the browser is looking for an HTML file - if !strings.Contains(req.Header.Get("Accept"), "text/html") { - rw.WriteHeader(http.StatusNotFound) - fmt.Fprint(rw, "404 not found") + if req.Method != http.MethodGet && req.Method != http.MethodHead { + rw.Header().Set("Allow", "GET, HEAD") + http.Error(rw, "method not allowed", http.StatusMethodNotAllowed) + return + } + + addVaryHeader(rw.Header(), "Accept") + // Restrict only to instances where the browser is looking for an HTML file + if !acceptsHTML(req) { + http.Error(rw, "not acceptable: only text/html is available", http.StatusNotAcceptable) return } - rawIndex, err := files.Open("index.html") - if err != nil { - http.Error(rw, "could not open index.html", http.StatusInternalServerError) + indexTemplate := cachedIndex + if devMode { + indexTemplate = loadIndexTemplate(files) + } + if indexTemplate.err != nil { + http.Error(rw, indexTemplate.errMessage, http.StatusInternalServerError) return } - config := struct { - APIURL string `json:"apiUrl"` //nolint:tagliatelle - Base string `json:"base"` - }{ + config := indexTemplateConfig{ APIURL: pathPrefix + "/api", Base: pathPrefix, } - templateData := map[string]any{ - "Config": config, - "Dev": devMode, - "Manifest": manifest, - "Base": pathPrefix, + templateData := indexTemplateData{ + Config: config, + Dev: devMode, + Manifest: manifest, + Base: pathPrefix, } - fileInfo, err := rawIndex.Stat() - if err != nil { - http.Error(rw, "could not stat index.html", http.StatusInternalServerError) + var output bytes.Buffer + if err := indexTemplate.tmpl.Execute(&output, templateData); err != nil { + http.Error(rw, "could not execute index.html", http.StatusInternalServerError) return } - indexBuf, err := io.ReadAll(rawIndex) - if err != nil { - http.Error(rw, "could not read index.html", http.StatusInternalServerError) - return + indexReader := bytes.NewReader(output.Bytes()) + + rw.Header().Set("Content-Type", "text/html; charset=utf-8") + http.ServeContent(rw, req, indexTemplate.name, indexTemplate.modTime, indexReader) + } +} + +func acceptsHTML(req *http.Request) bool { + acceptValues := req.Header.Values("Accept") + if len(acceptValues) == 0 { + return true + } + + return slices.ContainsFunc(acceptValues, acceptsHTMLValue) +} + +func acceptsHTMLValue(accept string) bool { + if strings.TrimSpace(accept) == "" { + return true + } + + for part := range strings.SplitSeq(accept, ",") { + part = strings.TrimSpace(part) + if part == "" { + continue } - tmpl, err := template.New("index.html").Funcs(template.FuncMap{ - "marshal": func(v any) template.JS { - a, _ := json.Marshal(v) - return template.JS(a) //nolint:gosec - }, - }).Parse(string(indexBuf)) + mediaType, params, err := mime.ParseMediaType(part) if err != nil { - http.Error(rw, "could not parse index.html", http.StatusInternalServerError) - return + mediaType = strings.TrimSpace(strings.SplitN(part, ";", 2)[0]) + params = nil } - var output bytes.Buffer - if err = tmpl.Execute(&output, templateData); err != nil { - http.Error(rw, "could not execute index.html", http.StatusInternalServerError) - return + quality := 1.0 + if params != nil { + if qRaw, ok := params["q"]; ok { + if parsed, err := strconv.ParseFloat(qRaw, 64); err == nil { + quality = parsed + } + } } - index := bytes.NewReader(output.Bytes()) + if quality <= 0 { + continue + } - rw.Header().Set("Content-Type", "text/html; charset=utf-8") - http.ServeContent(rw, req, fileInfo.Name(), fileInfo.ModTime(), index) + switch mediaType { + case "text/html", "application/xhtml+xml", "text/*", "*/*": + return true + } + } + + return false +} + +func addVaryHeader(headers http.Header, value string) { + for _, existing := range headers.Values("Vary") { + for part := range strings.SplitSeq(existing, ",") { + if strings.EqualFold(strings.TrimSpace(part), value) { + return + } + } + } + + headers.Add("Vary", value) +} + +type indexTemplateConfig struct { + APIURL string `json:"apiUrl"` //nolint:tagliatelle + Base string `json:"base"` +} + +type indexTemplateData struct { + Config indexTemplateConfig + Dev bool + Manifest map[string]any + Base string +} + +type indexTemplateResult struct { + tmpl *template.Template + name string + modTime time.Time + err error + errMessage string +} + +func loadIndexTemplate(files http.FileSystem) indexTemplateResult { + rawIndex, err := files.Open("index.html") + if err != nil { + return indexTemplateResult{err: err, errMessage: "could not open index.html"} + } + defer rawIndex.Close() + + fileInfo, err := rawIndex.Stat() + if err != nil { + return indexTemplateResult{err: err, errMessage: "could not stat index.html"} + } + + indexBuf, err := io.ReadAll(rawIndex) + if err != nil { + return indexTemplateResult{err: err, errMessage: "could not read index.html"} + } + + tmpl, err := parseIndexTemplate(indexBuf) + if err != nil { + return indexTemplateResult{err: err, errMessage: "could not parse index.html"} } + + return indexTemplateResult{ + tmpl: tmpl, + name: fileInfo.Name(), + modTime: fileInfo.ModTime(), + } +} + +func parseIndexTemplate(indexBuf []byte) (*template.Template, error) { + return template.New("index.html").Funcs(template.FuncMap{ + "marshal": func(v any) (template.JS, error) { + payload, err := json.Marshal(v) + if err != nil { + return "", err + } + return template.JS(payload), nil //nolint:gosec + }, + }).Parse(string(indexBuf)) } type spaResponseWriter struct { diff --git a/spa_response_writer_test.go b/spa_response_writer_test.go new file mode 100644 index 0000000..5943233 --- /dev/null +++ b/spa_response_writer_test.go @@ -0,0 +1,167 @@ +package riverui + +import ( + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "testing/fstest" + + "github.com/stretchr/testify/require" +) + +func TestServeIndexHTMLAcceptNegotiation(t *testing.T) { + t.Parallel() + + handler := serveIndexHTML(false, map[string]any{}, "/riverui", newIndexFileSystem()) + + tests := []struct { + name string + acceptHeaders []string + wantStatus int + }{ + { + name: "AcceptHTML", + acceptHeaders: []string{"text/html"}, + wantStatus: http.StatusOK, + }, + { + name: "AcceptWildcard", + acceptHeaders: []string{"*/*"}, + wantStatus: http.StatusOK, + }, + { + name: "AcceptTextWildcard", + acceptHeaders: []string{"text/*"}, + wantStatus: http.StatusOK, + }, + { + name: "AcceptMissing", + wantStatus: http.StatusOK, + }, + { + name: "AcceptJSON", + acceptHeaders: []string{"application/json"}, + wantStatus: http.StatusNotAcceptable, + }, + { + name: "AcceptXHTML", + acceptHeaders: []string{"application/xhtml+xml"}, + wantStatus: http.StatusOK, + }, + { + name: "AcceptHTMLWithQualityZero", + acceptHeaders: []string{"text/html;q=0"}, + wantStatus: http.StatusNotAcceptable, + }, + { + name: "AcceptMultipleHeaders", + acceptHeaders: []string{"application/json", "text/html"}, + wantStatus: http.StatusOK, + }, + { + name: "AcceptMultipleValues", + acceptHeaders: []string{"application/json, text/html"}, + wantStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + recorder := performRequest(handler, http.MethodGet, tt.acceptHeaders) + require.Equal(t, tt.wantStatus, recorder.Result().StatusCode) + }) + } +} + +func TestServeIndexHTMLVaryHeader(t *testing.T) { + t.Parallel() + + handler := serveIndexHTML(false, map[string]any{}, "/riverui", newIndexFileSystem()) + + recorder := performRequest(handler, http.MethodGet, []string{"application/json"}) + require.Equal(t, http.StatusNotAcceptable, recorder.Result().StatusCode) + require.Contains(t, strings.Join(recorder.Header().Values("Vary"), ","), "Accept") +} + +func TestServeIndexHTMLMethodNotAllowed(t *testing.T) { + t.Parallel() + + handler := serveIndexHTML(false, map[string]any{}, "/riverui", newIndexFileSystem()) + + recorder := performRequest(handler, http.MethodPost, []string{"text/html"}) + require.Equal(t, http.StatusMethodNotAllowed, recorder.Result().StatusCode) + require.Contains(t, recorder.Header().Get("Allow"), "GET") + require.Contains(t, recorder.Header().Get("Allow"), "HEAD") +} + +func TestServeIndexHTMLHead(t *testing.T) { + t.Parallel() + + handler := serveIndexHTML(false, map[string]any{}, "/riverui", newIndexFileSystem()) + + recorder := performRequest(handler, http.MethodHead, []string{"text/html"}) + require.Equal(t, http.StatusOK, recorder.Result().StatusCode) + require.Empty(t, recorder.Body.String()) +} + +func TestServeIndexHTMLTemplateCaching(t *testing.T) { + t.Parallel() + + t.Run("CachesWhenNotDevMode", func(t *testing.T) { + t.Parallel() + + counting := &countingFS{fs: newIndexFileSystem()} + handler := serveIndexHTML(false, map[string]any{}, "/riverui", counting) + + require.Equal(t, http.StatusOK, performRequest(handler, http.MethodGet, []string{"text/html"}).Result().StatusCode) + require.Equal(t, http.StatusOK, performRequest(handler, http.MethodGet, []string{"text/html"}).Result().StatusCode) + + require.Equal(t, int32(1), counting.opens.Load()) + }) + + t.Run("NoCacheInDevMode", func(t *testing.T) { + t.Parallel() + + counting := &countingFS{fs: newIndexFileSystem()} + handler := serveIndexHTML(true, map[string]any{}, "/riverui", counting) + + require.Equal(t, http.StatusOK, performRequest(handler, http.MethodGet, []string{"text/html"}).Result().StatusCode) + require.Equal(t, http.StatusOK, performRequest(handler, http.MethodGet, []string{"text/html"}).Result().StatusCode) + + require.Equal(t, int32(2), counting.opens.Load()) + }) +} + +func performRequest(handler http.Handler, method string, acceptHeaders []string) *httptest.ResponseRecorder { + req := httptest.NewRequest(method, "/", nil) + for _, header := range acceptHeaders { + req.Header.Add("Accept", header) + } + + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, req) + + return recorder +} + +func newIndexFileSystem() http.FileSystem { + files := fstest.MapFS{ + "index.html": &fstest.MapFile{Data: []byte("ok")}, + } + + return http.FS(files) +} + +type countingFS struct { + fs http.FileSystem + opens atomic.Int32 +} + +func (counting *countingFS) Open(name string) (http.File, error) { + counting.opens.Add(1) + return counting.fs.Open(name) +}