From 863f7c100d0bc840a9ea699be3424cccf6c04de7 Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Fri, 23 Jan 2026 18:02:30 +1100 Subject: [PATCH] refactor: switch cache to use http.Header I previously used textproto.MIMEHeader to keep the lower level cache interface separated from the HTTP protocol, but it does introduce friction and I think this is simpler. --- internal/cache/api.go | 12 ++++++------ internal/cache/cachetest/suite.go | 7 +++---- internal/cache/disk.go | 11 +++++------ internal/cache/disk_metadb.go | 8 ++++---- internal/cache/http.go | 5 ++--- internal/cache/memory.go | 13 ++++++------- internal/cache/remote.go | 11 +++++------ internal/cache/s3.go | 15 +++++++-------- internal/cache/tiered.go | 8 ++++---- internal/strategy/apiv1.go | 3 +-- internal/strategy/git/bundle.go | 4 ++-- internal/strategy/handler/handler.go | 3 +-- 12 files changed, 46 insertions(+), 54 deletions(-) diff --git a/internal/cache/api.go b/internal/cache/api.go index b7dd930..2dbdaaa 100644 --- a/internal/cache/api.go +++ b/internal/cache/api.go @@ -6,7 +6,7 @@ import ( "crypto/sha256" "encoding/hex" "io" - "net/textproto" + "net/http" "time" "github.com/alecthomas/errors" @@ -97,8 +97,8 @@ func (k *Key) MarshalText() ([]byte, error) { // FilterTransportHeaders returns a copy of the given headers with standard HTTP transport headers removed. // These headers are typically added by HTTP clients/servers and should not be cached. -func FilterTransportHeaders(headers textproto.MIMEHeader) textproto.MIMEHeader { - filtered := make(textproto.MIMEHeader) +func FilterTransportHeaders(headers http.Header) http.Header { + filtered := make(http.Header) for key, values := range headers { // Skip standard HTTP headers added by transport layer or that shouldn't be cached if key == "Content-Length" || key == "Date" || key == "Accept-Encoding" || @@ -120,13 +120,13 @@ type Cache interface { // // Expired files MUST not be returned. // Must return os.ErrNotExist if the file does not exist. - Stat(ctx context.Context, key Key) (textproto.MIMEHeader, error) + Stat(ctx context.Context, key Key) (http.Header, error) // Open an existing file in the cache. // // Expired files MUST NOT be returned. // The returned headers MUST include a Last-Modified header. // Must return os.ErrNotExist if the file does not exist. - Open(ctx context.Context, key Key) (io.ReadCloser, textproto.MIMEHeader, error) + Open(ctx context.Context, key Key) (io.ReadCloser, http.Header, error) // Create a new file in the cache. // // If "ttl" is zero, a maximum TTL MUST be used by the implementation. @@ -134,7 +134,7 @@ type Cache interface { // The file MUST NOT be available for read until completely written and closed. // // If the context is cancelled the object MUST NOT be made available in the cache. - Create(ctx context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) + Create(ctx context.Context, key Key, headers http.Header, ttl time.Duration) (io.WriteCloser, error) // Delete a file from the cache. // // MUST be atomic. diff --git a/internal/cache/cachetest/suite.go b/internal/cache/cachetest/suite.go index ed4e2c9..db43d3c 100644 --- a/internal/cache/cachetest/suite.go +++ b/internal/cache/cachetest/suite.go @@ -4,7 +4,6 @@ import ( "context" "io" "net/http" - "net/textproto" "os" "testing" "time" @@ -215,7 +214,7 @@ func testHeaders(t *testing.T, c cache.Cache) { key := cache.NewKey("test-key-with-headers") // Create headers to store - headers := textproto.MIMEHeader{ + headers := http.Header{ "Content-Type": []string{"application/json"}, "Cache-Control": []string{"max-age=3600"}, "X-Custom-Field": []string{"custom-value"}, @@ -258,7 +257,7 @@ func testContextCancellation(t *testing.T, c cache.Cache) { // Create an object with the cancellable context key := cache.NewKey("test-cancelled") - writer, err := c.Create(cancelledCtx, key, textproto.MIMEHeader{}, time.Hour) + writer, err := c.Create(cancelledCtx, key, http.Header{}, time.Hour) assert.NoError(t, err) // Write some data @@ -310,7 +309,7 @@ func testLastModified(t *testing.T, c cache.Cache) { // Test with explicit Last-Modified header key2 := cache.NewKey("test-last-modified-explicit") explicitTime := time.Date(2023, 1, 15, 12, 30, 0, 0, time.UTC) - explicitHeaders := textproto.MIMEHeader{ + explicitHeaders := http.Header{ "Last-Modified": []string{explicitTime.Format(http.TimeFormat)}, } diff --git a/internal/cache/disk.go b/internal/cache/disk.go index 4e1b7d2..63b5b08 100644 --- a/internal/cache/disk.go +++ b/internal/cache/disk.go @@ -7,7 +7,6 @@ import ( "log/slog" "maps" "net/http" - "net/textproto" "os" "path/filepath" "sort" @@ -130,14 +129,14 @@ func (d *Disk) Size() int64 { return d.size.Load() } -func (d *Disk) Create(ctx context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) { +func (d *Disk) Create(ctx context.Context, key Key, headers http.Header, ttl time.Duration) (io.WriteCloser, error) { if ttl > d.config.MaxTTL || ttl == 0 { ttl = d.config.MaxTTL } now := time.Now() // Clone headers to avoid concurrent map writes - clonedHeaders := make(textproto.MIMEHeader) + clonedHeaders := make(http.Header) maps.Copy(clonedHeaders, headers) if clonedHeaders.Get("Last-Modified") == "" { clonedHeaders.Set("Last-Modified", now.UTC().Format(http.TimeFormat)) @@ -204,7 +203,7 @@ func (d *Disk) Delete(_ context.Context, key Key) error { return nil } -func (d *Disk) Stat(ctx context.Context, key Key) (textproto.MIMEHeader, error) { +func (d *Disk) Stat(ctx context.Context, key Key) (http.Header, error) { path := d.keyToPath(key) fullPath := filepath.Join(d.config.Root, path) @@ -229,7 +228,7 @@ func (d *Disk) Stat(ctx context.Context, key Key) (textproto.MIMEHeader, error) return headers, nil } -func (d *Disk) Open(ctx context.Context, key Key) (io.ReadCloser, textproto.MIMEHeader, error) { +func (d *Disk) Open(ctx context.Context, key Key) (io.ReadCloser, http.Header, error) { path := d.keyToPath(key) fullPath := filepath.Join(d.config.Root, path) @@ -378,7 +377,7 @@ type diskWriter struct { path string tempPath string expiresAt time.Time - headers textproto.MIMEHeader + headers http.Header size int64 ctx context.Context } diff --git a/internal/cache/disk_metadb.go b/internal/cache/disk_metadb.go index 673658c..82017df 100644 --- a/internal/cache/disk_metadb.go +++ b/internal/cache/disk_metadb.go @@ -2,7 +2,7 @@ package cache import ( "encoding/json" - "net/textproto" + "net/http" "time" "github.com/alecthomas/errors" @@ -55,7 +55,7 @@ func (s *diskMetaDB) setTTL(key Key, expiresAt time.Time) error { })) } -func (s *diskMetaDB) set(key Key, expiresAt time.Time, headers textproto.MIMEHeader) error { +func (s *diskMetaDB) set(key Key, expiresAt time.Time, headers http.Header) error { ttlBytes, err := expiresAt.MarshalBinary() if err != nil { return errors.Errorf("failed to marshal TTL: %w", err) @@ -90,8 +90,8 @@ func (s *diskMetaDB) getTTL(key Key) (time.Time, error) { return expiresAt, errors.WithStack(err) } -func (s *diskMetaDB) getHeaders(key Key) (textproto.MIMEHeader, error) { - var headers textproto.MIMEHeader +func (s *diskMetaDB) getHeaders(key Key) (http.Header, error) { + var headers http.Header err := s.db.View(func(tx *bbolt.Tx) error { bucket := tx.Bucket(headersBucketName) headersBytes := bucket.Get(key[:]) diff --git a/internal/cache/http.go b/internal/cache/http.go index ca2c7e1..ec2ba01 100644 --- a/internal/cache/http.go +++ b/internal/cache/http.go @@ -4,7 +4,6 @@ import ( "io" "maps" "net/http" - "net/textproto" "os" "github.com/alecthomas/errors" @@ -27,7 +26,7 @@ func Fetch(client *http.Client, r *http.Request, c Cache) (*http.Response, error Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, - Header: http.Header(headers), + Header: headers, Body: cr, ContentLength: -1, Request: r, @@ -53,7 +52,7 @@ func FetchDirect(client *http.Client, r *http.Request, c Cache, key Key) (*http. return resp, nil } - responseHeaders := textproto.MIMEHeader(maps.Clone(resp.Header)) + responseHeaders := maps.Clone(resp.Header) cw, err := c.Create(r.Context(), key, responseHeaders, 0) if err != nil { _ = resp.Body.Close() diff --git a/internal/cache/memory.go b/internal/cache/memory.go index 4e0ad89..91f8c41 100644 --- a/internal/cache/memory.go +++ b/internal/cache/memory.go @@ -7,7 +7,6 @@ import ( "io" "maps" "net/http" - "net/textproto" "os" "sync" "time" @@ -33,7 +32,7 @@ type MemoryConfig struct { type memoryEntry struct { data []byte expiresAt time.Time - headers textproto.MIMEHeader + headers http.Header } type Memory struct { @@ -53,7 +52,7 @@ func NewMemory(ctx context.Context, config MemoryConfig) (*Memory, error) { func (m *Memory) String() string { return fmt.Sprintf("memory:%dMB", m.config.LimitMB) } -func (m *Memory) Stat(_ context.Context, key Key) (textproto.MIMEHeader, error) { +func (m *Memory) Stat(_ context.Context, key Key) (http.Header, error) { m.mu.RLock() defer m.mu.RUnlock() @@ -69,7 +68,7 @@ func (m *Memory) Stat(_ context.Context, key Key) (textproto.MIMEHeader, error) return entry.headers, nil } -func (m *Memory) Open(_ context.Context, key Key) (io.ReadCloser, textproto.MIMEHeader, error) { +func (m *Memory) Open(_ context.Context, key Key) (io.ReadCloser, http.Header, error) { m.mu.RLock() defer m.mu.RUnlock() @@ -85,14 +84,14 @@ func (m *Memory) Open(_ context.Context, key Key) (io.ReadCloser, textproto.MIME return io.NopCloser(bytes.NewReader(entry.data)), entry.headers, nil } -func (m *Memory) Create(ctx context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) { +func (m *Memory) Create(ctx context.Context, key Key, headers http.Header, ttl time.Duration) (io.WriteCloser, error) { if ttl == 0 { ttl = m.config.MaxTTL } now := time.Now() // Clone headers to avoid concurrent map writes - clonedHeaders := make(textproto.MIMEHeader) + clonedHeaders := make(http.Header) maps.Copy(clonedHeaders, headers) if clonedHeaders.Get("Last-Modified") == "" { clonedHeaders.Set("Last-Modified", now.UTC().Format(http.TimeFormat)) @@ -136,7 +135,7 @@ type memoryWriter struct { key Key buf *bytes.Buffer expiresAt time.Time - headers textproto.MIMEHeader + headers http.Header closed bool ctx context.Context } diff --git a/internal/cache/remote.go b/internal/cache/remote.go index b8becf7..fff4142 100644 --- a/internal/cache/remote.go +++ b/internal/cache/remote.go @@ -6,7 +6,6 @@ import ( "io" "maps" "net/http" - "net/textproto" "os" "time" @@ -32,7 +31,7 @@ func NewRemote(baseURL string) *Remote { func (c *Remote) String() string { return "remote:" + c.baseURL } // Open retrieves an object from the remote. -func (c *Remote) Open(ctx context.Context, key Key) (io.ReadCloser, textproto.MIMEHeader, error) { +func (c *Remote) Open(ctx context.Context, key Key) (io.ReadCloser, http.Header, error) { url := fmt.Sprintf("%s/%s", c.baseURL, key.String()) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -53,13 +52,13 @@ func (c *Remote) Open(ctx context.Context, key Key) (io.ReadCloser, textproto.MI } // Filter out HTTP transport headers - headers := FilterTransportHeaders(textproto.MIMEHeader(resp.Header)) + headers := FilterTransportHeaders(resp.Header) return resp.Body, headers, nil } // Stat retrieves headers for an object from the remote. -func (c *Remote) Stat(ctx context.Context, key Key) (textproto.MIMEHeader, error) { +func (c *Remote) Stat(ctx context.Context, key Key) (http.Header, error) { url := fmt.Sprintf("%s/%s", c.baseURL, key.String()) req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) if err != nil { @@ -81,13 +80,13 @@ func (c *Remote) Stat(ctx context.Context, key Key) (textproto.MIMEHeader, error } // Filter out HTTP transport headers - headers := FilterTransportHeaders(textproto.MIMEHeader(resp.Header)) + headers := FilterTransportHeaders(resp.Header) return headers, nil } // Create stores a new object in the remote. -func (c *Remote) Create(ctx context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) { +func (c *Remote) Create(ctx context.Context, key Key, headers http.Header, ttl time.Duration) (io.WriteCloser, error) { pr, pw := io.Pipe() url := fmt.Sprintf("%s/%s", c.baseURL, key.String()) diff --git a/internal/cache/s3.go b/internal/cache/s3.go index b66d2b1..23a3017 100644 --- a/internal/cache/s3.go +++ b/internal/cache/s3.go @@ -9,7 +9,6 @@ import ( "log/slog" "maps" "net/http" - "net/textproto" "os" "runtime" "time" @@ -162,7 +161,7 @@ func (s *S3) keyToPath(key Key) string { return hexKey[:2] + "/" + hexKey } -func (s *S3) Stat(ctx context.Context, key Key) (textproto.MIMEHeader, error) { +func (s *S3) Stat(ctx context.Context, key Key) (http.Header, error) { objectName := s.keyToPath(key) // Get object info to check metadata @@ -190,7 +189,7 @@ func (s *S3) Stat(ctx context.Context, key Key) (textproto.MIMEHeader, error) { // Retrieve headers from metadata // Note: UserMetadata keys are returned WITHOUT the "X-Amz-Meta-" prefix by minio-go - headers := make(textproto.MIMEHeader) + headers := make(http.Header) if headersJSON := objInfo.UserMetadata["Headers"]; headersJSON != "" { if err := json.Unmarshal([]byte(headersJSON), &headers); err != nil { return nil, errors.Errorf("failed to unmarshal headers: %w", err) @@ -205,7 +204,7 @@ func (s *S3) Stat(ctx context.Context, key Key) (textproto.MIMEHeader, error) { return headers, nil } -func (s *S3) Open(ctx context.Context, key Key) (io.ReadCloser, textproto.MIMEHeader, error) { +func (s *S3) Open(ctx context.Context, key Key) (io.ReadCloser, http.Header, error) { objectName := s.keyToPath(key) // Get object info to retrieve metadata and check expiration @@ -230,7 +229,7 @@ func (s *S3) Open(ctx context.Context, key Key) (io.ReadCloser, textproto.MIMEHe } // Retrieve headers from metadata - headers := make(textproto.MIMEHeader) + headers := make(http.Header) if headersJSON := objInfo.UserMetadata["Headers"]; headersJSON != "" { if err := json.Unmarshal([]byte(headersJSON), &headers); err != nil { return nil, nil, errors.Errorf("failed to unmarshal headers: %w", err) @@ -251,13 +250,13 @@ func (s *S3) Open(ctx context.Context, key Key) (io.ReadCloser, textproto.MIMEHe return obj, headers, nil } -func (s *S3) Create(ctx context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) { +func (s *S3) Create(ctx context.Context, key Key, headers http.Header, ttl time.Duration) (io.WriteCloser, error) { if ttl > s.config.MaxTTL || ttl == 0 { ttl = s.config.MaxTTL } // Clone headers to avoid concurrent access issues - clonedHeaders := make(textproto.MIMEHeader) + clonedHeaders := make(http.Header) maps.Copy(clonedHeaders, headers) expiresAt := time.Now().Add(ttl) @@ -296,7 +295,7 @@ type s3Writer struct { key Key pipe *io.PipeWriter expiresAt time.Time - headers textproto.MIMEHeader + headers http.Header ctx context.Context errCh chan error } diff --git a/internal/cache/tiered.go b/internal/cache/tiered.go index 8b3505f..5a5b5be 100644 --- a/internal/cache/tiered.go +++ b/internal/cache/tiered.go @@ -3,7 +3,7 @@ package cache import ( "context" "io" - "net/textproto" + "net/http" "os" "strings" "sync" @@ -50,7 +50,7 @@ func (t Tiered) Close() error { } // Create a new object. All underlying caches will be written to in sequence. -func (t Tiered) Create(ctx context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) { +func (t Tiered) Create(ctx context.Context, key Key, headers http.Header, ttl time.Duration) (io.WriteCloser, error) { // The first error will cancel all outstanding writes. ctx, cancel := context.WithCancelCause(ctx) @@ -91,7 +91,7 @@ func (t Tiered) Delete(ctx context.Context, key Key) error { // Stat returns headers from the first cache that succeeds. // // If all caches fail, all errors are returned. -func (t Tiered) Stat(ctx context.Context, key Key) (textproto.MIMEHeader, error) { +func (t Tiered) Stat(ctx context.Context, key Key) (http.Header, error) { errs := make([]error, len(t.caches)) for i, c := range t.caches { headers, err := c.Stat(ctx, key) @@ -109,7 +109,7 @@ func (t Tiered) Stat(ctx context.Context, key Key) (textproto.MIMEHeader, error) // Open returns a reader from the first cache that succeeds. // // If all caches fail, all errors are returned. -func (t Tiered) Open(ctx context.Context, key Key) (io.ReadCloser, textproto.MIMEHeader, error) { +func (t Tiered) Open(ctx context.Context, key Key) (io.ReadCloser, http.Header, error) { errs := make([]error, len(t.caches)) for i, c := range t.caches { r, headers, err := c.Open(ctx, key) diff --git a/internal/strategy/apiv1.go b/internal/strategy/apiv1.go index 215a08b..6dd14f0 100644 --- a/internal/strategy/apiv1.go +++ b/internal/strategy/apiv1.go @@ -7,7 +7,6 @@ import ( "log/slog" "maps" "net/http" - "net/textproto" "os" "time" @@ -109,7 +108,7 @@ func (d *APIV1) putObject(w http.ResponseWriter, r *http.Request) { } // Extract and filter headers from request - headers := cache.FilterTransportHeaders(textproto.MIMEHeader(r.Header)) + headers := cache.FilterTransportHeaders(r.Header) cw, err := d.cache.Create(r.Context(), key, headers, ttl) if err != nil { diff --git a/internal/strategy/git/bundle.go b/internal/strategy/git/bundle.go index 874d3ce..951b90d 100644 --- a/internal/strategy/git/bundle.go +++ b/internal/strategy/git/bundle.go @@ -4,7 +4,7 @@ import ( "context" "io" "log/slog" - "net/textproto" + "net/http" "strings" "time" @@ -20,7 +20,7 @@ func (s *Strategy) generateAndUploadBundle(ctx context.Context, c *clone) { cacheKey := cache.NewKey(c.upstreamURL + ".bundle") - headers := textproto.MIMEHeader{ + headers := http.Header{ "Content-Type": []string{"application/x-git-bundle"}, } ttl := 7 * 24 * time.Hour diff --git a/internal/strategy/handler/handler.go b/internal/strategy/handler/handler.go index 6628a59..4056329 100644 --- a/internal/strategy/handler/handler.go +++ b/internal/strategy/handler/handler.go @@ -5,7 +5,6 @@ import ( "log/slog" "maps" "net/http" - "net/textproto" "os" "time" @@ -169,7 +168,7 @@ func (h *Handler) streamNonOKResponse(w http.ResponseWriter, resp *http.Response func (h *Handler) streamAndCache(w http.ResponseWriter, r *http.Request, key cache.Key, resp *http.Response, logger *slog.Logger) { ttl := h.ttlFunc(r) - responseHeaders := textproto.MIMEHeader(maps.Clone(resp.Header)) + responseHeaders := maps.Clone(resp.Header) cw, err := h.cache.Create(r.Context(), key, responseHeaders, ttl) if err != nil { h.errorHandler(httputil.Errorf(http.StatusInternalServerError, "failed to create cache entry: %w", err), w, r)