Skip to content

Commit 514ee8c

Browse files
committed
perf: update dump/load paths
perf: use pickle.dumps
1 parent 5524dc3 commit 514ee8c

File tree

6 files changed

+21
-19
lines changed

6 files changed

+21
-19
lines changed

controller.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -489,8 +489,8 @@ def delete_embedding(project_id: str, embedding_id: str) -> int:
489489
org_id = organization.get_id_by_project_id(project_id)
490490
s3.delete_object(org_id, f"{project_id}/{object_name}")
491491
request_util.delete_embedding_from_neural_search(embedding_id)
492-
json_path = util.INFERENCE_DIR / project_id / embedding_id / "embedder.json"
493-
shutil.rmtree(json_path.parent)
492+
json_path = util.INFERENCE_DIR / project_id / f"embedder-{embedding_id}.json"
493+
json_path.unlink(missing_ok=True)
494494
return status.HTTP_200_OK
495495

496496

@@ -627,7 +627,7 @@ def re_embed_records(project_id: str, changes: Dict[str, List[Dict[str, str]]]):
627627

628628

629629
def __setup_tmp_embedder(project_id: str, embedder_id: str) -> Transformer:
630-
embedder_path = util.INFERENCE_DIR / project_id / embedder_id / "embedder.json"
630+
embedder_path = util.INFERENCE_DIR / project_id / f"embedder-{embedder_id}.json"
631631
if not embedder_path.exists():
632632
raise Exception(f"Embedder {embedder_id} not found")
633633
return __load_embedder_by_path(embedder_path)

src/embedders/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ def load(self, embedder: dict) -> None:
5555
"""Loads the model configuration and weights from disk.
5656
5757
Args:
58-
project_id (str): The ID of the project.
59-
embedding_id (str): The ID of the embedding.
58+
embedder (dict): The dumped model configuration.
6059
"""
6160
raise NotImplementedError
6261

src/embedders/classification/contextual.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def to_json(self) -> dict:
4949
}
5050

5151
def dump(self, project_id: str, embedding_id: str) -> None:
52-
export_file = util.INFERENCE_DIR / project_id / embedding_id / "embedder.json"
52+
export_file = util.INFERENCE_DIR / project_id / f"embedder-{embedding_id}.json"
5353
export_file.parent.mkdir(parents=True, exist_ok=True)
5454
util.write_json(self.to_json(), export_file, indent=2)
5555

@@ -176,6 +176,7 @@ def load(embedder: dict) -> "OpenAISentenceEmbedder":
176176
model_name=embedder["model_name"],
177177
batch_size=embedder["batch_size"],
178178
openai_api_key=embedder["openai_api_key"],
179+
# only set for Azure
179180
api_base=embedder["api_base"],
180181
api_type=embedder["api_type"],
181182
api_version=embedder["api_version"],
@@ -187,13 +188,14 @@ def to_json(self) -> dict:
187188
"model_name": self.model_name,
188189
"batch_size": self.batch_size,
189190
"openai_api_key": self.openai_api_key,
191+
# only set for Azure
190192
"api_base": self.api_base,
191193
"api_type": self.api_type,
192194
"api_version": self.api_version,
193195
"use_azure": self.use_azure,
194196
}
195197

196198
def dump(self, project_id: str, embedding_id: str) -> None:
197-
export_file = util.INFERENCE_DIR / project_id / embedding_id / "embedder.json"
199+
export_file = util.INFERENCE_DIR / project_id / f"embedder-{embedding_id}.json"
198200
export_file.parent.mkdir(parents=True, exist_ok=True)
199201
util.write_json(self.to_json(), export_file, indent=2)

src/embedders/classification/reduce.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from spacy.tokens.doc import Doc
22
from typing import Union, List, Generator
33
import numpy as np
4+
import pickle
45
from src.embedders import PCAReducer, util
6+
7+
# Embedder imports are used by eval(Embedder) in load methods
58
from src.embedders.classification.contextual import (
69
OpenAISentenceEmbedder,
710
HuggingFaceSentenceEmbedder,
@@ -57,7 +60,9 @@ def _reduce(
5760

5861
@staticmethod
5962
def load(embedder: dict) -> "PCASentenceReducer":
60-
reducer = util.read_pickle(embedder["reducer_pkl"])
63+
reducer = pickle.loads(
64+
embedder["reducer_pkl_bytes"].encode("latin-1")
65+
) # Decode to latin1 to avoid binary issues in JSON
6166
Embedder = eval(embedder["embedder"]["cls"])
6267
return PCASentenceReducer(
6368
embedder=Embedder.load(embedder["embedder"]),
@@ -68,14 +73,12 @@ def to_json(self) -> dict:
6873
return {
6974
"cls": "PCASentenceReducer",
7075
"embedder": self.embedder.to_json(),
76+
"reducer_pkl_bytes": pickle.dumps(self.reducer).decode(
77+
"latin-1"
78+
), # Encode to latin1 to avoid binary issues in JSON
7179
}
7280

7381
def dump(self, project_id: str, embedding_id: str) -> None:
74-
export_file = util.INFERENCE_DIR / project_id / embedding_id / "embedder.json"
82+
export_file = util.INFERENCE_DIR / project_id / f"embedder-{embedding_id}.json"
7583
export_file.parent.mkdir(parents=True, exist_ok=True)
76-
pkl_file = util.INFERENCE_DIR / project_id / embedding_id / "reducer.pkl"
77-
util.write_pickle(self.reducer, pkl_file)
78-
79-
json_obj = self.to_json()
80-
json_obj["reducer_pkl"] = str(pkl_file)
81-
util.write_json(json_obj, export_file, indent=2)
84+
util.write_json(self.to_json(), export_file, indent=2)

src/embedders/util.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from typing import Any, Generator, List
22
from pathlib import Path
33
import numpy as np
4-
import os
54
import json
65
import pickle
76

8-
INFERENCE_DIR = Path(os.getenv("INFERENCE_DIR", "/inference"))
7+
INFERENCE_DIR = Path("/inference")
98

109

1110
def batch(documents: List[Any], batch_size: int) -> Generator[List[Any], None, None]:
@@ -28,7 +27,7 @@ def write_pickle(obj: Any, file_path: str, **kwargs) -> None:
2827
pickle.dump(obj, f, **kwargs)
2928

3029

31-
def read_json(file_path: str) -> Any:
30+
def read_json(file_path: str) -> dict[str, Any]:
3231
with open(file_path, "r") as f:
3332
return json.load(f)
3433

start

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ docker run -d --rm \
6868
-e MODEL_PROVIDER=http://refinery-model-provider:80 \
6969
-e WS_NOTIFY_ENDPOINT="http://refinery-websocket:8080" \
7070
-e NEURAL_SEARCH=http://refinery-neural-search:80 \
71-
-e INFERENCE_DIR=/inference \
7271
--mount type=bind,source="$(pwd)"/,target=/app \
7372
-v /var/run/docker.sock:/var/run/docker.sock \
7473
-v "$MODEL_DIR":/models \

0 commit comments

Comments
 (0)