Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/20251024170139.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:sparkles: [headers] Added utilities for header sanitisation
1 change: 1 addition & 0 deletions changes/20251024170358.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:bug: [headers] Fix header search due to normalisation of the name by go
10 changes: 5 additions & 5 deletions utils/http/header_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ import (

type ClientWithHeaders struct {
client IClient
headers headers.Headers
headers *headers.Headers
}

func newClientWithHeaders(underlyingClient IClient, headerValues ...string) (c *ClientWithHeaders, err error) {
c = &ClientWithHeaders{
headers: make(headers.Headers),
headers: headers.NewHeaders(),
}

if underlyingClient == nil {
Expand Down Expand Up @@ -123,7 +123,7 @@ func (c *ClientWithHeaders) Close() error {

func (c *ClientWithHeaders) AppendHeader(key, value string) {
if c.headers == nil {
c.headers = make(headers.Headers)
c.headers = headers.NewHeaders()
}
c.headers.AppendHeader(key, value)
}
Expand All @@ -132,9 +132,9 @@ func (c *ClientWithHeaders) RemoveHeader(key string) {
if c.headers == nil {
return
}
delete(c.headers, key)
c.headers.RemoveHeader(key)
}

func (c *ClientWithHeaders) ClearHeaders() {
c.headers = make(headers.Headers)
c.headers = headers.NewHeaders()
}
4 changes: 3 additions & 1 deletion utils/http/header_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ func TestClientWithHeadersWithDifferentBodies(t *testing.T) {

clientStruct.AppendHeader("hello", "world")
require.NotEmpty(t, clientStruct.headers)
assert.Equal(t, headers.Header{Key: "hello", Value: "world"}, clientStruct.headers["hello"])
header := clientStruct.headers.GetHeader("hello")
require.NotNil(t, header)
assert.Equal(t, headers.Header{Key: "hello", Value: "world"}, *header)

clientStruct.RemoveHeader("hello")
assert.Empty(t, clientStruct.headers)
Expand Down
134 changes: 118 additions & 16 deletions utils/http/headers/headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/http"
"strings"

mapset "github.com/deckarep/golang-set/v2"
"github.com/go-http-utils/headers"

"github.com/ARM-software/golang-utils/utils/collection"
Expand Down Expand Up @@ -146,6 +147,9 @@ var (
headers.XRatelimitRemaining,
headers.XRatelimitReset,
}
// NormalisedSafeHeaders returns a normalised list of safe headers
NormalisedSafeHeaders = collection.Map[string, string](SafeHeaders, headers.Normalize) //nolint:misspell

)

type Header struct {
Expand All @@ -167,17 +171,34 @@ func (hs Headers) AppendHeader(key, value string) {
}

func (hs Headers) Append(h *Header) {
hs[h.Key] = *h
hs[headers.Normalize(h.Key)] = *h //nolint:misspell
}

func (hs Headers) Get(key string) string {
h, found := hs[key]
found, h := hs.get(key)
if !found {
return ""
}
return h.Value
}

func (hs Headers) GetHeader(key string) (header *Header) {
_, header = hs.get(key)
return
}

func (hs Headers) get(key string) (found bool, header *Header) {
h, found := hs[key]
if !found {
h, found = hs[headers.Normalize(key)] //nolint:misspell
if !found {
return
}
}
header = &h
return
}

func (hs Headers) Has(h *Header) bool {
if h == nil {
return false
Expand All @@ -186,10 +207,33 @@ func (hs Headers) Has(h *Header) bool {
}

func (hs Headers) HasHeader(key string) bool {
_, found := hs[key]
found, _ := hs.get(key)
return found
}

func (hs Headers) FromRequest(r *http.Request) {
if r == nil {
return
}
hs.FromGoHTTPHeaders(&r.Header)
}

func (hs Headers) FromGoHTTPHeaders(headers *http.Header) {
if reflection.IsEmpty(headers) {
return
}
for key, value := range *headers {
hs.AppendHeader(key, value[0])
}
}

func (hs Headers) FromResponse(resp *http.Response) {
if resp == nil {
return
}
hs.FromGoHTTPHeaders(&resp.Header)
}

func (hs Headers) Empty() bool {
return len(hs) == 0
}
Expand All @@ -210,10 +254,77 @@ func (hs Headers) AppendToRequest(r *http.Request) {
}
}

func (hs Headers) RemoveHeader(key string) {
delete(hs, key)
delete(hs, headers.Normalize(key)) //nolint:misspell
}

func (hs Headers) RemoveHeaders(key ...string) {
for i := range key {
hs.RemoveHeader(key[i])
}
}

func (hs Headers) Clone() *Headers {
clone := make(Headers, len(hs))
for k, v := range hs {
clone[k] = v
}
return &clone
}

// DisallowList returns the headers minus any header defined in the disallow list.
func (hs Headers) DisallowList(key ...string) *Headers {
clone := hs.Clone()
clone.RemoveHeaders(key...)
return clone
}

// AllowList return only safe headers and headers defined in the allow list.
func (hs Headers) AllowList(key ...string) *Headers {
clone := hs.Clone()
clone.Sanitise(key...)
return clone
}

// Sanitise sanitises headers so no personal data is retained.
// It is possible to provide an allowed list of extra headers which would also be retained.
func (hs Headers) Sanitise(allowList ...string) {
allowedHeaders := mapset.NewSet[string](NormalisedSafeHeaders...)
allowedHeaders.Append(collection.Map[string, string](allowList, headers.Normalize)...) //nolint:misspell
var headersToRemove []string
for key := range hs {
if !allowedHeaders.Contains(headers.Normalize(key)) { //nolint:misspell
headersToRemove = append(headersToRemove, key)
}
}
hs.RemoveHeaders(headersToRemove...)
}

func NewHeaders() *Headers {
return &Headers{}
}

// FromRequest returns request's headers
func FromRequest(r *http.Request) *Headers {
if r == nil {
return nil
}
h := NewHeaders()
h.FromRequest(r)
return h
}

// FromResponse returns response's headers
func FromResponse(resp *http.Response) *Headers {
if resp == nil {
return nil
}
h := NewHeaders()
h.FromResponse(resp)
return h
}

// ParseAuthorizationHeader fetches the `Authorization` header and parses it.
func ParseAuthorizationHeader(r *http.Request) (string, string, error) {
return ParseAuthorisationValue(FetchWebsocketAuthorisation(r))
Expand Down Expand Up @@ -414,17 +525,8 @@ func CreateLinkHeader(link, relation, contentType string) string {

// SanitiseHeaders sanitises a collection of request headers not to include any with personal data
func SanitiseHeaders(requestHeader *http.Header) *Headers {
if requestHeader == nil {
return nil
}
aHeaders := NewHeaders()
for i := range SafeHeaders {
safeHeader := SafeHeaders[i]
rHeader := requestHeader.Get(safeHeader)
if !reflection.IsEmpty(rHeader) {
aHeaders.AppendHeader(safeHeader, rHeader)
}
}

return aHeaders
hs := NewHeaders()
hs.FromGoHTTPHeaders(requestHeader)
hs.Sanitise()
return hs
}
78 changes: 78 additions & 0 deletions utils/http/headers/headers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,38 @@ func TestParseAuthorizationHeader(t *testing.T) {
})
}

func TestFromToRequestResponse(t *testing.T) {
request := httptest.NewRequest(http.MethodGet, faker.URL(), nil)
request.Header.Add(headers.Authorization, faker.Password())
request.Header.Add(HeaderWebsocketProtocol, faker.Password())
h := FromRequest(request)
h.AppendHeader(headers.Accept, "1.0.0")
h.AppendHeader(headers.AcceptEncoding, "gzip")
r2 := httptest.NewRequest(http.MethodGet, faker.URL(), nil)
assert.Empty(t, r2.Header)
h.AppendToRequest(r2)
assert.NotEmpty(t, r2.Header)
h2 := FromRequest(r2)
assert.True(t, h2.HasHeader(headers.Authorization))
assert.True(t, h2.HasHeader(headers.AcceptEncoding))
assert.True(t, h2.HasHeader(headers.Accept))
assert.True(t, h2.HasHeader(HeaderWebsocketProtocol))

response := httptest.NewRecorder()
response.Header().Set(HeaderWebsocketProtocol, "base64.binary.k8s.io")
response.Header().Set(headers.Authorization, faker.Password())
h3 := FromResponse(response.Result())
h3.AppendHeader(headers.Accept, "1.0.0")
h3.AppendHeader(headers.AcceptEncoding, "gzip")
response2 := httptest.NewRecorder()
h3.AppendToResponse(response2)
h4 := FromResponse(response2.Result())
assert.True(t, h4.HasHeader(headers.Authorization))
assert.True(t, h4.HasHeader(headers.AcceptEncoding))
assert.True(t, h4.HasHeader(headers.Accept))
assert.True(t, h4.HasHeader(HeaderWebsocketProtocol))
}

func TestAddProductInformationToUserAgent(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, faker.URL(), nil)
require.NoError(t, err)
Expand All @@ -165,6 +197,18 @@ func TestSetLocationHeaders(t *testing.T) {
assert.Equal(t, location, w.Header().Get(headers.ContentLocation))
}

func TestGetHeaders(t *testing.T) {
header := NewHeaders()
test := faker.Word()
header.AppendHeader(HeaderWebsocketProtocol, test)
assert.Equal(t, test, header.Get(headers.Normalize(HeaderWebsocketProtocol))) //nolint:misspell
assert.True(t, header.HasHeader(HeaderWebsocketProtocol))
assert.True(t, header.HasHeader(headers.Normalize(HeaderWebsocketProtocol))) //nolint:misspell
assert.Empty(t, header.Get(headers.ContentLocation))
assert.False(t, header.HasHeader(headers.ContentLocation))
assert.False(t, header.HasHeader(headers.Normalize(headers.ContentLocation))) //nolint:misspell
}

func TestSanitiseHeaders(t *testing.T) {
header := &http.Header{}
t.Run("empty", func(t *testing.T) {
Expand Down Expand Up @@ -197,5 +241,39 @@ func TestSanitiseHeaders(t *testing.T) {
assert.False(t, actual.HasHeader(
HeaderWebsocketProtocol))
})
t.Run("allow/disallow list", func(t *testing.T) {
h := NewHeaders()
h.AppendHeader(headers.Authorization, faker.Password())
h.AppendHeader(HeaderWebsocketProtocol, faker.Password())
h.AppendHeader(headers.Accept, "1.0.0")
h.AppendHeader(headers.AcceptEncoding, "gzip")
h1 := h.Clone()
h1.Sanitise()
assert.True(t, h1.HasHeader(headers.Accept))
assert.True(t, h1.HasHeader(headers.AcceptEncoding))
assert.False(t, h1.HasHeader(HeaderWebsocketProtocol))
assert.False(t, h1.HasHeader(headers.Authorization))
assert.True(t, h.HasHeader(headers.Accept))
assert.True(t, h.HasHeader(headers.AcceptEncoding))
assert.True(t, h.HasHeader(HeaderWebsocketProtocol))
assert.True(t, h.HasHeader(headers.Authorization))
h11 := h.AllowList(headers.Authorization)
assert.True(t, h11.HasHeader(headers.Accept))
assert.True(t, h11.HasHeader(headers.AcceptEncoding))
assert.False(t, h11.HasHeader(HeaderWebsocketProtocol))
assert.True(t, h11.HasHeader(headers.Authorization))
h2 := h.Clone()
h2.Sanitise(headers.Authorization)
h2.RemoveHeaders(headers.AcceptEncoding, headers.Accept)
assert.False(t, h2.HasHeader(headers.Accept))
assert.False(t, h2.HasHeader(headers.AcceptEncoding))
assert.False(t, h2.HasHeader(HeaderWebsocketProtocol))
assert.True(t, h2.HasHeader(headers.Authorization))
h22 := h.DisallowList(headers.AcceptEncoding, headers.Accept)
assert.False(t, h22.HasHeader(headers.Accept))
assert.False(t, h22.HasHeader(headers.AcceptEncoding))
assert.True(t, h22.HasHeader(HeaderWebsocketProtocol))
assert.True(t, h22.HasHeader(headers.Authorization))
})

}
Loading