Skip to content

Commit cfffd23

Browse files
committed
refactor(model): integrate ModelAPI into model proxy and service classes
This change introduces ModelAPI as a mixin to both ModelProxy and ModelService classes, consolidating common API functionality. The completions and responses methods have been removed from these classes as they are now inherited from ModelAPI, reducing code duplication and improving maintainability. The refactoring affects multiple files including __model_proxy_async_template.py, __model_service_async_template.py, model_proxy.py, and model_service.py, along with the new model_api.py module definition. Two new files were created: agentrun/model/api/model_api.py which defines the ModelAPI base class, and examples/embedding.py which provides usage examples for embedding functionality. // 此更改将ModelAPI作为mixin引入到ModelProxy和ModelService类中, 整合了通用的API功能。completions和responses方法已从这些类中移除, 因为它们现在继承自ModelAPI,减少了代码重复并提高了可维护性。 重构影响了多个文件,包括__model_proxy_async_template.py、 __model_service_async_template.py、model_proxy.py和model_service.py, 以及新的model_api.py模块定义。 创建了两个新文件:agentrun/model/api/model_api.py定义了ModelAPI基类, examples/embedding.py提供了嵌入功能的使用示例。 Change-Id: I81045039bef65df9d7daab6a56ff2e1442fddcb7 Signed-off-by: OhYee <oyohyee@oyohyee.com>
1 parent a7e906b commit cfffd23

File tree

6 files changed

+307
-148
lines changed

6 files changed

+307
-148
lines changed

agentrun/model/__model_proxy_async_template.py

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pydash
1010

1111
from agentrun.model.api.data import BaseInfo, ModelDataAPI
12+
from agentrun.model.api.model_api import ModelAPI
1213
from agentrun.utils.config import Config
1314
from agentrun.utils.model import Status
1415
from agentrun.utils.resource import ResourceBase
@@ -30,6 +31,7 @@ class ModelProxy(
3031
ModelProxyImmutableProps,
3132
ModelProxyMutableProps,
3233
ModelProxySystemProps,
34+
ModelAPI,
3335
ResourceBase,
3436
):
3537
"""模型服务"""
@@ -230,41 +232,3 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo:
230232
)
231233

232234
return self._data_client.model_info()
233-
234-
def completions(
235-
self,
236-
messages: list,
237-
model: Optional[str] = None,
238-
stream: bool = False,
239-
config: Optional[Config] = None,
240-
**kwargs,
241-
):
242-
self.model_info(config)
243-
assert self._data_client
244-
245-
return self._data_client.completions(
246-
**kwargs,
247-
messages=messages,
248-
model=model,
249-
stream=stream,
250-
config=config,
251-
)
252-
253-
def responses(
254-
self,
255-
messages: list,
256-
model: Optional[str] = None,
257-
stream: bool = False,
258-
config: Optional[Config] = None,
259-
**kwargs,
260-
):
261-
self.model_info(config)
262-
assert self._data_client
263-
264-
return self._data_client.responses(
265-
**kwargs,
266-
messages=messages,
267-
model=model,
268-
stream=stream,
269-
config=config,
270-
)

agentrun/model/__model_service_async_template.py

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
from typing import List, Optional
88

9-
from agentrun.model.api.data import BaseInfo, ModelCompletionAPI
9+
from agentrun.model.api.data import BaseInfo
10+
from agentrun.model.api.model_api import ModelAPI
1011
from agentrun.utils.config import Config
1112
from agentrun.utils.model import PageableInput
1213
from agentrun.utils.resource import ResourceBase
@@ -27,6 +28,7 @@ class ModelService(
2728
ModelServiceImmutableProps,
2829
ModelServiceMutableProps,
2930
ModelServicesSystemProps,
31+
ModelAPI,
3032
ResourceBase,
3133
):
3234
"""模型服务"""
@@ -230,38 +232,3 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo:
230232
model=default_model,
231233
headers=cfg.get_headers(),
232234
)
233-
234-
def completions(
235-
self,
236-
messages: list,
237-
model: Optional[str] = None,
238-
stream: bool = False,
239-
**kwargs,
240-
):
241-
info = self.model_info(config=kwargs.get("config"))
242-
243-
m = ModelCompletionAPI(
244-
api_key=info.api_key or "",
245-
base_url=info.base_url or "",
246-
model=model or info.model or self.model_service_name or "",
247-
)
248-
249-
return m.completions(**kwargs, messages=messages, stream=stream)
250-
251-
def responses(
252-
self,
253-
messages: list,
254-
model: Optional[str] = None,
255-
stream: bool = False,
256-
**kwargs,
257-
):
258-
info = self.model_info(config=kwargs.get("config"))
259-
260-
m = ModelCompletionAPI(
261-
api_key=info.api_key or "",
262-
base_url=info.base_url or "",
263-
model=model or info.model or self.model_service_name or "",
264-
provider=(self.provider or "openai").lower(),
265-
)
266-
267-
return m.responses(**kwargs, messages=messages, stream=stream)

agentrun/model/api/model_api.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Optional, TYPE_CHECKING, Union
3+
4+
from .data import BaseInfo
5+
6+
if TYPE_CHECKING:
7+
from litellm import ResponseInputParam
8+
9+
10+
class ModelAPI(ABC):
11+
12+
@abstractmethod
13+
def model_info(self) -> BaseInfo:
14+
...
15+
16+
def completions(
17+
self,
18+
**kwargs,
19+
):
20+
"""
21+
Deprecated. Use completion() instead.
22+
"""
23+
import warnings
24+
25+
warnings.warn(
26+
"completions() is deprecated, use completion() instead",
27+
DeprecationWarning,
28+
stacklevel=2,
29+
)
30+
return self.completion(**kwargs)
31+
32+
def completion(
33+
self,
34+
messages=[],
35+
model: Optional[str] = None,
36+
custom_llm_provider: Optional[str] = None,
37+
**kwargs,
38+
):
39+
from litellm import completion
40+
41+
info = self.model_info()
42+
return completion(
43+
**kwargs,
44+
api_key=info.api_key,
45+
base_url=info.base_url,
46+
model=model or info.model or "",
47+
custom_llm_provider=custom_llm_provider
48+
or info.provider
49+
or "openai",
50+
messages=messages,
51+
)
52+
53+
async def acompletion(
54+
self,
55+
messages=[],
56+
model: Optional[str] = None,
57+
custom_llm_provider: Optional[str] = None,
58+
**kwargs,
59+
):
60+
from litellm import acompletion
61+
62+
info = self.model_info()
63+
return await acompletion(
64+
**kwargs,
65+
api_key=info.api_key,
66+
base_url=info.base_url,
67+
model=model or info.model or "",
68+
custom_llm_provider=custom_llm_provider
69+
or info.provider
70+
or "openai",
71+
messages=messages,
72+
)
73+
74+
def responses(
75+
self,
76+
input: Union[str, "ResponseInputParam"],
77+
model: Optional[str] = None,
78+
custom_llm_provider: Optional[str] = None,
79+
**kwargs,
80+
):
81+
from litellm import responses
82+
83+
info = self.model_info()
84+
return responses(
85+
**kwargs,
86+
api_key=info.api_key,
87+
base_url=info.base_url,
88+
model=model or info.model or "",
89+
custom_llm_provider=custom_llm_provider
90+
or info.provider
91+
or "openai",
92+
input=input,
93+
)
94+
95+
async def aresponses(
96+
self,
97+
input: Union[str, "ResponseInputParam"],
98+
model: Optional[str] = None,
99+
custom_llm_provider: Optional[str] = None,
100+
**kwargs,
101+
):
102+
from litellm import aresponses
103+
104+
info = self.model_info()
105+
return await aresponses(
106+
**kwargs,
107+
api_key=info.api_key,
108+
base_url=info.base_url,
109+
model=model or info.model or "",
110+
custom_llm_provider=custom_llm_provider
111+
or info.provider
112+
or "openai",
113+
input=input,
114+
)
115+
116+
def embedding(
117+
self,
118+
input=[],
119+
model: Optional[str] = None,
120+
custom_llm_provider: Optional[str] = None,
121+
**kwargs,
122+
):
123+
from litellm import embedding
124+
125+
info = self.model_info()
126+
return embedding(
127+
**kwargs,
128+
api_key=info.api_key,
129+
api_base=info.base_url,
130+
model=model or info.model or "",
131+
custom_llm_provider=custom_llm_provider
132+
or info.provider
133+
or "openai",
134+
input=input,
135+
)
136+
137+
def aembedding(
138+
self,
139+
input=[],
140+
model: Optional[str] = None,
141+
custom_llm_provider: Optional[str] = None,
142+
**kwargs,
143+
):
144+
from litellm import aembedding
145+
146+
info = self.model_info()
147+
return aembedding(
148+
**kwargs,
149+
api_key=info.api_key,
150+
api_base=info.base_url,
151+
model=model or info.model or "",
152+
custom_llm_provider=custom_llm_provider
153+
or info.provider
154+
or "openai",
155+
input=input,
156+
)

agentrun/model/model_proxy.py

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pydash
2020

2121
from agentrun.model.api.data import BaseInfo, ModelDataAPI
22+
from agentrun.model.api.model_api import ModelAPI
2223
from agentrun.utils.config import Config
2324
from agentrun.utils.model import Status
2425
from agentrun.utils.resource import ResourceBase
@@ -40,6 +41,7 @@ class ModelProxy(
4041
ModelProxyImmutableProps,
4142
ModelProxyMutableProps,
4243
ModelProxySystemProps,
44+
ModelAPI,
4345
ResourceBase,
4446
):
4547
"""模型服务"""
@@ -399,41 +401,3 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo:
399401
)
400402

401403
return self._data_client.model_info()
402-
403-
def completions(
404-
self,
405-
messages: list,
406-
model: Optional[str] = None,
407-
stream: bool = False,
408-
config: Optional[Config] = None,
409-
**kwargs,
410-
):
411-
self.model_info(config)
412-
assert self._data_client
413-
414-
return self._data_client.completions(
415-
**kwargs,
416-
messages=messages,
417-
model=model,
418-
stream=stream,
419-
config=config,
420-
)
421-
422-
def responses(
423-
self,
424-
messages: list,
425-
model: Optional[str] = None,
426-
stream: bool = False,
427-
config: Optional[Config] = None,
428-
**kwargs,
429-
):
430-
self.model_info(config)
431-
assert self._data_client
432-
433-
return self._data_client.responses(
434-
**kwargs,
435-
messages=messages,
436-
model=model,
437-
stream=stream,
438-
config=config,
439-
)

agentrun/model/model_service.py

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
from typing import List, Optional
1818

19-
from agentrun.model.api.data import BaseInfo, ModelCompletionAPI
19+
from agentrun.model.api.data import BaseInfo
20+
from agentrun.model.api.model_api import ModelAPI
2021
from agentrun.utils.config import Config
2122
from agentrun.utils.model import PageableInput
2223
from agentrun.utils.resource import ResourceBase
@@ -37,6 +38,7 @@ class ModelService(
3738
ModelServiceImmutableProps,
3839
ModelServiceMutableProps,
3940
ModelServicesSystemProps,
41+
ModelAPI,
4042
ResourceBase,
4143
):
4244
"""模型服务"""
@@ -401,38 +403,3 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo:
401403
model=default_model,
402404
headers=cfg.get_headers(),
403405
)
404-
405-
def completions(
406-
self,
407-
messages: list,
408-
model: Optional[str] = None,
409-
stream: bool = False,
410-
**kwargs,
411-
):
412-
info = self.model_info(config=kwargs.get("config"))
413-
414-
m = ModelCompletionAPI(
415-
api_key=info.api_key or "",
416-
base_url=info.base_url or "",
417-
model=model or info.model or self.model_service_name or "",
418-
)
419-
420-
return m.completions(**kwargs, messages=messages, stream=stream)
421-
422-
def responses(
423-
self,
424-
messages: list,
425-
model: Optional[str] = None,
426-
stream: bool = False,
427-
**kwargs,
428-
):
429-
info = self.model_info(config=kwargs.get("config"))
430-
431-
m = ModelCompletionAPI(
432-
api_key=info.api_key or "",
433-
base_url=info.base_url or "",
434-
model=model or info.model or self.model_service_name or "",
435-
provider=(self.provider or "openai").lower(),
436-
)
437-
438-
return m.responses(**kwargs, messages=messages, stream=stream)

0 commit comments

Comments
 (0)