Skip to content

Commit 1f0554b

Browse files
committed
re-implement ai sdk changes
1 parent d690688 commit 1f0554b

26 files changed

+316
-357
lines changed

backend/src/__tests__/cost-aggregation.integration.test.ts

Lines changed: 75 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { TEST_USER_ID } from '@codebuff/common/old-constants'
2-
import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime'
2+
import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime'
33
import { getInitialSessionState } from '@codebuff/common/types/session-state'
44
import {
55
spyOn,
@@ -12,12 +12,12 @@ import {
1212
} from 'bun:test'
1313

1414
import * as messageCostTracker from '../llm-apis/message-cost-tracker'
15-
import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk'
1615
import { mainPrompt } from '../main-prompt'
1716
import * as agentRegistry from '../templates/agent-registry'
1817
import * as websocketAction from '../websockets/websocket-action'
1918

2019
import type { AgentTemplate } from '../templates/types'
20+
import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime'
2121
import type { ProjectFileContext } from '@codebuff/common/util/file'
2222
import type { WebSocket } from 'ws'
2323

@@ -99,6 +99,7 @@ class MockWebSocket {
9999
describe('Cost Aggregation Integration Tests', () => {
100100
let mockLocalAgentTemplates: Record<string, any>
101101
let mockWebSocket: MockWebSocket
102+
let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL }
102103

103104
beforeEach(async () => {
104105
mockWebSocket = new MockWebSocket()
@@ -150,33 +151,31 @@ describe('Cost Aggregation Integration Tests', () => {
150151
// Mock LLM streaming
151152
let callCount = 0
152153
const creditHistory: number[] = []
153-
spyOn(aisdk, 'promptAiSdkStream').mockImplementation(
154-
async function* (options) {
155-
callCount++
156-
const credits = callCount === 1 ? 10 : 7 // Main agent vs subagent costs
157-
creditHistory.push(credits)
158-
159-
if (options.onCostCalculated) {
160-
await options.onCostCalculated(credits)
161-
}
154+
agentRuntimeImpl.promptAiSdkStream = async function* (options) {
155+
callCount++
156+
const credits = callCount === 1 ? 10 : 7 // Main agent vs subagent costs
157+
creditHistory.push(credits)
162158

163-
// Simulate different responses based on call
164-
if (callCount === 1) {
165-
// Main agent spawns a subagent
166-
yield {
167-
type: 'text' as const,
168-
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write a simple hello world file"}]}\n</codebuff_tool_call>',
169-
}
170-
} else {
171-
// Subagent writes a file
172-
yield {
173-
type: 'text' as const,
174-
text: '<codebuff_tool_call>\n{"cb_tool_name": "write_file", "path": "hello.txt", "instructions": "Create hello world file", "content": "Hello, World!"}\n</codebuff_tool_call>',
175-
}
159+
if (options.onCostCalculated) {
160+
await options.onCostCalculated(credits)
161+
}
162+
163+
// Simulate different responses based on call
164+
if (callCount === 1) {
165+
// Main agent spawns a subagent
166+
yield {
167+
type: 'text' as const,
168+
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write a simple hello world file"}]}\n</codebuff_tool_call>',
176169
}
177-
return 'mock-message-id'
178-
},
179-
)
170+
} else {
171+
// Subagent writes a file
172+
yield {
173+
type: 'text' as const,
174+
text: '<codebuff_tool_call>\n{"cb_tool_name": "write_file", "path": "hello.txt", "instructions": "Create hello world file", "content": "Hello, World!"}\n</codebuff_tool_call>',
175+
}
176+
}
177+
return 'mock-message-id'
178+
}
180179

181180
// Mock tool call execution
182181
spyOn(websocketAction, 'requestToolCall').mockImplementation(
@@ -231,6 +230,7 @@ describe('Cost Aggregation Integration Tests', () => {
231230

232231
afterEach(() => {
233232
mock.restore()
233+
agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL }
234234
})
235235

236236
it('should correctly aggregate costs across the entire main prompt flow', async () => {
@@ -250,7 +250,7 @@ describe('Cost Aggregation Integration Tests', () => {
250250
}
251251

252252
const result = await mainPrompt({
253-
...testAgentRuntimeImpl,
253+
...agentRuntimeImpl,
254254
ws: mockWebSocket as unknown as WebSocket,
255255
action,
256256
userId: TEST_USER_ID,
@@ -285,7 +285,7 @@ describe('Cost Aggregation Integration Tests', () => {
285285

286286
// Call through websocket action handler to test full integration
287287
await websocketAction.callMainPrompt({
288-
...testAgentRuntimeImpl,
288+
...agentRuntimeImpl,
289289
ws: mockWebSocket as unknown as WebSocket,
290290
action,
291291
userId: TEST_USER_ID,
@@ -308,37 +308,35 @@ describe('Cost Aggregation Integration Tests', () => {
308308
it('should handle multi-level subagent hierarchies correctly', async () => {
309309
// Mock a more complex scenario with nested subagents
310310
let callCount = 0
311-
spyOn(aisdk, 'promptAiSdkStream').mockImplementation(
312-
async function* (options) {
313-
callCount++
311+
agentRuntimeImpl.promptAiSdkStream = async function* (options) {
312+
callCount++
314313

315-
if (options.onCostCalculated) {
316-
await options.onCostCalculated(5) // Each call costs 5 credits
317-
}
314+
if (options.onCostCalculated) {
315+
await options.onCostCalculated(5) // Each call costs 5 credits
316+
}
318317

319-
if (callCount === 1) {
320-
// Main agent spawns first-level subagent
321-
yield {
322-
type: 'text' as const,
323-
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Create files"}]}\n</codebuff_tool_call>',
324-
}
325-
} else if (callCount === 2) {
326-
// First-level subagent spawns second-level subagent
327-
yield {
328-
type: 'text' as const,
329-
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write specific file"}]}\n</codebuff_tool_call>',
330-
}
331-
} else {
332-
// Second-level subagent does actual work
333-
yield {
334-
type: 'text' as const,
335-
text: '<codebuff_tool_call>\n{"cb_tool_name": "write_file", "path": "nested.txt", "instructions": "Create nested file", "content": "Nested content"}\n</codebuff_tool_call>',
336-
}
318+
if (callCount === 1) {
319+
// Main agent spawns first-level subagent
320+
yield {
321+
type: 'text' as const,
322+
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Create files"}]}\n</codebuff_tool_call>',
323+
}
324+
} else if (callCount === 2) {
325+
// First-level subagent spawns second-level subagent
326+
yield {
327+
type: 'text' as const,
328+
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write specific file"}]}\n</codebuff_tool_call>',
337329
}
330+
} else {
331+
// Second-level subagent does actual work
332+
yield {
333+
type: 'text' as const,
334+
text: '<codebuff_tool_call>\n{"cb_tool_name": "write_file", "path": "nested.txt", "instructions": "Create nested file", "content": "Nested content"}\n</codebuff_tool_call>',
335+
}
336+
}
338337

339-
return 'mock-message-id'
340-
},
341-
)
338+
return 'mock-message-id'
339+
}
342340

343341
const sessionState = getInitialSessionState(mockFileContext)
344342
sessionState.mainAgentState.stepsRemaining = 10
@@ -355,7 +353,7 @@ describe('Cost Aggregation Integration Tests', () => {
355353
}
356354

357355
const result = await mainPrompt({
358-
...testAgentRuntimeImpl,
356+
...agentRuntimeImpl,
359357
ws: mockWebSocket as unknown as WebSocket,
360358
action,
361359
userId: TEST_USER_ID,
@@ -373,29 +371,27 @@ describe('Cost Aggregation Integration Tests', () => {
373371
it('should maintain cost integrity when subagents fail', async () => {
374372
// Mock scenario where subagent fails after incurring partial costs
375373
let callCount = 0
376-
spyOn(aisdk, 'promptAiSdkStream').mockImplementation(
377-
async function* (options) {
378-
callCount++
374+
agentRuntimeImpl.promptAiSdkStream = async function* (options) {
375+
callCount++
379376

380-
if (options.onCostCalculated) {
381-
await options.onCostCalculated(6) // Each call costs 6 credits
382-
}
377+
if (options.onCostCalculated) {
378+
await options.onCostCalculated(6) // Each call costs 6 credits
379+
}
383380

384-
if (callCount === 1) {
385-
// Main agent spawns subagent
386-
yield {
387-
type: 'text' as const,
388-
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "This will fail"}]}\n</codebuff_tool_call>',
389-
}
390-
} else {
391-
// Subagent fails after incurring cost
392-
yield { type: 'text' as const, text: 'Some response' }
393-
throw new Error('Subagent execution failed')
381+
if (callCount === 1) {
382+
// Main agent spawns subagent
383+
yield {
384+
type: 'text' as const,
385+
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "This will fail"}]}\n</codebuff_tool_call>',
394386
}
387+
} else {
388+
// Subagent fails after incurring cost
389+
yield { type: 'text' as const, text: 'Some response' }
390+
throw new Error('Subagent execution failed')
391+
}
395392

396-
return 'mock-message-id'
397-
},
398-
)
393+
return 'mock-message-id'
394+
}
399395

400396
const sessionState = getInitialSessionState(mockFileContext)
401397
sessionState.mainAgentState.agentType = 'base'
@@ -413,7 +409,7 @@ describe('Cost Aggregation Integration Tests', () => {
413409
let result
414410
try {
415411
result = await mainPrompt({
416-
...testAgentRuntimeImpl,
412+
...agentRuntimeImpl,
417413
ws: mockWebSocket as unknown as WebSocket,
418414
action,
419415
userId: TEST_USER_ID,
@@ -462,7 +458,7 @@ describe('Cost Aggregation Integration Tests', () => {
462458
}
463459

464460
await mainPrompt({
465-
...testAgentRuntimeImpl,
461+
...agentRuntimeImpl,
466462
ws: mockWebSocket as unknown as WebSocket,
467463
action,
468464
userId: TEST_USER_ID,
@@ -502,7 +498,7 @@ describe('Cost Aggregation Integration Tests', () => {
502498

503499
// Call through websocket action to test server-side reset
504500
await websocketAction.callMainPrompt({
505-
...testAgentRuntimeImpl,
501+
...agentRuntimeImpl,
506502
ws: mockWebSocket as unknown as WebSocket,
507503
action,
508504
userId: TEST_USER_ID,

backend/src/__tests__/cost-aggregation.test.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime'
1+
import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime'
22
import {
33
getInitialAgentState,
44
getInitialSessionState,
@@ -180,7 +180,7 @@ describe('Cost Aggregation System', () => {
180180
}
181181

182182
const result = handleSpawnAgents({
183-
...testAgentRuntimeImpl,
183+
...TEST_AGENT_RUNTIME_IMPL,
184184
previousToolCallFinished: Promise.resolve(),
185185
toolCall: mockToolCall,
186186
fileContext: mockFileContext,
@@ -260,7 +260,7 @@ describe('Cost Aggregation System', () => {
260260
}
261261

262262
const result = handleSpawnAgents({
263-
...testAgentRuntimeImpl,
263+
...TEST_AGENT_RUNTIME_IMPL,
264264
previousToolCallFinished: Promise.resolve(),
265265
toolCall: mockToolCall,
266266
fileContext: mockFileContext,
@@ -417,7 +417,7 @@ describe('Cost Aggregation System', () => {
417417
}
418418

419419
const result = handleSpawnAgents({
420-
...testAgentRuntimeImpl,
420+
...TEST_AGENT_RUNTIME_IMPL,
421421
previousToolCallFinished: Promise.resolve(),
422422
toolCall: mockToolCall,
423423
fileContext: mockFileContext,

backend/src/__tests__/generate-diffs-prompt.test.ts

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,8 @@
1+
import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime'
12
import { expect, describe, it } from 'bun:test'
23

34
import { parseAndGetDiffBlocksSingleFile } from '../generate-diffs-prompt'
45

5-
import type { Logger } from '@codebuff/common/types/contracts/logger'
6-
7-
const logger: Logger = {
8-
debug: () => {},
9-
info: () => {},
10-
warn: () => {},
11-
error: () => {},
12-
}
13-
146
describe('parseAndGetDiffBlocksSingleFile', () => {
157
it('should parse diff blocks with newline before closing marker', () => {
168
const oldContent = 'function test() {\n return true;\n}\n'
@@ -26,9 +18,9 @@ function test() {
2618
>>>>>>> REPLACE`
2719

2820
const result = parseAndGetDiffBlocksSingleFile({
21+
...TEST_AGENT_RUNTIME_IMPL,
2922
newContent,
3023
oldFileContent: oldContent,
31-
logger,
3224
})
3325
console.log(JSON.stringify({ result }))
3426

@@ -55,9 +47,9 @@ function test() {
5547
}>>>>>>> REPLACE`
5648

5749
const result = parseAndGetDiffBlocksSingleFile({
50+
...TEST_AGENT_RUNTIME_IMPL,
5851
newContent,
5952
oldFileContent: oldContent,
60-
logger,
6153
})
6254

6355
expect(result.diffBlocks.length).toBe(1)
@@ -108,9 +100,9 @@ function subtract(a, b) {
108100
>>>>>>> REPLACE`
109101

110102
const result = parseAndGetDiffBlocksSingleFile({
103+
...TEST_AGENT_RUNTIME_IMPL,
111104
newContent,
112105
oldFileContent: oldContent,
113-
logger,
114106
})
115107

116108
expect(result.diffBlocks.length).toBe(2)
@@ -136,9 +128,9 @@ function subtract(a, b) {
136128
>>>>>>> REPLACE`
137129

138130
const result = parseAndGetDiffBlocksSingleFile({
131+
...TEST_AGENT_RUNTIME_IMPL,
139132
newContent,
140133
oldFileContent: oldContent,
141-
logger,
142134
})
143135

144136
expect(result.diffBlocks.length).toBe(1)

0 commit comments

Comments
 (0)