Skip to content

Commit 30a23ec

Browse files
authored
⚡️ perf: enable smoothing for all providers (lobehub#7062)
* update * 尝试修正 tests * fix tests config * refactor agent-runtime tests * refactor agent-runtime tests * refactor agent-runtime tests * fix tests * fix tests
1 parent 3450714 commit 30a23ec

File tree

30 files changed

+314
-4658
lines changed

30 files changed

+314
-4658
lines changed

package.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@
286286
"@types/unist": "^3.0.3",
287287
"@types/uuid": "^10.0.0",
288288
"@types/ws": "^8.5.13",
289-
"@vitest/coverage-v8": "~1.2.2",
289+
"@vitest/coverage-v8": "^3.0.9",
290290
"ajv-keywords": "^5.1.0",
291291
"commitlint": "^19.6.1",
292292
"consola": "^3.3.3",
@@ -327,7 +327,7 @@
327327
"unified": "^11.0.5",
328328
"unist-util-visit": "^5.0.0",
329329
"vite": "^5.4.11",
330-
"vitest": "~1.2.2",
330+
"vitest": "^3.0.9",
331331
"vitest-canvas-mock": "^0.3.3"
332332
},
333333
"packageManager": "pnpm@9.15.9",

packages/web-crawler/src/utils/htmlToMarkdown.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,6 @@ describe('htmlToMarkdown', () => {
3030
const data = htmlToMarkdown(html, { url: item.url, filterOptions: item.filterOptions || {} });
3131

3232
expect(data).toMatchSnapshot();
33-
});
33+
}, 10000);
3434
});
3535
});

src/database/models/__tests__/knowledgeBase.test.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { and, eq } from 'drizzle-orm/expressions';
33
import { afterEach, beforeEach, describe, expect, it } from 'vitest';
44

55
import { LobeChatDatabase } from '@/database/type';
6+
import { sleep } from '@/utils/sleep';
67

78
import {
89
NewKnowledgeBase,
@@ -93,6 +94,7 @@ describe('KnowledgeBaseModel', () => {
9394
describe('query', () => {
9495
it('should query knowledge bases for the user', async () => {
9596
await knowledgeBaseModel.create({ name: 'Test Group 1' });
97+
await sleep(50);
9698
await knowledgeBaseModel.create({ name: 'Test Group 2' });
9799

98100
const userGroups = await knowledgeBaseModel.query();
Lines changed: 8 additions & 250 deletions
Original file line numberDiff line numberDiff line change
@@ -1,255 +1,13 @@
11
// @vitest-environment node
2-
import OpenAI from 'openai';
3-
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
2+
import { ModelProvider } from '@/libs/agent-runtime';
3+
import { testProvider } from '@/libs/agent-runtime/providerTestUtils';
44

5-
import {
6-
ChatStreamCallbacks,
7-
LobeOpenAICompatibleRuntime,
8-
ModelProvider,
9-
} from '@/libs/agent-runtime';
10-
11-
import * as debugStreamModule from '../utils/debugStream';
125
import { LobeAi21AI } from './index';
136

14-
const provider = ModelProvider.Ai21;
15-
const defaultBaseURL = 'https://api.ai21.com/studio/v1';
16-
17-
const bizErrorType = 'ProviderBizError';
18-
const invalidErrorType = 'InvalidProviderAPIKey';
19-
20-
// Mock the console.error to avoid polluting test output
21-
vi.spyOn(console, 'error').mockImplementation(() => {});
22-
23-
let instance: LobeOpenAICompatibleRuntime;
24-
25-
beforeEach(() => {
26-
instance = new LobeAi21AI({ apiKey: 'test' });
27-
28-
// 使用 vi.spyOn 来模拟 chat.completions.create 方法
29-
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
30-
new ReadableStream() as any,
31-
);
32-
});
33-
34-
afterEach(() => {
35-
vi.clearAllMocks();
36-
});
37-
38-
describe('LobeAi21AI', () => {
39-
describe('init', () => {
40-
it('should correctly initialize with an API key', async () => {
41-
const instance = new LobeAi21AI({ apiKey: 'test_api_key' });
42-
expect(instance).toBeInstanceOf(LobeAi21AI);
43-
expect(instance.baseURL).toEqual(defaultBaseURL);
44-
});
45-
});
46-
47-
describe('chat', () => {
48-
describe('Error', () => {
49-
it('should return OpenAIBizError with an openai error response when OpenAI.APIError is thrown', async () => {
50-
// Arrange
51-
const apiError = new OpenAI.APIError(
52-
400,
53-
{
54-
status: 400,
55-
error: {
56-
message: 'Bad Request',
57-
},
58-
},
59-
'Error message',
60-
{},
61-
);
62-
63-
vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError);
64-
65-
// Act
66-
try {
67-
await instance.chat({
68-
messages: [{ content: 'Hello', role: 'user' }],
69-
model: 'jamba-1.5-mini',
70-
temperature: 0,
71-
});
72-
} catch (e) {
73-
expect(e).toEqual({
74-
endpoint: defaultBaseURL,
75-
error: {
76-
error: { message: 'Bad Request' },
77-
status: 400,
78-
},
79-
errorType: bizErrorType,
80-
provider,
81-
});
82-
}
83-
});
84-
85-
it('should throw AgentRuntimeError with NoOpenAIAPIKey if no apiKey is provided', async () => {
86-
try {
87-
new LobeAi21AI({});
88-
} catch (e) {
89-
expect(e).toEqual({ errorType: invalidErrorType });
90-
}
91-
});
92-
93-
it('should return OpenAIBizError with the cause when OpenAI.APIError is thrown with cause', async () => {
94-
// Arrange
95-
const errorInfo = {
96-
stack: 'abc',
97-
cause: {
98-
message: 'api is undefined',
99-
},
100-
};
101-
const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {});
102-
103-
vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError);
104-
105-
// Act
106-
try {
107-
await instance.chat({
108-
messages: [{ content: 'Hello', role: 'user' }],
109-
model: 'jamba-1.5-mini',
110-
temperature: 0,
111-
});
112-
} catch (e) {
113-
expect(e).toEqual({
114-
endpoint: defaultBaseURL,
115-
error: {
116-
cause: { message: 'api is undefined' },
117-
stack: 'abc',
118-
},
119-
errorType: bizErrorType,
120-
provider,
121-
});
122-
}
123-
});
124-
125-
it('should return OpenAIBizError with an cause response with desensitize Url', async () => {
126-
// Arrange
127-
const errorInfo = {
128-
stack: 'abc',
129-
cause: { message: 'api is undefined' },
130-
};
131-
const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {});
132-
133-
instance = new LobeAi21AI({
134-
apiKey: 'test',
135-
136-
baseURL: 'https://api.abc.com/v1',
137-
});
138-
139-
vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError);
140-
141-
// Act
142-
try {
143-
await instance.chat({
144-
messages: [{ content: 'Hello', role: 'user' }],
145-
model: 'jamba-1.5-mini',
146-
temperature: 0,
147-
});
148-
} catch (e) {
149-
expect(e).toEqual({
150-
endpoint: 'https://api.***.com/v1',
151-
error: {
152-
cause: { message: 'api is undefined' },
153-
stack: 'abc',
154-
},
155-
errorType: bizErrorType,
156-
provider,
157-
});
158-
}
159-
});
160-
161-
it('should throw an InvalidAi21APIKey error type on 401 status code', async () => {
162-
// Mock the API call to simulate a 401 error
163-
const error = new Error('Unauthorized') as any;
164-
error.status = 401;
165-
vi.mocked(instance['client'].chat.completions.create).mockRejectedValue(error);
166-
167-
try {
168-
await instance.chat({
169-
messages: [{ content: 'Hello', role: 'user' }],
170-
model: 'jamba-1.5-mini',
171-
temperature: 0,
172-
});
173-
} catch (e) {
174-
// Expect the chat method to throw an error with InvalidAi21APIKey
175-
expect(e).toEqual({
176-
endpoint: defaultBaseURL,
177-
error: new Error('Unauthorized'),
178-
errorType: invalidErrorType,
179-
provider,
180-
});
181-
}
182-
});
183-
184-
it('should return AgentRuntimeError for non-OpenAI errors', async () => {
185-
// Arrange
186-
const genericError = new Error('Generic Error');
187-
188-
vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(genericError);
189-
190-
// Act
191-
try {
192-
await instance.chat({
193-
messages: [{ content: 'Hello', role: 'user' }],
194-
model: 'jamba-1.5-mini',
195-
temperature: 0,
196-
});
197-
} catch (e) {
198-
expect(e).toEqual({
199-
endpoint: defaultBaseURL,
200-
errorType: 'AgentRuntimeError',
201-
provider,
202-
error: {
203-
name: genericError.name,
204-
cause: genericError.cause,
205-
message: genericError.message,
206-
stack: genericError.stack,
207-
},
208-
});
209-
}
210-
});
211-
});
212-
213-
describe('DEBUG', () => {
214-
it('should call debugStream and return StreamingTextResponse when DEBUG_AI21_CHAT_COMPLETION is 1', async () => {
215-
// Arrange
216-
const mockProdStream = new ReadableStream() as any; // 模拟的 prod 流
217-
const mockDebugStream = new ReadableStream({
218-
start(controller) {
219-
controller.enqueue('Debug stream content');
220-
controller.close();
221-
},
222-
}) as any;
223-
mockDebugStream.toReadableStream = () => mockDebugStream; // 添加 toReadableStream 方法
224-
225-
// 模拟 chat.completions.create 返回值,包括模拟的 tee 方法
226-
(instance['client'].chat.completions.create as Mock).mockResolvedValue({
227-
tee: () => [mockProdStream, { toReadableStream: () => mockDebugStream }],
228-
});
229-
230-
// 保存原始环境变量值
231-
const originalDebugValue = process.env.DEBUG_AI21_CHAT_COMPLETION;
232-
233-
// 模拟环境变量
234-
process.env.DEBUG_AI21_CHAT_COMPLETION = '1';
235-
vi.spyOn(debugStreamModule, 'debugStream').mockImplementation(() => Promise.resolve());
236-
237-
// 执行测试
238-
// 运行你的测试函数,确保它会在条件满足时调用 debugStream
239-
// 假设的测试函数调用,你可能需要根据实际情况调整
240-
await instance.chat({
241-
messages: [{ content: 'Hello', role: 'user' }],
242-
model: 'jamba-1.5-mini',
243-
stream: true,
244-
temperature: 0,
245-
});
246-
247-
// 验证 debugStream 被调用
248-
expect(debugStreamModule.debugStream).toHaveBeenCalled();
249-
250-
// 恢复原始环境变量值
251-
process.env.DEBUG_AI21_CHAT_COMPLETION = originalDebugValue;
252-
});
253-
});
254-
});
7+
testProvider({
8+
Runtime: LobeAi21AI,
9+
provider: ModelProvider.Ai21,
10+
defaultBaseURL: 'https://api.ai21.com/studio/v1',
11+
chatDebugEnv: 'DEBUG_AI21_CHAT_COMPLETION',
12+
chatModel: 'deepseek-r1',
25513
});

0 commit comments

Comments
 (0)