Skip to content

Commit ab21429

Browse files
authored
[RAPTOR-12838] add streaming to chat example (#1517)
1 parent e50595d commit ab21429

File tree

4 files changed

+157
-77
lines changed

4 files changed

+157
-77
lines changed

model_templates/python3_dummy_chat/custom.py

Lines changed: 58 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,19 @@
66
"""
77
import calendar
88
import time
9-
from typing import Iterator
9+
from typing import Any
10+
import uuid
1011

1112
from openai.types.chat import ChatCompletion
1213
from openai.types.chat import ChatCompletionChunk
1314
from openai.types.chat import ChatCompletionMessage
1415
from openai.types.chat import CompletionCreateParams
1516
from openai.types.chat.chat_completion import Choice
17+
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
18+
from openai.types.chat.chat_completion_chunk import ChoiceDelta
1619
from openai.types.model import Model
1720

18-
from datarobot_drum import RuntimeParameters
19-
20-
"""
21-
This example shows how to create a text generation model supporting OpenAI chat
22-
"""
23-
24-
from typing import Any, Dict
21+
# This example shows how to create a text generation model supporting OpenAI chat
2522

2623

2724
def get_supported_llm_models(model: Any):
@@ -35,10 +32,11 @@ def get_supported_llm_models(model: Any):
3532
----------
3633
model: a model ID to compare against; optional
3734
38-
Returns: list of openai.types.model.Model
35+
Returns
3936
-------
40-
37+
List of openai.types.model.Model
4138
"""
39+
_ = model
4240
return [
4341
Model(
4442
id="datarobot_llm_id",
@@ -62,29 +60,62 @@ def load_model(code_dir: str) -> Any:
6260
-------
6361
If used, this hook must return a non-None value
6462
"""
63+
_ = code_dir
6564
return "dummy"
6665

6766

68-
def chat(
69-
completion_create_params: CompletionCreateParams, model: Any
70-
) -> ChatCompletion | Iterator[ChatCompletionChunk]:
67+
def chat(completion_create_params: CompletionCreateParams, model: Any):
7168
"""
72-
This hook supports chat completions; see https://platform.openai.com/docs/api-reference/chat/create.
73-
In this non-streaming example, the "LLM" echoes back the user's prompt,
74-
acting as the model specified in the chat completion request.
75-
76-
Parameters
77-
----------
78-
completion_create_params: the chat completion request.
79-
model: the deserialized model loaded by DRUM or by `load_model`, if supplied
80-
81-
Returns: a chat completion.
82-
-------
83-
69+
This hook supports chat completions;
70+
see https://platform.openai.com/docs/api-reference/chat/create.
71+
In this example, the "LLM" echoes back the user's prompt,
72+
acting as the model specified in the chat completion request.
73+
If streaming is requested, yields ChatCompletionChunk objects
74+
for each "token" (word) in the response.
75+
Returns ChatCompletion or Iterator[ChatCompletionChunk]
8476
"""
85-
model = completion_create_params["model"]
77+
_ = model
78+
inter_token_latency_seconds = 0.25
79+
model_id = completion_create_params["model"]
8680
message_content = "Echo: " + completion_create_params["messages"][0]["content"]
81+
stream = completion_create_params.get("stream", False)
82+
83+
if stream:
84+
# Mimic OpenAI streaming: yield one chunk at a time, split by whitespace
85+
def gen_chunks():
86+
chunk_id = str(uuid.uuid4())
87+
for token in message_content.split():
88+
yield ChatCompletionChunk(
89+
id=chunk_id,
90+
object="chat.completion.chunk",
91+
created=calendar.timegm(time.gmtime()),
92+
model=model_id,
93+
choices=[
94+
ChunkChoice(
95+
finish_reason=None,
96+
index=0,
97+
delta=ChoiceDelta(content=token),
98+
)
99+
],
100+
)
101+
time.sleep(inter_token_latency_seconds)
102+
# Send a final chunk with finish_reason
103+
yield ChatCompletionChunk(
104+
id=chunk_id,
105+
object="chat.completion.chunk",
106+
created=calendar.timegm(time.gmtime()),
107+
model=model_id,
108+
choices=[
109+
ChunkChoice(
110+
finish_reason="stop",
111+
index=0,
112+
delta=ChoiceDelta(),
113+
)
114+
],
115+
)
87116

117+
return gen_chunks()
118+
# non-streaming
88119
return ChatCompletion(
89120
id="association_id",
90121
choices=[
@@ -95,6 +126,6 @@ def chat(
95126
)
96127
],
97128
created=calendar.timegm(time.gmtime()),
98-
model=model,
129+
model=model_id,
99130
object="chat.completion",
100131
)

tests/unit/datarobot_drum/drum/conftest.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1+
import os
12
from typing import Optional
2-
from unittest.mock import patch
3+
from unittest.mock import Mock, patch
34

5+
import httpx
46
import pytest
5-
6-
from datarobot_drum.drum.enum import CustomHooks
7+
from httpx import WSGITransport
8+
from openai import OpenAI
79

810
from datarobot_drum.drum.adapters.model_adapters.python_model_adapter import PythonModelAdapter
11+
from datarobot_drum.drum.enum import CustomHooks, RunLanguage, TargetType
12+
from datarobot_drum.drum.lazy_loading.lazy_loading_handler import LazyLoadingHandler
13+
from datarobot_drum.drum.root_predictors.prediction_server import PredictionServer
14+
from datarobot_drum.drum.server import create_flask_app
915
from tests.unit.datarobot_drum.drum.helpers import MODEL_ID_FROM_RUNTIME_PARAMETER
1016
from tests.unit.datarobot_drum.drum.helpers import inject_runtime_parameter
1117
from tests.unit.datarobot_drum.drum.helpers import unset_runtime_parameter
@@ -114,3 +120,46 @@ def llm_id_parameter():
114120
inject_runtime_parameter(parameter_name, MODEL_ID_FROM_RUNTIME_PARAMETER)
115121
yield
116122
unset_runtime_parameter(parameter_name)
123+
124+
125+
@pytest.fixture
126+
def test_flask_app():
127+
with patch("datarobot_drum.drum.server.create_flask_app") as mock_create_flask_app, patch(
128+
"datarobot_drum.drum.root_predictors.prediction_server.PredictionServer._run_flask_app"
129+
):
130+
app = create_flask_app()
131+
app.config.update(
132+
{
133+
"TESTING": True,
134+
}
135+
)
136+
137+
mock_create_flask_app.return_value = app
138+
139+
yield app
140+
141+
142+
@pytest.fixture
143+
def openai_client(test_flask_app):
144+
return OpenAI(
145+
base_url="http://localhost:8080",
146+
api_key="<KEY>",
147+
http_client=httpx.Client(transport=WSGITransport(app=test_flask_app)),
148+
)
149+
150+
151+
@pytest.fixture
152+
def prediction_server(test_flask_app, chat_python_model_adapter):
153+
_, _ = test_flask_app, chat_python_model_adapter # depends on fixture side effects
154+
with patch.dict(os.environ, {"TARGET_NAME": "target"}), patch(
155+
"datarobot_drum.drum.language_predictors.python_predictor.python_predictor.PythonPredictor._init_mlops"
156+
), patch.object(LazyLoadingHandler, "download_lazy_loading_files"):
157+
params = {
158+
"run_language": RunLanguage.PYTHON,
159+
"target_type": TargetType.TEXT_GENERATION,
160+
"deployment_config": None,
161+
"__custom_model_path__": "/non-existing-path-to-avoid-loading-unwanted-artifacts",
162+
}
163+
server = PredictionServer(params)
164+
server._predictor._mlops = Mock()
165+
server.materialize()
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""
2+
Copyright 2025 DataRobot, Inc. and its affiliates.
3+
All rights reserved.
4+
This is proprietary source code of DataRobot, Inc. and its affiliates.
5+
Released under the terms of DataRobot Tool and Utility Agreement.
6+
"""
7+
8+
import pytest
9+
from openai import Stream
10+
from openai.types.chat import ChatCompletion
11+
12+
# This module tests score and chat hooks from selected model templates;
13+
# not by direct function call, but via testing flask app, prediction server, and model adapter.
14+
# Rename the imported hooks to keep them distinct
15+
from model_templates.python3_dummy_chat.custom import chat as dummy_chat_chat
16+
17+
# The particular model usually doesn't matter for the example hooks
18+
CHAT_COMPLETIONS_MODEL = "datarobot-deployed-llm"
19+
20+
21+
@pytest.mark.usefixtures("prediction_server")
22+
@pytest.mark.parametrize("is_streaming", [True, False])
23+
def test_dummy_chat_chat(openai_client, chat_python_model_adapter, is_streaming):
24+
"""Test the "python3 dummy chat" hook."""
25+
chat_python_model_adapter.chat_hook = dummy_chat_chat
26+
prompt = "Tell me a story"
27+
28+
completion = openai_client.chat.completions.create(
29+
model=CHAT_COMPLETIONS_MODEL,
30+
messages=[
31+
{"role": "user", "content": prompt},
32+
],
33+
stream=is_streaming,
34+
)
35+
36+
if is_streaming:
37+
assert isinstance(completion, Stream)
38+
chunk_messages = [
39+
chunk.choices[0].delta.content for chunk in completion if chunk.choices[0].delta.content
40+
]
41+
expected_messages = ["Echo:"] + prompt.split()
42+
assert chunk_messages == expected_messages
43+
else:
44+
assert isinstance(completion, ChatCompletion)
45+
assert completion.choices[0].message.content == "Echo: " + prompt

tests/unit/datarobot_drum/drum/test_prediction_server.py

Lines changed: 2 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,10 @@
33
from unittest.mock import ANY
44
from unittest.mock import Mock, patch
55

6-
import httpx
76
import openai
87
import pytest
9-
from httpx import WSGITransport
108
from openai import NotFoundError
11-
from openai import OpenAI, Stream
9+
from openai import Stream
1210
from openai.types.chat import (
1311
ChatCompletion,
1412
)
@@ -18,45 +16,11 @@
1816
from datarobot_drum.drum.enum import RunLanguage, TargetType
1917
from datarobot_drum.drum.lazy_loading.lazy_loading_handler import LazyLoadingHandler
2018
from datarobot_drum.drum.root_predictors.prediction_server import PredictionServer
21-
from datarobot_drum.drum.server import create_flask_app, HEADER_REQUEST_ID
22-
from datarobot_drum.drum.server import get_flask_app
19+
from datarobot_drum.drum.server import HEADER_REQUEST_ID
2320
from tests.unit.datarobot_drum.drum.chat_utils import create_completion, create_completion_chunks
2421
from tests.unit.datarobot_drum.drum.helpers import MODEL_ID_FROM_RUNTIME_PARAMETER
2522

2623

27-
@pytest.fixture
28-
def test_flask_app():
29-
with patch("datarobot_drum.drum.server.create_flask_app") as mock_create_flask_app, patch(
30-
"datarobot_drum.drum.root_predictors.prediction_server.PredictionServer._run_flask_app"
31-
):
32-
app = create_flask_app()
33-
app.config.update(
34-
{
35-
"TESTING": True,
36-
}
37-
)
38-
39-
mock_create_flask_app.return_value = app
40-
41-
yield app
42-
43-
44-
@pytest.fixture
45-
def prediction_server(test_flask_app, chat_python_model_adapter):
46-
with patch.dict(os.environ, {"TARGET_NAME": "target"}), patch(
47-
"datarobot_drum.drum.language_predictors.python_predictor.python_predictor.PythonPredictor._init_mlops"
48-
), patch.object(LazyLoadingHandler, "download_lazy_loading_files"):
49-
params = {
50-
"run_language": RunLanguage.PYTHON,
51-
"target_type": TargetType.TEXT_GENERATION,
52-
"deployment_config": None,
53-
"__custom_model_path__": "/non-existing-path-to-avoid-loading-unwanted-artifacts",
54-
}
55-
server = PredictionServer(params)
56-
server._predictor._mlops = Mock()
57-
server.materialize()
58-
59-
6024
@pytest.fixture
6125
def list_models_prediction_server(test_flask_app, list_models_python_model_adapter):
6226
with patch.dict(os.environ, {"TARGET_NAME": "target"}), patch(
@@ -89,15 +53,6 @@ def non_textgen_prediction_server(test_flask_app, non_chat_python_model_adapter)
8953
server.materialize()
9054

9155

92-
@pytest.fixture
93-
def openai_client(test_flask_app):
94-
return OpenAI(
95-
base_url="http://localhost:8080",
96-
api_key="<KEY>",
97-
http_client=httpx.Client(transport=WSGITransport(app=test_flask_app)),
98-
)
99-
100-
10156
@pytest.mark.usefixtures("prediction_server")
10257
def test_prediction_server(openai_client, chat_python_model_adapter):
10358
def chat_hook(completion_request, model):

0 commit comments

Comments
 (0)