|
1 | 1 | // @vitest-environment node |
2 | | -import OpenAI from 'openai'; |
3 | | -import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; |
| 2 | +import { ModelProvider } from '@/libs/agent-runtime'; |
| 3 | +import { testProvider } from '@/libs/agent-runtime/providerTestUtils'; |
4 | 4 |
|
5 | | -import { |
6 | | - ChatStreamCallbacks, |
7 | | - LobeOpenAICompatibleRuntime, |
8 | | - ModelProvider, |
9 | | -} from '@/libs/agent-runtime'; |
10 | | - |
11 | | -import * as debugStreamModule from '../utils/debugStream'; |
12 | 5 | import { LobeAi21AI } from './index'; |
13 | 6 |
|
14 | | -const provider = ModelProvider.Ai21; |
15 | | -const defaultBaseURL = 'https://api.ai21.com/studio/v1'; |
16 | | - |
17 | | -const bizErrorType = 'ProviderBizError'; |
18 | | -const invalidErrorType = 'InvalidProviderAPIKey'; |
19 | | - |
20 | | -// Mock the console.error to avoid polluting test output |
21 | | -vi.spyOn(console, 'error').mockImplementation(() => {}); |
22 | | - |
23 | | -let instance: LobeOpenAICompatibleRuntime; |
24 | | - |
25 | | -beforeEach(() => { |
26 | | - instance = new LobeAi21AI({ apiKey: 'test' }); |
27 | | - |
28 | | - // 使用 vi.spyOn 来模拟 chat.completions.create 方法 |
29 | | - vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue( |
30 | | - new ReadableStream() as any, |
31 | | - ); |
32 | | -}); |
33 | | - |
34 | | -afterEach(() => { |
35 | | - vi.clearAllMocks(); |
36 | | -}); |
37 | | - |
38 | | -describe('LobeAi21AI', () => { |
39 | | - describe('init', () => { |
40 | | - it('should correctly initialize with an API key', async () => { |
41 | | - const instance = new LobeAi21AI({ apiKey: 'test_api_key' }); |
42 | | - expect(instance).toBeInstanceOf(LobeAi21AI); |
43 | | - expect(instance.baseURL).toEqual(defaultBaseURL); |
44 | | - }); |
45 | | - }); |
46 | | - |
47 | | - describe('chat', () => { |
48 | | - describe('Error', () => { |
49 | | - it('should return OpenAIBizError with an openai error response when OpenAI.APIError is thrown', async () => { |
50 | | - // Arrange |
51 | | - const apiError = new OpenAI.APIError( |
52 | | - 400, |
53 | | - { |
54 | | - status: 400, |
55 | | - error: { |
56 | | - message: 'Bad Request', |
57 | | - }, |
58 | | - }, |
59 | | - 'Error message', |
60 | | - {}, |
61 | | - ); |
62 | | - |
63 | | - vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); |
64 | | - |
65 | | - // Act |
66 | | - try { |
67 | | - await instance.chat({ |
68 | | - messages: [{ content: 'Hello', role: 'user' }], |
69 | | - model: 'jamba-1.5-mini', |
70 | | - temperature: 0, |
71 | | - }); |
72 | | - } catch (e) { |
73 | | - expect(e).toEqual({ |
74 | | - endpoint: defaultBaseURL, |
75 | | - error: { |
76 | | - error: { message: 'Bad Request' }, |
77 | | - status: 400, |
78 | | - }, |
79 | | - errorType: bizErrorType, |
80 | | - provider, |
81 | | - }); |
82 | | - } |
83 | | - }); |
84 | | - |
85 | | - it('should throw AgentRuntimeError with NoOpenAIAPIKey if no apiKey is provided', async () => { |
86 | | - try { |
87 | | - new LobeAi21AI({}); |
88 | | - } catch (e) { |
89 | | - expect(e).toEqual({ errorType: invalidErrorType }); |
90 | | - } |
91 | | - }); |
92 | | - |
93 | | - it('should return OpenAIBizError with the cause when OpenAI.APIError is thrown with cause', async () => { |
94 | | - // Arrange |
95 | | - const errorInfo = { |
96 | | - stack: 'abc', |
97 | | - cause: { |
98 | | - message: 'api is undefined', |
99 | | - }, |
100 | | - }; |
101 | | - const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); |
102 | | - |
103 | | - vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); |
104 | | - |
105 | | - // Act |
106 | | - try { |
107 | | - await instance.chat({ |
108 | | - messages: [{ content: 'Hello', role: 'user' }], |
109 | | - model: 'jamba-1.5-mini', |
110 | | - temperature: 0, |
111 | | - }); |
112 | | - } catch (e) { |
113 | | - expect(e).toEqual({ |
114 | | - endpoint: defaultBaseURL, |
115 | | - error: { |
116 | | - cause: { message: 'api is undefined' }, |
117 | | - stack: 'abc', |
118 | | - }, |
119 | | - errorType: bizErrorType, |
120 | | - provider, |
121 | | - }); |
122 | | - } |
123 | | - }); |
124 | | - |
125 | | - it('should return OpenAIBizError with an cause response with desensitize Url', async () => { |
126 | | - // Arrange |
127 | | - const errorInfo = { |
128 | | - stack: 'abc', |
129 | | - cause: { message: 'api is undefined' }, |
130 | | - }; |
131 | | - const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); |
132 | | - |
133 | | - instance = new LobeAi21AI({ |
134 | | - apiKey: 'test', |
135 | | - |
136 | | - baseURL: 'https://api.abc.com/v1', |
137 | | - }); |
138 | | - |
139 | | - vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); |
140 | | - |
141 | | - // Act |
142 | | - try { |
143 | | - await instance.chat({ |
144 | | - messages: [{ content: 'Hello', role: 'user' }], |
145 | | - model: 'jamba-1.5-mini', |
146 | | - temperature: 0, |
147 | | - }); |
148 | | - } catch (e) { |
149 | | - expect(e).toEqual({ |
150 | | - endpoint: 'https://api.***.com/v1', |
151 | | - error: { |
152 | | - cause: { message: 'api is undefined' }, |
153 | | - stack: 'abc', |
154 | | - }, |
155 | | - errorType: bizErrorType, |
156 | | - provider, |
157 | | - }); |
158 | | - } |
159 | | - }); |
160 | | - |
161 | | - it('should throw an InvalidAi21APIKey error type on 401 status code', async () => { |
162 | | - // Mock the API call to simulate a 401 error |
163 | | - const error = new Error('Unauthorized') as any; |
164 | | - error.status = 401; |
165 | | - vi.mocked(instance['client'].chat.completions.create).mockRejectedValue(error); |
166 | | - |
167 | | - try { |
168 | | - await instance.chat({ |
169 | | - messages: [{ content: 'Hello', role: 'user' }], |
170 | | - model: 'jamba-1.5-mini', |
171 | | - temperature: 0, |
172 | | - }); |
173 | | - } catch (e) { |
174 | | - // Expect the chat method to throw an error with InvalidAi21APIKey |
175 | | - expect(e).toEqual({ |
176 | | - endpoint: defaultBaseURL, |
177 | | - error: new Error('Unauthorized'), |
178 | | - errorType: invalidErrorType, |
179 | | - provider, |
180 | | - }); |
181 | | - } |
182 | | - }); |
183 | | - |
184 | | - it('should return AgentRuntimeError for non-OpenAI errors', async () => { |
185 | | - // Arrange |
186 | | - const genericError = new Error('Generic Error'); |
187 | | - |
188 | | - vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(genericError); |
189 | | - |
190 | | - // Act |
191 | | - try { |
192 | | - await instance.chat({ |
193 | | - messages: [{ content: 'Hello', role: 'user' }], |
194 | | - model: 'jamba-1.5-mini', |
195 | | - temperature: 0, |
196 | | - }); |
197 | | - } catch (e) { |
198 | | - expect(e).toEqual({ |
199 | | - endpoint: defaultBaseURL, |
200 | | - errorType: 'AgentRuntimeError', |
201 | | - provider, |
202 | | - error: { |
203 | | - name: genericError.name, |
204 | | - cause: genericError.cause, |
205 | | - message: genericError.message, |
206 | | - stack: genericError.stack, |
207 | | - }, |
208 | | - }); |
209 | | - } |
210 | | - }); |
211 | | - }); |
212 | | - |
213 | | - describe('DEBUG', () => { |
214 | | - it('should call debugStream and return StreamingTextResponse when DEBUG_AI21_CHAT_COMPLETION is 1', async () => { |
215 | | - // Arrange |
216 | | - const mockProdStream = new ReadableStream() as any; // 模拟的 prod 流 |
217 | | - const mockDebugStream = new ReadableStream({ |
218 | | - start(controller) { |
219 | | - controller.enqueue('Debug stream content'); |
220 | | - controller.close(); |
221 | | - }, |
222 | | - }) as any; |
223 | | - mockDebugStream.toReadableStream = () => mockDebugStream; // 添加 toReadableStream 方法 |
224 | | - |
225 | | - // 模拟 chat.completions.create 返回值,包括模拟的 tee 方法 |
226 | | - (instance['client'].chat.completions.create as Mock).mockResolvedValue({ |
227 | | - tee: () => [mockProdStream, { toReadableStream: () => mockDebugStream }], |
228 | | - }); |
229 | | - |
230 | | - // 保存原始环境变量值 |
231 | | - const originalDebugValue = process.env.DEBUG_AI21_CHAT_COMPLETION; |
232 | | - |
233 | | - // 模拟环境变量 |
234 | | - process.env.DEBUG_AI21_CHAT_COMPLETION = '1'; |
235 | | - vi.spyOn(debugStreamModule, 'debugStream').mockImplementation(() => Promise.resolve()); |
236 | | - |
237 | | - // 执行测试 |
238 | | - // 运行你的测试函数,确保它会在条件满足时调用 debugStream |
239 | | - // 假设的测试函数调用,你可能需要根据实际情况调整 |
240 | | - await instance.chat({ |
241 | | - messages: [{ content: 'Hello', role: 'user' }], |
242 | | - model: 'jamba-1.5-mini', |
243 | | - stream: true, |
244 | | - temperature: 0, |
245 | | - }); |
246 | | - |
247 | | - // 验证 debugStream 被调用 |
248 | | - expect(debugStreamModule.debugStream).toHaveBeenCalled(); |
249 | | - |
250 | | - // 恢复原始环境变量值 |
251 | | - process.env.DEBUG_AI21_CHAT_COMPLETION = originalDebugValue; |
252 | | - }); |
253 | | - }); |
254 | | - }); |
| 7 | +testProvider({ |
| 8 | + Runtime: LobeAi21AI, |
| 9 | + provider: ModelProvider.Ai21, |
| 10 | + defaultBaseURL: 'https://api.ai21.com/studio/v1', |
| 11 | + chatDebugEnv: 'DEBUG_AI21_CHAT_COMPLETION', |
| 12 | + chatModel: 'deepseek-r1', |
255 | 13 | }); |
0 commit comments