From f60d4cefa8521af42e0494e42c5a0b4c1e786acc Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Fri, 21 Nov 2025 18:35:44 +0000 Subject: [PATCH 1/2] mcp: support SSE polling and server disconnects (SEP-1699) This CL implements support for SEP-1699, adding the following new API: - Event.Retry corresponds to the 'retry' field of a server sent event. This is used to signal client retry behavior. - RequestExtra.CloseStream is set by the streamable transport to allow closing a stream. The streamable server transport is updated to set CloseStream for calls, allowing server operations or middleware to close the underlying stream. If the configured 'reconnectAfter' delay is set, a 'retry' message is sent prior to terminating the stream. The streamable client transport is updated to read the configured delay. Fixes #630 --- docs/rough_edges.md | 4 + internal/docs/rough_edges.src.md | 4 + mcp/event.go | 23 ++-- mcp/shared.go | 8 ++ mcp/streamable.go | 161 ++++++++++++++++++++---- mcp/streamable_test.go | 205 +++++++++++++++++++++++++++---- 6 files changed, 341 insertions(+), 64 deletions(-) diff --git a/docs/rough_edges.md b/docs/rough_edges.md index 41888f80..5c732bdf 100644 --- a/docs/rough_edges.md +++ b/docs/rough_edges.md @@ -11,6 +11,10 @@ v2. **Workaround**: `Open` may be implemented as a no-op. +- `Event` need not have been exported: it's an implementation detail of the SSE + and streamable transports. Also the 'Name' field is a misnomer: it should be + 'event'. + - Enforcing valid tool names: with [SEP-986](https://github.com/modelcontextprotocol/modelcontextprotocol/issues/986) landing after the SDK was at v1, we missed an opportunity to panic on invalid diff --git a/internal/docs/rough_edges.src.md b/internal/docs/rough_edges.src.md index fb52391d..ff95b263 100644 --- a/internal/docs/rough_edges.src.md +++ b/internal/docs/rough_edges.src.md @@ -10,6 +10,10 @@ v2. **Workaround**: `Open` may be implemented as a no-op. +- `Event` need not have been exported: it's an implementation detail of the SSE + and streamable transports. Also the 'Name' field is a misnomer: it should be + 'event'. + - Enforcing valid tool names: with [SEP-986](https://github.com/modelcontextprotocol/modelcontextprotocol/issues/986) landing after the SDK was at v1, we missed an opportunity to panic on invalid diff --git a/mcp/event.go b/mcp/event.go index 281f5925..5c322c4a 100644 --- a/mcp/event.go +++ b/mcp/event.go @@ -29,14 +29,15 @@ const validateMemoryEventStore = false // An Event is a server-sent event. // See https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#fields. type Event struct { - Name string // the "event" field - ID string // the "id" field - Data []byte // the "data" field + Name string // the "event" field + ID string // the "id" field + Data []byte // the "data" field + Retry string // the "retry" field } // Empty reports whether the Event is empty. func (e Event) Empty() bool { - return e.Name == "" && e.ID == "" && len(e.Data) == 0 + return e.Name == "" && e.ID == "" && len(e.Data) == 0 && e.Retry == "" } // writeEvent writes the event to w, and flushes. @@ -48,6 +49,9 @@ func writeEvent(w io.Writer, evt Event) (int, error) { if evt.ID != "" { fmt.Fprintf(&b, "id: %s\n", evt.ID) } + if evt.Retry != "" { + fmt.Fprintf(&b, "retry: %s\n", evt.Retry) + } fmt.Fprintf(&b, "data: %s\n\n", string(evt.Data)) n, err := w.Write(b.Bytes()) if f, ok := w.(http.Flusher); ok { @@ -73,6 +77,7 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] { eventKey = []byte("event") idKey = []byte("id") dataKey = []byte("data") + retryKey = []byte("retry") ) return func(yield func(Event, error) bool) { @@ -119,6 +124,8 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] { evt.Name = strings.TrimSpace(string(after)) case bytes.Equal(before, idKey): evt.ID = strings.TrimSpace(string(after)) + case bytes.Equal(before, retryKey): + evt.Retry = strings.TrimSpace(string(after)) case bytes.Equal(before, dataKey): data := bytes.TrimSpace(after) if dataBuf != nil { @@ -191,12 +198,8 @@ type dataList struct { } func (dl *dataList) appendData(d []byte) { - // If we allowed empty data, we would consume memory without incrementing the size. - // We could of course account for that, but we keep it simple and assume there is no - // empty data. - if len(d) == 0 { - panic("empty data item") - } + // Empty data consumes memory but doesn't increment size. However, it should + // be rare. dl.data = append(dl.data, d) dl.size += len(d) } diff --git a/mcp/shared.go b/mcp/shared.go index 3fac40b2..459b917f 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -432,6 +432,14 @@ type ServerRequest[P Params] struct { type RequestExtra struct { TokenInfo *auth.TokenInfo // bearer token info (e.g. from OAuth) if any Header http.Header // header from HTTP request, if any + + // CloseStream closes the current request stream, if the current transport + // supports replaying requests. + // + // If reconnectAfter is nonzero, it signals to the client to reconnect after + // the given duration. Otherwise, clients may determine their own + // reconnection policy. + CloseStream func(reconnectAfter time.Duration) } func (*ClientRequest[P]) isRequest() {} diff --git a/mcp/streamable.go b/mcp/streamable.go index 67cdc390..21c73848 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -602,12 +602,25 @@ type stream struct { // HTTP connection acquires ownership of the stream by setting this field. deliver func(data []byte, final bool) error + // closeLocked sends a 'close' event to the client with a configurable retry + // delay, if there is a delivery channel available. The stream must be + // locked. + closeLocked func(reconnectAfter time.Duration) + // streamRequests is the set of unanswered incoming requests for the stream. // // Requests are removed when their response has been received. requests map[jsonrpc.ID]struct{} } +func (s *stream) close(reconnectAfter time.Duration) { + s.mu.Lock() + defer s.mu.Unlock() + if s.closeLocked != nil { + s.closeLocked(reconnectAfter) + } +} + // doneLocked reports whether the stream is logically complete. // // s.mu must be held while calling this function. @@ -704,6 +717,7 @@ func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request defer func() { stream.mu.Lock() stream.deliver = nil + stream.closeLocked = nil stream.mu.Unlock() }() @@ -720,14 +734,12 @@ func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request // writeEvent writes an SSE event to w corresponding to the given stream, data, and index. // lastIdx is incremented before writing, so that it continues to point to the index of the // last event written to the stream. -func (c *streamableServerConn) writeEvent(w http.ResponseWriter, stream *stream, data []byte, lastIdx *int) error { +func (c *streamableServerConn) writeEvent(w http.ResponseWriter, streamID string, e Event, lastIdx *int) error { *lastIdx++ - e := Event{ - Name: "message", - Data: data, - } if c.eventStore != nil { - e.ID = formatEventID(stream.id, *lastIdx) + // TODO(rfindley): this isn't quite right: we should only set the ID if the + // message was actually stored successfully. + e.ID = formatEventID(streamID, *lastIdx) } if _, err := writeEvent(w, e); err != nil { return err @@ -735,6 +747,19 @@ func (c *streamableServerConn) writeEvent(w http.ResponseWriter, stream *stream, return nil } +// writeCloseEvent writes a 'close' event to the stream, signaling to the +// caller that they should reconnect after the configured delay. +func (c *streamableServerConn) writeCloseEvent(w http.ResponseWriter, reconnectAfter time.Duration) { + reconnectStr := strconv.FormatInt(reconnectAfter.Milliseconds(), 10) + // Note: this event is not stored, since we don't want or need to replay it. + if _, err := writeEvent(w, Event{ + Name: "close", // don't make this empty, as the default event type is "message" + Retry: reconnectStr, + }); err != nil { + c.logger.Warn(fmt.Sprintf("Writing close event: %v", err)) + } +} + // acquireStream acquires the stream and replays all events since lastIdx, if // any, updating lastIdx accordingly. If non-nil, the resulting stream will be // registered for receiving new messages, and the resulting done channel will @@ -805,7 +830,9 @@ func (c *streamableServerConn) acquireStream(ctx context.Context, w http.Respons http.Error(w, "failed to replay events", http.StatusBadRequest) return nil, nil } - toReplay = append(toReplay, data) + if len(data) > 0 { + toReplay = append(toReplay, data) + } } } @@ -823,7 +850,7 @@ func (c *streamableServerConn) acquireStream(ctx context.Context, w http.Respons } for _, data := range toReplay { - if err := c.writeEvent(w, s, data, lastIdx); err != nil { + if err := c.writeEvent(w, s.id, Event{Name: "message", Data: data}, lastIdx); err != nil { return nil, nil } } @@ -835,16 +862,31 @@ func (c *streamableServerConn) acquireStream(ctx context.Context, w http.Respons // The stream is not done: register a delivery function before the stream is // unlocked, allowing the connection to write new events. - done := make(chan struct{}) + var done = make(chan struct{}) s.deliver = func(data []byte, final bool) error { + select { + case <-done: + return fmt.Errorf("stream closed") + default: + } if err := ctx.Err(); err != nil { return err } - err := c.writeEvent(w, s, data, lastIdx) if final { - close(done) + defer close(done) } - return err + return c.writeEvent(w, s.id, Event{Name: "message", Data: data}, lastIdx) + } + s.closeLocked = func(reconnectAfter time.Duration) { + select { + case <-done: + return + default: + } + if reconnectAfter > 0 { + c.writeCloseEvent(w, reconnectAfter) + } + close(done) } return s, done } @@ -914,6 +956,19 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques } if jreq.IsCall() { calls[jreq.ID] = struct{}{} + jreq.Extra.(*RequestExtra).CloseStream = func(reconnectAfter time.Duration) { + c.mu.Lock() + streamID, ok := c.requestStreams[jreq.ID] + var stream *stream + if ok { + stream = c.streams[streamID] + } + c.mu.Unlock() + + if stream != nil { + stream.close(reconnectAfter) + } + } } } } @@ -993,11 +1048,39 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques } else { // Write events in the order we receive them. lastIndex := -1 + if c.eventStore != nil { + // Write a priming event. + // We must also write it to the event store in order for indexes to + // align. + if err := c.eventStore.Append(req.Context(), c.sessionID, stream.id, nil); err != nil { + c.logger.Warn(fmt.Sprintf("Storing priming event: %v", err)) + } + if err := c.writeEvent(w, stream.id, Event{Name: "prime"}, &lastIndex); err != nil { + c.logger.Warn(fmt.Sprintf("Writing priming event: %v", err)) + } + } + stream.deliver = func(data []byte, final bool) error { + select { + case <-done: + return fmt.Errorf("stream closed") + default: + } if final { defer close(done) } - return c.writeEvent(w, stream, data, &lastIndex) + return c.writeEvent(w, stream.id, Event{Name: "message", Data: data}, &lastIndex) + } + stream.closeLocked = func(reconnectAfter time.Duration) { + select { + case <-done: + return + default: + } + if reconnectAfter > 0 { + c.writeCloseEvent(w, reconnectAfter) + } + close(done) } } @@ -1007,6 +1090,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques // TODO(rfindley): if we have no event store, we should really cancel all // remaining requests here, since the client will never get the results. stream.deliver = nil + stream.closeLocked = nil stream.mu.Unlock() }() @@ -1187,22 +1271,25 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e } delivered := false + var errs []error + // TODO(rfindley): we should only append if the response is SSE, not JSON, by + // pushing down into the delivery layer. if c.eventStore != nil { if err := c.eventStore.Append(ctx, c.sessionID, s.id, data); err != nil { - // TODO: report a side-channel error. + errs = append(errs, err) } else { delivered = true } } if s.deliver != nil { if err := s.deliver(data, s.doneLocked()); err != nil { - // TODO: report a side-channel error. + errs = append(errs, err) } else { delivered = true } } if !delivered { - return fmt.Errorf("%w: undelivered message", jsonrpc2.ErrRejected) + return fmt.Errorf("%w: undelivered message: %v", jsonrpc2.ErrRejected, errors.Join(errs...)) } return nil } @@ -1360,7 +1447,7 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) { } func (c *streamableClientConn) connectStandaloneSSE() { - resp, err := c.connectSSE("") + resp, err := c.connectSSE("", 0) if err != nil { c.fail(fmt.Errorf("standalone SSE request failed (session ID: %v): %v", c.sessionID, err)) return @@ -1589,7 +1676,7 @@ func (c *streamableClientConn) handleSSE(requestSummary string, resp *http.Respo // // Eventually, if we don't get the response, we should stop trying and // fail the request. - lastEventID, clientClosed := c.processStream(requestSummary, resp, forCall) + lastEventID, reconnectDelay, clientClosed := c.processStream(requestSummary, resp, forCall) // If the connection was closed by the client, we're done. if clientClosed { @@ -1602,7 +1689,7 @@ func (c *streamableClientConn) handleSSE(requestSummary string, resp *http.Respo } // The stream was interrupted or ended by the server. Attempt to reconnect. - newResp, err := c.connectSSE(lastEventID) + newResp, err := c.connectSSE(lastEventID, reconnectDelay) if err != nil { // All reconnection attempts failed: fail the connection. c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err)) @@ -1644,7 +1731,7 @@ func (c *streamableClientConn) checkResponse(requestSummary string, resp *http.R // incoming channel. It returns the ID of the last processed event and a flag // indicating if the connection was closed by the client. If resp is nil, it // returns "", false. -func (c *streamableClientConn) processStream(requestSummary string, resp *http.Response, forCall *jsonrpc.Request) (lastEventID string, clientClosed bool) { +func (c *streamableClientConn) processStream(requestSummary string, resp *http.Response, forCall *jsonrpc.Request) (lastEventID string, reconnectDelay time.Duration, clientClosed bool) { defer resp.Body.Close() for evt, err := range scanEvents(resp.Body) { if err != nil { @@ -1656,7 +1743,11 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R lastEventID = evt.ID } - // Skip non-message events (e.g., "ping" events used for keep-alive) + if evt.Retry != "" { + if n, err := strconv.ParseInt(evt.Retry, 10, 64); err == nil { + reconnectDelay = time.Duration(n) * time.Millisecond + } + } // According to SSE spec, events with no name default to "message" if evt.Name != "" && evt.Name != "message" { continue @@ -1665,7 +1756,7 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R msg, err := jsonrpc.DecodeMessage(evt.Data) if err != nil { c.fail(fmt.Errorf("%s: failed to decode event: %v", requestSummary, err)) - return "", true + return "", 0, true } select { @@ -1674,12 +1765,12 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R // TODO: we should never get a response when forReq is nil (the standalone SSE request). // We should detect this case. if jsonResp.ID == forCall.ID { - return "", true + return "", 0, true } } case <-c.done: // The connection was closed by the client; exit gracefully. - return "", true + return "", 0, true } } // The loop finished without an error, indicating the server closed the stream. @@ -1696,7 +1787,7 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R case <-c.done: } } - return lastEventID, false + return lastEventID, reconnectDelay, false } // connectSSE handles the logic of connecting a text/event-stream connection. @@ -1706,7 +1797,10 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R // If connection fails, connectSSE retries with an exponential backoff // strategy. It returns a new, valid HTTP response if successful, or an error // if all retries are exhausted. -func (c *streamableClientConn) connectSSE(lastEventID string) (*http.Response, error) { +// +// reconnectDelay is the delay set by the server using the SSE retry field, or +// 0. +func (c *streamableClientConn) connectSSE(lastEventID string, reconnectDelay time.Duration) (*http.Response, error) { var finalErr error // If lastEventID is set, we've already connected successfully once, so // consider that to be the first attempt. @@ -1714,11 +1808,23 @@ func (c *streamableClientConn) connectSSE(lastEventID string) (*http.Response, e if lastEventID != "" { attempt = 1 } + delay := calculateReconnectDelay(attempt) + if reconnectDelay > 0 { + delay = reconnectDelay // honor the server's requested initial delay + } for ; attempt <= c.maxRetries; attempt++ { select { case <-c.done: return nil, fmt.Errorf("connection closed by client during reconnect") - case <-time.After(calculateReconnectDelay(attempt)): + case <-c.ctx.Done(): + // If the connection context is canceled, the request below will not + // succeed anyway. + // + // TODO(#662): we should not be using the connection context for + // reconnection: we should instead be using the call context (from + // Write). + return nil, fmt.Errorf("connection context closed") + case <-time.After(delay): req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.url, nil) if err != nil { return nil, err @@ -1731,6 +1837,7 @@ func (c *streamableClientConn) connectSSE(lastEventID string) (*http.Response, e resp, err := c.client.Do(req) if err != nil { finalErr = err // Store the error and try again. + delay = calculateReconnectDelay(attempt + 1) continue } return resp, nil diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index e9c0cbda..2298e653 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -517,6 +517,94 @@ func testClientReplay(t *testing.T, test clientReplayTest) { } } +func TestStreamableServerDisconnect(t *testing.T) { + server := NewServer(testImpl, nil) + + // Test that client replayability allows the server to terminate incoming + // requests immediately, and have the client replay them. + + // testStream exercises stream resumption by interleaving stream termination + // with progress notifications. + testStream := func(ctx context.Context, session *ServerSession, extra *RequestExtra) { + // Close the stream before the first message. We should have sent an + // initial priming message already, so the client will be able to replay + extra.CloseStream(10 * time.Millisecond) + session.NotifyProgress(ctx, &ProgressNotificationParams{Message: "msg1"}) + time.Sleep(20 * time.Millisecond) + extra.CloseStream(10 * time.Millisecond) // Closing twice should still be supported. + session.NotifyProgress(ctx, &ProgressNotificationParams{Message: "msg2"}) + } + + AddTool(server, &Tool{Name: "disconnect"}, + func(ctx context.Context, req *CallToolRequest, args map[string]any) (*CallToolResult, map[string]any, error) { + testStream(ctx, req.Session, req.Extra) + return new(CallToolResult), nil, nil + }) + + server.AddPrompt(&Prompt{Name: "disconnect"}, func(ctx context.Context, req *GetPromptRequest) (*GetPromptResult, error) { + testStream(ctx, req.Session, req.Extra) + return nil, nil + }) + + tests := []struct { + name string + doCall func(context.Context, *ClientSession) error + }{ + { + "tool", + func(ctx context.Context, cs *ClientSession) error { + _, err := cs.CallTool(ctx, &CallToolParams{Name: "disconnect"}) + return err + }, + }, + { + "prompt", + func(ctx context.Context, cs *ClientSession) error { + _, err := cs.GetPrompt(ctx, &GetPromptParams{Name: "disconnect"}) + return err + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // memory event store to support replayability. + // Then implement the new SEP, and assert that the tool call succeeds. + notifications := make(chan string, 2) + handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ + EventStore: NewMemoryEventStore(nil), + }) + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + client := NewClient(testImpl, &ClientOptions{ + ProgressNotificationHandler: func(ctx context.Context, req *ProgressNotificationClientRequest) { + notifications <- req.Params.Message + }, + }) + clientSession, err := client.Connect(ctx, &StreamableClientTransport{ + Endpoint: httpServer.URL, + }, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer clientSession.Close() + + if err = test.doCall(ctx, clientSession); err != nil { + t.Fatalf("CallTool failed: %v", err) + } + + got := readNotifications(t, ctx, notifications, 2) + want := []string{"msg1", "msg2"} + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("got unexpected notifications (-want +got):\n%s", diff) + } + }) + } +} + func TestServerTransportCleanup(t *testing.T) { nClient := 3 @@ -691,10 +779,10 @@ func TestStreamableServerTransport(t *testing.T) { tests := []struct { name string - replay bool // if set, use a MemoryEventStore to enable stream replay - tool func(*testing.T, context.Context, *ServerSession) - requests []streamableRequest // http requests - wantSessions int // number of sessions expected after the test + replay bool // if set, use a MemoryEventStore to enable replay + tool func(*testing.T, context.Context, *CallToolRequest) // if set, called during execution + requests []streamableRequest + wantSessions int // number of sessions expected after the test }{ { name: "basic", @@ -818,9 +906,9 @@ func TestStreamableServerTransport(t *testing.T) { }, { name: "tool notification", - tool: func(t *testing.T, ctx context.Context, ss *ServerSession) { + tool: func(t *testing.T, ctx context.Context, req *CallToolRequest) { // Send an arbitrary notification. - if err := ss.NotifyProgress(ctx, &ProgressNotificationParams{}); err != nil { + if err := req.Session.NotifyProgress(ctx, &ProgressNotificationParams{}); err != nil { t.Errorf("Notify failed: %v", err) } }, @@ -843,9 +931,9 @@ func TestStreamableServerTransport(t *testing.T) { }, { name: "tool upcall", - tool: func(t *testing.T, ctx context.Context, ss *ServerSession) { + tool: func(t *testing.T, ctx context.Context, req *CallToolRequest) { // Make an arbitrary call. - if _, err := ss.ListRoots(ctx, &ListRootsParams{}); err != nil { + if _, err := req.Session.ListRoots(ctx, &ListRootsParams{}); err != nil { t.Errorf("Call failed: %v", err) } }, @@ -878,21 +966,23 @@ func TestStreamableServerTransport(t *testing.T) { name: "background", // Enabling replay is necessary here because the standalone "GET" request // is fully asynronous. Replay is needed to guarantee message delivery. + // + // TODO(rfindley): this should no longer be necessary. replay: true, - tool: func(t *testing.T, _ context.Context, ss *ServerSession) { + tool: func(t *testing.T, _ context.Context, req *CallToolRequest) { // Perform operations on a background context, and ensure the client // receives it. ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - if err := ss.NotifyProgress(ctx, &ProgressNotificationParams{}); err != nil { + if err := req.Session.NotifyProgress(ctx, &ProgressNotificationParams{}); err != nil { t.Errorf("Notify failed: %v", err) } // TODO(rfindley): finish implementing logging. // if err := ss.LoggingMessage(ctx, &LoggingMessageParams{}); err != nil { // t.Errorf("Logging failed: %v", err) // } - if _, err := ss.ListRoots(ctx, &ListRootsParams{}); err != nil { + if _, err := req.Session.ListRoots(ctx, &ListRootsParams{}); err != nil { t.Errorf("ListRoots failed: %v", err) } }, @@ -937,6 +1027,60 @@ func TestStreamableServerTransport(t *testing.T) { }, wantSessions: 0, // session deleted }, + { + name: "priming message", + replay: true, + requests: []streamableRequest{ + initialize, + initialized, + { + method: "POST", + messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)}, + wantBodyContaining: "prime", + }, + }, + wantSessions: 1, + }, + { + name: "close message", + replay: true, + tool: func(t *testing.T, _ context.Context, req *CallToolRequest) { + req.Extra.CloseStream(time.Millisecond) + }, + requests: []streamableRequest{ + initialize, + initialized, + { + method: "POST", + messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)}, + wantBodyContaining: "close", + }, + }, + wantSessions: 1, + }, + { + name: "no close message", + replay: true, + tool: func(t *testing.T, _ context.Context, req *CallToolRequest) { + req.Extra.CloseStream(0) + }, + requests: []streamableRequest{ + initialize, + initialized, + { + method: "POST", + messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)}, + wantBodyNotContaining: "close", + }, + }, + wantSessions: 1, + }, { name: "errors", requests: []streamableRequest{ @@ -980,7 +1124,7 @@ func TestStreamableServerTransport(t *testing.T) { &Tool{Name: "tool", InputSchema: &jsonschema.Schema{Type: "object"}}, func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { if test.tool != nil { - test.tool(t, ctx, req.Session) + test.tool(t, ctx, req) } return &CallToolResult{}, nil }) @@ -1092,11 +1236,14 @@ func testStreamableHandler(t *testing.T, handler http.Handler, requests []stream } wg.Wait() - if request.wantBodyContaining != "" { + if request.wantBodyContaining != "" || request.wantBodyNotContaining != "" { body := string(gotBody) - if !strings.Contains(body, request.wantBodyContaining) { + if request.wantBodyContaining != "" && !strings.Contains(body, request.wantBodyContaining) { t.Errorf("body does not contain %q:\n%s", request.wantBodyContaining, body) } + if request.wantBodyNotContaining != "" && strings.Contains(body, request.wantBodyNotContaining) { + t.Errorf("body contains %q:\n%s", request.wantBodyNotContaining, body) + } } else { transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() }) if diff := cmp.Diff(request.wantMessages, got, transform); diff != "" { @@ -1148,11 +1295,12 @@ type streamableRequest struct { headers http.Header // additional headers to set, overlaid on top of the default headers messages []jsonrpc.Message // messages to send - closeAfter int // if nonzero, close after receiving this many messages - wantStatusCode int // expected status code - wantBodyContaining string // if set, expect the response body to contain this text; overrides wantMessages - wantMessages []jsonrpc.Message // expected messages to receive; ignored if wantBodyContaining is set - wantSessionID bool // whether or not a session ID is expected in the response + closeAfter int // if nonzero, close after receiving this many messages + wantStatusCode int // expected status code + wantBodyContaining string // if set, expect the response body to contain this text; overrides wantMessages + wantBodyNotContaining string // if set, a negative assertion on the body; overrides wantMessages + wantMessages []jsonrpc.Message // expected messages to receive; ignored if wantBodyContaining is set + wantSessionID bool // whether or not a session ID is expected in the response } // streamingRequest makes a request to the given streamable server with the @@ -1221,13 +1369,15 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, if err != nil { return newSessionID, resp.StatusCode, nil, fmt.Errorf("reading events: %v", err) } - // TODO(rfindley): do we need to check evt.name? - // Does the MCP spec say anything about this? - msg, err := jsonrpc2.DecodeMessage(evt.Data) - if err != nil { - return newSessionID, resp.StatusCode, nil, fmt.Errorf("decoding message: %w", err) + if evt.Name == "" || evt.Name == "message" { // ordinary message + // TODO(rfindley): do we need to check evt.name? + // Does the MCP spec say anything about this? + msg, err := jsonrpc2.DecodeMessage(evt.Data) + if err != nil { + return newSessionID, resp.StatusCode, nil, fmt.Errorf("decoding message: %w", err) + } + out <- msg } - out <- msg } respBody = r.w.Bytes() } else if strings.HasPrefix(contentType, "application/json") { @@ -1882,12 +2032,13 @@ data: {"jsonrpc":"2.0","method":"test2","params":{}} // Verify that we can decode the message events but would fail on ping events for i, evt := range events { - if evt.Name == "message" { + switch evt.Name { + case "message": _, err := jsonrpc.DecodeMessage(evt.Data) if err != nil { t.Errorf("event %d: failed to decode message event: %v", i, err) } - } else if evt.Name == "ping" { + case "ping": // Ping events have non-JSON data and should fail decoding _, err := jsonrpc.DecodeMessage(evt.Data) if err == nil { From 1c7d6c803327f6b17ecf667c3b4e3bf9215950d1 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Mon, 1 Dec 2025 19:23:14 +0000 Subject: [PATCH 2/2] address review comments --- mcp/streamable_test.go | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 2298e653..46889cb5 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -568,11 +568,9 @@ func TestStreamableServerDisconnect(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - // memory event store to support replayability. - // Then implement the new SEP, and assert that the tool call succeeds. notifications := make(chan string, 2) handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ - EventStore: NewMemoryEventStore(nil), + EventStore: NewMemoryEventStore(nil), // support replayability }) httpServer := httptest.NewServer(mustNotPanic(t, handler)) defer httpServer.Close() @@ -1053,11 +1051,12 @@ func TestStreamableServerTransport(t *testing.T) { initialize, initialized, { - method: "POST", - messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, - wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)}, - wantBodyContaining: "close", + method: "POST", + messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)}, + wantBodyContaining: "close", + wantBodyNotContaining: "result", }, }, wantSessions: 1,