diff --git a/src/llama_stack_client/_client.py b/src/llama_stack_client/_client.py index 1f0c1d4e..718407cc 100644 --- a/src/llama_stack_client/_client.py +++ b/src/llama_stack_client/_client.py @@ -103,6 +103,7 @@ def __init__( self, *, base_url: str | httpx.URL | None = None, + api_key: str | None = None, timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, @@ -125,17 +126,22 @@ def __init__( """Construct a new synchronous llama-stack-client client instance.""" if base_url is None: base_url = os.environ.get("LLAMA_STACK_CLIENT_BASE_URL") + if api_key is None: + api_key = os.environ.get("LLAMA_STACK_CLIENT_API_KEY") if base_url is None: base_url = f"http://any-hosted-llama-stack.com" custom_headers = default_headers or {} custom_headers["X-LlamaStack-Client-Version"] = __version__ + if api_key is not None: + custom_headers["Authorization"] = f"Bearer {api_key}" if provider_data is not None: custom_headers["X-LlamaStack-Provider-Data"] = json.dumps(provider_data) super().__init__( version=__version__, base_url=base_url, + api_key=api_key, max_retries=max_retries, timeout=timeout, http_client=http_client, @@ -188,6 +194,7 @@ def copy( self, *, base_url: str | httpx.URL | None = None, + api_key: str | None = None, timeout: float | Timeout | None | NotGiven = NOT_GIVEN, http_client: httpx.Client | None = None, max_retries: int | NotGiven = NOT_GIVEN, @@ -221,6 +228,7 @@ def copy( http_client = http_client or self._client return self.__class__( base_url=base_url or self.base_url, + api_key=api_key or self.api_key, timeout=self.timeout if isinstance(timeout, NotGiven) else timeout, http_client=http_client, max_retries=max_retries if is_given(max_retries) else self.max_retries, @@ -300,6 +308,7 @@ def __init__( self, *, base_url: str | httpx.URL | None = None, + api_key: str | None = None, timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, @@ -322,17 +331,22 @@ def __init__( """Construct a new async llama-stack-client client instance.""" if base_url is None: base_url = os.environ.get("LLAMA_STACK_CLIENT_BASE_URL") + if api_key is None: + api_key = os.environ.get("LLAMA_STACK_CLIENT_API_KEY") if base_url is None: base_url = f"http://any-hosted-llama-stack.com" custom_headers = default_headers or {} custom_headers["X-LlamaStack-Client-Version"] = __version__ + if api_key is not None: + custom_headers["Authorization"] = f"Bearer {api_key}" if provider_data is not None: custom_headers["X-LlamaStack-Provider-Data"] = json.dumps(provider_data) super().__init__( version=__version__, base_url=base_url, + api_key=api_key, max_retries=max_retries, timeout=timeout, http_client=http_client, @@ -385,6 +399,7 @@ def copy( self, *, base_url: str | httpx.URL | None = None, + api_key: str | None = None, timeout: float | Timeout | None | NotGiven = NOT_GIVEN, http_client: httpx.AsyncClient | None = None, max_retries: int | NotGiven = NOT_GIVEN, @@ -418,6 +433,7 @@ def copy( http_client = http_client or self._client return self.__class__( base_url=base_url or self.base_url, + api_key=api_key or self.api_key, timeout=self.timeout if isinstance(timeout, NotGiven) else timeout, http_client=http_client, max_retries=max_retries if is_given(max_retries) else self.max_retries, diff --git a/src/llama_stack_client/resources/inference.py b/src/llama_stack_client/resources/inference.py index 8971a921..c837a6e4 100644 --- a/src/llama_stack_client/resources/inference.py +++ b/src/llama_stack_client/resources/inference.py @@ -272,6 +272,8 @@ def chat_completion( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ChatCompletionResponse | Stream[ChatCompletionResponseStreamChunk]: + if stream: + extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})} return self._post( "/v1/inference/chat-completion", body=maybe_transform( @@ -451,6 +453,8 @@ def completion( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> CompletionResponse | Stream[CompletionResponse]: + if stream: + extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})} return self._post( "/v1/inference/completion", body=maybe_transform( @@ -751,6 +755,8 @@ async def chat_completion( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ChatCompletionResponse | AsyncStream[ChatCompletionResponseStreamChunk]: + if stream: + extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})} return await self._post( "/v1/inference/chat-completion", body=await async_maybe_transform( @@ -930,6 +936,8 @@ async def completion( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> CompletionResponse | AsyncStream[CompletionResponse]: + if stream: + extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})} return await self._post( "/v1/inference/completion", body=await async_maybe_transform(