diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index 9ed924a4..627ffe7b 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -792,13 +792,9 @@ func (c *Connection) write(ctx context.Context, msg Message) error { err = c.writer.Write(ctx, msg) } - // For rejected requests, we don't set the writeErr (which would break the - // connection). They can just be returned to the caller. - if errors.Is(err, ErrRejected) { - return err - } - - if err != nil && ctx.Err() == nil { + // For cancelled or rejected requests, we don't set the writeErr (which would + // break the connection). They can just be returned to the caller. + if err != nil && ctx.Err() == nil && !errors.Is(err, ErrRejected) { // The call to Write failed, and since ctx.Err() is nil we can't attribute // the failure (even indirectly) to Context cancellation. The writer appears // to be broken, and future writes are likely to also fail. diff --git a/internal/jsonrpc2/wire.go b/internal/jsonrpc2/wire.go index 8be2872e..c0a41bff 100644 --- a/internal/jsonrpc2/wire.go +++ b/internal/jsonrpc2/wire.go @@ -47,7 +47,7 @@ var ( // Such failures do not indicate that the connection is broken, but rather // should be returned to the caller to indicate that the specific request is // invalid in the current context. - ErrRejected = NewError(-32004, "rejected by transport") + ErrRejected = NewError(-32005, "rejected by transport") ) const wireVersion = "2.0" diff --git a/mcp/streamable.go b/mcp/streamable.go index d46bd2f4..b4b2fa31 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1657,11 +1657,18 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e resp, err := c.client.Do(req) if err != nil { - return fmt.Errorf("%s: %v", requestSummary, err) + // Any error from client.Do means the request didn't reach the server. + // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr + // and permanently break the connection. + return fmt.Errorf("%w: %s: %v", jsonrpc2.ErrRejected, requestSummary, err) } if err := c.checkResponse(requestSummary, resp); err != nil { - c.fail(err) + // Only fail the connection for non-transient errors. + // Transient errors (wrapped with ErrRejected) should not break the connection. + if !errors.Is(err, jsonrpc2.ErrRejected) { + c.fail(err) + } return err } @@ -1826,8 +1833,13 @@ func (c *streamableClientConn) checkResponse(requestSummary string, resp *http.R // session is already gone. return fmt.Errorf("%s: failed to connect (session ID: %v): %w", requestSummary, c.sessionID, errSessionMissing) } + // Transient server errors (502, 503, 504, 429) should not break the connection. + // Wrap them with ErrRejected so the jsonrpc2 layer doesn't set writeErr. + if isTransientHTTPStatus(resp.StatusCode) { + return fmt.Errorf("%w: %s: %v", jsonrpc2.ErrRejected, requestSummary, http.StatusText(resp.StatusCode)) + } if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return fmt.Errorf("%s: failed to connect: %v", requestSummary, http.StatusText(resp.StatusCode)) + return fmt.Errorf("%s: %v", requestSummary, http.StatusText(resp.StatusCode)) } return nil } @@ -2012,3 +2024,17 @@ func calculateReconnectDelay(attempt int) time.Duration { return backoffDuration + jitter } + +// isTransientHTTPStatus reports whether the HTTP status code indicates a +// transient server error that should not permanently break the connection. +func isTransientHTTPStatus(statusCode int) bool { + switch statusCode { + case http.StatusInternalServerError, // 500 + http.StatusBadGateway, // 502 + http.StatusServiceUnavailable, // 503 + http.StatusGatewayTimeout, // 504 + http.StatusTooManyRequests: // 429 + return true + } + return false +} diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index dcdda322..e2923325 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -12,6 +12,7 @@ import ( "net/http/httptest" "strings" "sync" + "sync/atomic" "testing" "time" @@ -29,13 +30,15 @@ type streamableRequestKey struct { type header map[string]string +// TODO: replace body and status fields with responseFunc; add helpers to reduce duplication. type streamableResponse struct { - header header // response headers - status int // or http.StatusOK - body string // or "" - optional bool // if set, request need not be sent - wantProtocolVersion string // if "", unchecked - done chan struct{} // if set, receive from this channel before terminating the request + header header // response headers + status int // or http.StatusOK; ignored if responseFunc is set + body string // or ""; ignored if responseFunc is set + responseFunc func(r *jsonrpc.Request) (string, int) // if set, overrides body and status + optional bool // if set, request need not be sent + wantProtocolVersion string // if "", unchecked + done chan struct{} // if set, receive from this channel before terminating the request } type fakeResponses map[streamableRequestKey]*streamableResponse @@ -44,17 +47,17 @@ type fakeStreamableServer struct { t *testing.T responses fakeResponses - callMu sync.Mutex - calls map[streamableRequestKey]int + calledMu sync.Mutex + called map[streamableRequestKey]bool } func (s *fakeStreamableServer) missingRequests() []streamableRequestKey { - s.callMu.Lock() - defer s.callMu.Unlock() + s.calledMu.Lock() + defer s.calledMu.Unlock() var unused []streamableRequestKey for k, resp := range s.responses { - if s.calls[k] == 0 && !resp.optional { + if !s.called[k] && !resp.optional { unused = append(unused, k) } } @@ -67,6 +70,7 @@ func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Reques sessionID: req.Header.Get(sessionIDHeader), lastEventID: req.Header.Get("Last-Event-ID"), // TODO: extract this to a constant, like sessionIDHeader } + var jsonrpcReq *jsonrpc.Request if req.Method == http.MethodPost { body, err := io.ReadAll(req.Body) if err != nil { @@ -82,15 +86,16 @@ func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Reques } if r, ok := msg.(*jsonrpc.Request); ok { key.jsonrpcMethod = r.Method + jsonrpcReq = r } } - s.callMu.Lock() - if s.calls == nil { - s.calls = make(map[streamableRequestKey]int) + s.calledMu.Lock() + if s.called == nil { + s.called = make(map[streamableRequestKey]bool) } - s.calls[key]++ - s.callMu.Unlock() + s.called[key] = true + s.calledMu.Unlock() resp, ok := s.responses[key] if !ok { @@ -98,20 +103,27 @@ func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Reques http.Error(w, "no response", http.StatusInternalServerError) return } - for k, v := range resp.header { - w.Header().Set(k, v) - } + + // Determine body and status, potentially using responseFunc for dynamic responses. + body := resp.body status := resp.status + if resp.responseFunc != nil { + body, status = resp.responseFunc(jsonrpcReq) + } if status == 0 { status = http.StatusOK } + + for k, v := range resp.header { + w.Header().Set(k, v) + } w.WriteHeader(status) w.(http.Flusher).Flush() // flush response headers if v := req.Header.Get(protocolVersionHeader); v != resp.wantProtocolVersion && resp.wantProtocolVersion != "" { s.t.Errorf("%v: bad protocol version header: got %q, want %q", key, v, resp.wantProtocolVersion) } - w.Write([]byte(resp.body)) + w.Write([]byte(body)) w.(http.Flusher).Flush() // flush response if resp.done != nil { @@ -555,3 +567,129 @@ data: { "jsonrpc": "2.0", "method": "notifications/message", "params": { "level" }) } } + +// TestStreamableClientTransientErrors verifies that transient errors (timeouts, +// 5xx HTTP status codes) do not permanently break the client connection. +// This tests the fix for issue #683. +func TestStreamableClientTransientErrors(t *testing.T) { + ctx := context.Background() + + tests := []struct { + transientStatus int // HTTP status to return for the transient call + wantCallError bool // whether the transient call should error + wantSessionBroken bool // whether the session should be broken after + wantErrorContains string // substring expected in error message + }{ + { + transientStatus: http.StatusServiceUnavailable, + wantCallError: true, + wantSessionBroken: false, + wantErrorContains: "Service Unavailable", + }, + { + transientStatus: http.StatusBadGateway, + wantCallError: true, + wantSessionBroken: false, + wantErrorContains: "Bad Gateway", + }, + { + transientStatus: http.StatusGatewayTimeout, + wantCallError: true, + wantSessionBroken: false, + wantErrorContains: "Gateway Timeout", + }, + { + transientStatus: http.StatusTooManyRequests, + wantCallError: true, + wantSessionBroken: false, + wantErrorContains: "Too Many Requests", + }, + { + transientStatus: http.StatusUnauthorized, + wantCallError: true, + wantSessionBroken: true, + wantErrorContains: "Unauthorized", + }, + { + transientStatus: http.StatusNotFound, + wantCallError: true, + wantSessionBroken: true, + wantErrorContains: "not found", // NotFound has special handling + }, + } + + for _, test := range tests { + t.Run(http.StatusText(test.transientStatus), func(t *testing.T) { + var returnedError atomic.Bool + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize, ""}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + {"POST", "123", notificationInitialized, ""}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + {"GET", "123", "", ""}: { + status: http.StatusMethodNotAllowed, + }, + {"POST", "123", methodListTools, ""}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + responseFunc: func(r *jsonrpc.Request) (string, int) { + // First call returns transient error, subsequent calls succeed. + if !returnedError.Swap(true) && test.transientStatus != 0 { + return "", test.transientStatus + } + return jsonBody(t, resp(r.ID.Raw().(int64), &ListToolsResult{Tools: []*Tool{}}, nil)), 0 + }, + optional: true, + }, + {"DELETE", "123", "", ""}: {optional: true}, + }, + } + + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("Connect failed: %v", err) + } + defer session.Close() + + // First call: should trigger transient error. + _, err = session.ListTools(ctx, nil) + if test.wantCallError { + if err == nil { + t.Error("ListTools succeeded unexpectedly, want error") + } else if test.wantErrorContains != "" && !strings.Contains(err.Error(), test.wantErrorContains) { + t.Errorf("ListTools error = %q, want containing %q", err.Error(), test.wantErrorContains) + } + } else if err != nil { + t.Errorf("ListTools failed unexpectedly: %v", err) + } + + // Second call: verifies whether the session is still usable. + _, err = session.ListTools(ctx, nil) + if test.wantSessionBroken { + if err == nil { + t.Error("second ListTools succeeded unexpectedly, want session broken") + } + } else { + if err != nil { + t.Errorf("second ListTools failed unexpectedly: %v (session should survive transient errors)", err) + } + } + }) + } +}