1717import logging
1818import os
1919import sys
20- from typing import cast
20+ from pathlib import Path
21+ from typing import Any
22+ from urllib .parse import urlparse , urlunparse
2123
2224import requests
2325from datarobot_drum .drum .enum import TargetType
2426from datarobot_drum .drum .root_predictors .drum_server_utils import DrumServerRun
2527from openai import OpenAI
2628from openai .types .chat import ChatCompletion
27-
28- root = logging .getLogger ()
29-
30- parser = argparse .ArgumentParser ()
31- parser .add_argument (
32- "--chat_completion" ,
33- type = str ,
34- default = "{}" ,
35- help = "OpenAI ChatCompletion dict as json string" ,
36- )
37- parser .add_argument (
38- "--default_headers" ,
39- type = str ,
40- default = "{}" ,
41- help = "OpenAI default_headers as json string" ,
42- )
43- parser .add_argument (
44- "--custom_model_dir" ,
45- type = str ,
46- default = "" ,
47- help = "directory containing custom.py location" ,
29+ from openai .types .chat .completion_create_params import (
30+ CompletionCreateParamsBase ,
4831)
49- parser .add_argument ("--output_path" , type = str , default = "" , help = "json output file location" )
50- args = parser .parse_args ()
32+ from pydantic import TypeAdapter
5133
34+ root = logging .getLogger ()
5235
53- def setup_logging (logger : logging .Logger , output_path : str , log_level : int = logging .INFO ) -> None :
54- if len (output_path ) == 0 :
55- output_path = "output.log"
56- else :
57- output_path = f"{ output_path } .log"
58-
36+ CURRENT_DIR = Path (__file__ ).parent
37+ DEFAULT_OUTPUT_LOG_PATH = CURRENT_DIR / "output.log"
38+ DEFAULT_OUTPUT_JSON_PATH = CURRENT_DIR / "output.json"
39+
40+
41+ def argparse_args () -> argparse .Namespace :
42+ parser = argparse .ArgumentParser ()
43+ parser .add_argument (
44+ "--chat_completion" ,
45+ type = str ,
46+ required = True ,
47+ help = "OpenAI ChatCompletion dict as json string" ,
48+ )
49+ parser .add_argument (
50+ "--default_headers" ,
51+ type = str ,
52+ default = "{}" ,
53+ help = "OpenAI default_headers as json string" ,
54+ )
55+ parser .add_argument (
56+ "--custom_model_dir" ,
57+ type = str ,
58+ required = True ,
59+ help = "directory containing custom.py location" ,
60+ )
61+ parser .add_argument ("--output_path" , type = str , default = None , help = "json output file location" )
62+ parser .add_argument ("--otlp_entity_id" , type = str , default = None , help = "Entity ID for tracing" )
63+ args = parser .parse_args ()
64+ return args
65+
66+
67+ def setup_logging (logger : logging .Logger , log_level : int = logging .INFO ) -> None :
5968 logger .setLevel (log_level )
6069 handler_stream = logging .StreamHandler (sys .stdout )
6170 handler_stream .setLevel (log_level )
62- formatter = logging .Formatter ("%(message)s" )
71+ formatter = logging .Formatter ("%(asctime)s - %(levelname)s - %( message)s" )
6372 handler_stream .setFormatter (formatter )
6473
65- if os .path .exists (output_path ):
66- os .remove (output_path )
67- handler_file = logging .FileHandler (output_path )
68- handler_file .setLevel (log_level )
69- formatter_file = logging .Formatter ("%(asctime)s - %(levelname)s - %(message)s" )
70- handler_file .setFormatter (formatter_file )
71-
7274 logger .addHandler (handler_stream )
73- logger .addHandler (handler_file )
75+
76+
77+ def setup_otlp_env_variables (entity_id : str | None = None ) -> None :
78+ # do not override if already set
79+ if os .environ .get ("OTEL_EXPORTER_OTLP_ENDPOINT" ) or os .environ .get (
80+ "OTEL_EXPORTER_OTLP_HEADERS"
81+ ):
82+ root .info ("OTEL_EXPORTER_OTLP_ENDPOINT or OTEL_EXPORTER_OTLP_HEADERS already set, skipping" )
83+ return
84+
85+ datarobot_endpoint = os .environ .get ("DATAROBOT_ENDPOINT" )
86+ datarobot_api_token = os .environ .get ("DATAROBOT_API_TOKEN" )
87+ if not datarobot_endpoint or not datarobot_api_token :
88+ root .warning ("DATAROBOT_ENDPOINT or DATAROBOT_API_TOKEN not set, tracing is disabled" )
89+ return
90+
91+ parsed_url = urlparse (datarobot_endpoint )
92+ stripped_url = (parsed_url .scheme , parsed_url .netloc , "otel" , "" , "" , "" )
93+ otlp_endpoint = urlunparse (stripped_url )
94+ otlp_headers = f"X-DataRobot-Api-Key={ datarobot_api_token } "
95+ if entity_id :
96+ otlp_headers += f",X-DataRobot-Entity-Id={ entity_id } "
97+ os .environ ["OTEL_EXPORTER_OTLP_ENDPOINT" ] = otlp_endpoint
98+ os .environ ["OTEL_EXPORTER_OTLP_HEADERS" ] = otlp_headers
99+ root .info (f"Using OTEL_EXPORTER_OTLP_ENDPOINT: { otlp_endpoint } " )
74100
75101
76102def execute_drum (
77- chat_completion : str , default_headers : str , custom_model_dir : str , output_path : str
103+ chat_completion : CompletionCreateParamsBase ,
104+ default_headers : dict [str , str ],
105+ custom_model_dir : Path ,
78106) -> ChatCompletion :
79107 root .info ("Executing agent as [chat] endpoint. DRUM Executor." )
80108 root .info ("Starting DRUM server." )
@@ -97,44 +125,95 @@ def execute_drum(
97125 root .error ("Server failed to start" )
98126 try :
99127 root .error (response .text )
100- root .error (response .json ())
101128 finally :
102129 raise RuntimeError ("Server failed to start" )
103130
104- root .info ("Parsing OpenAI request" )
105- completion_params = json .loads (chat_completion )
106- header_params = json .loads (default_headers )
107-
108131 # Use a standard OpenAI client to call the DRUM server. This mirrors the behavior of a deployed agent.
109132 # Using the `chat.completions.create` method ensures the parameters are OpenAI compatible.
110133 root .info ("Executing Agent" )
111134 client = OpenAI (
112135 base_url = drum_runner .url_server_address ,
113136 api_key = "not-required" ,
114- default_headers = header_params ,
137+ default_headers = default_headers ,
115138 max_retries = 0 ,
116139 )
117- completion = client .chat .completions .create (** completion_params )
118-
140+ completion = client .chat .completions .create (** chat_completion )
119141 # Continue outside the context manager to ensure the server is stopped and logs
120142 # are flushed before we write the output
121- root .info (f"Storing result: { output_path } " )
122- if len (output_path ) == 0 :
123- output_path = os .path .join (custom_model_dir , "output.json" )
124- with open (output_path , "w" ) as fp :
125- fp .write (completion .to_json ())
143+ return completion
126144
127- root .info (completion .to_json ())
128- return cast (ChatCompletion , completion )
129145
146+ def construct_prompt (chat_completion : str ) -> CompletionCreateParamsBase :
147+ chat_completion_dict = json .loads (chat_completion )
148+ if "model" not in chat_completion_dict :
149+ chat_completion_dict ["model" ] = "unknown"
150+ validator = TypeAdapter (CompletionCreateParamsBase )
151+ validator .validate_python (chat_completion_dict )
152+ completion_create_params : CompletionCreateParamsBase = CompletionCreateParamsBase (
153+ ** chat_completion_dict # type: ignore[typeddict-item]
154+ )
155+ return completion_create_params
130156
131- # Agent execution
132- if len (args .custom_model_dir ) == 0 :
133- args .custom_model_dir = os .path .join (os .getcwd (), "custom_model" )
134- setup_logging (logger = root , output_path = args .output_path , log_level = logging .INFO )
135- result = execute_drum (
136- chat_completion = args .chat_completion ,
137- default_headers = args .default_headers ,
138- custom_model_dir = args .custom_model_dir ,
139- output_path = args .output_path ,
140- )
157+
158+ def store_result (result : ChatCompletion , output_path : Path ) -> None :
159+ root .info (f"Storing result: { output_path } " )
160+ with open (output_path , "w" ) as fp :
161+ fp .write (result .to_json ())
162+
163+
164+ def main () -> Any :
165+ with open (DEFAULT_OUTPUT_LOG_PATH , "w" ) as f :
166+ sys .stdout = f
167+ sys .stderr = f
168+ print ("Parsing args" )
169+ args = argparse_args ()
170+
171+ output_log_path = (
172+ Path (args .output_path + ".log" ) if args .output_path else DEFAULT_OUTPUT_LOG_PATH
173+ )
174+ with open (output_log_path , "a" ) as f :
175+ sys .stdout = f
176+ sys .stderr = f
177+
178+ try :
179+ print ("Setting up logging" )
180+ setup_logging (logger = root , log_level = logging .INFO )
181+ root .info ("Parsing args" )
182+
183+ # Parse input to fail early if it's not valid
184+ chat_completion = construct_prompt (args .chat_completion )
185+ default_headers = json .loads (args .default_headers )
186+ root .info (f"Chat completion: { chat_completion } " )
187+ root .info (f"Default headers: { default_headers } " )
188+
189+ # Setup tracing
190+ print ("Setting up tracing" )
191+ setup_otlp_env_variables (args .otlp_entity_id )
192+
193+ root .info (f"Executing request in directory { args .custom_model_dir } " )
194+ result = execute_drum (
195+ chat_completion = chat_completion ,
196+ default_headers = default_headers ,
197+ custom_model_dir = args .custom_model_dir ,
198+ )
199+ root .info (f"Result: { result } " )
200+ store_result (
201+ result ,
202+ Path (args .output_path ) if args .output_path else DEFAULT_OUTPUT_JSON_PATH ,
203+ )
204+ except Exception as e :
205+ root .exception (f"Error executing agent: { e } " )
206+
207+
208+ if __name__ == "__main__" :
209+ stdout = sys .stdout
210+ stderr = sys .stderr
211+ try :
212+ main ()
213+ except Exception :
214+ pass
215+ finally :
216+ # Return to original stdout and stderr otherwise the kernel will fail to flush and
217+ # hang
218+ sys .stdout = stdout
219+ sys .stderr = stderr
0 commit comments