Skip to content

Commit 121a0c3

Browse files
authored
[RAPTOR-13895] Implement inline predictor based on DRUM score (#1504)
* Demo * . * . * . * . * . * . * . * . * . * . * . * . * Add moderations as dep to tests * type annotation * Move moderations to a separate file and bump resources for harness * No moderation test for now, its triples the duration of the test suite * Proper skip * .
1 parent 03f9d14 commit 121a0c3

File tree

9 files changed

+336
-30
lines changed

9 files changed

+336
-30
lines changed

custom_model_runner/datarobot_drum/drum/common.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
This is proprietary source code of DataRobot, Inc. and its affiliates.
55
Released under the terms of DataRobot Tool and Utility Agreement.
66
"""
7+
78
import logging
89
import os
910
import sys
@@ -21,6 +22,8 @@
2122
PayloadFormat,
2223
)
2324
from datarobot_drum.drum.exceptions import DrumCommonException
25+
from datarobot_drum.drum.lazy_loading.lazy_loading_handler import LazyLoadingHandler
26+
from datarobot_drum.runtime_parameters.runtime_parameters import RuntimeParametersLoader
2427
from opentelemetry import trace, context
2528
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
2629
from opentelemetry.sdk.resources import Resource
@@ -229,3 +232,12 @@ def extract_chat_response_attributes(response):
229232
# last completion wins
230233
attributes["gen_ai.completion"] = m.get("content")
231234
return attributes
235+
236+
237+
def setup_required_environment_variables(options):
238+
if "runtime_params_file" in options and options.runtime_params_file:
239+
loader = RuntimeParametersLoader(options.runtime_params_file, options.code_dir)
240+
loader.setup_environment_variables()
241+
242+
if "lazy_loading_file" in options and options.lazy_loading_file:
243+
LazyLoadingHandler.setup_environment_variables_from_values_file(options.lazy_loading_file)

custom_model_runner/datarobot_drum/drum/drum.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -803,10 +803,7 @@ def _prepare_prediction_server_or_batch_pipeline(self, run_language):
803803

804804
return DrumUtils.render_file(functional_pipeline_filepath, replace_data)
805805

806-
def _run_predictions(self, stats_collector: Optional[StatsCollector] = None):
807-
if self.run_mode not in [RunMode.SCORE, RunMode.SERVER]:
808-
raise NotImplemented(f"The given run mode is supported here: {self.run_mode}")
809-
806+
def get_predictor_params(self):
810807
run_language = self._check_artifacts_and_get_run_language()
811808
infra_pipeline_str = self._prepare_prediction_server_or_batch_pipeline(run_language)
812809

@@ -815,12 +812,17 @@ def _run_predictions(self, stats_collector: Optional[StatsCollector] = None):
815812
raise DrumCommonException("Pipeline is empty")
816813
if "arguments" not in pipeline["pipe"][0]:
817814
raise DrumCommonException("Arguments are missing in the pipeline")
815+
return pipeline["pipe"][0]["arguments"]
816+
817+
def _run_predictions(self, stats_collector: Optional[StatsCollector] = None):
818+
if self.run_mode not in [RunMode.SCORE, RunMode.SERVER]:
819+
raise NotImplemented(f"The given run mode is supported here: {self.run_mode}")
818820

819821
self.logger.info(
820822
f">>> Start {ArgumentsOptions.MAIN_COMMAND} in the {self.run_mode.value} mode"
821823
)
822824

823-
params = pipeline["pipe"][0]["arguments"]
825+
params = self.get_predictor_params()
824826
predictor = None
825827
try:
826828
if stats_collector:

custom_model_runner/datarobot_drum/drum/main.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
This is proprietary source code of DataRobot, Inc. and its affiliates.
55
Released under the terms of DataRobot Tool and Utility Agreement.
66
"""
7+
78
from datarobot_drum.drum.lazy_loading.lazy_loading_handler import LazyLoadingHandler
89

910
#!/usr/bin/env python3
@@ -43,12 +44,15 @@
4344
import sys
4445

4546
from datarobot_drum.drum.args_parser import CMRunnerArgsRegistry
46-
from datarobot_drum.drum.common import config_logging, setup_tracer
47+
from datarobot_drum.drum.common import (
48+
config_logging,
49+
setup_tracer,
50+
setup_required_environment_variables,
51+
)
4752
from datarobot_drum.drum.enum import RunMode
4853
from datarobot_drum.drum.enum import ExitCodes
4954
from datarobot_drum.drum.exceptions import DrumSchemaValidationException
5055
from datarobot_drum.drum.runtime import DrumRuntime
51-
from datarobot_drum.runtime_parameters.exceptions import RuntimeParameterException
5256
from datarobot_drum.runtime_parameters.runtime_parameters import (
5357
RuntimeParametersLoader,
5458
RuntimeParameters,
@@ -92,7 +96,12 @@ def signal_handler(sig, frame):
9296

9397
options = arg_parser.parse_args()
9498
CMRunnerArgsRegistry.verify_options(options)
95-
_setup_required_environment_variables(options)
99+
100+
try:
101+
setup_required_environment_variables(options)
102+
except Exception as exc:
103+
print(str(exc))
104+
exit(255)
96105

97106
if RuntimeParameters.has("CUSTOM_MODEL_WORKERS"):
98107
options.max_workers = RuntimeParameters.get("CUSTOM_MODEL_WORKERS")
@@ -112,24 +121,5 @@ def signal_handler(sig, frame):
112121
sys.exit(ExitCodes.SCHEMA_VALIDATION_ERROR.value)
113122

114123

115-
def _setup_required_environment_variables(options):
116-
if "runtime_params_file" in options and options.runtime_params_file:
117-
try:
118-
loader = RuntimeParametersLoader(options.runtime_params_file, options.code_dir)
119-
loader.setup_environment_variables()
120-
except RuntimeParameterException as exc:
121-
print(str(exc))
122-
exit(255)
123-
124-
if "lazy_loading_file" in options and options.lazy_loading_file:
125-
try:
126-
LazyLoadingHandler.setup_environment_variables_from_values_file(
127-
options.lazy_loading_file
128-
)
129-
except Exception as exc:
130-
print(str(exc))
131-
exit(255)
132-
133-
134124
if __name__ == "__main__":
135125
main()
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""
2+
Copyright 2025 DataRobot, Inc. and its affiliates.
3+
All rights reserved.
4+
This is proprietary source code of DataRobot, Inc. and its affiliates.
5+
Released under the terms of DataRobot Tool and Utility Agreement.
6+
7+
Example:
8+
9+
import json
10+
11+
payload = json.loads(open("input.json", "r").read())
12+
code_dir = (
13+
'/datarobot-user-models/model_templates/python3_dummy_chat'
14+
)
15+
16+
with drum_inline_predictor(target_type=TargetType.AGENTIC_WORKFLOW.value, custom_model_dir=code_dir,
17+
target_name='response') as predictor:
18+
result = predictor.chat(payload)
19+
print(result)
20+
21+
"""
22+
23+
import contextlib
24+
import os
25+
import tempfile
26+
from typing import Generator, List
27+
28+
from datarobot_drum.drum.args_parser import CMRunnerArgsRegistry
29+
from datarobot_drum.drum.common import setup_required_environment_variables, setup_tracer
30+
from datarobot_drum.drum.drum import CMRunner
31+
from datarobot_drum.drum.language_predictors.base_language_predictor import BaseLanguagePredictor
32+
from datarobot_drum.drum.runtime import DrumRuntime
33+
from datarobot_drum.drum.root_predictors.generic_predictor import GenericPredictorComponent
34+
from datarobot_drum.runtime_parameters.runtime_parameters import RuntimeParameters
35+
36+
37+
@contextlib.contextmanager
38+
def drum_inline_predictor(
39+
target_type: str, custom_model_dir: str, target_name: str, *cmd_args: List[str]
40+
) -> Generator[BaseLanguagePredictor, None, None]:
41+
"""
42+
Drum run for a custom model code definition. Yields a predictor, ready to work with.
43+
Caller can work with the predictor directly.
44+
45+
:param target_type: Target type.
46+
:param custom_model_dir: Directory where the custom model code artifacts are located.
47+
:param target_name: Name of the target
48+
:param cmd_args: Extra command line arguments
49+
:return:
50+
"""
51+
with DrumRuntime() as runtime, tempfile.NamedTemporaryFile(mode="wb") as tf:
52+
# setup
53+
54+
os.environ["TARGET_NAME"] = target_name
55+
arg_parser = CMRunnerArgsRegistry.get_arg_parser()
56+
CMRunnerArgsRegistry.extend_sys_argv_with_env_vars()
57+
args = [
58+
"score",
59+
"--code-dir",
60+
custom_model_dir,
61+
# regular score is actually a CLI thing, so it expects input/output,
62+
# we can ignore these as we hand over the predictor directly to the caller to do I/O.
63+
"--input",
64+
tf.name,
65+
"--output",
66+
tf.name,
67+
"--target-type",
68+
target_type,
69+
*cmd_args,
70+
]
71+
options = arg_parser.parse_args(args)
72+
CMRunnerArgsRegistry.verify_options(options)
73+
setup_required_environment_variables(options)
74+
75+
runtime.options = options
76+
setup_tracer(RuntimeParameters, options)
77+
runtime.cm_runner = CMRunner(runtime)
78+
params = runtime.cm_runner.get_predictor_params()
79+
predictor = GenericPredictorComponent(params)
80+
81+
yield predictor.predictor

custom_model_runner/datarobot_drum/drum/root_predictors/generic_predictor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
Released under the terms of DataRobot Tool and Utility Agreement.
66
"""
77
import urllib
8-
from typing import Optional
98

109
import werkzeug
1110
from datarobot_drum.drum.adapters.cli.drum_score_adapter import DrumScoreAdapter
@@ -39,6 +38,10 @@ def __init__(self, params: dict):
3938
)
4039
self._predictor = self._setup_predictor()
4140

41+
@property
42+
def predictor(self):
43+
return self._predictor
44+
4245
def _setup_predictor(self):
4346
if self._run_language == RunLanguage.PYTHON:
4447
from datarobot_drum.drum.language_predictors.python_predictor.python_predictor import (
@@ -92,7 +95,6 @@ def _setup_predictor(self):
9295

9396
def materialize(self):
9497
output_filename = self._params.get("output_filename")
95-
9698
if self.cli_adapter.target_type == TargetType.UNSTRUCTURED:
9799
# TODO: add support to use cli_adapter for unstructured
98100
return self._materialize_unstructured(

requirements_test.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ retry
99
scikit-learn==1.3.2
1010
scipy>=1.1,<2
1111
urllib3>=1.25.0,<2.0.0
12+
openai>=1.55.3
1213
# strictly not needed for testing, but used when updating environment
13-
bson
14+
bson
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""
2+
Copyright 2025 DataRobot, Inc. and its affiliates.
3+
All rights reserved.
4+
This is proprietary source code of DataRobot, Inc. and its affiliates.
5+
Released under the terms of DataRobot Tool and Utility Agreement.
6+
"""
7+
import calendar
8+
import time
9+
from typing import Iterator
10+
11+
from openai.types.chat import ChatCompletion
12+
from openai.types.chat import ChatCompletionChunk
13+
from openai.types.chat import ChatCompletionMessage
14+
from openai.types.chat import CompletionCreateParams
15+
from openai.types.chat.chat_completion import Choice
16+
from openai.types.model import Model
17+
18+
from datarobot_drum import RuntimeParameters
19+
20+
"""
21+
This example shows how to create a text generation model supporting OpenAI chat
22+
"""
23+
24+
from typing import Any, Dict
25+
26+
27+
def get_supported_llm_models(model: Any):
28+
"""
29+
Return a list of supported LLM models; response to /v1/models and OpenAI models.list().
30+
If custom.py does not define this function, DRUM will return a list of either:
31+
* the model defined in the LLM_ID runtime parameter, if that exists, or:
32+
* an empty list
33+
34+
Parameters
35+
----------
36+
model: a model ID to compare against; optional
37+
38+
Returns: list of openai.types.model.Model
39+
-------
40+
41+
"""
42+
return [
43+
Model(
44+
id="datarobot_llm_id",
45+
created=1744854432,
46+
object="model",
47+
owned_by="tester@datarobot.com",
48+
)
49+
]
50+
51+
52+
def load_model(code_dir: str) -> Any:
53+
"""
54+
Can be used to load supported models if your model has multiple artifacts, or for loading
55+
models that **drum** does not natively support
56+
57+
Parameters
58+
----------
59+
code_dir : is the directory where model artifact and additional code are provided, passed in
60+
61+
Returns
62+
-------
63+
If used, this hook must return a non-None value
64+
"""
65+
return "dummy"
66+
67+
68+
def chat(
69+
completion_create_params: CompletionCreateParams, model: Any
70+
) -> ChatCompletion | Iterator[ChatCompletionChunk]:
71+
"""
72+
This hook supports chat completions; see https://platform.openai.com/docs/api-reference/chat/create.
73+
In this non-streaming example, the "LLM" echoes back the user's prompt,
74+
acting as the model specified in the chat completion request.
75+
76+
Parameters
77+
----------
78+
completion_create_params: the chat completion request.
79+
model: the deserialized model loaded by DRUM or by `load_model`, if supplied
80+
81+
Returns: a chat completion.
82+
-------
83+
84+
"""
85+
model = completion_create_params["model"]
86+
message_content = "Echo: " + completion_create_params["messages"][0]["content"]
87+
88+
return ChatCompletion(
89+
id="association_id",
90+
choices=[
91+
Choice(
92+
finish_reason="stop",
93+
index=0,
94+
message=ChatCompletionMessage(role="assistant", content=message_content),
95+
)
96+
],
97+
created=calendar.timegm(time.gmtime()),
98+
model=model,
99+
object="chat.completion",
100+
)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
guards:
2+
- description: Track the number of tokens associated with the input to the LLM, and/or retrieved
3+
text from the vector database.
4+
name: prompt_tokens
5+
ootb_type: token_count
6+
stage: prompt
7+
type: ootb
8+
- description: Track the number of tokens associated with the input to the LLM, and/or retrieved
9+
text from the vector database.
10+
name: response_tokens
11+
ootb_type: token_count
12+
stage: response
13+
type: ootb
14+
timeout_action: score
15+
timeout_sec: 60

0 commit comments

Comments
 (0)