Skip to content

Commit c3cad55

Browse files
committed
✨ [headers] Added utilities for header sanitisation
1 parent cf4da6d commit c3cad55

File tree

4 files changed

+191
-17
lines changed

4 files changed

+191
-17
lines changed

changes/20251024170139.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
:sparkles: [headers] Added utilities for header sanitisation

changes/20251024170358.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
:bug: [headers] Fix header search due to normalisation of the name by go

utils/http/headers/headers.go

Lines changed: 111 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"net/http"
77
"strings"
88

9+
mapset "github.com/deckarep/golang-set/v2"
910
"github.com/go-http-utils/headers"
1011

1112
"github.com/ARM-software/golang-utils/utils/collection"
@@ -146,6 +147,9 @@ var (
146147
headers.XRatelimitRemaining,
147148
headers.XRatelimitReset,
148149
}
150+
// NormalisedSafeHeaders returns a normalised list of safe headers
151+
NormalisedSafeHeaders = collection.Map[string, string](SafeHeaders, headers.Normalize) //nolint:misspell
152+
149153
)
150154

151155
type Header struct {
@@ -167,15 +171,24 @@ func (hs Headers) AppendHeader(key, value string) {
167171
}
168172

169173
func (hs Headers) Append(h *Header) {
170-
hs[h.Key] = *h
174+
hs[headers.Normalize(h.Key)] = *h
171175
}
172176

173177
func (hs Headers) Get(key string) string {
178+
_, value := hs.get(key)
179+
return value
180+
}
181+
182+
func (hs Headers) get(key string) (found bool, value string) {
174183
h, found := hs[key]
175184
if !found {
176-
return ""
185+
h, found = hs[headers.Normalize(key)]
186+
if !found {
187+
return
188+
}
177189
}
178-
return h.Value
190+
value = h.Value
191+
return
179192
}
180193

181194
func (hs Headers) Has(h *Header) bool {
@@ -186,10 +199,33 @@ func (hs Headers) Has(h *Header) bool {
186199
}
187200

188201
func (hs Headers) HasHeader(key string) bool {
189-
_, found := hs[key]
202+
found, _ := hs.get(key)
190203
return found
191204
}
192205

206+
func (hs Headers) FromRequest(r *http.Request) {
207+
if r == nil {
208+
return
209+
}
210+
hs.FromGoHttpHeaders(&r.Header)
211+
}
212+
213+
func (hs Headers) FromGoHttpHeaders(headers *http.Header) {
214+
if reflection.IsEmpty(headers) {
215+
return
216+
}
217+
for key, value := range *headers {
218+
hs.AppendHeader(key, value[0])
219+
}
220+
}
221+
222+
func (hs Headers) FromResponse(resp *http.Response) {
223+
if resp == nil {
224+
return
225+
}
226+
hs.FromGoHttpHeaders(&resp.Header)
227+
}
228+
193229
func (hs Headers) Empty() bool {
194230
return len(hs) == 0
195231
}
@@ -210,10 +246,77 @@ func (hs Headers) AppendToRequest(r *http.Request) {
210246
}
211247
}
212248

249+
func (hs Headers) RemoveHeader(key string) {
250+
delete(hs, key)
251+
delete(hs, headers.Normalize(key))
252+
}
253+
254+
func (hs Headers) RemoveHeaders(key ...string) {
255+
for i := range key {
256+
hs.RemoveHeader(key[i])
257+
}
258+
}
259+
260+
func (hs Headers) Clone() *Headers {
261+
clone := make(Headers, len(hs))
262+
for k, v := range hs {
263+
clone[k] = v
264+
}
265+
return &clone
266+
}
267+
268+
// DisallowList returns the headers minus any header defined in the disallow list.
269+
func (hs Headers) DisallowList(key ...string) *Headers {
270+
clone := hs.Clone()
271+
clone.RemoveHeaders(key...)
272+
return clone
273+
}
274+
275+
// AllowList return only safe headers and headers defined in the allow list.
276+
func (hs Headers) AllowList(key ...string) *Headers {
277+
clone := hs.Clone()
278+
clone.Sanitise(key...)
279+
return clone
280+
}
281+
282+
// Sanitise sanitises headers so no personal data is retained.
283+
// It is possible to provide an allowed list of extra headers which would also be retained.
284+
func (hs Headers) Sanitise(allowList ...string) {
285+
allowedHeaders := mapset.NewSet[string](NormalisedSafeHeaders...)
286+
allowedHeaders.Append(collection.Map[string, string](allowList, headers.Normalize)...)
287+
var headersToRemove []string
288+
for key := range hs {
289+
if !allowedHeaders.Contains(headers.Normalize(key)) {
290+
headersToRemove = append(headersToRemove, key)
291+
}
292+
}
293+
hs.RemoveHeaders(headersToRemove...)
294+
}
295+
213296
func NewHeaders() *Headers {
214297
return &Headers{}
215298
}
216299

300+
// FromRequest returns request's headers
301+
func FromRequest(r *http.Request) *Headers {
302+
if r == nil {
303+
return nil
304+
}
305+
h := NewHeaders()
306+
h.FromRequest(r)
307+
return h
308+
}
309+
310+
// FromResponse returns response's headers
311+
func FromResponse(resp *http.Response) *Headers {
312+
if resp == nil {
313+
return nil
314+
}
315+
h := NewHeaders()
316+
h.FromResponse(resp)
317+
return h
318+
}
319+
217320
// ParseAuthorizationHeader fetches the `Authorization` header and parses it.
218321
func ParseAuthorizationHeader(r *http.Request) (string, string, error) {
219322
return ParseAuthorisationValue(FetchWebsocketAuthorisation(r))
@@ -414,17 +517,8 @@ func CreateLinkHeader(link, relation, contentType string) string {
414517

415518
// SanitiseHeaders sanitises a collection of request headers not to include any with personal data
416519
func SanitiseHeaders(requestHeader *http.Header) *Headers {
417-
if requestHeader == nil {
418-
return nil
419-
}
420-
aHeaders := NewHeaders()
421-
for i := range SafeHeaders {
422-
safeHeader := SafeHeaders[i]
423-
rHeader := requestHeader.Get(safeHeader)
424-
if !reflection.IsEmpty(rHeader) {
425-
aHeaders.AppendHeader(safeHeader, rHeader)
426-
}
427-
}
428-
429-
return aHeaders
520+
hs := NewHeaders()
521+
hs.FromGoHttpHeaders(requestHeader)
522+
hs.Sanitise()
523+
return hs
430524
}

utils/http/headers/headers_test.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,38 @@ func TestParseAuthorizationHeader(t *testing.T) {
144144
})
145145
}
146146

147+
func TestFromToRequestResponse(t *testing.T) {
148+
request := httptest.NewRequest(http.MethodGet, faker.URL(), nil)
149+
request.Header.Add(headers.Authorization, faker.Password())
150+
request.Header.Add(HeaderWebsocketProtocol, faker.Password())
151+
h := FromRequest(request)
152+
h.AppendHeader(headers.Accept, "1.0.0")
153+
h.AppendHeader(headers.AcceptEncoding, "gzip")
154+
r2 := httptest.NewRequest(http.MethodGet, faker.URL(), nil)
155+
assert.Empty(t, r2.Header)
156+
h.AppendToRequest(r2)
157+
assert.NotEmpty(t, r2.Header)
158+
h2 := FromRequest(r2)
159+
assert.True(t, h2.HasHeader(headers.Authorization))
160+
assert.True(t, h2.HasHeader(headers.AcceptEncoding))
161+
assert.True(t, h2.HasHeader(headers.Accept))
162+
assert.True(t, h2.HasHeader(HeaderWebsocketProtocol))
163+
164+
response := httptest.NewRecorder()
165+
response.Header().Set(HeaderWebsocketProtocol, "base64.binary.k8s.io")
166+
response.Header().Set(headers.Authorization, faker.Password())
167+
h3 := FromResponse(response.Result())
168+
h3.AppendHeader(headers.Accept, "1.0.0")
169+
h3.AppendHeader(headers.AcceptEncoding, "gzip")
170+
response2 := httptest.NewRecorder()
171+
h3.AppendToResponse(response2)
172+
h4 := FromResponse(response2.Result())
173+
assert.True(t, h4.HasHeader(headers.Authorization))
174+
assert.True(t, h4.HasHeader(headers.AcceptEncoding))
175+
assert.True(t, h4.HasHeader(headers.Accept))
176+
assert.True(t, h4.HasHeader(HeaderWebsocketProtocol))
177+
}
178+
147179
func TestAddProductInformationToUserAgent(t *testing.T) {
148180
r, err := http.NewRequest(http.MethodGet, faker.URL(), nil)
149181
require.NoError(t, err)
@@ -165,6 +197,18 @@ func TestSetLocationHeaders(t *testing.T) {
165197
assert.Equal(t, location, w.Header().Get(headers.ContentLocation))
166198
}
167199

200+
func TestGetHeaders(t *testing.T) {
201+
header := NewHeaders()
202+
test := faker.Word()
203+
header.AppendHeader(HeaderWebsocketProtocol, test)
204+
assert.Equal(t, test, header.Get(headers.Normalize(HeaderWebsocketProtocol)))
205+
assert.True(t, header.HasHeader(HeaderWebsocketProtocol))
206+
assert.True(t, header.HasHeader(headers.Normalize(HeaderWebsocketProtocol)))
207+
assert.Empty(t, header.Get(headers.ContentLocation))
208+
assert.False(t, header.HasHeader(headers.ContentLocation))
209+
assert.False(t, header.HasHeader(headers.Normalize(headers.ContentLocation)))
210+
}
211+
168212
func TestSanitiseHeaders(t *testing.T) {
169213
header := &http.Header{}
170214
t.Run("empty", func(t *testing.T) {
@@ -197,5 +241,39 @@ func TestSanitiseHeaders(t *testing.T) {
197241
assert.False(t, actual.HasHeader(
198242
HeaderWebsocketProtocol))
199243
})
244+
t.Run("allow/disallow list", func(t *testing.T) {
245+
h := NewHeaders()
246+
h.AppendHeader(headers.Authorization, faker.Password())
247+
h.AppendHeader(HeaderWebsocketProtocol, faker.Password())
248+
h.AppendHeader(headers.Accept, "1.0.0")
249+
h.AppendHeader(headers.AcceptEncoding, "gzip")
250+
h1 := h.Clone()
251+
h1.Sanitise()
252+
assert.True(t, h1.HasHeader(headers.Accept))
253+
assert.True(t, h1.HasHeader(headers.AcceptEncoding))
254+
assert.False(t, h1.HasHeader(HeaderWebsocketProtocol))
255+
assert.False(t, h1.HasHeader(headers.Authorization))
256+
assert.True(t, h.HasHeader(headers.Accept))
257+
assert.True(t, h.HasHeader(headers.AcceptEncoding))
258+
assert.True(t, h.HasHeader(HeaderWebsocketProtocol))
259+
assert.True(t, h.HasHeader(headers.Authorization))
260+
h11 := h.AllowList(headers.Authorization)
261+
assert.True(t, h11.HasHeader(headers.Accept))
262+
assert.True(t, h11.HasHeader(headers.AcceptEncoding))
263+
assert.False(t, h11.HasHeader(HeaderWebsocketProtocol))
264+
assert.True(t, h11.HasHeader(headers.Authorization))
265+
h2 := h.Clone()
266+
h2.Sanitise(headers.Authorization)
267+
h2.RemoveHeaders(headers.AcceptEncoding, headers.Accept)
268+
assert.False(t, h2.HasHeader(headers.Accept))
269+
assert.False(t, h2.HasHeader(headers.AcceptEncoding))
270+
assert.False(t, h2.HasHeader(HeaderWebsocketProtocol))
271+
assert.True(t, h2.HasHeader(headers.Authorization))
272+
h22 := h.DisallowList(headers.AcceptEncoding, headers.Accept)
273+
assert.False(t, h22.HasHeader(headers.Accept))
274+
assert.False(t, h22.HasHeader(headers.AcceptEncoding))
275+
assert.True(t, h22.HasHeader(HeaderWebsocketProtocol))
276+
assert.True(t, h22.HasHeader(headers.Authorization))
277+
})
200278

201279
}

0 commit comments

Comments
 (0)