Skip to content

Commit 91f93af

Browse files
committed
mcp: improve http transports error handling and make buffer size configurable
1 parent 76e6854 commit 91f93af

File tree

6 files changed

+171
-31
lines changed

6 files changed

+171
-31
lines changed

mcp/event.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,12 @@ func writeEvent(w io.Writer, evt Event) (int, error) {
6666
//
6767
// TODO(rfindley): consider a different API here that makes failure modes more
6868
// apparent.
69-
func scanEvents(r io.Reader) iter.Seq2[Event, error] {
69+
func scanEvents(r io.Reader, maxLineSize int) iter.Seq2[Event, error] {
7070
scanner := bufio.NewScanner(r)
71-
const maxTokenSize = 1 * 1024 * 1024 // 1 MiB max line size
72-
scanner.Buffer(nil, maxTokenSize)
71+
if maxLineSize == 0 {
72+
maxLineSize = 1 * 1024 * 1024 // defaults to 1MB
73+
}
74+
scanner.Buffer(nil, maxLineSize)
7375

7476
// TODO: investigate proper behavior when events are out of order, or have
7577
// non-standard names.
@@ -139,7 +141,7 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] {
139141
}
140142
if err := scanner.Err(); err != nil {
141143
if errors.Is(err, bufio.ErrTooLong) {
142-
err = fmt.Errorf("event exceeded max line length of %d", maxTokenSize)
144+
err = fmt.Errorf("event exceeded max line length of %d", maxLineSize)
143145
}
144146
if !yield(Event{}, err) {
145147
return

mcp/event_test.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@ import (
1515

1616
func TestScanEvents(t *testing.T) {
1717
tests := []struct {
18-
name string
19-
input string
20-
want []Event
21-
wantErr string
18+
name string
19+
input string
20+
want []Event
21+
wantErr string
22+
maxLineSize int
2223
}{
2324
{
2425
name: "simple event",
@@ -54,14 +55,20 @@ func TestScanEvents(t *testing.T) {
5455
input: "invalid line\n\n",
5556
wantErr: "malformed line",
5657
},
58+
{
59+
name: "event exceeds buffer size",
60+
input: "data: " + strings.Repeat("x", 200) + "\n\n",
61+
maxLineSize: 100,
62+
wantErr: "event exceeded max line length of 100",
63+
},
5764
}
5865

5966
for _, tt := range tests {
6067
t.Run(tt.name, func(t *testing.T) {
6168
r := strings.NewReader(tt.input)
6269
var got []Event
6370
var err error
64-
for e, err2 := range scanEvents(r) {
71+
for e, err2 := range scanEvents(r, tt.maxLineSize) {
6572
if err2 != nil {
6673
err = err2
6774
break

mcp/sse.go

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,9 @@ type SSEClientTransport struct {
329329
// HTTPClient is the client to use for making HTTP requests. If nil,
330330
// http.DefaultClient is used.
331331
HTTPClient *http.Client
332+
333+
// MaxLineSize is the maximum buffer size used when reading a message. It defaults to 1MB
334+
MaxLineSize int
332335
}
333336

334337
// Connect connects through the client endpoint.
@@ -353,7 +356,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {
353356

354357
msgEndpoint, err := func() (*url.URL, error) {
355358
var evt Event
356-
for evt, err = range scanEvents(resp.Body) {
359+
for evt, err = range scanEvents(resp.Body, c.MaxLineSize) {
357360
break
358361
}
359362
if err != nil {
@@ -374,20 +377,24 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {
374377
s := &sseClientConn{
375378
client: httpClient,
376379
msgEndpoint: msgEndpoint,
377-
incoming: make(chan []byte, 100),
380+
incoming: make(chan sseMessage, 100),
378381
body: resp.Body,
379382
done: make(chan struct{}),
380383
}
381384

382385
go func() {
383386
defer s.Close() // close the transport when the GET exits
384387

385-
for evt, err := range scanEvents(resp.Body) {
388+
for evt, err := range scanEvents(resp.Body, c.MaxLineSize) {
386389
if err != nil {
390+
select {
391+
case s.incoming <- sseMessage{err: err}:
392+
case <-s.done:
393+
}
387394
return
388395
}
389396
select {
390-
case s.incoming <- evt.Data:
397+
case s.incoming <- sseMessage{data: evt.Data}:
391398
case <-s.done:
392399
return
393400
}
@@ -397,15 +404,21 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {
397404
return s, nil
398405
}
399406

407+
// sseMessage represents a message or error from the SSE stream.
408+
type sseMessage struct {
409+
data []byte
410+
err error
411+
}
412+
400413
// An sseClientConn is a logical jsonrpc2 connection that implements the client
401414
// half of the SSE protocol:
402415
// - Writes are POSTS to the session endpoint.
403416
// - Reads are SSE 'message' events, and pushes them onto a buffered channel.
404417
// - Close terminates the GET request.
405418
type sseClientConn struct {
406-
client *http.Client // HTTP client to use for requests
407-
msgEndpoint *url.URL // session endpoint for POSTs
408-
incoming chan []byte // queue of incoming messages
419+
client *http.Client // HTTP client to use for requests
420+
msgEndpoint *url.URL // session endpoint for POSTs
421+
incoming chan sseMessage // queue of incoming messages or errors
409422

410423
mu sync.Mutex
411424
body io.ReadCloser // body of the hanging GET
@@ -430,12 +443,15 @@ func (c *sseClientConn) Read(ctx context.Context) (jsonrpc.Message, error) {
430443
case <-c.done:
431444
return nil, io.EOF
432445

433-
case data := <-c.incoming:
446+
case m := <-c.incoming:
447+
if m.err != nil {
448+
return nil, m.err
449+
}
434450
// TODO(rfindley): do we really need to check this? We receive from c.done above.
435451
if c.isDone() {
436452
return nil, io.EOF
437453
}
438-
msg, err := jsonrpc2.DecodeMessage(data)
454+
msg, err := jsonrpc2.DecodeMessage(m.data)
439455
if err != nil {
440456
return nil, err
441457
}

mcp/streamable.go

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,6 +1388,8 @@ type StreamableClientTransport struct {
13881388
// MaxRetries is the maximum number of times to attempt a reconnect before giving up.
13891389
// It defaults to 5. To disable retries, use a negative number.
13901390
MaxRetries int
1391+
// MaxLineSize is the maximum buffer size used when reading a message. It defaults to 1MB
1392+
MaxLineSize int
13911393

13921394
// TODO(rfindley): propose exporting these.
13931395
// If strict is set, the transport is in 'strict mode', where any violation
@@ -1453,16 +1455,17 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
14531455
// middleware), yet only cancel the standalone stream when the connection is closed.
14541456
connCtx, cancel := context.WithCancel(xcontext.Detach(ctx))
14551457
conn := &streamableClientConn{
1456-
url: t.Endpoint,
1457-
client: client,
1458-
incoming: make(chan jsonrpc.Message, 10),
1459-
done: make(chan struct{}),
1460-
maxRetries: maxRetries,
1461-
strict: t.strict,
1462-
logger: ensureLogger(t.logger), // must be non-nil for safe logging
1463-
ctx: connCtx,
1464-
cancel: cancel,
1465-
failed: make(chan struct{}),
1458+
url: t.Endpoint,
1459+
client: client,
1460+
incoming: make(chan jsonrpc.Message, 10),
1461+
done: make(chan struct{}),
1462+
maxRetries: maxRetries,
1463+
strict: t.strict,
1464+
logger: ensureLogger(t.logger), // must be non-nil for safe logging
1465+
ctx: connCtx,
1466+
cancel: cancel,
1467+
failed: make(chan struct{}),
1468+
maxLineSize: t.MaxLineSize,
14661469
}
14671470
return conn, nil
14681471
}
@@ -1497,6 +1500,7 @@ type streamableClientConn struct {
14971500
mu sync.Mutex
14981501
initializedResult *InitializeResult
14991502
sessionID string
1503+
maxLineSize int
15001504
}
15011505

15021506
// errSessionMissing distinguishes if the session is known to not be present on
@@ -1854,11 +1858,14 @@ func (c *streamableClientConn) processStream(ctx context.Context, requestSummary
18541858
io.Copy(io.Discard, resp.Body)
18551859
resp.Body.Close()
18561860
}()
1857-
for evt, err := range scanEvents(resp.Body) {
1861+
for evt, err := range scanEvents(resp.Body, c.maxLineSize) {
18581862
if err != nil {
18591863
if ctx.Err() != nil {
18601864
return "", 0, true // don't reconnect: client cancelled
18611865
}
1866+
1867+
// EOF errors are returned as nil from bufio.Scanner, so all errors should be returned back
1868+
c.fail(fmt.Errorf("%s: failed to process stream: %v", requestSummary, err))
18621869
break
18631870
}
18641871

mcp/streamable_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1425,7 +1425,7 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string,
14251425
var respBody []byte
14261426
if strings.HasPrefix(contentType, "text/event-stream") {
14271427
r := readerInto{resp.Body, new(bytes.Buffer)}
1428-
for evt, err := range scanEvents(r) {
1428+
for evt, err := range scanEvents(r, 0) {
14291429
if err != nil {
14301430
return newSessionID, resp.StatusCode, nil, fmt.Errorf("reading events: %v", err)
14311431
}
@@ -2143,7 +2143,7 @@ data: {"jsonrpc":"2.0","method":"test2","params":{}}
21432143
var events []Event
21442144

21452145
// Scan all events
2146-
for evt, err := range scanEvents(reader) {
2146+
for evt, err := range scanEvents(reader, 0) {
21472147
if err != nil {
21482148
if err != io.EOF {
21492149
t.Fatalf("scanEvents error: %v", err)

mcp/transport_test.go

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ package mcp
77
import (
88
"context"
99
"io"
10+
"net/http"
11+
"net/http/httptest"
1012
"strings"
1113
"testing"
1214

@@ -124,3 +126,109 @@ func TestIOConnRead(t *testing.T) {
124126
})
125127
}
126128
}
129+
130+
func TestScanEventsBufferError(t *testing.T) {
131+
ctx := context.Background()
132+
tests := []struct {
133+
name string
134+
clientTransport func(url string) Transport
135+
serverHandler func(server *Server) http.Handler
136+
responseLength int
137+
expectedContainsError string
138+
}{
139+
{
140+
name: "sse-large-output",
141+
clientTransport: func(url string) Transport {
142+
return &SSEClientTransport{
143+
Endpoint: url,
144+
MaxLineSize: 1024,
145+
}
146+
},
147+
serverHandler: func(server *Server) http.Handler {
148+
return NewSSEHandler(func(req *http.Request) *Server { return server }, nil)
149+
},
150+
responseLength: 10000,
151+
expectedContainsError: "exceeded max line length",
152+
},
153+
{
154+
name: "streamable-large-output",
155+
clientTransport: func(url string) Transport {
156+
return &StreamableClientTransport{
157+
Endpoint: url,
158+
MaxLineSize: 1024,
159+
}
160+
},
161+
serverHandler: func(server *Server) http.Handler {
162+
return NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
163+
},
164+
responseLength: 10000,
165+
expectedContainsError: "exceeded max line length",
166+
},
167+
{
168+
name: "sse-small-output",
169+
clientTransport: func(url string) Transport {
170+
return &SSEClientTransport{
171+
Endpoint: url,
172+
MaxLineSize: 1024,
173+
}
174+
},
175+
serverHandler: func(server *Server) http.Handler {
176+
return NewSSEHandler(func(req *http.Request) *Server { return server }, nil)
177+
},
178+
responseLength: 512,
179+
},
180+
{
181+
name: "streamable-small-output",
182+
clientTransport: func(url string) Transport {
183+
return &StreamableClientTransport{
184+
Endpoint: url,
185+
MaxLineSize: 1024,
186+
}
187+
},
188+
serverHandler: func(server *Server) http.Handler {
189+
return NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
190+
},
191+
responseLength: 512,
192+
},
193+
}
194+
for _, tt := range tests {
195+
t.Run(tt.name, func(t *testing.T) {
196+
largeResponse := strings.Repeat("x", tt.responseLength)
197+
server := NewServer(testImpl, nil)
198+
AddTool(server, &Tool{Name: "largeTool", Description: "returns large response"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) {
199+
return &CallToolResult{Content: []Content{&TextContent{Text: largeResponse}}}, nil, nil
200+
})
201+
202+
httpHandler := tt.serverHandler(server)
203+
httpServer := httptest.NewServer(mustNotPanic(t, httpHandler))
204+
defer httpServer.Close()
205+
206+
client := NewClient(testImpl, nil)
207+
clientTransport := tt.clientTransport(httpServer.URL)
208+
session, err := client.Connect(ctx, clientTransport, nil)
209+
if err != nil {
210+
t.Fatalf("client.Connect() failed: %v", err)
211+
}
212+
defer session.Close()
213+
214+
_, err = session.CallTool(ctx, &CallToolParams{
215+
Name: "largeTool",
216+
Arguments: map[string]any{},
217+
})
218+
if tt.expectedContainsError != "" {
219+
if tt.expectedContainsError != "" && err == nil {
220+
t.Fatal("expected error due to small buffer, got nil")
221+
}
222+
223+
if !strings.Contains(err.Error(), "exceeded max line length") {
224+
t.Fatalf("expected buffer-related error, got: %v", err)
225+
}
226+
} else {
227+
if err != nil {
228+
t.Fatalf("client.CallTool() unexpectedly failed: %v", err)
229+
}
230+
}
231+
232+
})
233+
}
234+
}

0 commit comments

Comments
 (0)