Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/20251031150814.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:sparkles: `[http/proxy]` Add helpers for proxying requests and responses
3 changes: 2 additions & 1 deletion utils/http/httptest/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ import (

// NewTestServer creates a test server
func NewTestServer(t *testing.T, ctx context.Context, handler http.Handler, port string) {
t.Helper()
list, err := net.Listen("tcp", fmt.Sprintf(":%v", port))
require.Nil(t, err)
require.NoError(t, err)
srv := &http.Server{
Handler: handler,
ReadHeaderTimeout: time.Minute,
Expand Down
157 changes: 157 additions & 0 deletions utils/http/proxy/proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package proxy

import (
"bytes"
"context"
"io"
"net/http"
"strconv"
"strings"

"github.com/go-http-utils/headers"

"github.com/ARM-software/golang-utils/utils/commonerrors"
httpheaders "github.com/ARM-software/golang-utils/utils/http/headers"
"github.com/ARM-software/golang-utils/utils/reflection"
"github.com/ARM-software/golang-utils/utils/safecast"
"github.com/ARM-software/golang-utils/utils/safeio"
)

// ProxyDisallowList describes headers which are not proxied back.
var ProxyDisallowList = []string{
headers.AccessControlAllowOrigin,
headers.AccessControlAllowMethods,
headers.AccessControlAllowHeaders,
headers.AccessControlExposeHeaders,
headers.AccessControlMaxAge,
headers.AccessControlAllowCredentials,
}

// ProxyRequest proxies a request to a new endpoint. The method can also be changed. Headers are sanitised during the process.
func ProxyRequest(r *http.Request, proxyMethod, endpoint string) (proxiedRequest *http.Request, err error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have typically built functions like this with an object (struct) input for the configuration because they've always grown over time to add headers, rewriting, redirect following, etc, although I appreciate that might not be the norm in Go.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is exactly what http.Request is : it contains all the information for making a request. It does not do the request, That is the role of the client

if reflection.IsEmpty(r) {
err = commonerrors.UndefinedVariable("request to proxy")
return
}
ctx := r.Context()
// Note: It is important to know that an 0 or -1 content length does not mean there is no body. This is likely the case but it could also be because the body was never read and its size never assessed.
contentLength := determineRequestContentLength(r)
h := httpheaders.FromRequest(r).AllowList(headers.Authorization)
if reflection.IsEmpty(proxyMethod) {
proxyMethod = http.MethodGet
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does it change it to get if it is not set?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is usually what happens in HTTP. the default method is GET https://en.wikipedia.org/wiki/Post/Redirect/Get

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, I didn't know about that.

It still seems like a proxy should not do that, that Post/Redirect/Get thing looks like it is more about allowing page refreshing without form resubmission which is different to proxying.

But I will leave the decision up to you

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And they introduced 307 and 308 because the original ones were causing so many issues for people.

I think to me, a proxy is different from a redirect. I see a proxy like a middleman so it is just part of the normal request therefore it shouldn't change behavior, whereas a redirect seems like a more explicit change to the destination.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this Post/Redirect/Get was more to give an example. this should never happen but I preferred that the proxy would be attempted rather than a 500 being returned

}
proxiedRequest, err = http.NewRequestWithContext(ctx, proxyMethod, endpoint, r.Body)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the origin returns a redirect, say 301, will this natively follow the 301 to resolution or just return the 301 back to the client to follow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does not make the call this is just creating the request object (the configuration object you were suggesting)
It is rewriting the endpoint path and method but it also needs to "clone" quite a few things manually such as headers and body

Copy link
Contributor Author

@acabarbaye acabarbaye Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to look into the implementation of the newRequest to see what is actually done. It is trying to do things a bit like https://cs.opensource.google/go/go/+/refs/tags/go1.25.3:src/net/http/request.go;l=386 but the request object does not expose all the fields and so, these are weird workarounds to make sure everything is correctly set

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not ideal and fairly unnecessary complex but was the only way to cover all the corner-cases I encountered

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is left to the client to follow redirect

if err != nil {
err = commonerrors.WrapError(commonerrors.ErrUnexpected, err, "could not create a proxied request")
return
}

if proxiedRequest.ContentLength <= 0 {
if proxiedRequest.Body == nil || proxiedRequest.Body == http.NoBody {
if contentLength > 0 {
// In this case, NewRequestWithContext does not understand/expect the request body type (not a string/byte buffer as it may be wrapped into a bigger structure) and so, the body of the proxied request is set to nil
// This makes sure this does not happen without performing a copy of the body and the use of unnecessary memory.
proxiedRequest.Body = r.Body
proxiedRequest.GetBody = r.GetBody
} else {
// In this case, it will attempt a copy of the request body which should not be costly as the request is unlikely to have a body. Although it may still do as contentlength may not have actually been evaluated. However, we want to make sure it is set to the same type as the original request.
proxiedRequest, err = http.NewRequestWithContext(ctx, proxyMethod, endpoint, convertBody(ctx, r.Body))
if err != nil {
err = commonerrors.WrapError(commonerrors.ErrUnexpected, err, "could not create a proxied request")
return
}
}
} else {
// In this case, the original request is unlikely to have a body but we want to make sure that the body is of the same type.
if contentLength <= 0 {
proxiedRequest, err = http.NewRequestWithContext(ctx, proxyMethod, endpoint, convertBody(ctx, r.Body))
if err != nil {
err = commonerrors.WrapError(commonerrors.ErrUnexpected, err, "could not create a proxied request")
return
}
}
}
if contentLength > 0 && proxiedRequest.ContentLength <= 0 {
proxiedRequest.ContentLength = contentLength
h.AppendHeader(headers.ContentLength, strconv.FormatInt(contentLength, 10))
}
}
if contentLength > 0 && contentLength != proxiedRequest.ContentLength {
err = commonerrors.Newf(commonerrors.ErrUnexpected, "proxied request does not have the same content length `%v` as original request `%v`", proxiedRequest.ContentLength, contentLength)
return
}
h.AppendToRequest(proxiedRequest)
return
}

func determineRequestContentLength(r *http.Request) int64 {
if reflection.IsEmpty(r) {
return -1
}
if r.ContentLength > 0 {
return r.ContentLength
}
// Following what was done in https://github.com/luraproject/lura/blob/b9ad9ab654dd6149aeb58a5d6ffe731aba41717e/proxy/http.go#L99C1-L105C4
v := r.Header.Values(headers.ContentLength)
if len(v) == 1 && v[0] != "chunked" {
if size, err := strconv.Atoi(v[0]); err == nil {
return safecast.ToInt64(size)
}
}
return -1
}

func convertBody(_ context.Context, body io.Reader) io.Reader {
if body == nil || body == http.NoBody {
return http.NoBody
}
switch v := body.(type) {
case *bytes.Buffer:
return body
case *bytes.Reader:
return body
case *strings.Reader:
return body
default:
// see example https://github.com/luraproject/lura/blob/b9ad9ab654dd6149aeb58a5d6ffe731aba41717e/proxy/http.go#L73
buf := new(bytes.Buffer)
_, err := buf.ReadFrom(v)
if err != nil {
return http.NoBody
}
if b, ok := body.(io.ReadCloser); ok {
_ = b.Close()
}
return buf
}
}

// ProxyResponse proxies a response to a writer. Headers are sanitised and some headers such as CORS headers will be removed from the response.
func ProxyResponse(ctx context.Context, resp *http.Response, w http.ResponseWriter) (err error) {
if w == nil {
err = commonerrors.UndefinedVariable("response writer")
return
}
if reflection.IsEmpty(resp) {
err = commonerrors.UndefinedVariable("response")
return
}
h := httpheaders.FromResponse(resp)
h.Sanitise()

var written int64
_, err = safeio.CopyDataWithContext(ctx, resp.Body, w)
if resp.Body != nil && resp.Body != http.NoBody {
written, err = safeio.CopyDataWithContext(ctx, resp.Body, w)
if err != nil {
err = commonerrors.DescribeCircumstance(err, "failed copying response body")
}
}
if written >= 0 {
h.AppendHeader(headers.ContentLength, strconv.FormatInt(written, 10))
}
h.RemoveHeaders(ProxyDisallowList...)
h.AppendToResponse(w)
w.WriteHeader(resp.StatusCode)
return
}
164 changes: 164 additions & 0 deletions utils/http/proxy/proxy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
package proxy

import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"

"github.com/go-faker/faker/v4"
"github.com/go-http-utils/headers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/ARM-software/golang-utils/utils/commonerrors"
"github.com/ARM-software/golang-utils/utils/commonerrors/errortest"
"github.com/ARM-software/golang-utils/utils/safecast"
"github.com/ARM-software/golang-utils/utils/safeio"
)

func TestProxy(t *testing.T) {
content := faker.Paragraph()
path := faker.URL()
password := faker.Password()
tests := []struct {
request *http.Request
}{
{
request: httptest.NewRequest(http.MethodGet, faker.URL(), io.NopCloser(strings.NewReader(content))),
},
{
request: httptest.NewRequest(http.MethodGet, faker.URL(), strings.NewReader(content)),
},
{
request: httptest.NewRequest(http.MethodGet, faker.URL(), io.NopCloser(bytes.NewReader([]byte(content)))),
},
{
request: httptest.NewRequest(http.MethodGet, faker.URL(), bytes.NewReader([]byte(content))),
},
{
request: httptest.NewRequest(http.MethodGet, faker.URL(), io.NopCloser(bytes.NewBuffer([]byte(content)))),
},
{
request: httptest.NewRequest(http.MethodGet, faker.URL(), bytes.NewBuffer([]byte(content))),
},
}
for i := range tests {
test := tests[i]
t.Run(strconv.Itoa(i), func(t *testing.T) {
req := test.request
req.Header.Set(headers.AccessControlAllowOrigin, faker.Word())
req.Header.Set(headers.XHTTPMethodOverride, http.MethodPut)
req.Header.Set(headers.Authorization, password)
assert.NotEqual(t, req.URL.String(), path)
_, err := ProxyRequest(nil, http.MethodPost, "/")
errortest.AssertError(t, err, commonerrors.ErrUndefined)
preq, err := ProxyRequest(req, " ", path)
require.NoError(t, err)
require.NotNil(t, preq)
assert.Equal(t, path, preq.URL.String())
assert.Equal(t, http.MethodGet, preq.Method)
assert.NotEmpty(t, preq.Header.Get(headers.AccessControlAllowOrigin))
assert.NotEmpty(t, preq.Header.Get(headers.Authorization))
assert.NotZero(t, preq.ContentLength)
resp := generateTestResponseBasedOnRequest(t, preq)
defer func() {
if resp != nil {
_ = resp.Body.Close()
}
}()
w := httptest.NewRecorder()
require.NoError(t, ProxyResponse(context.Background(), resp, w))
proxiedResp := w.Result()
defer func() { _ = proxiedResp.Body.Close() }()
assert.Empty(t, w.Header().Get(headers.AccessControlAllowOrigin))
assert.Equal(t, http.MethodPut, w.Header().Get(headers.XHTTPMethodOverride))
assert.Equal(t, http.StatusOK, resp.StatusCode)
responseContent, err := safeio.ReadAll(context.Background(), proxiedResp.Body)
require.NoError(t, err)
assert.Equal(t, content, string(responseContent))
})
}
}

func TestEmptyResponse(t *testing.T) {
path := faker.URL()
tests := []struct {
request *http.Request
}{
{
httptest.NewRequest(http.MethodGet, faker.URL(), nil),
},
{
request: httptest.NewRequest(http.MethodGet, faker.URL(), http.NoBody),
},
{
request: httptest.NewRequest(http.MethodGet, faker.URL(), io.NopCloser(http.NoBody)),
},
{
request: httptest.NewRequest(http.MethodGet, faker.URL(), bytes.NewReader(nil)),
},
{
request: httptest.NewRequest(http.MethodGet, faker.URL(), io.NopCloser(bytes.NewBuffer(nil))),
},
{
request: httptest.NewRequest(http.MethodGet, faker.URL(), strings.NewReader("")),
},
{
request: httptest.NewRequest(http.MethodGet, faker.URL(), io.NopCloser(strings.NewReader(""))),
},
}
for i := range tests {
test := tests[i]
t.Run(strconv.Itoa(i), func(t *testing.T) {
req := test.request
assert.NotEqual(t, req.URL.String(), path)
preq, err := ProxyRequest(req, http.MethodPost, path)
require.NoError(t, err)
require.NotNil(t, preq)
assert.Equal(t, path, preq.URL.String())
assert.Equal(t, http.MethodPost, preq.Method)
assert.Zero(t, preq.ContentLength)

resp := generateTestResponseBasedOnRequest(t, preq)
defer func() {
if resp != nil {
_ = resp.Body.Close()
}
}()
w := httptest.NewRecorder()
require.NoError(t, ProxyResponse(context.Background(), resp, w))
require.NoError(t, err)
returnedResp := w.Result()
assert.LessOrEqual(t, returnedResp.ContentLength, safecast.ToInt64(0))
assert.Equal(t, http.StatusOK, returnedResp.StatusCode)
})
}
}

func loopTestHandler(t *testing.T, w http.ResponseWriter, r *http.Request) {
t.Helper()
require.NotNil(t, r)
require.NotNil(t, w)
for k, v := range r.Header {
for h := range v {
w.Header().Add(k, v[h])
}
}
written, err := safeio.CopyDataWithContext(r.Context(), r.Body, w)
require.NoError(t, err)
w.Header().Add(headers.ContentLength, strconv.FormatInt(written, 10))
w.WriteHeader(http.StatusOK)
}

func generateTestResponseBasedOnRequest(t *testing.T, r *http.Request) *http.Response {
t.Helper()
require.NotNil(t, r)
w := httptest.NewRecorder()
loopTestHandler(t, w, r)
return w.Result()
}
Loading