Skip to content

Commit af8c341

Browse files
authored
[MMM-19441] Missing chat hook error message for agentic workflow (#1439)
* message * test
1 parent c27999c commit af8c341

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

custom_model_runner/datarobot_drum/drum/root_predictors/predict_mixin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def do_chat(self, logger=None):
385385
# _predictor is a BaseLanguagePredictor attribute of PredictionServer;
386386
# see PredictionServer.__init__()
387387
if not self._predictor.supports_chat():
388-
if self._target_type == TargetType.TEXT_GENERATION:
388+
if self._target_type in [TargetType.TEXT_GENERATION, TargetType.AGENTIC_WORKFLOW]:
389389
message = undefined_chat_message
390390
else:
391391
message = unsupported_chat_message

tests/unit/datarobot_drum/drum/test_prediction_server.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -73,22 +73,6 @@ def list_models_prediction_server(test_flask_app, list_models_python_model_adapt
7373
server.materialize()
7474

7575

76-
@pytest.fixture
77-
def non_chat_prediction_server(test_flask_app, non_chat_python_model_adapter):
78-
with patch.dict(os.environ, {"TARGET_NAME": "target"}), patch(
79-
"datarobot_drum.drum.language_predictors.python_predictor.python_predictor.PythonPredictor._init_mlops"
80-
), patch.object(LazyLoadingHandler, "download_lazy_loading_files"):
81-
params = {
82-
"run_language": RunLanguage.PYTHON,
83-
"target_type": TargetType.TEXT_GENERATION,
84-
"deployment_config": None,
85-
"__custom_model_path__": "/non-existing-path-to-avoid-loading-unwanted-artifacts",
86-
}
87-
server = PredictionServer(params)
88-
server._predictor._mlops = Mock()
89-
server.materialize()
90-
91-
9276
@pytest.fixture
9377
def non_textgen_prediction_server(test_flask_app, non_chat_python_model_adapter):
9478
with patch.dict(os.environ, {"TARGET_NAME": "target"}), patch(
@@ -153,9 +137,25 @@ def test_prediction_server_chat_unsupported(openai_client):
153137
)
154138

155139

156-
@pytest.mark.usefixtures("non_chat_prediction_server")
157-
def test_prediction_server_chat_unimplemented(openai_client):
140+
@pytest.mark.parametrize("target_type", [TargetType.TEXT_GENERATION, TargetType.AGENTIC_WORKFLOW])
141+
def test_prediction_server_chat_unimplemented(
142+
test_flask_app, non_chat_python_model_adapter, openai_client, target_type
143+
):
158144
"""Attempt to chat when a textgen model does not implement chat()."""
145+
146+
with patch.dict(os.environ, {"TARGET_NAME": "target"}), patch(
147+
"datarobot_drum.drum.language_predictors.python_predictor.python_predictor.PythonPredictor._init_mlops"
148+
), patch.object(LazyLoadingHandler, "download_lazy_loading_files"):
149+
params = {
150+
"run_language": RunLanguage.PYTHON,
151+
"target_type": target_type,
152+
"deployment_config": None,
153+
"__custom_model_path__": "/non-existing-path-to-avoid-loading-unwanted-artifacts",
154+
}
155+
server = PredictionServer(params)
156+
server._predictor._mlops = Mock()
157+
server.materialize()
158+
159159
with pytest.raises(NotFoundError, match=r"but chat\(\) is not implemented"):
160160
_ = openai_client.chat.completions.create(
161161
model="any",

0 commit comments

Comments
 (0)