From c3cad557c8954f7a11065b19843c81b241808de8 Mon Sep 17 00:00:00 2001 From: Adrien CABARBAYE Date: Fri, 24 Oct 2025 17:04:08 +0100 Subject: [PATCH 1/3] :sparkles: [headers] Added utilities for header sanitisation --- changes/20251024170139.feature | 1 + changes/20251024170358.bugfix | 1 + utils/http/headers/headers.go | 128 +++++++++++++++++++++++++---- utils/http/headers/headers_test.go | 78 ++++++++++++++++++ 4 files changed, 191 insertions(+), 17 deletions(-) create mode 100644 changes/20251024170139.feature create mode 100644 changes/20251024170358.bugfix diff --git a/changes/20251024170139.feature b/changes/20251024170139.feature new file mode 100644 index 0000000000..6ed67c7b4e --- /dev/null +++ b/changes/20251024170139.feature @@ -0,0 +1 @@ +:sparkles: [headers] Added utilities for header sanitisation diff --git a/changes/20251024170358.bugfix b/changes/20251024170358.bugfix new file mode 100644 index 0000000000..58e5565235 --- /dev/null +++ b/changes/20251024170358.bugfix @@ -0,0 +1 @@ +:bug: [headers] Fix header search due to normalisation of the name by go diff --git a/utils/http/headers/headers.go b/utils/http/headers/headers.go index 19d20acb61..9400426f82 100644 --- a/utils/http/headers/headers.go +++ b/utils/http/headers/headers.go @@ -6,6 +6,7 @@ import ( "net/http" "strings" + mapset "github.com/deckarep/golang-set/v2" "github.com/go-http-utils/headers" "github.com/ARM-software/golang-utils/utils/collection" @@ -146,6 +147,9 @@ var ( headers.XRatelimitRemaining, headers.XRatelimitReset, } + // NormalisedSafeHeaders returns a normalised list of safe headers + NormalisedSafeHeaders = collection.Map[string, string](SafeHeaders, headers.Normalize) //nolint:misspell + ) type Header struct { @@ -167,15 +171,24 @@ func (hs Headers) AppendHeader(key, value string) { } func (hs Headers) Append(h *Header) { - hs[h.Key] = *h + hs[headers.Normalize(h.Key)] = *h } func (hs Headers) Get(key string) string { + _, value := hs.get(key) + return value +} + +func (hs Headers) get(key string) (found bool, value string) { h, found := hs[key] if !found { - return "" + h, found = hs[headers.Normalize(key)] + if !found { + return + } } - return h.Value + value = h.Value + return } func (hs Headers) Has(h *Header) bool { @@ -186,10 +199,33 @@ func (hs Headers) Has(h *Header) bool { } func (hs Headers) HasHeader(key string) bool { - _, found := hs[key] + found, _ := hs.get(key) return found } +func (hs Headers) FromRequest(r *http.Request) { + if r == nil { + return + } + hs.FromGoHttpHeaders(&r.Header) +} + +func (hs Headers) FromGoHttpHeaders(headers *http.Header) { + if reflection.IsEmpty(headers) { + return + } + for key, value := range *headers { + hs.AppendHeader(key, value[0]) + } +} + +func (hs Headers) FromResponse(resp *http.Response) { + if resp == nil { + return + } + hs.FromGoHttpHeaders(&resp.Header) +} + func (hs Headers) Empty() bool { return len(hs) == 0 } @@ -210,10 +246,77 @@ func (hs Headers) AppendToRequest(r *http.Request) { } } +func (hs Headers) RemoveHeader(key string) { + delete(hs, key) + delete(hs, headers.Normalize(key)) +} + +func (hs Headers) RemoveHeaders(key ...string) { + for i := range key { + hs.RemoveHeader(key[i]) + } +} + +func (hs Headers) Clone() *Headers { + clone := make(Headers, len(hs)) + for k, v := range hs { + clone[k] = v + } + return &clone +} + +// DisallowList returns the headers minus any header defined in the disallow list. +func (hs Headers) DisallowList(key ...string) *Headers { + clone := hs.Clone() + clone.RemoveHeaders(key...) + return clone +} + +// AllowList return only safe headers and headers defined in the allow list. +func (hs Headers) AllowList(key ...string) *Headers { + clone := hs.Clone() + clone.Sanitise(key...) + return clone +} + +// Sanitise sanitises headers so no personal data is retained. +// It is possible to provide an allowed list of extra headers which would also be retained. +func (hs Headers) Sanitise(allowList ...string) { + allowedHeaders := mapset.NewSet[string](NormalisedSafeHeaders...) + allowedHeaders.Append(collection.Map[string, string](allowList, headers.Normalize)...) + var headersToRemove []string + for key := range hs { + if !allowedHeaders.Contains(headers.Normalize(key)) { + headersToRemove = append(headersToRemove, key) + } + } + hs.RemoveHeaders(headersToRemove...) +} + func NewHeaders() *Headers { return &Headers{} } +// FromRequest returns request's headers +func FromRequest(r *http.Request) *Headers { + if r == nil { + return nil + } + h := NewHeaders() + h.FromRequest(r) + return h +} + +// FromResponse returns response's headers +func FromResponse(resp *http.Response) *Headers { + if resp == nil { + return nil + } + h := NewHeaders() + h.FromResponse(resp) + return h +} + // ParseAuthorizationHeader fetches the `Authorization` header and parses it. func ParseAuthorizationHeader(r *http.Request) (string, string, error) { return ParseAuthorisationValue(FetchWebsocketAuthorisation(r)) @@ -414,17 +517,8 @@ func CreateLinkHeader(link, relation, contentType string) string { // SanitiseHeaders sanitises a collection of request headers not to include any with personal data func SanitiseHeaders(requestHeader *http.Header) *Headers { - if requestHeader == nil { - return nil - } - aHeaders := NewHeaders() - for i := range SafeHeaders { - safeHeader := SafeHeaders[i] - rHeader := requestHeader.Get(safeHeader) - if !reflection.IsEmpty(rHeader) { - aHeaders.AppendHeader(safeHeader, rHeader) - } - } - - return aHeaders + hs := NewHeaders() + hs.FromGoHttpHeaders(requestHeader) + hs.Sanitise() + return hs } diff --git a/utils/http/headers/headers_test.go b/utils/http/headers/headers_test.go index e3149d1efa..b123411b5e 100644 --- a/utils/http/headers/headers_test.go +++ b/utils/http/headers/headers_test.go @@ -144,6 +144,38 @@ func TestParseAuthorizationHeader(t *testing.T) { }) } +func TestFromToRequestResponse(t *testing.T) { + request := httptest.NewRequest(http.MethodGet, faker.URL(), nil) + request.Header.Add(headers.Authorization, faker.Password()) + request.Header.Add(HeaderWebsocketProtocol, faker.Password()) + h := FromRequest(request) + h.AppendHeader(headers.Accept, "1.0.0") + h.AppendHeader(headers.AcceptEncoding, "gzip") + r2 := httptest.NewRequest(http.MethodGet, faker.URL(), nil) + assert.Empty(t, r2.Header) + h.AppendToRequest(r2) + assert.NotEmpty(t, r2.Header) + h2 := FromRequest(r2) + assert.True(t, h2.HasHeader(headers.Authorization)) + assert.True(t, h2.HasHeader(headers.AcceptEncoding)) + assert.True(t, h2.HasHeader(headers.Accept)) + assert.True(t, h2.HasHeader(HeaderWebsocketProtocol)) + + response := httptest.NewRecorder() + response.Header().Set(HeaderWebsocketProtocol, "base64.binary.k8s.io") + response.Header().Set(headers.Authorization, faker.Password()) + h3 := FromResponse(response.Result()) + h3.AppendHeader(headers.Accept, "1.0.0") + h3.AppendHeader(headers.AcceptEncoding, "gzip") + response2 := httptest.NewRecorder() + h3.AppendToResponse(response2) + h4 := FromResponse(response2.Result()) + assert.True(t, h4.HasHeader(headers.Authorization)) + assert.True(t, h4.HasHeader(headers.AcceptEncoding)) + assert.True(t, h4.HasHeader(headers.Accept)) + assert.True(t, h4.HasHeader(HeaderWebsocketProtocol)) +} + func TestAddProductInformationToUserAgent(t *testing.T) { r, err := http.NewRequest(http.MethodGet, faker.URL(), nil) require.NoError(t, err) @@ -165,6 +197,18 @@ func TestSetLocationHeaders(t *testing.T) { assert.Equal(t, location, w.Header().Get(headers.ContentLocation)) } +func TestGetHeaders(t *testing.T) { + header := NewHeaders() + test := faker.Word() + header.AppendHeader(HeaderWebsocketProtocol, test) + assert.Equal(t, test, header.Get(headers.Normalize(HeaderWebsocketProtocol))) + assert.True(t, header.HasHeader(HeaderWebsocketProtocol)) + assert.True(t, header.HasHeader(headers.Normalize(HeaderWebsocketProtocol))) + assert.Empty(t, header.Get(headers.ContentLocation)) + assert.False(t, header.HasHeader(headers.ContentLocation)) + assert.False(t, header.HasHeader(headers.Normalize(headers.ContentLocation))) +} + func TestSanitiseHeaders(t *testing.T) { header := &http.Header{} t.Run("empty", func(t *testing.T) { @@ -197,5 +241,39 @@ func TestSanitiseHeaders(t *testing.T) { assert.False(t, actual.HasHeader( HeaderWebsocketProtocol)) }) + t.Run("allow/disallow list", func(t *testing.T) { + h := NewHeaders() + h.AppendHeader(headers.Authorization, faker.Password()) + h.AppendHeader(HeaderWebsocketProtocol, faker.Password()) + h.AppendHeader(headers.Accept, "1.0.0") + h.AppendHeader(headers.AcceptEncoding, "gzip") + h1 := h.Clone() + h1.Sanitise() + assert.True(t, h1.HasHeader(headers.Accept)) + assert.True(t, h1.HasHeader(headers.AcceptEncoding)) + assert.False(t, h1.HasHeader(HeaderWebsocketProtocol)) + assert.False(t, h1.HasHeader(headers.Authorization)) + assert.True(t, h.HasHeader(headers.Accept)) + assert.True(t, h.HasHeader(headers.AcceptEncoding)) + assert.True(t, h.HasHeader(HeaderWebsocketProtocol)) + assert.True(t, h.HasHeader(headers.Authorization)) + h11 := h.AllowList(headers.Authorization) + assert.True(t, h11.HasHeader(headers.Accept)) + assert.True(t, h11.HasHeader(headers.AcceptEncoding)) + assert.False(t, h11.HasHeader(HeaderWebsocketProtocol)) + assert.True(t, h11.HasHeader(headers.Authorization)) + h2 := h.Clone() + h2.Sanitise(headers.Authorization) + h2.RemoveHeaders(headers.AcceptEncoding, headers.Accept) + assert.False(t, h2.HasHeader(headers.Accept)) + assert.False(t, h2.HasHeader(headers.AcceptEncoding)) + assert.False(t, h2.HasHeader(HeaderWebsocketProtocol)) + assert.True(t, h2.HasHeader(headers.Authorization)) + h22 := h.DisallowList(headers.AcceptEncoding, headers.Accept) + assert.False(t, h22.HasHeader(headers.Accept)) + assert.False(t, h22.HasHeader(headers.AcceptEncoding)) + assert.True(t, h22.HasHeader(HeaderWebsocketProtocol)) + assert.True(t, h22.HasHeader(headers.Authorization)) + }) } From 17950bfb88d9082b0c868e2c4033396d6103108e Mon Sep 17 00:00:00 2001 From: Adrien CABARBAYE Date: Fri, 24 Oct 2025 17:16:36 +0100 Subject: [PATCH 2/3] :green_heart: linting --- utils/http/header_client.go | 10 ++++----- utils/http/header_client_test.go | 4 +++- utils/http/headers/headers.go | 34 ++++++++++++++++++------------ utils/http/headers/headers_test.go | 6 +++--- 4 files changed, 32 insertions(+), 22 deletions(-) diff --git a/utils/http/header_client.go b/utils/http/header_client.go index b7f9debf5a..b1e2fb0d4a 100644 --- a/utils/http/header_client.go +++ b/utils/http/header_client.go @@ -15,12 +15,12 @@ import ( type ClientWithHeaders struct { client IClient - headers headers.Headers + headers *headers.Headers } func newClientWithHeaders(underlyingClient IClient, headerValues ...string) (c *ClientWithHeaders, err error) { c = &ClientWithHeaders{ - headers: make(headers.Headers), + headers: headers.NewHeaders(), } if underlyingClient == nil { @@ -123,7 +123,7 @@ func (c *ClientWithHeaders) Close() error { func (c *ClientWithHeaders) AppendHeader(key, value string) { if c.headers == nil { - c.headers = make(headers.Headers) + c.headers = headers.NewHeaders() } c.headers.AppendHeader(key, value) } @@ -132,9 +132,9 @@ func (c *ClientWithHeaders) RemoveHeader(key string) { if c.headers == nil { return } - delete(c.headers, key) + c.headers.RemoveHeader(key) } func (c *ClientWithHeaders) ClearHeaders() { - c.headers = make(headers.Headers) + c.headers = headers.NewHeaders() } diff --git a/utils/http/header_client_test.go b/utils/http/header_client_test.go index b3357ec24a..44bce2ca05 100644 --- a/utils/http/header_client_test.go +++ b/utils/http/header_client_test.go @@ -223,7 +223,9 @@ func TestClientWithHeadersWithDifferentBodies(t *testing.T) { clientStruct.AppendHeader("hello", "world") require.NotEmpty(t, clientStruct.headers) - assert.Equal(t, headers.Header{Key: "hello", Value: "world"}, clientStruct.headers["hello"]) + header := clientStruct.headers.GetHeader("hello") + require.NotNil(t, header) + assert.Equal(t, headers.Header{Key: "hello", Value: "world"}, *header) clientStruct.RemoveHeader("hello") assert.Empty(t, clientStruct.headers) diff --git a/utils/http/headers/headers.go b/utils/http/headers/headers.go index 9400426f82..468effdec9 100644 --- a/utils/http/headers/headers.go +++ b/utils/http/headers/headers.go @@ -171,23 +171,31 @@ func (hs Headers) AppendHeader(key, value string) { } func (hs Headers) Append(h *Header) { - hs[headers.Normalize(h.Key)] = *h + hs[headers.Normalize(h.Key)] = *h //nolint:misspell } func (hs Headers) Get(key string) string { - _, value := hs.get(key) - return value + found, h := hs.get(key) + if !found { + return "" + } + return h.Value +} + +func (hs Headers) GetHeader(key string) (header *Header) { + _, header = hs.get(key) + return } -func (hs Headers) get(key string) (found bool, value string) { +func (hs Headers) get(key string) (found bool, header *Header) { h, found := hs[key] if !found { - h, found = hs[headers.Normalize(key)] + h, found = hs[headers.Normalize(key)] //nolint:misspell if !found { return } } - value = h.Value + header = &h return } @@ -207,10 +215,10 @@ func (hs Headers) FromRequest(r *http.Request) { if r == nil { return } - hs.FromGoHttpHeaders(&r.Header) + hs.FromGoHTTPHeaders(&r.Header) } -func (hs Headers) FromGoHttpHeaders(headers *http.Header) { +func (hs Headers) FromGoHTTPHeaders(headers *http.Header) { if reflection.IsEmpty(headers) { return } @@ -223,7 +231,7 @@ func (hs Headers) FromResponse(resp *http.Response) { if resp == nil { return } - hs.FromGoHttpHeaders(&resp.Header) + hs.FromGoHTTPHeaders(&resp.Header) } func (hs Headers) Empty() bool { @@ -248,7 +256,7 @@ func (hs Headers) AppendToRequest(r *http.Request) { func (hs Headers) RemoveHeader(key string) { delete(hs, key) - delete(hs, headers.Normalize(key)) + delete(hs, headers.Normalize(key)) //nolint:misspell } func (hs Headers) RemoveHeaders(key ...string) { @@ -283,10 +291,10 @@ func (hs Headers) AllowList(key ...string) *Headers { // It is possible to provide an allowed list of extra headers which would also be retained. func (hs Headers) Sanitise(allowList ...string) { allowedHeaders := mapset.NewSet[string](NormalisedSafeHeaders...) - allowedHeaders.Append(collection.Map[string, string](allowList, headers.Normalize)...) + allowedHeaders.Append(collection.Map[string, string](allowList, headers.Normalize)...) //nolint:misspell var headersToRemove []string for key := range hs { - if !allowedHeaders.Contains(headers.Normalize(key)) { + if !allowedHeaders.Contains(headers.Normalize(key)) { //nolint:misspell headersToRemove = append(headersToRemove, key) } } @@ -518,7 +526,7 @@ func CreateLinkHeader(link, relation, contentType string) string { // SanitiseHeaders sanitises a collection of request headers not to include any with personal data func SanitiseHeaders(requestHeader *http.Header) *Headers { hs := NewHeaders() - hs.FromGoHttpHeaders(requestHeader) + hs.FromGoHTTPHeaders(requestHeader) hs.Sanitise() return hs } diff --git a/utils/http/headers/headers_test.go b/utils/http/headers/headers_test.go index b123411b5e..0d6b2d17f5 100644 --- a/utils/http/headers/headers_test.go +++ b/utils/http/headers/headers_test.go @@ -201,12 +201,12 @@ func TestGetHeaders(t *testing.T) { header := NewHeaders() test := faker.Word() header.AppendHeader(HeaderWebsocketProtocol, test) - assert.Equal(t, test, header.Get(headers.Normalize(HeaderWebsocketProtocol))) + assert.Equal(t, test, header.Get(headers.Normalize(HeaderWebsocketProtocol))) //nolint:misspell assert.True(t, header.HasHeader(HeaderWebsocketProtocol)) - assert.True(t, header.HasHeader(headers.Normalize(HeaderWebsocketProtocol))) + assert.True(t, header.HasHeader(headers.Normalize(HeaderWebsocketProtocol))) //nolint:misspell assert.Empty(t, header.Get(headers.ContentLocation)) assert.False(t, header.HasHeader(headers.ContentLocation)) - assert.False(t, header.HasHeader(headers.Normalize(headers.ContentLocation))) + assert.False(t, header.HasHeader(headers.Normalize(headers.ContentLocation))) //nolint:misspell } func TestSanitiseHeaders(t *testing.T) { From 9226b9e29a9063e7e28534e5686a81611c4a9d09 Mon Sep 17 00:00:00 2001 From: Adrien CABARBAYE Date: Mon, 27 Oct 2025 09:39:41 +0000 Subject: [PATCH 3/3] :green_heart: Address review comments --- utils/http/headers/headers.go | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/utils/http/headers/headers.go b/utils/http/headers/headers.go index 468effdec9..1c0781616c 100644 --- a/utils/http/headers/headers.go +++ b/utils/http/headers/headers.go @@ -12,6 +12,7 @@ import ( "github.com/ARM-software/golang-utils/utils/collection" "github.com/ARM-software/golang-utils/utils/commonerrors" "github.com/ARM-software/golang-utils/utils/encoding/base64" + "github.com/ARM-software/golang-utils/utils/field" "github.com/ARM-software/golang-utils/utils/http/headers/useragent" "github.com/ARM-software/golang-utils/utils/http/schemes" "github.com/ARM-software/golang-utils/utils/reflection" @@ -175,7 +176,7 @@ func (hs Headers) Append(h *Header) { } func (hs Headers) Get(key string) string { - found, h := hs.get(key) + h, found := hs.get(key) if !found { return "" } @@ -183,11 +184,11 @@ func (hs Headers) Get(key string) string { } func (hs Headers) GetHeader(key string) (header *Header) { - _, header = hs.get(key) + header, _ = hs.get(key) return } -func (hs Headers) get(key string) (found bool, header *Header) { +func (hs Headers) get(key string) (header *Header, found bool) { h, found := hs[key] if !found { h, found = hs[headers.Normalize(key)] //nolint:misspell @@ -207,28 +208,25 @@ func (hs Headers) Has(h *Header) bool { } func (hs Headers) HasHeader(key string) bool { - found, _ := hs.get(key) + _, found := hs.get(key) return found } func (hs Headers) FromRequest(r *http.Request) { - if r == nil { + if reflection.IsEmpty(r) { return } hs.FromGoHTTPHeaders(&r.Header) } func (hs Headers) FromGoHTTPHeaders(headers *http.Header) { - if reflection.IsEmpty(headers) { - return - } - for key, value := range *headers { + for key, value := range field.Optional[http.Header](headers, http.Header{}) { hs.AppendHeader(key, value[0]) } } func (hs Headers) FromResponse(resp *http.Response) { - if resp == nil { + if reflection.IsEmpty(resp) { return } hs.FromGoHTTPHeaders(&resp.Header) @@ -307,7 +305,7 @@ func NewHeaders() *Headers { // FromRequest returns request's headers func FromRequest(r *http.Request) *Headers { - if r == nil { + if reflection.IsEmpty(r) { return nil } h := NewHeaders() @@ -317,7 +315,7 @@ func FromRequest(r *http.Request) *Headers { // FromResponse returns response's headers func FromResponse(resp *http.Response) *Headers { - if resp == nil { + if reflection.IsEmpty(resp) { return nil } h := NewHeaders()