Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bun.lock
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"lockfileVersion": 1,
"configVersion": 0,
"workspaces": {
"": {
"name": "mux",
Expand Down
2 changes: 1 addition & 1 deletion src/browser/stores/WorkspaceStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ export class WorkspaceStore {
if (msg.metadata?.compacted) {
continue;
}
const rawUsage = msg.metadata?.contextUsage ?? msg.metadata?.usage;
const rawUsage = msg.metadata?.contextUsage;
const providerMeta =
msg.metadata?.contextProviderMetadata ?? msg.metadata?.providerMetadata;
if (rawUsage) {
Expand Down
1 change: 1 addition & 0 deletions src/browser/stories/mockFactory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ export function createAssistantMessage(
timestamp: opts.timestamp ?? STABLE_TIMESTAMP,
model: opts.model ?? DEFAULT_MODEL,
usage: { inputTokens: 100, outputTokens: 50, totalTokens: 150 },
contextUsage: { inputTokens: 100, outputTokens: 50, totalTokens: 150 },
duration: 1000,
},
};
Expand Down
2 changes: 2 additions & 0 deletions src/common/orpc/schemas/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ export const MuxMessageSchema = z.object({
timestamp: z.number().optional(),
model: z.string().optional(),
usage: z.any().optional(),
contextUsage: z.any().optional(),
providerMetadata: z.record(z.string(), z.unknown()).optional(),
contextProviderMetadata: z.record(z.string(), z.unknown()).optional(),
duration: z.number().optional(),
systemMessageTokens: z.number().optional(),
muxMetadata: z.any().optional(),
Expand Down
63 changes: 43 additions & 20 deletions src/node/services/streamManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -334,22 +334,42 @@ export class StreamManager extends EventEmitter {
private async getStreamMetadata(
streamInfo: WorkspaceStreamInfo,
timeoutMs = 1000
): Promise<{ usage?: LanguageModelV2Usage; duration: number }> {
let usage = undefined;
): Promise<{
totalUsage?: LanguageModelV2Usage;
contextUsage?: LanguageModelV2Usage;
contextProviderMetadata?: Record<string, unknown>;
duration: number;
}> {
let totalUsage: LanguageModelV2Usage | undefined;
let contextUsage: LanguageModelV2Usage | undefined;
let contextProviderMetadata: Record<string, unknown> | undefined;

try {
// Race usage retrieval against timeout to prevent hanging on abort
// CRITICAL: Use totalUsage (sum of all steps) not usage (last step only)
// For multi-step tool calls, usage would severely undercount actual token consumption
usage = await Promise.race([
streamInfo.streamResult.totalUsage,
new Promise<undefined>((resolve) => setTimeout(() => resolve(undefined), timeoutMs)),
// Fetch all metadata in parallel with timeout
// - totalUsage: sum of all steps (for cost calculation)
// - usage: last step only (for context window display)
// - providerMetadata: last step (for context window cache display)
const [total, context, contextMeta] = await Promise.race([
Promise.all([
streamInfo.streamResult.totalUsage,
streamInfo.streamResult.usage,
streamInfo.streamResult.providerMetadata,
]),
new Promise<[undefined, undefined, undefined]>((resolve) =>
setTimeout(() => resolve([undefined, undefined, undefined]), timeoutMs)
),
]);
totalUsage = total;
contextUsage = context;
contextProviderMetadata = contextMeta;
} catch (error) {
log.debug("Could not retrieve usage:", error);
log.debug("Could not retrieve stream metadata:", error);
}

return {
usage,
totalUsage,
contextUsage,
contextProviderMetadata,
duration: Date.now() - streamInfo.startTime,
};
}
Expand Down Expand Up @@ -1071,17 +1091,20 @@ export class StreamManager extends EventEmitter {

// Check if stream completed successfully
if (!streamInfo.abortController.signal.aborted) {
// Get usage, duration, and provider metadata from stream result
// CRITICAL: Use totalUsage (via getStreamMetadata) and aggregated providerMetadata
// to correctly account for all steps in multi-tool-call conversations
const { usage, duration } = await this.getStreamMetadata(streamInfo);
// Get all metadata from stream result in one call
// - totalUsage: sum of all steps (for cost calculation)
// - contextUsage: last step only (for context window display)
// - contextProviderMetadata: last step (for context window cache tokens)
// Falls back to tracked values from finish-step if streamResult fails/times out
const streamMeta = await this.getStreamMetadata(streamInfo);
const totalUsage = streamMeta.totalUsage;
const contextUsage = streamMeta.contextUsage ?? streamInfo.lastStepUsage;
const contextProviderMetadata =
streamMeta.contextProviderMetadata ?? streamInfo.lastStepProviderMetadata;
const duration = streamMeta.duration;
// Aggregated provider metadata across all steps (for cost calculation with cache tokens)
const providerMetadata = await this.getAggregatedProviderMetadata(streamInfo);

// For context window display, use last step's usage (inputTokens = current context size)
// This is stored in streamInfo during finish-step handling
const contextUsage = streamInfo.lastStepUsage;
const contextProviderMetadata = streamInfo.lastStepProviderMetadata;

// Emit stream end event with parts preserved in temporal order
const streamEndEvent: StreamEndEvent = {
type: "stream-end",
Expand All @@ -1090,7 +1113,7 @@ export class StreamManager extends EventEmitter {
metadata: {
...streamInfo.initialMetadata, // AIService-provided metadata (systemMessageTokens, etc)
model: streamInfo.model,
usage, // Total across all steps (for cost calculation)
usage: totalUsage, // Total across all steps (for cost calculation)
contextUsage, // Last step only (for context window display)
providerMetadata, // Aggregated (for cost calculation)
contextProviderMetadata, // Last step (for context window display)
Expand Down