From 00e20f274682a53c37167f9c46b781efd729245e Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Fri, 12 Dec 2025 10:44:34 -0500 Subject: [PATCH] mcp: don't break the streamable client connection for transient errors When POST requests in the streamableClientConn return a transient error, return this error to the caller rather than permanently breaking the connection. This is achieved by using the special sentinel ErrRejected error to the jsonrpc2 layer. In doing so, the change revealed a pre-existing bug: ErrRejected had the same code as ErrConnectionClosing, and jsonrpc2.WireError implements errors.Is, so the two sentinel values could be conflated. This is fixed by using a new internal code. The new test required some additional machinery in our fake server: the ability to handle multiple requests to the same logical key. There's more to do for #683: we should also retry transient errors in handleSSE. For #683 --- internal/jsonrpc2/conn.go | 10 +- internal/jsonrpc2/wire.go | 2 +- mcp/streamable.go | 32 +++++- mcp/streamable_client_test.go | 178 ++++++++++++++++++++++++++++++---- 4 files changed, 191 insertions(+), 31 deletions(-) diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index 9ed924a4..627ffe7b 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -792,13 +792,9 @@ func (c *Connection) write(ctx context.Context, msg Message) error { err = c.writer.Write(ctx, msg) } - // For rejected requests, we don't set the writeErr (which would break the - // connection). They can just be returned to the caller. - if errors.Is(err, ErrRejected) { - return err - } - - if err != nil && ctx.Err() == nil { + // For cancelled or rejected requests, we don't set the writeErr (which would + // break the connection). They can just be returned to the caller. + if err != nil && ctx.Err() == nil && !errors.Is(err, ErrRejected) { // The call to Write failed, and since ctx.Err() is nil we can't attribute // the failure (even indirectly) to Context cancellation. The writer appears // to be broken, and future writes are likely to also fail. diff --git a/internal/jsonrpc2/wire.go b/internal/jsonrpc2/wire.go index 8be2872e..c0a41bff 100644 --- a/internal/jsonrpc2/wire.go +++ b/internal/jsonrpc2/wire.go @@ -47,7 +47,7 @@ var ( // Such failures do not indicate that the connection is broken, but rather // should be returned to the caller to indicate that the specific request is // invalid in the current context. - ErrRejected = NewError(-32004, "rejected by transport") + ErrRejected = NewError(-32005, "rejected by transport") ) const wireVersion = "2.0" diff --git a/mcp/streamable.go b/mcp/streamable.go index d46bd2f4..b4b2fa31 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1657,11 +1657,18 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e resp, err := c.client.Do(req) if err != nil { - return fmt.Errorf("%s: %v", requestSummary, err) + // Any error from client.Do means the request didn't reach the server. + // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr + // and permanently break the connection. + return fmt.Errorf("%w: %s: %v", jsonrpc2.ErrRejected, requestSummary, err) } if err := c.checkResponse(requestSummary, resp); err != nil { - c.fail(err) + // Only fail the connection for non-transient errors. + // Transient errors (wrapped with ErrRejected) should not break the connection. + if !errors.Is(err, jsonrpc2.ErrRejected) { + c.fail(err) + } return err } @@ -1826,8 +1833,13 @@ func (c *streamableClientConn) checkResponse(requestSummary string, resp *http.R // session is already gone. return fmt.Errorf("%s: failed to connect (session ID: %v): %w", requestSummary, c.sessionID, errSessionMissing) } + // Transient server errors (502, 503, 504, 429) should not break the connection. + // Wrap them with ErrRejected so the jsonrpc2 layer doesn't set writeErr. + if isTransientHTTPStatus(resp.StatusCode) { + return fmt.Errorf("%w: %s: %v", jsonrpc2.ErrRejected, requestSummary, http.StatusText(resp.StatusCode)) + } if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return fmt.Errorf("%s: failed to connect: %v", requestSummary, http.StatusText(resp.StatusCode)) + return fmt.Errorf("%s: %v", requestSummary, http.StatusText(resp.StatusCode)) } return nil } @@ -2012,3 +2024,17 @@ func calculateReconnectDelay(attempt int) time.Duration { return backoffDuration + jitter } + +// isTransientHTTPStatus reports whether the HTTP status code indicates a +// transient server error that should not permanently break the connection. +func isTransientHTTPStatus(statusCode int) bool { + switch statusCode { + case http.StatusInternalServerError, // 500 + http.StatusBadGateway, // 502 + http.StatusServiceUnavailable, // 503 + http.StatusGatewayTimeout, // 504 + http.StatusTooManyRequests: // 429 + return true + } + return false +} diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index dcdda322..e2923325 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -12,6 +12,7 @@ import ( "net/http/httptest" "strings" "sync" + "sync/atomic" "testing" "time" @@ -29,13 +30,15 @@ type streamableRequestKey struct { type header map[string]string +// TODO: replace body and status fields with responseFunc; add helpers to reduce duplication. type streamableResponse struct { - header header // response headers - status int // or http.StatusOK - body string // or "" - optional bool // if set, request need not be sent - wantProtocolVersion string // if "", unchecked - done chan struct{} // if set, receive from this channel before terminating the request + header header // response headers + status int // or http.StatusOK; ignored if responseFunc is set + body string // or ""; ignored if responseFunc is set + responseFunc func(r *jsonrpc.Request) (string, int) // if set, overrides body and status + optional bool // if set, request need not be sent + wantProtocolVersion string // if "", unchecked + done chan struct{} // if set, receive from this channel before terminating the request } type fakeResponses map[streamableRequestKey]*streamableResponse @@ -44,17 +47,17 @@ type fakeStreamableServer struct { t *testing.T responses fakeResponses - callMu sync.Mutex - calls map[streamableRequestKey]int + calledMu sync.Mutex + called map[streamableRequestKey]bool } func (s *fakeStreamableServer) missingRequests() []streamableRequestKey { - s.callMu.Lock() - defer s.callMu.Unlock() + s.calledMu.Lock() + defer s.calledMu.Unlock() var unused []streamableRequestKey for k, resp := range s.responses { - if s.calls[k] == 0 && !resp.optional { + if !s.called[k] && !resp.optional { unused = append(unused, k) } } @@ -67,6 +70,7 @@ func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Reques sessionID: req.Header.Get(sessionIDHeader), lastEventID: req.Header.Get("Last-Event-ID"), // TODO: extract this to a constant, like sessionIDHeader } + var jsonrpcReq *jsonrpc.Request if req.Method == http.MethodPost { body, err := io.ReadAll(req.Body) if err != nil { @@ -82,15 +86,16 @@ func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Reques } if r, ok := msg.(*jsonrpc.Request); ok { key.jsonrpcMethod = r.Method + jsonrpcReq = r } } - s.callMu.Lock() - if s.calls == nil { - s.calls = make(map[streamableRequestKey]int) + s.calledMu.Lock() + if s.called == nil { + s.called = make(map[streamableRequestKey]bool) } - s.calls[key]++ - s.callMu.Unlock() + s.called[key] = true + s.calledMu.Unlock() resp, ok := s.responses[key] if !ok { @@ -98,20 +103,27 @@ func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Reques http.Error(w, "no response", http.StatusInternalServerError) return } - for k, v := range resp.header { - w.Header().Set(k, v) - } + + // Determine body and status, potentially using responseFunc for dynamic responses. + body := resp.body status := resp.status + if resp.responseFunc != nil { + body, status = resp.responseFunc(jsonrpcReq) + } if status == 0 { status = http.StatusOK } + + for k, v := range resp.header { + w.Header().Set(k, v) + } w.WriteHeader(status) w.(http.Flusher).Flush() // flush response headers if v := req.Header.Get(protocolVersionHeader); v != resp.wantProtocolVersion && resp.wantProtocolVersion != "" { s.t.Errorf("%v: bad protocol version header: got %q, want %q", key, v, resp.wantProtocolVersion) } - w.Write([]byte(resp.body)) + w.Write([]byte(body)) w.(http.Flusher).Flush() // flush response if resp.done != nil { @@ -555,3 +567,129 @@ data: { "jsonrpc": "2.0", "method": "notifications/message", "params": { "level" }) } } + +// TestStreamableClientTransientErrors verifies that transient errors (timeouts, +// 5xx HTTP status codes) do not permanently break the client connection. +// This tests the fix for issue #683. +func TestStreamableClientTransientErrors(t *testing.T) { + ctx := context.Background() + + tests := []struct { + transientStatus int // HTTP status to return for the transient call + wantCallError bool // whether the transient call should error + wantSessionBroken bool // whether the session should be broken after + wantErrorContains string // substring expected in error message + }{ + { + transientStatus: http.StatusServiceUnavailable, + wantCallError: true, + wantSessionBroken: false, + wantErrorContains: "Service Unavailable", + }, + { + transientStatus: http.StatusBadGateway, + wantCallError: true, + wantSessionBroken: false, + wantErrorContains: "Bad Gateway", + }, + { + transientStatus: http.StatusGatewayTimeout, + wantCallError: true, + wantSessionBroken: false, + wantErrorContains: "Gateway Timeout", + }, + { + transientStatus: http.StatusTooManyRequests, + wantCallError: true, + wantSessionBroken: false, + wantErrorContains: "Too Many Requests", + }, + { + transientStatus: http.StatusUnauthorized, + wantCallError: true, + wantSessionBroken: true, + wantErrorContains: "Unauthorized", + }, + { + transientStatus: http.StatusNotFound, + wantCallError: true, + wantSessionBroken: true, + wantErrorContains: "not found", // NotFound has special handling + }, + } + + for _, test := range tests { + t.Run(http.StatusText(test.transientStatus), func(t *testing.T) { + var returnedError atomic.Bool + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize, ""}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + {"POST", "123", notificationInitialized, ""}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + {"GET", "123", "", ""}: { + status: http.StatusMethodNotAllowed, + }, + {"POST", "123", methodListTools, ""}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + responseFunc: func(r *jsonrpc.Request) (string, int) { + // First call returns transient error, subsequent calls succeed. + if !returnedError.Swap(true) && test.transientStatus != 0 { + return "", test.transientStatus + } + return jsonBody(t, resp(r.ID.Raw().(int64), &ListToolsResult{Tools: []*Tool{}}, nil)), 0 + }, + optional: true, + }, + {"DELETE", "123", "", ""}: {optional: true}, + }, + } + + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("Connect failed: %v", err) + } + defer session.Close() + + // First call: should trigger transient error. + _, err = session.ListTools(ctx, nil) + if test.wantCallError { + if err == nil { + t.Error("ListTools succeeded unexpectedly, want error") + } else if test.wantErrorContains != "" && !strings.Contains(err.Error(), test.wantErrorContains) { + t.Errorf("ListTools error = %q, want containing %q", err.Error(), test.wantErrorContains) + } + } else if err != nil { + t.Errorf("ListTools failed unexpectedly: %v", err) + } + + // Second call: verifies whether the session is still usable. + _, err = session.ListTools(ctx, nil) + if test.wantSessionBroken { + if err == nil { + t.Error("second ListTools succeeded unexpectedly, want session broken") + } + } else { + if err != nil { + t.Errorf("second ListTools failed unexpectedly: %v (session should survive transient errors)", err) + } + } + }) + } +}