diff --git a/backend/src/__tests__/agent-registry.test.ts b/backend/src/__tests__/agent-registry.test.ts index aa9c38d7ba..d4f8637836 100644 --- a/backend/src/__tests__/agent-registry.test.ts +++ b/backend/src/__tests__/agent-registry.test.ts @@ -26,6 +26,13 @@ import type { DynamicAgentTemplate } from '@codebuff/common/types/dynamic-agent- import type { ProjectFileContext } from '@codebuff/common/util/file' import type { Logger } from '@codebuff/types/logger' +const logger: Logger = { + debug: () => {}, + error: () => {}, + info: () => {}, + warn: () => {}, +} + // Create mock static templates that will be used by the agent registry const mockStaticTemplates: Record = { base: { @@ -101,16 +108,6 @@ describe('Agent Registry', () => { desc: (field: any) => ({ type: 'desc', field }), eq: (field: any, value: any) => ({ type: 'eq', field, value }), })) - - // Mock logger - mockModule('@codebuff/backend/util/logger', () => ({ - logger: { - debug: () => {}, - log: () => {}, - error: () => {}, - warn: () => {}, - }, - })) }) let mockFileContext: ProjectFileContext @@ -150,16 +147,6 @@ describe('Agent Registry', () => { desc: (field: any) => ({ type: 'desc', field }), eq: (field: any, value: any) => ({ type: 'eq', field, value }), })) - - // Mock logger - mockModule('@codebuff/backend/util/logger', () => ({ - logger: { - debug: () => {}, - log: () => {}, - error: () => {}, - warn: () => {}, - }, - })) }) beforeEach(async () => { @@ -253,33 +240,80 @@ describe('Agent Registry', () => { } as AgentTemplate, } - const result = await getAgentTemplate('my-agent', localAgents) + const logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, + } + const result = await getAgentTemplate({ + agentId: 'my-agent', + localAgentTemplates: localAgents, + logger, + }) expect(result).toBeTruthy() expect(result?.id).toBe('my-agent') }) it('should handle agent IDs with publisher but no version', async () => { - const result = await getAgentTemplate('publisher/agent-name', {}) + const logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, + } + const result = await getAgentTemplate({ + agentId: 'publisher/agent-name', + localAgentTemplates: {}, + logger, + }) expect(result).toBeNull() }) it('should handle agent IDs with publisher and version', async () => { - const result = await getAgentTemplate('publisher/agent-name@1.0.0', {}) + const logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, + } + const result = await getAgentTemplate({ + agentId: 'publisher/agent-name@1.0.0', + localAgentTemplates: {}, + logger, + }) expect(result).toBeNull() }) it('should return null for invalid agent ID formats', async () => { - const result = await getAgentTemplate( - 'invalid/format/with/too/many/slashes', - {}, - ) + const logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, + } + const result = await getAgentTemplate({ + agentId: 'invalid/format/with/too/many/slashes', + localAgentTemplates: {}, + logger, + }) expect(result).toBeNull() }) }) describe('fetchAgentFromDatabase', () => { it('should return null when agent not found in database', async () => { - const result = await getAgentTemplate('nonexistent/agent@1.0.0', {}) + const logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, + } + const result = await getAgentTemplate({ + agentId: 'nonexistent/agent@1.0.0', + localAgentTemplates: {}, + logger, + }) expect(result).toBeNull() }) @@ -317,10 +351,17 @@ describe('Agent Registry', () => { }) as any, ) - const result = await getAgentTemplate( - 'test-publisher/test-agent@1.0.0', - {}, - ) + const logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, + } + const result = await getAgentTemplate({ + agentId: 'test-publisher/test-agent@1.0.0', + localAgentTemplates: {}, + logger, + }) expect(result).toBeTruthy() expect(result?.id).toBe('test-publisher/test-agent@1.0.0') }) @@ -347,7 +388,17 @@ describe('Agent Registry', () => { } as AgentTemplate, } - const result = await getAgentTemplate('test-agent', localAgents) + const logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, + } + const result = await getAgentTemplate({ + agentId: 'test-agent', + localAgentTemplates: localAgents, + logger, + }) expect(result).toBeTruthy() expect(result?.displayName).toBe('Local Test Agent') }) @@ -387,18 +438,20 @@ describe('Agent Registry', () => { ) // First call - should hit database - const result1 = await getAgentTemplate( - 'test-publisher/cached-agent@1.0.0', - {}, - ) + const result1 = await getAgentTemplate({ + agentId: 'test-publisher/cached-agent@1.0.0', + localAgentTemplates: {}, + logger, + }) expect(result1).toBeTruthy() expect(selectSpy).toHaveBeenCalledTimes(1) // Second call - should use cache - const result2 = await getAgentTemplate( - 'test-publisher/cached-agent@1.0.0', - {}, - ) + const result2 = await getAgentTemplate({ + agentId: 'test-publisher/cached-agent@1.0.0', + localAgentTemplates: {}, + logger, + }) expect(result2).toBeTruthy() expect(result2?.displayName).toBe('Cached Agent') expect(selectSpy).toHaveBeenCalledTimes(1) @@ -426,7 +479,13 @@ describe('Agent Registry', () => { }, } - const result = assembleLocalAgentTemplates(fileContext) + const logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, + } + const result = assembleLocalAgentTemplates({ fileContext, logger }) // Should have dynamic template expect(result.agentTemplates).toHaveProperty('custom-agent') @@ -450,7 +509,13 @@ describe('Agent Registry', () => { }, } - const result = assembleLocalAgentTemplates(fileContext) + const logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, + } + const result = assembleLocalAgentTemplates({ fileContext, logger }) // Should not have invalid template expect(result.agentTemplates).not.toHaveProperty('invalid-agent') @@ -465,7 +530,13 @@ describe('Agent Registry', () => { agentTemplates: {}, } - const result = assembleLocalAgentTemplates(fileContext) + const logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, + } + const result = assembleLocalAgentTemplates({ fileContext, logger }) // Should have no validation errors expect(result.validationErrors).toHaveLength(0) @@ -511,35 +582,83 @@ describe('Agent Registry', () => { ) // First call - should hit database and populate cache - await getAgentTemplate('test-publisher/cache-test-agent@1.0.0', {}) + const logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, + } + await getAgentTemplate({ + agentId: 'test-publisher/cache-test-agent@1.0.0', + localAgentTemplates: {}, + logger, + }) expect(selectSpy).toHaveBeenCalledTimes(1) // Second call - should use cache - await getAgentTemplate('test-publisher/cache-test-agent@1.0.0', {}) + await getAgentTemplate({ + agentId: 'test-publisher/cache-test-agent@1.0.0', + localAgentTemplates: {}, + logger, + }) expect(selectSpy).toHaveBeenCalledTimes(1) // Clear cache clearDatabaseCache() // Third call - should hit database again after cache clear - await getAgentTemplate('test-publisher/cache-test-agent@1.0.0', {}) + await getAgentTemplate({ + agentId: 'test-publisher/cache-test-agent@1.0.0', + localAgentTemplates: {}, + logger, + }) expect(selectSpy).toHaveBeenCalledTimes(2) }) }) describe('edge cases', () => { it('should handle empty agent ID', async () => { - const result = await getAgentTemplate('', {}) + const logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, + } + const result = await getAgentTemplate({ + agentId: '', + localAgentTemplates: {}, + logger, + }) expect(result).toBeNull() }) it('should handle agent ID with multiple @ symbols', async () => { - const result = await getAgentTemplate('publisher/agent@1.0.0@extra', {}) + const logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, + } + const result = await getAgentTemplate({ + agentId: 'publisher/agent@1.0.0@extra', + localAgentTemplates: {}, + logger, + }) expect(result).toBeNull() }) it('should handle agent ID with only @ symbol', async () => { - const result = await getAgentTemplate('publisher/agent@', {}) + const logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, + } + const result = await getAgentTemplate({ + agentId: 'publisher/agent@', + localAgentTemplates: {}, + logger, + }) expect(result).toBeNull() }) @@ -549,7 +668,17 @@ describe('Agent Registry', () => { throw new Error('Database connection failed') }) - const result = await getAgentTemplate('publisher/agent@1.0.0', {}) + const logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, + } + const result = await getAgentTemplate({ + agentId: 'publisher/agent@1.0.0', + localAgentTemplates: {}, + logger, + }) expect(result).toBeNull() }) @@ -579,10 +708,17 @@ describe('Agent Registry', () => { }) as any, ) - const result = await getAgentTemplate( - 'publisher/malformed-agent@1.0.0', - {}, - ) + const logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, + } + const result = await getAgentTemplate({ + agentId: 'publisher/malformed-agent@1.0.0', + localAgentTemplates: {}, + logger, + }) expect(result).toBeNull() }) }) diff --git a/backend/src/__tests__/cost-aggregation.integration.test.ts b/backend/src/__tests__/cost-aggregation.integration.test.ts index 01ad5a45d7..2b52da0dd1 100644 --- a/backend/src/__tests__/cost-aggregation.integration.test.ts +++ b/backend/src/__tests__/cost-aggregation.integration.test.ts @@ -229,7 +229,7 @@ describe('Cost Aggregation Integration Tests', () => { // Mock getAgentTemplate to return our mock templates spyOn(agentRegistry, 'getAgentTemplate').mockImplementation( - async (agentId, localAgentTemplates) => { + async ({ agentId, localAgentTemplates }) => { return localAgentTemplates[agentId] || null }, ) diff --git a/backend/src/__tests__/cost-aggregation.test.ts b/backend/src/__tests__/cost-aggregation.test.ts index cadaade2ba..5d75378d68 100644 --- a/backend/src/__tests__/cost-aggregation.test.ts +++ b/backend/src/__tests__/cost-aggregation.test.ts @@ -18,8 +18,16 @@ import { handleSpawnAgents } from '../tools/handlers/tool/spawn-agents' import type { AgentState } from '@codebuff/common/types/session-state' import type { ProjectFileContext } from '@codebuff/common/util/file' +import type { Logger } from '@codebuff/types/logger' import type { WebSocket } from 'ws' +const logger: Logger = { + debug: () => {}, + error: () => {}, + info: () => {}, + warn: () => {}, +} + const mockFileContext: ProjectFileContext = { projectRoot: '/test', cwd: '/test', @@ -187,6 +195,7 @@ describe('Cost Aggregation System', () => { writeToClient: () => {}, getLatestState: () => ({ messages: [] }), state: mockValidatedState, + logger, }) await result.result @@ -266,6 +275,7 @@ describe('Cost Aggregation System', () => { writeToClient: () => {}, getLatestState: () => ({ messages: [] }), state: mockValidatedState, + logger, }) await result.result @@ -422,6 +432,7 @@ describe('Cost Aggregation System', () => { writeToClient: () => {}, getLatestState: () => ({ messages: [] }), state: mockValidatedState, + logger, }) await result.result diff --git a/backend/src/__tests__/generate-diffs-prompt.test.ts b/backend/src/__tests__/generate-diffs-prompt.test.ts index 7cbfa78255..cfc1d9cb44 100644 --- a/backend/src/__tests__/generate-diffs-prompt.test.ts +++ b/backend/src/__tests__/generate-diffs-prompt.test.ts @@ -2,6 +2,15 @@ import { expect, describe, it } from 'bun:test' import { parseAndGetDiffBlocksSingleFile } from '../generate-diffs-prompt' +import type { Logger } from '@codebuff/types/logger' + +const logger: Logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, +} + describe('parseAndGetDiffBlocksSingleFile', () => { it('should parse diff blocks with newline before closing marker', () => { const oldContent = 'function test() {\n return true;\n}\n' @@ -16,7 +25,11 @@ function test() { } >>>>>>> REPLACE` - const result = parseAndGetDiffBlocksSingleFile(newContent, oldContent) + const result = parseAndGetDiffBlocksSingleFile({ + newContent, + oldFileContent: oldContent, + logger, + }) console.log(JSON.stringify({ result })) expect(result.diffBlocks.length).toBe(1) @@ -41,7 +54,11 @@ function test() { return true; }>>>>>>> REPLACE` - const result = parseAndGetDiffBlocksSingleFile(newContent, oldContent) + const result = parseAndGetDiffBlocksSingleFile({ + newContent, + oldFileContent: oldContent, + logger, + }) expect(result.diffBlocks.length).toBe(1) expect(result.diffBlocksThatDidntMatch.length).toBe(0) @@ -90,7 +107,11 @@ function subtract(a, b) { } >>>>>>> REPLACE` - const result = parseAndGetDiffBlocksSingleFile(newContent, oldContent) + const result = parseAndGetDiffBlocksSingleFile({ + newContent, + oldFileContent: oldContent, + logger, + }) expect(result.diffBlocks.length).toBe(2) expect(result.diffBlocksThatDidntMatch.length).toBe(0) @@ -114,7 +135,11 @@ function subtract(a, b) { ======= >>>>>>> REPLACE` - const result = parseAndGetDiffBlocksSingleFile(newContent, oldContent) + const result = parseAndGetDiffBlocksSingleFile({ + newContent, + oldFileContent: oldContent, + logger, + }) expect(result.diffBlocks.length).toBe(1) expect(result.diffBlocksThatDidntMatch.length).toBe(0) diff --git a/backend/src/__tests__/live-user-inputs.test.ts b/backend/src/__tests__/live-user-inputs.test.ts index 89554b5984..1bc1a6cc51 100644 --- a/backend/src/__tests__/live-user-inputs.test.ts +++ b/backend/src/__tests__/live-user-inputs.test.ts @@ -10,6 +10,15 @@ import { resetLiveUserInputsState, } from '../live-user-inputs' +import type { Logger } from '@codebuff/types/logger' + +const logger: Logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, +} + describe('live-user-inputs', () => { beforeEach(() => { // Clear any existing state before each test @@ -54,7 +63,7 @@ describe('live-user-inputs', () => { startUserInput({ userId: 'user-1', userInputId: 'input-123' }) startUserInput({ userId: 'user-1', userInputId: 'input-456' }) - cancelUserInput({ userId: 'user-1', userInputId: 'input-123' }) + cancelUserInput({ userId: 'user-1', userInputId: 'input-123', logger }) const liveInputs = getLiveUserInputIds('user-1') expect(liveInputs).toEqual(['input-456']) @@ -63,7 +72,7 @@ describe('live-user-inputs', () => { it('should remove user from tracking when all inputs cancelled', () => { startUserInput({ userId: 'user-1', userInputId: 'input-123' }) - cancelUserInput({ userId: 'user-1', userInputId: 'input-123' }) + cancelUserInput({ userId: 'user-1', userInputId: 'input-123', logger }) const liveInputs = getLiveUserInputIds('user-1') expect(liveInputs).toBeUndefined() @@ -74,7 +83,11 @@ describe('live-user-inputs', () => { // Should not throw expect(() => { - cancelUserInput({ userId: 'user-1', userInputId: 'input-nonexistent' }) + cancelUserInput({ + userId: 'user-1', + userInputId: 'input-nonexistent', + logger, + }) }).not.toThrow() const liveInputs = getLiveUserInputIds('user-1') @@ -84,7 +97,11 @@ describe('live-user-inputs', () => { it('should handle cancelling for non-existent user gracefully', () => { // Should not throw expect(() => { - cancelUserInput({ userId: 'user-nonexistent', userInputId: 'input-123' }) + cancelUserInput({ + userId: 'user-nonexistent', + userInputId: 'input-123', + logger, + }) }).not.toThrow() }) }) @@ -94,7 +111,7 @@ describe('live-user-inputs', () => { // Note: Testing the actual behavior requires integration with the constants module // For unit testing, we'll test the function directly startUserInput({ userId: 'user-1', userInputId: 'input-123' }) - cancelUserInput({ userId: 'user-1', userInputId: 'input-123' }) // This simulates the behavior when async agents disabled + cancelUserInput({ userId: 'user-1', userInputId: 'input-123', logger }) // This simulates the behavior when async agents disabled const liveInputs = getLiveUserInputIds('user-1') expect(liveInputs).toBeUndefined() @@ -116,7 +133,11 @@ describe('live-user-inputs', () => { startUserInput({ userId: 'user-1', userInputId: 'input-123' }) setSessionConnected('session-1', true) - const isLive = checkLiveUserInput('user-1', 'input-123', 'session-1') + const isLive = checkLiveUserInput({ + userId: 'user-1', + userInputId: 'input-123', + clientSessionId: 'session-1', + }) expect(isLive).toBe(true) }) @@ -124,29 +145,33 @@ describe('live-user-inputs', () => { startUserInput({ userId: 'user-1', userInputId: 'input-123' }) setSessionConnected('session-1', true) - const isLive = checkLiveUserInput( - 'user-1', - 'input-123-async-agent', - 'session-1', - ) + const isLive = checkLiveUserInput({ + userId: 'user-1', + userInputId: 'input-123-async-agent', + clientSessionId: 'session-1', + }) expect(isLive).toBe(true) }) it('should return false for non-existent user', () => { setSessionConnected('session-1', true) - const isLive = checkLiveUserInput( - 'user-nonexistent', - 'input-123', - 'session-1', - ) + const isLive = checkLiveUserInput({ + userId: 'user-nonexistent', + userInputId: 'input-123', + clientSessionId: 'session-1', + }) expect(isLive).toBe(false) }) it('should return false for undefined user', () => { setSessionConnected('session-1', true) - const isLive = checkLiveUserInput(undefined, 'input-123', 'session-1') + const isLive = checkLiveUserInput({ + userId: undefined, + userInputId: 'input-123', + clientSessionId: 'session-1', + }) expect(isLive).toBe(false) }) @@ -154,7 +179,11 @@ describe('live-user-inputs', () => { startUserInput({ userId: 'user-1', userInputId: 'input-123' }) setSessionConnected('session-1', false) - const isLive = checkLiveUserInput('user-1', 'input-123', 'session-1') + const isLive = checkLiveUserInput({ + userId: 'user-1', + userInputId: 'input-123', + clientSessionId: 'session-1', + }) expect(isLive).toBe(false) }) @@ -162,14 +191,22 @@ describe('live-user-inputs', () => { startUserInput({ userId: 'user-1', userInputId: 'input-123' }) setSessionConnected('session-1', true) - const isLive = checkLiveUserInput('user-1', 'input-456', 'session-1') + const isLive = checkLiveUserInput({ + userId: 'user-1', + userInputId: 'input-456', + clientSessionId: 'session-1', + }) expect(isLive).toBe(false) }) it('should return true when live user input check is disabled', () => { disableLiveUserInputCheck() - const isLive = checkLiveUserInput('user-1', 'input-123', 'session-1') + const isLive = checkLiveUserInput({ + userId: 'user-1', + userInputId: 'input-123', + clientSessionId: 'session-1', + }) expect(isLive).toBe(true) }) }) @@ -179,7 +216,11 @@ describe('live-user-inputs', () => { setSessionConnected('session-1', true) startUserInput({ userId: 'user-1', userInputId: 'input-123' }) - const isLive = checkLiveUserInput('user-1', 'input-123', 'session-1') + const isLive = checkLiveUserInput({ + userId: 'user-1', + userInputId: 'input-123', + clientSessionId: 'session-1', + }) expect(isLive).toBe(true) }) @@ -188,11 +229,23 @@ describe('live-user-inputs', () => { startUserInput({ userId: 'user-1', userInputId: 'input-123' }) // First verify it's connected - expect(checkLiveUserInput('user-1', 'input-123', 'session-1')).toBe(true) + expect( + checkLiveUserInput({ + userId: 'user-1', + userInputId: 'input-123', + clientSessionId: 'session-1', + }), + ).toBe(true) // Then disconnect setSessionConnected('session-1', false) - expect(checkLiveUserInput('user-1', 'input-123', 'session-1')).toBe(false) + expect( + checkLiveUserInput({ + userId: 'user-1', + userInputId: 'input-123', + clientSessionId: 'session-1', + }), + ).toBe(false) }) it('should handle multiple sessions independently', () => { @@ -201,8 +254,20 @@ describe('live-user-inputs', () => { startUserInput({ userId: 'user-1', userInputId: 'input-123' }) - expect(checkLiveUserInput('user-1', 'input-123', 'session-1')).toBe(true) - expect(checkLiveUserInput('user-1', 'input-123', 'session-2')).toBe(false) + expect( + checkLiveUserInput({ + userId: 'user-1', + userInputId: 'input-123', + clientSessionId: 'session-1', + }), + ).toBe(true) + expect( + checkLiveUserInput({ + userId: 'user-1', + userInputId: 'input-123', + clientSessionId: 'session-2', + }), + ).toBe(false) }) }) @@ -233,14 +298,26 @@ describe('live-user-inputs', () => { startUserInput({ userId: 'user-1', userInputId: 'input-123' }) // Verify input is live - expect(checkLiveUserInput('user-1', 'input-123', 'session-1')).toBe(true) + expect( + checkLiveUserInput({ + userId: 'user-1', + userInputId: 'input-123', + clientSessionId: 'session-1', + }), + ).toBe(true) expect(getLiveUserInputIds('user-1')).toEqual(['input-123']) // End user input - cancelUserInput({ userId: 'user-1', userInputId: 'input-123' }) + cancelUserInput({ userId: 'user-1', userInputId: 'input-123', logger }) // Verify input is no longer live - expect(checkLiveUserInput('user-1', 'input-123', 'session-1')).toBe(false) + expect( + checkLiveUserInput({ + userId: 'user-1', + userInputId: 'input-123', + clientSessionId: 'session-1', + }), + ).toBe(false) expect(getLiveUserInputIds('user-1')).toBeUndefined() }) @@ -250,13 +327,25 @@ describe('live-user-inputs', () => { startUserInput({ userId: 'user-1', userInputId: 'input-123' }) // Verify input is live - expect(checkLiveUserInput('user-1', 'input-123', 'session-1')).toBe(true) + expect( + checkLiveUserInput({ + userId: 'user-1', + userInputId: 'input-123', + clientSessionId: 'session-1', + }), + ).toBe(true) // Disconnect session setSessionConnected('session-1', false) // Input should no longer be considered live - expect(checkLiveUserInput('user-1', 'input-123', 'session-1')).toBe(false) + expect( + checkLiveUserInput({ + userId: 'user-1', + userInputId: 'input-123', + clientSessionId: 'session-1', + }), + ).toBe(false) // But input ID should still exist (for potential reconnection) expect(getLiveUserInputIds('user-1')).toEqual(['input-123']) @@ -268,15 +357,39 @@ describe('live-user-inputs', () => { startUserInput({ userId: 'user-1', userInputId: 'input-123' }) startUserInput({ userId: 'user-1', userInputId: 'input-456' }) - expect(checkLiveUserInput('user-1', 'input-123', 'session-1')).toBe(true) - expect(checkLiveUserInput('user-1', 'input-456', 'session-1')).toBe(true) + expect( + checkLiveUserInput({ + userId: 'user-1', + userInputId: 'input-123', + clientSessionId: 'session-1', + }), + ).toBe(true) + expect( + checkLiveUserInput({ + userId: 'user-1', + userInputId: 'input-456', + clientSessionId: 'session-1', + }), + ).toBe(true) expect(getLiveUserInputIds('user-1')).toEqual(['input-123', 'input-456']) // Cancel one input - cancelUserInput({ userId: 'user-1', userInputId: 'input-123' }) - - expect(checkLiveUserInput('user-1', 'input-123', 'session-1')).toBe(false) - expect(checkLiveUserInput('user-1', 'input-456', 'session-1')).toBe(true) + cancelUserInput({ userId: 'user-1', userInputId: 'input-123', logger }) + + expect( + checkLiveUserInput({ + userId: 'user-1', + userInputId: 'input-123', + clientSessionId: 'session-1', + }), + ).toBe(false) + expect( + checkLiveUserInput({ + userId: 'user-1', + userInputId: 'input-456', + clientSessionId: 'session-1', + }), + ).toBe(true) expect(getLiveUserInputIds('user-1')).toEqual(['input-456']) }) }) diff --git a/backend/src/__tests__/loop-agent-steps.test.ts b/backend/src/__tests__/loop-agent-steps.test.ts index c10e7ad658..158eca8ecd 100644 --- a/backend/src/__tests__/loop-agent-steps.test.ts +++ b/backend/src/__tests__/loop-agent-steps.test.ts @@ -18,6 +18,7 @@ import { mock, spyOn, } from 'bun:test' +import { z } from 'zod/v4' import { withAppContext } from '../context/app-context' import { loopAgentSteps } from '../run-agent-step' @@ -25,11 +26,20 @@ import { clearAgentGeneratorCache } from '../run-programmatic-step' import { mockFileContext, MockWebSocket } from './test-utils' import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk' +import type { getAgentTemplate } from '../templates/agent-registry' import type { AgentTemplate } from '../templates/types' import type { StepGenerator } from '@codebuff/common/types/agent-template' import type { AgentState } from '@codebuff/common/types/session-state' +import type { ParamsOf } from '@codebuff/types/common' +import type { Logger } from '@codebuff/types/logger' import type { WebSocket } from 'ws' -import { z } from 'zod/v4' + +const logger: Logger = { + debug: () => {}, + error: () => {}, + info: () => {}, + warn: () => {}, +} describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => { let mockTemplate: AgentTemplate @@ -37,8 +47,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => let llmCallCount: number const runLoopAgentStepsWithContext = async ( - ws: WebSocket, - options: Parameters[1], + options: ParamsOf, ) => { return await withAppContext( { @@ -49,22 +58,11 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => currentUserId: options.userId, processedRepoId: 'test-repo', }, - async () => loopAgentSteps(ws, options), + async () => loopAgentSteps(options), ) } beforeAll(() => { - // Mock logger - mockModule('@codebuff/backend/util/logger', () => ({ - logger: { - debug: () => {}, - error: () => {}, - info: () => {}, - warn: () => {}, - }, - withLoggerContext: async (context: any, fn: () => Promise) => fn(), - })) - // Mock bigquery mockModule('@codebuff/bigquery', () => ({ insertTrace: () => {}, @@ -72,8 +70,11 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => // Mock agent registry mockModule('@codebuff/backend/templates/agent-registry', () => ({ - getAgentTemplate: async (agentType: string, localTemplates: any) => { - return localTemplates[agentType] || mockTemplate + getAgentTemplate: async ({ + agentId, + localAgentTemplates, + }: ParamsOf) => { + return localAgentTemplates[agentId] || mockTemplate }, })) @@ -100,7 +101,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => }) beforeEach(() => { - clearAgentGeneratorCache() + clearAgentGeneratorCache({ logger }) llmCallCount = 0 @@ -171,7 +172,7 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => }) afterEach(() => { - clearAgentGeneratorCache() + clearAgentGeneratorCache({ logger }) mock.restore() }) @@ -204,22 +205,21 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => 'test-agent': mockTemplate, } - const result = await runLoopAgentStepsWithContext( - new MockWebSocket() as unknown as WebSocket, - { - userInputId: 'test-user-input', - agentType: 'test-agent', - agentState: mockAgentState, - prompt: 'Test prompt', - params: undefined, - fingerprintId: 'test-fingerprint', - fileContext: mockFileContext, - localAgentTemplates, - userId: TEST_USER_ID, - clientSessionId: 'test-session', - onResponseChunk: () => {}, - }, - ) + const result = await runLoopAgentStepsWithContext({ + ws: new MockWebSocket() as unknown as WebSocket, + userInputId: 'test-user-input', + agentType: 'test-agent', + agentState: mockAgentState, + prompt: 'Test prompt', + params: undefined, + fingerprintId: 'test-fingerprint', + fileContext: mockFileContext, + localAgentTemplates, + userId: TEST_USER_ID, + clientSessionId: 'test-session', + onResponseChunk: () => {}, + logger, + }) console.log(`LLM calls made: ${llmCallCount}`) console.log(`Step count: ${stepCount}`) @@ -251,22 +251,21 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => 'test-agent': mockTemplate, } - const result = await runLoopAgentStepsWithContext( - new MockWebSocket() as unknown as WebSocket, - { - userInputId: 'test-user-input', - agentType: 'test-agent', - agentState: mockAgentState, - prompt: 'Test prompt', - params: undefined, - fingerprintId: 'test-fingerprint', - fileContext: mockFileContext, - localAgentTemplates, - userId: TEST_USER_ID, - clientSessionId: 'test-session', - onResponseChunk: () => {}, - }, - ) + const result = await runLoopAgentStepsWithContext({ + ws: new MockWebSocket() as unknown as WebSocket, + userInputId: 'test-user-input', + agentType: 'test-agent', + agentState: mockAgentState, + prompt: 'Test prompt', + params: undefined, + fingerprintId: 'test-fingerprint', + fileContext: mockFileContext, + localAgentTemplates, + userId: TEST_USER_ID, + clientSessionId: 'test-session', + onResponseChunk: () => {}, + logger, + }) // Should NOT call LLM since the programmatic agent ended with end_turn expect(llmCallCount).toBe(0) @@ -300,22 +299,21 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => 'test-agent': mockTemplate, } - const result = await runLoopAgentStepsWithContext( - new MockWebSocket() as unknown as WebSocket, - { - userInputId: 'test-user-input', - agentType: 'test-agent', - agentState: mockAgentState, - prompt: 'Test execution order', - params: undefined, - fingerprintId: 'test-fingerprint', - fileContext: mockFileContext, - localAgentTemplates, - userId: TEST_USER_ID, - clientSessionId: 'test-session', - onResponseChunk: () => {}, - }, - ) + const result = await runLoopAgentStepsWithContext({ + ws: new MockWebSocket() as unknown as WebSocket, + userInputId: 'test-user-input', + agentType: 'test-agent', + agentState: mockAgentState, + prompt: 'Test execution order', + params: undefined, + fingerprintId: 'test-fingerprint', + fileContext: mockFileContext, + localAgentTemplates, + userId: TEST_USER_ID, + clientSessionId: 'test-session', + onResponseChunk: () => {}, + logger, + }) // Verify execution order: // 1. Programmatic step function was called once (creates generator) @@ -348,22 +346,21 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => 'test-agent': mockTemplate, } - const result = await runLoopAgentStepsWithContext( - new MockWebSocket() as unknown as WebSocket, - { - userInputId: 'test-user-input', - agentType: 'test-agent', - agentState: mockAgentState, - prompt: 'Test STEP_ALL behavior', - params: undefined, - fingerprintId: 'test-fingerprint', - fileContext: mockFileContext, - localAgentTemplates, - userId: TEST_USER_ID, - clientSessionId: 'test-session', - onResponseChunk: () => {}, - }, - ) + const result = await runLoopAgentStepsWithContext({ + ws: new MockWebSocket() as unknown as WebSocket, + userInputId: 'test-user-input', + agentType: 'test-agent', + agentState: mockAgentState, + prompt: 'Test STEP_ALL behavior', + params: undefined, + fingerprintId: 'test-fingerprint', + fileContext: mockFileContext, + localAgentTemplates, + userId: TEST_USER_ID, + clientSessionId: 'test-session', + onResponseChunk: () => {}, + logger, + }) expect(stepCount).toBe(1) // Generator function called once expect(llmCallCount).toBe(1) // LLM should be called once @@ -389,22 +386,21 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => 'test-agent': mockTemplate, } - const result = await runLoopAgentStepsWithContext( - new MockWebSocket() as unknown as WebSocket, - { - userInputId: 'test-user-input', - agentType: 'test-agent', - agentState: mockAgentState, - prompt: 'Test no LLM call', - params: undefined, - fingerprintId: 'test-fingerprint', - fileContext: mockFileContext, - localAgentTemplates, - userId: TEST_USER_ID, - clientSessionId: 'test-session', - onResponseChunk: () => {}, - }, - ) + const result = await runLoopAgentStepsWithContext({ + ws: new MockWebSocket() as unknown as WebSocket, + userInputId: 'test-user-input', + agentType: 'test-agent', + agentState: mockAgentState, + prompt: 'Test no LLM call', + params: undefined, + fingerprintId: 'test-fingerprint', + fileContext: mockFileContext, + localAgentTemplates, + userId: TEST_USER_ID, + clientSessionId: 'test-session', + onResponseChunk: () => {}, + logger, + }) expect(llmCallCount).toBe(0) // No LLM calls should be made expect(result.agentState).toBeDefined() @@ -422,22 +418,21 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => 'test-agent': llmOnlyTemplate, } - const result = await runLoopAgentStepsWithContext( - new MockWebSocket() as unknown as WebSocket, - { - userInputId: 'test-user-input', - agentType: 'test-agent', - agentState: mockAgentState, - prompt: 'Test LLM-only agent', - params: undefined, - fingerprintId: 'test-fingerprint', - fileContext: mockFileContext, - localAgentTemplates, - userId: TEST_USER_ID, - clientSessionId: 'test-session', - onResponseChunk: () => {}, - }, - ) + const result = await runLoopAgentStepsWithContext({ + ws: new MockWebSocket() as unknown as WebSocket, + userInputId: 'test-user-input', + agentType: 'test-agent', + agentState: mockAgentState, + prompt: 'Test LLM-only agent', + params: undefined, + fingerprintId: 'test-fingerprint', + fileContext: mockFileContext, + localAgentTemplates, + userId: TEST_USER_ID, + clientSessionId: 'test-session', + onResponseChunk: () => {}, + logger, + }) expect(llmCallCount).toBe(1) // LLM should be called once expect(result.agentState).toBeDefined() @@ -457,22 +452,21 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => 'test-agent': mockTemplate, } - const result = await runLoopAgentStepsWithContext( - new MockWebSocket() as unknown as WebSocket, - { - userInputId: 'test-user-input', - agentType: 'test-agent', - agentState: mockAgentState, - prompt: 'Test error handling', - params: undefined, - fingerprintId: 'test-fingerprint', - fileContext: mockFileContext, - localAgentTemplates, - userId: TEST_USER_ID, - clientSessionId: 'test-session', - onResponseChunk: () => {}, - }, - ) + const result = await runLoopAgentStepsWithContext({ + ws: new MockWebSocket() as unknown as WebSocket, + userInputId: 'test-user-input', + agentType: 'test-agent', + agentState: mockAgentState, + prompt: 'Test error handling', + params: undefined, + fingerprintId: 'test-fingerprint', + fileContext: mockFileContext, + localAgentTemplates, + userId: TEST_USER_ID, + clientSessionId: 'test-session', + onResponseChunk: () => {}, + logger, + }) // After programmatic step error, should end turn and not call LLM expect(llmCallCount).toBe(0) @@ -509,22 +503,21 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => 'test-agent': mockTemplate, } - const result = await runLoopAgentStepsWithContext( - new MockWebSocket() as unknown as WebSocket, - { - userInputId: 'test-user-input', - agentType: 'test-agent', - agentState: mockAgentState, - prompt: 'Test multiple STEP interactions', - params: undefined, - fingerprintId: 'test-fingerprint', - fileContext: mockFileContext, - localAgentTemplates, - userId: TEST_USER_ID, - clientSessionId: 'test-session', - onResponseChunk: () => {}, - }, - ) + const result = await runLoopAgentStepsWithContext({ + ws: new MockWebSocket() as unknown as WebSocket, + userInputId: 'test-user-input', + agentType: 'test-agent', + agentState: mockAgentState, + prompt: 'Test multiple STEP interactions', + params: undefined, + fingerprintId: 'test-fingerprint', + fileContext: mockFileContext, + localAgentTemplates, + userId: TEST_USER_ID, + clientSessionId: 'test-session', + onResponseChunk: () => {}, + logger, + }) expect(stepCount).toBe(1) // Generator function called once expect(llmCallCount).toBe(1) // LLM called once after STEP @@ -539,10 +532,10 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => const mockedRunProgrammaticStep = await mockModule( '@codebuff/backend/run-programmatic-step', () => ({ - runProgrammaticStep: async (agentState: any, options: any) => { - runProgrammaticStepCalls.push({ agentState, options }) + runProgrammaticStep: async (params: any) => { + runProgrammaticStepCalls.push(params) // Return default behavior - return { agentState, endTurn: false } + return { agentState: params.agentState, endTurn: false } }, clearAgentGeneratorCache: () => {}, agentIdToStepAll: new Set(), @@ -559,22 +552,21 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => 'test-agent': mockTemplate, } - await runLoopAgentStepsWithContext( - new MockWebSocket() as unknown as WebSocket, - { - userInputId: 'test-user-input', - agentType: 'test-agent', - agentState: mockAgentState, - prompt: 'Test shouldEndTurn to stepsComplete flow', - params: undefined, - fingerprintId: 'test-fingerprint', - fileContext: mockFileContext, - localAgentTemplates, - userId: TEST_USER_ID, - clientSessionId: 'test-session', - onResponseChunk: () => {}, - }, - ) + await runLoopAgentStepsWithContext({ + ws: new MockWebSocket() as unknown as WebSocket, + userInputId: 'test-user-input', + agentType: 'test-agent', + agentState: mockAgentState, + prompt: 'Test shouldEndTurn to stepsComplete flow', + params: undefined, + fingerprintId: 'test-fingerprint', + fileContext: mockFileContext, + localAgentTemplates, + userId: TEST_USER_ID, + clientSessionId: 'test-session', + onResponseChunk: () => {}, + logger, + }) mockedRunProgrammaticStep.clear() @@ -584,10 +576,10 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => expect(runProgrammaticStepCalls).toHaveLength(2) // First call should have stepsComplete: false - expect(runProgrammaticStepCalls[0].options.stepsComplete).toBe(false) + expect(runProgrammaticStepCalls[0].stepsComplete).toBe(false) // Second call should have stepsComplete: true (after end_turn tool was called) - expect(runProgrammaticStepCalls[1].options.stepsComplete).toBe(true) + expect(runProgrammaticStepCalls[1].stepsComplete).toBe(true) }) it('should restart loop when agent finishes without setting required output', async () => { @@ -647,22 +639,21 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => mockAgentState.output = undefined capturedAgentState = mockAgentState - const result = await runLoopAgentStepsWithContext( - new MockWebSocket() as unknown as WebSocket, - { - userInputId: 'test-user-input', - agentType: 'test-agent', - agentState: mockAgentState, - prompt: 'Test output schema validation', - params: undefined, - fingerprintId: 'test-fingerprint', - fileContext: mockFileContext, - localAgentTemplates, - userId: TEST_USER_ID, - clientSessionId: 'test-session', - onResponseChunk: () => {}, - }, - ) + const result = await runLoopAgentStepsWithContext({ + ws: new MockWebSocket() as unknown as WebSocket, + userInputId: 'test-user-input', + agentType: 'test-agent', + agentState: mockAgentState, + prompt: 'Test output schema validation', + params: undefined, + fingerprintId: 'test-fingerprint', + fileContext: mockFileContext, + localAgentTemplates, + userId: TEST_USER_ID, + clientSessionId: 'test-session', + onResponseChunk: () => {}, + logger, + }) // Should call LLM twice: once to try ending without output, once after reminder expect(llmCallNumber).toBe(2) @@ -721,22 +712,21 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => mockAgentState.output = undefined capturedAgentState = mockAgentState - const result = await runLoopAgentStepsWithContext( - new MockWebSocket() as unknown as WebSocket, - { - userInputId: 'test-user-input', - agentType: 'test-agent', - agentState: mockAgentState, - prompt: 'Test with correct output', - params: undefined, - fingerprintId: 'test-fingerprint', - fileContext: mockFileContext, - localAgentTemplates, - userId: TEST_USER_ID, - clientSessionId: 'test-session', - onResponseChunk: () => {}, - }, - ) + const result = await runLoopAgentStepsWithContext({ + ws: new MockWebSocket() as unknown as WebSocket, + userInputId: 'test-user-input', + agentType: 'test-agent', + agentState: mockAgentState, + prompt: 'Test with correct output', + params: undefined, + fingerprintId: 'test-fingerprint', + fileContext: mockFileContext, + localAgentTemplates, + userId: TEST_USER_ID, + clientSessionId: 'test-session', + onResponseChunk: () => {}, + logger, + }) // Should only call LLM once since output was set correctly expect(llmCallNumber).toBe(1) @@ -768,22 +758,21 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => return 'mock-message-id' }) - const result = await runLoopAgentStepsWithContext( - new MockWebSocket() as unknown as WebSocket, - { - userInputId: 'test-user-input', - agentType: 'test-agent', - agentState: mockAgentState, - prompt: 'Test without output schema', - params: undefined, - fingerprintId: 'test-fingerprint', - fileContext: mockFileContext, - localAgentTemplates, - userId: TEST_USER_ID, - clientSessionId: 'test-session', - onResponseChunk: () => {}, - }, - ) + const result = await runLoopAgentStepsWithContext({ + ws: new MockWebSocket() as unknown as WebSocket, + userInputId: 'test-user-input', + agentType: 'test-agent', + agentState: mockAgentState, + prompt: 'Test without output schema', + params: undefined, + fingerprintId: 'test-fingerprint', + fileContext: mockFileContext, + localAgentTemplates, + userId: TEST_USER_ID, + clientSessionId: 'test-session', + onResponseChunk: () => {}, + logger, + }) // Should only call LLM once and end normally expect(llmCallNumber).toBe(1) @@ -837,22 +826,21 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => mockAgentState.output = undefined capturedAgentState = mockAgentState - const result = await runLoopAgentStepsWithContext( - new MockWebSocket() as unknown as WebSocket, - { - userInputId: 'test-user-input', - agentType: 'test-agent', - agentState: mockAgentState, - prompt: 'Test loop continues', - params: undefined, - fingerprintId: 'test-fingerprint', - fileContext: mockFileContext, - localAgentTemplates, - userId: TEST_USER_ID, - clientSessionId: 'test-session', - onResponseChunk: () => {}, - }, - ) + const result = await runLoopAgentStepsWithContext({ + ws: new MockWebSocket() as unknown as WebSocket, + userInputId: 'test-user-input', + agentType: 'test-agent', + agentState: mockAgentState, + prompt: 'Test loop continues', + params: undefined, + fingerprintId: 'test-fingerprint', + fileContext: mockFileContext, + localAgentTemplates, + userId: TEST_USER_ID, + clientSessionId: 'test-session', + onResponseChunk: () => {}, + logger, + }) // Should call LLM twice: once for work, once to set output and end expect(llmCallNumber).toBe(2) diff --git a/backend/src/__tests__/main-prompt.integration.test.ts b/backend/src/__tests__/main-prompt.integration.test.ts index 66f4ced134..d77525561c 100644 --- a/backend/src/__tests__/main-prompt.integration.test.ts +++ b/backend/src/__tests__/main-prompt.integration.test.ts @@ -17,15 +17,22 @@ import * as checkTerminalCommandModule from '../check-terminal-command' import * as requestFilesPrompt from '../find-files/request-files-prompt' import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk' import { mainPrompt } from '../main-prompt' -import { logger } from '../util/logger' import * as websocketAction from '../websockets/websocket-action' import type { PrintModeEvent } from '@codebuff/common/types/print-mode' import type { ProjectFileContext } from '@codebuff/common/util/file' +import type { Logger } from '@codebuff/types/logger' import type { WebSocket } from 'ws' // --- Shared Mocks & Helpers --- +const logger: Logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, +} + class MockWebSocket { send(msg: string) {} close() {} @@ -107,13 +114,6 @@ describe.skip('mainPrompt (Integration)', () => { }) it('should delete a specified function while preserving other code', async () => { - // Mock necessary non-LLM functions - spyOn(logger, 'debug').mockImplementation(() => {}) - spyOn(logger, 'error').mockImplementation(() => {}) - spyOn(logger, 'info').mockImplementation(() => {}) - spyOn(logger, 'warn').mockImplementation(() => {}) - spyOn(requestFilesPrompt, 'requestRelevantFiles').mockResolvedValue([]) - const initialContent = `import { Message } from '@codebuff/common/types/message' import { withCacheControl } from '@codebuff/common/util/messages' diff --git a/backend/src/__tests__/malformed-tool-call.test.ts b/backend/src/__tests__/malformed-tool-call.test.ts index b2ed8db11d..9effde90fb 100644 --- a/backend/src/__tests__/malformed-tool-call.test.ts +++ b/backend/src/__tests__/malformed-tool-call.test.ts @@ -131,6 +131,7 @@ describe('malformed tool call error handling', () => { agentContext: {}, onResponseChunk, fullResponse: '', + logger, }) // Should have tool result errors in the final message history @@ -187,6 +188,7 @@ describe('malformed tool call error handling', () => { agentContext: {}, onResponseChunk, fullResponse: '', + logger, }) // Should have multiple error tool results @@ -233,6 +235,7 @@ describe('malformed tool call error handling', () => { agentContext: {}, onResponseChunk, fullResponse: '', + logger, }) // Should have error in both toolResults and message history @@ -283,6 +286,7 @@ describe('malformed tool call error handling', () => { agentContext: {}, onResponseChunk, fullResponse: '', + logger, }) const toolMessages = result.state.messages.filter( @@ -335,6 +339,7 @@ describe('malformed tool call error handling', () => { agentContext: {}, onResponseChunk, fullResponse: '', + logger, }) const toolMessages = result.state.messages.filter( @@ -389,6 +394,7 @@ describe('malformed tool call error handling', () => { agentContext: {}, onResponseChunk, fullResponse: '', + logger, }) const toolMessages = result.state.messages.filter( diff --git a/backend/src/__tests__/process-str-replace.test.ts b/backend/src/__tests__/process-str-replace.test.ts index d224bc2794..3728b0393a 100644 --- a/backend/src/__tests__/process-str-replace.test.ts +++ b/backend/src/__tests__/process-str-replace.test.ts @@ -1,14 +1,14 @@ -import { - describe, - expect, - it, - spyOn, - beforeEach, - afterEach, - mock, -} from 'bun:test' +import { describe, expect, it, mock } from 'bun:test' import { applyPatch } from 'diff' +import { processStrReplace } from '../process-str-replace' +import { + executeBatchStrReplaces, + benchifyCanFixLanguage, +} from '../tools/batch-str-replace' + +import type { Logger } from '@codebuff/types/logger' + // Mock the benchify module to simulate missing API key mock.module('benchify', () => ({ Benchify: class MockBenchify { @@ -19,15 +19,6 @@ mock.module('benchify', () => ({ }, })) -import { processStrReplace } from '../process-str-replace' -import { mockFileContext } from './test-utils' -import { - executeBatchStrReplaces, - benchifyCanFixLanguage, -} from '../tools/batch-str-replace' - -import type { Logger } from '@codebuff/types/logger' - const logger: Logger = { debug: () => {}, info: () => {}, @@ -283,12 +274,12 @@ describe('processStrReplace', () => { const oldStr = 'const x' const newStr = 'let x' - const result = await processStrReplace({ - path: 'test.ts', - replacements: [{ old: oldStr, new: newStr, allowMultiple: false }], - initialContentPromise: Promise.resolve(initialContent), - logger, - }) + const result = await processStrReplace({ + path: 'test.ts', + replacements: [{ old: oldStr, new: newStr, allowMultiple: false }], + initialContentPromise: Promise.resolve(initialContent), + logger, + }) expect(result).not.toBeNull() expect('error' in result).toBe(true) @@ -563,6 +554,7 @@ describe('Benchify resilience', () => { onResponseChunk: () => {}, state: { messages: [] }, userId: 'test-user', + logger, }) // Should complete without error even when Benchify is unavailable @@ -590,6 +582,7 @@ describe('Benchify resilience', () => { onResponseChunk: () => {}, state: { messages: [] }, userId: 'test-user', + logger, }), ).resolves.toBeUndefined() // Should complete without throwing }) @@ -627,6 +620,7 @@ describe('Benchify resilience', () => { onResponseChunk: () => {}, state: { messages: [] }, userId: 'test-user', + logger, }) // Should complete without throwing an error diff --git a/backend/src/__tests__/prompt-caching-subagents.test.ts b/backend/src/__tests__/prompt-caching-subagents.test.ts index 323797ea14..cbeb32775f 100644 --- a/backend/src/__tests__/prompt-caching-subagents.test.ts +++ b/backend/src/__tests__/prompt-caching-subagents.test.ts @@ -1,15 +1,9 @@ import { TEST_USER_ID } from '@codebuff/common/old-constants' -import { - clearMockedModules, - mockModule, -} from '@codebuff/common/testing/mock-modules' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { spyOn, beforeEach, afterEach, - beforeAll, - afterAll, describe, expect, it, @@ -23,8 +17,16 @@ import * as websocketAction from '../websockets/websocket-action' import type { AgentTemplate } from '../templates/types' import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { ProjectFileContext } from '@codebuff/common/util/file' +import type { Logger } from '@codebuff/types/logger' import type { WebSocket } from 'ws' +const logger: Logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, +} + const mockFileContext: ProjectFileContext = { projectRoot: '/test', cwd: '/test', @@ -62,19 +64,6 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { let mockLocalAgentTemplates: Record let capturedMessages: Message[] = [] - beforeAll(() => { - // Mock logger - mockModule('@codebuff/backend/util/logger', () => ({ - logger: { - debug: () => {}, - error: () => {}, - info: () => {}, - warn: () => {}, - }, - withLoggerContext: async (context: any, fn: () => Promise) => fn(), - })) - }) - beforeEach(() => { capturedMessages = [] @@ -167,16 +156,13 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { mock.restore() }) - afterAll(() => { - clearMockedModules() - }) - it('should inherit parent system prompt when inheritParentSystemPrompt is true', async () => { const sessionState = getInitialSessionState(mockFileContext) const ws = new MockWebSocket() as unknown as WebSocket // Run parent agent first to establish system prompt - const parentResult = await loopAgentSteps(ws, { + const parentResult = await loopAgentSteps({ + ws, userInputId: 'test-parent', prompt: 'Parent task', params: undefined, @@ -188,6 +174,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, + logger, }) // Capture parent's messages which include the system prompt @@ -208,7 +195,8 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { messageHistory: [], } - await loopAgentSteps(ws, { + await loopAgentSteps({ + ws, userInputId: 'test-child', prompt: 'Child task', params: undefined, @@ -221,6 +209,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { clientSessionId: 'test-session', onResponseChunk: () => {}, parentSystemPrompt: parentSystemPrompt, + logger, }) // Verify child uses parent's system prompt @@ -255,7 +244,8 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { mockLocalAgentTemplates['standalone-child'] = standaloneChild // Run parent agent first - const parentResult = await loopAgentSteps(ws, { + const parentResult = await loopAgentSteps({ + ws, userInputId: 'test-parent', prompt: 'Parent task', params: undefined, @@ -267,6 +257,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, + logger, }) const parentMessages = capturedMessages @@ -281,7 +272,8 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { messageHistory: [], } - await loopAgentSteps(ws, { + await loopAgentSteps({ + ws, userInputId: 'test-child', prompt: 'Child task', params: undefined, @@ -294,6 +286,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { clientSessionId: 'test-session', onResponseChunk: () => {}, parentSystemPrompt: parentSystemPrompt, + logger, }) const childMessages = capturedMessages @@ -329,7 +322,8 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { mockLocalAgentTemplates['message-history-child'] = messageHistoryChild // Run parent agent first - await loopAgentSteps(ws, { + await loopAgentSteps({ + ws, userInputId: 'test-parent', prompt: 'Parent task', params: undefined, @@ -341,6 +335,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, + logger, }) const parentMessages = capturedMessages @@ -358,7 +353,8 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { ], } - await loopAgentSteps(ws, { + await loopAgentSteps({ + ws, userInputId: 'test-child', prompt: 'Child task', params: undefined, @@ -371,6 +367,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { clientSessionId: 'test-session', onResponseChunk: () => {}, parentSystemPrompt: parentSystemPrompt, + logger, }) const childMessages = capturedMessages @@ -432,7 +429,8 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { const ws = new MockWebSocket() as unknown as WebSocket // Run parent agent - const parentResult = await loopAgentSteps(ws, { + const parentResult = await loopAgentSteps({ + ws, userInputId: 'test-parent', prompt: 'Parent task', params: undefined, @@ -444,6 +442,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, + logger, }) const parentMessages = capturedMessages @@ -458,7 +457,8 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { messageHistory: [], } - await loopAgentSteps(ws, { + await loopAgentSteps({ + ws, userInputId: 'test-child', prompt: 'Child task', params: undefined, @@ -471,6 +471,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { clientSessionId: 'test-session', onResponseChunk: () => {}, parentSystemPrompt: parentSystemPrompt, + logger, }) const childMessages = capturedMessages @@ -510,7 +511,8 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { mockLocalAgentTemplates['full-inherit-child'] = fullInheritChild // Run parent agent first with some message history - const parentResult = await loopAgentSteps(ws, { + const parentResult = await loopAgentSteps({ + ws, userInputId: 'test-parent', prompt: 'Parent task', params: undefined, @@ -528,6 +530,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { userId: TEST_USER_ID, clientSessionId: 'test-session', onResponseChunk: () => {}, + logger, }) const parentMessages = capturedMessages @@ -545,7 +548,8 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { ], } - await loopAgentSteps(ws, { + await loopAgentSteps({ + ws, userInputId: 'test-child', prompt: 'Child task', params: undefined, @@ -558,6 +562,7 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { clientSessionId: 'test-session', onResponseChunk: () => {}, parentSystemPrompt: parentSystemPrompt, + logger, }) const childMessages = capturedMessages diff --git a/backend/src/__tests__/read-docs-tool.test.ts b/backend/src/__tests__/read-docs-tool.test.ts index 6294b911c1..76d7b0ee76 100644 --- a/backend/src/__tests__/read-docs-tool.test.ts +++ b/backend/src/__tests__/read-docs-tool.test.ts @@ -27,6 +27,13 @@ import * as websocketAction from '../websockets/websocket-action' import type { Logger } from '@codebuff/types/logger' import type { WebSocket } from 'ws' +const logger: Logger = { + debug: () => {}, + error: () => {}, + info: () => {}, + warn: () => {}, +} + function mockAgentStream(content: string | string[]) { spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) { if (typeof content === 'string') { @@ -42,12 +49,6 @@ function mockAgentStream(content: string | string[]) { describe('read_docs tool with researcher agent', () => { // Track all mocked functions to verify they're being used const mockedFunctions: Array<{ name: string; spy: any }> = [] - const logger: Logger = { - debug: () => {}, - error: () => {}, - info: () => {}, - warn: () => {}, - } beforeEach(() => { // Clear tracked mocks @@ -291,34 +292,33 @@ describe('read_docs tool with researcher agent', () => { ...sessionState.mainAgentState, agentType: 'researcher' as const, } - const { agentTemplates } = assembleLocalAgentTemplates( - mockFileContextWithAgents, - ) + const { agentTemplates } = assembleLocalAgentTemplates({ + fileContext: mockFileContextWithAgents, + logger, + }) - const { agentState: newAgentState } = await runAgentStep( - new MockWebSocket() as unknown as WebSocket, - { - system: 'Test system prompt', - userId: TEST_USER_ID, - userInputId: 'test-input', - clientSessionId: 'test-session', - fingerprintId: 'test-fingerprint', - onResponseChunk: () => {}, - agentType: 'researcher', - fileContext: mockFileContextWithAgents, - localAgentTemplates: agentTemplates, - agentState, - prompt: 'Get React documentation', - params: undefined, - }, - ) + const { agentState: newAgentState } = await runAgentStep({ + ws: new MockWebSocket() as unknown as WebSocket, + system: 'Test system prompt', + userId: TEST_USER_ID, + userInputId: 'test-input', + clientSessionId: 'test-session', + fingerprintId: 'test-fingerprint', + onResponseChunk: () => {}, + agentType: 'researcher', + fileContext: mockFileContextWithAgents, + localAgentTemplates: agentTemplates, + agentState, + prompt: 'Get React documentation', + params: undefined, + logger, + }) - expect(context7Api.fetchContext7LibraryDocumentation).toHaveBeenCalledWith( - 'React', - { - topic: 'hooks', - }, - ) + expect(context7Api.fetchContext7LibraryDocumentation).toHaveBeenCalledWith({ + query: 'React', + topic: 'hooks', + logger: expect.anything(), + }) // Check that the documentation was added to the message history const toolResultMessages = newAgentState.messageHistory.filter( @@ -365,11 +365,13 @@ describe('read_docs tool with researcher agent', () => { ...sessionState.mainAgentState, agentType: 'researcher' as const, } - const { agentTemplates } = assembleLocalAgentTemplates( - mockFileContextWithAgents, - ) + const { agentTemplates } = assembleLocalAgentTemplates({ + fileContext: mockFileContextWithAgents, + logger, + }) - await runAgentStep(new MockWebSocket() as unknown as WebSocket, { + await runAgentStep({ + ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, userInputId: 'test-input', @@ -382,15 +384,15 @@ describe('read_docs tool with researcher agent', () => { agentState, prompt: 'Get React hooks documentation', params: undefined, + logger, }) - expect(context7Api.fetchContext7LibraryDocumentation).toHaveBeenCalledWith( - 'React', - { - topic: 'hooks', - tokens: 5000, - }, - ) + expect(context7Api.fetchContext7LibraryDocumentation).toHaveBeenCalledWith({ + query: 'React', + topic: 'hooks', + tokens: 5000, + logger: expect.anything(), + }) }, 10000) test('should handle case when no documentation is found', async () => { @@ -413,27 +415,27 @@ describe('read_docs tool with researcher agent', () => { ...sessionState.mainAgentState, agentType: 'researcher' as const, } - const { agentTemplates } = assembleLocalAgentTemplates( - mockFileContextWithAgents, - ) + const { agentTemplates } = assembleLocalAgentTemplates({ + fileContext: mockFileContextWithAgents, + logger, + }) - const { agentState: newAgentState } = await runAgentStep( - new MockWebSocket() as unknown as WebSocket, - { - system: 'Test system prompt', - userId: TEST_USER_ID, - userInputId: 'test-input', - clientSessionId: 'test-session', - fingerprintId: 'test-fingerprint', - onResponseChunk: () => {}, - agentType: 'researcher', - fileContext: mockFileContextWithAgents, - localAgentTemplates: agentTemplates, - agentState, - prompt: 'Get documentation for NonExistentLibrary', - params: undefined, - }, - ) + const { agentState: newAgentState } = await runAgentStep({ + ws: new MockWebSocket() as unknown as WebSocket, + system: 'Test system prompt', + userId: TEST_USER_ID, + userInputId: 'test-input', + clientSessionId: 'test-session', + fingerprintId: 'test-fingerprint', + onResponseChunk: () => {}, + agentType: 'researcher', + fileContext: mockFileContextWithAgents, + localAgentTemplates: agentTemplates, + agentState, + prompt: 'Get documentation for NonExistentLibrary', + params: undefined, + logger, + }) // Check that the "no documentation found" message was added const toolResultMessages = newAgentState.messageHistory.filter( @@ -480,27 +482,27 @@ describe('read_docs tool with researcher agent', () => { ...sessionState.mainAgentState, agentType: 'researcher' as const, } - const { agentTemplates } = assembleLocalAgentTemplates( - mockFileContextWithAgents, - ) + const { agentTemplates } = assembleLocalAgentTemplates({ + fileContext: mockFileContextWithAgents, + logger, + }) - const { agentState: newAgentState } = await runAgentStep( - new MockWebSocket() as unknown as WebSocket, - { - system: 'Test system prompt', - userId: TEST_USER_ID, - userInputId: 'test-input', - clientSessionId: 'test-session', - fingerprintId: 'test-fingerprint', - onResponseChunk: () => {}, - agentType: 'researcher', - fileContext: mockFileContextWithAgents, - localAgentTemplates: agentTemplates, - agentState, - prompt: 'Get React documentation', - params: undefined, - }, - ) + const { agentState: newAgentState } = await runAgentStep({ + ws: new MockWebSocket() as unknown as WebSocket, + system: 'Test system prompt', + userId: TEST_USER_ID, + userInputId: 'test-input', + clientSessionId: 'test-session', + fingerprintId: 'test-fingerprint', + onResponseChunk: () => {}, + agentType: 'researcher', + fileContext: mockFileContextWithAgents, + localAgentTemplates: agentTemplates, + agentState, + prompt: 'Get React documentation', + params: undefined, + logger, + }) // Check that the error message was added const toolResultMessages = newAgentState.messageHistory.filter( @@ -546,27 +548,27 @@ describe('read_docs tool with researcher agent', () => { ...sessionState.mainAgentState, agentType: 'researcher' as const, } - const { agentTemplates } = assembleLocalAgentTemplates( - mockFileContextWithAgents, - ) + const { agentTemplates } = assembleLocalAgentTemplates({ + fileContext: mockFileContextWithAgents, + logger, + }) - const { agentState: newAgentState } = await runAgentStep( - new MockWebSocket() as unknown as WebSocket, - { - system: 'Test system prompt', - userId: TEST_USER_ID, - userInputId: 'test-input', - clientSessionId: 'test-session', - fingerprintId: 'test-fingerprint', - onResponseChunk: () => {}, - agentType: 'researcher', - fileContext: mockFileContextWithAgents, - localAgentTemplates: agentTemplates, - agentState, - prompt: 'Get React server components documentation', - params: undefined, - }, - ) + const { agentState: newAgentState } = await runAgentStep({ + ws: new MockWebSocket() as unknown as WebSocket, + system: 'Test system prompt', + userId: TEST_USER_ID, + userInputId: 'test-input', + clientSessionId: 'test-session', + fingerprintId: 'test-fingerprint', + onResponseChunk: () => {}, + agentType: 'researcher', + fileContext: mockFileContextWithAgents, + localAgentTemplates: agentTemplates, + agentState, + prompt: 'Get React server components documentation', + params: undefined, + logger, + }) // Check that the topic is included in the error message const toolResultMessages = newAgentState.messageHistory.filter( @@ -613,27 +615,27 @@ describe('read_docs tool with researcher agent', () => { ...sessionState.mainAgentState, agentType: 'researcher' as const, } - const { agentTemplates } = assembleLocalAgentTemplates( - mockFileContextWithAgents, - ) + const { agentTemplates } = assembleLocalAgentTemplates({ + fileContext: mockFileContextWithAgents, + logger, + }) - const { agentState: newAgentState } = await runAgentStep( - new MockWebSocket() as unknown as WebSocket, - { - system: 'Test system prompt', - userId: TEST_USER_ID, - userInputId: 'test-input', - clientSessionId: 'test-session', - fingerprintId: 'test-fingerprint', - onResponseChunk: () => {}, - agentType: 'researcher', - fileContext: mockFileContextWithAgents, - localAgentTemplates: agentTemplates, - agentState, - prompt: 'Get React documentation', - params: undefined, - }, - ) + const { agentState: newAgentState } = await runAgentStep({ + ws: new MockWebSocket() as unknown as WebSocket, + system: 'Test system prompt', + userId: TEST_USER_ID, + userInputId: 'test-input', + clientSessionId: 'test-session', + fingerprintId: 'test-fingerprint', + onResponseChunk: () => {}, + agentType: 'researcher', + fileContext: mockFileContextWithAgents, + localAgentTemplates: agentTemplates, + agentState, + prompt: 'Get React documentation', + params: undefined, + logger, + }) // Check that the generic error message was added const toolResultMessages = newAgentState.messageHistory.filter( diff --git a/backend/src/__tests__/request-files-prompt.test.ts b/backend/src/__tests__/request-files-prompt.test.ts index b7fdb04459..d11ab037a2 100644 --- a/backend/src/__tests__/request-files-prompt.test.ts +++ b/backend/src/__tests__/request-files-prompt.test.ts @@ -58,6 +58,12 @@ describe('requestRelevantFiles', () => { const mockUserId = 'user1' const mockCostMode: CostMode = 'normal' const mockRepoId = 'owner/repo' + const logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, + } let getCustomFilePickerConfigForOrgSpy: any // Explicitly typed as any @@ -127,17 +133,19 @@ describe('requestRelevantFiles', () => { }) it('should use default file counts and maxFiles when no custom config', async () => { - await OriginalRequestFilesPromptModule.requestRelevantFiles( - { messages: mockMessages, system: mockSystem }, - mockFileContext, - mockAssistantPrompt, - mockAgentStepId, - mockClientSessionId, - mockFingerprintId, - mockUserInputId, - mockUserId, - mockRepoId, - ) + await OriginalRequestFilesPromptModule.requestRelevantFiles({ + messages: mockMessages, + system: mockSystem, + fileContext: mockFileContext, + assistantPrompt: mockAssistantPrompt, + agentStepId: mockAgentStepId, + clientSessionId: mockClientSessionId, + fingerprintId: mockFingerprintId, + userInputId: mockUserInputId, + userId: mockUserId, + repoId: mockRepoId, + logger, + }) expect( geminiWithFallbacksModule.promptFlashWithFallbacks, ).toHaveBeenCalled() @@ -152,17 +160,19 @@ describe('requestRelevantFiles', () => { } getCustomFilePickerConfigForOrgSpy!.mockResolvedValue(customConfig as any) - await OriginalRequestFilesPromptModule.requestRelevantFiles( - { messages: mockMessages, system: mockSystem }, - mockFileContext, - mockAssistantPrompt, - mockAgentStepId, - mockClientSessionId, - mockFingerprintId, - mockUserInputId, - mockUserId, - mockRepoId, - ) + await OriginalRequestFilesPromptModule.requestRelevantFiles({ + messages: mockMessages, + system: mockSystem, + fileContext: mockFileContext, + assistantPrompt: mockAssistantPrompt, + agentStepId: mockAgentStepId, + clientSessionId: mockClientSessionId, + fingerprintId: mockFingerprintId, + userInputId: mockUserInputId, + userId: mockUserId, + repoId: mockRepoId, + logger, + }) expect( geminiWithFallbacksModule.promptFlashWithFallbacks, ).toHaveBeenCalled() @@ -176,17 +186,19 @@ describe('requestRelevantFiles', () => { } getCustomFilePickerConfigForOrgSpy!.mockResolvedValue(customConfig as any) - const result = await OriginalRequestFilesPromptModule.requestRelevantFiles( - { messages: mockMessages, system: mockSystem }, - mockFileContext, - mockAssistantPrompt, - mockAgentStepId, - mockClientSessionId, - mockFingerprintId, - mockUserInputId, - mockUserId, - mockRepoId, - ) + const result = await OriginalRequestFilesPromptModule.requestRelevantFiles({ + messages: mockMessages, + system: mockSystem, + fileContext: mockFileContext, + assistantPrompt: mockAssistantPrompt, + agentStepId: mockAgentStepId, + clientSessionId: mockClientSessionId, + fingerprintId: mockFingerprintId, + userInputId: mockUserInputId, + userId: mockUserId, + repoId: mockRepoId, + logger, + }) expect(result).toBeArray() if (result) { expect(result.length).toBeLessThanOrEqual(3) @@ -200,21 +212,22 @@ describe('requestRelevantFiles', () => { } getCustomFilePickerConfigForOrgSpy!.mockResolvedValue(customConfig as any) - await OriginalRequestFilesPromptModule.requestRelevantFiles( - { messages: mockMessages, system: mockSystem }, - mockFileContext, - mockAssistantPrompt, - mockAgentStepId, - mockClientSessionId, - mockFingerprintId, - mockUserInputId, - mockUserId, - mockRepoId, - ) + await OriginalRequestFilesPromptModule.requestRelevantFiles({ + messages: mockMessages, + system: mockSystem, + fileContext: mockFileContext, + assistantPrompt: mockAssistantPrompt, + agentStepId: mockAgentStepId, + clientSessionId: mockClientSessionId, + fingerprintId: mockFingerprintId, + userInputId: mockUserInputId, + userId: mockUserId, + repoId: mockRepoId, + logger, + }) expect( geminiWithFallbacksModule.promptFlashWithFallbacks, ).toHaveBeenCalledWith( - expect.anything(), expect.objectContaining({ useFinetunedModel: finetunedVertexModels.ft_filepicker_010, }), @@ -228,22 +241,23 @@ describe('requestRelevantFiles', () => { } getCustomFilePickerConfigForOrgSpy!.mockResolvedValue(customConfig as any) - await OriginalRequestFilesPromptModule.requestRelevantFiles( - { messages: mockMessages, system: mockSystem }, - mockFileContext, - mockAssistantPrompt, - mockAgentStepId, - mockClientSessionId, - mockFingerprintId, - mockUserInputId, - mockUserId, - mockRepoId, - ) + await OriginalRequestFilesPromptModule.requestRelevantFiles({ + messages: mockMessages, + system: mockSystem, + fileContext: mockFileContext, + assistantPrompt: mockAssistantPrompt, + agentStepId: mockAgentStepId, + clientSessionId: mockClientSessionId, + fingerprintId: mockFingerprintId, + userInputId: mockUserInputId, + userId: mockUserId, + repoId: mockRepoId, + logger, + }) const expectedModel = finetunedVertexModels.ft_filepicker_010 expect( geminiWithFallbacksModule.promptFlashWithFallbacks, ).toHaveBeenCalledWith( - expect.anything(), expect.objectContaining({ useFinetunedModel: expectedModel, }), diff --git a/backend/src/__tests__/run-agent-step-tools.test.ts b/backend/src/__tests__/run-agent-step-tools.test.ts index b0906e8286..302dc687e6 100644 --- a/backend/src/__tests__/run-agent-step-tools.test.ts +++ b/backend/src/__tests__/run-agent-step-tools.test.ts @@ -113,7 +113,7 @@ describe('runAgentStep - set_output tool', () => { spyOn(aisdk, 'promptAiSdk').mockImplementation(() => Promise.resolve('Test response'), ) - clearAgentGeneratorCache() + clearAgentGeneratorCache({ logger }) }) afterEach(() => { @@ -121,7 +121,7 @@ describe('runAgentStep - set_output tool', () => { }) afterAll(() => { - clearAgentGeneratorCache() + clearAgentGeneratorCache({ logger }) }) class MockWebSocket { @@ -176,23 +176,22 @@ describe('runAgentStep - set_output tool', () => { 'test-set-output-agent': testAgent, } - const result = await runAgentStep( - new MockWebSocket() as unknown as WebSocket, - { - userId: TEST_USER_ID, - userInputId: 'test-input', - clientSessionId: 'test-session', - fingerprintId: 'test-fingerprint', - onResponseChunk: () => {}, - agentType: 'test-set-output-agent', - fileContext: mockFileContext, - localAgentTemplates, - agentState, - prompt: 'Analyze the codebase', - params: undefined, - system: 'Test system prompt', - }, - ) + const result = await runAgentStep({ + ws: new MockWebSocket() as unknown as WebSocket, + userId: TEST_USER_ID, + userInputId: 'test-input', + clientSessionId: 'test-session', + fingerprintId: 'test-fingerprint', + onResponseChunk: () => {}, + agentType: 'test-set-output-agent', + fileContext: mockFileContext, + localAgentTemplates, + agentState, + prompt: 'Analyze the codebase', + params: undefined, + system: 'Test system prompt', + logger, + }) expect(result.agentState.output).toEqual({ message: 'Hi', @@ -219,23 +218,22 @@ describe('runAgentStep - set_output tool', () => { 'test-set-output-agent': testAgent, } - const result = await runAgentStep( - new MockWebSocket() as unknown as WebSocket, - { - userId: TEST_USER_ID, - userInputId: 'test-input', - clientSessionId: 'test-session', - fingerprintId: 'test-fingerprint', - onResponseChunk: () => {}, - agentType: 'test-set-output-agent', - fileContext: mockFileContext, - localAgentTemplates, - agentState, - prompt: 'Analyze the codebase', - params: undefined, - system: 'Test system prompt', - }, - ) + const result = await runAgentStep({ + ws: new MockWebSocket() as unknown as WebSocket, + userId: TEST_USER_ID, + userInputId: 'test-input', + clientSessionId: 'test-session', + fingerprintId: 'test-fingerprint', + onResponseChunk: () => {}, + agentType: 'test-set-output-agent', + fileContext: mockFileContext, + localAgentTemplates, + agentState, + prompt: 'Analyze the codebase', + params: undefined, + system: 'Test system prompt', + logger, + }) expect(result.agentState.output).toEqual({ message: 'Analysis complete', @@ -268,23 +266,22 @@ describe('runAgentStep - set_output tool', () => { 'test-set-output-agent': testAgent, } - const result = await runAgentStep( - new MockWebSocket() as unknown as WebSocket, - { - userId: TEST_USER_ID, - userInputId: 'test-input', - clientSessionId: 'test-session', - fingerprintId: 'test-fingerprint', - onResponseChunk: () => {}, - agentType: 'test-set-output-agent', - fileContext: mockFileContext, - localAgentTemplates, - agentState, - prompt: 'Update the output', - params: undefined, - system: 'Test system prompt', - }, - ) + const result = await runAgentStep({ + ws: new MockWebSocket() as unknown as WebSocket, + userId: TEST_USER_ID, + userInputId: 'test-input', + clientSessionId: 'test-session', + fingerprintId: 'test-fingerprint', + onResponseChunk: () => {}, + agentType: 'test-set-output-agent', + fileContext: mockFileContext, + localAgentTemplates, + agentState, + prompt: 'Update the output', + params: undefined, + system: 'Test system prompt', + logger, + }) expect(result.agentState.output).toEqual({ newField: 'new value', @@ -308,23 +305,22 @@ describe('runAgentStep - set_output tool', () => { 'test-set-output-agent': testAgent, } - const result = await runAgentStep( - new MockWebSocket() as unknown as WebSocket, - { - userId: TEST_USER_ID, - userInputId: 'test-input', - clientSessionId: 'test-session', - fingerprintId: 'test-fingerprint', - onResponseChunk: () => {}, - agentType: 'test-set-output-agent', - fileContext: mockFileContext, - localAgentTemplates, - agentState, - prompt: 'Update with empty object', - params: undefined, - system: 'Test system prompt', - }, - ) + const result = await runAgentStep({ + ws: new MockWebSocket() as unknown as WebSocket, + userId: TEST_USER_ID, + userInputId: 'test-input', + clientSessionId: 'test-session', + fingerprintId: 'test-fingerprint', + onResponseChunk: () => {}, + agentType: 'test-set-output-agent', + fileContext: mockFileContext, + localAgentTemplates, + agentState, + prompt: 'Update with empty object', + params: undefined, + system: 'Test system prompt', + logger, + }) // Should replace with empty object expect(result.agentState.output).toEqual({}) @@ -405,23 +401,22 @@ describe('runAgentStep - set_output tool', () => { const initialMessageCount = agentState.messageHistory.length - const result = await runAgentStep( - new MockWebSocket() as unknown as WebSocket, - { - userId: TEST_USER_ID, - userInputId: 'test-input', - clientSessionId: 'test-session', - fingerprintId: 'test-fingerprint', - onResponseChunk: () => {}, - agentType: 'test-handlesteps-agent', - fileContext: mockFileContext, - localAgentTemplates: mockAgentRegistry, - agentState, - prompt: 'Test the handleSteps functionality', - params: undefined, - system: 'Test system prompt', - }, - ) + const result = await runAgentStep({ + ws: new MockWebSocket() as unknown as WebSocket, + userId: TEST_USER_ID, + userInputId: 'test-input', + clientSessionId: 'test-session', + fingerprintId: 'test-fingerprint', + onResponseChunk: () => {}, + agentType: 'test-handlesteps-agent', + fileContext: mockFileContext, + localAgentTemplates: mockAgentRegistry, + agentState, + prompt: 'Test the handleSteps functionality', + params: undefined, + system: 'Test system prompt', + logger, + }) // Should end turn because toolCalls.length === 0 && toolResults.length === 0 from LLM processing // (The programmatic step tool results don't count toward this calculation) @@ -566,23 +561,22 @@ describe('runAgentStep - set_output tool', () => { }, ] - const result = await runAgentStep( - new MockWebSocket() as unknown as WebSocket, - { - userId: TEST_USER_ID, - userInputId: 'test-input', - clientSessionId: 'test-session', - fingerprintId: 'test-fingerprint', - onResponseChunk: () => {}, - agentType: 'parent-agent', - fileContext: mockFileContext, - localAgentTemplates: mockAgentRegistry, - agentState, - prompt: 'Spawn an inline agent to clean up messages', - params: undefined, - system: 'Parent system prompt', - }, - ) + const result = await runAgentStep({ + ws: new MockWebSocket() as unknown as WebSocket, + userId: TEST_USER_ID, + userInputId: 'test-input', + clientSessionId: 'test-session', + fingerprintId: 'test-fingerprint', + onResponseChunk: () => {}, + agentType: 'parent-agent', + fileContext: mockFileContext, + localAgentTemplates: mockAgentRegistry, + agentState, + prompt: 'Spawn an inline agent to clean up messages', + params: undefined, + system: 'Parent system prompt', + logger, + }) const finalMessages = result.agentState.messageHistory diff --git a/backend/src/__tests__/run-programmatic-step.test.ts b/backend/src/__tests__/run-programmatic-step.test.ts index db98f01dc7..9f50d7722d 100644 --- a/backend/src/__tests__/run-programmatic-step.test.ts +++ b/backend/src/__tests__/run-programmatic-step.test.ts @@ -3,7 +3,6 @@ import { TEST_USER_ID } from '@codebuff/common/old-constants' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { afterEach, - beforeAll, beforeEach, describe, expect, @@ -32,6 +31,13 @@ import type { AgentState } from '@codebuff/common/types/session-state' import type { Logger } from '@codebuff/types/logger' import type { WebSocket } from 'ws' +const logger: Logger = { + debug: () => {}, + error: () => {}, + info: () => {}, + warn: () => {}, +} + describe('runProgrammaticStep', () => { let mockTemplate: AgentTemplate let mockAgentState: AgentState @@ -40,17 +46,6 @@ describe('runProgrammaticStep', () => { let getRequestContextSpy: any let addAgentStepSpy: any let sendActionSpy: any - let logger: Logger - - beforeAll(() => { - // Mock logger - logger = { - debug: () => {}, - error: () => {}, - info: () => {}, - warn: () => {}, - } - }) beforeEach(() => { // Mock analytics @@ -125,6 +120,7 @@ describe('runProgrammaticStep', () => { // Create mock params mockParams = { + agentState: mockAgentState, template: mockTemplate, prompt: 'Test prompt', params: { testParam: 'value' }, @@ -140,13 +136,14 @@ describe('runProgrammaticStep', () => { localAgentTemplates: {}, stepsComplete: false, stepNumber: 1, + logger, } }) afterEach(() => { mock.restore() // Clear the generator cache between tests - clearAgentGeneratorCache() + clearAgentGeneratorCache({ logger }) }) describe('generator lifecycle', () => { @@ -157,7 +154,7 @@ describe('runProgrammaticStep', () => { mockTemplate.handleSteps = () => mockGenerator - const result = await runProgrammaticStep(mockAgentState, mockParams) + const result = await runProgrammaticStep(mockParams) expect(result.endTurn).toBe(true) expect(result.agentState).toBeDefined() @@ -175,11 +172,11 @@ describe('runProgrammaticStep', () => { mockTemplate.handleSteps = createGenerator // First call - await runProgrammaticStep(mockAgentState, mockParams) + await runProgrammaticStep(mockParams) expect(callCount).toBe(1) // Second call with same agent ID should create a new generator - await runProgrammaticStep(mockAgentState, mockParams) + await runProgrammaticStep(mockParams) expect(callCount).toBe(2) // Should create new generator }) @@ -192,12 +189,15 @@ describe('runProgrammaticStep', () => { mockTemplate.handleSteps = () => mockGenerator // First call to set STEP_ALL state - const result1 = await runProgrammaticStep(mockAgentState, mockParams) + const result1 = await runProgrammaticStep(mockParams) expect(result1.endTurn).toBe(false) // Second call should return early due to STEP_ALL state // Use the same agent state with the same runId - const result2 = await runProgrammaticStep(result1.agentState, mockParams) + const result2 = await runProgrammaticStep({ + ...mockParams, + agentState: result1.agentState, + }) expect(result2.endTurn).toBe(false) expect(result2.agentState.agentId).toEqual(result1.agentState.agentId) }) @@ -205,9 +205,9 @@ describe('runProgrammaticStep', () => { it('should throw error when template has no handleStep', async () => { mockTemplate.handleSteps = undefined - await expect( - runProgrammaticStep(mockAgentState, mockParams), - ).rejects.toThrow('No step handler found for agent template test-agent') + await expect(runProgrammaticStep(mockParams)).rejects.toThrow( + 'No step handler found for agent template test-agent', + ) }) }) @@ -234,7 +234,7 @@ describe('runProgrammaticStep', () => { } }) - const result = await runProgrammaticStep(mockAgentState, mockParams) + const result = await runProgrammaticStep(mockParams) // Verify add_message tool was executed expect(executeToolCallSpy).toHaveBeenCalledWith( @@ -284,7 +284,7 @@ describe('runProgrammaticStep', () => { mockTemplate.handleSteps = () => mockGenerator - const result = await runProgrammaticStep(mockAgentState, mockParams) + const result = await runProgrammaticStep(mockParams) expect(executeToolCallSpy).toHaveBeenCalledTimes(2) expect(executeToolCallSpy).toHaveBeenCalledWith( @@ -336,7 +336,7 @@ describe('runProgrammaticStep', () => { return {} }) - const result = await runProgrammaticStep(mockAgentState, mockParams) + const result = await runProgrammaticStep(mockParams) expect(executeToolCallSpy).toHaveBeenCalledWith( expect.objectContaining({ @@ -372,7 +372,7 @@ describe('runProgrammaticStep', () => { mockTemplate.handleSteps = () => mockGenerator - const result = await runProgrammaticStep(mockAgentState, mockParams) + const result = await runProgrammaticStep(mockParams) expect(executeToolCallSpy).toHaveBeenCalledTimes(3) expect(result.endTurn).toBe(true) @@ -553,7 +553,7 @@ describe('runProgrammaticStep', () => { }) // First call - should execute all tools and transition to STEP_ALL - const result1 = await runProgrammaticStep(mockAgentState, mockParams) + const result1 = await runProgrammaticStep(mockParams) // Verify all tools were executed expect(executeToolCallSpy).toHaveBeenCalledTimes(7) // 7 tools before STEP_ALL @@ -619,9 +619,10 @@ describe('runProgrammaticStep', () => { executeToolCallSpy.mockClear() // Second call - should return early due to STEP_ALL state - const result2 = await runProgrammaticStep(result1.agentState, { + const result2 = await runProgrammaticStep({ ...mockParams, // Use the updated agent state from first call + agentState: result1.agentState, }) // Verify STEP_ALL behavior @@ -631,8 +632,9 @@ describe('runProgrammaticStep', () => { expect(stepCount).toBe(1) // Generator should not have run again // Third call - verify STEP_ALL state persists - const result3 = await runProgrammaticStep(result2.agentState, { + const result3 = await runProgrammaticStep({ ...mockParams, + agentState: result2.agentState, }) expect(executeToolCallSpy).not.toHaveBeenCalled() @@ -673,7 +675,7 @@ describe('runProgrammaticStep', () => { } }) - await runProgrammaticStep(mockAgentState, mockParams) + await runProgrammaticStep(mockParams) expect(receivedToolResult).toEqual([ { @@ -697,7 +699,7 @@ describe('runProgrammaticStep', () => { mockTemplate.handleSteps = () => mockGenerator - const result = await runProgrammaticStep(mockAgentState, mockParams) + const result = await runProgrammaticStep(mockParams) expect(executeToolCallSpy).toHaveBeenCalledTimes(1) // Only first tool call expect(result.endTurn).toBe(false) @@ -711,7 +713,7 @@ describe('runProgrammaticStep', () => { mockTemplate.handleSteps = () => mockGenerator - const result = await runProgrammaticStep(mockAgentState, mockParams) + const result = await runProgrammaticStep(mockParams) expect(result.endTurn).toBe(true) }) @@ -728,7 +730,7 @@ describe('runProgrammaticStep', () => { mockTemplate.handleSteps = () => mockGenerator - const result = await runProgrammaticStep(mockAgentState, mockParams) + const result = await runProgrammaticStep(mockParams) expect(executeToolCallSpy).toHaveBeenCalledTimes(2) // read_files + end_turn expect(result.endTurn).toBe(true) @@ -755,7 +757,7 @@ describe('runProgrammaticStep', () => { } }) - const result = await runProgrammaticStep(mockAgentState, mockParams) + const result = await runProgrammaticStep(mockParams) expect(result.agentState.output).toEqual({ status: 'complete' }) }) @@ -767,7 +769,7 @@ describe('runProgrammaticStep', () => { mockTemplate.handleSteps = () => mockGenerator - const result = await runProgrammaticStep(mockAgentState, mockParams) + const result = await runProgrammaticStep(mockParams) expect(result.agentState.messageHistory).toEqual([ ...mockAgentState.messageHistory, @@ -791,7 +793,7 @@ describe('runProgrammaticStep', () => { const responseChunks: string[] = [] mockParams.onResponseChunk = (chunk: string) => responseChunks.push(chunk) - const result = await runProgrammaticStep(mockAgentState, mockParams) + const result = await runProgrammaticStep(mockParams) expect(result.endTurn).toBe(true) expect(result.agentState.output?.error).toContain('Generator error') @@ -812,7 +814,7 @@ describe('runProgrammaticStep', () => { const responseChunks: string[] = [] mockParams.onResponseChunk = (chunk: string) => responseChunks.push(chunk) - const result = await runProgrammaticStep(mockAgentState, mockParams) + const result = await runProgrammaticStep(mockParams) expect(result.endTurn).toBe(true) expect(result.agentState.output?.error).toContain('Tool execution failed') @@ -825,7 +827,7 @@ describe('runProgrammaticStep', () => { mockTemplate.handleSteps = () => mockGenerator - const result = await runProgrammaticStep(mockAgentState, mockParams) + const result = await runProgrammaticStep(mockParams) expect(result.endTurn).toBe(true) expect(result.agentState.output?.error).toContain('Unknown error') @@ -867,7 +869,7 @@ describe('runProgrammaticStep', () => { // Don't mock executeToolCall - let it use the real implementation executeToolCallSpy.mockRestore() - const result = await runProgrammaticStep(mockAgentState, { + const result = await runProgrammaticStep({ ...mockParams, template: schemaTemplate, localAgentTemplates: { 'test-agent': schemaTemplate }, @@ -917,7 +919,7 @@ describe('runProgrammaticStep', () => { const responseChunks: string[] = [] mockParams.onResponseChunk = (chunk: string) => responseChunks.push(chunk) - const result = await runProgrammaticStep(mockAgentState, { + const result = await runProgrammaticStep({ ...mockParams, template: schemaTemplate, localAgentTemplates: { 'test-agent': schemaTemplate }, @@ -953,7 +955,7 @@ describe('runProgrammaticStep', () => { // Don't mock executeToolCall - let it use the real implementation executeToolCallSpy.mockRestore() - const result = await runProgrammaticStep(mockAgentState, { + const result = await runProgrammaticStep({ ...mockParams, template: noSchemaTemplate, localAgentTemplates: { 'test-agent': noSchemaTemplate }, @@ -990,7 +992,7 @@ describe('runProgrammaticStep', () => { // Don't mock executeToolCall - let it use the real implementation executeToolCallSpy.mockRestore() - const result = await runProgrammaticStep(mockAgentState, { + const result = await runProgrammaticStep({ ...mockParams, template: schemaWithoutSchemaTemplate, localAgentTemplates: { 'test-agent': schemaWithoutSchemaTemplate }, @@ -1019,7 +1021,7 @@ describe('runProgrammaticStep', () => { mockTemplate.handleSteps = () => mockGenerator - await runProgrammaticStep(mockAgentState, { + await runProgrammaticStep({ ...mockParams, stepsComplete: false, }) @@ -1041,7 +1043,7 @@ describe('runProgrammaticStep', () => { mockTemplate.handleSteps = () => mockGenerator - await runProgrammaticStep(mockAgentState, { + await runProgrammaticStep({ ...mockParams, stepsComplete: true, }) @@ -1081,7 +1083,7 @@ describe('runProgrammaticStep', () => { } }) - const result = await runProgrammaticStep(mockAgentState, { + const result = await runProgrammaticStep({ ...mockParams, stepsComplete: true, }) @@ -1120,7 +1122,7 @@ describe('runProgrammaticStep', () => { mockTemplate.toolNames = ['set_output', 'end_turn'] // First call to set STEP_ALL state - const result1 = await runProgrammaticStep(mockAgentState, { + const result1 = await runProgrammaticStep({ ...mockParams, stepsComplete: false, }) @@ -1128,8 +1130,9 @@ describe('runProgrammaticStep', () => { expect(generatorCallCount).toBe(1) // Second call with stepsComplete=false should return early due to STEP_ALL - const result2 = await runProgrammaticStep(result1.agentState, { + const result2 = await runProgrammaticStep({ ...mockParams, + agentState: result1.agentState, stepsComplete: false, }) expect(result2.endTurn).toBe(false) @@ -1142,8 +1145,9 @@ describe('runProgrammaticStep', () => { } }) - const result3 = await runProgrammaticStep(result2.agentState, { + const result3 = await runProgrammaticStep({ ...mockParams, + agentState: result2.agentState, stepsComplete: true, }) @@ -1206,7 +1210,7 @@ describe('runProgrammaticStep', () => { }) // First call with stepsComplete=true (post-processing mode) - const result = await runProgrammaticStep(mockAgentState, { + const result = await runProgrammaticStep({ ...mockParams, stepsComplete: true, }) @@ -1219,7 +1223,7 @@ describe('runProgrammaticStep', () => { expect(result.endTurn).toBe(false) // Should not end due to STEP expect(executeToolCallSpy).toHaveBeenCalledTimes(2) // read_files + write_file - const finalResult = await runProgrammaticStep(mockAgentState, { + const finalResult = await runProgrammaticStep({ ...mockParams, stepsComplete: false, }) @@ -1265,7 +1269,7 @@ describe('runProgrammaticStep', () => { mockTemplate.toolNames = ['read_files', 'write_file', 'end_turn'] // First call with stepsComplete=true - const result = await runProgrammaticStep(mockAgentState, { + const result = await runProgrammaticStep({ ...mockParams, stepsComplete: true, }) @@ -1386,7 +1390,7 @@ describe('runProgrammaticStep', () => { }) // Call with stepsComplete=true to trigger post-processing - const result = await runProgrammaticStep(mockAgentState, { + const result = await runProgrammaticStep({ ...mockParams, stepsComplete: true, }) @@ -1416,7 +1420,7 @@ describe('runProgrammaticStep', () => { mockTemplate.handleSteps = () => mockGenerator - await runProgrammaticStep(mockAgentState, mockParams) + await runProgrammaticStep(mockParams) // Logger is mocked, but we can verify the function completes without error expect(true).toBe(true) @@ -1429,7 +1433,7 @@ describe('runProgrammaticStep', () => { mockTemplate.handleSteps = () => mockGenerator - await runProgrammaticStep(mockAgentState, mockParams) + await runProgrammaticStep(mockParams) expect(getRequestContextSpy).toHaveBeenCalled() }) @@ -1442,7 +1446,7 @@ describe('runProgrammaticStep', () => { mockTemplate.handleSteps = () => mockGenerator - await runProgrammaticStep(mockAgentState, mockParams) + await runProgrammaticStep(mockParams) expect(executeToolCallSpy).toHaveBeenCalledWith( expect.objectContaining({ diff --git a/backend/src/__tests__/sandbox-generator.test.ts b/backend/src/__tests__/sandbox-generator.test.ts index cbe179d004..feae067aa5 100644 --- a/backend/src/__tests__/sandbox-generator.test.ts +++ b/backend/src/__tests__/sandbox-generator.test.ts @@ -1,7 +1,3 @@ -import { - clearMockedModules, - mockModule, -} from '@codebuff/common/testing/mock-modules' import { getInitialAgentState, type AgentState, @@ -18,15 +14,23 @@ import * as requestContext from '../websockets/request-context' import * as websocketAction from '../websockets/websocket-action' import type { AgentTemplate } from '../templates/types' +import type { Logger } from '@codebuff/types/logger' import type { WebSocket } from 'ws' +const logger: Logger = { + debug: () => {}, + error: () => {}, + info: () => {}, + warn: () => {}, +} + describe('QuickJS Sandbox Generator', () => { let mockAgentState: AgentState let mockParams: any let mockTemplate: AgentTemplate beforeEach(() => { - clearAgentGeneratorCache() + clearAgentGeneratorCache({ logger }) // Mock dependencies spyOn(agentRun, 'addAgentStep').mockImplementation( @@ -41,17 +45,6 @@ describe('QuickJS Sandbox Generator', () => { 'mock-uuid-0000-0000-0000-000000000000' as `${string}-${string}-${string}-${string}-${string}`, ) - // Mock logger - mockModule('@codebuff/backend/util/logger', () => ({ - logger: { - debug: () => {}, - error: () => {}, - info: () => {}, - warn: () => {}, - }, - withLoggerContext: async (context: any, fn: () => Promise) => fn(), - })) - // Reuse common test data structure mockAgentState = { ...getInitialAgentState(), @@ -85,6 +78,7 @@ describe('QuickJS Sandbox Generator', () => { // Common params structure mockParams = { + agentState: mockAgentState, template: mockTemplate, prompt: 'Test prompt', params: { testParam: 'value' }, @@ -101,12 +95,12 @@ describe('QuickJS Sandbox Generator', () => { localAgentTemplates: {}, stepsComplete: false, stepNumber: 1, + logger, } }) afterEach(() => { - clearAgentGeneratorCache() - clearMockedModules() + clearAgentGeneratorCache({ logger }) }) test('should execute string-based generator in QuickJS sandbox', async () => { @@ -126,7 +120,7 @@ describe('QuickJS Sandbox Generator', () => { mockParams.template = mockTemplate mockParams.localAgentTemplates = { 'test-vm-agent': mockTemplate } - const result = await runProgrammaticStep(mockAgentState, mockParams) + const result = await runProgrammaticStep(mockParams) expect(result.agentState.output).toEqual({ message: 'Hello from QuickJS sandbox!', @@ -155,7 +149,7 @@ describe('QuickJS Sandbox Generator', () => { mockParams.params = {} mockParams.localAgentTemplates = { 'test-vm-agent-error': mockTemplate } - const result = await runProgrammaticStep(mockAgentState, mockParams) + const result = await runProgrammaticStep(mockParams) expect(result.endTurn).toBe(true) expect(result.agentState.output?.error).toContain( diff --git a/backend/src/__tests__/spawn-agents-message-history.test.ts b/backend/src/__tests__/spawn-agents-message-history.test.ts index 102116e053..d27ef4da4a 100644 --- a/backend/src/__tests__/spawn-agents-message-history.test.ts +++ b/backend/src/__tests__/spawn-agents-message-history.test.ts @@ -13,28 +13,26 @@ import { import { mockFileContext, MockWebSocket } from './test-utils' import * as runAgentStep from '../run-agent-step' import { handleSpawnAgents } from '../tools/handlers/tool/spawn-agents' -import * as loggerModule from '../util/logger' import type { CodebuffToolCall } from '@codebuff/common/tools/list' import type { AgentTemplate } from '@codebuff/common/types/agent-template' import type { Message } from '@codebuff/common/types/messages/codebuff-message' +import type { Logger } from '@codebuff/types/logger' import type { WebSocket } from 'ws' +const logger: Logger = { + debug: () => {}, + error: () => {}, + info: () => {}, + warn: () => {}, +} + describe('Spawn Agents Message History', () => { let mockSendSubagentChunk: any let mockLoopAgentSteps: any let capturedSubAgentState: any beforeEach(() => { - // Mock logger to reduce noise in tests - spyOn(loggerModule.logger, 'debug').mockImplementation(() => {}) - spyOn(loggerModule.logger, 'error').mockImplementation(() => {}) - spyOn(loggerModule.logger, 'info').mockImplementation(() => {}) - spyOn(loggerModule.logger, 'warn').mockImplementation(() => {}) - spyOn(loggerModule, 'withLoggerContext').mockImplementation( - async (context: any, fn: () => Promise) => fn(), - ) - // Mock sendSubagentChunk mockSendSubagentChunk = mock(() => {}) @@ -42,7 +40,7 @@ describe('Spawn Agents Message History', () => { mockLoopAgentSteps = spyOn( runAgentStep, 'loopAgentSteps', - ).mockImplementation(async (ws, options) => { + ).mockImplementation(async (options) => { capturedSubAgentState = options.agentState return { agentState: { @@ -134,6 +132,7 @@ describe('Spawn Agents Message History', () => { agentState: sessionState.mainAgentState, system: 'Test system prompt', }, + logger, }) await result @@ -205,6 +204,7 @@ describe('Spawn Agents Message History', () => { agentState: sessionState.mainAgentState, system: 'Test system prompt', }, + logger, }) await result @@ -241,6 +241,7 @@ describe('Spawn Agents Message History', () => { agentState: sessionState.mainAgentState, system: 'Test system prompt', }, + logger, }) await result @@ -280,6 +281,7 @@ describe('Spawn Agents Message History', () => { agentState: sessionState.mainAgentState, system: 'Test system prompt', }, + logger, }) await result diff --git a/backend/src/__tests__/spawn-agents-permissions.test.ts b/backend/src/__tests__/spawn-agents-permissions.test.ts index 807595f7d7..8d70ea1863 100644 --- a/backend/src/__tests__/spawn-agents-permissions.test.ts +++ b/backend/src/__tests__/spawn-agents-permissions.test.ts @@ -15,12 +15,19 @@ import * as runAgentStep from '../run-agent-step' import { handleSpawnAgentInline } from '../tools/handlers/tool/spawn-agent-inline' import { getMatchingSpawn } from '../tools/handlers/tool/spawn-agent-utils' import { handleSpawnAgents } from '../tools/handlers/tool/spawn-agents' -import * as loggerModule from '../util/logger' import type { CodebuffToolCall } from '@codebuff/common/tools/list' import type { AgentTemplate } from '@codebuff/common/types/agent-template' +import type { Logger } from '@codebuff/types/logger' import type { WebSocket } from 'ws' +const logger: Logger = { + debug: () => {}, + error: () => {}, + info: () => {}, + warn: () => {}, +} + describe('Spawn Agents Permissions', () => { let mockSendSubagentChunk: any let mockLoopAgentSteps: any @@ -50,15 +57,6 @@ describe('Spawn Agents Permissions', () => { }) beforeEach(() => { - // Mock logger to reduce noise in tests - spyOn(loggerModule.logger, 'debug').mockImplementation(() => {}) - spyOn(loggerModule.logger, 'error').mockImplementation(() => {}) - spyOn(loggerModule.logger, 'info').mockImplementation(() => {}) - spyOn(loggerModule.logger, 'warn').mockImplementation(() => {}) - spyOn(loggerModule, 'withLoggerContext').mockImplementation( - async (context: any, fn: () => Promise) => fn(), - ) - // Mock sendSubagentChunk mockSendSubagentChunk = mock(() => {}) @@ -66,7 +64,7 @@ describe('Spawn Agents Permissions', () => { mockLoopAgentSteps = spyOn( runAgentStep, 'loopAgentSteps', - ).mockImplementation(async (ws, options) => { + ).mockImplementation(async (options) => { return { agentState: { ...options.agentState, @@ -262,6 +260,7 @@ describe('Spawn Agents Permissions', () => { agentState: sessionState.mainAgentState, system: 'Test system prompt', }, + logger, }) const output = await result @@ -295,6 +294,7 @@ describe('Spawn Agents Permissions', () => { agentState: sessionState.mainAgentState, system: 'Test system prompt', }, + logger, }) const output = await result @@ -330,6 +330,7 @@ describe('Spawn Agents Permissions', () => { agentState: sessionState.mainAgentState, system: 'Test system prompt', }, + logger, }) const output = await result @@ -367,6 +368,7 @@ describe('Spawn Agents Permissions', () => { agentState: sessionState.mainAgentState, system: 'Test system prompt', }, + logger, }) const output = await result @@ -403,6 +405,7 @@ describe('Spawn Agents Permissions', () => { agentState: sessionState.mainAgentState, system: 'Test system prompt', }, + logger, }) const output = await result @@ -436,6 +439,7 @@ describe('Spawn Agents Permissions', () => { agentState: sessionState.mainAgentState, system: 'Test system prompt', }, + logger, }) const output = await result @@ -486,6 +490,7 @@ describe('Spawn Agents Permissions', () => { agentState: sessionState.mainAgentState, system: 'Test system prompt', }, + logger, }) const output = await result @@ -536,6 +541,7 @@ describe('Spawn Agents Permissions', () => { agentState: sessionState.mainAgentState, system: 'Test system prompt', }, + logger, }) await result // Should not throw @@ -567,6 +573,7 @@ describe('Spawn Agents Permissions', () => { agentState: sessionState.mainAgentState, system: 'Test system prompt', }, + logger, }) await expect(result).rejects.toThrow( @@ -599,6 +606,7 @@ describe('Spawn Agents Permissions', () => { agentState: sessionState.mainAgentState, system: 'Test system prompt', }, + logger, }) await expect(result).rejects.toThrow('Agent type nonexistent not found') @@ -630,6 +638,7 @@ describe('Spawn Agents Permissions', () => { agentState: sessionState.mainAgentState, system: 'Test system prompt', }, + logger, }) await result // Should not throw @@ -664,6 +673,7 @@ describe('Spawn Agents Permissions', () => { agentState: sessionState.mainAgentState, system: 'Test system prompt', }, + logger, }) await result // Should not throw @@ -695,6 +705,7 @@ describe('Spawn Agents Permissions', () => { agentState: sessionState.mainAgentState, system: 'Test system prompt', }, + logger, }) await expect(result).rejects.toThrow( @@ -721,6 +732,7 @@ describe('Spawn Agents Permissions', () => { agentTemplate: parentAgent, localAgentTemplates: {}, }, + logger, }) }).toThrow('Missing WebSocket in state') expect(mockLoopAgentSteps).not.toHaveBeenCalled() diff --git a/backend/src/__tests__/subagent-streaming.test.ts b/backend/src/__tests__/subagent-streaming.test.ts index cfd24cda1d..4ca8982185 100644 --- a/backend/src/__tests__/subagent-streaming.test.ts +++ b/backend/src/__tests__/subagent-streaming.test.ts @@ -20,9 +20,17 @@ import * as loggerModule from '../util/logger' import type { AgentTemplate } from '../templates/types' import type { SendSubagentChunk } from '../tools/handlers/tool/spawn-agents' import type { CodebuffToolCall } from '@codebuff/common/tools/list' +import type { Logger } from '@codebuff/types/logger' import type { Mock } from 'bun:test' import type { WebSocket } from 'ws' +const logger: Logger = { + debug: () => {}, + error: () => {}, + info: () => {}, + warn: () => {}, +} + describe('Subagent Streaming', () => { let mockSendSubagentChunk: Mock let mockLoopAgentSteps: Mock<(typeof runAgentStep)['loopAgentSteps']> @@ -79,7 +87,7 @@ describe('Subagent Streaming', () => { mockLoopAgentSteps = spyOn( runAgentStep, 'loopAgentSteps', - ).mockImplementation(async (ws, options) => { + ).mockImplementation(async (options) => { // Simulate streaming chunks by calling the callback if (options.onResponseChunk) { options.onResponseChunk('Thinking about the problem...') @@ -165,6 +173,7 @@ describe('Subagent Streaming', () => { agentState, system: 'Test system prompt', }, + logger, }) await result @@ -242,6 +251,7 @@ describe('Subagent Streaming', () => { agentState, system: 'Test system prompt', }, + logger, }) await result diff --git a/backend/src/__tests__/tool-call-schema.test.ts b/backend/src/__tests__/tool-call-schema.test.ts index b810560fa5..ed22191e4b 100644 --- a/backend/src/__tests__/tool-call-schema.test.ts +++ b/backend/src/__tests__/tool-call-schema.test.ts @@ -1,7 +1,5 @@ import { describe, it, expect, beforeEach } from 'bun:test' -import { logger } from '../util/logger' - import type { WebSocket } from 'ws' describe('Backend Tool Call Schema', () => { @@ -221,7 +219,6 @@ export async function generateMockProjectStructureAnalysis( return analysis } catch (error) { - logger.error({ error }, 'Project analysis failed') return `Project analysis failed: ${error instanceof Error ? error.message : error}` } } @@ -276,7 +273,6 @@ src/utils.ts:2:import { readFileSync } from 'fs'`, return analysis } catch (error) { - logger.error({ error }, 'Dependency analysis failed') return `Dependency analysis failed: ${error instanceof Error ? error.message : error}` } } @@ -318,7 +314,6 @@ export async function generateMockFileContentAnalysis( return analysis } catch (error) { - logger.error({ error }, 'File content analysis failed') return `File content analysis failed: ${error instanceof Error ? error.message : error}` } } diff --git a/backend/src/__tests__/web-search-tool.test.ts b/backend/src/__tests__/web-search-tool.test.ts index c458d5179a..1b36608809 100644 --- a/backend/src/__tests__/web-search-tool.test.ts +++ b/backend/src/__tests__/web-search-tool.test.ts @@ -120,11 +120,13 @@ describe('web_search tool with researcher agent', () => { ...sessionState.mainAgentState, agentType: 'researcher' as const, } - const { agentTemplates } = assembleLocalAgentTemplates( - mockFileContextWithAgents, - ) + const { agentTemplates } = assembleLocalAgentTemplates({ + fileContext: mockFileContextWithAgents, + logger, + }) - await runAgentStep(new MockWebSocket() as unknown as WebSocket, { + await runAgentStep({ + ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, userInputId: 'test-input', @@ -137,11 +139,14 @@ describe('web_search tool with researcher agent', () => { agentState, prompt: 'Search for test', params: undefined, + logger, }) // Just verify that searchWeb was called - expect(linkupApi.searchWeb).toHaveBeenCalledWith('test query', { + expect(linkupApi.searchWeb).toHaveBeenCalledWith({ + query: 'test query', depth: 'standard', + logger: expect.anything(), }) }) @@ -165,34 +170,33 @@ describe('web_search tool with researcher agent', () => { ...sessionState.mainAgentState, agentType: 'researcher' as const, } - const { agentTemplates } = assembleLocalAgentTemplates( - mockFileContextWithAgents, - ) + const { agentTemplates } = assembleLocalAgentTemplates({ + fileContext: mockFileContextWithAgents, + logger, + }) - const { agentState: newAgentState } = await runAgentStep( - new MockWebSocket() as unknown as WebSocket, - { - system: 'Test system prompt', - userId: TEST_USER_ID, - userInputId: 'test-input', - clientSessionId: 'test-session', - fingerprintId: 'test-fingerprint', - onResponseChunk: () => {}, - agentType: 'researcher', - fileContext: mockFileContext, - localAgentTemplates: agentTemplates, - agentState, - prompt: 'Search for Next.js 15 new features', - params: undefined, - }, - ) + const { agentState: newAgentState } = await runAgentStep({ + ws: new MockWebSocket() as unknown as WebSocket, + system: 'Test system prompt', + userId: TEST_USER_ID, + userInputId: 'test-input', + clientSessionId: 'test-session', + fingerprintId: 'test-fingerprint', + onResponseChunk: () => {}, + agentType: 'researcher', + fileContext: mockFileContext, + localAgentTemplates: agentTemplates, + agentState, + prompt: 'Search for Next.js 15 new features', + params: undefined, + logger, + }) - expect(linkupApi.searchWeb).toHaveBeenCalledWith( - 'Next.js 15 new features', - { - depth: 'standard', - }, - ) + expect(linkupApi.searchWeb).toHaveBeenCalledWith({ + query: 'Next.js 15 new features', + depth: 'standard', + logger: expect.anything(), + }) // Check that the search results were added to the message history const toolResultMessages = newAgentState.messageHistory.filter( @@ -225,11 +229,13 @@ describe('web_search tool with researcher agent', () => { ...sessionState.mainAgentState, agentType: 'researcher' as const, } - const { agentTemplates } = assembleLocalAgentTemplates( - mockFileContextWithAgents, - ) + const { agentTemplates } = assembleLocalAgentTemplates({ + fileContext: mockFileContextWithAgents, + logger, + }) - await runAgentStep(new MockWebSocket() as unknown as WebSocket, { + await runAgentStep({ + ws: new MockWebSocket() as unknown as WebSocket, system: 'Test system prompt', userId: TEST_USER_ID, userInputId: 'test-input', @@ -242,14 +248,14 @@ describe('web_search tool with researcher agent', () => { agentState, prompt: 'Search for React Server Components tutorial with deep search', params: undefined, + logger, }) - expect(linkupApi.searchWeb).toHaveBeenCalledWith( - 'React Server Components tutorial', - { - depth: 'deep', - }, - ) + expect(linkupApi.searchWeb).toHaveBeenCalledWith({ + query: 'React Server Components tutorial', + depth: 'deep', + logger: expect.anything(), + }) }) test('should handle case when no search results are found', async () => { @@ -267,35 +273,34 @@ describe('web_search tool with researcher agent', () => { ...sessionState.mainAgentState, agentType: 'researcher' as const, } - const { agentTemplates } = assembleLocalAgentTemplates( - mockFileContextWithAgents, - ) + const { agentTemplates } = assembleLocalAgentTemplates({ + fileContext: mockFileContextWithAgents, + logger, + }) - const { agentState: newAgentState } = await runAgentStep( - new MockWebSocket() as unknown as WebSocket, - { - system: 'Test system prompt', - userId: TEST_USER_ID, - userInputId: 'test-input', - clientSessionId: 'test-session', - fingerprintId: 'test-fingerprint', - onResponseChunk: () => {}, - agentType: 'researcher', - fileContext: mockFileContext, - localAgentTemplates: agentTemplates, - agentState, - prompt: "Search for something that doesn't exist", - params: undefined, - }, - ) + const { agentState: newAgentState } = await runAgentStep({ + ws: new MockWebSocket() as unknown as WebSocket, + system: 'Test system prompt', + userId: TEST_USER_ID, + userInputId: 'test-input', + clientSessionId: 'test-session', + fingerprintId: 'test-fingerprint', + onResponseChunk: () => {}, + agentType: 'researcher', + fileContext: mockFileContext, + localAgentTemplates: agentTemplates, + agentState, + prompt: "Search for something that doesn't exist", + params: undefined, + logger, + }) // Verify that searchWeb was called - expect(linkupApi.searchWeb).toHaveBeenCalledWith( - 'very obscure search query that returns nothing', - { - depth: 'standard', - }, - ) + expect(linkupApi.searchWeb).toHaveBeenCalledWith({ + query: 'very obscure search query that returns nothing', + depth: 'standard', + logger: expect.anything(), + }) // Check that the "no results found" message was added const toolResultMessages = newAgentState.messageHistory.filter( @@ -326,31 +331,33 @@ describe('web_search tool with researcher agent', () => { ...sessionState.mainAgentState, agentType: 'researcher' as const, } - const { agentTemplates } = assembleLocalAgentTemplates( - mockFileContextWithAgents, - ) + const { agentTemplates } = assembleLocalAgentTemplates({ + fileContext: mockFileContextWithAgents, + logger, + }) - const { agentState: newAgentState } = await runAgentStep( - new MockWebSocket() as unknown as WebSocket, - { - system: 'Test system prompt', - userId: TEST_USER_ID, - userInputId: 'test-input', - clientSessionId: 'test-session', - fingerprintId: 'test-fingerprint', - onResponseChunk: () => {}, - agentType: 'researcher', - fileContext: mockFileContext, - localAgentTemplates: agentTemplates, - agentState, - prompt: 'Search for something', - params: undefined, - }, - ) + const { agentState: newAgentState } = await runAgentStep({ + ws: new MockWebSocket() as unknown as WebSocket, + system: 'Test system prompt', + userId: TEST_USER_ID, + userInputId: 'test-input', + clientSessionId: 'test-session', + fingerprintId: 'test-fingerprint', + onResponseChunk: () => {}, + agentType: 'researcher', + fileContext: mockFileContext, + localAgentTemplates: agentTemplates, + agentState, + prompt: 'Search for something', + params: undefined, + logger, + }) // Verify that searchWeb was called - expect(linkupApi.searchWeb).toHaveBeenCalledWith('test query', { + expect(linkupApi.searchWeb).toHaveBeenCalledWith({ + query: 'test query', depth: 'standard', + logger: expect.anything(), }) // Check that the error message was added @@ -381,31 +388,33 @@ describe('web_search tool with researcher agent', () => { ...sessionState.mainAgentState, agentType: 'researcher' as const, } - const { agentTemplates } = assembleLocalAgentTemplates( - mockFileContextWithAgents, - ) + const { agentTemplates } = assembleLocalAgentTemplates({ + fileContext: mockFileContextWithAgents, + logger, + }) - const { agentState: newAgentState } = await runAgentStep( - new MockWebSocket() as unknown as WebSocket, - { - system: 'Test system prompt', - userId: TEST_USER_ID, - userInputId: 'test-input', - clientSessionId: 'test-session', - fingerprintId: 'test-fingerprint', - onResponseChunk: () => {}, - agentType: 'researcher', - fileContext: mockFileContext, - localAgentTemplates: agentTemplates, - agentState, - prompt: 'Search for something', - params: undefined, - }, - ) + const { agentState: newAgentState } = await runAgentStep({ + ws: new MockWebSocket() as unknown as WebSocket, + system: 'Test system prompt', + userId: TEST_USER_ID, + userInputId: 'test-input', + clientSessionId: 'test-session', + fingerprintId: 'test-fingerprint', + onResponseChunk: () => {}, + agentType: 'researcher', + fileContext: mockFileContext, + localAgentTemplates: agentTemplates, + agentState, + prompt: 'Search for something', + params: undefined, + logger, + }) // Verify that searchWeb was called - expect(linkupApi.searchWeb).toHaveBeenCalledWith('test query', { + expect(linkupApi.searchWeb).toHaveBeenCalledWith({ + query: 'test query', depth: 'standard', + logger: expect.anything(), }) }) @@ -426,31 +435,33 @@ describe('web_search tool with researcher agent', () => { ...sessionState.mainAgentState, agentType: 'researcher' as const, } - const { agentTemplates } = assembleLocalAgentTemplates( - mockFileContextWithAgents, - ) + const { agentTemplates } = assembleLocalAgentTemplates({ + fileContext: mockFileContextWithAgents, + logger, + }) - const { agentState: newAgentState } = await runAgentStep( - new MockWebSocket() as unknown as WebSocket, - { - system: 'Test system prompt', - userId: TEST_USER_ID, - userInputId: 'test-input', - clientSessionId: 'test-session', - fingerprintId: 'test-fingerprint', - onResponseChunk: () => {}, - agentType: 'researcher', - fileContext: mockFileContext, - localAgentTemplates: agentTemplates, - agentState, - prompt: 'Search for something', - params: undefined, - }, - ) + const { agentState: newAgentState } = await runAgentStep({ + ws: new MockWebSocket() as unknown as WebSocket, + system: 'Test system prompt', + userId: TEST_USER_ID, + userInputId: 'test-input', + clientSessionId: 'test-session', + fingerprintId: 'test-fingerprint', + onResponseChunk: () => {}, + agentType: 'researcher', + fileContext: mockFileContext, + localAgentTemplates: agentTemplates, + agentState, + prompt: 'Search for something', + params: undefined, + logger, + }) // Verify that searchWeb was called - expect(linkupApi.searchWeb).toHaveBeenCalledWith('test query', { + expect(linkupApi.searchWeb).toHaveBeenCalledWith({ + query: 'test query', depth: 'standard', + logger: expect.anything(), }) // Check that the error message was added @@ -483,31 +494,33 @@ describe('web_search tool with researcher agent', () => { ...sessionState.mainAgentState, agentType: 'researcher' as const, } - const { agentTemplates } = assembleLocalAgentTemplates( - mockFileContextWithAgents, - ) + const { agentTemplates } = assembleLocalAgentTemplates({ + fileContext: mockFileContextWithAgents, + logger, + }) - const { agentState: newAgentState } = await runAgentStep( - new MockWebSocket() as unknown as WebSocket, - { - system: 'Test system prompt', - userId: TEST_USER_ID, - userInputId: 'test-input', - clientSessionId: 'test-session', - fingerprintId: 'test-fingerprint', - onResponseChunk: () => {}, - agentType: 'researcher', - fileContext: mockFileContextWithAgents, - localAgentTemplates: agentTemplates, - agentState, - prompt: 'Test search result formatting', - params: undefined, - }, - ) + const { agentState: newAgentState } = await runAgentStep({ + ws: new MockWebSocket() as unknown as WebSocket, + system: 'Test system prompt', + userId: TEST_USER_ID, + userInputId: 'test-input', + clientSessionId: 'test-session', + fingerprintId: 'test-fingerprint', + onResponseChunk: () => {}, + agentType: 'researcher', + fileContext: mockFileContextWithAgents, + localAgentTemplates: agentTemplates, + agentState, + prompt: 'Test search result formatting', + params: undefined, + logger, + }) // Verify that searchWeb was called - expect(linkupApi.searchWeb).toHaveBeenCalledWith('test formatting', { + expect(linkupApi.searchWeb).toHaveBeenCalledWith({ + query: 'test formatting', depth: 'standard', + logger: expect.anything(), }) // Check that the search results were formatted correctly diff --git a/backend/src/admin/grade-runs.ts b/backend/src/admin/grade-runs.ts index 1d9f9b884f..d2517eaefa 100644 --- a/backend/src/admin/grade-runs.ts +++ b/backend/src/admin/grade-runs.ts @@ -4,6 +4,7 @@ import { closeXml } from '@codebuff/common/util/xml' import { promptAiSdk } from '../llm-apis/vercel-ai-sdk/ai-sdk' import type { Relabel, GetRelevantFilesTrace } from '@codebuff/bigquery' +import type { Logger } from '@codebuff/types/logger' const PROMPT = ` You are an evaluator system, measuring how well various models perform at selecting the most relevant files for a given user request. @@ -96,11 +97,12 @@ function extractResponse(response: string): { } } -export async function gradeRun(tracesAndRelabels: { +export async function gradeRun(params: { trace: GetRelevantFilesTrace relabels: Relabel[] + logger: Logger }) { - const { trace, relabels } = tracesAndRelabels + const { trace, relabels, logger } = params const messages = trace.payload.messages const originalOutput = trace.payload.output @@ -152,6 +154,7 @@ export async function gradeRun(tracesAndRelabels: { // type: 'enabled', // budget_tokens: 10000, // }, + logger, }) const { scores } = extractResponse(response) diff --git a/backend/src/admin/relabelRuns.ts b/backend/src/admin/relabelRuns.ts index d643a1298a..5852deae5e 100644 --- a/backend/src/admin/relabelRuns.ts +++ b/backend/src/admin/relabelRuns.ts @@ -14,7 +14,6 @@ import { closeXml } from '@codebuff/common/util/xml' import { rerank } from '../llm-apis/relace-api' import { promptAiSdk } from '../llm-apis/vercel-ai-sdk/ai-sdk' -import { logger } from '../util/logger' import { messagesWithSystem } from '../util/messages' import type { System } from '../llm-apis/claude' @@ -26,11 +25,17 @@ import type { Relabel, } from '@codebuff/bigquery' import type { Message } from '@codebuff/common/types/messages/codebuff-message' +import type { Logger } from '@codebuff/types/logger' import type { Request, Response } from 'express' // --- GET Handler Logic --- -export async function getTracesForUserHandler(req: Request, res: Response) { +export async function getTracesForUserHandler(params: { + req: Request + res: Response + logger: Logger +}) { + const { req, res, logger } = params try { // Extract userId from the query parameters const userId = req.query.userId as string @@ -124,7 +129,12 @@ const modelsToRelabel = [ finetunedVertexModels.ft_filepicker_topk_002, ] as const -export async function relabelForUserHandler(req: Request, res: Response) { +export async function relabelForUserHandler(params: { + req: Request + res: Response + logger: Logger +}) { + const { req, res, logger } = params try { // Extract userId from the URL query params const userId = req.query.userId as string @@ -140,7 +150,11 @@ export async function relabelForUserHandler(req: Request, res: Response) { const allResults = [] - const relaceResults = relabelUsingFullFilesForUser({ userId, limit }) + const relaceResults = relabelUsingFullFilesForUser({ + userId, + limit, + logger, + }) // Process each model for (const model of modelsToRelabel) { @@ -178,6 +192,7 @@ export async function relabelForUserHandler(req: Request, res: Response) { fingerprintId: 'relabel-trace-api', userInputId: 'relabel-trace-api', userId: TEST_USER_ID, + logger, }) // Create relabel record @@ -253,8 +268,9 @@ export async function relabelForUserHandler(req: Request, res: Response) { async function relabelUsingFullFilesForUser(params: { userId: string limit: number + logger: Logger }) { - const { userId, limit } = params + const { userId, limit, logger } = params // TODO: We need to figure out changing _everything_ to use `getTracesAndAllDataForUser` const tracesBundles = await getTracesAndAllDataForUser(userId) @@ -278,7 +294,7 @@ async function relabelUsingFullFilesForUser(params: { } if (!traceBundle.relabels.some((r) => r.model === 'relace-ranker')) { - relabelPromises.push(relabelWithRelace({ trace, fileBlobs })) + relabelPromises.push(relabelWithRelace({ trace, fileBlobs, logger })) didRelabel = true } for (const model of [ @@ -291,7 +307,12 @@ async function relabelUsingFullFilesForUser(params: { ) ) { relabelPromises.push( - relabelWithClaudeWithFullFileContext({ trace, fileBlobs, model }), + relabelWithClaudeWithFullFileContext({ + trace, + fileBlobs, + model, + logger, + }), ) didRelabel = true } @@ -315,8 +336,9 @@ async function relabelUsingFullFilesForUser(params: { async function relabelWithRelace(params: { trace: GetRelevantFilesTrace fileBlobs: GetExpandedFileContextForTrainingBlobTrace + logger: Logger }) { - const { trace, fileBlobs } = params + const { trace, fileBlobs, logger } = params logger.info(`Relabeling ${trace.id} with Relace`) const messages = trace.payload.messages || [] const queryBody = @@ -335,12 +357,15 @@ async function relabelWithRelace(params: { }), ) - const relaced = await rerank(filesWithPath, query, { + const relaced = await rerank({ + files: filesWithPath, + prompt: query, clientSessionId: trace.payload.client_session_id, fingerprintId: trace.payload.fingerprint_id, userInputId: trace.payload.user_input_id, userId: 'test-user-id', // Make sure we don't bill em for it!! messageId: trace.id, + logger, }) const relabel = { @@ -367,8 +392,9 @@ export async function relabelWithClaudeWithFullFileContext(params: { fileBlobs: GetExpandedFileContextForTrainingBlobTrace model: string dataset?: string + logger: Logger }) { - const { trace, fileBlobs, model, dataset } = params + const { trace, fileBlobs, model, dataset, logger } = params if (dataset) { await setupBigQuery({ dataset, logger }) } @@ -401,13 +427,17 @@ export async function relabelWithClaudeWithFullFileContext(params: { } const output = await promptAiSdk({ - messages: messagesWithSystem({ messages: trace.payload.messages as Message[], system }), + messages: messagesWithSystem({ + messages: trace.payload.messages as Message[], + system, + }), model: model as any, // Model type is string here for flexibility clientSessionId: 'relabel-trace-api', fingerprintId: 'relabel-trace-api', userInputId: 'relabel-trace-api', userId: TEST_USER_ID, maxOutputTokens: 1000, + logger, }) const relabel = { diff --git a/backend/src/api/__tests__/validate-agent-name.test.ts b/backend/src/api/__tests__/validate-agent-name.test.ts index cc42186585..8e6c455755 100644 --- a/backend/src/api/__tests__/validate-agent-name.test.ts +++ b/backend/src/api/__tests__/validate-agent-name.test.ts @@ -19,9 +19,9 @@ import type { } from 'express' function createMockReq(query: Record): Partial { - return { - query, - headers: { 'x-codebuff-api-key': 'test-api-key' } + return { + query, + headers: { 'x-codebuff-api-key': 'test-api-key' }, } as any } @@ -78,7 +78,11 @@ describe('validateAgentNameHandler', () => { await validateAgentNameHandler(req as any, res as any, noopNext) - expect(spy).toHaveBeenCalledWith(agentId, {}) + expect(spy).toHaveBeenCalledWith({ + agentId, + localAgentTemplates: {}, + logger: expect.anything(), + }) expect(res.status).toHaveBeenCalledWith(200) expect(res.jsonPayload.valid).toBe(true) expect(res.jsonPayload.source).toBe('published') @@ -96,7 +100,11 @@ describe('validateAgentNameHandler', () => { await validateAgentNameHandler(req as any, res as any, noopNext) - expect(spy).toHaveBeenCalledWith(agentId, {}) + expect(spy).toHaveBeenCalledWith({ + agentId, + localAgentTemplates: {}, + logger: expect.anything(), + }) expect(res.status).toHaveBeenCalledWith(200) expect(res.jsonPayload.valid).toBe(true) expect(res.jsonPayload.source).toBe('published') @@ -114,7 +122,11 @@ describe('validateAgentNameHandler', () => { await validateAgentNameHandler(req as any, res as any, noopNext) - expect(spy).toHaveBeenCalledWith(agentId, {}) + expect(spy).toHaveBeenCalledWith({ + agentId, + localAgentTemplates: {}, + logger: expect.anything(), + }) expect(res.status).toHaveBeenCalledWith(200) expect(res.jsonPayload.valid).toBe(false) }) diff --git a/backend/src/api/agents.ts b/backend/src/api/agents.ts index 8563249348..979a2598e7 100644 --- a/backend/src/api/agents.ts +++ b/backend/src/api/agents.ts @@ -77,7 +77,7 @@ export async function validateAgentNameHandler( } // Check published agents (database) - const found = await getAgentTemplate(agentId, {}) + const found = await getAgentTemplate({ agentId, localAgentTemplates: {}, logger }) if (found) { const result = { valid: true as const, diff --git a/backend/src/api/usage.ts b/backend/src/api/usage.ts index 5154b15ac5..c4d9b013c0 100644 --- a/backend/src/api/usage.ts +++ b/backend/src/api/usage.ts @@ -48,6 +48,7 @@ async function usageHandler( fingerprintId, authToken, clientSessionId, + logger, }) if (authResult) { const errorMessage = @@ -95,6 +96,7 @@ async function usageHandler( fingerprintId, userId, clientSessionId, + logger, }) return res.status(200).json(usageResponse) diff --git a/backend/src/fast-rewrite.ts b/backend/src/fast-rewrite.ts index be95d8fb1e..f29f60920c 100644 --- a/backend/src/fast-rewrite.ts +++ b/backend/src/fast-rewrite.ts @@ -7,12 +7,12 @@ import { promptFlashWithFallbacks } from './llm-apis/gemini-with-fallbacks' import { promptRelaceAI } from './llm-apis/relace-api' import { promptAiSdk } from './llm-apis/vercel-ai-sdk/ai-sdk' -import type { Logger } from '@codebuff/types/logger' import type { CodebuffToolMessage } from '@codebuff/common/tools/list' import type { Message, ToolMessage, } from '@codebuff/common/types/messages/codebuff-message' +import type { Logger } from '@codebuff/types/logger' export async function fastRewrite(params: { initialContent: string @@ -40,19 +40,18 @@ export async function fastRewrite(params: { } = params const relaceStartTime = Date.now() const messageId = generateCompactId('cb-') - let response = await promptRelaceAI( - initialContent, + let response = await promptRelaceAI({ + initialCode: initialContent, editSnippet, instructions, - { - clientSessionId, - fingerprintId, - userInputId, - userId, - userMessage, - messageId, - }, - ) + clientSessionId, + fingerprintId, + userInputId, + userId, + userMessage, + messageId, + logger, + }) const relaceDuration = Date.now() - relaceStartTime // Check if response still contains lazy edits @@ -148,6 +147,7 @@ Please output just the complete updated file content with the edit applied and n fingerprintId, userInputId, userId, + logger, }) return parseMarkdownCodeBlock(response) + '\n' @@ -243,12 +243,14 @@ Do not write anything else. content: prompt, }, ) - const response = await promptFlashWithFallbacks(messages, { + const response = await promptFlashWithFallbacks({ + messages, clientSessionId, fingerprintId, userInputId, model: models.openrouter_gemini2_5_flash, userId, + logger, }) const shouldAddPlaceholderComments = response.includes('LOCAL_CHANGE_ONLY') logger.debug( diff --git a/backend/src/find-files/request-files-prompt.ts b/backend/src/find-files/request-files-prompt.ts index b714a950a8..fd3def9def 100644 --- a/backend/src/find-files/request-files-prompt.ts +++ b/backend/src/find-files/request-files-prompt.ts @@ -15,7 +15,6 @@ import { range, shuffle, uniq } from 'lodash' import { CustomFilePickerConfigSchema } from './custom-file-picker-config' import { promptFlashWithFallbacks } from '../llm-apis/gemini-with-fallbacks' import { promptAiSdk } from '../llm-apis/vercel-ai-sdk/ai-sdk' -import { logger } from '../util/logger' import { castAssistantMessage, messagesWithSystem, @@ -31,14 +30,17 @@ import type { } from '@codebuff/bigquery' import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { ProjectFileContext } from '@codebuff/common/util/file' +import type { Logger } from '@codebuff/types/logger' const NUMBER_OF_EXAMPLE_FILES = 100 const MAX_FILES_PER_REQUEST = 30 -export async function getCustomFilePickerConfigForOrg( - orgId: string | undefined, - isRepoApprovedForUserInOrg: boolean | undefined, -): Promise { +export async function getCustomFilePickerConfigForOrg(params: { + orgId: string | undefined + isRepoApprovedForUserInOrg: boolean | undefined + logger: Logger +}): Promise { + const { orgId, isRepoApprovedForUserInOrg, logger } = params if (!orgId || !isRepoApprovedForUserInOrg) { return null } @@ -118,30 +120,40 @@ function isValidFilePickerModelName( return Object.keys(finetunedVertexModels).includes(modelName) } -export async function requestRelevantFiles( - { +export async function requestRelevantFiles(params: { + messages: Message[] + system: string | Array + fileContext: ProjectFileContext + assistantPrompt: string | null + agentStepId: string + clientSessionId: string + fingerprintId: string + userInputId: string + userId: string | undefined + repoId: string | undefined + logger: Logger +}) { + const { messages, system, - }: { - messages: Message[] - system: string | Array - }, - fileContext: ProjectFileContext, - assistantPrompt: string | null, - agentStepId: string, - clientSessionId: string, - fingerprintId: string, - userInputId: string, - userId: string | undefined, - repoId: string | undefined, -) { + fileContext, + assistantPrompt, + agentStepId, + clientSessionId, + fingerprintId, + userInputId, + userId, + repoId, + logger, + } = params // Check for organization custom file picker feature const requestContext = getRequestContext() const orgId = requestContext?.approvedOrgIdForRepo - const customFilePickerConfig = await getCustomFilePickerConfigForOrg( + const customFilePickerConfig = await getCustomFilePickerConfigForOrg({ orgId, - requestContext?.isRepoApprovedForUserInOrg, - ) + isRepoApprovedForUserInOrg: requestContext?.isRepoApprovedForUserInOrg, + logger, + }) const countPerRequest = 12 @@ -180,21 +192,20 @@ export async function requestRelevantFiles( } } - const keyPromise = getRelevantFiles( - { - messages: messagesExcludingLastIfByUser, - system, - }, - keyPrompt, - 'Key', + const keyPromise = getRelevantFiles({ + messages: messagesExcludingLastIfByUser, + system, + userPrompt: keyPrompt, + requestType: 'Key', agentStepId, clientSessionId, fingerprintId, userInputId, userId, repoId, - modelIdForRequest, - ).catch((error) => { + modelId: modelIdForRequest, + logger, + }).catch((error) => { logger.error({ error }, 'Error requesting key files') return { files: [] as string[], duration: 0 } }) @@ -216,23 +227,32 @@ export async function requestRelevantFiles( return candidateFiles.slice(0, maxFilesPerRequest) } -export async function requestRelevantFilesForTraining( - { +export async function requestRelevantFilesForTraining(params: { + messages: Message[] + system: string | Array + fileContext: ProjectFileContext + assistantPrompt: string | null + agentStepId: string + clientSessionId: string + fingerprintId: string + userInputId: string + userId: string | undefined + repoId: string | undefined + logger: Logger +}) { + const { messages, system, - }: { - messages: Message[] - system: string | Array - }, - fileContext: ProjectFileContext, - assistantPrompt: string | null, - agentStepId: string, - clientSessionId: string, - fingerprintId: string, - userInputId: string, - userId: string | undefined, - repoId: string | undefined, -) { + fileContext, + assistantPrompt, + agentStepId, + clientSessionId, + fingerprintId, + userInputId, + userId, + repoId, + logger, + } = params const COUNT = 50 const lastMessage = messages[messages.length - 1] @@ -258,35 +278,33 @@ export async function requestRelevantFilesForTraining( COUNT, ) - const keyFiles = await getRelevantFilesForTraining( - { - messages: messagesExcludingLastIfByUser, - system, - }, - keyFilesPrompt, - 'Key', + const keyFiles = await getRelevantFilesForTraining({ + messages: messagesExcludingLastIfByUser, + system, + userPrompt: keyFilesPrompt, + requestType: 'Key', agentStepId, clientSessionId, fingerprintId, userInputId, userId, repoId, - ) + logger, + }) - const nonObviousFiles = await getRelevantFilesForTraining( - { - messages: messagesExcludingLastIfByUser, - system, - }, - nonObviousPrompt, - 'Non-Obvious', + const nonObviousFiles = await getRelevantFilesForTraining({ + messages: messagesExcludingLastIfByUser, + system, + userPrompt: nonObviousPrompt, + requestType: 'Non-Obvious', agentStepId, clientSessionId, fingerprintId, userInputId, userId, repoId, - ) + logger, + }) const candidateFiles = [...keyFiles.files, ...nonObviousFiles.files] const validatedFiles = validateFilePaths(uniq(candidateFiles)) @@ -297,37 +315,51 @@ export async function requestRelevantFilesForTraining( return validatedFiles.slice(0, MAX_FILES_PER_REQUEST) } -async function getRelevantFiles( - { +async function getRelevantFiles(params: { + messages: Message[] + system: string | Array + userPrompt: string + requestType: string + agentStepId: string + clientSessionId: string + fingerprintId: string + userInputId: string + userId: string | undefined + repoId: string | undefined + modelId?: FinetunedVertexModel + logger: Logger +}) { + const { messages, system, - }: { - messages: Message[] - system: string | Array - }, - userPrompt: string, - requestType: string, - agentStepId: string, - clientSessionId: string, - fingerprintId: string, - userInputId: string, - userId: string | undefined, - repoId: string | undefined, - modelId?: FinetunedVertexModel, -) { + userPrompt, + requestType, + agentStepId, + clientSessionId, + fingerprintId, + userInputId, + userId, + repoId, + modelId, + logger, + } = params const bufferTokens = 100_000 - const messagesWithPrompt = getMessagesSubset( - [ + const messagesWithPrompt = getMessagesSubset({ + messages: [ ...messages, { role: 'user' as const, content: userPrompt, }, ], - bufferTokens, - ) + otherTokens: bufferTokens, + logger, + }) const start = performance.now() - let codebuffMessages = messagesWithSystem({ messages: messagesWithPrompt, system }) + let codebuffMessages = messagesWithSystem({ + messages: messagesWithPrompt, + system, + }) // Converts assistant messages to user messages for finetuned model codebuffMessages = codebuffMessages @@ -341,13 +373,15 @@ async function getRelevantFiles( .filter((msg) => msg !== null) const finetunedModel = modelId ?? finetunedVertexModels.ft_filepicker_010 - let response = await promptFlashWithFallbacks(codebuffMessages, { + let response = await promptFlashWithFallbacks({ + messages: codebuffMessages, clientSessionId, userInputId, model: models.openrouter_gemini2_5_flash, userId, useFinetunedModel: finetunedModel, fingerprintId, + logger, }) const end = performance.now() const duration = end - start @@ -380,34 +414,44 @@ async function getRelevantFiles( return { files, duration, requestType, response } } -async function getRelevantFilesForTraining( - { +async function getRelevantFilesForTraining(params: { + messages: Message[] + system: string | Array + userPrompt: string + requestType: string + agentStepId: string + clientSessionId: string + fingerprintId: string + userInputId: string + userId: string | undefined + repoId: string | undefined + logger: Logger +}) { + const { messages, system, - }: { - messages: Message[] - system: string | Array - }, - userPrompt: string, - requestType: string, - agentStepId: string, - clientSessionId: string, - fingerprintId: string, - userInputId: string, - userId: string | undefined, - repoId: string | undefined, -) { + userPrompt, + requestType, + agentStepId, + clientSessionId, + fingerprintId, + userInputId, + userId, + repoId, + logger, + } = params const bufferTokens = 100_000 - const messagesWithPrompt = getMessagesSubset( - [ + const messagesWithPrompt = getMessagesSubset({ + messages: [ ...messages, { role: 'user' as const, content: userPrompt, }, ], - bufferTokens, - ) + otherTokens: bufferTokens, + logger, + }) const start = performance.now() let response = await promptAiSdk({ messages: messagesWithSystem({ messages: messagesWithPrompt, system }), @@ -417,6 +461,7 @@ async function getRelevantFilesForTraining( model: models.openrouter_claude_sonnet_4, userId, chargeUser: false, + logger, }) const end = performance.now() @@ -489,7 +534,10 @@ function generateNonObviousRequestFilesPrompt( fileContext: ProjectFileContext, count: number, ): string { - const exampleFiles = getExampleFileList({ fileContext, count: NUMBER_OF_EXAMPLE_FILES }) + const exampleFiles = getExampleFileList({ + fileContext, + count: NUMBER_OF_EXAMPLE_FILES, + }) return ` Your task is to find the second-order relevant files for the following user request (in quotes). @@ -544,7 +592,10 @@ function generateKeyRequestFilesPrompt( fileContext: ProjectFileContext, count: number, ): string { - const exampleFiles = getExampleFileList({ fileContext, count: NUMBER_OF_EXAMPLE_FILES }) + const exampleFiles = getExampleFileList({ + fileContext, + count: NUMBER_OF_EXAMPLE_FILES, + }) return ` Your task is to find the most relevant files for the following user request (in quotes). diff --git a/backend/src/generate-diffs-prompt.ts b/backend/src/generate-diffs-prompt.ts index f9a113be42..3a8d04847b 100644 --- a/backend/src/generate-diffs-prompt.ts +++ b/backend/src/generate-diffs-prompt.ts @@ -5,12 +5,15 @@ import { } from '@codebuff/common/util/file' import { promptAiSdk } from './llm-apis/vercel-ai-sdk/ai-sdk' -import { logger } from './util/logger' -export const parseAndGetDiffBlocksSingleFile = ( - newContent: string, - oldFileContent: string, -) => { +import type { Logger } from '@codebuff/types/logger' + +export const parseAndGetDiffBlocksSingleFile = (params: { + newContent: string + oldFileContent: string + logger: Logger +}) => { + const { newContent, oldFileContent, logger } = params const diffBlocksThatDidntMatch: { searchContent: string replaceContent: string @@ -138,6 +141,7 @@ export async function retryDiffBlocksPrompt(params: { userInputId: string userId: string | undefined diffBlocksThatDidntMatch: { searchContent: string; replaceContent: string }[] + logger: Logger }) { const { filePath, @@ -147,6 +151,7 @@ export async function retryDiffBlocksPrompt(params: { userInputId, userId, diffBlocksThatDidntMatch, + logger, } = params const newPrompt = `The assistant failed to find a match for the following changes. Please help the assistant understand what the changes should be. @@ -169,11 +174,16 @@ Provide a new set of SEARCH/REPLACE changes to make the intended edit from the o fingerprintId, userInputId, userId, + logger, }) const { diffBlocks: newDiffBlocks, diffBlocksThatDidntMatch: newDiffBlocksThatDidntMatch, - } = parseAndGetDiffBlocksSingleFile(response, oldContent) + } = parseAndGetDiffBlocksSingleFile({ + newContent: response, + oldFileContent: oldContent, + logger, + }) if (newDiffBlocksThatDidntMatch.length > 0) { logger.error( diff --git a/backend/src/get-documentation-for-query.ts b/backend/src/get-documentation-for-query.ts index a442085d28..9ccc5e3331 100644 --- a/backend/src/get-documentation-for-query.ts +++ b/backend/src/get-documentation-for-query.ts @@ -5,7 +5,8 @@ import { z } from 'zod/v4' import { fetchContext7LibraryDocumentation } from './llm-apis/context7-api' import { promptAiSdkStructured } from './llm-apis/vercel-ai-sdk/ai-sdk' -import { logger } from './util/logger' + +import type { Logger } from '@codebuff/types/logger' const DELIMITER = `\n\n----------------------------------------\n\n` @@ -19,20 +20,35 @@ const DELIMITER = `\n\n----------------------------------------\n\n` * @param options.userId The ID of the user making the request * @returns The documentation text chunks or null if no relevant docs found */ -export async function getDocumentationForQuery( - query: string, - options: { - tokens?: number - clientSessionId: string - userInputId: string - fingerprintId: string - userId?: string - }, -): Promise { +export async function getDocumentationForQuery(params: { + query: string + tokens?: number + clientSessionId: string + userInputId: string + fingerprintId: string + userId?: string + logger: Logger +}): Promise { + const { + query, + tokens, + clientSessionId, + userInputId, + fingerprintId, + userId, + logger, + } = params const startTime = Date.now() // 1. Search for relevant libraries - const libraryResults = await suggestLibraries(query, options) + const libraryResults = await suggestLibraries({ + query, + clientSessionId, + userInputId, + fingerprintId, + userId, + logger, + }) if (!libraryResults || libraryResults.libraries.length === 0) { logger.info( @@ -53,9 +69,11 @@ export async function getDocumentationForQuery( const allRawChunks = ( await Promise.all( libraries.map(({ libraryName, topic }) => - fetchContext7LibraryDocumentation(libraryName, { - tokens: options.tokens, + fetchContext7LibraryDocumentation({ + query: libraryName, + tokens, topic, + logger, }), ), ) @@ -85,11 +103,15 @@ export async function getDocumentationForQuery( } // 3. Filter relevant chunks using another LLM call - const filterResults = await filterRelevantChunks( + const filterResults = await filterRelevantChunks({ query, - allUniqueChunks, - options, - ) + allChunks: allUniqueChunks, + clientSessionId, + userInputId, + fingerprintId, + userId, + logger, + }) const totalDuration = Date.now() - startTime @@ -135,15 +157,16 @@ export async function getDocumentationForQuery( return relevantChunks.join(DELIMITER) } -const suggestLibraries = async ( - query: string, - options: { - clientSessionId: string - userInputId: string - fingerprintId: string - userId?: string - }, -) => { +const suggestLibraries = async (params: { + query: string + clientSessionId: string + userInputId: string + fingerprintId: string + userId?: string + logger: Logger +}) => { + const { query, clientSessionId, userInputId, fingerprintId, userId, logger } = + params const prompt = `You are an expert at documentation for libraries. Given a user's query return a list of (library name, topic) where each library name is the name of a library and topic is a keyword or phrase that specifies a topic within the library that is most relevant to the user's query. @@ -163,9 +186,11 @@ ${closeXml('user_query')} const geminiStartTime = Date.now() try { const response = await promptAiSdkStructured({ - ...options, messages: [{ role: 'user', content: prompt }], - userId: options.userId, + clientSessionId, + userInputId, + fingerprintId, + userId, model: models.openrouter_gemini2_5_flash, temperature: 0, schema: z.object({ @@ -177,6 +202,7 @@ ${closeXml('user_query')} ), }), timeout: 5_000, + logger, }) return { libraries: response.libraries, @@ -198,16 +224,24 @@ ${closeXml('user_query')} * @param options Common request options including session and user identifiers. * @returns A promise that resolves to an object containing the relevant chunks and Gemini call duration, or null if an error occurs. */ -async function filterRelevantChunks( - query: string, - allChunks: string[], - options: { - clientSessionId: string - userInputId: string - fingerprintId: string - userId?: string - }, -): Promise<{ relevantChunks: string[]; geminiDuration: number } | null> { +async function filterRelevantChunks(params: { + query: string + allChunks: string[] + clientSessionId: string + userInputId: string + fingerprintId: string + userId?: string + logger: Logger +}): Promise<{ relevantChunks: string[]; geminiDuration: number } | null> { + const { + query, + allChunks, + clientSessionId, + userInputId, + fingerprintId, + userId, + logger, + } = params const prompt = `You are an expert at analyzing documentation queries. Given a user's query and a list of documentation chunks, determine which chunks are relevant to the query. Choose as few chunks as possible, likely none. Only include chunks if they are relevant to the user query. @@ -223,16 +257,17 @@ ${closeXml('documentation_chunks')} try { const response = await promptAiSdkStructured({ messages: [{ role: 'user', content: prompt }], - clientSessionId: options.clientSessionId, - userInputId: options.userInputId, - fingerprintId: options.fingerprintId, - userId: options.userId, + clientSessionId, + userInputId, + fingerprintId, + userId, model: models.openrouter_gemini2_5_flash, temperature: 0, schema: z.object({ relevant_chunks: z.array(z.number()), }), timeout: 20_000, + logger, }) const geminiDuration = Date.now() - geminiStartTime diff --git a/backend/src/index.ts b/backend/src/index.ts index 9aeeaa17f5..1f1848b81e 100644 --- a/backend/src/index.ts +++ b/backend/src/index.ts @@ -94,7 +94,7 @@ server.listen(port, () => { logger.debug(`🚀 Server is running on port ${port}`) console.log(`🚀 Server is running on port ${port}`) }) -webSocketListen({ server, path: '/ws' }) +webSocketListen({ server, path: '/ws', logger }) let shutdownInProgress = false // Graceful shutdown handler for both SIGTERM and SIGINT diff --git a/backend/src/live-user-inputs.ts b/backend/src/live-user-inputs.ts index 6630ba10e9..1b2250aa01 100644 --- a/backend/src/live-user-inputs.ts +++ b/backend/src/live-user-inputs.ts @@ -1,4 +1,4 @@ -import { logger } from './util/logger' +import type { Logger } from '@codebuff/types/logger' let liveUserInputCheckEnabled = true export const disableLiveUserInputCheck = () => { @@ -25,8 +25,9 @@ export function startUserInput(params: { export function cancelUserInput(params: { userId: string userInputId: string + logger: Logger }): void { - const { userId, userInputId } = params + const { userId, userInputId, logger } = params if (live[userId] && live[userId].includes(userInputId)) { live[userId] = live[userId].filter((id) => id !== userInputId) if (live[userId].length === 0) { @@ -40,11 +41,13 @@ export function cancelUserInput(params: { } } -export function checkLiveUserInput( - userId: string | undefined, - userInputId: string, - sessionId: string, -): boolean { +export function checkLiveUserInput(params: { + userId: string | undefined + userInputId: string + clientSessionId: string +}): boolean { + const { userId, userInputId, clientSessionId } = params + if (!liveUserInputCheckEnabled) { return true } @@ -53,7 +56,7 @@ export function checkLiveUserInput( } // Check if WebSocket is still connected for this session - if (!sessionConnections[sessionId]) { + if (!sessionConnections[clientSessionId]) { return false } diff --git a/backend/src/llm-apis/__tests__/linkup-api.test.ts b/backend/src/llm-apis/__tests__/linkup-api.test.ts index ab21c62f08..a7743fc3d3 100644 --- a/backend/src/llm-apis/__tests__/linkup-api.test.ts +++ b/backend/src/llm-apis/__tests__/linkup-api.test.ts @@ -35,10 +35,6 @@ describe('Linkup API', () => { }, })) - mockModule('@codebuff/backend/util/logger', () => ({ - logger: mockLogger, - })) - // Mock withTimeout utility mockModule('@codebuff/common/util/promise', () => ({ withTimeout: async (promise: Promise, timeout: number) => promise, @@ -84,7 +80,10 @@ describe('Linkup API', () => { }), ) - const result = await searchWeb('React tutorial') + const result = await searchWeb({ + query: 'React tutorial', + logger: mockLogger, + }) expect(result).toBe( 'React is a JavaScript library for building user interfaces. You can learn how to build your first React application by following the official documentation.', @@ -128,8 +127,10 @@ describe('Linkup API', () => { }), ) - const result = await searchWeb('React patterns', { + const result = await searchWeb({ + query: 'React patterns', depth: 'deep', + logger: mockLogger, }) expect(result).toBe( @@ -157,7 +158,7 @@ describe('Linkup API', () => { }), ) - const result = await searchWeb('test query') + const result = await searchWeb({ query: 'test query', logger: mockLogger }) expect(result).toBeNull() }) @@ -165,7 +166,7 @@ describe('Linkup API', () => { test('should handle network errors', async () => { spyOn(global, 'fetch').mockRejectedValue(new Error('Network error')) - const result = await searchWeb('test query') + const result = await searchWeb({ query: 'test query', logger: mockLogger }) expect(result).toBeNull() }) @@ -178,7 +179,7 @@ describe('Linkup API', () => { }), ) - const result = await searchWeb('test query') + const result = await searchWeb({ query: 'test query', logger: mockLogger }) expect(result).toBeNull() }) @@ -191,7 +192,7 @@ describe('Linkup API', () => { }), ) - const result = await searchWeb('test query') + const result = await searchWeb({ query: 'test query', logger: mockLogger }) expect(result).toBeNull() }) @@ -208,7 +209,7 @@ describe('Linkup API', () => { }), ) - const result = await searchWeb('test query') + const result = await searchWeb({ query: 'test query', logger: mockLogger }) expect(result).toBeNull() }) @@ -228,7 +229,7 @@ describe('Linkup API', () => { }), ) - await searchWeb('test query') + await searchWeb({ query: 'test query', logger: mockLogger }) // Verify fetch was called with default parameters expect(fetch).toHaveBeenCalledWith( @@ -251,7 +252,7 @@ describe('Linkup API', () => { }), ) - const result = await searchWeb('test query') + const result = await searchWeb({ query: 'test query', logger: mockLogger }) expect(result).toBeNull() // Verify that error logging was called @@ -269,7 +270,10 @@ describe('Linkup API', () => { }), ) - const result = await searchWeb('test query for 404') + const result = await searchWeb({ + query: 'test query for 404', + logger: mockLogger, + }) expect(result).toBeNull() // Verify that detailed error logging was called with 404 info diff --git a/backend/src/llm-apis/context7-api.ts b/backend/src/llm-apis/context7-api.ts index 1a115cb1d4..e52aa366ad 100644 --- a/backend/src/llm-apis/context7-api.ts +++ b/backend/src/llm-apis/context7-api.ts @@ -1,7 +1,8 @@ import { withTimeout } from '@codebuff/common/util/promise' import { env } from '@codebuff/internal/env' -import { logger } from '../util/logger' +import type { ParamsOf } from '@codebuff/types/common' +import type { Logger } from '@codebuff/types/logger' const CONTEXT7_API_BASE_URL = 'https://context7.com/api/v1' const DEFAULT_TYPE = 'txt' @@ -42,9 +43,12 @@ export interface SearchResult { * Lists all available documentation projects from Context7 * @returns Array of projects with their metadata, or null if the request fails */ -export async function searchLibraries( - query: string, -): Promise { +export async function searchLibraries(params: { + query: string + logger: Logger +}): Promise { + const { query, logger } = params + const searchStartTime = Date.now() const searchContext = { query, @@ -127,23 +131,26 @@ export async function searchLibraries( * @returns The documentation text or null if the request fails */ export async function fetchContext7LibraryDocumentation( - query: string, - options: { + params: { + query: string tokens?: number topic?: string folders?: string - } = {}, + logger: Logger + } & ParamsOf, ): Promise { + const { query, tokens, topic, folders, logger } = params + const apiStartTime = Date.now() const apiContext = { query, - requestedTokens: options.tokens, - topic: options.topic, - folders: options.folders, + requestedTokens: tokens, + topic, + folders, } const searchStartTime = Date.now() - const libraries = await searchLibraries(query) + const libraries = await searchLibraries(params) const searchDuration = Date.now() - searchStartTime if (!libraries || libraries.length === 0) { @@ -179,10 +186,9 @@ export async function fetchContext7LibraryDocumentation( try { const url = new URL(`${CONTEXT7_API_BASE_URL}/${libraryId}`) - if (options.tokens) - url.searchParams.set('tokens', options.tokens.toString()) - if (options.topic) url.searchParams.set('topic', options.topic) - if (options.folders) url.searchParams.set('folders', options.folders) + if (tokens) url.searchParams.set('tokens', tokens.toString()) + if (topic) url.searchParams.set('topic', topic) + if (folders) url.searchParams.set('folders', folders) url.searchParams.set('type', DEFAULT_TYPE) const fetchStartTime = Date.now() diff --git a/backend/src/llm-apis/gemini-with-fallbacks.ts b/backend/src/llm-apis/gemini-with-fallbacks.ts index 90f39e1ca3..4de7033ac6 100644 --- a/backend/src/llm-apis/gemini-with-fallbacks.ts +++ b/backend/src/llm-apis/gemini-with-fallbacks.ts @@ -1,6 +1,5 @@ import { openaiModels, openrouterModels } from '@codebuff/common/old-constants' -import { logger } from '../util/logger' import { promptAiSdk } from './vercel-ai-sdk/ai-sdk' import type { @@ -9,6 +8,7 @@ import type { Model, } from '@codebuff/common/old-constants' import type { Message } from '@codebuff/common/types/messages/codebuff-message' +import type { Logger } from '@codebuff/types/logger' /** * Prompts a Gemini model with fallback logic. @@ -35,28 +35,29 @@ import type { Message } from '@codebuff/common/types/messages/codebuff-message' * @returns A promise that resolves to the complete response string from the successful API call. * @throws If all API calls (primary and fallbacks) fail. */ -export async function promptFlashWithFallbacks( - messages: Message[], - options: { - clientSessionId: string - fingerprintId: string - userInputId: string - model: Model - userId: string | undefined - maxTokens?: number - temperature?: number - costMode?: CostMode - useGPT4oInsteadOfClaude?: boolean - thinkingBudget?: number - useFinetunedModel?: FinetunedVertexModel | undefined - }, -): Promise { +export async function promptFlashWithFallbacks(params: { + messages: Message[] + clientSessionId: string + fingerprintId: string + userInputId: string + model: Model + userId: string | undefined + maxTokens?: number + temperature?: number + costMode?: CostMode + useGPT4oInsteadOfClaude?: boolean + thinkingBudget?: number + useFinetunedModel?: FinetunedVertexModel | undefined + logger: Logger +}): Promise { const { + messages, costMode, useGPT4oInsteadOfClaude, useFinetunedModel, + logger, ...geminiOptions - } = options + } = params // Try finetuned model first if enabled if (useFinetunedModel) { @@ -65,6 +66,7 @@ export async function promptFlashWithFallbacks( ...geminiOptions, messages, model: useFinetunedModel, + logger, }) } catch (error) { logger.warn( @@ -76,7 +78,7 @@ export async function promptFlashWithFallbacks( try { // First try Gemini - return await promptAiSdk({ ...geminiOptions, messages }) + return await promptAiSdk({ ...geminiOptions, messages, logger }) } catch (error) { logger.warn( { error }, @@ -94,6 +96,7 @@ export async function promptFlashWithFallbacks( experimental: openrouterModels.openrouter_claude_3_5_haiku, ask: openrouterModels.openrouter_claude_3_5_haiku, }[costMode ?? 'normal'], + logger, }) } } diff --git a/backend/src/llm-apis/linkup-api.ts b/backend/src/llm-apis/linkup-api.ts index 09aa8183aa..44c348abbb 100644 --- a/backend/src/llm-apis/linkup-api.ts +++ b/backend/src/llm-apis/linkup-api.ts @@ -1,7 +1,7 @@ import { withTimeout } from '@codebuff/common/util/promise' import { env } from '@codebuff/internal' -import { logger } from '../util/logger' +import type { Logger } from '@codebuff/types/logger' const LINKUP_API_BASE_URL = 'https://api.linkup.so/v1' const FETCH_TIMEOUT_MS = 30_000 @@ -23,13 +23,12 @@ export interface LinkupSearchResponse { * @param options Search options including depth and max results * @returns Array containing a single result with the sourced answer or null if the request fails */ -export async function searchWeb( - query: string, - options: { - depth?: 'standard' | 'deep' - } = {}, -): Promise { - const { depth = 'standard' } = options +export async function searchWeb(options: { + query: string + depth?: 'standard' | 'deep' + logger: Logger +}): Promise { + const { query, depth = 'standard', logger } = options const apiStartTime = Date.now() const requestBody = { diff --git a/backend/src/llm-apis/message-cost-tracker.ts b/backend/src/llm-apis/message-cost-tracker.ts index 916a779b63..9922cfcb8e 100644 --- a/backend/src/llm-apis/message-cost-tracker.ts +++ b/backend/src/llm-apis/message-cost-tracker.ts @@ -12,13 +12,14 @@ import Stripe from 'stripe' import { WebSocket } from 'ws' import { getRequestContext } from '../context/app-context' -import { logger, withLoggerContext } from '../util/logger' +import { withLoggerContext } from '../util/logger' import { stripNullCharsFromObject } from '../util/object' import { SWITCHBOARD } from '../websockets/server' import { sendAction } from '../websockets/websocket-action' import type { ClientState } from '../websockets/switchboard' import type { Message } from '@codebuff/common/types/messages/codebuff-message' +import type { Logger } from '@codebuff/types/logger' export const PROFIT_MARGIN = 0.055 @@ -212,13 +213,14 @@ const calcCost = ( const VERBOSE = false -async function syncMessageToStripe(messageData: { +async function syncMessageToStripe(params: { messageId: string userId: string costInCents: number finishedAt: Date + logger: Logger }) { - const { messageId, userId, costInCents, finishedAt } = messageData + const { messageId, userId, costInCents, finishedAt, logger } = params if (!userId || userId === TEST_USER_ID) { if (VERBOSE) { @@ -329,17 +331,19 @@ type InsertMessageParams = { latencyMs: number } -export async function insertMessageRecordWithRetries( - params: InsertMessageParams, - maxRetries = 3, -): Promise { +export async function insertMessageRecordWithRetries(params: { + messageParams: InsertMessageParams + logger: Logger + maxRetries?: number +}): Promise { + const { messageParams, logger, maxRetries = 3 } = params for (let attempt = 1; attempt <= maxRetries; attempt++) { try { - return await insertMessageRecord(params) + return await insertMessageRecord(messageParams) } catch (error) { if (attempt === maxRetries) { logger.error( - { messageId: params.messageId, error, attempt }, + { messageId: messageParams.messageId, error, attempt }, `Failed to save message after ${maxRetries} attempts`, ) return null @@ -347,7 +351,7 @@ export async function insertMessageRecordWithRetries( // throw error } else { logger.warn( - { messageId: params.messageId, error: error }, + { messageId: messageParams.messageId, error: error }, `Retrying save message to DB (attempt ${attempt}/${maxRetries})`, ) await new Promise((resolve) => setTimeout(resolve, 1000 * attempt)) @@ -416,12 +420,14 @@ async function insertMessageRecord( return insertResult[0] } -async function sendCostResponseToClient( - clientSessionId: string, - userInputId: string, - creditsUsed: number, - agentId?: string, -): Promise { +async function sendCostResponseToClient(params: { + clientSessionId: string + userInputId: string + creditsUsed: number + agentId?: string + logger: Logger +}): Promise { + const { clientSessionId, userInputId, creditsUsed, agentId, logger } = params try { const clientEntry = Array.from(SWITCHBOARD.clients.entries()).find( ([_, state]: [WebSocket, ClientState]) => @@ -462,17 +468,19 @@ type CreditConsumptionResult = { fromPurchased: number } -async function updateUserCycleUsageWithRetries( - userId: string, - creditsUsed: number, - maxRetries = 3, -): Promise { +async function updateUserCycleUsageWithRetries(params: { + userId: string + creditsUsed: number + logger: Logger + maxRetries?: number +}): Promise { + const { userId, creditsUsed, logger, maxRetries = 3 } = params const requestContext = getRequestContext() const orgId = requestContext?.approvedOrgIdForRepo for (let attempt = 1; attempt <= maxRetries; attempt++) { try { - return await updateUserCycleUsage(userId, creditsUsed) + return await updateUserCycleUsage({ userId, creditsUsed, logger }) } catch (error) { if (attempt === maxRetries) { logger.error( @@ -496,10 +504,12 @@ async function updateUserCycleUsageWithRetries( throw new Error('Failed to update user cycle usage after all attempts.') } -async function updateUserCycleUsage( - userId: string, - creditsUsed: number, -): Promise { +async function updateUserCycleUsage(params: { + userId: string + creditsUsed: number + logger: Logger +}): Promise { + const { userId, creditsUsed, logger } = params if (creditsUsed <= 0) { if (VERBOSE) { logger.debug( @@ -590,6 +600,7 @@ export const saveMessage = async (value: { chargeUser?: boolean costOverrideDollars?: number agentId?: string + logger: Logger }): Promise => withLoggerContext( { @@ -598,6 +609,7 @@ export const saveMessage = async (value: { fingerprintId: value.fingerprintId, }, async () => { + const { logger } = value const cost = value.costOverrideDollars ?? calcCost( @@ -652,17 +664,21 @@ export const saveMessage = async (value: { ) } - sendCostResponseToClient( - value.clientSessionId, - value.userInputId, + sendCostResponseToClient({ + clientSessionId: value.clientSessionId, + userInputId: value.userInputId, creditsUsed, - value.agentId, - ) + agentId: value.agentId, + logger, + }) await insertMessageRecordWithRetries({ - ...value, - cost, - creditsUsed, + messageParams: { + ...value, + cost, + creditsUsed, + }, + logger, }) if (!value.userId) { @@ -673,10 +689,11 @@ export const saveMessage = async (value: { return 0 } - const consumptionResult = await updateUserCycleUsageWithRetries( - value.userId, + const consumptionResult = await updateUserCycleUsageWithRetries({ + userId: value.userId, creditsUsed, - ) + logger, + }) // Only sync the portion from purchased credits to Stripe if (consumptionResult.fromPurchased > 0) { @@ -688,6 +705,7 @@ export const saveMessage = async (value: { userId: value.userId, costInCents: purchasedCostInCents, finishedAt: value.finishedAt, + logger, }).catch((syncError) => { logger.error( { messageId: value.messageId, error: syncError }, diff --git a/backend/src/llm-apis/relace-api.ts b/backend/src/llm-apis/relace-api.ts index 7b5a253bc1..1ac4d56258 100644 --- a/backend/src/llm-apis/relace-api.ts +++ b/backend/src/llm-apis/relace-api.ts @@ -6,36 +6,40 @@ import { import { env } from '@codebuff/internal' import { saveMessage } from '../llm-apis/message-cost-tracker' -import { logger } from '../util/logger' import { countTokens } from '../util/token-counter' import { promptAiSdk } from './vercel-ai-sdk/ai-sdk' +import type { Logger } from '@codebuff/types/logger' + const timeoutPromise = (ms: number) => new Promise((_, reject) => setTimeout(() => reject(new Error('Relace API request timed out')), ms), ) -export async function promptRelaceAI( - initialCode: string, - editSnippet: string, - instructions: string | undefined, - options: { - clientSessionId: string - fingerprintId: string - userInputId: string - userId: string | undefined - messageId: string - userMessage?: string - }, -) { +export async function promptRelaceAI(params: { + initialCode: string + editSnippet: string + instructions: string | undefined + clientSessionId: string + fingerprintId: string + userInputId: string + userId: string | undefined + messageId: string + userMessage?: string + logger: Logger +}) { const { + initialCode, + editSnippet, + instructions, clientSessionId, fingerprintId, userInputId, userId, userMessage, messageId, - } = options + logger, + } = params const startTime = Date.now() try { @@ -90,6 +94,7 @@ export async function promptRelaceAI( outputTokens: countTokens(content), finishedAt: new Date(), latencyMs: Date.now() - startTime, + logger, }) return content + '\n' } catch (error) { @@ -134,6 +139,7 @@ Please output just the complete updated file content with no other text.` userInputId, model: models.o3mini, userId, + logger, }) return parseMarkdownCodeBlock(content) + '\n' @@ -150,19 +156,26 @@ export type FileWithPath = { content: string } -export async function rerank( - files: FileWithPath[], - prompt: string, - options: { - clientSessionId: string - fingerprintId: string - userInputId: string - userId: string | undefined - messageId: string - }, -) { - const { clientSessionId, fingerprintId, userInputId, userId, messageId } = - options +export async function rerank(params: { + files: FileWithPath[] + prompt: string + clientSessionId: string + fingerprintId: string + userInputId: string + userId: string | undefined + messageId: string + logger: Logger +}) { + const { + files, + prompt, + clientSessionId, + fingerprintId, + userInputId, + userId, + messageId, + logger, + } = params const startTime = Date.now() if (!prompt || !files.length) { @@ -227,6 +240,7 @@ export async function rerank( outputTokens: countTokens(JSON.stringify(rankings)), finishedAt: new Date(), latencyMs: Date.now() - startTime, + logger, }) return rankings diff --git a/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts b/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts index 84ba2da6e0..ee7e3f75e7 100644 --- a/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts +++ b/backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts @@ -12,13 +12,13 @@ import { generateCompactId } from '@codebuff/common/util/string' import { APICallError, generateObject, generateText, streamText } from 'ai' import { checkLiveUserInput, getLiveUserInputIds } from '../../live-user-inputs' -import { logger } from '../../util/logger' import { saveMessage } from '../message-cost-tracker' import { openRouterLanguageModel } from '../openrouter' import { vertexFinetuned } from './vertex-finetuned' import type { Model, OpenAIModel } from '@codebuff/common/old-constants' import type { Message } from '@codebuff/common/types/messages/codebuff-message' +import type { Logger } from '@codebuff/types/logger' import type { OpenRouterProviderOptions, OpenRouterUsageAccounting, @@ -60,7 +60,7 @@ const modelToAiSDKModel = (model: Model): LanguageModel => { // also take an array of form [{model: Model, retries: number}, {model: Model, retries: number}...] // eg: [{model: "gemini-2.0-flash-001"}, {model: "vertex/gemini-2.0-flash-001"}, {model: "claude-3-5-haiku", retries: 3}] export const promptAiSdkStream = async function* ( - options: { + params: { messages: Message[] clientSessionId: string fingerprintId: string @@ -73,20 +73,18 @@ export const promptAiSdkStream = async function* ( maxRetries?: number onCostCalculated?: (credits: number) => Promise includeCacheControl?: boolean + logger: Logger } & Omit[0], 'model' | 'messages'>, ): AsyncGenerator { + const { logger } = params if ( - !checkLiveUserInput( - options.userId, - options.userInputId, - options.clientSessionId, - ) + !checkLiveUserInput({ ...params, clientSessionId: params.clientSessionId }) ) { logger.info( { - userId: options.userId, - userInputId: options.userInputId, - liveUserInputId: getLiveUserInputIds(options.userId), + userId: params.userId, + userInputId: params.userInputId, + liveUserInputId: getLiveUserInputIds(params.userId), }, 'Skipping stream due to canceled user input', ) @@ -94,16 +92,16 @@ export const promptAiSdkStream = async function* ( } const startTime = Date.now() - let aiSDKModel = modelToAiSDKModel(options.model) + let aiSDKModel = modelToAiSDKModel(params.model) const response = streamText({ - ...options, + ...params, model: aiSDKModel, - messages: convertCbToModelMessages(options), + messages: convertCbToModelMessages(params), }) let content = '' - const stopSequenceHandler = new StopSequenceHandler(options.stopSequences) + const stopSequenceHandler = new StopSequenceHandler(params.stopSequences) for await (const chunk of response.fullStream) { if (chunk.type !== 'text-delta') { @@ -121,7 +119,7 @@ export const promptAiSdkStream = async function* ( { chunk: { ...chunk, error: undefined }, error: getErrorObject(chunk.error), - model: options.model, + model: params.model, }, 'Error from AI SDK', ) @@ -135,7 +133,7 @@ export const promptAiSdkStream = async function* ( : typeof chunk.error === 'string' ? chunk.error : JSON.stringify(chunk.error) - const errorMessage = `Error from AI SDK (model ${options.model}): ${buildArray([mainErrorMessage, errorBody]).join('\n')}` + const errorMessage = `Error from AI SDK (model ${params.model}): ${buildArray([mainErrorMessage, errorBody]).join('\n')}` yield { type: 'error', message: errorMessage, @@ -146,7 +144,7 @@ export const promptAiSdkStream = async function* ( if (chunk.type === 'reasoning-delta') { if ( ( - options.providerOptions?.openrouter as + params.providerOptions?.openrouter as | OpenRouterProviderOptions | undefined )?.reasoning?.exclude @@ -159,7 +157,7 @@ export const promptAiSdkStream = async function* ( } } if (chunk.type === 'text-delta') { - if (!options.stopSequences) { + if (!params.stopSequences) { content += chunk.text if (chunk.text) { yield { @@ -223,12 +221,12 @@ export const promptAiSdkStream = async function* ( const messageId = (await response.response).id const creditsUsedPromise = saveMessage({ messageId, - userId: options.userId, - clientSessionId: options.clientSessionId, - fingerprintId: options.fingerprintId, - userInputId: options.userInputId, - model: options.model, - request: options.messages, + userId: params.userId, + clientSessionId: params.clientSessionId, + fingerprintId: params.fingerprintId, + userInputId: params.userInputId, + model: params.model, + request: params.messages, response: content, inputTokens, outputTokens, @@ -236,15 +234,16 @@ export const promptAiSdkStream = async function* ( cacheReadInputTokens, finishedAt: new Date(), latencyMs: Date.now() - startTime, - chargeUser: options.chargeUser ?? true, + chargeUser: params.chargeUser ?? true, costOverrideDollars, - agentId: options.agentId, + agentId: params.agentId, + logger, }) // Call the cost callback if provided - if (options.onCostCalculated) { + if (params.onCostCalculated) { const creditsUsed = await creditsUsedPromise - await options.onCostCalculated(creditsUsed) + await params.onCostCalculated(creditsUsed) } return messageId @@ -252,7 +251,7 @@ export const promptAiSdkStream = async function* ( // TODO: figure out a nice way to unify stream & non-stream versions maybe? export const promptAiSdk = async function ( - options: { + params: { messages: Message[] clientSessionId: string fingerprintId: string @@ -264,20 +263,17 @@ export const promptAiSdk = async function ( onCostCalculated?: (credits: number) => Promise includeCacheControl?: boolean maxRetries?: number + logger: Logger } & Omit[0], 'model' | 'messages'>, ): Promise { - if ( - !checkLiveUserInput( - options.userId, - options.userInputId, - options.clientSessionId, - ) - ) { + const { logger } = params + + if (!checkLiveUserInput(params)) { logger.info( { - userId: options.userId, - userInputId: options.userInputId, - liveUserInputId: getLiveUserInputIds(options.userId), + userId: params.userId, + userInputId: params.userInputId, + liveUserInputId: getLiveUserInputIds(params.userId), }, 'Skipping prompt due to canceled user input', ) @@ -285,12 +281,12 @@ export const promptAiSdk = async function ( } const startTime = Date.now() - let aiSDKModel = modelToAiSDKModel(options.model) + let aiSDKModel = modelToAiSDKModel(params.model) const response = await generateText({ - ...options, + ...params, model: aiSDKModel, - messages: convertCbToModelMessages(options), + messages: convertCbToModelMessages(params), }) const content = response.text const inputTokens = response.usage.inputTokens || 0 @@ -298,32 +294,33 @@ export const promptAiSdk = async function ( const creditsUsedPromise = saveMessage({ messageId: generateCompactId(), - userId: options.userId, - clientSessionId: options.clientSessionId, - fingerprintId: options.fingerprintId, - userInputId: options.userInputId, - model: options.model, - request: options.messages, + userId: params.userId, + clientSessionId: params.clientSessionId, + fingerprintId: params.fingerprintId, + userInputId: params.userInputId, + model: params.model, + request: params.messages, response: content, inputTokens, outputTokens, finishedAt: new Date(), latencyMs: Date.now() - startTime, - chargeUser: options.chargeUser ?? true, - agentId: options.agentId, + chargeUser: params.chargeUser ?? true, + agentId: params.agentId, + logger, }) // Call the cost callback if provided - if (options.onCostCalculated) { + if (params.onCostCalculated) { const creditsUsed = await creditsUsedPromise - await options.onCostCalculated(creditsUsed) + await params.onCostCalculated(creditsUsed) } return content } // Copied over exactly from promptAiSdk but with a schema -export const promptAiSdkStructured = async function (options: { +export const promptAiSdkStructured = async function (params: { messages: Message[] schema: z.ZodType clientSessionId: string @@ -339,62 +336,60 @@ export const promptAiSdkStructured = async function (options: { onCostCalculated?: (credits: number) => Promise includeCacheControl?: boolean maxRetries?: number + logger: Logger }): Promise { - if ( - !checkLiveUserInput( - options.userId, - options.userInputId, - options.clientSessionId, - ) - ) { + const { logger } = params + + if (!checkLiveUserInput(params)) { logger.info( { - userId: options.userId, - userInputId: options.userInputId, - liveUserInputId: getLiveUserInputIds(options.userId), + userId: params.userId, + userInputId: params.userInputId, + liveUserInputId: getLiveUserInputIds(params.userId), }, 'Skipping structured prompt due to canceled user input', ) return {} as T } const startTime = Date.now() - let aiSDKModel = modelToAiSDKModel(options.model) + let aiSDKModel = modelToAiSDKModel(params.model) const responsePromise = generateObject, 'object'>({ - ...options, + ...params, model: aiSDKModel, output: 'object', - messages: convertCbToModelMessages(options), + messages: convertCbToModelMessages(params), }) - const response = await (options.timeout === undefined + const response = await (params.timeout === undefined ? responsePromise - : withTimeout(responsePromise, options.timeout)) + : withTimeout(responsePromise, params.timeout)) const content = response.object const inputTokens = response.usage.inputTokens || 0 const outputTokens = response.usage.inputTokens || 0 const creditsUsedPromise = saveMessage({ messageId: generateCompactId(), - userId: options.userId, - clientSessionId: options.clientSessionId, - fingerprintId: options.fingerprintId, - userInputId: options.userInputId, - model: options.model, - request: options.messages, + userId: params.userId, + clientSessionId: params.clientSessionId, + fingerprintId: params.fingerprintId, + userInputId: params.userInputId, + model: params.model, + request: params.messages, response: JSON.stringify(content), inputTokens, outputTokens, finishedAt: new Date(), latencyMs: Date.now() - startTime, - chargeUser: options.chargeUser ?? true, - agentId: options.agentId, + chargeUser: params.chargeUser ?? true, + agentId: params.agentId, + logger, }) // Call the cost callback if provided - if (options.onCostCalculated) { + if (params.onCostCalculated) { const creditsUsed = await creditsUsedPromise - await options.onCostCalculated(creditsUsed) + await params.onCostCalculated(creditsUsed) } return content diff --git a/backend/src/main-prompt.ts b/backend/src/main-prompt.ts index cca7f9c152..b029763060 100644 --- a/backend/src/main-prompt.ts +++ b/backend/src/main-prompt.ts @@ -62,7 +62,7 @@ export const mainPrompt = async (params: { let agentType: AgentTemplateType if (agentId) { - if (!(await getAgentTemplate(agentId, localAgentTemplates))) { + if (!(await getAgentTemplate({ agentId, localAgentTemplates, logger }))) { throw new Error( `Invalid agent ID: "${agentId}". Available agents: ${availableAgents.join(', ')}`, ) @@ -81,7 +81,13 @@ export const mainPrompt = async (params: { // Check for base agent in config const configBaseAgent = fileContext.codebuffConfig?.baseAgent if (configBaseAgent) { - if (!(await getAgentTemplate(configBaseAgent, localAgentTemplates))) { + if ( + !(await getAgentTemplate({ + agentId: configBaseAgent, + localAgentTemplates, + logger, + })) + ) { throw new Error( `Invalid base agent in config: "${configBaseAgent}". Available agents: ${availableAgents.join(', ')}`, ) @@ -111,7 +117,11 @@ export const mainPrompt = async (params: { mainAgentState.agentType = agentType - let mainAgentTemplate = await getAgentTemplate(agentType, localAgentTemplates) + let mainAgentTemplate = await getAgentTemplate({ + agentId: agentType, + localAgentTemplates, + logger, + }) if (!mainAgentTemplate) { throw new Error(`Agent template not found for type: ${agentType}`) } @@ -198,7 +208,8 @@ export const mainPrompt = async (params: { } } - const { agentState, output } = await loopAgentSteps(ws, { + const { agentState, output } = await loopAgentSteps({ + ws, userInputId: promptId, prompt, content, @@ -211,6 +222,7 @@ export const mainPrompt = async (params: { clientSessionId, onResponseChunk, localAgentTemplates, + logger, }) logger.debug({ agentState, output }, 'Main prompt finished') diff --git a/backend/src/process-file-block.ts b/backend/src/process-file-block.ts index 2ca24e3e20..4e4e8fa37b 100644 --- a/backend/src/process-file-block.ts +++ b/backend/src/process-file-block.ts @@ -11,9 +11,8 @@ import { import { promptAiSdk } from './llm-apis/vercel-ai-sdk/ai-sdk' import { countTokens } from './util/token-counter' -import type { Logger } from '@codebuff/types/logger' - import type { Message } from '@codebuff/common/types/messages/codebuff-message' +import type { Logger } from '@codebuff/types/logger' export async function processFileBlock(params: { path: string @@ -295,10 +294,15 @@ Please output just the SEARCH/REPLACE blocks like this: fingerprintId, userInputId, userId, + logger, }) const { diffBlocks, diffBlocksThatDidntMatch } = - parseAndGetDiffBlocksSingleFile(response, oldContent) + parseAndGetDiffBlocksSingleFile({ + newContent: response, + oldFileContent: oldContent, + logger, + }) let updatedContent = oldContent for (const { searchContent, replaceContent } of diffBlocks) { @@ -328,6 +332,7 @@ Please output just the SEARCH/REPLACE blocks like this: userInputId, userId, diffBlocksThatDidntMatch, + logger, }) if (newDiffBlocksThatDidntMatch.length > 0) { diff --git a/backend/src/prompt-agent-stream.ts b/backend/src/prompt-agent-stream.ts index 61521a72c6..ace57da851 100644 --- a/backend/src/prompt-agent-stream.ts +++ b/backend/src/prompt-agent-stream.ts @@ -6,6 +6,7 @@ import { globalStopSequence } from './tools/constants' import type { AgentTemplate } from './templates/types' import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { OpenRouterProviderOptions } from '@codebuff/internal/openrouter-ai-sdk' +import type { Logger } from '@codebuff/types/logger' export const getAgentStreamFromTemplate = (params: { clientSessionId: string @@ -17,6 +18,7 @@ export const getAgentStreamFromTemplate = (params: { includeCacheControl?: boolean template: AgentTemplate + logger: Logger }) => { const { clientSessionId, @@ -27,6 +29,7 @@ export const getAgentStreamFromTemplate = (params: { agentId, includeCacheControl, template, + logger, } = params if (!template) { @@ -36,7 +39,7 @@ export const getAgentStreamFromTemplate = (params: { const { model } = template const getStream = (messages: Message[]) => { - const options: Parameters[0] = { + const aiSdkStreamParams: Parameters[0] = { messages, model, stopSequences: [globalStopSequence], @@ -49,6 +52,7 @@ export const getAgentStreamFromTemplate = (params: { includeCacheControl, agentId, maxRetries: 3, + logger, } // Add Gemini-specific options if needed @@ -56,17 +60,17 @@ export const getAgentStreamFromTemplate = (params: { const provider = providerModelNames[primaryModel as keyof typeof providerModelNames] - if (!options.providerOptions) { - options.providerOptions = {} + if (!aiSdkStreamParams.providerOptions) { + aiSdkStreamParams.providerOptions = {} } - if (!options.providerOptions.openrouter) { - options.providerOptions.openrouter = {} + if (!aiSdkStreamParams.providerOptions.openrouter) { + aiSdkStreamParams.providerOptions.openrouter = {} } ;( - options.providerOptions.openrouter as OpenRouterProviderOptions + aiSdkStreamParams.providerOptions.openrouter as OpenRouterProviderOptions ).reasoning = template.reasoningOptions - return promptAiSdkStream(options) + return promptAiSdkStream(aiSdkStreamParams) } return { getStream } diff --git a/backend/src/run-agent-step.ts b/backend/src/run-agent-step.ts index 5e2192ed5a..a090a64111 100644 --- a/backend/src/run-agent-step.ts +++ b/backend/src/run-agent-step.ts @@ -16,7 +16,6 @@ import { additionalSystemPrompts } from './system-prompt/prompts' import { getAgentTemplate } from './templates/agent-registry' import { getAgentPrompt } from './templates/strings' import { processStreamWithTools } from './tools/stream-parser' -import { logger } from './util/logger' import { asSystemInstruction, asSystemMessage, @@ -45,6 +44,7 @@ import type { AgentOutput, } from '@codebuff/common/types/session-state' import type { ProjectFileContext } from '@codebuff/common/util/file' +import type { Logger } from '@codebuff/types/logger' import type { WebSocket } from 'ws' /** @@ -72,6 +72,7 @@ function buildUserMessageContent( } export interface AgentOptions { + ws: WebSocket userId: string | undefined userInputId: string clientSessionId: string @@ -86,10 +87,10 @@ export interface AgentOptions { prompt: string | undefined params: Record | undefined system: string + logger: Logger } export const runAgentStep = async ( - ws: WebSocket, options: AgentOptions, ): Promise<{ agentState: AgentState @@ -98,6 +99,7 @@ export const runAgentStep = async ( messageId: string | null }> => { const { + ws, userId, userInputId, fingerprintId, @@ -109,6 +111,7 @@ export const runAgentStep = async ( prompt, params, system, + logger, } = options let agentState = options.agentState @@ -170,7 +173,11 @@ export const runAgentStep = async ( } } - const agentTemplate = await getAgentTemplate(agentType, localAgentTemplates) + const agentTemplate = await getAgentTemplate({ + agentId: agentType, + localAgentTemplates, + logger, + }) if (!agentTemplate) { throw new Error( `Agent template not found for type: ${agentType}. Available types: ${Object.keys(localAgentTemplates).join(', ')}`, @@ -183,6 +190,7 @@ export const runAgentStep = async ( fileContext, agentState, agentTemplates: localAgentTemplates, + logger, additionalToolDefinitions: () => { const additionalToolDefinitions = cloneDeep( Object.fromEntries( @@ -250,6 +258,7 @@ export const runAgentStep = async ( } }, includeCacheControl: supportsCacheControl(agentTemplate.model), + logger, }) const iterationNum = agentState.messageHistory.length @@ -305,6 +314,7 @@ export const runAgentStep = async ( agentContext, onResponseChunk, fullResponse, + logger, }) toolResults.push(...newToolResults) @@ -394,45 +404,50 @@ export const runAgentStep = async ( } } -export const loopAgentSteps = async ( - ws: WebSocket, - { - userInputId, - agentType, - agentState, - prompt, - content, - params, - fingerprintId, - fileContext, - localAgentTemplates, - userId, - clientSessionId, - onResponseChunk, - clearUserPromptMessagesAfterResponse = true, - parentSystemPrompt, - }: { - userInputId: string - agentType: AgentTemplateType - agentState: AgentState - prompt: string | undefined - content?: Array - params: Record | undefined - fingerprintId: string - fileContext: ProjectFileContext - localAgentTemplates: Record - clearUserPromptMessagesAfterResponse?: boolean - parentSystemPrompt?: string - - userId: string | undefined - clientSessionId: string - onResponseChunk: (chunk: string | PrintModeEvent) => void - }, -): Promise<{ +export const loopAgentSteps = async ({ + ws, + userInputId, + agentType, + agentState, + prompt, + content, + params, + fingerprintId, + fileContext, + localAgentTemplates, + userId, + clientSessionId, + onResponseChunk, + clearUserPromptMessagesAfterResponse = true, + parentSystemPrompt, + logger, +}: { + ws: WebSocket + userInputId: string + agentType: AgentTemplateType + agentState: AgentState + prompt: string | undefined + content?: Array + params: Record | undefined + fingerprintId: string + fileContext: ProjectFileContext + localAgentTemplates: Record + clearUserPromptMessagesAfterResponse?: boolean + parentSystemPrompt?: string + + userId: string | undefined + clientSessionId: string + onResponseChunk: (chunk: string | PrintModeEvent) => void + logger: Logger +}): Promise<{ agentState: AgentState output: AgentOutput }> => { - const agentTemplate = await getAgentTemplate(agentType, localAgentTemplates) + const agentTemplate = await getAgentTemplate({ + agentId: agentType, + localAgentTemplates, + logger, + }) if (!agentTemplate) { throw new Error(`Agent template not found for type: ${agentType}`) } @@ -460,6 +475,7 @@ export const loopAgentSteps = async ( fileContext, agentState, agentTemplates: localAgentTemplates, + logger, additionalToolDefinitions: () => { const additionalToolDefinitions = cloneDeep( Object.fromEntries( @@ -489,6 +505,7 @@ export const loopAgentSteps = async ( fileContext, agentState, agentTemplates: localAgentTemplates, + logger, additionalToolDefinitions: () => { const additionalToolDefinitions = cloneDeep( Object.fromEntries( @@ -547,7 +564,7 @@ export const loopAgentSteps = async ( try { while (true) { totalSteps++ - if (!checkLiveUserInput(userId, userInputId, clientSessionId)) { + if (!checkLiveUserInput({ userId, userInputId, clientSessionId })) { logger.warn( { userId, @@ -570,7 +587,8 @@ export const loopAgentSteps = async ( agentState: programmaticAgentState, endTurn, stepNumber, - } = await runProgrammaticStep(currentAgentState, { + } = await runProgrammaticStep({ + agentState: currentAgentState, userId, userInputId, clientSessionId, @@ -581,10 +599,11 @@ export const loopAgentSteps = async ( template: agentTemplate, localAgentTemplates, prompt: currentPrompt, - params: currentParams, + toolCallParams: currentParams, system, stepsComplete: shouldEndTurn, stepNumber: totalSteps, + logger, }) currentAgentState = programmaticAgentState totalSteps = stepNumber @@ -640,7 +659,8 @@ export const loopAgentSteps = async ( agentState: newAgentState, shouldEndTurn: llmShouldEndTurn, messageId, - } = await runAgentStep(ws, { + } = await runAgentStep({ + ws, userId, userInputId, clientSessionId, @@ -653,6 +673,7 @@ export const loopAgentSteps = async ( prompt: currentPrompt, params: currentParams, system, + logger, }) if (newAgentState.runId) { @@ -685,7 +706,7 @@ export const loopAgentSteps = async ( ) } - const status = checkLiveUserInput(userId, userInputId, clientSessionId) + const status = checkLiveUserInput({ userId, userInputId, clientSessionId }) ? 'completed' : 'cancelled' await finishAgentRun({ @@ -717,7 +738,7 @@ export const loopAgentSteps = async ( ) const errorMessage = typeof error === 'string' ? error : `${error}` - const status = checkLiveUserInput(userId, userInputId, clientSessionId) + const status = checkLiveUserInput({ userId, userInputId, clientSessionId }) ? 'failed' : 'cancelled' await finishAgentRun({ diff --git a/backend/src/run-programmatic-step.ts b/backend/src/run-programmatic-step.ts index c599cde10c..db4ef99363 100644 --- a/backend/src/run-programmatic-step.ts +++ b/backend/src/run-programmatic-step.ts @@ -4,7 +4,6 @@ import { cloneDeep } from 'lodash' import { addAgentStep } from './agent-run' import { executeToolCall } from './tools/tool-executor' -import { logger } from './util/logger' import { SandboxManager } from './util/quickjs-sandbox' import { getRequestContext } from './websockets/request-context' import { sendAction } from './websockets/websocket-action' @@ -22,6 +21,8 @@ import type { import type { PrintModeEvent } from '@codebuff/common/types/print-mode' import type { AgentState } from '@codebuff/common/types/session-state' import type { ProjectFileContext } from '@codebuff/common/util/file' +import type { ParamsOf } from '@codebuff/types/common' +import type { Logger } from '@codebuff/types/logger' import type { WebSocket } from 'ws' // Global sandbox manager for QuickJS contexts @@ -32,22 +33,41 @@ const runIdToGenerator: Record = {} export const runIdToStepAll: Set = new Set() // Function to clear the generator cache for testing purposes -export function clearAgentGeneratorCache() { +export function clearAgentGeneratorCache( + params: ParamsOf, +) { for (const key in runIdToGenerator) { delete runIdToGenerator[key] } runIdToStepAll.clear() // Clean up QuickJS sandboxes - sandboxManager.dispose() + sandboxManager.dispose(params) } // Function to handle programmatic agents -export async function runProgrammaticStep( - agentState: AgentState, - { +export async function runProgrammaticStep(params: { + agentState: AgentState + template: AgentTemplate + prompt: string | undefined + toolCallParams: Record | undefined + system: string | undefined + userId: string | undefined + userInputId: string + clientSessionId: string + fingerprintId: string + onResponseChunk: (chunk: string | PrintModeEvent) => void + fileContext: ProjectFileContext + ws: WebSocket + localAgentTemplates: Record + stepsComplete: boolean + stepNumber: number + logger: Logger +}): Promise<{ agentState: AgentState; endTurn: boolean; stepNumber: number }> { + const { + agentState, template, prompt, - params, + toolCallParams, system, userId, userInputId, @@ -58,24 +78,10 @@ export async function runProgrammaticStep( ws, localAgentTemplates, stepsComplete, - stepNumber, - }: { - template: AgentTemplate - prompt: string | undefined - params: Record | undefined - system: string | undefined - userId: string | undefined - userInputId: string - clientSessionId: string - fingerprintId: string - onResponseChunk: (chunk: string | PrintModeEvent) => void - fileContext: ProjectFileContext - ws: WebSocket - localAgentTemplates: Record - stepsComplete: boolean - stepNumber: number - }, -): Promise<{ agentState: AgentState; endTurn: boolean; stepNumber: number }> { + logger, + } = params + let { stepNumber } = params + if (!template.handleSteps) { throw new Error('No step handler found for agent template ' + template.id) } @@ -119,11 +125,12 @@ export async function runProgrammaticStep( initialInput: { agentState, prompt, - params, + params: toolCallParams, logger: streamingLogger, }, config: undefined, // config - logger: streamingLogger, // pass the streaming logger instance for internal use + sandboxLogger: streamingLogger, // pass the streaming logger instance for internal use + logger, }) } else { // Initialize native generator @@ -278,6 +285,7 @@ export async function runProgrammaticStep( userId, autoInsertEndStepParam: true, excludeToolFromMessageHistory, + logger, }) // TODO: Remove messages from state and always use agentState.messageHistory. @@ -362,7 +370,7 @@ export async function runProgrammaticStep( if (endTurn) { if (sandbox) { // Clean up QuickJS sandbox if execution is complete - sandboxManager.removeSandbox({ runId: agentState.runId }) + sandboxManager.removeSandbox({ runId: agentState.runId, logger }) } delete runIdToGenerator[agentState.runId] runIdToStepAll.delete(agentState.runId) diff --git a/backend/src/system-prompt/prompts.ts b/backend/src/system-prompt/prompts.ts index 428a7314c2..9e409b563d 100644 --- a/backend/src/system-prompt/prompts.ts +++ b/backend/src/system-prompt/prompts.ts @@ -13,6 +13,7 @@ import { schemaToJsonStr } from '@codebuff/common/util/zod-schema' import { truncateFileTreeBasedOnTokenBudget } from './truncate-file-tree' +import type { Logger } from '@codebuff/types/logger' import type { ProjectFileContext } from '@codebuff/common/util/file' export const configSchemaPrompt = ` @@ -146,16 +147,19 @@ export const additionalSystemPrompts = { compact: compactPrompt, } as const -export const getProjectFileTreePrompt = ( - fileContext: ProjectFileContext, - fileTreeTokenBudget: number, - mode: 'search' | 'agent', -) => { +export const getProjectFileTreePrompt = (params: { + fileContext: ProjectFileContext + fileTreeTokenBudget: number + mode: 'search' | 'agent' + logger: Logger +}) => { + const { fileContext, fileTreeTokenBudget, mode, logger } = params const { projectRoot } = fileContext - const { printedTree, truncationLevel } = truncateFileTreeBasedOnTokenBudget( + const { printedTree, truncationLevel } = truncateFileTreeBasedOnTokenBudget({ fileContext, - Math.max(0, fileTreeTokenBudget), - ) + tokenBudget: Math.max(0, fileTreeTokenBudget), + logger, + }) const truncationNote = truncationLevel === 'none' diff --git a/backend/src/system-prompt/search-system-prompt.ts b/backend/src/system-prompt/search-system-prompt.ts index ca83d910d0..68240ceb1c 100644 --- a/backend/src/system-prompt/search-system-prompt.ts +++ b/backend/src/system-prompt/search-system-prompt.ts @@ -6,14 +6,15 @@ import { getProjectFileTreePrompt, getSystemInfoPrompt, } from './prompts' -import { logger } from '../util/logger' import { countTokens, countTokensJson } from '../util/token-counter' +import type { Logger } from '@codebuff/types/logger' import type { ProjectFileContext } from '@codebuff/common/util/file' export function getSearchSystemPrompt(params: { fileContext: ProjectFileContext messagesTokens: number + logger: Logger options: { agentStepId: string clientSessionId: string @@ -22,7 +23,7 @@ export function getSearchSystemPrompt(params: { userId: string | undefined } }): string { - const { fileContext, messagesTokens, options } = params + const { fileContext, messagesTokens, logger, options } = params const startTime = Date.now() const maxTokens = 500_000 // costMode === 'lite' ? 64_000 : @@ -41,17 +42,23 @@ export function getSearchSystemPrompt(params: { 20_000, ) * 20_000 - const projectFileTreePrompt = getProjectFileTreePrompt( + const projectFileTreePrompt = getProjectFileTreePrompt({ fileContext, fileTreeTokenBudget, - 'search', - ) + mode: 'search', + logger, + }) const t = Date.now() const truncationBudgets = [5_000, 20_000, 40_000, 100_000, 500_000] const truncatedTrees = truncationBudgets.reduce( (acc, budget) => { - acc[budget] = getProjectFileTreePrompt(fileContext, budget, 'search') + acc[budget] = getProjectFileTreePrompt({ + fileContext, + fileTreeTokenBudget: budget, + mode: 'search', + logger, + }) return acc }, {} as Record, diff --git a/backend/src/system-prompt/truncate-file-tree.ts b/backend/src/system-prompt/truncate-file-tree.ts index 66e82d1cac..89eafb8c1d 100644 --- a/backend/src/system-prompt/truncate-file-tree.ts +++ b/backend/src/system-prompt/truncate-file-tree.ts @@ -4,9 +4,10 @@ import { } from '@codebuff/common/util/file' import { sampleSizeWithSeed } from '@codebuff/common/util/random' -import { logger } from '../util/logger' import { countTokens, countTokensJson } from '../util/token-counter' +import type { Logger } from '@codebuff/types/logger' + import type { FileTreeNode, ProjectFileContext, @@ -15,14 +16,16 @@ import type { type TruncationLevel = 'none' | 'unimportant-files' | 'tokens' | 'depth-based' const DEBUG = false -export const truncateFileTreeBasedOnTokenBudget = ( - fileContext: ProjectFileContext, - tokenBudget: number, -): { +export const truncateFileTreeBasedOnTokenBudget = (params: { + fileContext: ProjectFileContext + tokenBudget: number + logger: Logger +}): { printedTree: string tokenCount: number truncationLevel: TruncationLevel } => { + const { fileContext, tokenBudget, logger } = params const startTime = performance.now() const { fileTree, fileTokenScores } = fileContext @@ -69,6 +72,7 @@ export const truncateFileTreeBasedOnTokenBudget = ( fileTree: filteredTree, fileTokenScores, tokenBudget, + logger, }) if (tokenCount <= tokenBudget) { @@ -220,8 +224,9 @@ function pruneFileTokenScores(params: { fileTree: FileTreeNode[] fileTokenScores: Record> tokenBudget: number + logger: Logger }) { - const { fileTree, fileTokenScores, tokenBudget } = params + const { fileTree, fileTokenScores, tokenBudget, logger } = params const startTime = performance.now() // Create sorted array of tokens by score diff --git a/backend/src/templates/agent-registry.ts b/backend/src/templates/agent-registry.ts index 269d12977c..4da5042eca 100644 --- a/backend/src/templates/agent-registry.ts +++ b/backend/src/templates/agent-registry.ts @@ -8,7 +8,7 @@ import { parsePublishedAgentId } from '@codebuff/common/util/agent-id-parsing' import { DEFAULT_ORG_PREFIX } from '@codebuff/common/util/agent-name-normalization' import { and, desc, eq } from 'drizzle-orm' -import { logger } from '../util/logger' +import type { Logger } from '@codebuff/types/logger' import type { DynamicAgentValidationError } from '@codebuff/common/templates/agent-validation' import type { AgentTemplate } from '@codebuff/common/types/agent-template' @@ -23,11 +23,15 @@ const databaseAgentCache = new Map() /** * Fetch an agent from the database by publisher/agent-id[@version] format */ -async function fetchAgentFromDatabase(parsedAgentId: { - publisherId: string - agentId: string - version?: string +async function fetchAgentFromDatabase(params: { + parsedAgentId: { + publisherId: string + agentId: string + version?: string + } + logger: Logger }): Promise { + const { parsedAgentId, logger } = params const { publisherId, agentId, version } = parsedAgentId try { @@ -128,10 +132,12 @@ async function fetchAgentFromDatabase(parsedAgentId: { * 2. Database cache * 3. Database query */ -export async function getAgentTemplate( - agentId: string, - localAgentTemplates: Record, -): Promise { +export async function getAgentTemplate(params: { + agentId: string + localAgentTemplates: Record + logger: Logger +}): Promise { + const { agentId, localAgentTemplates, logger } = params // 1. Check localAgentTemplates first (dynamic agents + static templates) if (localAgentTemplates[agentId]) { return localAgentTemplates[agentId] @@ -148,7 +154,7 @@ export async function getAgentTemplate( `${DEFAULT_ORG_PREFIX}${agentId}`, ) if (codebuffParsed) { - const dbAgent = await fetchAgentFromDatabase(codebuffParsed) + const dbAgent = await fetchAgentFromDatabase({ parsedAgentId: codebuffParsed, logger }) if (dbAgent) { databaseAgentCache.set(dbAgent.id, dbAgent) return dbAgent @@ -159,7 +165,7 @@ export async function getAgentTemplate( } // 3. Query database (only for publisher/agent-id format) - const dbAgent = await fetchAgentFromDatabase(parsed) + const dbAgent = await fetchAgentFromDatabase({ parsedAgentId: parsed, logger }) if (dbAgent && parsed.version && parsed.version !== 'latest') { // Cache only specific versions to avoid stale 'latest' results databaseAgentCache.set(dbAgent.id, dbAgent) @@ -170,10 +176,14 @@ export async function getAgentTemplate( /** * Assemble local agent templates from fileContext + static templates */ -export function assembleLocalAgentTemplates(fileContext: ProjectFileContext): { +export function assembleLocalAgentTemplates(params: { + fileContext: ProjectFileContext + logger: Logger +}): { agentTemplates: Record validationErrors: DynamicAgentValidationError[] } { + const { fileContext, logger } = params // Load dynamic agents using the service const { templates: dynamicTemplates, validationErrors } = validateAgents({ agentTemplates: fileContext.agentTemplates, diff --git a/backend/src/templates/prompts.ts b/backend/src/templates/prompts.ts index 02ca5cdbc2..6d1c4f29b1 100644 --- a/backend/src/templates/prompts.ts +++ b/backend/src/templates/prompts.ts @@ -4,12 +4,15 @@ import { getAgentTemplate } from './agent-registry' import type { AgentTemplate } from '@codebuff/common/types/agent-template' import type { AgentTemplateType } from '@codebuff/common/types/session-state' +import type { Logger } from '@codebuff/types/logger' import { buildArray } from '@codebuff/common/util/array' -export async function buildSpawnableAgentsDescription( - spawnableAgents: AgentTemplateType[], - agentTemplates: Record, -): Promise { +export async function buildSpawnableAgentsDescription(params: { + spawnableAgents: AgentTemplateType[] + agentTemplates: Record + logger: Logger +}): Promise { + const { spawnableAgents, agentTemplates, logger } = params if (spawnableAgents.length === 0) { return '' } @@ -18,7 +21,7 @@ export async function buildSpawnableAgentsDescription( spawnableAgents.map(async (agentType) => { return [ agentType, - await getAgentTemplate(agentType, agentTemplates), + await getAgentTemplate({ agentId: agentType, localAgentTemplates: agentTemplates, logger }), ] as const }), ) diff --git a/backend/src/templates/strings.ts b/backend/src/templates/strings.ts index b10957f179..51768713a7 100644 --- a/backend/src/templates/strings.ts +++ b/backend/src/templates/strings.ts @@ -18,6 +18,7 @@ import { } from '../tools/prompts' import { parseUserMessage } from '../util/messages' +import type { Logger } from '@codebuff/types/logger' import type { AgentTemplate, PlaceholderValue } from './types' import type { AgentState, @@ -34,6 +35,7 @@ export async function formatPrompt({ agentTemplates, intitialAgentPrompt, additionalToolDefinitions, + logger, }: { prompt: string fileContext: ProjectFileContext @@ -45,6 +47,7 @@ export async function formatPrompt({ additionalToolDefinitions: () => Promise< ProjectFileContext['customToolDefinitions'] > + logger: Logger }): Promise { const { messageHistory } = agentState const lastUserMessage = messageHistory.findLast( @@ -58,7 +61,7 @@ export async function formatPrompt({ : undefined const agentTemplate = agentState.agentType - ? await getAgentTemplate(agentState.agentType, agentTemplates) + ? await getAgentTemplate({ agentId: agentState.agentType, localAgentTemplates: agentTemplates, logger }) : null const toInject: Record string | Promise> = { @@ -66,9 +69,19 @@ export async function formatPrompt({ agentTemplate ? agentTemplate.displayName || 'Unknown Agent' : 'Buffy', [PLACEHOLDER.CONFIG_SCHEMA]: () => schemaToJsonStr(CodebuffConfigSchema), [PLACEHOLDER.FILE_TREE_PROMPT_SMALL]: () => - getProjectFileTreePrompt(fileContext, 2_500, 'agent'), + getProjectFileTreePrompt({ + fileContext, + fileTreeTokenBudget: 2_500, + mode: 'agent', + logger, + }), [PLACEHOLDER.FILE_TREE_PROMPT]: () => - getProjectFileTreePrompt(fileContext, 10_000, 'agent'), + getProjectFileTreePrompt({ + fileContext, + fileTreeTokenBudget: 10_000, + mode: 'agent', + logger, + }), [PLACEHOLDER.GIT_CHANGES_PROMPT]: () => getGitChangesPrompt(fileContext), [PLACEHOLDER.REMAINING_STEPS]: () => `${agentState.stepsRemaining!}`, [PLACEHOLDER.PROJECT_ROOT]: () => fileContext.projectRoot, @@ -76,7 +89,7 @@ export async function formatPrompt({ [PLACEHOLDER.TOOLS_PROMPT]: async () => getToolsInstructions(tools, await additionalToolDefinitions()), [PLACEHOLDER.AGENTS_PROMPT]: () => - buildSpawnableAgentsDescription(spawnableAgents, agentTemplates), + buildSpawnableAgentsDescription({ spawnableAgents, agentTemplates, logger }), [PLACEHOLDER.USER_CWD]: () => fileContext.cwd, [PLACEHOLDER.USER_INPUT_PROMPT]: () => escapeString(lastUserInput ?? ''), [PLACEHOLDER.INITIAL_AGENT_PROMPT]: () => @@ -142,6 +155,7 @@ export async function getAgentPrompt({ agentState, agentTemplates, additionalToolDefinitions, + logger, }: { agentTemplate: AgentTemplate promptType: { type: T } @@ -151,6 +165,7 @@ export async function getAgentPrompt({ additionalToolDefinitions: () => Promise< ProjectFileContext['customToolDefinitions'] > + logger: Logger }): Promise { let promptValue = agentTemplate[promptType.type] for (const placeholder of additionalPlaceholders[promptType.type]) { @@ -171,6 +186,7 @@ export async function getAgentPrompt({ spawnableAgents: agentTemplate.spawnableAgents, agentTemplates, additionalToolDefinitions, + logger, }) let addendum = '' @@ -192,10 +208,11 @@ export async function getAgentPrompt({ '\n\n' + toolsInstructions + '\n\n' + - (await buildSpawnableAgentsDescription( - agentTemplate.spawnableAgents, + (await buildSpawnableAgentsDescription({ + spawnableAgents: agentTemplate.spawnableAgents, agentTemplates, - )) + logger, + })) const parentInstructions = await collectParentInstructions({ agentType: agentState.agentType, diff --git a/backend/src/tools/batch-str-replace.ts b/backend/src/tools/batch-str-replace.ts index 9264112026..95de68dc4e 100644 --- a/backend/src/tools/batch-str-replace.ts +++ b/backend/src/tools/batch-str-replace.ts @@ -1,19 +1,18 @@ -import { handleStrReplace } from './handlers/tool/str-replace' -import { getFileProcessingValues } from './handlers/tool/write-file' -import { logger } from '../util/logger' -import { Benchify } from 'benchify' +import { withRetry, withTimeout } from '@codebuff/common/util/promise' import { env } from '@codebuff/internal/env' +import { Benchify } from 'benchify' + import { requestToolCall, requestFiles } from '../websockets/websocket-action' -import { ParsedDiff, parsePatch } from 'diff' -import { withRetry, withTimeout } from '@codebuff/common/util/promise' -import { match, P } from 'ts-pattern' +import { handleStrReplace } from './handlers/tool/str-replace' +import { getFileProcessingValues } from './handlers/tool/write-file' + import type { CodebuffToolCall, CodebuffToolOutput, } from '@codebuff/common/tools/list' import type { ToolResultPart } from '@codebuff/common/types/messages/content-part' import type { PrintModeEvent } from '@codebuff/common/types/print-mode' - +import type { Logger } from '@codebuff/types/logger' import type { WebSocket } from 'ws' export type DeferredStrReplace = { @@ -46,7 +45,8 @@ let benchifyCircuitBreaker = { const CIRCUIT_BREAKER_THRESHOLD = 3 // Open circuit after 3 consecutive failures const CIRCUIT_BREAKER_TIMEOUT = 60000 // Keep circuit open for 1 minute -export function getBenchifyClient(): Benchify | null { +export function getBenchifyClient(params: { logger: Logger }): Benchify | null { + const { logger } = params if (!benchifyClient) { let benchifyApiKey: string | undefined try { @@ -82,18 +82,7 @@ type BatchContext = { intendedChanges: Record } -export async function executeBatchStrReplaces({ - deferredStrReplaces, - toolCalls, - toolResults, - ws, - agentStepId, - clientSessionId, - userInputId, - onResponseChunk, - state, - userId, -}: { +export async function executeBatchStrReplaces(params: { deferredStrReplaces: DeferredStrReplace[] toolCalls: (CodebuffToolCall | any)[] toolResults: ToolResultPart[] @@ -104,7 +93,22 @@ export async function executeBatchStrReplaces({ onResponseChunk: (chunk: string | PrintModeEvent) => void state: Record userId: string | undefined + logger: Logger }) { + const { + deferredStrReplaces, + toolCalls, + toolResults, + ws, + agentStepId, + clientSessionId, + userInputId, + onResponseChunk, + state, + userId, + logger, + } = params + if (deferredStrReplaces.length === 0) { return } @@ -120,13 +124,18 @@ export async function executeBatchStrReplaces({ } // Pre-load original content for all paths that support benchify - const originalContents = await preloadOriginalContent(operationsByPath, ws) + const originalContents = await preloadOriginalContent({ + operationsByPath, + ws, + logger, + }) // Extract intended changes for benchify (before execution) - const intendedChanges = await extractAllIntendedChanges( + const intendedChanges = await extractAllIntendedChanges({ operationsByPath, originalContents, - ) + logger, + }) // Track edited files during processing const editedFiles: Record = {} @@ -138,7 +147,8 @@ export async function executeBatchStrReplaces({ const pathPromises: Record> = {} for (const [path, operations] of Object.entries(operationsByPath)) { - pathPromises[path] = processPathOperations(operations, { + pathPromises[path] = processPathOperations({ + operations, toolCalls, toolResults, agentStepId, @@ -147,6 +157,7 @@ export async function executeBatchStrReplaces({ state, editedFiles, requestClientToolCall, + logger, }) } @@ -171,6 +182,7 @@ export async function executeBatchStrReplaces({ userId, toolResults, toolCalls: deferredStrReplaces.map((d) => d.toolCall), + logger, }, ) logger.debug({ agentStepId }, 'Completed batch processing') @@ -180,10 +192,12 @@ export async function executeBatchStrReplaces({ * Pre-loads original file content for all paths that support benchify * Returns a record of path to content for files that were successfully loaded */ -async function preloadOriginalContent( - operationsByPath: Record, - ws: WebSocket, -): Promise> { +async function preloadOriginalContent(params: { + operationsByPath: Record + ws: WebSocket + logger: Logger +}): Promise> { + const { operationsByPath, ws, logger } = params const pathsToLoad = Object.keys(operationsByPath).filter( benchifyCanFixLanguage, ) @@ -220,10 +234,12 @@ async function preloadOriginalContent( * Extracts intended changes for all operations (for benchify) * Returns an object mapping path to intended content after all operations are applied */ -async function extractAllIntendedChanges( - operationsByPath: Record, - originalContents: Record, -): Promise> { +async function extractAllIntendedChanges(params: { + operationsByPath: Record + originalContents: Record + logger: Logger +}): Promise> { + const { operationsByPath, originalContents, logger } = params const intendedChanges: Record = {} for (const [path, operations] of Object.entries(operationsByPath)) { @@ -237,8 +253,11 @@ async function extractAllIntendedChanges( // Apply all operations sequentially to get final intended content for (const { toolCall } of operations) { currentContent = - (await extractIntendedContent(toolCall, currentContent)) || - currentContent + (await extractIntendedContent({ + toolCall, + currentContent, + logger, + })) || currentContent } intendedChanges[path] = currentContent @@ -256,28 +275,33 @@ async function extractAllIntendedChanges( /** * Processes all operations for a single file path sequentially */ -async function processPathOperations( - operations: DeferredStrReplace[], - context: { - toolCalls: (CodebuffToolCall | any)[] - toolResults: ToolResultPart[] - agentStepId: string - userInputId: string - onResponseChunk: (chunk: string | PrintModeEvent) => void - state: Record - editedFiles: Record - requestClientToolCall: ( - clientToolCall: any, - ) => Promise> - }, -) { +async function processPathOperations(params: { + operations: DeferredStrReplace[] + toolCalls: (CodebuffToolCall | any)[] + toolResults: ToolResultPart[] + agentStepId: string + userInputId: string + onResponseChunk: (chunk: string | PrintModeEvent) => void + state: Record + editedFiles: Record + requestClientToolCall: ( + clientToolCall: any, + ) => Promise> + logger: Logger +}) { + const { operations } = params let previousPromise = Promise.resolve() for (let i = 0; i < operations.length; i++) { const { toolCall } = operations[i] previousPromise = previousPromise.then(() => - executeSingleStrReplace(toolCall, i + 1, operations.length, context), + executeSingleStrReplace({ + ...params, + toolCall, + operationIndex: i + 1, + totalOperations: operations.length, + }), ) } @@ -287,24 +311,26 @@ async function processPathOperations( /** * Executes a single str_replace operation with proper error handling */ -async function executeSingleStrReplace( - toolCall: CodebuffToolCall<'str_replace'>, - operationIndex: number, - totalOperations: number, - context: { - toolCalls: (CodebuffToolCall | any)[] - toolResults: ToolResultPart[] - agentStepId: string - userInputId: string - onResponseChunk: (chunk: string | PrintModeEvent) => void - state: Record - editedFiles: Record - requestClientToolCall: ( - clientToolCall: any, - ) => Promise> - }, -) { +async function executeSingleStrReplace(params: { + toolCall: CodebuffToolCall<'str_replace'> + operationIndex: number + totalOperations: number + toolCalls: (CodebuffToolCall | any)[] + toolResults: ToolResultPart[] + agentStepId: string + userInputId: string + onResponseChunk: (chunk: string | PrintModeEvent) => void + state: Record + editedFiles: Record + requestClientToolCall: ( + clientToolCall: any, + ) => Promise> + logger: Logger +}) { const { + toolCall, + operationIndex, + totalOperations, userInputId, onResponseChunk, state, @@ -313,7 +339,8 @@ async function executeSingleStrReplace( toolResults, agentStepId, requestClientToolCall, - } = context + logger, + } = params try { // Create isolated state for each operation @@ -331,6 +358,7 @@ async function executeSingleStrReplace( toolCall, requestClientToolCall, writeToClient: onResponseChunk, + logger, getLatestState: () => getFileProcessingValues(isolatedState), state: isolatedState, }) @@ -359,11 +387,16 @@ async function executeSingleStrReplace( toolCalls.push(toolCall) } catch (error) { - handleStrReplaceError(error, toolCall, operationIndex, totalOperations, { + handleStrReplaceError({ + error, + toolCall, + operationIndex, + totalOperations, toolResults, agentStepId, userInputId, onResponseChunk, + logger, }) } } @@ -427,19 +460,28 @@ function trackEditedFile( /** * Handles errors from str_replace operations with proper logging and error results */ -function handleStrReplaceError( - error: unknown, - toolCall: CodebuffToolCall<'str_replace'>, - operationIndex: number, - totalOperations: number, - context: { - toolResults: ToolResultPart[] - agentStepId: string - userInputId: string - onResponseChunk: (chunk: string | PrintModeEvent) => void - }, -) { - const { toolResults, agentStepId, userInputId, onResponseChunk } = context +function handleStrReplaceError(params: { + error: unknown + toolCall: CodebuffToolCall<'str_replace'> + operationIndex: number + totalOperations: number + toolResults: ToolResultPart[] + agentStepId: string + userInputId: string + onResponseChunk: (chunk: string | PrintModeEvent) => void + logger: Logger +}) { + const { + error, + toolCall, + operationIndex, + totalOperations, + toolResults, + agentStepId, + userInputId, + onResponseChunk, + logger, + } = params logger.error( { @@ -493,62 +535,77 @@ async function applyBenchifyIfNeeded( userId: string | undefined toolResults: ToolResultPart[] toolCalls: CodebuffToolCall<'str_replace'>[] + logger: Logger }, ) { + const { logger } = options // Early exit conditions - fail gracefully without blocking user edits if (Object.keys(batchContext.intendedChanges).length === 0) { return } // Check circuit breaker - if (isBenchifyCircuitOpen()) { + if (isBenchifyCircuitOpen({ logger })) { return } try { // Filter and validate intended changes for Benchify - const filteredChanges = filterBenchifyFiles( - Object.entries(batchContext.intendedChanges).map(([path, contents]) => ({ - path, - contents, - })), - options.agentStepId, - ) + const filteredChanges = filterBenchifyFiles({ + files: Object.entries(batchContext.intendedChanges).map( + ([path, contents]) => ({ + path, + contents, + }), + ), + agentStepId: options.agentStepId, + logger, + }) if (filteredChanges.length === 0) { return } // Call Benchify with timeout and retry logic - const benchifyResult = await callBenchifyWithResilience( - filteredChanges, - options, - ) + const benchifyResult = await callBenchifyWithResilience({ + editedFiles: filteredChanges, + context: options, + logger, + }) if (benchifyResult && benchifyResult.length > 0) { // Apply results with individual error handling to prevent one failure from blocking others - await applyBenchifyResultsGracefully(filteredChanges, benchifyResult, { - ws: batchContext.ws, - onResponseChunk: batchContext.onResponseChunk, - state: { - ...batchContext.state, - originalContents: batchContext.originalContents, + await applyBenchifyResultsGracefully({ + editedFiles: filteredChanges, + benchifyDiff: benchifyResult, + context: { + ws: batchContext.ws, + onResponseChunk: batchContext.onResponseChunk, + state: { + ...batchContext.state, + originalContents: batchContext.originalContents, + }, + toolResults: options.toolResults, + toolCalls: options.toolCalls, + userInputId: options.userInputId, + agentStepId: options.agentStepId, }, - toolResults: options.toolResults, - toolCalls: options.toolCalls, - userInputId: options.userInputId, - agentStepId: options.agentStepId, + logger, }) } // Reset circuit breaker on success - resetBenchifyCircuitBreaker() + resetBenchifyCircuitBreaker({ logger }) } catch (error) { // Handle Benchify failure gracefully without blocking user edits - handleBenchifyFailure(error, { - intendedChangeFiles: Object.keys(batchContext.intendedChanges), - agentStepId: options.agentStepId, - userInputId: options.userInputId, + handleBenchifyFailure({ + error, + context: { + intendedChangeFiles: Object.keys(batchContext.intendedChanges), + agentStepId: options.agentStepId, + userInputId: options.userInputId, + }, + logger, }) } } @@ -556,10 +613,12 @@ async function applyBenchifyIfNeeded( /** * Filters files for Benchify processing based on size and count limits */ -function filterBenchifyFiles( - files: { path: string; contents: string }[], - agentStepId: string, -): { path: string; contents: string }[] { +function filterBenchifyFiles(params: { + files: { path: string; contents: string }[] + agentStepId: string + logger: Logger +}): { path: string; contents: string }[] { + const { files, agentStepId, logger } = params const filtered = files.filter((file) => { // Check file size limit if (file.contents.length > BENCHIFY_MAX_FILE_SIZE) { @@ -597,16 +656,18 @@ function filterBenchifyFiles( /** * Calls benchify API with timeout and retry logic using common utilities */ -async function callBenchifyWithResilience( - editedFiles: { path: string; contents: string }[], +async function callBenchifyWithResilience(params: { + editedFiles: { path: string; contents: string }[] context: { agentStepId: string clientSessionId: string userInputId: string userId: string | undefined - }, -): Promise { - const client = getBenchifyClient() + } + logger: Logger +}): Promise { + const { editedFiles, context, logger } = params + const client = getBenchifyClient({ logger }) if (!client) { return null } @@ -686,9 +747,9 @@ function shouldRetryBenchifyError(error: Error): boolean { /** * Applies benchify results back to the file system with individual error handling */ -async function applyBenchifyResultsGracefully( - editedFiles: { path: string; contents: string }[], - benchifyDiff: string, +async function applyBenchifyResultsGracefully(params: { + editedFiles: { path: string; contents: string }[] + benchifyDiff: string context: { ws: WebSocket onResponseChunk: (chunk: string | PrintModeEvent) => void @@ -697,12 +758,19 @@ async function applyBenchifyResultsGracefully( toolCalls: CodebuffToolCall<'str_replace'>[] userInputId: string agentStepId: string - }, -) { + } + logger: Logger +}) { + const { editedFiles, benchifyDiff, context, logger } = params const results = await Promise.allSettled( editedFiles.map((editedFile) => { if (benchifyDiff) { - applyBenchifyResultSafely(editedFile, benchifyDiff, context) + applyBenchifyResultSafely({ + benchifyFile: editedFile, + benchifyDiff, + context, + logger, + }) } else { logger.warn( { file: editedFile.path }, @@ -729,9 +797,9 @@ async function applyBenchifyResultsGracefully( /** * Safely applies a single Benchify result with comprehensive error handling */ -async function applyBenchifyResultSafely( - benchifyFile: { path: string; contents: string }, - benchifyDiff: string, +async function applyBenchifyResultSafely(params: { + benchifyFile: { path: string; contents: string } + benchifyDiff: string context: { ws: WebSocket onResponseChunk: (chunk: string | PrintModeEvent) => void @@ -740,8 +808,10 @@ async function applyBenchifyResultSafely( toolCalls: CodebuffToolCall<'str_replace'>[] userInputId: string agentStepId: string - }, -): Promise { + } + logger: Logger +}): Promise { + const { benchifyFile, benchifyDiff, context, logger } = params try { // Find the corresponding tool call for this file const relatedToolCall = context.toolCalls.find( @@ -844,10 +914,12 @@ async function applyBenchifyResultSafely( /** * Extracts the intended file content by applying str_replace operations to the current content */ -async function extractIntendedContent( - toolCall: CodebuffToolCall<'str_replace'>, - currentContent: string, -): Promise { +async function extractIntendedContent(params: { + toolCall: CodebuffToolCall<'str_replace'> + currentContent: string + logger: Logger +}): Promise { + const { toolCall, currentContent, logger } = params try { let content = currentContent @@ -895,7 +967,8 @@ async function extractIntendedContent( /** * Circuit breaker functions for Benchify resilience */ -function isBenchifyCircuitOpen(): boolean { +function isBenchifyCircuitOpen(params: { logger: Logger }): boolean { + const { logger } = params const now = Date.now() // Check if circuit should be half-open (reset after timeout) @@ -908,14 +981,16 @@ function isBenchifyCircuitOpen(): boolean { return benchifyCircuitBreaker.isOpen } -function handleBenchifyFailure( - error: unknown, +function handleBenchifyFailure(params: { + error: unknown context: { intendedChangeFiles: string[] agentStepId: string userInputId: string - }, -): void { + } + logger: Logger +}): void { + const { error, context, logger } = params benchifyCircuitBreaker.failureCount++ benchifyCircuitBreaker.lastFailureTime = Date.now() @@ -949,7 +1024,8 @@ function handleBenchifyFailure( ) } -function resetBenchifyCircuitBreaker(): void { +function resetBenchifyCircuitBreaker(params: { logger: Logger }): void { + const { logger } = params if (benchifyCircuitBreaker.failureCount > 0) { logger.debug( { previousFailures: benchifyCircuitBreaker.failureCount }, diff --git a/backend/src/tools/handlers/handler-function-type.ts b/backend/src/tools/handlers/handler-function-type.ts index ebcc6d8249..0c49136cbf 100644 --- a/backend/src/tools/handlers/handler-function-type.ts +++ b/backend/src/tools/handlers/handler-function-type.ts @@ -8,6 +8,7 @@ import type { } from '@codebuff/common/tools/list' import type { PrintModeEvent } from '@codebuff/common/types/print-mode' import type { ProjectFileContext } from '@codebuff/common/util/file' +import type { Logger } from '@codebuff/types/logger' type PresentOrAbsent = | { [P in K]: V } @@ -29,6 +30,7 @@ export type CodebuffToolHandlerFunction = ( getLatestState: () => any state: { [K in string]?: any } + logger: Logger } & PresentOrAbsent< 'requestClientToolCall', ( diff --git a/backend/src/tools/handlers/tool/create-plan.ts b/backend/src/tools/handlers/tool/create-plan.ts index 81f54cba60..f6313a8c2b 100644 --- a/backend/src/tools/handlers/tool/create-plan.ts +++ b/backend/src/tools/handlers/tool/create-plan.ts @@ -2,9 +2,9 @@ import { trackEvent } from '@codebuff/common/analytics' import { AnalyticsEvent } from '@codebuff/common/constants/analytics-events' import { getFileProcessingValues, postStreamProcessing } from './write-file' -import { logger } from '../../../util/logger' import type { CodebuffToolHandlerFunction } from '../handler-function-type' +import type { Logger } from '@codebuff/types/logger' import type { FileProcessingState, OptionalFileProcessingState, @@ -22,6 +22,7 @@ export const handleCreatePlan = ((params: { toolCall: ClientToolCall<'create_plan'>, ) => Promise> writeToClient: (chunk: string) => void + logger: Logger getLatestState: () => FileProcessingState state: { @@ -41,6 +42,7 @@ export const handleCreatePlan = ((params: { toolCall, requestClientToolCall, writeToClient, + logger, getLatestState, state, } = params diff --git a/backend/src/tools/handlers/tool/find-files.ts b/backend/src/tools/handlers/tool/find-files.ts index 76ecc271cd..99d6a4719e 100644 --- a/backend/src/tools/handlers/tool/find-files.ts +++ b/backend/src/tools/handlers/tool/find-files.ts @@ -6,13 +6,13 @@ import { } from '../../../find-files/request-files-prompt' import { getFileReadingUpdates } from '../../../get-file-reading-updates' import { getSearchSystemPrompt } from '../../../system-prompt/search-system-prompt' -import { logger } from '../../../util/logger' import { renderReadFilesResult } from '../../../util/parse-tool-call-xml' import { countTokens, countTokensJson } from '../../../util/token-counter' import { requestFiles } from '../../../websockets/websocket-action' import type { TextBlock } from '../../../llm-apis/claude' import type { CodebuffToolHandlerFunction } from '../handler-function-type' +import type { Logger } from '@codebuff/types/logger' import type { GetExpandedFileContextForTrainingBlobTrace } from '@codebuff/bigquery' import type { CodebuffToolCall, @@ -29,6 +29,7 @@ const COLLECT_FULL_FILE_CONTEXT = false export const handleFindFiles = ((params: { previousToolCallFinished: Promise toolCall: CodebuffToolCall<'find_files'> + logger: Logger fileContext: ProjectFileContext agentStepId: string @@ -46,6 +47,7 @@ export const handleFindFiles = ((params: { const { previousToolCallFinished, toolCall, + logger, fileContext, agentStepId, clientSessionId, @@ -71,6 +73,7 @@ export const handleFindFiles = ((params: { const system = getSearchSystemPrompt({ fileContext, messagesTokens: fileRequestMessagesTokens, + logger, options: { agentStepId, clientSessionId, @@ -83,17 +86,19 @@ export const handleFindFiles = ((params: { const triggerFindFiles: () => Promise< CodebuffToolOutput<'find_files'> > = async () => { - const requestedFiles = await requestRelevantFiles( - { messages, system }, + const requestedFiles = await requestRelevantFiles({ + messages, + system, fileContext, - prompt, + assistantPrompt: prompt, agentStepId, clientSessionId, fingerprintId, userInputId, userId, repoId, - ) + logger, + }) if (requestedFiles && requestedFiles.length > 0) { const addedFiles = await getFileReadingUpdates(ws, requestedFiles) @@ -110,6 +115,7 @@ export const handleFindFiles = ((params: { userInputId, userId, repoId, + logger, ).catch((error) => { logger.error( { error }, @@ -175,9 +181,11 @@ async function uploadExpandedFileContextForTraining( userInputId: string, userId: string | undefined, repoId: string | undefined, + logger: Logger, ) { - const files = await requestRelevantFilesForTraining( - { messages, system }, + const files = await requestRelevantFilesForTraining({ + messages, + system, fileContext, assistantPrompt, agentStepId, @@ -186,7 +194,8 @@ async function uploadExpandedFileContextForTraining( userInputId, userId, repoId, - ) + logger, + }) const loadedFiles = await requestFiles({ ws, filePaths: files }) diff --git a/backend/src/tools/handlers/tool/lookup-agent-info.ts b/backend/src/tools/handlers/tool/lookup-agent-info.ts index db50040452..7d3fb93fc7 100644 --- a/backend/src/tools/handlers/tool/lookup-agent-info.ts +++ b/backend/src/tools/handlers/tool/lookup-agent-info.ts @@ -11,10 +11,11 @@ export const handleLookupAgentInfo: CodebuffToolHandlerFunction< return { result: (async () => { - const agentTemplate = await getAgentTemplate( + const agentTemplate = await getAgentTemplate({ agentId, - params.state.localAgentTemplates || {}, - ) + localAgentTemplates: params.state.localAgentTemplates || {}, + logger: params.logger, + }) if (!agentTemplate) { return [ diff --git a/backend/src/tools/handlers/tool/read-docs.ts b/backend/src/tools/handlers/tool/read-docs.ts index cf2b9eae29..f4a925a03e 100644 --- a/backend/src/tools/handlers/tool/read-docs.ts +++ b/backend/src/tools/handlers/tool/read-docs.ts @@ -1,15 +1,16 @@ import { fetchContext7LibraryDocumentation } from '../../../llm-apis/context7-api' -import { logger } from '../../../util/logger' import type { CodebuffToolHandlerFunction } from '../handler-function-type' import type { CodebuffToolCall, CodebuffToolOutput, } from '@codebuff/common/tools/list' +import type { Logger } from '@codebuff/types/logger' export const handleReadDocs = (({ previousToolCallFinished, toolCall, + logger, agentStepId, clientSessionId, userInputId, @@ -17,6 +18,7 @@ export const handleReadDocs = (({ }: { previousToolCallFinished: Promise toolCall: CodebuffToolCall<'read_docs'> + logger: Logger agentStepId: string clientSessionId: string @@ -58,13 +60,12 @@ export const handleReadDocs = (({ const documentationPromise = (async () => { try { - const documentation = await fetchContext7LibraryDocumentation( - libraryTitle, - { - topic, - tokens: max_tokens, - }, - ) + const documentation = await fetchContext7LibraryDocumentation({ + query: libraryTitle, + topic, + tokens: max_tokens, + logger, + }) const docsDuration = Date.now() - docsStartTime const resultLength = documentation?.length || 0 diff --git a/backend/src/tools/handlers/tool/set-output.ts b/backend/src/tools/handlers/tool/set-output.ts index 83ee85a2e1..e3d416b8fc 100644 --- a/backend/src/tools/handlers/tool/set-output.ts +++ b/backend/src/tools/handlers/tool/set-output.ts @@ -1,5 +1,5 @@ import { getAgentTemplate } from '../../../templates/agent-registry' -import { logger } from '../../../util/logger' +import type { Logger } from '@codebuff/types/logger' import type { CodebuffToolHandlerFunction } from '../handler-function-type' import type { @@ -19,11 +19,12 @@ export const handleSetOutput = ((params: { agentState?: AgentState localAgentTemplates?: Record } + logger: Logger }): { result: Promise> state: { agentState: AgentState } } => { - const { previousToolCallFinished, toolCall, state } = params + const { previousToolCallFinished, toolCall, state, logger } = params const output = toolCall.input const { agentState, localAgentTemplates } = state @@ -43,10 +44,11 @@ export const handleSetOutput = ((params: { // Validate output against outputSchema if defined let agentTemplate = null if (agentState.agentType) { - agentTemplate = await getAgentTemplate( - agentState.agentType, + agentTemplate = await getAgentTemplate({ + agentId: agentState.agentType, localAgentTemplates, - ) + logger, + }) } if (agentTemplate?.outputSchema) { try { diff --git a/backend/src/tools/handlers/tool/spawn-agent-inline.ts b/backend/src/tools/handlers/tool/spawn-agent-inline.ts index d35c1c6744..7edd022ee3 100644 --- a/backend/src/tools/handlers/tool/spawn-agent-inline.ts +++ b/backend/src/tools/handlers/tool/spawn-agent-inline.ts @@ -6,6 +6,7 @@ import { executeSubagent, createAgentState, } from './spawn-agent-utils' +import type { Logger } from '@codebuff/types/logger' import type { CodebuffToolHandlerFunction } from '../handler-function-type' import type { @@ -39,6 +40,7 @@ export const handleSpawnAgentInline = ((params: { agentState?: AgentState system?: string } + logger: Logger }): { result: Promise>; state: {} } => { const { previousToolCallFinished, @@ -65,12 +67,15 @@ export const handleSpawnAgentInline = ((params: { system, } = validateSpawnState(state, 'spawn_agent_inline') + const { logger } = params + const triggerSpawnAgentInline = async () => { - const { agentTemplate, agentType } = await validateAndGetAgentTemplate( + const { agentTemplate, agentType } = await validateAndGetAgentTemplate({ agentTypeStr, parentAgentTemplate, localAgentTemplates, - ) + logger, + }) validateAgentInput(agentTemplate, agentType, prompt, agentParams) @@ -83,15 +88,16 @@ export const handleSpawnAgentInline = ((params: { parentAgentState.agentContext, ) - logAgentSpawn( + logAgentSpawn({ agentTemplate, agentType, - childAgentState.agentId, - childAgentState.parentId, + agentId: childAgentState.agentId, + parentId: childAgentState.parentId, prompt, - agentParams, - true, // inline = true - ) + params: agentParams, + inline: true, + logger, + }) const result = await executeSubagent({ ws, @@ -113,6 +119,7 @@ export const handleSpawnAgentInline = ((params: { // writeToClient(chunk) }, clearUserPromptMessagesAfterResponse: false, + logger, }) // Update parent's message history with child's final state diff --git a/backend/src/tools/handlers/tool/spawn-agent-utils.ts b/backend/src/tools/handlers/tool/spawn-agent-utils.ts index c1938a3e21..e1d228b674 100644 --- a/backend/src/tools/handlers/tool/spawn-agent-utils.ts +++ b/backend/src/tools/handlers/tool/spawn-agent-utils.ts @@ -3,7 +3,7 @@ import { parseAgentId } from '@codebuff/common/util/agent-id-parsing' import { generateCompactId } from '@codebuff/common/util/string' import { getAgentTemplate } from '../../../templates/agent-registry' -import { logger } from '../../../util/logger' +import type { Logger } from '@codebuff/types/logger' import type { AgentTemplate } from '@codebuff/common/types/agent-template' import type { Message } from '@codebuff/common/types/messages/codebuff-message' @@ -165,15 +165,19 @@ export function getMatchingSpawn( /** * Validates agent template and permissions */ -export async function validateAndGetAgentTemplate( - agentTypeStr: string, - parentAgentTemplate: AgentTemplate, - localAgentTemplates: Record, -): Promise<{ agentTemplate: AgentTemplate; agentType: string }> { - const agentTemplate = await getAgentTemplate( - agentTypeStr, +export async function validateAndGetAgentTemplate(params: { + agentTypeStr: string + parentAgentTemplate: AgentTemplate + localAgentTemplates: Record + logger: Logger +}): Promise<{ agentTemplate: AgentTemplate; agentType: string }> { + const { agentTypeStr, parentAgentTemplate, localAgentTemplates, logger } = + params + const agentTemplate = await getAgentTemplate({ + agentId: agentTypeStr, localAgentTemplates, - ) + logger, + }) if (!agentTemplate) { throw new Error(`Agent type ${agentTypeStr} not found.`) @@ -267,20 +271,31 @@ export function createAgentState( /** * Logs agent spawn information */ -export function logAgentSpawn( - agentTemplate: AgentTemplate, - agentType: string, - agentId: string, - parentId: string | undefined, - prompt?: string, - params?: any, - inline = false, -): void { +export function logAgentSpawn(params: { + agentTemplate: AgentTemplate + agentType: string + agentId: string + parentId: string | undefined + prompt?: string + params?: any + inline?: boolean + logger: Logger +}): void { + const { + agentTemplate, + agentType, + agentId, + parentId, + prompt, + params: spawnParams, + inline = false, + logger, + } = params logger.debug( { agentTemplate, prompt, - params, + params: spawnParams, agentId, parentId, }, @@ -308,6 +323,7 @@ export async function executeSubagent({ isOnlyChild = false, clearUserPromptMessagesAfterResponse = true, parentSystemPrompt, + logger, }: { ws: WebSocket userInputId: string @@ -325,6 +341,7 @@ export async function executeSubagent({ isOnlyChild?: boolean clearUserPromptMessagesAfterResponse?: boolean parentSystemPrompt?: string + logger: Logger }) { onResponseChunk({ type: 'subagent_start', @@ -336,7 +353,8 @@ export async function executeSubagent({ // Import loopAgentSteps dynamically to avoid circular dependency const { loopAgentSteps } = await import('../../../run-agent-step') - const result = await loopAgentSteps(ws, { + const result = await loopAgentSteps({ + ws, userInputId, prompt, params, @@ -350,6 +368,7 @@ export async function executeSubagent({ onResponseChunk, clearUserPromptMessagesAfterResponse, parentSystemPrompt, + logger, }) onResponseChunk({ @@ -361,8 +380,6 @@ export async function executeSubagent({ if (result.agentState.runId) { parentAgentState.childRunIds.push(result.agentState.runId) - } else { - logger.error('No runId found for agent state after executing agent') } return result diff --git a/backend/src/tools/handlers/tool/spawn-agents.ts b/backend/src/tools/handlers/tool/spawn-agents.ts index 1c9bfc9e73..75598ac5b2 100644 --- a/backend/src/tools/handlers/tool/spawn-agents.ts +++ b/backend/src/tools/handlers/tool/spawn-agents.ts @@ -6,7 +6,7 @@ import { logAgentSpawn, executeSubagent, } from './spawn-agent-utils' -import { logger } from '../../../util/logger' +import type { Logger } from '@codebuff/types/logger' import type { CodebuffToolHandlerFunction } from '../handler-function-type' import type { @@ -50,6 +50,7 @@ export const handleSpawnAgents = ((params: { agentState?: AgentState system?: string } + logger: Logger }): { result: Promise>; state: {} } => { const { previousToolCallFinished, @@ -64,6 +65,7 @@ export const handleSpawnAgents = ((params: { } = params const { agents } = toolCall.input const validatedState = validateSpawnState(state, 'spawn_agents') + const { logger } = params const { sendSubagentChunk, system: parentSystemPrompt } = state if (!sendSubagentChunk) { @@ -84,11 +86,12 @@ export const handleSpawnAgents = ((params: { const triggerSpawnAgents = async () => { const results = await Promise.allSettled( agents.map(async ({ agent_type: agentTypeStr, prompt, params }) => { - const { agentTemplate, agentType } = await validateAndGetAgentTemplate( + const { agentTemplate, agentType } = await validateAndGetAgentTemplate({ agentTypeStr, parentAgentTemplate, localAgentTemplates, - ) + logger, + }) validateAgentInput(agentTemplate, agentType, prompt, params) @@ -100,14 +103,15 @@ export const handleSpawnAgents = ((params: { {}, ) - logAgentSpawn( + logAgentSpawn({ agentTemplate, agentType, - subAgentState.agentId, - subAgentState.parentId, + agentId: subAgentState.agentId, + parentId: subAgentState.parentId, prompt, params, - ) + logger, + }) const result = await executeSubagent({ ws, @@ -140,6 +144,7 @@ export const handleSpawnAgents = ((params: { prompt, }) }, + logger, }) return { ...result, agentType, agentName: agentTemplate.displayName } }), diff --git a/backend/src/tools/handlers/tool/str-replace.ts b/backend/src/tools/handlers/tool/str-replace.ts index 1548aba157..80c11a8685 100644 --- a/backend/src/tools/handlers/tool/str-replace.ts +++ b/backend/src/tools/handlers/tool/str-replace.ts @@ -1,9 +1,9 @@ import { getFileProcessingValues, postStreamProcessing } from './write-file' import { processStrReplace } from '../../../process-str-replace' -import { logger } from '../../../util/logger' import { requestOptionalFile } from '../../../websockets/websocket-action' import type { CodebuffToolHandlerFunction } from '../handler-function-type' +import type { Logger } from '@codebuff/types/logger' import type { FileProcessingState, OptionalFileProcessingState, @@ -22,6 +22,7 @@ export const handleStrReplace = ((params: { toolCall: ClientToolCall<'str_replace'>, ) => Promise> writeToClient: (chunk: string) => void + logger: Logger getLatestState: () => FileProcessingState state: { @@ -36,6 +37,7 @@ export const handleStrReplace = ((params: { toolCall, requestClientToolCall, writeToClient, + logger, getLatestState, state, } = params diff --git a/backend/src/tools/handlers/tool/think-deeply.ts b/backend/src/tools/handlers/tool/think-deeply.ts index 8bdd8d9aad..f2aba6e5fe 100644 --- a/backend/src/tools/handlers/tool/think-deeply.ts +++ b/backend/src/tools/handlers/tool/think-deeply.ts @@ -1,6 +1,5 @@ -import { logger } from '../../../util/logger' - import type { CodebuffToolHandlerFunction } from '../handler-function-type' +import type { Logger } from '@codebuff/types/logger' import type { CodebuffToolCall, CodebuffToolOutput, @@ -9,8 +8,9 @@ import type { export const handleThinkDeeply = ((params: { previousToolCallFinished: Promise toolCall: CodebuffToolCall<'think_deeply'> + logger: Logger }): { result: Promise>; state: {} } => { - const { previousToolCallFinished, toolCall } = params + const { previousToolCallFinished, toolCall, logger } = params const { thought } = toolCall.input logger.debug( diff --git a/backend/src/tools/handlers/tool/web-search.ts b/backend/src/tools/handlers/tool/web-search.ts index ccabc76855..d91f8d5cbd 100644 --- a/backend/src/tools/handlers/tool/web-search.ts +++ b/backend/src/tools/handlers/tool/web-search.ts @@ -3,17 +3,18 @@ import { consumeCreditsWithFallback } from '@codebuff/billing' import { getRequestContext } from '../../../context/app-context' import { searchWeb } from '../../../llm-apis/linkup-api' import { PROFIT_MARGIN } from '../../../llm-apis/message-cost-tracker' -import { logger } from '../../../util/logger' import type { CodebuffToolHandlerFunction } from '../handler-function-type' import type { CodebuffToolCall, CodebuffToolOutput, } from '@codebuff/common/tools/list' +import type { Logger } from '@codebuff/types/logger' export const handleWebSearch = ((params: { previousToolCallFinished: Promise toolCall: CodebuffToolCall<'web_search'> + logger: Logger agentStepId: string clientSessionId: string @@ -28,6 +29,7 @@ export const handleWebSearch = ((params: { const { previousToolCallFinished, toolCall, + logger, agentStepId, clientSessionId, userInputId, @@ -57,7 +59,7 @@ export const handleWebSearch = ((params: { const webSearchPromise: Promise> = (async () => { try { - const searchResult = await searchWeb(query, { depth }) + const searchResult = await searchWeb({ query, depth, logger }) const searchDuration = Date.now() - searchStartTime const resultLength = searchResult?.length || 0 const hasResults = Boolean(searchResult && searchResult.trim()) diff --git a/backend/src/tools/handlers/tool/write-file.ts b/backend/src/tools/handlers/tool/write-file.ts index 0c968d8bc9..e452a82bcf 100644 --- a/backend/src/tools/handlers/tool/write-file.ts +++ b/backend/src/tools/handlers/tool/write-file.ts @@ -75,6 +75,7 @@ export const handleWriteFile = (({ getLatestState, state, + logger, }: { previousToolCallFinished: Promise toolCall: CodebuffToolCall<'write_file'> @@ -95,14 +96,14 @@ export const handleWriteFile = (({ fullResponse?: string prompt?: string messages?: Message[] - logger?: Logger } & OptionalFileProcessingState + logger: Logger }): { result: Promise> state: FileProcessingState } => { const { path, instructions, content } = toolCall.input - const { ws, fingerprintId, userId, fullResponse, prompt, logger } = state + const { ws, fingerprintId, userId, fullResponse, prompt } = state if (!ws) { throw new Error('Internal error for write_file: Missing WebSocket in state') } @@ -119,9 +120,6 @@ export const handleWriteFile = (({ if (!agentMessagesUntruncated) { throw new Error('Internal error for write_file: Missing messages in state') } - if (!logger) { - throw new Error('Internal error for write_file: Missing logger in state') - } // Initialize state for this file path if needed if (!fileProcessingPromisesByPath[path]) { diff --git a/backend/src/tools/stream-parser.ts b/backend/src/tools/stream-parser.ts index d285f360d0..ea606fb772 100644 --- a/backend/src/tools/stream-parser.ts +++ b/backend/src/tools/stream-parser.ts @@ -9,7 +9,7 @@ import { generateCompactId } from '@codebuff/common/util/string' import { cloneDeep } from 'lodash' import { executeBatchStrReplaces } from './batch-str-replace' -import { logger } from '../util/logger' +import type { Logger } from '@codebuff/types/logger' import { expireMessages } from '../util/messages' import { sendAction } from '../websockets/websocket-action' import { processStreamWithTags } from '../xml-stream-parser' @@ -56,6 +56,7 @@ export async function processStreamWithTools(options: { agentContext: Record onResponseChunk: (chunk: string | PrintModeEvent) => void fullResponse: string + logger: Logger }) { const { stream, @@ -73,6 +74,7 @@ export async function processStreamWithTools(options: { system, agentState, onResponseChunk, + logger, } = options const fullResponseChunks: string[] = [options.fullResponse] @@ -175,6 +177,7 @@ export async function processStreamWithTools(options: { onResponseChunk, state, userId, + logger, }) }, ) @@ -197,6 +200,7 @@ export async function processStreamWithTools(options: { onResponseChunk, state, userId, + logger, }) } }, @@ -224,6 +228,7 @@ export async function processStreamWithTools(options: { onResponseChunk, state, userId, + logger, }) }, } @@ -339,6 +344,7 @@ export async function processStreamWithTools(options: { onResponseChunk, state, userId, + logger, }) logger.info( { diff --git a/backend/src/tools/tool-executor.ts b/backend/src/tools/tool-executor.ts index f99e85c91e..1e446ef99e 100644 --- a/backend/src/tools/tool-executor.ts +++ b/backend/src/tools/tool-executor.ts @@ -6,7 +6,6 @@ import z from 'zod/v4' import { convertJsonSchemaToZod } from 'zod-from-json-schema' import { checkLiveUserInput } from '../live-user-inputs' -import { logger } from '../util/logger' import { requestToolCall } from '../websockets/websocket-action' import { codebuffToolDefs } from './definitions/list' import { codebuffToolHandlers } from './handlers/list' @@ -31,6 +30,7 @@ import type { customToolDefinitionsSchema, ProjectFileContext, } from '@codebuff/common/util/file' +import type { Logger } from '@codebuff/types/logger' import type { WebSocket } from 'ws' export type CustomToolCall = { @@ -131,28 +131,33 @@ export interface ExecuteToolCallParams { userId: string | undefined autoInsertEndStepParam?: boolean excludeToolFromMessageHistory?: boolean + logger: Logger } -export function executeToolCall({ - toolName, - input, - toolCalls, - toolResults, - toolResultsToAddAfterStream, - previousToolCallFinished, - ws, - agentTemplate, - fileContext, - agentStepId, - clientSessionId, - userInputId, - fullResponse, - onResponseChunk, - state, - userId, - autoInsertEndStepParam = false, - excludeToolFromMessageHistory = false, -}: ExecuteToolCallParams): Promise { +export function executeToolCall( + params: ExecuteToolCallParams, +): Promise { + const { + toolName, + input, + toolCalls, + toolResults, + toolResultsToAddAfterStream, + previousToolCallFinished, + ws, + agentTemplate, + fileContext, + agentStepId, + clientSessionId, + userInputId, + fullResponse, + onResponseChunk, + state, + userId, + autoInsertEndStepParam = false, + excludeToolFromMessageHistory = false, + logger, + } = params const toolCall: CodebuffToolCall | ToolCallError = parseRawToolCall({ rawToolCall: { toolName, @@ -226,7 +231,7 @@ export function executeToolCall({ requestClientToolCall: async ( clientToolCall: ClientToolCall, ) => { - if (!checkLiveUserInput(userId, userInputId, clientSessionId)) { + if (!checkLiveUserInput(params)) { return [] } @@ -241,6 +246,7 @@ export function executeToolCall({ toolCall, getLatestState: () => state, state, + logger, }) as ReturnType> for (const [key, value] of Object.entries(stateUpdate ?? {})) { @@ -360,24 +366,28 @@ export function parseRawCustomToolCall(params: { } } -export async function executeCustomToolCall({ - toolName, - input, - toolCalls, - toolResults, - toolResultsToAddAfterStream, - previousToolCallFinished, - ws, - agentTemplate, - fileContext, - clientSessionId, - userInputId, - onResponseChunk, - state, - userId, - autoInsertEndStepParam = false, - excludeToolFromMessageHistory = false, -}: ExecuteToolCallParams): Promise { +export async function executeCustomToolCall( + params: ExecuteToolCallParams, +): Promise { + const { + toolName, + input, + toolCalls, + toolResults, + toolResultsToAddAfterStream, + previousToolCallFinished, + ws, + agentTemplate, + fileContext, + clientSessionId, + userInputId, + onResponseChunk, + state, + userId, + autoInsertEndStepParam = false, + excludeToolFromMessageHistory = false, + logger, + } = params const toolCall: CustomToolCall | ToolCallError = parseRawCustomToolCall({ customToolDefs: await getMCPToolData({ ws, @@ -452,7 +462,7 @@ export async function executeCustomToolCall({ return previousToolCallFinished .then(async () => { - if (!checkLiveUserInput(userId, userInputId, clientSessionId)) { + if (!checkLiveUserInput(params)) { return null } diff --git a/backend/src/util/__tests__/messages.test.ts b/backend/src/util/__tests__/messages.test.ts index 3655c3fd30..a69fed9c2b 100644 --- a/backend/src/util/__tests__/messages.test.ts +++ b/backend/src/util/__tests__/messages.test.ts @@ -8,7 +8,6 @@ import { spyOn, } from 'bun:test' -import { logger } from '../logger' import { trimMessagesToFitTokenLimit, messagesWithSystem, @@ -40,6 +39,14 @@ describe('messagesWithSystem', () => { }) }) +// Mock logger for tests +const logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, +} + describe('trimMessagesToFitTokenLimit', () => { beforeEach(() => { // Mock countTokensJson to just count characters @@ -211,11 +218,12 @@ describe('trimMessagesToFitTokenLimit', () => { it('handles all features working together correctly', () => { const maxTotalTokens = 3000 const systemTokens = 0 - const result = trimMessagesToFitTokenLimit( - testMessages, + const result = trimMessagesToFitTokenLimit({ + messages: testMessages, systemTokens, maxTotalTokens, - ) + logger, + }) // Should have replacement message for omitted content expect(result.length).toBeGreaterThan(0) @@ -236,11 +244,12 @@ describe('trimMessagesToFitTokenLimit', () => { it('subtracts system tokens from total tokens', () => { const maxTotalTokens = 10_000 const systemTokens = 7_000 - const result = trimMessagesToFitTokenLimit( - testMessages, + const result = trimMessagesToFitTokenLimit({ + messages: testMessages, systemTokens, maxTotalTokens, - ) + logger, + }) // Should have replacement message for omitted content expect(result.length).toBeGreaterThan(0) @@ -261,11 +270,12 @@ describe('trimMessagesToFitTokenLimit', () => { it('does not simplify if under token limit', () => { const maxTotalTokens = 10_000 const systemTokens = 100 - const result = trimMessagesToFitTokenLimit( - testMessages, + const result = trimMessagesToFitTokenLimit({ + messages: testMessages, systemTokens, maxTotalTokens, - ) + logger, + }) // All messages should be unchanged expect(result).toHaveLength(testMessages.length) @@ -282,7 +292,7 @@ describe('trimMessagesToFitTokenLimit', () => { it('handles empty messages array', () => { const maxTotalTokens = 200 const systemTokens = 100 - const result = trimMessagesToFitTokenLimit([], systemTokens, maxTotalTokens) + const result = trimMessagesToFitTokenLimit({ messages: [], systemTokens, maxTotalTokens, logger }) expect(result).toEqual([]) }) @@ -305,7 +315,7 @@ describe('trimMessagesToFitTokenLimit', () => { }, ] - const result = trimMessagesToFitTokenLimit(messages, 0, 1000) + const result = trimMessagesToFitTokenLimit({ messages, systemTokens: 0, maxTotalTokens: 1000, logger }) // Should contain the kept messages const keptMessages = result.filter( @@ -335,7 +345,7 @@ describe('trimMessagesToFitTokenLimit', () => { }, ] as Message[] - const result = trimMessagesToFitTokenLimit(messages, 0, 10000) + const result = trimMessagesToFitTokenLimit({ messages, systemTokens: 0, maxTotalTokens: 10000, logger }) // Should be unchanged when under token limit expect(result).toHaveLength(2) @@ -351,7 +361,7 @@ describe('trimMessagesToFitTokenLimit', () => { { role: 'user', content: 'Keep this', keepDuringTruncation: true }, ] - const result = trimMessagesToFitTokenLimit(messages, 0, 1000) + const result = trimMessagesToFitTokenLimit({ messages, systemTokens: 0, maxTotalTokens: 1000, logger }) // Should only have one replacement message for consecutive removals const replacementMessages = result.filter( @@ -381,7 +391,7 @@ describe('trimMessagesToFitTokenLimit', () => { { role: 'user', content: 'C'.repeat(100) }, // Might be kept ] - const result = trimMessagesToFitTokenLimit(messages, 0, 2000) + const result = trimMessagesToFitTokenLimit({ messages, systemTokens: 0, maxTotalTokens: 2000, logger }) // Should preserve the keepDuringTruncation message const keptMessage = result.find( @@ -405,7 +415,7 @@ describe('trimMessagesToFitTokenLimit', () => { { role: 'user', content: 'C'.repeat(800) }, // Large message to force truncation ] - const result = trimMessagesToFitTokenLimit(messages, 0, 500) + const result = trimMessagesToFitTokenLimit({ messages, systemTokens: 0, maxTotalTokens: 500, logger }) // Should keep both marked messages const keptMessages = result.filter( @@ -428,7 +438,7 @@ describe('trimMessagesToFitTokenLimit', () => { describe('getPreviouslyReadFiles', () => { it('returns empty array when no messages provided', () => { - const result = getPreviouslyReadFiles([]) + const result = getPreviouslyReadFiles({ messages: [], logger }) expect(result).toEqual([]) }) @@ -449,7 +459,7 @@ describe('getPreviouslyReadFiles', () => { } satisfies CodebuffToolMessage<'write_file'>, ] - const result = getPreviouslyReadFiles(messages) + const result = getPreviouslyReadFiles({ messages, logger }) expect(result).toEqual([]) }) @@ -481,7 +491,7 @@ describe('getPreviouslyReadFiles', () => { } satisfies CodebuffToolMessage<'read_files'>, ] - const result = getPreviouslyReadFiles(messages) + const result = getPreviouslyReadFiles({ messages, logger }) expect(result).toEqual([ { path: 'src/test.ts', @@ -518,7 +528,7 @@ describe('getPreviouslyReadFiles', () => { } satisfies CodebuffToolMessage<'find_files'>, ] - const result = getPreviouslyReadFiles(messages) + const result = getPreviouslyReadFiles({ messages, logger }) expect(result).toEqual([ { path: 'components/Button.tsx', @@ -573,7 +583,7 @@ describe('getPreviouslyReadFiles', () => { }, ] - const result = getPreviouslyReadFiles(messages) + const result = getPreviouslyReadFiles({ messages, logger }) expect(result).toEqual([ { path: 'file1.ts', content: 'content 1' }, { path: 'file2.ts', content: 'content 2' }, @@ -611,7 +621,7 @@ describe('getPreviouslyReadFiles', () => { } satisfies CodebuffToolMessage<'read_files'>, ] - const result = getPreviouslyReadFiles(messages) + const result = getPreviouslyReadFiles({ messages, logger }) expect(result).toEqual([ { path: 'small-file.ts', content: 'small content' }, { path: 'another-small-file.ts', content: 'another small content' }, @@ -633,7 +643,7 @@ describe('getPreviouslyReadFiles', () => { }, ] - const result = getPreviouslyReadFiles(messages) + const result = getPreviouslyReadFiles({ messages, logger }) expect(result).toEqual([]) expect(mockLoggerError).toHaveBeenCalled() @@ -660,7 +670,7 @@ describe('getPreviouslyReadFiles', () => { } satisfies CodebuffToolMessage<'find_files'>, ] - const result = getPreviouslyReadFiles(messages) + const result = getPreviouslyReadFiles({ messages, logger }) expect(result).toEqual([]) }) @@ -690,7 +700,7 @@ describe('getPreviouslyReadFiles', () => { } satisfies CodebuffToolMessage<'read_files'>, ] - const result = getPreviouslyReadFiles(messages) + const result = getPreviouslyReadFiles({ messages, logger }) expect(result).toEqual([{ path: 'test.ts', content: 'test content' }]) }) @@ -712,7 +722,7 @@ describe('getPreviouslyReadFiles', () => { } satisfies CodebuffToolMessage<'read_files'>, ] - const result = getPreviouslyReadFiles(messages) + const result = getPreviouslyReadFiles({ messages, logger }) expect(result).toEqual([]) }) }) diff --git a/backend/src/util/__tests__/simplify-tool-results.test.ts b/backend/src/util/__tests__/simplify-tool-results.test.ts index eedb1b749c..183cc4b024 100644 --- a/backend/src/util/__tests__/simplify-tool-results.test.ts +++ b/backend/src/util/__tests__/simplify-tool-results.test.ts @@ -1,21 +1,20 @@ -import { - afterEach, - beforeEach, - describe, - expect, - it, - mock, - spyOn, -} from 'bun:test' +import { describe, expect, it } from 'bun:test' import { simplifyReadFileResults, simplifyTerminalCommandResults, } from '../simplify-tool-results' -import * as logger from '../logger' import type { CodebuffToolOutput } from '@codebuff/common/tools/list' +// Mock logger for tests +const logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, +} + describe('simplifyReadFileResults', () => { it('should simplify read file results by omitting content', () => { const input: CodebuffToolOutput<'read_files'> = [ @@ -123,15 +122,6 @@ describe('simplifyReadFileResults', () => { }) describe('simplifyTerminalCommandResults', () => { - beforeEach(() => { - // Mock the logger.error function directly - spyOn(logger.logger, 'error').mockImplementation(() => {}) - }) - - afterEach(() => { - mock.restore() - }) - it('should simplify terminal command results with stdout', () => { const input: CodebuffToolOutput<'run_terminal_command'> = [ { @@ -147,7 +137,10 @@ describe('simplifyTerminalCommandResults', () => { }, ] - const result = simplifyTerminalCommandResults(input) + const result = simplifyTerminalCommandResults({ + messageContent: input, + logger, + }) expect(result).toEqual([ { @@ -174,7 +167,10 @@ describe('simplifyTerminalCommandResults', () => { }, ] - const result = simplifyTerminalCommandResults(input) + const result = simplifyTerminalCommandResults({ + messageContent: input, + logger, + }) expect(result).toEqual([ { @@ -199,7 +195,10 @@ describe('simplifyTerminalCommandResults', () => { }, ] - const result = simplifyTerminalCommandResults(input) + const result = simplifyTerminalCommandResults({ + messageContent: input, + logger, + }) expect(result).toEqual([ { @@ -224,7 +223,10 @@ describe('simplifyTerminalCommandResults', () => { }, ] - const result = simplifyTerminalCommandResults(input) + const result = simplifyTerminalCommandResults({ + messageContent: input, + logger, + }) expect(result).toEqual(input) }) @@ -240,7 +242,10 @@ describe('simplifyTerminalCommandResults', () => { }, ] - const result = simplifyTerminalCommandResults(input) + const result = simplifyTerminalCommandResults({ + messageContent: input, + logger, + }) expect(result).toEqual(input) }) @@ -258,7 +263,10 @@ describe('simplifyTerminalCommandResults', () => { }, ] - const result = simplifyTerminalCommandResults(input) + const result = simplifyTerminalCommandResults({ + messageContent: input, + logger, + }) expect(result).toEqual([ { @@ -277,6 +285,7 @@ describe('simplifyTerminalCommandResults', () => { // Create input that will cause an error during processing const malformedInput = { invalidStructure: true, + logger, } as any const result = simplifyTerminalCommandResults(malformedInput) @@ -290,9 +299,6 @@ describe('simplifyTerminalCommandResults', () => { }, }, ]) - - // Verify error was logged - expect(logger.logger.error).toHaveBeenCalled() }) it('should not mutate the original input', () => { @@ -308,7 +314,7 @@ describe('simplifyTerminalCommandResults', () => { ] const input = structuredClone(originalInput) - simplifyTerminalCommandResults(input) + simplifyTerminalCommandResults({ messageContent: input, logger }) // Original input should be unchanged expect(input).toEqual(originalInput) @@ -327,7 +333,10 @@ describe('simplifyTerminalCommandResults', () => { }, ] - const result = simplifyTerminalCommandResults(input) + const result = simplifyTerminalCommandResults({ + messageContent: input, + logger, + }) expect(result).toEqual([ { @@ -354,7 +363,10 @@ describe('simplifyTerminalCommandResults', () => { }, ] - const result = simplifyTerminalCommandResults(input) + const result = simplifyTerminalCommandResults({ + messageContent: input, + logger, + }) expect(result).toEqual([ { diff --git a/backend/src/util/check-auth.ts b/backend/src/util/check-auth.ts index bc4a996110..e7aeb9dfe2 100644 --- a/backend/src/util/check-auth.ts +++ b/backend/src/util/check-auth.ts @@ -3,21 +3,19 @@ import * as schema from '@codebuff/common/db/schema' import { utils } from '@codebuff/internal' import { eq } from 'drizzle-orm' -import { logger } from './logger' import { extractAuthTokenFromHeader } from './auth-helpers' import type { ServerAction } from '@codebuff/common/actions' +import type { Logger } from '@codebuff/types/logger' import type { Request, Response, NextFunction } from 'express' -export const checkAuth = async ({ - fingerprintId, - authToken, - clientSessionId, -}: { +export const checkAuth = async (params: { fingerprintId?: string authToken?: string clientSessionId: string + logger: Logger }): Promise => { + const { fingerprintId, authToken, clientSessionId, logger } = params // Use shared auth check functionality const authResult = await utils.checkAuthToken({ fingerprintId, @@ -45,7 +43,7 @@ export const checkAuth = async ({ } // Express middleware for checking admin access -export const checkAdmin = async ( +export const checkAdmin = (logger: Logger) => async ( req: Request, res: Response, next: NextFunction, @@ -63,6 +61,7 @@ export const checkAdmin = async ( const authResult = await checkAuth({ authToken, clientSessionId, + logger, }) if (authResult) { diff --git a/backend/src/util/messages.ts b/backend/src/util/messages.ts index 291c861dbb..971085471d 100644 --- a/backend/src/util/messages.ts +++ b/backend/src/util/messages.ts @@ -5,7 +5,6 @@ import { getErrorObject } from '@codebuff/common/util/error' import { closeXml } from '@codebuff/common/util/xml' import { cloneDeep, isEqual } from 'lodash' -import { logger } from './logger' import { simplifyTerminalCommandResults } from './simplify-tool-results' import { countTokensJson } from './token-counter' @@ -18,6 +17,7 @@ import type { Message, ToolMessage, } from '@codebuff/common/types/messages/codebuff-message' +import type { Logger } from '@codebuff/types/logger' export function messagesWithSystem(params: { messages: Message[] @@ -98,9 +98,13 @@ const numTerminalCommandsToKeep = 5 function simplifyTerminalHelper(params: { toolResult: CodebuffToolOutput<'run_terminal_command'> numKept: number + logger: Logger }): { result: CodebuffToolOutput<'run_terminal_command'>; numKept: number } { - const { toolResult, numKept } = params - const simplified = simplifyTerminalCommandResults(toolResult) + const { toolResult, numKept, logger } = params + const simplified = simplifyTerminalCommandResults({ + messageContent: toolResult, + logger, + }) // Keep the full output for the N most recent commands if (numKept < numTerminalCommandsToKeep && !isEqual(simplified, toolResult)) { @@ -134,11 +138,13 @@ const replacementMessage = { * @param maxTotalTokens - Maximum total tokens allowed, defaults to 200k * @returns Trimmed array of messages that fits within token limit */ -export function trimMessagesToFitTokenLimit( - messages: Message[], - systemTokens: number, - maxTotalTokens: number = 190_000, -): Message[] { +export function trimMessagesToFitTokenLimit(params: { + messages: Message[] + systemTokens: number + maxTotalTokens?: number + logger: Logger +}): Message[] { + const { messages, systemTokens, maxTotalTokens = 190_000, logger } = params const maxMessageTokens = maxTotalTokens - systemTokens // Check if we're already under the limit @@ -169,6 +175,7 @@ export function trimMessagesToFitTokenLimit( const result = simplifyTerminalHelper({ toolResult: terminalResultMessage.content.output, numKept, + logger, }) terminalResultMessage.content.output = result.result numKept = result.numKept @@ -211,11 +218,17 @@ export function trimMessagesToFitTokenLimit( ) } -export function getMessagesSubset( - messages: Message[], - otherTokens: number, -): Message[] { - const messagesSubset = trimMessagesToFitTokenLimit(messages, otherTokens) +export function getMessagesSubset(params: { + messages: Message[] + otherTokens: number + logger: Logger +}): Message[] { + const { messages, otherTokens, logger } = params + const messagesSubset = trimMessagesToFitTokenLimit({ + messages, + systemTokens: otherTokens, + logger, + }) // Remove cache_control from all messages for (const message of messagesSubset) { @@ -255,7 +268,11 @@ export function expireMessages( }) } -export function getEditedFiles(messages: Message[]): string[] { +export function getEditedFiles(params: { + messages: Message[] + logger: Logger +}): string[] { + const { messages, logger } = params return buildArray( messages .filter( @@ -294,11 +311,15 @@ export function getEditedFiles(messages: Message[]): string[] { ) } -export function getPreviouslyReadFiles(messages: Message[]): { +export function getPreviouslyReadFiles(params: { + messages: Message[] + logger: Logger +}): { path: string content: string referencedBy?: Record }[] { + const { messages, logger } = params const files: ReturnType = [] for (const message of messages) { if (message.role !== 'tool') continue diff --git a/backend/src/util/quickjs-sandbox.ts b/backend/src/util/quickjs-sandbox.ts index 947e73f9b3..fc763b7c74 100644 --- a/backend/src/util/quickjs-sandbox.ts +++ b/backend/src/util/quickjs-sandbox.ts @@ -1,7 +1,6 @@ import { newQuickJSWASMModuleFromVariant } from 'quickjs-emscripten-core' -import { logger } from './logger' - +import type { Logger } from '@codebuff/types/logger' import type { QuickJSContext, QuickJSWASMModule, @@ -52,14 +51,9 @@ export class QuickJSSandbox { generatorCode: string initialInput: any config?: SandboxConfig - logger?: { - debug: (data: any, msg?: string) => void - info: (data: any, msg?: string) => void - warn: (data: any, msg?: string) => void - error: (data: any, msg?: string) => void - } + sandboxLogger?: Logger }): Promise { - const { generatorCode, initialInput, config = {}, logger } = params + const { generatorCode, initialInput, config = {}, sandboxLogger } = params const { memoryLimit = 1024 * 1024 * 20, // 20MB maxStackSize = 1024 * 512, // 512KB @@ -105,13 +99,16 @@ export class QuickJSSandbox { msgStr = msg ? context.getString(msg) : undefined // Call the appropriate logger method if available - if (logger?.[levelStr as keyof typeof logger]) { - logger[levelStr as keyof typeof logger](dataObj, msgStr) + if (sandboxLogger?.[levelStr as keyof typeof sandboxLogger]) { + sandboxLogger[levelStr as keyof typeof sandboxLogger]( + dataObj, + msgStr, + ) } } catch (err) { // Fallback for logging errors - if (logger?.error) { - logger.error({ error: err }, 'Logger handler error') + if (sandboxLogger?.error) { + sandboxLogger.error({ error: err }, 'Logger handler error') } } }, @@ -231,7 +228,8 @@ export class QuickJSSandbox { /** * Dispose of the sandbox and clean up resources */ - dispose(): void { + dispose(params: { logger: Logger }): void { + const { logger } = params if (!this.initialized) { return } @@ -267,14 +265,17 @@ export class SandboxManager { generatorCode: string initialInput: any config?: SandboxConfig - logger?: { - debug: (data: any, msg?: string) => void - info: (data: any, msg?: string) => void - warn: (data: any, msg?: string) => void - error: (data: any, msg?: string) => void - } + sandboxLogger?: Logger + logger: Logger }): Promise { - const { runId, generatorCode, initialInput, config, logger } = params + const { + runId, + generatorCode, + initialInput, + config, + sandboxLogger, + logger, + } = params const existing = this.sandboxes.get(runId) if (existing && existing.isInitialized()) { return existing @@ -282,7 +283,9 @@ export class SandboxManager { // Clean up any existing sandbox if (existing) { - existing.dispose() + existing.dispose({ + logger, + }) } // Create new sandbox @@ -290,7 +293,7 @@ export class SandboxManager { generatorCode, initialInput, config, - logger, + sandboxLogger, }) this.sandboxes.set(runId, sandbox) return sandbox @@ -307,11 +310,11 @@ export class SandboxManager { /** * Remove and dispose a sandbox */ - removeSandbox(params: { runId: string }): void { - const { runId } = params + removeSandbox(params: { runId: string; logger: Logger }): void { + const { runId, logger } = params const sandbox = this.sandboxes.get(runId) if (sandbox) { - sandbox.dispose() + sandbox.dispose({ logger }) this.sandboxes.delete(runId) } } @@ -319,10 +322,11 @@ export class SandboxManager { /** * Clean up all sandboxes */ - dispose(): void { + dispose(params: { logger: Logger }): void { + const { logger } = params for (const [runId, sandbox] of this.sandboxes) { try { - sandbox.dispose() + sandbox.dispose({ logger }) } catch (error) { logger.warn( { error, runId }, diff --git a/backend/src/util/simplify-tool-results.ts b/backend/src/util/simplify-tool-results.ts index 3062c28287..bb2c0d22f0 100644 --- a/backend/src/util/simplify-tool-results.ts +++ b/backend/src/util/simplify-tool-results.ts @@ -1,9 +1,8 @@ import { getErrorObject } from '@codebuff/common/util/error' import { cloneDeep } from 'lodash' -import { logger } from './logger' - import type { CodebuffToolOutput } from '@codebuff/common/tools/list' +import type { Logger } from '@codebuff/types/logger' export function simplifyReadFileResults( messageContent: CodebuffToolOutput<'read_files'>, @@ -21,9 +20,11 @@ export function simplifyReadFileResults( ] } -export function simplifyTerminalCommandResults( - messageContent: CodebuffToolOutput<'run_terminal_command'>, -): CodebuffToolOutput<'run_terminal_command'> { +export function simplifyTerminalCommandResults(params: { + messageContent: CodebuffToolOutput<'run_terminal_command'> + logger: Logger +}): CodebuffToolOutput<'run_terminal_command'> { + const { messageContent, logger } = params try { const clone = cloneDeep(messageContent) const content = clone[0].value diff --git a/backend/src/websockets/middleware.ts b/backend/src/websockets/middleware.ts index 132e5ea87d..45febba8ac 100644 --- a/backend/src/websockets/middleware.ts +++ b/backend/src/websockets/middleware.ts @@ -17,18 +17,19 @@ import { updateRequestContext } from './request-context' import { sendAction } from './websocket-action' import { withAppContext } from '../context/app-context' import { checkAuth } from '../util/check-auth' -import { logger } from '../util/logger' import type { UserInfo } from './auth' import type { ClientAction, ServerAction } from '@codebuff/common/actions' +import type { Logger } from '@codebuff/types/logger' import type { WebSocket } from 'ws' -type MiddlewareCallback = ( - action: ClientAction, - clientSessionId: string, - ws: WebSocket, - userInfo: UserInfo | undefined, -) => Promise +type MiddlewareCallback = (params: { + action: ClientAction + clientSessionId: string + ws: WebSocket + userInfo: UserInfo | undefined + logger: Logger +}) => Promise function getServerErrorAction( action: T, @@ -52,34 +53,39 @@ export class WebSocketMiddleware { private middlewares: Array = [] use( - callback: ( - action: ClientAction, - clientSessionId: string, - ws: WebSocket, - userInfo: UserInfo | undefined, - ) => Promise, + callback: (params: { + action: ClientAction + clientSessionId: string + ws: WebSocket + userInfo: UserInfo | undefined + logger: Logger + }) => Promise, ) { this.middlewares.push(callback as MiddlewareCallback) } - async execute( - action: ClientAction, - clientSessionId: string, - ws: WebSocket, - options: { silent?: boolean } = {}, - ): Promise { + async execute(params: { + action: ClientAction + clientSessionId: string + ws: WebSocket + silent?: boolean + logger: Logger + }): Promise { + const { action, clientSessionId, ws, silent, logger } = params + const userInfo = 'authToken' in action && action.authToken ? await getUserInfoFromAuthToken(action.authToken) : undefined for (const middleware of this.middlewares) { - const actionOrContinue = await middleware( + const actionOrContinue = await middleware({ action, clientSessionId, ws, userInfo, - ) + logger, + }) if (actionOrContinue) { logger.warn( { @@ -89,7 +95,7 @@ export class WebSocketMiddleware { }, 'Middleware execution halted.', ) - if (!options.silent) { + if (!silent) { sendAction(ws, actionOrContinue) } return false @@ -98,14 +104,18 @@ export class WebSocketMiddleware { return true } - run( - baseAction: ( - action: ClientAction, - clientSessionId: string, - ws: WebSocket, - ) => void, - options: { silent?: boolean } = {}, - ) { + run(params: { + baseAction: (params: { + action: ClientAction + clientSessionId: string + ws: WebSocket + logger: Logger + }) => void + silent?: boolean + logger: Logger + }) { + const { baseAction, silent, logger } = params + return async ( action: ClientAction, clientSessionId: string, @@ -126,14 +136,15 @@ export class WebSocketMiddleware { }, {}, // request context starts empty async () => { - const shouldContinue = await this.execute( + const shouldContinue = await this.execute({ action, clientSessionId, ws, - options, - ) + silent, + logger, + }) if (shouldContinue) { - baseAction(action, clientSessionId, ws) + baseAction({ action, clientSessionId, ws, logger }) } }, ) @@ -143,16 +154,17 @@ export class WebSocketMiddleware { export const protec = new WebSocketMiddleware() -protec.use(async (action, clientSessionId, ws, userInfo) => +protec.use(async ({ action, clientSessionId, ws, userInfo, logger }) => checkAuth({ fingerprintId: 'fingerprintId' in action ? action.fingerprintId : undefined, authToken: 'authToken' in action ? action.authToken : undefined, clientSessionId, + logger, }), ) // Organization repository coverage detection middleware -protec.use(async (action, clientSessionId, ws, userInfo) => { +protec.use(async ({ action, clientSessionId, ws, userInfo, logger }) => { const userId = userInfo?.id // Only process actions that have repoUrl as a valid string @@ -292,7 +304,7 @@ protec.use(async (action, clientSessionId, ws, userInfo) => { return undefined }) -protec.use(async (action, clientSessionId, ws, userInfo) => { +protec.use(async ({ action, clientSessionId, ws, userInfo, logger }) => { const userId = userInfo?.id const fingerprintId = 'fingerprintId' in action ? action.fingerprintId : 'unknown-fingerprint' diff --git a/backend/src/websockets/server.ts b/backend/src/websockets/server.ts index 5fe490e48b..74c9e95b5e 100644 --- a/backend/src/websockets/server.ts +++ b/backend/src/websockets/server.ts @@ -5,9 +5,9 @@ import { WebSocketServer } from 'ws' import { setSessionConnected } from '../live-user-inputs' import { Switchboard } from './switchboard' import { onWebsocketAction } from './websocket-action' -import { logger } from '../util/logger' import type { ServerMessage } from '@codebuff/common/websockets/websocket-schema' +import type { Logger } from '@codebuff/types/logger' import type { Server as HttpServer } from 'node:http' import type { RawData, WebSocket } from 'ws' @@ -29,11 +29,14 @@ function serializeError(err: unknown) { return isError(err) ? err.message : 'Unexpected error.' } -async function processMessage( - ws: WebSocket, - clientSessionId: string, - data: RawData, -): Promise> { +async function processMessage(params: { + ws: WebSocket + clientSessionId: string + data: RawData + logger: Logger +}): Promise> { + const { ws, clientSessionId, data, logger } = params + let messageObj: any try { messageObj = JSON.parse(data.toString()) @@ -62,7 +65,7 @@ async function processMessage( break } case 'action': { - onWebsocketAction(ws, clientSessionId, msg) + onWebsocketAction({ ws, clientSessionId, msg, logger }) break } default: @@ -80,8 +83,12 @@ async function processMessage( } } -export function listen(params: { server: HttpServer; path: string }) { - const { server, path } = params +export function listen(params: { + server: HttpServer + path: string + logger: Logger +}) { + const { server, path, logger } = params const wss = new WebSocketServer({ server, path }) let deadConnectionCleaner: NodeJS.Timeout | undefined wss.on('listening', () => { @@ -129,7 +136,7 @@ export function listen(params: { server: HttpServer; path: string }) { // Mark session as connected setSessionConnected(clientSessionId, true) ws.on('message', async (data: RawData) => { - const result = await processMessage(ws, clientSessionId, data) + const result = await processMessage({ ws, clientSessionId, data, logger }) // mqp: check ws.readyState before sending? ws.send(JSON.stringify(result)) }) diff --git a/backend/src/websockets/websocket-action.ts b/backend/src/websockets/websocket-action.ts index e7ec1aa643..1e8358272a 100644 --- a/backend/src/websockets/websocket-action.ts +++ b/backend/src/websockets/websocket-action.ts @@ -80,8 +80,9 @@ export async function genUsageResponse(params: { fingerprintId: string userId: string clientSessionId?: string + logger: Logger }): Promise { - const { fingerprintId, userId, clientSessionId } = params + const { fingerprintId, userId, clientSessionId, logger } = params const logContext = { fingerprintId, userId, sessionId: clientSessionId } const defaultResp = { type: 'usage-response' as const, @@ -135,11 +136,13 @@ export async function genUsageResponse(params: { * @param clientSessionId - The client's session ID * @param ws - The WebSocket connection */ -const onPrompt = async ( - action: ClientAction<'prompt'>, - clientSessionId: string, - ws: WebSocket, -) => { +const onPrompt = async (params: { + action: ClientAction<'prompt'> + clientSessionId: string + ws: WebSocket + logger: Logger +}) => { + const { action, clientSessionId, ws, logger } = params const { fingerprintId, authToken, promptId, prompt, costMode } = action await withLoggerContext( @@ -188,10 +191,11 @@ const onPrompt = async ( message: response, }) } finally { - cancelUserInput({ userId, userInputId: promptId }) + cancelUserInput({ userId, userInputId: promptId, logger }) const usageResponse = await genUsageResponse({ fingerprintId, userId, + logger, }) sendAction(ws, usageResponse) } @@ -207,7 +211,7 @@ export const callMainPrompt = async (params: { clientSessionId: string logger: Logger }) => { - const { ws, action, userId, promptId, clientSessionId } = params + const { ws, action, userId, promptId, clientSessionId, logger } = params const { fileContext } = action.sessionState // Enforce server-side state authority: reset creditsUsed to 0 @@ -217,7 +221,7 @@ export const callMainPrompt = async (params: { // Assemble local agent templates from fileContext const { agentTemplates: localAgentTemplates, validationErrors } = - assembleLocalAgentTemplates(fileContext) + assembleLocalAgentTemplates({ fileContext, logger }) if (validationErrors.length > 0) { sendAction(ws, { @@ -242,7 +246,9 @@ export const callMainPrompt = async (params: { ...params, localAgentTemplates, onResponseChunk: (chunk) => { - if (checkLiveUserInput(userId, promptId, clientSessionId)) { + if ( + checkLiveUserInput({ userId, userInputId: promptId, clientSessionId }) + ) { sendAction(ws, { type: 'response-chunk', userInputId: promptId, @@ -285,11 +291,15 @@ export const callMainPrompt = async (params: { * @param clientSessionId - The client's session ID * @param ws - The WebSocket connection */ -const onInit = async ( - { fileContext, fingerprintId, authToken }: ClientAction<'init'>, - clientSessionId: string, - ws: WebSocket, -) => { +const onInit = async (params: { + action: ClientAction<'init'> + clientSessionId: string + ws: WebSocket + logger: Logger +}) => { + const { action, clientSessionId, ws, logger } = params + const { fileContext, fingerprintId, authToken } = action + await withLoggerContext({ fingerprintId }, async () => { const userId = await getUserIdFromAuthToken({ authToken }) @@ -308,6 +318,7 @@ const onInit = async ( fingerprintId, userId, clientSessionId, + logger, }) sendAction(ws, { ...usageResponse, @@ -316,16 +327,19 @@ const onInit = async ( }) } -const onCancelUserInput = async ({ - authToken, - promptId, -}: ClientAction<'cancel-user-input'>) => { +const onCancelUserInput = async (params: { + action: ClientAction<'cancel-user-input'> + logger: Logger +}) => { + const { action, logger } = params + const { authToken, promptId } = action + const userId = await getUserIdFromAuthToken({ authToken }) if (!userId) { logger.error({ authToken }, 'User id not found for authToken') return } - cancelUserInput({ userId, userInputId: promptId }) + cancelUserInput({ userId, userInputId: promptId, logger }) } /** @@ -370,11 +384,14 @@ export const subscribeToAction = ( * @param clientSessionId - The client's session ID * @param msg - The action message from the client */ -export const onWebsocketAction = async ( - ws: WebSocket, - clientSessionId: string, - msg: ClientMessage & { type: 'action' }, -) => { +export const onWebsocketAction = async (params: { + ws: WebSocket + clientSessionId: string + msg: ClientMessage & { type: 'action' } + logger: Logger +}) => { + const { ws, clientSessionId, msg, logger } = params + await withLoggerContext({ clientSessionId }, async () => { const callbacks = callbacksByAction[msg.data.type] ?? [] try { @@ -394,9 +411,15 @@ export const onWebsocketAction = async ( } // Register action handlers -subscribeToAction('prompt', protec.run(onPrompt)) -subscribeToAction('init', protec.run(onInit, { silent: true })) -subscribeToAction('cancel-user-input', protec.run(onCancelUserInput)) +subscribeToAction('prompt', protec.run({ baseAction: onPrompt, logger })) +subscribeToAction( + 'init', + protec.run({ baseAction: onInit, silent: true, logger }), +) +subscribeToAction( + 'cancel-user-input', + protec.run({ baseAction: onCancelUserInput, logger }), +) /** * Requests multiple files from the client diff --git a/codebuff.json b/codebuff.json index 9051ed5f4f..d7839fd614 100644 --- a/codebuff.json +++ b/codebuff.json @@ -76,5 +76,6 @@ "command": "set -o pipefail && git diff --name-only --diff-filter=ACMR | grep -E '\\.(ts|tsx|js|jsx)$' | xargs -r npx eslint --fix --quiet", "filePattern": "**/*.{ts,tsx,js,jsx}" } - ] + ], + "maxAgentSteps": 100 } diff --git a/evals/git-evals/gen-evals.ts b/evals/git-evals/gen-evals.ts index d005d5e3d9..b1ef85e061 100644 --- a/evals/git-evals/gen-evals.ts +++ b/evals/git-evals/gen-evals.ts @@ -1,13 +1,13 @@ import { execSync } from 'child_process' import fs from 'fs' import path from 'path' -import { mapLimit } from 'async' +import { disableLiveUserInputCheck } from '@codebuff/backend/live-user-inputs' import { promptAiSdk } from '@codebuff/backend/llm-apis/vercel-ai-sdk/ai-sdk' import { models } from '@codebuff/common/old-constants' +import { mapLimit } from 'async' import { extractRepoNameFromUrl, setupTestRepo } from './setup-test-repo' -import { disableLiveUserInputCheck } from '@codebuff/backend/live-user-inputs' import type { EvalData, EvalInput, FileState, EvalCommit } from './types' const SPEC_GENERATION_PROMPT = `Given a set of file changes and an optional description, write a clear specification describing WHAT needs to be implemented. @@ -132,6 +132,7 @@ File Changes:\n${fileContext}` fingerprintId, userInputId, userId: undefined, + logger: console, }) // Extract spec from tags diff --git a/evals/git-evals/judge-git-eval.ts b/evals/git-evals/judge-git-eval.ts index 437edad903..a9ec0b9d65 100644 --- a/evals/git-evals/judge-git-eval.ts +++ b/evals/git-evals/judge-git-eval.ts @@ -193,6 +193,7 @@ export async function judgeEvalRun(evalRun: EvalRunLog) { userInputId: generateCompactId(), userId: undefined, timeout: 10 * 60 * 1000, // 10 minute timeout + logger: console, }).catch((error) => { console.warn(`Judge ${index + 1} failed:`, error) return null diff --git a/evals/git-evals/pick-commits.ts b/evals/git-evals/pick-commits.ts index c84bbac297..c88cad8d8a 100644 --- a/evals/git-evals/pick-commits.ts +++ b/evals/git-evals/pick-commits.ts @@ -3,14 +3,14 @@ import { execFileSync } from 'child_process' import fs from 'fs' import path from 'path' -import { mapLimit } from 'async' +import { disableLiveUserInputCheck } from '@codebuff/backend/live-user-inputs' import { promptAiSdkStructured } from '@codebuff/backend/llm-apis/vercel-ai-sdk/ai-sdk' import { models } from '@codebuff/common/old-constants' +import { mapLimit } from 'async' import { z } from 'zod/v4' import { extractRepoNameFromUrl, setupTestRepo } from './setup-test-repo' -import { disableLiveUserInputCheck } from '@codebuff/backend/live-user-inputs' // Types for commit data export interface CommitDiff { @@ -379,6 +379,7 @@ async function screenCommitsWithGpt5( fingerprintId, userInputId, userId: undefined, + logger: console, }) // Handle empty or invalid response diff --git a/evals/git-evals/post-eval-analysis.ts b/evals/git-evals/post-eval-analysis.ts index 4122ba2120..e0af166bce 100644 --- a/evals/git-evals/post-eval-analysis.ts +++ b/evals/git-evals/post-eval-analysis.ts @@ -188,5 +188,6 @@ export async function analyzeEvalResults( userInputId: generateCompactId(), userId: undefined, timeout: 10 * 60 * 1000, // 10 minute timeout + logger: console, }) } diff --git a/evals/git-evals/run-git-evals.ts b/evals/git-evals/run-git-evals.ts index 03b4ec48f7..54c92ae23f 100644 --- a/evals/git-evals/run-git-evals.ts +++ b/evals/git-evals/run-git-evals.ts @@ -151,6 +151,7 @@ Explain your reasoning in detail. Do not ask Codebuff to git commit changes.`, userInputId: generateCompactId(), userId: undefined, timeout: 5 * 60_000, // 5 minute timeout + logger: console, }) } catch (agentError) { throw new Error( diff --git a/evals/scaffolding.ts b/evals/scaffolding.ts index 5fb64f857d..7f23ff34e1 100644 --- a/evals/scaffolding.ts +++ b/evals/scaffolding.ts @@ -175,10 +175,13 @@ export async function runAgentStepScaffolding( mockWs.close = mock() let fullResponse = '' - const { agentTemplates: localAgentTemplates } = - assembleLocalAgentTemplates(fileContext) + const { agentTemplates: localAgentTemplates } = assembleLocalAgentTemplates({ + fileContext, + logger: console, + }) - const result = await runAgentStep(mockWs, { + const result = await runAgentStep({ + ws: mockWs, userId: TEST_USER_ID, userInputId: generateCompactId(), clientSessionId: sessionId, @@ -199,6 +202,7 @@ export async function runAgentStepScaffolding( prompt, params: undefined, system: 'Test system prompt', + logger: console, }) return { diff --git a/plans/logger-refactor-plan.md b/plans/logger-refactor-plan.md index 8a6ff85072..06eb6b952e 100644 --- a/plans/logger-refactor-plan.md +++ b/plans/logger-refactor-plan.md @@ -1,4 +1,14 @@ -## Plan: Refactor Logger to Pass as Parameter +## Plan: Refactor Logger to Pass as Parameter**IMPORTANT: This is a LIVE document that MUST be updated as you work!** + +As you complete each step and encounter unintuitive cases or learn new patterns: + +1. **UPDATE THIS DOCUMENT IMMEDIATELY** - Don't wait until the end +2. **Add learnings directly into the relevant step sections** - Do NOT create a dedicated "Findings" or "Learnings" section +3. **Update existing instructions inline** - Integrate new knowledge into the step descriptions themselves +4. **Split complex steps** - Feel free to break one step into multiple sub-steps (e.g., 1, [2], 3, 4 -> 1, [2, 3], 4, 5) if needed +5. Document any edge cases or special handling required within the relevant step + +**This plan serves as documentation for future engineers - keep it accurate and up-to-date!** ### Step 1: Search for logger imports @@ -6,7 +16,18 @@ - Set `cwd: "backend"` to limit search scope - Exclude websocket-action.ts files -### Step 2: For each file (except websocket-action.ts) +### Step 2: Select a file (except websocket-action.ts and library-integrated functions) + +**Exception**: Do NOT refactor functions that are directly integrated with external libraries and must maintain a specific signature required by that library. Examples include: + +- Express route handlers passed directly as middleware (e.g., `usageHandler`, `isRepoCoveredHandler`) - must maintain `(req, res, next?)` signature +- WebSocket handlers that conform to a specific library interface +- Event handlers or callbacks that match a library's expected signature +- Any function where changing the signature would break the integration with an external library + +These functions should continue to import and use the logger directly, as they cannot accept custom parameter objects without breaking their integration. + +For all other files: - Remove the `import { logger }` line - Refactor function signature to use single `params` object containing all arguments including `logger: Logger` @@ -18,14 +39,30 @@ - **Always run full `bun run typecheck`** (not head/tail!) to find ALL errors - Update function calls to pass object with named properties -- For tests: create mock logger constant called `logger` with all 4 methods (debug, info, warn, error) -- Use `allowMultiple: true` in str_replace when updating multiple calls in same file +- **For tests:** Create a no-op logger constant named `logger` (NOT `mockLogger`!) with all 4 methods: + ```typescript + const logger: Logger = { + debug: () => {}, + info: () => {}, + warn: () => {}, + error: () => {}, + } + ``` - **Check carefully** - there may be multiple call sites in the same file! - Repeat typecheck until ALL errors resolved -### Step 4: Commit changes +### Step 4: Run reviewer agent + +- After all typechecks pass, spawn the reviewer agent to review the changes +- Address any feedback from the reviewer before committing +- If you make _any_ changes, go back to Step 3. + +### Step 5: Commit changes - **Do NOT use git-committer agent** - it's too slow -- Instead, manually run: `git add ` then `git commit -m ""` +- Instead, manually run: `git add && git commit -m "" && git push` + - This will push to a branch, where I can manually review the changes. - Keep commit message concise but descriptive - Include Codebuff footer in commit message + +### Step 6: Go back to Step 1 and repeat diff --git a/scripts/ft-file-selection/grade-traces.ts b/scripts/ft-file-selection/grade-traces.ts index cdc29514a2..1d2e341a97 100644 --- a/scripts/ft-file-selection/grade-traces.ts +++ b/scripts/ft-file-selection/grade-traces.ts @@ -46,7 +46,7 @@ async function gradeTraces({ logger }: { logger: Logger }) { batch.map(async (traceAndRelabels) => { try { console.log(`Grading trace ${traceAndRelabels.trace.id}`) - const result = await gradeRun(traceAndRelabels) + const result = await gradeRun({ ...traceAndRelabels, logger }) return { traceId: traceAndRelabels.trace.id, status: 'success', diff --git a/scripts/ft-file-selection/relabel-for-offline-scoring.ts b/scripts/ft-file-selection/relabel-for-offline-scoring.ts index 0b2542ca26..6d7b71e62b 100644 --- a/scripts/ft-file-selection/relabel-for-offline-scoring.ts +++ b/scripts/ft-file-selection/relabel-for-offline-scoring.ts @@ -191,6 +191,7 @@ async function relabelTraceForModel( userInputId: 'relabel-offline-scoring', userId: TEST_USER_ID, maxOutputTokens: 1000, + logger: console, }) const newRelabel: Relabel = { diff --git a/scripts/ft-file-selection/relabel-traces-with-context.ts b/scripts/ft-file-selection/relabel-traces-with-context.ts index ad4950d1b8..144f929c0b 100644 --- a/scripts/ft-file-selection/relabel-traces-with-context.ts +++ b/scripts/ft-file-selection/relabel-traces-with-context.ts @@ -95,6 +95,7 @@ async function runTraces() { fileBlobs, model: MODEL_TO_TEST, dataset: DATASET, + logger: console, }) console.log(`Successfully stored relabel for trace ${trace.id}`) } catch (error) { diff --git a/scripts/ft-file-selection/relabel-traces.ts b/scripts/ft-file-selection/relabel-traces.ts index 96fe49262e..5ce518477f 100644 --- a/scripts/ft-file-selection/relabel-traces.ts +++ b/scripts/ft-file-selection/relabel-traces.ts @@ -67,22 +67,22 @@ async function runTraces() { fingerprintId: 'relabel-trace-run', userInputId: 'relabel-trace-run', userId: TEST_USER_ID, + logger: console, }) } else { - output = await promptFlashWithFallbacks( - messagesWithSystem({ + output = await promptFlashWithFallbacks({ + messages: messagesWithSystem({ messages: messages as Message[], system: system as System, }), - { - model: - model as typeof models.openrouter_gemini2_5_pro_preview, - clientSessionId: 'relabel-trace-run', - fingerprintId: 'relabel-trace-run', - userInputId: 'relabel-trace-run', - userId: 'relabel-trace-run', - }, - ) + model: + model as typeof models.openrouter_gemini2_5_pro_preview, + clientSessionId: 'relabel-trace-run', + fingerprintId: 'relabel-trace-run', + userInputId: 'relabel-trace-run', + userId: 'relabel-trace-run', + logger: console, + }) } // Create relabel record diff --git a/scripts/get-changelog.ts b/scripts/get-changelog.ts index 29fdc0026f..5b489db3a1 100644 --- a/scripts/get-changelog.ts +++ b/scripts/get-changelog.ts @@ -172,6 +172,7 @@ Start your response with a heading using ### (three hashes) and organize the con model: models.openrouter_claude_sonnet_4, userId: undefined, chargeUser: false, + logger: console, }) // Clean up the AI response