1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- # pylint: skip-file
16-
1715import argparse
1816import json
1917import logging
2018import os
2119import sys
22- from typing import Any , cast
20+ from typing import cast
2321
2422import requests
2523from datarobot_drum .drum .enum import TargetType
2624from datarobot_drum .drum .root_predictors .drum_server_utils import DrumServerRun
2725from openai import OpenAI
28- from openai .types .chat import (
29- ChatCompletion ,
30- ChatCompletionSystemMessageParam ,
31- ChatCompletionUserMessageParam ,
32- )
33- from openai .types .chat .completion_create_params import (
34- CompletionCreateParamsNonStreaming ,
35- )
26+ from openai .types .chat import ChatCompletion
3627
3728root = logging .getLogger ()
3829
3930parser = argparse .ArgumentParser ()
40- parser .add_argument ("--user_prompt" , type = str , default = "" , help = "user_prompt for chat endpoint" )
41- parser .add_argument ("--extra_body" , type = str , default = "" , help = "extra_body for chat endpoint" )
31+ parser .add_argument (
32+ "--chat_completion" ,
33+ type = str ,
34+ default = "{}" ,
35+ help = "OpenAI ChatCompletion dict as json string" ,
36+ )
37+ parser .add_argument (
38+ "--default_headers" ,
39+ type = str ,
40+ default = "{}" ,
41+ help = "OpenAI default_headers as json string" ,
42+ )
4243parser .add_argument (
4344 "--custom_model_dir" ,
4445 type = str ,
@@ -72,29 +73,8 @@ def setup_logging(logger: logging.Logger, output_path: str, log_level: int = log
7273 logger .addHandler (handler_file )
7374
7475
75- def construct_prompt (user_prompt : str , extra_body : str ) -> Any :
76- extra_body_params = json .loads (extra_body ) if extra_body else {}
77- completion_create_params = CompletionCreateParamsNonStreaming (
78- model = "datarobot-deployed-llm" ,
79- messages = [
80- ChatCompletionSystemMessageParam (
81- content = "You are a helpful assistant" ,
82- role = "system" ,
83- ),
84- ChatCompletionUserMessageParam (
85- content = user_prompt ,
86- role = "user" ,
87- ),
88- ],
89- n = 1 ,
90- temperature = 0.01 ,
91- extra_body = extra_body_params , # type: ignore[typeddict-unknown-key]
92- )
93- return completion_create_params
94-
95-
9676def execute_drum (
97- user_prompt : str , extra_body : str , custom_model_dir : str , output_path : str
77+ chat_completion : str , default_headers : str , custom_model_dir : str , output_path : str
9878) -> ChatCompletion :
9979 root .info ("Executing agent as [chat] endpoint. DRUM Executor." )
10080 root .info ("Starting DRUM server." )
@@ -111,22 +91,30 @@ def execute_drum(
11191 port = 8191 ,
11292 stream_output = True ,
11393 ) as drum_runner :
114- root .info ("Verifying DRUM server. " )
94+ root .info ("Verifying DRUM server" )
11595 response = requests .get (drum_runner .url_server_address )
11696 if not response .ok :
117- raise RuntimeError ("Server failed to start" )
97+ root .error ("Server failed to start" )
98+ try :
99+ root .error (response .text )
100+ root .error (response .json ())
101+ finally :
102+ raise RuntimeError ("Server failed to start" )
103+
104+ root .info ("Parsing OpenAI request" )
105+ completion_params = json .loads (chat_completion )
106+ header_params = json .loads (default_headers )
118107
119108 # Use a standard OpenAI client to call the DRUM server. This mirrors the behavior of a deployed agent.
120- root .info ("Building prompt." )
109+ # Using the `chat.completions.create` method ensures the parameters are OpenAI compatible.
110+ root .info ("Executing Agent" )
121111 client = OpenAI (
122112 base_url = drum_runner .url_server_address ,
123113 api_key = "not-required" ,
114+ default_headers = header_params ,
124115 max_retries = 0 ,
125116 )
126- completion_create_params = construct_prompt (user_prompt , extra_body )
127-
128- root .info ("Executing Agent." )
129- completion = client .chat .completions .create (** completion_create_params )
117+ completion = client .chat .completions .create (** completion_params )
130118
131119 # Continue outside the context manager to ensure the server is stopped and logs
132120 # are flushed before we write the output
@@ -145,8 +133,8 @@ def execute_drum(
145133 args .custom_model_dir = os .path .join (os .getcwd (), "custom_model" )
146134setup_logging (logger = root , output_path = args .output_path , log_level = logging .INFO )
147135result = execute_drum (
148- user_prompt = args .user_prompt ,
149- extra_body = args .extra_body ,
136+ chat_completion = args .chat_completion ,
137+ default_headers = args .default_headers ,
150138 custom_model_dir = args .custom_model_dir ,
151139 output_path = args .output_path ,
152140)
0 commit comments