|
3 | 3 | import os |
4 | 4 | from collections.abc import Iterator |
5 | 5 | from pathlib import Path |
| 6 | +from typing import Literal |
6 | 7 |
|
7 | 8 | import pytest |
8 | 9 | import urllib3 |
@@ -31,15 +32,23 @@ def org_id() -> str: |
31 | 32 |
|
32 | 33 |
|
33 | 34 | @pytest.fixture(scope="session") |
34 | | -def api_client(api_token: str) -> Iterator[ApiClient]: |
| 35 | +def environment() -> Literal["prod", "dev", "local", "staging"]: |
35 | 36 | env = os.getenv("VECTORIZE_ENV", "prod") |
| 37 | + if env not in ["prod", "dev", "local", "staging"]: |
| 38 | + msg = "Invalid VECTORIZE_ENV environment variable." |
| 39 | + raise ValueError(msg) |
| 40 | + return env |
| 41 | + |
| 42 | + |
| 43 | +@pytest.fixture(scope="session") |
| 44 | +def api_client(api_token: str, environment: str) -> Iterator[ApiClient]: |
36 | 45 | header_name = None |
37 | 46 | header_value = None |
38 | | - if env == "prod": |
| 47 | + if environment == "prod": |
39 | 48 | host = "https://api.vectorize.io/v1" |
40 | | - elif env == "dev": |
| 49 | + elif environment == "dev": |
41 | 50 | host = "https://api-dev.vectorize.io/v1" |
42 | | - elif env == "local": |
| 51 | + elif environment == "local": |
43 | 52 | host = "http://localhost:3000/api" |
44 | 53 | header_name = "x-lambda-api-key" |
45 | 54 | header_value = api_token |
@@ -148,16 +157,30 @@ def pipeline_id(api_client: v.ApiClient, org_id: str) -> Iterator[str]: |
148 | 157 | logging.exception("Failed to delete pipeline %s", pipeline_id) |
149 | 158 |
|
150 | 159 |
|
151 | | -def test_retrieve_init_args(api_token: str, org_id: str, pipeline_id: str) -> None: |
| 160 | +def test_retrieve_init_args( |
| 161 | + environment: Literal["prod", "dev", "local", "staging"], |
| 162 | + api_token: str, |
| 163 | + org_id: str, |
| 164 | + pipeline_id: str, |
| 165 | +) -> None: |
152 | 166 | retriever = VectorizeRetriever( |
153 | | - api_token=api_token, organization=org_id, pipeline_id=pipeline_id, num_results=2 |
| 167 | + environment=environment, |
| 168 | + api_token=api_token, |
| 169 | + organization=org_id, |
| 170 | + pipeline_id=pipeline_id, |
| 171 | + num_results=2, |
154 | 172 | ) |
155 | 173 | docs = retriever.invoke(input="What are you?") |
156 | 174 | assert len(docs) == 2 |
157 | 175 |
|
158 | 176 |
|
159 | | -def test_retrieve_invoke_args(api_token: str, org_id: str, pipeline_id: str) -> None: |
160 | | - retriever = VectorizeRetriever(api_token=api_token) |
| 177 | +def test_retrieve_invoke_args( |
| 178 | + environment: Literal["prod", "dev", "local", "staging"], |
| 179 | + api_token: str, |
| 180 | + org_id: str, |
| 181 | + pipeline_id: str, |
| 182 | +) -> None: |
| 183 | + retriever = VectorizeRetriever(environment=environment, api_token=api_token) |
161 | 184 | docs = retriever.invoke( |
162 | 185 | input="What are you?", |
163 | 186 | organization=org_id, |
|
0 commit comments