Skip to content

Commit 487378b

Browse files
authored
fix(core): completions tool chunk concat behavior (#9544)
1 parent 0ad007b commit 487378b

File tree

5 files changed

+212
-74
lines changed

5 files changed

+212
-74
lines changed

.changeset/free-ends-shave.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@langchain/core": patch
3+
---
4+
5+
fix tool chunk concat behavior (#9450)

libs/langchain-core/src/messages/ai.ts

Lines changed: 2 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import { parsePartialJson } from "../utils/json.js";
21
import {
32
BaseMessage,
43
BaseMessageChunk,
@@ -21,7 +20,7 @@ import {
2120
ToolCallChunk,
2221
defaultToolCallParser,
2322
} from "./tool.js";
24-
import { Constructor } from "./utils.js";
23+
import { collapseToolCallChunks, Constructor } from "./utils.js";
2524

2625
export interface AIMessageFields<
2726
TStructure extends MessageStructure = MessageStructure
@@ -283,77 +282,9 @@ export class AIMessageChunk<
283282
: undefined,
284283
};
285284
} else {
286-
const toolCallChunks = fields.tool_call_chunks ?? [];
287-
const groupedToolCallChunks = toolCallChunks.reduce((acc, chunk) => {
288-
const matchedChunkIndex = acc.findIndex(([match]) => {
289-
// If chunk has an id and index, match if both are present
290-
if (
291-
"id" in chunk &&
292-
chunk.id &&
293-
"index" in chunk &&
294-
chunk.index !== undefined
295-
) {
296-
return chunk.id === match.id && chunk.index === match.index;
297-
}
298-
// If chunk has an id, we match on id
299-
if ("id" in chunk && chunk.id) {
300-
return chunk.id === match.id;
301-
}
302-
// If chunk has an index, we match on index
303-
if ("index" in chunk && chunk.index !== undefined) {
304-
return chunk.index === match.index;
305-
}
306-
return false;
307-
});
308-
if (matchedChunkIndex !== -1) {
309-
acc[matchedChunkIndex].push(chunk);
310-
} else {
311-
acc.push([chunk]);
312-
}
313-
return acc;
314-
}, [] as ToolCallChunk[][]);
315-
316-
const toolCalls: ToolCall[] = [];
317-
const invalidToolCalls: InvalidToolCall[] = [];
318-
for (const chunks of groupedToolCallChunks) {
319-
let parsedArgs: Record<string, unknown> | null = null;
320-
const name = chunks[0]?.name ?? "";
321-
const joinedArgs = chunks
322-
.map((c) => c.args || "")
323-
.join("")
324-
.trim();
325-
const argsStr = joinedArgs.length ? joinedArgs : "{}";
326-
const id = chunks[0]?.id;
327-
try {
328-
parsedArgs = parsePartialJson(argsStr);
329-
if (
330-
!id ||
331-
parsedArgs === null ||
332-
typeof parsedArgs !== "object" ||
333-
Array.isArray(parsedArgs)
334-
) {
335-
throw new Error("Malformed tool call chunk args.");
336-
}
337-
toolCalls.push({
338-
name,
339-
args: parsedArgs,
340-
id,
341-
type: "tool_call",
342-
});
343-
} catch {
344-
invalidToolCalls.push({
345-
name,
346-
args: argsStr,
347-
id,
348-
error: "Malformed args.",
349-
type: "invalid_tool_call",
350-
});
351-
}
352-
}
353285
initParams = {
354286
...fields,
355-
tool_calls: toolCalls,
356-
invalid_tool_calls: invalidToolCalls,
287+
...collapseToolCallChunks(fields.tool_call_chunks ?? []),
357288
usage_metadata:
358289
fields.usage_metadata !== undefined
359290
? fields.usage_metadata

libs/langchain-core/src/messages/base.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,8 +448,10 @@ export function _mergeDicts(
448448
} else if (
449449
["id", "name", "output_version", "model_provider"].includes(key)
450450
) {
451-
// Keep the incoming value for these fields
452-
merged[key] = value;
451+
// Keep the incoming value for these fields if its defined
452+
if (value) {
453+
merged[key] = value;
454+
}
453455
} else {
454456
merged[key] += value;
455457
}

libs/langchain-core/src/messages/tests/ai.test.ts

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,26 @@ describe("AIMessageChunk", () => {
384384
'{"ideas":["read about Angular 19 updates"]}'
385385
);
386386
expect(secondCall?.id).toBe("5abf542e-87f3-4899-87c6-8f7d9cb6a28d");
387+
388+
expect(merged.tool_calls).toHaveLength(2);
389+
expect(merged.tool_calls).toEqual([
390+
{
391+
id: "9fb5c937-6944-4173-84be-ad1caee1cedd",
392+
type: "tool_call",
393+
name: "add_new_task",
394+
args: {
395+
tasks: ["buy tomatoes", "help child with math"],
396+
},
397+
},
398+
{
399+
id: "5abf542e-87f3-4899-87c6-8f7d9cb6a28d",
400+
type: "tool_call",
401+
name: "add_ideas",
402+
args: {
403+
ideas: ["read about Angular 19 updates"],
404+
},
405+
},
406+
]);
387407
});
388408

389409
it("should properly merge tool call chunks that have matching indices and at least one id is blank", () => {
@@ -418,6 +438,75 @@ describe("AIMessageChunk", () => {
418438
'{"tasks":["buy tomatoes","help child with math"]}'
419439
);
420440
expect(firstCall?.id).toBe("9fb5c937-6944-4173-84be-ad1caee1cedd");
441+
442+
expect(merged.tool_calls).toHaveLength(1);
443+
expect(merged.tool_calls).toEqual([
444+
{
445+
type: "tool_call",
446+
name: "add_new_task",
447+
args: {
448+
tasks: ["buy tomatoes", "help child with math"],
449+
},
450+
id: "9fb5c937-6944-4173-84be-ad1caee1cedd",
451+
},
452+
]);
453+
});
454+
455+
// https://github.com/langchain-ai/langchainjs/issues/9450
456+
it("should properly concat a string of old completions-style tool call chunks", () => {
457+
const chunk1 = new AIMessageChunk({
458+
tool_call_chunks: [
459+
{
460+
name: "get_weather",
461+
args: "",
462+
id: "call_7171a25538d44feea5155a",
463+
index: 0,
464+
type: "tool_call_chunk",
465+
},
466+
],
467+
});
468+
const chunk2 = new AIMessageChunk({
469+
tool_call_chunks: [
470+
{
471+
name: undefined,
472+
args: '{"city": "',
473+
id: "",
474+
index: 0,
475+
type: "tool_call_chunk",
476+
},
477+
],
478+
});
479+
const chunk3 = new AIMessageChunk({
480+
tool_call_chunks: [
481+
{
482+
name: undefined,
483+
args: 'sf"}',
484+
id: "",
485+
index: 0,
486+
type: "tool_call_chunk",
487+
},
488+
],
489+
});
490+
491+
const merged = chunk1.concat(chunk2).concat(chunk3);
492+
expect(merged.tool_call_chunks).toHaveLength(1);
493+
494+
const firstCall = merged.tool_call_chunks?.[0];
495+
expect(firstCall?.name).toBe("get_weather");
496+
expect(firstCall?.args).toBe('{"city": "sf"}');
497+
expect(firstCall?.id).toBe("call_7171a25538d44feea5155a");
498+
499+
expect(merged.tool_calls).toHaveLength(1);
500+
expect(merged.tool_calls).toEqual([
501+
{
502+
type: "tool_call",
503+
name: "get_weather",
504+
args: {
505+
city: "sf",
506+
},
507+
id: "call_7171a25538d44feea5155a",
508+
},
509+
]);
421510
});
422511

423512
it("should properly merge tool call chunks that have matching indices no IDs at all", () => {

libs/langchain-core/src/messages/utils.ts

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { addLangChainErrorFields } from "../errors/index.js";
22
import { SerializedConstructor } from "../load/serializable.js";
33
import { _isToolCall } from "../tools/utils.js";
4+
import { parsePartialJson } from "../utils/json.js";
45
import { AIMessage, AIMessageChunk, AIMessageChunkFields } from "./ai.js";
56
import {
67
BaseMessageLike,
@@ -20,7 +21,13 @@ import {
2021
import { HumanMessage, HumanMessageChunk } from "./human.js";
2122
import { RemoveMessage } from "./modifier.js";
2223
import { SystemMessage, SystemMessageChunk } from "./system.js";
23-
import { ToolCall, ToolMessage, ToolMessageFields } from "./tool.js";
24+
import {
25+
InvalidToolCall,
26+
ToolCall,
27+
ToolCallChunk,
28+
ToolMessage,
29+
ToolMessageFields,
30+
} from "./tool.js";
2431

2532
export type $Expand<T> = T extends infer U ? { [K in keyof U]: U[K] } : never;
2633

@@ -445,3 +452,107 @@ export function convertToChunk(message: BaseMessage) {
445452
throw new Error("Unknown message type.");
446453
}
447454
}
455+
456+
/**
457+
* Collapses an array of tool call chunks into complete tool calls.
458+
*
459+
* This function groups tool call chunks by their id and/or index, then attempts to
460+
* parse and validate the accumulated arguments for each group. Successfully parsed
461+
* tool calls are returned as valid `ToolCall` objects, while malformed ones are
462+
* returned as `InvalidToolCall` objects.
463+
*
464+
* @param chunks - An array of `ToolCallChunk` objects to collapse
465+
* @returns An object containing:
466+
* - `tool_call_chunks`: The original input chunks
467+
* - `tool_calls`: An array of successfully parsed and validated tool calls
468+
* - `invalid_tool_calls`: An array of tool calls that failed parsing or validation
469+
*
470+
* @remarks
471+
* Chunks are grouped using the following matching logic:
472+
* - If a chunk has both an id and index, it matches chunks with the same id and index
473+
* - If a chunk has only an id, it matches chunks with the same id
474+
* - If a chunk has only an index, it matches chunks with the same index
475+
*
476+
* For each group, the function:
477+
* 1. Concatenates all `args` strings from the chunks
478+
* 2. Attempts to parse the concatenated string as JSON
479+
* 3. Validates that the result is a non-null object with a valid id
480+
* 4. Creates either a `ToolCall` (if valid) or `InvalidToolCall` (if invalid)
481+
*/
482+
export function collapseToolCallChunks(chunks: ToolCallChunk[]): {
483+
tool_call_chunks: ToolCallChunk[];
484+
tool_calls: ToolCall[];
485+
invalid_tool_calls: InvalidToolCall[];
486+
} {
487+
const groupedToolCallChunks = chunks.reduce((acc, chunk) => {
488+
const matchedChunkIndex = acc.findIndex(([match]) => {
489+
// If chunk has an id and index, match if both are present
490+
if (
491+
"id" in chunk &&
492+
chunk.id &&
493+
"index" in chunk &&
494+
chunk.index !== undefined
495+
) {
496+
return chunk.id === match.id && chunk.index === match.index;
497+
}
498+
// If chunk has an id, we match on id
499+
if ("id" in chunk && chunk.id) {
500+
return chunk.id === match.id;
501+
}
502+
// If chunk has an index, we match on index
503+
if ("index" in chunk && chunk.index !== undefined) {
504+
return chunk.index === match.index;
505+
}
506+
return false;
507+
});
508+
if (matchedChunkIndex !== -1) {
509+
acc[matchedChunkIndex].push(chunk);
510+
} else {
511+
acc.push([chunk]);
512+
}
513+
return acc;
514+
}, [] as ToolCallChunk[][]);
515+
516+
const toolCalls: ToolCall[] = [];
517+
const invalidToolCalls: InvalidToolCall[] = [];
518+
for (const chunks of groupedToolCallChunks) {
519+
let parsedArgs: Record<string, unknown> | null = null;
520+
const name = chunks[0]?.name ?? "";
521+
const joinedArgs = chunks
522+
.map((c) => c.args || "")
523+
.join("")
524+
.trim();
525+
const argsStr = joinedArgs.length ? joinedArgs : "{}";
526+
const id = chunks[0]?.id;
527+
try {
528+
parsedArgs = parsePartialJson(argsStr);
529+
if (
530+
!id ||
531+
parsedArgs === null ||
532+
typeof parsedArgs !== "object" ||
533+
Array.isArray(parsedArgs)
534+
) {
535+
throw new Error("Malformed tool call chunk args.");
536+
}
537+
toolCalls.push({
538+
name,
539+
args: parsedArgs,
540+
id,
541+
type: "tool_call",
542+
});
543+
} catch {
544+
invalidToolCalls.push({
545+
name,
546+
args: argsStr,
547+
id,
548+
error: "Malformed args.",
549+
type: "invalid_tool_call",
550+
});
551+
}
552+
}
553+
return {
554+
tool_call_chunks: chunks,
555+
tool_calls: toolCalls,
556+
invalid_tool_calls: invalidToolCalls,
557+
};
558+
}

0 commit comments

Comments
 (0)