Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions internal/jsonrpc2/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -792,13 +792,9 @@ func (c *Connection) write(ctx context.Context, msg Message) error {
err = c.writer.Write(ctx, msg)
}

// For rejected requests, we don't set the writeErr (which would break the
// connection). They can just be returned to the caller.
if errors.Is(err, ErrRejected) {
return err
}

if err != nil && ctx.Err() == nil {
// For cancelled or rejected requests, we don't set the writeErr (which would
// break the connection). They can just be returned to the caller.
if err != nil && ctx.Err() == nil && !errors.Is(err, ErrRejected) {
// The call to Write failed, and since ctx.Err() is nil we can't attribute
// the failure (even indirectly) to Context cancellation. The writer appears
// to be broken, and future writes are likely to also fail.
Expand Down
2 changes: 1 addition & 1 deletion internal/jsonrpc2/wire.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ var (
// Such failures do not indicate that the connection is broken, but rather
// should be returned to the caller to indicate that the specific request is
// invalid in the current context.
ErrRejected = NewError(-32004, "rejected by transport")
ErrRejected = NewError(-32005, "rejected by transport")
)

const wireVersion = "2.0"
Expand Down
32 changes: 29 additions & 3 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -1657,11 +1657,18 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e

resp, err := c.client.Do(req)
if err != nil {
return fmt.Errorf("%s: %v", requestSummary, err)
// Any error from client.Do means the request didn't reach the server.
// Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr
// and permanently break the connection.
return fmt.Errorf("%w: %s: %v", jsonrpc2.ErrRejected, requestSummary, err)
}

if err := c.checkResponse(requestSummary, resp); err != nil {
c.fail(err)
// Only fail the connection for non-transient errors.
// Transient errors (wrapped with ErrRejected) should not break the connection.
if !errors.Is(err, jsonrpc2.ErrRejected) {
c.fail(err)
}
return err
}

Expand Down Expand Up @@ -1826,8 +1833,13 @@ func (c *streamableClientConn) checkResponse(requestSummary string, resp *http.R
// session is already gone.
return fmt.Errorf("%s: failed to connect (session ID: %v): %w", requestSummary, c.sessionID, errSessionMissing)
}
// Transient server errors (502, 503, 504, 429) should not break the connection.
// Wrap them with ErrRejected so the jsonrpc2 layer doesn't set writeErr.
if isTransientHTTPStatus(resp.StatusCode) {
return fmt.Errorf("%w: %s: %v", jsonrpc2.ErrRejected, requestSummary, http.StatusText(resp.StatusCode))
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("%s: failed to connect: %v", requestSummary, http.StatusText(resp.StatusCode))
return fmt.Errorf("%s: %v", requestSummary, http.StatusText(resp.StatusCode))
}
return nil
}
Expand Down Expand Up @@ -2012,3 +2024,17 @@ func calculateReconnectDelay(attempt int) time.Duration {

return backoffDuration + jitter
}

// isTransientHTTPStatus reports whether the HTTP status code indicates a
// transient server error that should not permanently break the connection.
func isTransientHTTPStatus(statusCode int) bool {
switch statusCode {
case http.StatusInternalServerError, // 500
http.StatusBadGateway, // 502
http.StatusServiceUnavailable, // 503
http.StatusGatewayTimeout, // 504
http.StatusTooManyRequests: // 429
return true
}
return false
}
178 changes: 158 additions & 20 deletions mcp/streamable_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

Expand All @@ -29,13 +30,15 @@ type streamableRequestKey struct {

type header map[string]string

// TODO: replace body and status fields with responseFunc; add helpers to reduce duplication.
type streamableResponse struct {
header header // response headers
status int // or http.StatusOK
body string // or ""
optional bool // if set, request need not be sent
wantProtocolVersion string // if "", unchecked
done chan struct{} // if set, receive from this channel before terminating the request
header header // response headers
status int // or http.StatusOK; ignored if responseFunc is set
body string // or ""; ignored if responseFunc is set
responseFunc func(r *jsonrpc.Request) (string, int) // if set, overrides body and status
optional bool // if set, request need not be sent
wantProtocolVersion string // if "", unchecked
done chan struct{} // if set, receive from this channel before terminating the request
}

type fakeResponses map[streamableRequestKey]*streamableResponse
Expand All @@ -44,17 +47,17 @@ type fakeStreamableServer struct {
t *testing.T
responses fakeResponses

callMu sync.Mutex
calls map[streamableRequestKey]int
calledMu sync.Mutex
called map[streamableRequestKey]bool
}

func (s *fakeStreamableServer) missingRequests() []streamableRequestKey {
s.callMu.Lock()
defer s.callMu.Unlock()
s.calledMu.Lock()
defer s.calledMu.Unlock()

var unused []streamableRequestKey
for k, resp := range s.responses {
if s.calls[k] == 0 && !resp.optional {
if !s.called[k] && !resp.optional {
unused = append(unused, k)
}
}
Expand All @@ -67,6 +70,7 @@ func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Reques
sessionID: req.Header.Get(sessionIDHeader),
lastEventID: req.Header.Get("Last-Event-ID"), // TODO: extract this to a constant, like sessionIDHeader
}
var jsonrpcReq *jsonrpc.Request
if req.Method == http.MethodPost {
body, err := io.ReadAll(req.Body)
if err != nil {
Expand All @@ -82,36 +86,44 @@ func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Reques
}
if r, ok := msg.(*jsonrpc.Request); ok {
key.jsonrpcMethod = r.Method
jsonrpcReq = r
}
}

s.callMu.Lock()
if s.calls == nil {
s.calls = make(map[streamableRequestKey]int)
s.calledMu.Lock()
if s.called == nil {
s.called = make(map[streamableRequestKey]bool)
}
s.calls[key]++
s.callMu.Unlock()
s.called[key] = true
s.calledMu.Unlock()

resp, ok := s.responses[key]
if !ok {
s.t.Errorf("missing response for %v", key)
http.Error(w, "no response", http.StatusInternalServerError)
return
}
for k, v := range resp.header {
w.Header().Set(k, v)
}

// Determine body and status, potentially using responseFunc for dynamic responses.
body := resp.body
status := resp.status
if resp.responseFunc != nil {
body, status = resp.responseFunc(jsonrpcReq)
}
if status == 0 {
status = http.StatusOK
}

for k, v := range resp.header {
w.Header().Set(k, v)
}
w.WriteHeader(status)
w.(http.Flusher).Flush() // flush response headers

if v := req.Header.Get(protocolVersionHeader); v != resp.wantProtocolVersion && resp.wantProtocolVersion != "" {
s.t.Errorf("%v: bad protocol version header: got %q, want %q", key, v, resp.wantProtocolVersion)
}
w.Write([]byte(resp.body))
w.Write([]byte(body))
w.(http.Flusher).Flush() // flush response

if resp.done != nil {
Expand Down Expand Up @@ -555,3 +567,129 @@ data: { "jsonrpc": "2.0", "method": "notifications/message", "params": { "level"
})
}
}

// TestStreamableClientTransientErrors verifies that transient errors (timeouts,
// 5xx HTTP status codes) do not permanently break the client connection.
// This tests the fix for issue #683.
func TestStreamableClientTransientErrors(t *testing.T) {
ctx := context.Background()

tests := []struct {
transientStatus int // HTTP status to return for the transient call
wantCallError bool // whether the transient call should error
wantSessionBroken bool // whether the session should be broken after
wantErrorContains string // substring expected in error message
}{
{
transientStatus: http.StatusServiceUnavailable,
wantCallError: true,
wantSessionBroken: false,
wantErrorContains: "Service Unavailable",
},
{
transientStatus: http.StatusBadGateway,
wantCallError: true,
wantSessionBroken: false,
wantErrorContains: "Bad Gateway",
},
{
transientStatus: http.StatusGatewayTimeout,
wantCallError: true,
wantSessionBroken: false,
wantErrorContains: "Gateway Timeout",
},
{
transientStatus: http.StatusTooManyRequests,
wantCallError: true,
wantSessionBroken: false,
wantErrorContains: "Too Many Requests",
},
{
transientStatus: http.StatusUnauthorized,
wantCallError: true,
wantSessionBroken: true,
wantErrorContains: "Unauthorized",
},
{
transientStatus: http.StatusNotFound,
wantCallError: true,
wantSessionBroken: true,
wantErrorContains: "not found", // NotFound has special handling
},
}

for _, test := range tests {
t.Run(http.StatusText(test.transientStatus), func(t *testing.T) {
var returnedError atomic.Bool
fake := &fakeStreamableServer{
t: t,
responses: fakeResponses{
{"POST", "", methodInitialize, ""}: {
header: header{
"Content-Type": "application/json",
sessionIDHeader: "123",
},
body: jsonBody(t, initResp),
},
{"POST", "123", notificationInitialized, ""}: {
status: http.StatusAccepted,
wantProtocolVersion: latestProtocolVersion,
},
{"GET", "123", "", ""}: {
status: http.StatusMethodNotAllowed,
},
{"POST", "123", methodListTools, ""}: {
header: header{
"Content-Type": "application/json",
sessionIDHeader: "123",
},
responseFunc: func(r *jsonrpc.Request) (string, int) {
// First call returns transient error, subsequent calls succeed.
if !returnedError.Swap(true) && test.transientStatus != 0 {
return "", test.transientStatus
}
return jsonBody(t, resp(r.ID.Raw().(int64), &ListToolsResult{Tools: []*Tool{}}, nil)), 0
},
optional: true,
},
{"DELETE", "123", "", ""}: {optional: true},
},
}

httpServer := httptest.NewServer(fake)
defer httpServer.Close()

transport := &StreamableClientTransport{Endpoint: httpServer.URL}
client := NewClient(testImpl, nil)
session, err := client.Connect(ctx, transport, nil)
if err != nil {
t.Fatalf("Connect failed: %v", err)
}
defer session.Close()

// First call: should trigger transient error.
_, err = session.ListTools(ctx, nil)
if test.wantCallError {
if err == nil {
t.Error("ListTools succeeded unexpectedly, want error")
} else if test.wantErrorContains != "" && !strings.Contains(err.Error(), test.wantErrorContains) {
t.Errorf("ListTools error = %q, want containing %q", err.Error(), test.wantErrorContains)
}
} else if err != nil {
t.Errorf("ListTools failed unexpectedly: %v", err)
}

// Second call: verifies whether the session is still usable.
_, err = session.ListTools(ctx, nil)
if test.wantSessionBroken {
if err == nil {
t.Error("second ListTools succeeded unexpectedly, want session broken")
}
} else {
if err != nil {
t.Errorf("second ListTools failed unexpectedly: %v (session should survive transient errors)", err)
}
}
})
}
}