Skip to content

Commit ddb93ca

Browse files
authored
feat: add updated batch inference types (#220)
See llamastack/llama-stack#1945
1 parent cc072c8 commit ddb93ca

34 files changed

+840
-1029
lines changed

src/llama_stack_client/_client.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
benchmarks,
4242
toolgroups,
4343
vector_dbs,
44-
batch_inference,
4544
scoring_functions,
4645
synthetic_data_generation,
4746
)
@@ -74,7 +73,6 @@ class LlamaStackClient(SyncAPIClient):
7473
tools: tools.ToolsResource
7574
tool_runtime: tool_runtime.ToolRuntimeResource
7675
agents: agents.AgentsResource
77-
batch_inference: batch_inference.BatchInferenceResource
7876
datasets: datasets.DatasetsResource
7977
eval: eval.EvalResource
8078
inspect: inspect.InspectResource
@@ -155,7 +153,6 @@ def __init__(
155153
self.tools = tools.ToolsResource(self)
156154
self.tool_runtime = tool_runtime.ToolRuntimeResource(self)
157155
self.agents = agents.AgentsResource(self)
158-
self.batch_inference = batch_inference.BatchInferenceResource(self)
159156
self.datasets = datasets.DatasetsResource(self)
160157
self.eval = eval.EvalResource(self)
161158
self.inspect = inspect.InspectResource(self)
@@ -288,7 +285,6 @@ class AsyncLlamaStackClient(AsyncAPIClient):
288285
tools: tools.AsyncToolsResource
289286
tool_runtime: tool_runtime.AsyncToolRuntimeResource
290287
agents: agents.AsyncAgentsResource
291-
batch_inference: batch_inference.AsyncBatchInferenceResource
292288
datasets: datasets.AsyncDatasetsResource
293289
eval: eval.AsyncEvalResource
294290
inspect: inspect.AsyncInspectResource
@@ -369,7 +365,6 @@ def __init__(
369365
self.tools = tools.AsyncToolsResource(self)
370366
self.tool_runtime = tool_runtime.AsyncToolRuntimeResource(self)
371367
self.agents = agents.AsyncAgentsResource(self)
372-
self.batch_inference = batch_inference.AsyncBatchInferenceResource(self)
373368
self.datasets = datasets.AsyncDatasetsResource(self)
374369
self.eval = eval.AsyncEvalResource(self)
375370
self.inspect = inspect.AsyncInspectResource(self)
@@ -503,7 +498,6 @@ def __init__(self, client: LlamaStackClient) -> None:
503498
self.tools = tools.ToolsResourceWithRawResponse(client.tools)
504499
self.tool_runtime = tool_runtime.ToolRuntimeResourceWithRawResponse(client.tool_runtime)
505500
self.agents = agents.AgentsResourceWithRawResponse(client.agents)
506-
self.batch_inference = batch_inference.BatchInferenceResourceWithRawResponse(client.batch_inference)
507501
self.datasets = datasets.DatasetsResourceWithRawResponse(client.datasets)
508502
self.eval = eval.EvalResourceWithRawResponse(client.eval)
509503
self.inspect = inspect.InspectResourceWithRawResponse(client.inspect)
@@ -531,7 +525,6 @@ def __init__(self, client: AsyncLlamaStackClient) -> None:
531525
self.tools = tools.AsyncToolsResourceWithRawResponse(client.tools)
532526
self.tool_runtime = tool_runtime.AsyncToolRuntimeResourceWithRawResponse(client.tool_runtime)
533527
self.agents = agents.AsyncAgentsResourceWithRawResponse(client.agents)
534-
self.batch_inference = batch_inference.AsyncBatchInferenceResourceWithRawResponse(client.batch_inference)
535528
self.datasets = datasets.AsyncDatasetsResourceWithRawResponse(client.datasets)
536529
self.eval = eval.AsyncEvalResourceWithRawResponse(client.eval)
537530
self.inspect = inspect.AsyncInspectResourceWithRawResponse(client.inspect)
@@ -561,7 +554,6 @@ def __init__(self, client: LlamaStackClient) -> None:
561554
self.tools = tools.ToolsResourceWithStreamingResponse(client.tools)
562555
self.tool_runtime = tool_runtime.ToolRuntimeResourceWithStreamingResponse(client.tool_runtime)
563556
self.agents = agents.AgentsResourceWithStreamingResponse(client.agents)
564-
self.batch_inference = batch_inference.BatchInferenceResourceWithStreamingResponse(client.batch_inference)
565557
self.datasets = datasets.DatasetsResourceWithStreamingResponse(client.datasets)
566558
self.eval = eval.EvalResourceWithStreamingResponse(client.eval)
567559
self.inspect = inspect.InspectResourceWithStreamingResponse(client.inspect)
@@ -591,7 +583,6 @@ def __init__(self, client: AsyncLlamaStackClient) -> None:
591583
self.tools = tools.AsyncToolsResourceWithStreamingResponse(client.tools)
592584
self.tool_runtime = tool_runtime.AsyncToolRuntimeResourceWithStreamingResponse(client.tool_runtime)
593585
self.agents = agents.AsyncAgentsResourceWithStreamingResponse(client.agents)
594-
self.batch_inference = batch_inference.AsyncBatchInferenceResourceWithStreamingResponse(client.batch_inference)
595586
self.datasets = datasets.AsyncDatasetsResourceWithStreamingResponse(client.datasets)
596587
self.eval = eval.AsyncEvalResourceWithStreamingResponse(client.eval)
597588
self.inspect = inspect.AsyncInspectResourceWithStreamingResponse(client.inspect)

src/llama_stack_client/_decoders/jsonl.py

Lines changed: 0 additions & 123 deletions
This file was deleted.

src/llama_stack_client/_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ def set_pydantic_config(typ: Any, config: pydantic.ConfigDict) -> None:
681681
setattr(typ, "__pydantic_config__", config) # noqa: B010
682682

683683

684-
# our use of subclasssing here causes weirdness for type checkers,
684+
# our use of subclassing here causes weirdness for type checkers,
685685
# so we just pretend that we don't subclass
686686
if TYPE_CHECKING:
687687
GenericModel = BaseModel

src/llama_stack_client/_response.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER
3131
from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
3232
from ._exceptions import LlamaStackClientError, APIResponseValidationError
33-
from ._decoders.jsonl import JSONLDecoder, AsyncJSONLDecoder
3433

3534
if TYPE_CHECKING:
3635
from ._models import FinalRequestOptions
@@ -139,27 +138,6 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T:
139138

140139
origin = get_origin(cast_to) or cast_to
141140

142-
if inspect.isclass(origin):
143-
if issubclass(cast(Any, origin), JSONLDecoder):
144-
return cast(
145-
R,
146-
cast("type[JSONLDecoder[Any]]", cast_to)(
147-
raw_iterator=self.http_response.iter_bytes(chunk_size=64),
148-
line_type=extract_type_arg(cast_to, 0),
149-
http_response=self.http_response,
150-
),
151-
)
152-
153-
if issubclass(cast(Any, origin), AsyncJSONLDecoder):
154-
return cast(
155-
R,
156-
cast("type[AsyncJSONLDecoder[Any]]", cast_to)(
157-
raw_iterator=self.http_response.aiter_bytes(chunk_size=64),
158-
line_type=extract_type_arg(cast_to, 0),
159-
http_response=self.http_response,
160-
),
161-
)
162-
163141
if self._is_sse_stream:
164142
if to:
165143
if not is_stream_class_type(to):

src/llama_stack_client/_utils/_transform.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
import pathlib
66
from typing import Any, Mapping, TypeVar, cast
77
from datetime import date, datetime
8-
from typing_extensions import Literal, get_args, override, get_type_hints
8+
from typing_extensions import Literal, get_args, override, get_type_hints as _get_type_hints
99

1010
import anyio
1111
import pydantic
1212

1313
from ._utils import (
1414
is_list,
15+
is_given,
16+
lru_cache,
1517
is_mapping,
1618
is_iterable,
1719
)
@@ -108,6 +110,7 @@ class Params(TypedDict, total=False):
108110
return cast(_T, transformed)
109111

110112

113+
@lru_cache(maxsize=8096)
111114
def _get_annotated_type(type_: type) -> type | None:
112115
"""If the given type is an `Annotated` type then it is returned, if not `None` is returned.
113116
@@ -126,7 +129,7 @@ def _get_annotated_type(type_: type) -> type | None:
126129
def _maybe_transform_key(key: str, type_: type) -> str:
127130
"""Transform the given `data` based on the annotations provided in `type_`.
128131
129-
Note: this function only looks at `Annotated` types that contain `PropertInfo` metadata.
132+
Note: this function only looks at `Annotated` types that contain `PropertyInfo` metadata.
130133
"""
131134
annotated_type = _get_annotated_type(type_)
132135
if annotated_type is None:
@@ -142,6 +145,10 @@ def _maybe_transform_key(key: str, type_: type) -> str:
142145
return key
143146

144147

148+
def _no_transform_needed(annotation: type) -> bool:
149+
return annotation == float or annotation == int
150+
151+
145152
def _transform_recursive(
146153
data: object,
147154
*,
@@ -184,6 +191,15 @@ def _transform_recursive(
184191
return cast(object, data)
185192

186193
inner_type = extract_type_arg(stripped_type, 0)
194+
if _no_transform_needed(inner_type):
195+
# for some types there is no need to transform anything, so we can get a small
196+
# perf boost from skipping that work.
197+
#
198+
# but we still need to convert to a list to ensure the data is json-serializable
199+
if is_list(data):
200+
return data
201+
return list(data)
202+
187203
return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
188204

189205
if is_union_type(stripped_type):
@@ -245,6 +261,11 @@ def _transform_typeddict(
245261
result: dict[str, object] = {}
246262
annotations = get_type_hints(expected_type, include_extras=True)
247263
for key, value in data.items():
264+
if not is_given(value):
265+
# we don't need to include `NotGiven` values here as they'll
266+
# be stripped out before the request is sent anyway
267+
continue
268+
248269
type_ = annotations.get(key)
249270
if type_ is None:
250271
# we do not have a type annotation for this field, leave it as is
@@ -332,6 +353,15 @@ async def _async_transform_recursive(
332353
return cast(object, data)
333354

334355
inner_type = extract_type_arg(stripped_type, 0)
356+
if _no_transform_needed(inner_type):
357+
# for some types there is no need to transform anything, so we can get a small
358+
# perf boost from skipping that work.
359+
#
360+
# but we still need to convert to a list to ensure the data is json-serializable
361+
if is_list(data):
362+
return data
363+
return list(data)
364+
335365
return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
336366

337367
if is_union_type(stripped_type):
@@ -393,10 +423,25 @@ async def _async_transform_typeddict(
393423
result: dict[str, object] = {}
394424
annotations = get_type_hints(expected_type, include_extras=True)
395425
for key, value in data.items():
426+
if not is_given(value):
427+
# we don't need to include `NotGiven` values here as they'll
428+
# be stripped out before the request is sent anyway
429+
continue
430+
396431
type_ = annotations.get(key)
397432
if type_ is None:
398433
# we do not have a type annotation for this field, leave it as is
399434
result[key] = value
400435
else:
401436
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
402437
return result
438+
439+
440+
@lru_cache(maxsize=8096)
441+
def get_type_hints(
442+
obj: Any,
443+
globalns: dict[str, Any] | None = None,
444+
localns: Mapping[str, Any] | None = None,
445+
include_extras: bool = False,
446+
) -> dict[str, Any]:
447+
return _get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras)

src/llama_stack_client/_utils/_typing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
get_origin,
1414
)
1515

16+
from ._utils import lru_cache
1617
from .._types import InheritsGeneric
1718
from .._compat import is_union as _is_union
1819

@@ -66,6 +67,7 @@ def is_type_alias_type(tp: Any, /) -> TypeIs[typing_extensions.TypeAliasType]:
6667

6768

6869
# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
70+
@lru_cache(maxsize=8096)
6971
def strip_annotated_type(typ: type) -> type:
7072
if is_required_type(typ) or is_annotated_type(typ):
7173
return strip_annotated_type(cast(type, get_args(typ)[0]))

src/llama_stack_client/lib/inference/event_logger.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,24 @@ def print(self, flush=True):
2222

2323

2424
class InferenceStreamLogEventPrinter:
25+
def __init__(self):
26+
self.is_thinking = False
27+
2528
def yield_printable_events(self, chunk):
2629
event = chunk.event
2730
if event.event_type == "start":
2831
yield InferenceStreamPrintableEvent("Assistant> ", color="cyan", end="")
2932
elif event.event_type == "progress":
30-
yield InferenceStreamPrintableEvent(event.delta.text, color="yellow", end="")
33+
if event.delta.type == "reasoning":
34+
if not self.is_thinking:
35+
yield InferenceStreamPrintableEvent("<thinking> ", color="magenta", end="")
36+
self.is_thinking = True
37+
yield InferenceStreamPrintableEvent(event.delta.reasoning, color="magenta", end="")
38+
else:
39+
if self.is_thinking:
40+
yield InferenceStreamPrintableEvent("</thinking>", color="magenta", end="")
41+
self.is_thinking = False
42+
yield InferenceStreamPrintableEvent(event.delta.text, color="yellow", end="")
3143
elif event.event_type == "complete":
3244
yield InferenceStreamPrintableEvent("")
3345

0 commit comments

Comments
 (0)