Skip to content

Commit 651f03a

Browse files
committed
fix: harden sdk stream filtering
1 parent f8bae66 commit 651f03a

File tree

2 files changed

+140
-92
lines changed

2 files changed

+140
-92
lines changed

sdk/src/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import type { ToolName as ToolNameType } from '../../common/src/tools/constants'
2+
13
export type * from '../../common/src/types/json'
24
export type * from '../../common/src/types/messages/codebuff-message'
35
export type * from '../../common/src/types/messages/data-content'
46
export type * from '../../common/src/types/print-mode'
57
export type * from './run'
68
// Agent type exports
79
export type { AgentDefinition } from '../../common/src/templates/initial-agents-dir/types/agent-definition'
10+
export type ToolName = ToolNameType
811

912
// Re-export code analysis functionality
1013
export * from '../../packages/code-map/src/index'

sdk/src/run.ts

Lines changed: 137 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,34 @@ import {
4141

4242
type ToolXmlFilterState = {
4343
buffer: string
44-
inside: boolean
44+
activeTag: 'tool_call' | 'tool_result' | null
4545
}
4646

4747
const TOOL_XML_OPEN = `<${toolXmlName}>`
4848
const TOOL_XML_CLOSE = `</${toolXmlName}>`
4949
const TOOL_XML_PREFIX = `<${toolXmlName}`
50-
const TOOL_XML_CLOSE_PREFIX = `</${toolXmlName}`
50+
const TOOL_RESULT_OPEN = '<tool_result>'
51+
const TOOL_RESULT_CLOSE = '</tool_result>'
52+
const TOOL_RESULT_PREFIX = '<tool_result'
53+
54+
const TAG_DEFINITIONS = [
55+
{
56+
type: 'tool_call' as const,
57+
open: TOOL_XML_OPEN,
58+
close: TOOL_XML_CLOSE,
59+
prefix: TOOL_XML_PREFIX,
60+
},
61+
{
62+
type: 'tool_result' as const,
63+
open: TOOL_RESULT_OPEN,
64+
close: TOOL_RESULT_CLOSE,
65+
prefix: TOOL_RESULT_PREFIX,
66+
},
67+
]
68+
69+
const TAG_INFO_BY_TYPE = Object.fromEntries(
70+
TAG_DEFINITIONS.map((tag) => [tag.type, tag]),
71+
)
5172

5273
const getPartialStartIndex = (value: string, pattern: string): number => {
5374
const max = Math.min(pattern.length - 1, value.length)
@@ -64,18 +85,38 @@ function filterToolXmlFromText(
6485
state: ToolXmlFilterState,
6586
incoming: string,
6687
maxBuffer: number,
67-
): { text: string; removed: boolean } {
88+
): { text: string } {
6889
if (incoming) {
6990
state.buffer += incoming
7091
}
7192

7293
let sanitized = ''
73-
let removed = false
94+
7495
while (state.buffer.length > 0) {
75-
if (!state.inside) {
76-
const openIndex = state.buffer.indexOf(TOOL_XML_OPEN)
77-
if (openIndex === -1) {
78-
const partialIndex = getPartialStartIndex(state.buffer, TOOL_XML_PREFIX)
96+
if (state.activeTag == null) {
97+
let nextTag: {
98+
index: number
99+
definition: (typeof TAG_DEFINITIONS)[number]
100+
} | null = null
101+
102+
for (const definition of TAG_DEFINITIONS) {
103+
const index = state.buffer.indexOf(definition.open)
104+
if (index !== -1) {
105+
if (nextTag == null || index < nextTag.index) {
106+
nextTag = { index, definition }
107+
}
108+
}
109+
}
110+
111+
if (!nextTag) {
112+
let partialIndex = -1
113+
for (const definition of TAG_DEFINITIONS) {
114+
const idx = getPartialStartIndex(state.buffer, definition.prefix)
115+
if (idx !== -1 && (partialIndex === -1 || idx < partialIndex)) {
116+
partialIndex = idx
117+
}
118+
}
119+
79120
if (partialIndex === -1) {
80121
sanitized += state.buffer
81122
state.buffer = ''
@@ -86,40 +127,49 @@ function filterToolXmlFromText(
86127
break
87128
}
88129

89-
sanitized += state.buffer.slice(0, openIndex)
90-
state.buffer = state.buffer.slice(openIndex + TOOL_XML_OPEN.length)
91-
state.inside = true
92-
removed = true
130+
sanitized += state.buffer.slice(0, nextTag.index)
131+
state.buffer = state.buffer.slice(
132+
nextTag.index + nextTag.definition.open.length,
133+
)
134+
state.activeTag = nextTag.definition.type
93135
} else {
94-
const closeIndex = state.buffer.indexOf(TOOL_XML_CLOSE)
136+
const definition = TAG_INFO_BY_TYPE[state.activeTag]
137+
const closeIndex = state.buffer.indexOf(definition.close)
138+
95139
if (closeIndex === -1) {
96140
const partialCloseIndex = getPartialStartIndex(
97141
state.buffer,
98-
TOOL_XML_CLOSE,
142+
definition.close,
99143
)
100144
if (partialCloseIndex === -1) {
101-
const keepIndex = Math.max(
102-
0,
103-
state.buffer.length - (TOOL_XML_CLOSE.length - 1),
104-
)
105-
state.buffer = state.buffer.slice(keepIndex)
145+
const keepLength = definition.close.length - 1
146+
if (state.buffer.length > keepLength) {
147+
state.buffer = state.buffer.slice(
148+
state.buffer.length - keepLength,
149+
)
150+
}
106151
} else {
107152
state.buffer = state.buffer.slice(partialCloseIndex)
108153
}
154+
155+
if (state.buffer.length > maxBuffer) {
156+
state.buffer = state.buffer.slice(-maxBuffer)
157+
}
109158
break
110159
}
111160

112-
state.buffer = state.buffer.slice(closeIndex + TOOL_XML_CLOSE.length)
113-
state.inside = false
114-
removed = true
161+
state.buffer = state.buffer.slice(
162+
closeIndex + definition.close.length,
163+
)
164+
state.activeTag = null
115165
}
116166
}
117167

118168
if (state.buffer.length > maxBuffer) {
119169
state.buffer = state.buffer.slice(-maxBuffer)
120170
}
121171

122-
return { text: sanitized, removed }
172+
return { text: sanitized }
123173
}
124174

125175
export type CodebuffClientOptions = {
@@ -211,23 +261,29 @@ export async function run({
211261
const MAX_TOOL_XML_BUFFER = BUFFER_SIZE * 10
212262
const ROOT_AGENT_KEY = '__root__'
213263

214-
const streamFilterState: ToolXmlFilterState = { buffer: '', inside: false }
264+
const streamFilterState: ToolXmlFilterState = { buffer: '', activeTag: null }
215265
const textFilterStates = new Map<string, ToolXmlFilterState>()
216266

217-
const subagentBuffers = new Map<
218-
string,
219-
{ buffer: string; insideToolCall: boolean }
220-
>()
267+
const subagentFilterStates = new Map<string, ToolXmlFilterState>()
221268

222269
const getTextFilterState = (agentKey: string): ToolXmlFilterState => {
223270
let state = textFilterStates.get(agentKey)
224271
if (!state) {
225-
state = { buffer: '', inside: false }
272+
state = { buffer: '', activeTag: null }
226273
textFilterStates.set(agentKey, state)
227274
}
228275
return state
229276
}
230277

278+
const getSubagentFilterState = (agentId: string): ToolXmlFilterState => {
279+
let state = subagentFilterStates.get(agentId)
280+
if (!state) {
281+
state = { buffer: '', activeTag: null }
282+
subagentFilterStates.set(agentId, state)
283+
}
284+
return state
285+
}
286+
231287
const flushTextState = async (
232288
agentKey: string,
233289
eventAgentId?: string,
@@ -244,13 +300,11 @@ export async function run({
244300
)
245301
let remainder = pendingText
246302

247-
if (!state.inside && state.buffer) {
248-
if (!state.buffer.includes('<')) {
249-
remainder += state.buffer
250-
}
303+
if (state.buffer && !state.buffer.includes('<')) {
304+
remainder += state.buffer
251305
}
252306
state.buffer = ''
253-
state.inside = false
307+
state.activeTag = null
254308

255309
textFilterStates.delete(agentKey)
256310

@@ -263,6 +317,36 @@ export async function run({
263317
}
264318
}
265319

320+
const flushSubagentState = async (
321+
agentId: string,
322+
agentType?: string,
323+
): Promise<void> => {
324+
const state = subagentFilterStates.get(agentId)
325+
if (!state) {
326+
return
327+
}
328+
329+
const { text: pendingText } = filterToolXmlFromText(
330+
state,
331+
'',
332+
MAX_TOOL_XML_BUFFER,
333+
)
334+
335+
subagentFilterStates.delete(agentId)
336+
state.buffer = ''
337+
state.activeTag = null
338+
339+
const trimmed = pendingText.trim()
340+
if (trimmed) {
341+
await handleEvent?.({
342+
type: 'subagent-chunk',
343+
agentId,
344+
agentType,
345+
chunk: pendingText,
346+
} as any)
347+
}
348+
}
349+
266350
const websocketHandler = new WebSocketHandler({
267351
apiKey,
268352
onWebsocketError: (error) => {
@@ -330,15 +414,12 @@ export async function run({
330414
MAX_TOOL_XML_BUFFER,
331415
)
332416
let remainder = streamTail
333-
if (
334-
!streamFilterState.inside &&
335-
streamFilterState.buffer &&
336-
!streamFilterState.buffer.includes('<')
337-
) {
417+
418+
if (streamFilterState.buffer && !streamFilterState.buffer.includes('<')) {
338419
remainder += streamFilterState.buffer
339420
}
340421
streamFilterState.buffer = ''
341-
streamFilterState.inside = false
422+
streamFilterState.activeTag = null
342423

343424
if (remainder) {
344425
await handleStreamChunk?.(remainder)
@@ -350,6 +431,10 @@ export async function run({
350431
(chunk as typeof chunk & { agentId?: string }).agentId
351432
if (finishAgentKey && finishAgentKey !== ROOT_AGENT_KEY) {
352433
await flushTextState(finishAgentKey, finishAgentKey)
434+
await flushSubagentState(
435+
finishAgentKey,
436+
(chunk as { agentType?: string }).agentType,
437+
)
353438
}
354439
} else if (
355440
chunkType === 'subagent_finish' ||
@@ -358,6 +443,10 @@ export async function run({
358443
const subagentId = (chunk as { agentId?: string }).agentId
359444
if (subagentId) {
360445
await flushTextState(subagentId, subagentId)
446+
await flushSubagentState(
447+
subagentId,
448+
(chunk as { agentType?: string }).agentType,
449+
)
361450
}
362451
}
363452

@@ -368,63 +457,19 @@ export async function run({
368457
checkAborted(signal)
369458
const { agentId, agentType, chunk } = action
370459

371-
// Filter out tool call XML from subagent chunks
372-
let bufferState = subagentBuffers.get(agentId)
373-
if (!bufferState) {
374-
bufferState = { buffer: '', insideToolCall: false }
375-
subagentBuffers.set(agentId, bufferState)
376-
}
377-
378-
bufferState.buffer += chunk
379-
let filteredChunk = ''
380-
381-
if (
382-
!bufferState.insideToolCall &&
383-
bufferState.buffer.includes(TOOL_XML_OPEN)
384-
) {
385-
const openTagIndex = bufferState.buffer.indexOf(TOOL_XML_OPEN)
386-
const beforeTag = bufferState.buffer.substring(0, openTagIndex)
387-
if (beforeTag) {
388-
filteredChunk = beforeTag
389-
}
390-
bufferState.insideToolCall = true
391-
bufferState.buffer = bufferState.buffer.substring(openTagIndex)
392-
} else if (
393-
bufferState.insideToolCall &&
394-
bufferState.buffer.includes(TOOL_XML_CLOSE)
395-
) {
396-
const closeTagIndex = bufferState.buffer.indexOf(TOOL_XML_CLOSE)
397-
bufferState.insideToolCall = false
398-
bufferState.buffer = bufferState.buffer.substring(
399-
closeTagIndex + TOOL_XML_CLOSE.length,
400-
)
401-
} else if (!bufferState.insideToolCall) {
402-
if (bufferState.buffer.length > 50) {
403-
const safeToOutput = bufferState.buffer.substring(
404-
0,
405-
bufferState.buffer.length - 50,
406-
)
407-
filteredChunk = safeToOutput
408-
bufferState.buffer = bufferState.buffer.substring(
409-
bufferState.buffer.length - 50,
410-
)
411-
}
412-
}
413-
414-
if (
415-
bufferState.insideToolCall &&
416-
bufferState.buffer.length > MAX_TOOL_XML_BUFFER
417-
) {
418-
bufferState.buffer = bufferState.buffer.slice(-MAX_TOOL_XML_BUFFER)
419-
}
460+
const state = getSubagentFilterState(agentId)
461+
const { text: sanitized } = filterToolXmlFromText(
462+
state,
463+
chunk,
464+
MAX_TOOL_XML_BUFFER,
465+
)
420466

421-
// Send filtered chunk to handler
422-
if (filteredChunk && handleEvent) {
467+
if (sanitized && handleEvent) {
423468
await handleEvent({
424469
type: 'subagent-chunk',
425470
agentId,
426471
agentType,
427-
chunk: filteredChunk,
472+
chunk: sanitized,
428473
} as any)
429474
}
430475
},

0 commit comments

Comments
 (0)