Skip to content

Commit 1fb0002

Browse files
committed
Add tests in CI
1 parent 74f7d30 commit 1fb0002

File tree

2 files changed

+51
-10
lines changed

2 files changed

+51
-10
lines changed

langchain/langchain_vectorize/retrievers.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import TYPE_CHECKING, Any, Optional
5+
from typing import TYPE_CHECKING, Any, Literal, Optional
66

77
import vectorize_client
88
from langchain_core.documents import Document
@@ -40,6 +40,8 @@ class VectorizeRetriever(BaseRetriever):
4040

4141
api_token: str
4242
"""The Vectorize API token."""
43+
environment: Literal["prod", "dev", "local", "staging"] = "prod"
44+
"""The Vectorize API environment."""
4345
organization: Optional[str] = None # noqa: UP007
4446
"""The Vectorize organization ID."""
4547
pipeline_id: Optional[str] = None # noqa: UP007
@@ -55,7 +57,23 @@ class VectorizeRetriever(BaseRetriever):
5557

5658
@override
5759
def model_post_init(self, /, context: Any) -> None:
58-
api = ApiClient(Configuration(access_token=self.api_token))
60+
header_name = None
61+
header_value = None
62+
if self.environment == "prod":
63+
host = "https://api.vectorize.io/v1"
64+
elif self.environment == "dev":
65+
host = "https://api-dev.vectorize.io/v1"
66+
elif self.environment == "local":
67+
host = "http://localhost:3000/api"
68+
header_name = "x-lambda-api-key"
69+
header_value = self.api_token
70+
else:
71+
host = "https://api-staging.vectorize.io/v1"
72+
api = ApiClient(
73+
Configuration(host=host, access_token=self.api_token, debug=True),
74+
header_name,
75+
header_value,
76+
)
5977
self._pipelines = PipelinesApi(api)
6078

6179
@staticmethod

langchain/tests/test_retrievers.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
from collections.abc import Iterator
55
from pathlib import Path
6+
from typing import Literal
67

78
import pytest
89
import urllib3
@@ -31,15 +32,23 @@ def org_id() -> str:
3132

3233

3334
@pytest.fixture(scope="session")
34-
def api_client(api_token: str) -> Iterator[ApiClient]:
35+
def environment() -> Literal["prod", "dev", "local", "staging"]:
3536
env = os.getenv("VECTORIZE_ENV", "prod")
37+
if env not in ["prod", "dev", "local", "staging"]:
38+
msg = "Invalid VECTORIZE_ENV environment variable."
39+
raise ValueError(msg)
40+
return env
41+
42+
43+
@pytest.fixture(scope="session")
44+
def api_client(api_token: str, environment: str) -> Iterator[ApiClient]:
3645
header_name = None
3746
header_value = None
38-
if env == "prod":
47+
if environment == "prod":
3948
host = "https://api.vectorize.io/v1"
40-
elif env == "dev":
49+
elif environment == "dev":
4150
host = "https://api-dev.vectorize.io/v1"
42-
elif env == "local":
51+
elif environment == "local":
4352
host = "http://localhost:3000/api"
4453
header_name = "x-lambda-api-key"
4554
header_value = api_token
@@ -148,16 +157,30 @@ def pipeline_id(api_client: v.ApiClient, org_id: str) -> Iterator[str]:
148157
logging.exception("Failed to delete pipeline %s", pipeline_id)
149158

150159

151-
def test_retrieve_init_args(api_token: str, org_id: str, pipeline_id: str) -> None:
160+
def test_retrieve_init_args(
161+
environment: Literal["prod", "dev", "local", "staging"],
162+
api_token: str,
163+
org_id: str,
164+
pipeline_id: str,
165+
) -> None:
152166
retriever = VectorizeRetriever(
153-
api_token=api_token, organization=org_id, pipeline_id=pipeline_id, num_results=2
167+
environment=environment,
168+
api_token=api_token,
169+
organization=org_id,
170+
pipeline_id=pipeline_id,
171+
num_results=2,
154172
)
155173
docs = retriever.invoke(input="What are you?")
156174
assert len(docs) == 2
157175

158176

159-
def test_retrieve_invoke_args(api_token: str, org_id: str, pipeline_id: str) -> None:
160-
retriever = VectorizeRetriever(api_token=api_token)
177+
def test_retrieve_invoke_args(
178+
environment: Literal["prod", "dev", "local", "staging"],
179+
api_token: str,
180+
org_id: str,
181+
pipeline_id: str,
182+
) -> None:
183+
retriever = VectorizeRetriever(environment=environment, api_token=api_token)
161184
docs = retriever.invoke(
162185
input="What are you?",
163186
organization=org_id,

0 commit comments

Comments
 (0)