Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 26 additions & 25 deletions controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,28 @@
from fastapi import status
from spacy.tokens import DocBin, Doc
from spacy.vocab import Vocab
from functools import lru_cache

import pickle
import json
import torch
import traceback
import logging
import time
import zlib
import gc
import os
import openai
import pandas as pd
import shutil
from openai import APIConnectionError

from src.embedders import Transformer
from src.embedders import Transformer, util

# Embedder imports are used by eval(Embedder) in __setup_tmp_embedder
from src.embedders.classification.contextual import (
OpenAISentenceEmbedder,
HuggingFaceSentenceEmbedder,
)
from src.embedders.classification.reduce import PCASentenceReducer
from src.util import daemon, request_util
from src.util.decorator import param_throttle
from src.util.embedders import get_embedder
Expand Down Expand Up @@ -339,7 +348,7 @@ def run_encoding(
enums.EmbeddingState.ENCODING.value,
initial_count,
)
except openai.error.APIConnectionError as e:
except APIConnectionError as e:
embedding.update_embedding_state_failed(
project_id,
embedding_id,
Expand Down Expand Up @@ -407,9 +416,6 @@ def run_encoding(
notification_message = "Access denied due to invalid api key."
elif platform == enums.EmbeddingPlatform.AZURE.value:
notification_message = "Access denied due to invalid subscription key or wrong endpoint data."
elif error_message == "invalid api token":
# cohere
notification_message = "Access denied due to invalid api token."
notification.create(
project_id,
user_id,
Expand Down Expand Up @@ -453,14 +459,7 @@ def run_encoding(
request_util.post_embedding_to_neural_search(project_id, embedding_id)

# now always since otherwise record edit wouldn't work for embedded columns
pickle_path = os.path.join(
"/inference", project_id, f"embedder-{embedding_id}.pkl"
)
if not os.path.exists(pickle_path):
os.makedirs(os.path.dirname(pickle_path), exist_ok=True)
with open(pickle_path, "wb") as f:
pickle.dump(embedder, f)

embedder.dump(project_id, embedding_id)
upload_embedding_as_file(project_id, embedding_id)
embedding.update_embedding_state_finished(
project_id,
Expand Down Expand Up @@ -490,9 +489,8 @@ def delete_embedding(project_id: str, embedding_id: str) -> int:
org_id = organization.get_id_by_project_id(project_id)
s3.delete_object(org_id, f"{project_id}/{object_name}")
request_util.delete_embedding_from_neural_search(embedding_id)
pickle_path = os.path.join("/inference", project_id, f"embedder-{embedding_id}.pkl")
if os.path.exists(pickle_path):
os.remove(pickle_path)
json_path = util.INFERENCE_DIR / project_id / f"embedder-{embedding_id}.json"
json_path.unlink(missing_ok=True)
return status.HTTP_200_OK


Expand Down Expand Up @@ -629,15 +627,18 @@ def re_embed_records(project_id: str, changes: Dict[str, List[Dict[str, str]]]):


def __setup_tmp_embedder(project_id: str, embedder_id: str) -> Transformer:
embedder_path = os.path.join(
"/inference", project_id, f"embedder-{embedder_id}.pkl"
)
if not os.path.exists(embedder_path):
embedder_path = util.INFERENCE_DIR / project_id / f"embedder-{embedder_id}.json"
if not embedder_path.exists():
raise Exception(f"Embedder {embedder_id} not found")
with open(embedder_path, "rb") as f:
embedder = pickle.load(f)
return __load_embedder_by_path(embedder_path)


return embedder
@lru_cache(maxsize=32)
def __load_embedder_by_path(embedder_path: str) -> Transformer:
with open(embedder_path, "r") as f:
embedder = json.load(f)
Embedder = eval(embedder["cls"])
return Embedder.load(embedder)


def calc_tensors(project_id: str, embedding_id: str, texts: List[str]) -> List[Any]:
Expand Down
51 changes: 10 additions & 41 deletions gpu-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,6 @@
#
--extra-index-url https://download.pytorch.org/whl/cu113

aiohappyeyeballs==2.6.1
# via aiohttp
aiohttp==3.12.14
# via openai
aiosignal==1.4.0
# via aiohttp
annotated-types==0.7.0
# via
# -r requirements/torch-cuda-requirements.txt
Expand All @@ -20,6 +14,7 @@ anyio==4.9.0
# via
# -r requirements/torch-cuda-requirements.txt
# httpx
# openai
# starlette
argon2-cffi==25.1.0
# via
Expand All @@ -29,8 +24,6 @@ argon2-cffi-bindings==21.2.0
# via
# -r requirements/torch-cuda-requirements.txt
# argon2-cffi
attrs==25.3.0
# via aiohttp
blis==0.7.11
# via thinc
boto3==1.39.6
Expand Down Expand Up @@ -67,8 +60,6 @@ click==8.2.1
# uvicorn
cloudpathlib==0.21.1
# via weasel
cohere==5.16.1
# via -r requirements/gpu-requirements.in
confection==0.1.5
# via
# thinc
Expand All @@ -78,20 +69,16 @@ cymem==2.0.11
# preshed
# spacy
# thinc
distro==1.9.0
# via openai
fastapi==0.116.1
# via -r requirements/torch-cuda-requirements.txt
fastavro==1.11.1
# via cohere
filelock==3.18.0
# via
# -r requirements/torch-cuda-requirements.txt
# huggingface-hub
# torch
# transformers
frozenlist==1.7.0
# via
# aiohttp
# aiosignal
fsspec==2025.7.0
# via
# -r requirements/torch-cuda-requirements.txt
Expand All @@ -109,9 +96,7 @@ hf-xet==1.1.5
httpcore==1.0.9
# via httpx
httpx==0.28.1
# via cohere
httpx-sse==0.4.0
# via cohere
# via openai
huggingface-hub==0.33.4
# via
# -r requirements/torch-cuda-requirements.txt
Expand All @@ -124,12 +109,13 @@ idna==3.10
# anyio
# httpx
# requests
# yarl
jinja2==3.1.6
# via
# -r requirements/torch-cuda-requirements.txt
# spacy
# torch
jiter==0.10.0
# via openai
jmespath==1.0.1
# via
# -r requirements/torch-cuda-requirements.txt
Expand Down Expand Up @@ -160,10 +146,6 @@ mpmath==1.3.0
# via
# -r requirements/torch-cuda-requirements.txt
# sympy
multidict==6.6.3
# via
# aiohttp
# yarl
murmurhash==1.0.13
# via
# preshed
Expand All @@ -185,7 +167,7 @@ numpy==1.23.4
# thinc
# torchvision
# transformers
openai==0.28.1
openai==1.97.1
# via -r requirements/gpu-requirements.in
packaging==25.0
# via
Expand All @@ -205,10 +187,6 @@ preshed==3.0.10
# via
# spacy
# thinc
propcache==0.3.2
# via
# aiohttp
# yarl
psycopg2-binary==2.9.9
# via -r requirements/torch-cuda-requirements.txt
pyaml==25.7.0
Expand All @@ -226,16 +204,15 @@ pycryptodome==3.23.0
pydantic==2.7.4
# via
# -r requirements/torch-cuda-requirements.txt
# cohere
# confection
# fastapi
# openai
# spacy
# thinc
# weasel
pydantic-core==2.18.4
# via
# -r requirements/torch-cuda-requirements.txt
# cohere
# pydantic
pygments==2.19.2
# via rich
Expand All @@ -261,9 +238,7 @@ regex==2024.11.6
requests==2.32.4
# via
# -r requirements/torch-cuda-requirements.txt
# cohere
# huggingface-hub
# openai
# spacy
# transformers
# weasel
Expand Down Expand Up @@ -304,6 +279,7 @@ sniffio==1.3.1
# via
# -r requirements/torch-cuda-requirements.txt
# anyio
# openai
spacy==3.7.5
# via -r requirements/gpu-requirements.in
spacy-legacy==3.0.12
Expand Down Expand Up @@ -335,7 +311,6 @@ threadpoolctl==3.6.0
tokenizers==0.21.2
# via
# -r requirements/torch-cuda-requirements.txt
# cohere
# transformers
torch==2.7.1
# via
Expand All @@ -360,17 +335,14 @@ typer==0.16.0
# via
# spacy
# weasel
types-requests==2.32.4.20250611
# via cohere
typing-extensions==4.14.1
# via
# -r requirements/torch-cuda-requirements.txt
# aiosignal
# anyio
# cohere
# fastapi
# huggingface-hub
# minio
# openai
# pydantic
# pydantic-core
# sentence-transformers
Expand All @@ -383,7 +355,6 @@ urllib3==2.5.0
# botocore
# minio
# requests
# types-requests
uvicorn==0.35.0
# via -r requirements/torch-cuda-requirements.txt
wasabi==1.1.3
Expand All @@ -395,8 +366,6 @@ weasel==0.4.1
# via spacy
wrapt==1.17.2
# via smart-open
yarl==1.20.1
# via aiohttp

# The following packages are considered to be unsafe in a requirements file:
# setuptools
Loading