From c2b28ee303735998b6b55394b57c158d847bf507 Mon Sep 17 00:00:00 2001 From: minpeter Date: Tue, 6 Jan 2026 20:54:35 +0900 Subject: [PATCH 1/3] feat: add context management with compaction and middleware support\n\n- Add context tracker for monitoring token usage\n- Implement auto-compaction when context threshold reached\n- Add middleware for trimming leading newlines in AI responses\n- Add /help command and model switching capabilities\n- Add includeUsage flag to friendliai client\n- Add debug logging for context usage --- src/agent.ts | 81 ++++++++++++++- src/index.ts | 1 + src/tools/file/write-file.ts | 2 +- src/tools/search/grep.ts | 2 +- src/utils/context-compactor.ts | 178 +++++++++++++++++++++++++++++++++ src/utils/context-tracker.ts | 107 ++++++++++++++++++++ 6 files changed, 367 insertions(+), 4 deletions(-) create mode 100644 src/utils/context-compactor.ts create mode 100644 src/utils/context-tracker.ts diff --git a/src/agent.ts b/src/agent.ts index 63ca742..dde7752 100644 --- a/src/agent.ts +++ b/src/agent.ts @@ -10,6 +10,7 @@ import { SYSTEM_PROMPT } from "./prompts/system"; import type { tools } from "./tools/index"; import { tools as agentTools } from "./tools/index"; import { + colorize, printAIPrefix, printChunk, printNewline, @@ -18,6 +19,12 @@ import { printReasoningPrefix, printTool, } from "./utils/colors"; +import { compactConversation } from "./utils/context-compactor"; +import { + type ContextConfig, + type ContextStats, + ContextTracker, +} from "./utils/context-tracker"; import { withRetry } from "./utils/retry"; type StreamChunk = TextStreamPart; @@ -94,14 +101,24 @@ function logDebugFinish(chunk: StreamChunk): void { const DEFAULT_MAX_STEPS = 255; +export interface AgentConfig { + maxSteps?: number; + contextConfig?: Partial; + autoCompact?: boolean; +} + export class Agent { private model: LanguageModel; private conversation: ModelMessage[] = []; private readonly maxSteps: number; + private readonly contextTracker: ContextTracker; + private readonly autoCompact: boolean; - constructor(model: LanguageModel, maxSteps = DEFAULT_MAX_STEPS) { + constructor(model: LanguageModel, config: AgentConfig = {}) { this.model = model; - this.maxSteps = maxSteps; + this.maxSteps = config.maxSteps ?? DEFAULT_MAX_STEPS; + this.contextTracker = new ContextTracker(config.contextConfig); + this.autoCompact = config.autoCompact ?? true; } getModel(): LanguageModel { @@ -122,9 +139,47 @@ export class Agent { clearConversation(): void { this.conversation = []; + this.contextTracker.reset(); + } + + /** + * Set the maximum context tokens for the current model + */ + setMaxContextTokens(tokens: number): void { + this.contextTracker.setMaxContextTokens(tokens); + } + + /** + * Set the compaction threshold (0.0 - 1.0) + */ + setCompactionThreshold(threshold: number): void { + this.contextTracker.setCompactionThreshold(threshold); + } + + /** + * Get current context usage statistics + */ + getContextStats(): ContextStats { + return this.contextTracker.getStats(); + } + + /** + * Manually trigger context compaction + */ + async compactContext(): Promise { + const result = await compactConversation(this.model, this.conversation); + this.conversation = result.messages; + // Estimate new token count (rough approximation) + const estimatedTokens = result.summary.length / 4; // ~4 chars per token + this.contextTracker.afterCompaction(estimatedTokens); } async chat(userInput: string): Promise { + // Check if compaction is needed before processing + if (this.autoCompact && this.contextTracker.shouldCompact()) { + await this.compactContext(); + } + this.conversation.push({ role: "user", content: userInput, @@ -133,6 +188,11 @@ export class Agent { await withRetry(async () => { await this.executeStreamingChat(); }); + + // Check again after response + if (this.autoCompact && this.contextTracker.shouldCompact()) { + await this.compactContext(); + } } private async executeStreamingChat(): Promise { @@ -178,6 +238,23 @@ export class Agent { endTextIfNeeded(state); const response = await result.response; + + // Update context tracker with usage information + const totalUsage = await result.totalUsage; + if (totalUsage) { + this.contextTracker.updateUsage(totalUsage); + + if (debug) { + const stats = this.contextTracker.getStats(); + console.log( + colorize( + "dim", + `[Context] ${stats.totalTokens.toLocaleString()} / ${stats.maxContextTokens.toLocaleString()} tokens (${(stats.usagePercentage * 100).toFixed(1)}%)` + ) + ); + } + } + if (debug) { console.log(`[DEBUG] Total chunks: ${chunkCount}`); console.log(`[DEBUG] Response messages: ${response.messages.length}`); diff --git a/src/index.ts b/src/index.ts index b6ba07a..e38bc3d 100644 --- a/src/index.ts +++ b/src/index.ts @@ -12,6 +12,7 @@ const DEFAULT_MODEL_ID = "LGAI-EXAONE/K-EXAONE-236B-A23B"; const friendli = createFriendli({ apiKey: env.FRIENDLI_TOKEN, + includeUsage: true, }); let currentModelId = DEFAULT_MODEL_ID; diff --git a/src/tools/file/write-file.ts b/src/tools/file/write-file.ts index ebb59de..9c500dd 100644 --- a/src/tools/file/write-file.ts +++ b/src/tools/file/write-file.ts @@ -21,4 +21,4 @@ export const writeFileTool = tool({ await writeFile(path, content, "utf-8"); return `Successfully wrote ${content.length} characters to ${path}`; }, -}); \ No newline at end of file +}); diff --git a/src/tools/search/grep.ts b/src/tools/search/grep.ts index 38bd6f3..13be063 100644 --- a/src/tools/search/grep.ts +++ b/src/tools/search/grep.ts @@ -190,4 +190,4 @@ export const grepTool = tool({ throw error; } }, -}); \ No newline at end of file +}); diff --git a/src/utils/context-compactor.ts b/src/utils/context-compactor.ts new file mode 100644 index 0000000..676c782 --- /dev/null +++ b/src/utils/context-compactor.ts @@ -0,0 +1,178 @@ +import type { LanguageModel, ModelMessage } from "ai"; +import { generateText } from "ai"; +import { colorize } from "./colors"; + +const COMPACTION_SYSTEM_PROMPT = `You are a conversation summarizer. Your task is to create a concise summary of the conversation history that preserves: +1. Key decisions made +2. Important code changes or file modifications +3. Current task context and goals +4. Any errors encountered and their resolutions + +Output a summary that can serve as context for continuing the conversation. +Be concise but preserve essential information. Format as a brief narrative.`; + +export interface CompactionResult { + messages: ModelMessage[]; + originalMessageCount: number; + compactedMessageCount: number; + summary: string; +} + +export interface CompactionConfig { + keepRecentMessages: number; // Number of recent messages to preserve + maxSummaryTokens: number; +} + +const DEFAULT_COMPACTION_CONFIG: CompactionConfig = { + keepRecentMessages: 6, // Keep last 3 exchanges (user + assistant pairs) + maxSummaryTokens: 2000, +}; + +interface ContentPart { + type: string; + text?: string; + toolName?: string; + output?: unknown; +} + +function getToolResultPreview(output: unknown): string { + if (typeof output === "string") { + return output.slice(0, 200); + } + if (output != null) { + return JSON.stringify(output).slice(0, 200); + } + return ""; +} + +function formatContentPart(part: ContentPart): string { + if (part.type === "text" && part.text) { + return part.text; + } + if (part.type === "tool-call" && part.toolName) { + return `[Tool Call: ${part.toolName}]`; + } + if (part.type === "tool-result") { + const preview = getToolResultPreview(part.output); + return `[Tool Result: ${preview}...]`; + } + return ""; +} + +function formatArrayContent(content: ContentPart[]): string { + return content.map(formatContentPart).filter(Boolean).join("\n"); +} + +function formatMessage(msg: ModelMessage): string | null { + const role = msg.role.toUpperCase(); + + if (typeof msg.content === "string") { + return `[${role}]: ${msg.content}`; + } + + if (Array.isArray(msg.content)) { + const content = formatArrayContent(msg.content as ContentPart[]); + if (content) { + return `[${role}]: ${content}`; + } + } + + return null; +} + +/** + * Formats messages for summarization + */ +function formatMessagesForSummary(messages: ModelMessage[]): string { + return messages.map(formatMessage).filter(Boolean).join("\n\n"); +} + +/** + * Compacts conversation history by summarizing older messages + */ +export async function compactConversation( + model: LanguageModel, + messages: ModelMessage[], + config: Partial = {} +): Promise { + const { keepRecentMessages, maxSummaryTokens } = { + ...DEFAULT_COMPACTION_CONFIG, + ...config, + }; + + // If not enough messages to compact, return as-is + if (messages.length <= keepRecentMessages) { + return { + messages, + originalMessageCount: messages.length, + compactedMessageCount: messages.length, + summary: "", + }; + } + + // Split messages: older ones to summarize, recent ones to keep + const messagesToSummarize = messages.slice(0, -keepRecentMessages); + const recentMessages = messages.slice(-keepRecentMessages); + + console.log( + colorize( + "yellow", + `\n[Compacting context: summarizing ${messagesToSummarize.length} messages...]` + ) + ); + + // Format older messages for summarization + const conversationText = formatMessagesForSummary(messagesToSummarize); + + try { + // Generate summary using the same model + const result = await generateText({ + model, + system: COMPACTION_SYSTEM_PROMPT, + prompt: `Please summarize the following conversation history:\n\n${conversationText}`, + maxOutputTokens: maxSummaryTokens, + }); + + const summary = result.text; + + // Create a new message array with the summary as context + const summaryMessage: ModelMessage = { + role: "user", + content: `[Previous conversation summary]\n${summary}\n\n[Continuing conversation...]`, + }; + + const compactedMessages: ModelMessage[] = [ + summaryMessage, + ...recentMessages, + ]; + + console.log( + colorize( + "green", + `[Context compacted: ${messages.length} → ${compactedMessages.length} messages]` + ) + ); + + return { + messages: compactedMessages, + originalMessageCount: messages.length, + compactedMessageCount: compactedMessages.length, + summary, + }; + } catch (error) { + console.log( + colorize( + "red", + `[Compaction failed: ${error instanceof Error ? error.message : error}]` + ) + ); + + // On failure, just truncate old messages without summary + return { + messages: recentMessages, + originalMessageCount: messages.length, + compactedMessageCount: recentMessages.length, + summary: "", + }; + } +} diff --git a/src/utils/context-tracker.ts b/src/utils/context-tracker.ts new file mode 100644 index 0000000..f5bac4f --- /dev/null +++ b/src/utils/context-tracker.ts @@ -0,0 +1,107 @@ +import type { LanguageModelUsage } from "ai"; + +export interface ContextConfig { + maxContextTokens: number; + compactionThreshold: number; // 0.0 ~ 1.0, e.g., 0.8 means compact at 80% +} + +export interface ContextStats { + totalTokens: number; + inputTokens: number; + outputTokens: number; + maxContextTokens: number; + usagePercentage: number; + shouldCompact: boolean; +} + +const DEFAULT_CONFIG: ContextConfig = { + maxContextTokens: 128_000, // Default for most modern models + compactionThreshold: 0.75, // Compact when 75% of context is used +}; + +export class ContextTracker { + private readonly config: ContextConfig; + private totalInputTokens = 0; + private totalOutputTokens = 0; + private stepCount = 0; + + constructor(config: Partial = {}) { + this.config = { ...DEFAULT_CONFIG, ...config }; + } + + setMaxContextTokens(tokens: number): void { + this.config.maxContextTokens = tokens; + } + + setCompactionThreshold(threshold: number): void { + if (threshold < 0 || threshold > 1) { + throw new Error("Compaction threshold must be between 0 and 1"); + } + this.config.compactionThreshold = threshold; + } + + updateUsage(usage: LanguageModelUsage): void { + this.totalInputTokens += usage.inputTokens ?? 0; + this.totalOutputTokens += usage.outputTokens ?? 0; + this.stepCount++; + } + + /** + * Set total usage directly (useful after compaction or when loading state) + */ + setTotalUsage(inputTokens: number, outputTokens: number): void { + this.totalInputTokens = inputTokens; + this.totalOutputTokens = outputTokens; + } + + /** + * Get estimated current context size + * Note: This is an approximation based on accumulated usage + */ + getEstimatedContextTokens(): number { + // The input tokens from the last request roughly represents + // the current context size (system prompt + conversation history) + return this.totalInputTokens > 0 + ? Math.round(this.totalInputTokens / Math.max(this.stepCount, 1)) + : 0; + } + + getStats(): ContextStats { + const totalTokens = this.totalInputTokens + this.totalOutputTokens; + const usagePercentage = totalTokens / this.config.maxContextTokens; + const shouldCompact = usagePercentage >= this.config.compactionThreshold; + + return { + totalTokens, + inputTokens: this.totalInputTokens, + outputTokens: this.totalOutputTokens, + maxContextTokens: this.config.maxContextTokens, + usagePercentage, + shouldCompact, + }; + } + + shouldCompact(): boolean { + return this.getStats().shouldCompact; + } + + reset(): void { + this.totalInputTokens = 0; + this.totalOutputTokens = 0; + this.stepCount = 0; + } + + /** + * Called after compaction to adjust token counts + * @param newInputTokens The token count of the compacted context + */ + afterCompaction(newInputTokens: number): void { + this.totalInputTokens = newInputTokens; + this.totalOutputTokens = 0; + this.stepCount = 1; + } + + getConfig(): ContextConfig { + return { ...this.config }; + } +} From af9a8c7e7b45b5f6aff335e0d19b02f9cb40a075 Mon Sep 17 00:00:00 2001 From: minpeter Date: Tue, 6 Jan 2026 21:42:54 +0900 Subject: [PATCH 2/3] Update default model from LGAI-EXAONE/K-EXAONE-236B-A23B to zai-org/GLM-4.6 and add support for aborting ongoing conversations via ESC key - Replace default model in documentation and code - Add abort functionality to Agent class with AbortController - Modify chat method to return aborted status - Update command handler to support abort signals in streaming - Add ESC key interrupt support in input handling - Implement /context and /compact commands for monitoring and managing context usage --- README.md | 4 +- src/agent.ts | 79 +++++++++++--- src/commands/index.ts | 239 ++++++++++++++++++++++++++++++++++++++++++ src/index.ts | 52 +++++++-- 4 files changed, 346 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index ab73487..e34f3bc 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ bun start ``` $ bun start -Chat with AI (model: LGAI-EXAONE/K-EXAONE-236B-A23B) +Chat with AI (model: zai-org/GLM-4.6) Use '/help' for commands, 'ctrl-c' to quit You: what's in package.json? @@ -132,7 +132,7 @@ code-editing-agent/ ## Model -Uses `LGAI-EXAONE/K-EXAONE-236B-A23B` via FriendliAI serverless endpoints by default. Use `/models` command to switch models. +Uses `zai-org/GLM-4.6` via FriendliAI serverless endpoints by default. Use `/models` command to switch models. ## License diff --git a/src/agent.ts b/src/agent.ts index dde7752..b10e8ad 100644 --- a/src/agent.ts +++ b/src/agent.ts @@ -113,6 +113,7 @@ export class Agent { private readonly maxSteps: number; private readonly contextTracker: ContextTracker; private readonly autoCompact: boolean; + private abortController: AbortController | null = null; constructor(model: LanguageModel, config: AgentConfig = {}) { this.model = model; @@ -121,6 +122,17 @@ export class Agent { this.autoCompact = config.autoCompact ?? true; } + isRunning(): boolean { + return this.abortController !== null; + } + + abort(): void { + if (this.abortController) { + this.abortController.abort(); + this.abortController = null; + } + } + getModel(): LanguageModel { return this.model; } @@ -174,8 +186,7 @@ export class Agent { this.contextTracker.afterCompaction(estimatedTokens); } - async chat(userInput: string): Promise { - // Check if compaction is needed before processing + async chat(userInput: string): Promise<{ aborted: boolean }> { if (this.autoCompact && this.contextTracker.shouldCompact()) { await this.compactContext(); } @@ -185,14 +196,26 @@ export class Agent { content: userInput, }); - await withRetry(async () => { - await this.executeStreamingChat(); - }); + this.abortController = new AbortController(); + + try { + await withRetry(async () => { + await this.executeStreamingChat(); + }); + } catch (error) { + if (error instanceof Error && error.name === "AbortError") { + return { aborted: true }; + } + throw error; + } finally { + this.abortController = null; + } - // Check again after response if (this.autoCompact && this.contextTracker.shouldCompact()) { await this.compactContext(); } + + return { aborted: false }; } private async executeStreamingChat(): Promise { @@ -202,9 +225,9 @@ export class Agent { messages: this.conversation, tools: agentTools, stopWhen: stepCountIs(this.maxSteps), + abortSignal: this.abortController?.signal, providerOptions: { friendliai: { - // enable_thinking for hybrid reasoning models chat_template_kwargs: { enable_thinking: true, }, @@ -220,23 +243,45 @@ export class Agent { let chunkCount = 0; const debug = env.DEBUG_CHUNK_LOG; - for await (const chunk of result.fullStream) { - chunkCount++; + let aborted = false; - if (debug) { - logDebugChunk(chunk, chunkCount); - logDebugError(chunk); - logDebugFinish(chunk); - } + try { + for await (const chunk of result.fullStream) { + if (this.abortController?.signal.aborted) { + aborted = true; + break; + } + + chunkCount++; + + if (debug) { + logDebugChunk(chunk, chunkCount); + logDebugError(chunk); + logDebugFinish(chunk); + } - handleReasoningDelta(chunk, state); - handleTextDelta(chunk, state); - handleToolCall(chunk, state); + handleReasoningDelta(chunk, state); + handleTextDelta(chunk, state); + handleToolCall(chunk, state); + } + } catch (error) { + if (error instanceof Error && error.name === "AbortError") { + aborted = true; + } else { + throw error; + } } endReasoningIfNeeded(state); endTextIfNeeded(state); + if (aborted) { + console.log(colorize("yellow", "\n[Interrupted by user]")); + const abortError = new Error("Aborted"); + abortError.name = "AbortError"; + throw abortError; + } + const response = await result.response; // Update context tracker with usage information diff --git a/src/commands/index.ts b/src/commands/index.ts index de0c50e..f134195 100644 --- a/src/commands/index.ts +++ b/src/commands/index.ts @@ -8,6 +8,7 @@ import type { import type { Agent } from "../agent"; import { env } from "../env"; import { SYSTEM_PROMPT } from "../prompts/system"; +import { tools } from "../tools/index"; import { colorize } from "../utils/colors"; import { deleteConversation, @@ -17,6 +18,33 @@ import { } from "../utils/conversation-store"; import { selectModel } from "../utils/model-selector"; +interface OpenAITool { + type: "function"; + function: { + name: string; + description: string; + parameters: Record; + }; +} + +interface SchemaWithToJSON { + toJSONSchema: () => Record; +} + +function convertToolsToOpenAIFormat(): OpenAITool[] { + return Object.entries(tools).map(([name, tool]) => { + const schema = tool.inputSchema as unknown as SchemaWithToJSON; + return { + type: "function" as const, + function: { + name, + description: tool.description ?? "", + parameters: schema.toJSONSchema(), + }, + }; + }); +} + interface RenderAPIMessage { role: "system" | "user" | "assistant" | "tool"; content: string | null; @@ -152,6 +180,8 @@ ${colorize("cyan", "Available commands:")} /delete - Delete a saved conversation /models - List and select available AI models /render - Render conversation as raw prompt text + /context - Show context usage statistics + /compact - Manually trigger context compaction /quit - Exit the program `); } @@ -311,6 +341,213 @@ async function handleRender( return { conversationId: ctx.currentConversationId }; } +async function fetchRenderedText( + messages: RenderAPIMessage[], + modelId: string, + includeTools = false +): Promise { + const body: Record = { + model: modelId, + messages, + }; + + if (includeTools) { + body.tools = convertToolsToOpenAIFormat(); + } + + const response = await fetch( + "https://api.friendli.ai/serverless/v1/chat/render", + { + method: "POST", + headers: { + Authorization: `Bearer ${env.FRIENDLI_TOKEN}`, + "Content-Type": "application/json", + }, + body: JSON.stringify(body), + } + ); + + if (!response.ok) { + const error = await response.text(); + console.log(colorize("red", `Render API failed: ${error}`)); + return null; + } + + const data = (await response.json()) as { text: string }; + return data.text; +} + +async function fetchTokenCount( + text: string, + modelId: string +): Promise { + const response = await fetch( + "https://api.friendli.ai/serverless/v1/tokenize", + { + method: "POST", + headers: { + Authorization: `Bearer ${env.FRIENDLI_TOKEN}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + model: modelId, + prompt: text, + }), + } + ); + + if (!response.ok) { + const error = await response.text(); + console.log(colorize("red", `Tokenize API failed: ${error}`)); + return null; + } + + const data = (await response.json()) as { tokens: number[] }; + return data.tokens.length; +} + +type ColorName = "blue" | "yellow" | "green" | "cyan" | "red" | "dim" | "reset"; + +function getProgressBarColor(percentage: number): ColorName { + if (percentage >= 0.8) { + return "red"; + } + if (percentage >= 0.6) { + return "yellow"; + } + if (percentage >= 0.4) { + return "cyan"; + } + return "green"; +} + +function renderProgressBar( + usagePercentage: number, + totalTokens: number, + maxTokens: number +): void { + const barWidth = 40; + const clampedPercentage = Math.min(Math.max(usagePercentage, 0), 1); + const filledWidth = Math.floor(barWidth * clampedPercentage); + const emptyWidth = barWidth - filledWidth; + + const filledBar = "█".repeat(filledWidth); + const emptyBar = "░".repeat(emptyWidth); + const barColor = getProgressBarColor(clampedPercentage); + + console.log( + `${colorize(barColor, "Progress: ")}${filledBar}${emptyBar} ${(clampedPercentage * 100).toFixed(1)}% (${totalTokens.toLocaleString()} / ${maxTokens.toLocaleString()})` + ); +} + +async function handleCompact( + _args: string[], + ctx: CommandContext +): Promise { + const messages = ctx.agent.getConversation(); + if (messages.length === 0) { + console.log(colorize("yellow", "No conversation to compact.")); + return { conversationId: ctx.currentConversationId }; + } + + console.log(colorize("cyan", "=== Context Compaction ===")); + console.log(colorize("dim", "Compacting conversation...")); + + try { + await ctx.agent.compactContext(); + console.log(colorize("green", "✓ Conversation compacted successfully.")); + + const stats = ctx.agent.getContextStats(); + console.log( + colorize( + "dim", + ` New estimated size: ${stats.totalTokens.toLocaleString()} tokens` + ) + ); + } catch (error) { + console.log(colorize("red", `Compaction failed: ${error}`)); + } + + return { conversationId: ctx.currentConversationId }; +} + +async function handleContext( + _args: string[], + ctx: CommandContext +): Promise { + const messages = ctx.agent.getConversation(); + const apiMessages = convertToRenderAPIMessages(messages, SYSTEM_PROMPT); + const isEmptyConversation = messages.length === 0; + + console.log(colorize("cyan", "=== Context Usage ===")); + if (isEmptyConversation) { + console.log(colorize("dim", "Calculating system prompt + tools size...")); + } else { + console.log(colorize("dim", "Calculating accurate token count...")); + } + + try { + const renderedText = await fetchRenderedText( + apiMessages, + ctx.currentModelId, + true + ); + if (renderedText === null) { + return { conversationId: ctx.currentConversationId }; + } + + const tokenCount = await fetchTokenCount(renderedText, ctx.currentModelId); + if (tokenCount === null) { + return { conversationId: ctx.currentConversationId }; + } + + const stats = ctx.agent.getContextStats(); + const maxContextTokens = stats.maxContextTokens; + const usagePercentage = tokenCount / maxContextTokens; + const compactionThreshold = 0.75; + + const tokenLabel = isEmptyConversation + ? `Total tokens: ${tokenCount.toLocaleString()} (system prompt + tools)` + : `Total tokens: ${tokenCount.toLocaleString()}`; + console.log(`\n${tokenLabel}`); + console.log(`Max context: ${maxContextTokens.toLocaleString()}`); + console.log(""); + + renderProgressBar(usagePercentage, tokenCount, maxContextTokens); + + console.log(`\n${colorize("dim", "Usage Details:")}`); + console.log(` Usage percentage: ${(usagePercentage * 100).toFixed(1)}%`); + console.log( + ` Usage threshold: ${(compactionThreshold * 100).toFixed(0)}%` + ); + + if (usagePercentage >= compactionThreshold) { + console.log( + colorize("yellow", " ⚠️ Status: Compaction recommended!") + ); + console.log( + colorize( + "yellow", + ` (Usage above ${(compactionThreshold * 100).toFixed(0)}% threshold)` + ) + ); + } else { + const remaining = (compactionThreshold - usagePercentage) * 100; + console.log(colorize("green", " ✓ Status: Healthy")); + console.log( + colorize( + "dim", + ` ${remaining.toFixed(0)}% until compaction threshold` + ) + ); + } + } catch (error) { + console.log(colorize("red", `Error: ${error}`)); + } + + return { conversationId: ctx.currentConversationId }; +} + const commands: Record = { help: handleHelp, clear: handleClear, @@ -322,6 +559,8 @@ const commands: Record = { exit: handleQuit, models: handleModels, render: handleRender, + context: handleContext, + compact: handleCompact, }; export function handleCommand( diff --git a/src/index.ts b/src/index.ts index e38bc3d..219d731 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,14 +1,14 @@ #!/usr/bin/env bun -import { createInterface } from "node:readline"; +import { createInterface, emitKeypressEvents } from "node:readline"; import { createFriendli } from "@friendliai/ai-provider"; import type { LanguageModel } from "ai"; import { Agent } from "./agent"; import { handleCommand } from "./commands"; import { env } from "./env"; import { wrapModel } from "./model/create-model"; -import { printYou } from "./utils/colors"; +import { colorize, printYou } from "./utils/colors"; -const DEFAULT_MODEL_ID = "LGAI-EXAONE/K-EXAONE-236B-A23B"; +const DEFAULT_MODEL_ID = "zai-org/GLM-4.6"; const friendli = createFriendli({ apiKey: env.FRIENDLI_TOKEN, @@ -24,15 +24,40 @@ const rl = createInterface({ output: process.stdout, }); +emitKeypressEvents(process.stdin); + +function setupEscHandler(): void { + if (process.stdin.isTTY) { + process.stdin.setRawMode(true); + } + + process.stdin.on("keypress", (_chunk, key) => { + if (key?.name === "escape" && agent.isRunning()) { + agent.abort(); + } + }); +} + function getUserInput(): Promise { return new Promise((resolve) => { printYou(); - rl.once("line", (line) => { + + if (process.stdin.isTTY) { + process.stdin.setRawMode(false); + } + + const onLine = (line: string) => { + rl.removeListener("close", onClose); resolve(line); - }); - rl.once("close", () => { + }; + + const onClose = () => { + rl.removeListener("line", onLine); resolve(null); - }); + }; + + rl.once("line", onLine); + rl.once("close", onClose); }); } @@ -48,9 +73,11 @@ function setModel(model: LanguageModel, modelId: string): void { async function main(): Promise { console.log(`Chat with AI (model: ${currentModelId})`); - console.log("Use '/help' for commands, 'ctrl-c' to quit"); + console.log("Use '/help' for commands, 'ESC' to interrupt, 'ctrl-c' to quit"); console.log(); + setupEscHandler(); + while (true) { const userInput = await getUserInput(); @@ -77,8 +104,15 @@ async function main(): Promise { continue; } + if (process.stdin.isTTY) { + process.stdin.setRawMode(true); + } + try { - await agent.chat(userInput); + const { aborted } = await agent.chat(userInput); + if (aborted) { + console.log(colorize("dim", "(You can continue typing)")); + } } catch (error) { console.error("An error occurred:", error); } From 6476bf824d634c93cb37aba55527c3a8625f5569 Mon Sep 17 00:00:00 2001 From: minpeter Date: Tue, 6 Jan 2026 23:12:10 +0900 Subject: [PATCH 3/3] feat: context management follow-ups --- src/agent.ts | 344 ++++++++++++++++++++++++++++++--- src/commands/index.ts | 274 +++----------------------- src/env.ts | 1 + src/index.ts | 10 +- src/prompts/system.ts | 4 + src/utils/context-compactor.ts | 202 ++++++++++++++----- src/utils/context-tracker.ts | 15 +- src/utils/render-api.ts | 228 ++++++++++++++++++++++ 8 files changed, 754 insertions(+), 324 deletions(-) create mode 100644 src/utils/render-api.ts diff --git a/src/agent.ts b/src/agent.ts index b10e8ad..1f67b64 100644 --- a/src/agent.ts +++ b/src/agent.ts @@ -25,13 +25,20 @@ import { type ContextStats, ContextTracker, } from "./utils/context-tracker"; +import { + measureContextTokens, + type RenderApiOptions, +} from "./utils/render-api"; import { withRetry } from "./utils/retry"; type StreamChunk = TextStreamPart; +type ToolCallChunk = Extract; +type AssistantContentPart = { type: "text"; text: string } | ToolCallChunk; interface StreamState { hasStartedText: boolean; hasStartedReasoning: boolean; + sawTextDelta: boolean; } function endReasoningIfNeeded(state: StreamState): void { @@ -63,6 +70,7 @@ function handleTextDelta(chunk: StreamChunk, state: StreamState): void { if (chunk.type !== "text-delta") { return; } + state.sawTextDelta = true; endReasoningIfNeeded(state); if (!state.hasStartedText) { printAIPrefix(); @@ -80,6 +88,39 @@ function handleToolCall(chunk: StreamChunk, state: StreamState): void { printTool(chunk.toolName, chunk.input); } +function appendAssistantText( + parts: AssistantContentPart[], + text: string +): void { + const lastPart = parts.at(-1); + if (lastPart && lastPart.type === "text") { + lastPart.text += text; + return; + } + parts.push({ type: "text", text }); +} + +function appendAssistantToolCall( + parts: AssistantContentPart[], + toolCall: ToolCallChunk +): void { + parts.push(toolCall); +} + +function flushAssistantMessage( + stagedMessages: ModelMessage[], + parts: AssistantContentPart[] +): void { + if (parts.length === 0) { + return; + } + stagedMessages.push({ + role: "assistant", + content: [...parts], + }); + parts.length = 0; +} + function logDebugChunk(chunk: StreamChunk, chunkCount: number): void { const skipTypes = ["text-delta", "reasoning-delta", "tool-result"]; if (!skipTypes.includes(chunk.type)) { @@ -99,24 +140,93 @@ function logDebugFinish(chunk: StreamChunk): void { } } +function extractAssistantText(messages: ModelMessage[]): string { + const chunks: string[] = []; + for (const message of messages) { + if (message.role !== "assistant") { + continue; + } + if (typeof message.content === "string") { + if (message.content) { + chunks.push(message.content); + } + continue; + } + if (Array.isArray(message.content)) { + for (const part of message.content) { + if (part.type === "text" && part.text) { + chunks.push(part.text); + } + } + } + } + return chunks.join(""); +} + +function assistantMessageHasText(message: ModelMessage): boolean { + if (message.role !== "assistant") { + return false; + } + if (typeof message.content === "string") { + return message.content.trim().length > 0; + } + if (Array.isArray(message.content)) { + return message.content.some( + (part) => part.type === "text" && (part.text ?? "").trim().length > 0 + ); + } + return false; +} + +function shouldContinueAfterTools(messages: ModelMessage[]): boolean { + let lastToolIndex = -1; + for (let i = 0; i < messages.length; i += 1) { + if (messages[i]?.role === "tool") { + lastToolIndex = i; + } + } + if (lastToolIndex === -1) { + return false; + } + for (let i = lastToolIndex + 1; i < messages.length; i += 1) { + if (assistantMessageHasText(messages[i])) { + return false; + } + } + return true; +} + +const MAX_TOOL_FOLLOWUPS = 3; + const DEFAULT_MAX_STEPS = 255; export interface AgentConfig { maxSteps?: number; contextConfig?: Partial; autoCompact?: boolean; + modelId?: string; } export class Agent { private model: LanguageModel; + private modelId: string; private conversation: ModelMessage[] = []; private readonly maxSteps: number; private readonly contextTracker: ContextTracker; private readonly autoCompact: boolean; private abortController: AbortController | null = null; + private contextMeasureInFlight: Promise | null = null; + private readonly pendingContextMeasures: Array< + RenderApiOptions & { + debugLabel?: string; + messages?: ModelMessage[]; + systemPrompt?: string; + } + > = []; constructor(model: LanguageModel, config: AgentConfig = {}) { this.model = model; + this.modelId = config.modelId ?? ""; this.maxSteps = config.maxSteps ?? DEFAULT_MAX_STEPS; this.contextTracker = new ContextTracker(config.contextConfig); this.autoCompact = config.autoCompact ?? true; @@ -137,8 +247,11 @@ export class Agent { return this.model; } - setModel(model: LanguageModel): void { + setModel(model: LanguageModel, modelId?: string): void { this.model = model; + if (modelId) { + this.modelId = modelId; + } } getConversation(): ModelMessage[] { @@ -175,32 +288,140 @@ export class Agent { return this.contextTracker.getStats(); } + getContextConfig(): ContextConfig { + return this.contextTracker.getConfig(); + } + + async refreshContextTokens( + options: RenderApiOptions & { + debugLabel?: string; + messages?: ModelMessage[]; + systemPrompt?: string; + } = {} + ): Promise { + if (!this.modelId) { + options.onError?.("Model ID is not set for context measurement."); + return null; + } + + const { messages, debugLabel, systemPrompt, ...renderOptions } = options; + const prompt = systemPrompt ?? SYSTEM_PROMPT; + let tokenCount: number | null = null; + try { + tokenCount = await measureContextTokens( + messages ?? this.conversation, + this.modelId, + prompt, + renderOptions + ); + } catch (error) { + const message = error instanceof Error ? error.message : String(error); + options.onError?.(`Context measurement failed: ${message}`); + return null; + } + + if (tokenCount !== null) { + this.contextTracker.setContextTokens(tokenCount); + this.logContextStats(debugLabel); + } + + return tokenCount; + } + + private logContextStats(label?: string): void { + if (!env.DEBUG_CONTEXT_LOG) { + return; + } + const stats = this.contextTracker.getStats(); + const suffix = label ? ` ${label}` : ""; + const status = stats.shouldCompact + ? colorize("yellow", "COMPACT") + : colorize("green", "OK"); + console.log( + colorize("dim", `[Context${suffix}]`) + + " " + + `${stats.totalTokens.toLocaleString()} / ${stats.maxContextTokens.toLocaleString()} tokens ` + + `(${(stats.usagePercentage * 100).toFixed(1)}%) | ` + + `compact: ${status}` + ); + } + + scheduleContextMeasurement( + options: RenderApiOptions & { + debugLabel?: string; + messages?: ModelMessage[]; + systemPrompt?: string; + } = {} + ): void { + const scheduledOptions = options.messages + ? { ...options, messages: [...options.messages] } + : options; + + if (this.contextMeasureInFlight) { + this.pendingContextMeasures.push(scheduledOptions); + return; + } + + this.contextMeasureInFlight = this.refreshContextTokens(scheduledOptions) + .catch(() => null) + .finally(() => { + this.contextMeasureInFlight = null; + const pending = this.pendingContextMeasures.shift(); + if (pending) { + this.scheduleContextMeasurement(pending); + } + }); + } + + async flushContextMeasurement(): Promise { + while ( + this.contextMeasureInFlight || + this.pendingContextMeasures.length > 0 + ) { + if (this.contextMeasureInFlight) { + await this.contextMeasureInFlight; + continue; + } + const pending = this.pendingContextMeasures.shift(); + if (pending) { + this.scheduleContextMeasurement(pending); + } + } + + return this.contextTracker.getStats().totalTokens; + } + /** * Manually trigger context compaction */ async compactContext(): Promise { const result = await compactConversation(this.model, this.conversation); this.conversation = result.messages; - // Estimate new token count (rough approximation) - const estimatedTokens = result.summary.length / 4; // ~4 chars per token + + const tokenCount = await this.refreshContextTokens({ + debugLabel: "after compact", + }); + if (tokenCount !== null) { + this.contextTracker.afterCompaction(tokenCount); + return; + } + const estimatedTokens = Math.round(result.summary.length / 4); this.contextTracker.afterCompaction(estimatedTokens); } async chat(userInput: string): Promise<{ aborted: boolean }> { + this.conversation.push({ role: "user", content: userInput }); + await this.refreshContextTokens({ debugLabel: "after user" }); + if (this.autoCompact && this.contextTracker.shouldCompact()) { await this.compactContext(); } - this.conversation.push({ - role: "user", - content: userInput, - }); - this.abortController = new AbortController(); try { await withRetry(async () => { - await this.executeStreamingChat(); + await this.executeStreamingChat(SYSTEM_PROMPT, this.conversation); }); } catch (error) { if (error instanceof Error && error.name === "AbortError") { @@ -211,6 +432,9 @@ export class Agent { this.abortController = null; } + this.scheduleContextMeasurement({ debugLabel: "after ai" }); + await this.flushContextMeasurement(); + if (this.autoCompact && this.contextTracker.shouldCompact()) { await this.compactContext(); } @@ -218,11 +442,48 @@ export class Agent { return { aborted: false }; } - private async executeStreamingChat(): Promise { + private handleStreamChunk( + chunk: StreamChunk, + state: StreamState, + stagedMessages: ModelMessage[], + assistantParts: AssistantContentPart[], + systemPrompt: string + ): void { + if (chunk.type === "text-delta") { + appendAssistantText(assistantParts, chunk.text); + } + + if (chunk.type === "tool-call") { + appendAssistantToolCall(assistantParts, chunk); + } + + if (chunk.type === "tool-result") { + flushAssistantMessage(stagedMessages, assistantParts); + const toolMessage: ModelMessage = { + role: "tool", + content: [chunk], + }; + stagedMessages.push(toolMessage); + this.scheduleContextMeasurement({ + debugLabel: "after tool", + messages: stagedMessages, + systemPrompt, + }); + } + + handleReasoningDelta(chunk, state); + handleTextDelta(chunk, state); + handleToolCall(chunk, state); + } + + private async runStreamingStep( + systemPrompt: string, + messages: ModelMessage[] + ): Promise { const result = streamText({ model: this.model, - system: SYSTEM_PROMPT, - messages: this.conversation, + system: systemPrompt, + messages, tools: agentTools, stopWhen: stepCountIs(this.maxSteps), abortSignal: this.abortController?.signal, @@ -238,8 +499,12 @@ export class Agent { const state: StreamState = { hasStartedText: false, hasStartedReasoning: false, + sawTextDelta: false, }; + const stagedMessages: ModelMessage[] = [...messages]; + const assistantParts: AssistantContentPart[] = []; + let chunkCount = 0; const debug = env.DEBUG_CHUNK_LOG; @@ -260,9 +525,13 @@ export class Agent { logDebugFinish(chunk); } - handleReasoningDelta(chunk, state); - handleTextDelta(chunk, state); - handleToolCall(chunk, state); + this.handleStreamChunk( + chunk, + state, + stagedMessages, + assistantParts, + systemPrompt + ); } } catch (error) { if (error instanceof Error && error.name === "AbortError") { @@ -272,6 +541,7 @@ export class Agent { } } + flushAssistantMessage(stagedMessages, assistantParts); endReasoningIfNeeded(state); endTextIfNeeded(state); @@ -284,26 +554,46 @@ export class Agent { const response = await result.response; - // Update context tracker with usage information + // Update context tracker with usage information (fallback for estimation). const totalUsage = await result.totalUsage; if (totalUsage) { this.contextTracker.updateUsage(totalUsage); - - if (debug) { - const stats = this.contextTracker.getStats(); - console.log( - colorize( - "dim", - `[Context] ${stats.totalTokens.toLocaleString()} / ${stats.maxContextTokens.toLocaleString()} tokens (${(stats.usagePercentage * 100).toFixed(1)}%)` - ) - ); - } } if (debug) { console.log(`[DEBUG] Total chunks: ${chunkCount}`); console.log(`[DEBUG] Response messages: ${response.messages.length}`); } - this.conversation.push(...response.messages); + if (!state.sawTextDelta) { + const fallbackText = extractAssistantText(response.messages); + if (fallbackText) { + printAIPrefix(); + printChunk(fallbackText); + printNewline(); + } + } + + return response.messages; + } + + private async executeStreamingChat( + systemPrompt: string, + messages: ModelMessage[] + ): Promise { + let currentMessages = messages; + + for (let attempt = 0; attempt < MAX_TOOL_FOLLOWUPS; attempt += 1) { + const responseMessages = await this.runStreamingStep( + systemPrompt, + currentMessages + ); + this.conversation.push(...responseMessages); + + if (!shouldContinueAfterTools(responseMessages)) { + return; + } + + currentMessages = this.conversation; + } } } diff --git a/src/commands/index.ts b/src/commands/index.ts index f134195..5427bea 100644 --- a/src/commands/index.ts +++ b/src/commands/index.ts @@ -1,14 +1,7 @@ import type { Interface as ReadlineInterface } from "node:readline"; -import type { - LanguageModel, - ModelMessage, - ToolModelMessage, - ToolResultPart, -} from "ai"; +import type { LanguageModel } from "ai"; import type { Agent } from "../agent"; -import { env } from "../env"; import { SYSTEM_PROMPT } from "../prompts/system"; -import { tools } from "../tools/index"; import { colorize } from "../utils/colors"; import { deleteConversation, @@ -17,139 +10,14 @@ import { saveConversation, } from "../utils/conversation-store"; import { selectModel } from "../utils/model-selector"; +import { + convertToRenderAPIMessages, + fetchRenderedText, +} from "../utils/render-api"; -interface OpenAITool { - type: "function"; - function: { - name: string; - description: string; - parameters: Record; - }; -} - -interface SchemaWithToJSON { - toJSONSchema: () => Record; -} - -function convertToolsToOpenAIFormat(): OpenAITool[] { - return Object.entries(tools).map(([name, tool]) => { - const schema = tool.inputSchema as unknown as SchemaWithToJSON; - return { - type: "function" as const, - function: { - name, - description: tool.description ?? "", - parameters: schema.toJSONSchema(), - }, - }; - }); -} - -interface RenderAPIMessage { - role: "system" | "user" | "assistant" | "tool"; - content: string | null; - name?: string; - tool_calls?: Array<{ - id: string; - type: "function"; - function: { name: string; arguments: string }; - }>; - tool_call_id?: string; -} - -function extractTextContent( - parts: Array<{ type: string; text?: string }> -): string { - return parts - .filter((p) => p.type === "text") - .map((p) => p.text ?? "") - .join(""); -} - -function determineAssistantContent( - textParts: Array<{ type: string; text?: string }>, - hasToolCalls: boolean -): string | null { - if (textParts.length > 0) { - return extractTextContent(textParts); - } - if (hasToolCalls) { - return null; - } - return ""; -} - -function convertUserMessage(msg: ModelMessage): RenderAPIMessage { - const content = Array.isArray(msg.content) - ? extractTextContent(msg.content) - : msg.content; - return { role: "user", content }; -} - -function convertAssistantMessage(msg: ModelMessage): RenderAPIMessage { - const contentArray = Array.isArray(msg.content) ? msg.content : []; - const textParts = contentArray.filter((p) => p.type === "text"); - const toolCallParts = contentArray.filter((p) => p.type === "tool-call"); - - const content = determineAssistantContent( - textParts, - toolCallParts.length > 0 - ); - const assistantMsg: RenderAPIMessage = { role: "assistant", content }; - - if (toolCallParts.length > 0) { - assistantMsg.tool_calls = toolCallParts.map((tc) => ({ - id: tc.toolCallId, - type: "function" as const, - function: { - name: tc.toolName, - arguments: JSON.stringify(tc.input), - }, - })); - } - - return assistantMsg; -} - -function convertToolMessages(msg: ToolModelMessage): RenderAPIMessage[] { - const results: RenderAPIMessage[] = []; - for (const part of msg.content) { - if (part.type === "tool-result") { - const resultPart = part as ToolResultPart; - const content = - typeof resultPart.output === "string" - ? resultPart.output - : JSON.stringify(resultPart.output); - results.push({ - role: "tool", - content, - tool_call_id: resultPart.toolCallId, - }); - } - } - return results; -} - -function convertToRenderAPIMessages( - messages: ModelMessage[], - systemPrompt: string -): RenderAPIMessage[] { - const result: RenderAPIMessage[] = [ - { role: "system", content: systemPrompt }, - ]; - - for (const msg of messages) { - if (msg.role === "user") { - result.push(convertUserMessage(msg)); - } else if (msg.role === "assistant") { - result.push(convertAssistantMessage(msg)); - } else if (msg.role === "tool") { - result.push(...convertToolMessages(msg as ToolModelMessage)); - } - } - - return result; -} +const apiErrorHandler = (message: string): void => { + console.log(colorize("red", message)); +}; export interface CommandContext { agent: Agent; @@ -191,9 +59,13 @@ function handleHelp(_args: string[], ctx: CommandContext): CommandResult { return { conversationId: ctx.currentConversationId }; } -function handleClear(_args: string[], ctx: CommandContext): CommandResult { +async function handleClear( + _args: string[], + ctx: CommandContext +): Promise { ctx.agent.clearConversation(); console.log(colorize("green", "Conversation cleared.")); + await ctx.agent.refreshContextTokens({ onError: apiErrorHandler }); return { conversationId: undefined }; } @@ -226,6 +98,7 @@ async function handleLoad( return { conversationId: ctx.currentConversationId }; } ctx.agent.loadConversation(stored.messages); + await ctx.agent.refreshContextTokens({ onError: apiErrorHandler }); console.log( colorize( "green", @@ -291,6 +164,7 @@ async function handleModels( if (selection) { ctx.setModel(selection.model, selection.modelId); console.log(colorize("green", `Model changed to: ${selection.modelId}`)); + await ctx.agent.refreshContextTokens({ onError: apiErrorHandler }); } return { conversationId: ctx.currentConversationId }; @@ -309,30 +183,17 @@ async function handleRender( const apiMessages = convertToRenderAPIMessages(messages, SYSTEM_PROMPT); try { - const response = await fetch( - "https://api.friendli.ai/serverless/v1/chat/render", - { - method: "POST", - headers: { - Authorization: `Bearer ${env.FRIENDLI_TOKEN}`, - "Content-Type": "application/json", - }, - body: JSON.stringify({ - model: ctx.currentModelId, - messages: apiMessages, - }), - } + const renderedText = await fetchRenderedText( + apiMessages, + ctx.currentModelId, + false, + { onError: apiErrorHandler } ); - - if (!response.ok) { - const error = await response.text(); - console.log(colorize("red", `Render failed: ${error}`)); + if (renderedText === null) { return { conversationId: ctx.currentConversationId }; } - - const data = (await response.json()) as { text: string }; console.log(colorize("cyan", "=== Rendered Prompt ===")); - console.log(data.text); + console.log(renderedText); console.log(colorize("cyan", "=======================")); } catch (error) { console.log(colorize("red", `Error: ${error}`)); @@ -341,71 +202,6 @@ async function handleRender( return { conversationId: ctx.currentConversationId }; } -async function fetchRenderedText( - messages: RenderAPIMessage[], - modelId: string, - includeTools = false -): Promise { - const body: Record = { - model: modelId, - messages, - }; - - if (includeTools) { - body.tools = convertToolsToOpenAIFormat(); - } - - const response = await fetch( - "https://api.friendli.ai/serverless/v1/chat/render", - { - method: "POST", - headers: { - Authorization: `Bearer ${env.FRIENDLI_TOKEN}`, - "Content-Type": "application/json", - }, - body: JSON.stringify(body), - } - ); - - if (!response.ok) { - const error = await response.text(); - console.log(colorize("red", `Render API failed: ${error}`)); - return null; - } - - const data = (await response.json()) as { text: string }; - return data.text; -} - -async function fetchTokenCount( - text: string, - modelId: string -): Promise { - const response = await fetch( - "https://api.friendli.ai/serverless/v1/tokenize", - { - method: "POST", - headers: { - Authorization: `Bearer ${env.FRIENDLI_TOKEN}`, - "Content-Type": "application/json", - }, - body: JSON.stringify({ - model: modelId, - prompt: text, - }), - } - ); - - if (!response.ok) { - const error = await response.text(); - console.log(colorize("red", `Tokenize API failed: ${error}`)); - return null; - } - - const data = (await response.json()) as { tokens: number[] }; - return data.tokens.length; -} - type ColorName = "blue" | "yellow" | "green" | "cyan" | "red" | "dim" | "reset"; function getProgressBarColor(percentage: number): ColorName { @@ -457,12 +253,14 @@ async function handleCompact( await ctx.agent.compactContext(); console.log(colorize("green", "✓ Conversation compacted successfully.")); + const tokenCount = await ctx.agent.refreshContextTokens({ + onError: apiErrorHandler, + }); const stats = ctx.agent.getContextStats(); + const displayTokens = tokenCount ?? stats.totalTokens; + const label = tokenCount === null ? " New estimated size:" : " New size:"; console.log( - colorize( - "dim", - ` New estimated size: ${stats.totalTokens.toLocaleString()} tokens` - ) + colorize("dim", `${label} ${displayTokens.toLocaleString()} tokens`) ); } catch (error) { console.log(colorize("red", `Compaction failed: ${error}`)); @@ -476,7 +274,6 @@ async function handleContext( ctx: CommandContext ): Promise { const messages = ctx.agent.getConversation(); - const apiMessages = convertToRenderAPIMessages(messages, SYSTEM_PROMPT); const isEmptyConversation = messages.length === 0; console.log(colorize("cyan", "=== Context Usage ===")); @@ -487,16 +284,9 @@ async function handleContext( } try { - const renderedText = await fetchRenderedText( - apiMessages, - ctx.currentModelId, - true - ); - if (renderedText === null) { - return { conversationId: ctx.currentConversationId }; - } - - const tokenCount = await fetchTokenCount(renderedText, ctx.currentModelId); + const tokenCount = await ctx.agent.refreshContextTokens({ + onError: apiErrorHandler, + }); if (tokenCount === null) { return { conversationId: ctx.currentConversationId }; } @@ -504,7 +294,7 @@ async function handleContext( const stats = ctx.agent.getContextStats(); const maxContextTokens = stats.maxContextTokens; const usagePercentage = tokenCount / maxContextTokens; - const compactionThreshold = 0.75; + const { compactionThreshold } = ctx.agent.getContextConfig(); const tokenLabel = isEmptyConversation ? `Total tokens: ${tokenCount.toLocaleString()} (system prompt + tools)` diff --git a/src/env.ts b/src/env.ts index a77178f..0e00122 100644 --- a/src/env.ts +++ b/src/env.ts @@ -5,6 +5,7 @@ export const env = createEnv({ server: { FRIENDLI_TOKEN: z.string().min(1), DEBUG_CHUNK_LOG: z.stringbool().default(false), + DEBUG_CONTEXT_LOG: z.stringbool().default(false), }, runtimeEnv: process.env, emptyStringAsUndefined: true, diff --git a/src/index.ts b/src/index.ts index 219d731..3e334a5 100644 --- a/src/index.ts +++ b/src/index.ts @@ -16,12 +16,15 @@ const friendli = createFriendli({ }); let currentModelId = DEFAULT_MODEL_ID; -const agent = new Agent(wrapModel(friendli(currentModelId))); +const agent = new Agent(wrapModel(friendli(currentModelId)), { + modelId: currentModelId, +}); let currentConversationId: string | undefined; const rl = createInterface({ input: process.stdin, output: process.stdout, + terminal: false, }); emitKeypressEvents(process.stdin); @@ -40,6 +43,7 @@ function setupEscHandler(): void { function getUserInput(): Promise { return new Promise((resolve) => { + rl.resume(); printYou(); if (process.stdin.isTTY) { @@ -48,11 +52,13 @@ function getUserInput(): Promise { const onLine = (line: string) => { rl.removeListener("close", onClose); + rl.pause(); resolve(line); }; const onClose = () => { rl.removeListener("line", onLine); + rl.pause(); resolve(null); }; @@ -67,7 +73,7 @@ function exitProgram(): void { } function setModel(model: LanguageModel, modelId: string): void { - agent.setModel(wrapModel(model)); + agent.setModel(wrapModel(model), modelId); currentModelId = modelId; } diff --git a/src/prompts/system.ts b/src/prompts/system.ts index 8bf66b2..2b42b9e 100644 --- a/src/prompts/system.ts +++ b/src/prompts/system.ts @@ -5,7 +5,11 @@ You have access to the following tools: - **read_file**: Read the contents of a file - **list_files**: List files and directories (respects .gitignore) - **edit_file**: Edit files by replacing text, or create new files +- **write_file**: Create or overwrite files +- **delete_file**: Delete files - **run_command**: Execute safe shell commands +- **glob**: Find files by glob pattern +- **grep**: Search file contents quickly ## Guidelines diff --git a/src/utils/context-compactor.ts b/src/utils/context-compactor.ts index 676c782..b21f25e 100644 --- a/src/utils/context-compactor.ts +++ b/src/utils/context-compactor.ts @@ -2,14 +2,32 @@ import type { LanguageModel, ModelMessage } from "ai"; import { generateText } from "ai"; import { colorize } from "./colors"; -const COMPACTION_SYSTEM_PROMPT = `You are a conversation summarizer. Your task is to create a concise summary of the conversation history that preserves: -1. Key decisions made -2. Important code changes or file modifications -3. Current task context and goals -4. Any errors encountered and their resolutions +const COMPACTION_SYSTEM_PROMPT = `You compress a long coding-agent conversation into a single context block that will replace old messages. +Preserve ONLY information needed to continue the task correctly. -Output a summary that can serve as context for continuing the conversation. -Be concise but preserve essential information. Format as a brief narrative.`; +Hard rules: +- Keep exact identifiers: file paths, function/class names, CLI commands, flags, config keys, URLs, branch names. +- Keep exact error messages or their essential lines (do NOT paraphrase them away). +- Keep decisions that constrain future actions. +- If information is uncertain, put it in OpenQuestions instead of stating as fact. +- Do not include chit-chat or redundant deliberation. +- Output must follow the exact format below. + +OUTPUT FORMAT (exact): +[COMPRESSED_CONTEXT] +Goal: +Constraints: +Repo/Env: +Current state: +Key decisions: +Work completed: +Files touched: +- : +Current errors / failing tests: +- : +OpenQuestions: +NextSteps: +[/COMPRESSED_CONTEXT]`; export interface CompactionResult { messages: ModelMessage[]; @@ -19,13 +37,13 @@ export interface CompactionResult { } export interface CompactionConfig { - keepRecentMessages: number; // Number of recent messages to preserve + keepRecentMessages: number; maxSummaryTokens: number; } const DEFAULT_COMPACTION_CONFIG: CompactionConfig = { - keepRecentMessages: 6, // Keep last 3 exchanges (user + assistant pairs) - maxSummaryTokens: 2000, + keepRecentMessages: 8, + maxSummaryTokens: 1600, }; interface ContentPart { @@ -33,16 +51,41 @@ interface ContentPart { text?: string; toolName?: string; output?: unknown; + input?: Record; +} + +const SUMMARY_TAG = "[COMPRESSED_CONTEXT]"; +const SUMMARY_TAG_END = "[/COMPRESSED_CONTEXT]"; +const MAX_TOOL_PREVIEW_CHARS = 1600; +const TOOL_PREVIEW_HEAD_LINES = 8; +const TOOL_PREVIEW_TAIL_LINES = 4; + +function isCompressedSummaryMessage(message: ModelMessage): boolean { + return ( + typeof message.content === "string" && + message.content.includes(SUMMARY_TAG) && + message.content.includes(SUMMARY_TAG_END) + ); } function getToolResultPreview(output: unknown): string { - if (typeof output === "string") { - return output.slice(0, 200); + const text = + typeof output === "string" ? output : JSON.stringify(output, null, 2); + + if (text.length <= MAX_TOOL_PREVIEW_CHARS) { + return text; } - if (output != null) { - return JSON.stringify(output).slice(0, 200); + + const lines = text.split("\n"); + if (lines.length <= TOOL_PREVIEW_HEAD_LINES + TOOL_PREVIEW_TAIL_LINES) { + return text.slice(0, MAX_TOOL_PREVIEW_CHARS); } - return ""; + + const head = lines.slice(0, TOOL_PREVIEW_HEAD_LINES).join("\n"); + const tail = lines + .slice(-TOOL_PREVIEW_TAIL_LINES) + .join("\n"); + return `${head}\n...\n${tail}`; } function formatContentPart(part: ContentPart): string { @@ -50,11 +93,16 @@ function formatContentPart(part: ContentPart): string { return part.text; } if (part.type === "tool-call" && part.toolName) { - return `[Tool Call: ${part.toolName}]`; + const input = part.input ? JSON.stringify(part.input) : ""; + const trimmedInput = + input.length > 400 ? `${input.slice(0, 400)}…` : input; + return trimmedInput + ? `[Tool Call: ${part.toolName} ${trimmedInput}]` + : `[Tool Call: ${part.toolName}]`; } if (part.type === "tool-result") { const preview = getToolResultPreview(part.output); - return `[Tool Result: ${preview}...]`; + return `[Tool Result]\n${preview}`; } return ""; } @@ -80,16 +128,52 @@ function formatMessage(msg: ModelMessage): string | null { return null; } -/** - * Formats messages for summarization - */ function formatMessagesForSummary(messages: ModelMessage[]): string { return messages.map(formatMessage).filter(Boolean).join("\n\n"); } -/** - * Compacts conversation history by summarizing older messages - */ +function hasValidSummaryFormat(text: string): boolean { + const requiredSections = [ + SUMMARY_TAG, + "Goal:", + "Constraints:", + "Repo/Env:", + "Current state:", + "Key decisions:", + "Work completed:", + "Files touched:", + "Current errors / failing tests:", + "OpenQuestions:", + "NextSteps:", + SUMMARY_TAG_END, + ]; + return requiredSections.every((section) => text.includes(section)); +} + +async function generateSummary( + model: LanguageModel, + previousSummary: string | null, + conversationText: string, + maxSummaryTokens: number +): Promise { + const previousText = previousSummary ?? "NONE"; + const prompt = `Previous compressed context (if any):\n${previousText}\n\nNew conversation to incorporate:\n${conversationText}`; + const result = await generateText({ + model, + system: COMPACTION_SYSTEM_PROMPT, + prompt, + maxOutputTokens: maxSummaryTokens, + }); + return result.text.trim(); +} + +function buildSummaryMessage(summary: string): ModelMessage { + return { + role: "user", + content: summary, + }; +} + export async function compactConversation( model: LanguageModel, messages: ModelMessage[], @@ -100,19 +184,27 @@ export async function compactConversation( ...config, }; - // If not enough messages to compact, return as-is - if (messages.length <= keepRecentMessages) { + const summaryMessages = messages.filter(isCompressedSummaryMessage); + const previousSummary = summaryMessages.at(0)?.content; + const filteredMessages = messages.filter( + (message) => !isCompressedSummaryMessage(message) + ); + + if (filteredMessages.length <= keepRecentMessages) { + const preserved = previousSummary + ? [buildSummaryMessage(String(previousSummary)), ...filteredMessages] + : [...filteredMessages]; + return { - messages, + messages: preserved, originalMessageCount: messages.length, - compactedMessageCount: messages.length, - summary: "", + compactedMessageCount: preserved.length, + summary: previousSummary ? String(previousSummary) : "", }; } - // Split messages: older ones to summarize, recent ones to keep - const messagesToSummarize = messages.slice(0, -keepRecentMessages); - const recentMessages = messages.slice(-keepRecentMessages); + const messagesToSummarize = filteredMessages.slice(0, -keepRecentMessages); + const recentMessages = filteredMessages.slice(-keepRecentMessages); console.log( colorize( @@ -121,30 +213,35 @@ export async function compactConversation( ) ); - // Format older messages for summarization const conversationText = formatMessagesForSummary(messagesToSummarize); try { - // Generate summary using the same model - const result = await generateText({ + let summary = await generateSummary( model, - system: COMPACTION_SYSTEM_PROMPT, - prompt: `Please summarize the following conversation history:\n\n${conversationText}`, - maxOutputTokens: maxSummaryTokens, - }); + typeof previousSummary === "string" ? previousSummary : null, + conversationText, + maxSummaryTokens + ); - const summary = result.text; + if (!hasValidSummaryFormat(summary)) { + const previousText = + typeof previousSummary === "string" ? previousSummary : "NONE"; + const retryPrompt = `Previous compressed context (if any):\n${previousText}\n\nNew conversation to incorporate:\n${conversationText}\n\nYour output did not match the required format. Please retry with the exact format only.`; + const retryResult = await generateText({ + model, + system: COMPACTION_SYSTEM_PROMPT, + prompt: retryPrompt, + maxOutputTokens: maxSummaryTokens, + }); + summary = retryResult.text.trim(); + } - // Create a new message array with the summary as context - const summaryMessage: ModelMessage = { - role: "user", - content: `[Previous conversation summary]\n${summary}\n\n[Continuing conversation...]`, - }; + if (!hasValidSummaryFormat(summary)) { + throw new Error("Compaction summary format invalid after retry"); + } - const compactedMessages: ModelMessage[] = [ - summaryMessage, - ...recentMessages, - ]; + const summaryMessage = buildSummaryMessage(summary); + const compactedMessages: ModelMessage[] = [summaryMessage, ...recentMessages]; console.log( colorize( @@ -167,11 +264,14 @@ export async function compactConversation( ) ); - // On failure, just truncate old messages without summary + const fallbackMessages = previousSummary + ? [buildSummaryMessage(String(previousSummary)), ...recentMessages] + : recentMessages; + return { - messages: recentMessages, + messages: fallbackMessages, originalMessageCount: messages.length, - compactedMessageCount: recentMessages.length, + compactedMessageCount: fallbackMessages.length, summary: "", }; } diff --git a/src/utils/context-tracker.ts b/src/utils/context-tracker.ts index f5bac4f..17a47d5 100644 --- a/src/utils/context-tracker.ts +++ b/src/utils/context-tracker.ts @@ -16,7 +16,7 @@ export interface ContextStats { const DEFAULT_CONFIG: ContextConfig = { maxContextTokens: 128_000, // Default for most modern models - compactionThreshold: 0.75, // Compact when 75% of context is used + compactionThreshold: 0.85, // Compact when 85% of context is used }; export class ContextTracker { @@ -24,6 +24,7 @@ export class ContextTracker { private totalInputTokens = 0; private totalOutputTokens = 0; private stepCount = 0; + private currentContextTokens: number | null = null; constructor(config: Partial = {}) { this.config = { ...DEFAULT_CONFIG, ...config }; @@ -46,6 +47,13 @@ export class ContextTracker { this.stepCount++; } + /** + * Set the exact current context token count. + */ + setContextTokens(tokens: number): void { + this.currentContextTokens = Math.max(0, Math.round(tokens)); + } + /** * Set total usage directly (useful after compaction or when loading state) */ @@ -67,7 +75,8 @@ export class ContextTracker { } getStats(): ContextStats { - const totalTokens = this.totalInputTokens + this.totalOutputTokens; + const totalTokens = + this.currentContextTokens ?? this.getEstimatedContextTokens(); const usagePercentage = totalTokens / this.config.maxContextTokens; const shouldCompact = usagePercentage >= this.config.compactionThreshold; @@ -89,6 +98,7 @@ export class ContextTracker { this.totalInputTokens = 0; this.totalOutputTokens = 0; this.stepCount = 0; + this.currentContextTokens = 0; } /** @@ -99,6 +109,7 @@ export class ContextTracker { this.totalInputTokens = newInputTokens; this.totalOutputTokens = 0; this.stepCount = 1; + this.currentContextTokens = Math.max(0, Math.round(newInputTokens)); } getConfig(): ContextConfig { diff --git a/src/utils/render-api.ts b/src/utils/render-api.ts new file mode 100644 index 0000000..9cf4a1c --- /dev/null +++ b/src/utils/render-api.ts @@ -0,0 +1,228 @@ +import type { ModelMessage, ToolModelMessage, ToolResultPart } from "ai"; +import { env } from "../env"; +import { tools } from "../tools/index"; + +export interface OpenAITool { + type: "function"; + function: { + name: string; + description: string; + parameters: Record; + }; +} + +interface SchemaWithToJSON { + toJSONSchema: () => Record; +} + +export interface RenderAPIMessage { + role: "system" | "user" | "assistant" | "tool"; + content: string | null; + name?: string; + tool_calls?: Array<{ + id: string; + type: "function"; + function: { name: string; arguments: string }; + }>; + tool_call_id?: string; +} + +export interface RenderApiOptions { + onError?: (message: string) => void; +} + +const extractTextContent = ( + parts: Array<{ type: string; text?: string }> +): string => { + return parts + .filter((part) => part.type === "text") + .map((part) => part.text ?? "") + .join(""); +}; + +const determineAssistantContent = ( + textParts: Array<{ type: string; text?: string }>, + hasToolCalls: boolean +): string | null => { + if (textParts.length > 0) { + return extractTextContent(textParts); + } + if (hasToolCalls) { + return null; + } + return ""; +}; + +const convertUserMessage = (msg: ModelMessage): RenderAPIMessage => { + const content = Array.isArray(msg.content) + ? extractTextContent(msg.content) + : msg.content; + return { role: "user", content }; +}; + +const convertAssistantMessage = (msg: ModelMessage): RenderAPIMessage => { + const contentArray = Array.isArray(msg.content) ? msg.content : []; + const textParts = contentArray.filter((part) => part.type === "text"); + const toolCallParts = contentArray.filter( + (part) => part.type === "tool-call" + ); + + const content = determineAssistantContent( + textParts, + toolCallParts.length > 0 + ); + const assistantMsg: RenderAPIMessage = { role: "assistant", content }; + + if (toolCallParts.length > 0) { + assistantMsg.tool_calls = toolCallParts.map((toolCall) => ({ + id: toolCall.toolCallId, + type: "function" as const, + function: { + name: toolCall.toolName, + arguments: JSON.stringify(toolCall.input), + }, + })); + } + + return assistantMsg; +}; + +const convertToolMessages = (msg: ToolModelMessage): RenderAPIMessage[] => { + const results: RenderAPIMessage[] = []; + for (const part of msg.content) { + if (part.type === "tool-result") { + const resultPart = part as ToolResultPart; + const content = + typeof resultPart.output === "string" + ? resultPart.output + : JSON.stringify(resultPart.output); + results.push({ + role: "tool", + content, + tool_call_id: resultPart.toolCallId, + }); + } + } + return results; +}; + +export const convertToRenderAPIMessages = ( + messages: ModelMessage[], + systemPrompt: string +): RenderAPIMessage[] => { + const result: RenderAPIMessage[] = [ + { role: "system", content: systemPrompt }, + ]; + + for (const msg of messages) { + if (msg.role === "user") { + result.push(convertUserMessage(msg)); + } else if (msg.role === "assistant") { + result.push(convertAssistantMessage(msg)); + } else if (msg.role === "tool") { + result.push(...convertToolMessages(msg as ToolModelMessage)); + } + } + + return result; +}; + +export const convertToolsToOpenAIFormat = (): OpenAITool[] => { + return Object.entries(tools).map(([name, tool]) => { + const schema = tool.inputSchema as unknown as SchemaWithToJSON; + return { + type: "function" as const, + function: { + name, + description: tool.description ?? "", + parameters: schema.toJSONSchema(), + }, + }; + }); +}; + +export const fetchRenderedText = async ( + messages: RenderAPIMessage[], + modelId: string, + includeTools = false, + options: RenderApiOptions = {} +): Promise => { + const body: Record = { + model: modelId, + messages, + }; + + if (includeTools) { + body.tools = convertToolsToOpenAIFormat(); + } + + const response = await fetch( + "https://api.friendli.ai/serverless/v1/chat/render", + { + method: "POST", + headers: { + Authorization: `Bearer ${env.FRIENDLI_TOKEN}`, + "Content-Type": "application/json", + }, + body: JSON.stringify(body), + } + ); + + if (!response.ok) { + const error = await response.text(); + options.onError?.(`Render API failed: ${error}`); + return null; + } + + const data = (await response.json()) as { text: string }; + return data.text; +}; + +export const fetchTokenCount = async ( + text: string, + modelId: string, + options: RenderApiOptions = {} +): Promise => { + const response = await fetch( + "https://api.friendli.ai/serverless/v1/tokenize", + { + method: "POST", + headers: { + Authorization: `Bearer ${env.FRIENDLI_TOKEN}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + model: modelId, + prompt: text, + }), + } + ); + + if (!response.ok) { + const error = await response.text(); + options.onError?.(`Tokenize API failed: ${error}`); + return null; + } + + const data = (await response.json()) as { tokens: number[] }; + return data.tokens.length; +}; + +export const measureContextTokens = async ( + messages: ModelMessage[], + modelId: string, + systemPrompt: string, + options: RenderApiOptions = {} +): Promise => { + const apiMessages = convertToRenderAPIMessages(messages, systemPrompt); + const renderedText = await fetchRenderedText( + apiMessages, + modelId, + true, + options + ); + if (renderedText === null) { + return null; + } + return fetchTokenCount(renderedText, modelId, options); +};