Skip to content

Commit d3e4234

Browse files
committed
perf: hf caching + openai alignment
1 parent bba274c commit d3e4234

File tree

3 files changed

+11
-13
lines changed

3 files changed

+11
-13
lines changed

controller.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from fastapi import status
44
from spacy.tokens import DocBin, Doc
55
from spacy.vocab import Vocab
6+
from functools import lru_cache
67

78
import json
89
import torch
@@ -636,6 +637,11 @@ def __setup_tmp_embedder(project_id: str, embedder_id: str) -> Transformer:
636637
embedder_path = util.INFERENCE_DIR / project_id / embedder_id / "embedder.json"
637638
if not embedder_path.exists():
638639
raise Exception(f"Embedder {embedder_id} not found")
640+
return __load_embedder_by_path(embedder_path)
641+
642+
643+
@lru_cache(maxsize=32)
644+
def __load_embedder_by_path(embedder_path: str) -> Transformer:
639645
with open(embedder_path, "r") as f:
640646
embedder = json.load(f)
641647
Embedder = eval(embedder["cls"])

src/embedders/classification/contextual.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def load(embedder: dict) -> "HuggingFaceSentenceEmbedder":
4444
def to_json(self) -> dict:
4545
return {
4646
"cls": "HuggingFaceSentenceEmbedder",
47-
"model_name": self.model.model_card_data.base_model,
47+
"config_string": self.model.model_card_data.base_model,
4848
"batch_size": self.batch_size,
4949
}
5050

@@ -121,19 +121,13 @@ def __init__(
121121
and api_version is not None
122122
and api_base is not None
123123
), "If you want to use Azure, you need to provide api_type, api_version and api_base."
124-
125-
@property
126-
def openai_client(self):
127-
if self.use_azure:
128-
return AzureOpenAI(
124+
self.openai_client = AzureOpenAI(
129125
api_key=self.openai_api_key,
130126
azure_endpoint=self.api_base,
131127
api_version=self.api_version,
132128
)
133-
return OpenAI(
134-
api_key=self.openai_api_key,
135-
base_url=self.api_base,
136-
)
129+
else:
130+
self.openai_client = OpenAI(api_key=self.openai_api_key)
137131

138132
def _encode(
139133
self, documents: List[Union[str, Doc]], fit_model: bool

src/embedders/util.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
import json
66
import pickle
77

8-
INFERENCE_DIR = Path(
9-
os.getenv("INFERENCE_DIR", "/Users/andhrelja/Projects/dev-setup/inference")
10-
)
8+
INFERENCE_DIR = Path(os.getenv("INFERENCE_DIR", "/inference"))
119

1210

1311
def batch(documents: List[Any], batch_size: int) -> Generator[List[Any], None, None]:

0 commit comments

Comments
 (0)