diff --git a/cachew.hcl b/cachew.hcl index 3534e99..57595f3 100644 --- a/cachew.hcl +++ b/cachew.hcl @@ -26,3 +26,7 @@ disk { limit-mb = 250000 max-ttl = "8h" } + +gomod { + proxy = "https://proxy.golang.org" +} \ No newline at end of file diff --git a/go.mod b/go.mod index b85ba05..ca8d1a0 100644 --- a/go.mod +++ b/go.mod @@ -5,12 +5,14 @@ go 1.25.5 require ( github.com/alecthomas/hcl/v2 v2.3.1 github.com/alecthomas/kong v1.13.0 + github.com/goproxy/goproxy v0.25.0 github.com/lmittmann/tint v1.1.2 github.com/minio/minio-go/v7 v7.0.97 go.etcd.io/bbolt v1.4.3 ) require ( + github.com/aofei/backoff v1.1.0 // indirect github.com/dlclark/regexp2 v1.11.5 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/go-ini/ini v1.67.0 // indirect @@ -28,6 +30,7 @@ require ( github.com/stretchr/testify v1.11.1 // indirect github.com/tinylib/msgp v1.3.0 // indirect golang.org/x/crypto v0.44.0 // indirect + golang.org/x/mod v0.31.0 // indirect golang.org/x/net v0.47.0 // indirect golang.org/x/sys v0.38.0 // indirect golang.org/x/text v0.32.0 // indirect diff --git a/go.sum b/go.sum index f22a084..9ef6ac0 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/alecthomas/participle/v2 v2.1.4 h1:W/H79S8Sat/krZ3el6sQMvMaahJ+XcM9WS github.com/alecthomas/participle/v2 v2.1.4/go.mod h1:8tqVbpTX20Ru4NfYQgZf4mP18eXPTBViyMWiArNEgGI= github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs= github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= +github.com/aofei/backoff v1.1.0 h1:7ey7Ydpx/eFIyyrBNKPbgvTzvIuUOHcwkR3gPjjY9ag= +github.com/aofei/backoff v1.1.0/go.mod h1:IHCkMdd5vGP6dcDHD+uLn6lVuBw7+rKYaS7e7QIQwYA= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -23,6 +25,8 @@ github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/goproxy/goproxy v0.25.0 h1:TujZjUbKCwpFYrm+j04HACs1EAcBbFSGLwLMn8ynTys= +github.com/goproxy/goproxy v0.25.0/go.mod h1:6RIssMPDpQ0IHZus17gPUyBtU62RoqblQDYWx2sz/qs= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= @@ -65,6 +69,8 @@ go.etcd.io/bbolt v1.4.3 h1:dEadXpI6G79deX5prL3QRNP6JB8UxVkqo4UPnHaNXJo= go.etcd.io/bbolt v1.4.3/go.mod h1:tKQlpPaYCVFctUIgFKFnAlvbmB3tpy1vkTnDWohtc0E= golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU= golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc= +golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= +golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= @@ -73,6 +79,8 @@ golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= +golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/internal/strategy/gomod.go b/internal/strategy/gomod.go index 7b066e9..2ff4653 100644 --- a/internal/strategy/gomod.go +++ b/internal/strategy/gomod.go @@ -6,55 +6,32 @@ import ( "log/slog" "net/http" "net/url" - "strings" - "time" + + "github.com/goproxy/goproxy" "github.com/block/cachew/internal/cache" - "github.com/block/cachew/internal/httputil" "github.com/block/cachew/internal/jobscheduler" "github.com/block/cachew/internal/logging" - "github.com/block/cachew/internal/strategy/handler" ) func init() { Register("gomod", "Caches Go module proxy requests.", NewGoMod) } -// GoModConfig represents the configuration for the Go module proxy strategy. -// -// In HCL it looks like: -// -// gomod { -// proxy = "https://proxy.golang.org" -// } type GoModConfig struct { - Proxy string `hcl:"proxy,optional" help:"Upstream Go module proxy URL (defaults to proxy.golang.org)" default:"https://proxy.golang.org"` - MutableTTL time.Duration `hcl:"mutable-ttl,optional" help:"TTL for mutable Go module proxy endpoints (list, latest). Defaults to 5m." default:"5m"` - ImmutableTTL time.Duration `hcl:"immutable-ttl,optional" help:"TTL for immutable Go module proxy endpoints (versioned info, mod, zip). Defaults to 168h (7 days)." default:"168h"` + Proxy string `hcl:"proxy,optional" help:"Upstream Go module proxy URL (defaults to proxy.golang.org)" default:"https://proxy.golang.org"` } -// The GoMod strategy implements a caching proxy for the Go module proxy protocol. -// -// It supports all standard GOPROXY endpoints: -// - /$module/@v/list - Lists available versions -// - /$module/@v/$version.info - Version metadata JSON -// - /$module/@v/$version.mod - go.mod file -// - /$module/@v/$version.zip - Module source code -// - /$module/@latest - Latest version info -// -// The strategy uses differential caching: short TTL (5 minutes) for mutable -// endpoints (list, latest) and long TTL (7 days) for immutable versioned content. type GoMod struct { - config GoModConfig - cache cache.Cache - client *http.Client - logger *slog.Logger - proxy *url.URL + config GoModConfig + cache cache.Cache + logger *slog.Logger + proxy *url.URL + goproxy *goproxy.Goproxy } var _ Strategy = (*GoMod)(nil) -// NewGoMod creates a new Go module proxy strategy. func NewGoMod(ctx context.Context, config GoModConfig, _ jobscheduler.Scheduler, cache cache.Cache, mux Mux) (*GoMod, error) { parsedURL, err := url.Parse(config.Proxy) if err != nil { @@ -64,105 +41,34 @@ func NewGoMod(ctx context.Context, config GoModConfig, _ jobscheduler.Scheduler, g := &GoMod{ config: config, cache: cache, - client: http.DefaultClient, logger: logging.FromContext(ctx), proxy: parsedURL, } + g.goproxy = &goproxy.Goproxy{ + Logger: g.logger, + Fetcher: &goproxy.GoFetcher{ + Env: []string{ + "GOPROXY=" + config.Proxy, + "GOSUMDB=off", // Disable checksum database validation in fetcher, to prevent unneccessary double validation + }, + }, + Cacher: &goproxyCacher{ + cache: cache, + }, + ProxiedSumDBs: []string{ + "sum.golang.org https://sum.golang.org", + }, + } + g.logger.InfoContext(ctx, "Initialized Go module proxy strategy", slog.String("proxy", g.proxy.String())) - // Create handler with caching configuration - h := handler.New(g.client, g.cache). - CacheKey(func(r *http.Request) string { - return g.buildUpstreamURL(r).String() - }). - Transform(g.transformRequest). - TTL(g.calculateTTL) - - // Register a namespaced handler for Go module proxy patterns - mux.Handle("GET /gomod/{path...}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - // Check if this is a valid Go module proxy endpoint - if g.isGoModulePath(path) { - h.ServeHTTP(w, r) - return - } - http.NotFound(w, r) - })) + mux.Handle("GET /gomod/{path...}", http.StripPrefix("/gomod", g.goproxy)) return g, nil } -// isGoModulePath checks if the path matches a valid Go module proxy endpoint pattern. -func (g *GoMod) isGoModulePath(path string) bool { - // Strip the /gomod prefix before checking the pattern - path = strings.TrimPrefix(path, "/gomod") - - // Valid patterns: - // - /@v/list - // - /@v/{version}.info - // - /@v/{version}.mod - // - /@v/{version}.zip - // - /@latest - return strings.HasSuffix(path, "/@v/list") || - strings.HasSuffix(path, "/@latest") || - (strings.Contains(path, "/@v/") && - (strings.HasSuffix(path, ".info") || - strings.HasSuffix(path, ".mod") || - strings.HasSuffix(path, ".zip"))) -} - func (g *GoMod) String() string { return "gomod:" + g.proxy.Host } - -// buildUpstreamURL constructs the full upstream URL from the incoming request. -func (g *GoMod) buildUpstreamURL(r *http.Request) *url.URL { - // The full path includes the module path and the endpoint - // e.g., /gomod/github.com/user/repo/@v/v1.0.0.info - // We need to strip the /gomod prefix before forwarding to the upstream proxy - path := r.URL.Path - path = strings.TrimPrefix(path, "/gomod") - if !strings.HasPrefix(path, "/") { - path = "/" + path - } - - targetURL := *g.proxy - targetURL.Path = g.proxy.Path + path - targetURL.RawQuery = r.URL.RawQuery - - return &targetURL -} - -// transformRequest creates the upstream request to the Go module proxy. -func (g *GoMod) transformRequest(r *http.Request) (*http.Request, error) { - targetURL := g.buildUpstreamURL(r) - - g.logger.DebugContext(r.Context(), "Transforming Go module request", - slog.String("original_path", r.URL.Path), - slog.String("upstream_url", targetURL.String())) - - req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, targetURL.String(), nil) - if err != nil { - return nil, httputil.Errorf(http.StatusInternalServerError, "create upstream request: %w", err) - } - - return req, nil -} - -// calculateTTL returns the appropriate cache TTL based on the endpoint type. -// -// Mutable endpoints (list, latest) get short TTL (5 minutes). -// Immutable versioned content (info, mod, zip) gets long TTL (7 days). -func (g *GoMod) calculateTTL(r *http.Request) time.Duration { - path := r.URL.Path - - // Short TTL for mutable endpoints - if strings.HasSuffix(path, "/@v/list") || strings.HasSuffix(path, "/@latest") { - return g.config.MutableTTL - } - - // Long TTL for immutable versioned content (.info, .mod, .zip) - return g.config.ImmutableTTL -} diff --git a/internal/strategy/gomod_cacher.go b/internal/strategy/gomod_cacher.go new file mode 100644 index 0000000..c66a38a --- /dev/null +++ b/internal/strategy/gomod_cacher.go @@ -0,0 +1,54 @@ +package strategy + +import ( + "context" + "fmt" + "io" + "io/fs" + "strings" + + "github.com/block/cachew/internal/cache" +) + +type goproxyCacher struct { + cache cache.Cache +} + +func (g *goproxyCacher) Get(ctx context.Context, name string) (io.ReadCloser, error) { + key := cache.NewKey(name) + + rc, _, err := g.cache.Open(ctx, key) + if err != nil { + return nil, fs.ErrNotExist + } + + return rc, nil +} + +func (g *goproxyCacher) Put(ctx context.Context, name string, content io.ReadSeeker) error { + if strings.HasSuffix(name, "/@v/list") || strings.HasSuffix(name, "/@latest") { + return nil + } + + key := cache.NewKey(name) + + wc, err := g.cache.Create(ctx, key, nil, 0) + if err != nil { + return fmt.Errorf("create cache entry: %w", err) + } + defer wc.Close() + + if _, err := content.Seek(0, io.SeekStart); err != nil { + return fmt.Errorf("seek to start: %w", err) + } + + if _, err := io.Copy(wc, content); err != nil { + return fmt.Errorf("write to cache: %w", err) + } + + if err := wc.Close(); err != nil { + return fmt.Errorf("close cache entry: %w", err) + } + + return nil +} diff --git a/internal/strategy/gomod_test.go b/internal/strategy/gomod_test.go index b60ac85..c30a18f 100644 --- a/internal/strategy/gomod_test.go +++ b/internal/strategy/gomod_test.go @@ -1,11 +1,14 @@ package strategy_test import ( + "archive/zip" + "bytes" "context" "log/slog" "net/http" "net/http/httptest" "strings" + "sync" "testing" "time" @@ -20,8 +23,10 @@ import ( type mockGoModServer struct { server *httptest.Server requestCount map[string]int // Track requests by path + mu sync.Mutex // Protects requestCount lastPath string responses map[string]mockResponse + t *testing.T } type mockResponse struct { @@ -29,28 +34,17 @@ type mockResponse struct { content string } -func newMockGoModServer() *mockGoModServer { +func newMockGoModServer(t *testing.T) *mockGoModServer { m := &mockGoModServer{ requestCount: make(map[string]int), responses: make(map[string]mockResponse), + t: t, } // Set up default responses for common endpoints m.responses["/@v/list"] = mockResponse{ status: http.StatusOK, - content: "v1.0.0\nv1.0.1\nv1.1.0\n", - } - m.responses["/@v/v1.0.0.info"] = mockResponse{ - status: http.StatusOK, - content: `{"Version":"v1.0.0","Time":"2023-01-01T00:00:00Z"}`, - } - m.responses["/@v/v1.0.0.mod"] = mockResponse{ - status: http.StatusOK, - content: "module github.com/example/test\n\ngo 1.21\n", - } - m.responses["/@v/v1.0.0.zip"] = mockResponse{ - status: http.StatusOK, - content: "PK\x03\x04...", // Mock zip content + content: "v1.0.0\nv1.0.1\nv1.1.0", } m.responses["/@latest"] = mockResponse{ status: http.StatusOK, @@ -64,21 +58,44 @@ func newMockGoModServer() *mockGoModServer { return m } +func createModuleZip(t *testing.T, modulePath, version string) string { + t.Helper() + var buf bytes.Buffer + w := zip.NewWriter(&buf) + + prefix := modulePath + "@" + version + "/" + + f, err := w.Create(prefix + "go.mod") + assert.NoError(t, err) + _, err = f.Write([]byte("module " + modulePath + "\n\ngo 1.21\n")) + assert.NoError(t, err) + + f2, err := w.Create(prefix + "main.go") + assert.NoError(t, err) + _, err = f2.Write([]byte("package main\n\nfunc main() {}\n")) + assert.NoError(t, err) + + err = w.Close() + assert.NoError(t, err) + + return buf.String() +} + func (m *mockGoModServer) handleRequest(w http.ResponseWriter, r *http.Request) { path := r.URL.Path + + m.mu.Lock() m.lastPath = path m.requestCount[path]++ + m.mu.Unlock() - // Find matching response var resp mockResponse found := false - // Try exact match first if r, ok := m.responses[path]; ok { resp = r found = true } else { - // Try suffix match for module paths for suffix, r := range m.responses { if len(path) >= len(suffix) && path[len(path)-len(suffix):] == suffix { resp = r @@ -88,27 +105,34 @@ func (m *mockGoModServer) handleRequest(w http.ResponseWriter, r *http.Request) } } - // If still not found, try pattern matching for any version if !found && strings.Contains(path, "/@v/") { - switch { - case strings.HasSuffix(path, ".info"): - resp = mockResponse{ - status: http.StatusOK, - content: `{"Version":"v1.0.0","Time":"2023-01-01T00:00:00Z"}`, - } - found = true - case strings.HasSuffix(path, ".mod"): - resp = mockResponse{ - status: http.StatusOK, - content: "module github.com/example/test\n\ngo 1.21\n", - } - found = true - case strings.HasSuffix(path, ".zip"): - resp = mockResponse{ - status: http.StatusOK, - content: "PK\x03\x04...", + parts := strings.Split(path, "/@v/") + if len(parts) == 2 { + modulePath := strings.TrimPrefix(parts[0], "/") + versionPart := parts[1] + + switch { + case strings.HasSuffix(path, ".info"): + version := strings.TrimSuffix(versionPart, ".info") + resp = mockResponse{ + status: http.StatusOK, + content: `{"Version":"` + version + `","Time":"2023-01-01T00:00:00Z"}`, + } + found = true + case strings.HasSuffix(path, ".mod"): + resp = mockResponse{ + status: http.StatusOK, + content: "module " + modulePath + "\n\ngo 1.21\n", + } + found = true + case strings.HasSuffix(path, ".zip"): + version := strings.TrimSuffix(versionPart, ".zip") + resp = mockResponse{ + status: http.StatusOK, + content: createModuleZip(m.t, modulePath, version), + } + found = true } - found = true } } @@ -133,10 +157,16 @@ func (m *mockGoModServer) setResponse(path string, status int, content string) { } } +func (m *mockGoModServer) getRequestCount(path string) int { + m.mu.Lock() + defer m.mu.Unlock() + return m.requestCount[path] +} + func setupGoModTest(t *testing.T) (*mockGoModServer, *http.ServeMux, context.Context) { t.Helper() - mock := newMockGoModServer() + mock := newMockGoModServer(t) t.Cleanup(mock.close) _, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelError}) @@ -147,9 +177,7 @@ func setupGoModTest(t *testing.T) (*mockGoModServer, *http.ServeMux, context.Con mux := http.NewServeMux() _, err = strategy.NewGoMod(ctx, strategy.GoModConfig{ - Proxy: mock.server.URL, - MutableTTL: 5 * time.Minute, - ImmutableTTL: 168 * time.Hour, + Proxy: mock.server.URL, }, jobscheduler.New(ctx, jobscheduler.Config{}), memCache, mux) assert.NoError(t, err) @@ -166,8 +194,11 @@ func TestGoModList(t *testing.T) { mux.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, "v1.0.0\nv1.0.1\nv1.1.0\n", w.Body.String()) - assert.Equal(t, 1, mock.requestCount["/github.com/example/test/@v/list"]) + body := strings.TrimSpace(w.Body.String()) + assert.True(t, strings.Contains(body, "v1.0.0"), "response should contain v1.0.0") + assert.True(t, strings.Contains(body, "v1.0.1"), "response should contain v1.0.1") + assert.True(t, strings.Contains(body, "v1.1.0"), "response should contain v1.1.0") + assert.Equal(t, 1, mock.getRequestCount("/github.com/example/test/@v/list")) } func TestGoModInfo(t *testing.T) { @@ -181,7 +212,7 @@ func TestGoModInfo(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, `{"Version":"v1.0.0","Time":"2023-01-01T00:00:00Z"}`, w.Body.String()) - assert.Equal(t, 1, mock.requestCount["/github.com/example/test/@v/v1.0.0.info"]) + assert.Equal(t, 1, mock.getRequestCount("/github.com/example/test/@v/v1.0.0.info")) } func TestGoModMod(t *testing.T) { @@ -195,7 +226,7 @@ func TestGoModMod(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "module github.com/example/test\n\ngo 1.21\n", w.Body.String()) - assert.Equal(t, 1, mock.requestCount["/github.com/example/test/@v/v1.0.0.mod"]) + assert.Equal(t, 1, mock.getRequestCount("/github.com/example/test/@v/v1.0.0.mod")) } func TestGoModZip(t *testing.T) { @@ -208,8 +239,8 @@ func TestGoModZip(t *testing.T) { mux.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, "PK\x03\x04...", w.Body.String()) - assert.Equal(t, 1, mock.requestCount["/github.com/example/test/@v/v1.0.0.zip"]) + assert.True(t, strings.HasPrefix(w.Body.String(), "PK"), "response should be a valid zip file") + assert.True(t, mock.getRequestCount("/github.com/example/test/@v/v1.0.0.zip") >= 1, "should have fetched zip") } func TestGoModLatest(t *testing.T) { @@ -223,7 +254,7 @@ func TestGoModLatest(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, `{"Version":"v1.1.0","Time":"2023-06-01T00:00:00Z"}`, w.Body.String()) - assert.Equal(t, 1, mock.requestCount["/github.com/example/test/@latest"]) + assert.Equal(t, 1, mock.getRequestCount("/github.com/example/test/@latest")) } func TestGoModCaching(t *testing.T) { @@ -232,16 +263,14 @@ func TestGoModCaching(t *testing.T) { path := "/gomod/github.com/example/test/@v/v1.0.0.info" upstreamPath := "/github.com/example/test/@v/v1.0.0.info" - // First request req1 := httptest.NewRequest(http.MethodGet, path, nil) req1 = req1.WithContext(ctx) w1 := httptest.NewRecorder() mux.ServeHTTP(w1, req1) assert.Equal(t, http.StatusOK, w1.Code) - assert.Equal(t, 1, mock.requestCount[upstreamPath]) + assert.Equal(t, 1, mock.getRequestCount(upstreamPath)) - // Second request should hit cache req2 := httptest.NewRequest(http.MethodGet, path, nil) req2 = req2.WithContext(ctx) w2 := httptest.NewRecorder() @@ -249,13 +278,12 @@ func TestGoModCaching(t *testing.T) { assert.Equal(t, http.StatusOK, w2.Code) assert.Equal(t, w1.Body.String(), w2.Body.String()) - assert.Equal(t, 1, mock.requestCount[upstreamPath], "second request should be served from cache") + assert.Equal(t, 1, mock.getRequestCount(upstreamPath), "second request should be served from cache") } func TestGoModComplexModulePath(t *testing.T) { mock, mux, ctx := setupGoModTest(t) - // Test module path with multiple slashes req := httptest.NewRequest(http.MethodGet, "/gomod/golang.org/x/tools/@v/v0.1.0.info", nil) req = req.WithContext(ctx) w := httptest.NewRecorder() @@ -263,34 +291,31 @@ func TestGoModComplexModulePath(t *testing.T) { mux.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, 1, mock.requestCount["/golang.org/x/tools/@v/v0.1.0.info"]) + assert.Equal(t, 1, mock.getRequestCount("/golang.org/x/tools/@v/v0.1.0.info")) } func TestGoModNonOKResponse(t *testing.T) { mock, mux, ctx := setupGoModTest(t) - // Set up 404 response upstreamPath := "/github.com/example/nonexistent/@v/v99.0.0.info" notFoundPath := "/gomod" + upstreamPath mock.setResponse(upstreamPath, http.StatusNotFound, "not found") - // First request should return 404 req1 := httptest.NewRequest(http.MethodGet, notFoundPath, nil) req1 = req1.WithContext(ctx) w1 := httptest.NewRecorder() mux.ServeHTTP(w1, req1) assert.Equal(t, http.StatusNotFound, w1.Code) - assert.Equal(t, 1, mock.requestCount[upstreamPath]) + assert.Equal(t, 1, mock.getRequestCount(upstreamPath)) - // Second request should also hit upstream (404s are not cached) req2 := httptest.NewRequest(http.MethodGet, notFoundPath, nil) req2 = req2.WithContext(ctx) w2 := httptest.NewRecorder() mux.ServeHTTP(w2, req2) assert.Equal(t, http.StatusNotFound, w2.Code) - assert.Equal(t, 2, mock.requestCount[upstreamPath], "404 responses should not be cached") + assert.Equal(t, 2, mock.getRequestCount(upstreamPath), "404 responses should not be cached") } func TestGoModMultipleConcurrentRequests(t *testing.T) { @@ -299,7 +324,6 @@ func TestGoModMultipleConcurrentRequests(t *testing.T) { path := "/gomod/github.com/example/test/@v/v1.0.0.zip" upstreamPath := "/github.com/example/test/@v/v1.0.0.zip" - // Make multiple concurrent requests results := make(chan *httptest.ResponseRecorder, 3) for range 3 { go func() { @@ -311,14 +335,58 @@ func TestGoModMultipleConcurrentRequests(t *testing.T) { }() } - // Collect results for range 3 { w := <-results assert.Equal(t, http.StatusOK, w.Code) } - // First request should have created the cache entry - // Subsequent requests might hit cache or might be in-flight - // We just verify all requests succeeded - assert.True(t, mock.requestCount[upstreamPath] >= 1, "at least one request should have been made to upstream") + assert.True(t, mock.getRequestCount(upstreamPath) >= 1, "at least one request should have been made to upstream") +} + +func TestGoModListNotCached(t *testing.T) { + mock, mux, ctx := setupGoModTest(t) + + path := "/gomod/github.com/example/test/@v/list" + upstreamPath := "/github.com/example/test/@v/list" + + req1 := httptest.NewRequest(http.MethodGet, path, nil) + req1 = req1.WithContext(ctx) + w1 := httptest.NewRecorder() + mux.ServeHTTP(w1, req1) + + assert.Equal(t, http.StatusOK, w1.Code) + assert.Equal(t, 1, mock.getRequestCount(upstreamPath)) + + req2 := httptest.NewRequest(http.MethodGet, path, nil) + req2 = req2.WithContext(ctx) + w2 := httptest.NewRecorder() + mux.ServeHTTP(w2, req2) + + assert.Equal(t, http.StatusOK, w2.Code) + assert.Equal(t, w1.Body.String(), w2.Body.String()) + assert.Equal(t, 2, mock.getRequestCount(upstreamPath), "/@v/list endpoint should not be cached") +} + +func TestGoModLatestNotCached(t *testing.T) { + mock, mux, ctx := setupGoModTest(t) + + path := "/gomod/github.com/example/test/@latest" + upstreamPath := "/github.com/example/test/@latest" + + req1 := httptest.NewRequest(http.MethodGet, path, nil) + req1 = req1.WithContext(ctx) + w1 := httptest.NewRecorder() + mux.ServeHTTP(w1, req1) + + assert.Equal(t, http.StatusOK, w1.Code) + assert.Equal(t, 1, mock.getRequestCount(upstreamPath)) + + req2 := httptest.NewRequest(http.MethodGet, path, nil) + req2 = req2.WithContext(ctx) + w2 := httptest.NewRecorder() + mux.ServeHTTP(w2, req2) + + assert.Equal(t, http.StatusOK, w2.Code) + assert.Equal(t, w1.Body.String(), w2.Body.String()) + assert.Equal(t, 2, mock.getRequestCount(upstreamPath), "/@latest endpoint should not be cached") }