1818import os
1919import sys
2020from pathlib import Path
21- from typing import Any
21+ from typing import Any , Optional , Union , cast
2222from urllib .parse import urlparse , urlunparse
2323
2424import requests
@@ -46,32 +46,51 @@ def argparse_args() -> argparse.Namespace:
4646 required = True ,
4747 help = "OpenAI ChatCompletion dict as json string" ,
4848 )
49- parser .add_argument (
50- "--default_headers" ,
51- type = str ,
52- default = "{}" ,
53- help = "OpenAI default_headers as json string" ,
54- )
5549 parser .add_argument (
5650 "--custom_model_dir" ,
5751 type = str ,
5852 required = True ,
5953 help = "directory containing custom.py location" ,
6054 )
55+ parser .add_argument (
56+ "--default_headers" ,
57+ type = str ,
58+ default = "{}" ,
59+ help = "OpenAI default_headers as json string" ,
60+ )
6161 parser .add_argument ("--output_path" , type = str , default = None , help = "json output file location" )
6262 parser .add_argument ("--otlp_entity_id" , type = str , default = None , help = "Entity ID for tracing" )
6363 args = parser .parse_args ()
6464 return args
6565
6666
67- def setup_logging (logger : logging .Logger , log_level : int = logging .INFO ) -> None :
67+ def setup_logging (
68+ logger : logging .Logger ,
69+ output_path : Optional [Union [Path , str ]] = DEFAULT_OUTPUT_LOG_PATH ,
70+ log_level : Optional [int ] = logging .INFO ,
71+ update : Optional [bool ] = False ,
72+ ) -> None :
73+ log_level = cast (int , log_level )
74+ output_path = str (output_path )
75+
6876 logger .setLevel (log_level )
6977 handler_stream = logging .StreamHandler (sys .stdout )
7078 handler_stream .setLevel (log_level )
7179 formatter = logging .Formatter ("%(asctime)s - %(levelname)s - %(message)s" )
7280 handler_stream .setFormatter (formatter )
7381
82+ if os .path .exists (output_path ):
83+ os .remove (output_path )
84+ handler_file = logging .FileHandler (output_path )
85+ handler_file .setLevel (log_level )
86+ formatter_file = logging .Formatter ("%(asctime)s - %(levelname)s - %(message)s" )
87+ handler_file .setFormatter (formatter_file )
88+
89+ if update :
90+ logger .handlers .clear ()
91+
7492 logger .addHandler (handler_stream )
93+ logger .addHandler (handler_file )
7594
7695
7796def setup_otlp_env_variables (entity_id : str | None = None ) -> None :
@@ -145,7 +164,8 @@ def execute_drum(
145164
146165def construct_prompt (chat_completion : str ) -> CompletionCreateParamsBase :
147166 chat_completion_dict = json .loads (chat_completion )
148- if "model" not in chat_completion_dict :
167+ model = chat_completion_dict .get ("model" )
168+ if model is None or len (str (model )) == 0 :
149169 chat_completion_dict ["model" ] = "unknown"
150170 validator = TypeAdapter (CompletionCreateParamsBase )
151171 validator .validate_python (chat_completion_dict )
@@ -162,49 +182,54 @@ def store_result(result: ChatCompletion, output_path: Path) -> None:
162182
163183
164184def main () -> Any :
165- with open (DEFAULT_OUTPUT_LOG_PATH , "w" ) as f :
166- sys .stdout = f
167- sys .stderr = f
168- print ("Parsing args" )
169- args = argparse_args ()
170-
171- output_log_path = (
172- Path (args .output_path + ".log" ) if args .output_path else DEFAULT_OUTPUT_LOG_PATH
173- )
174- with open (output_log_path , "a" ) as f :
175- sys .stdout = f
176- sys .stderr = f
177-
185+ # During failures logs will be dumped to the default output log path
186+ setup_logging (logger = root , log_level = logging .INFO )
187+ try :
188+ root .info ("Parsing args" )
178189 try :
179- print ("Setting up logging" )
180- setup_logging (logger = root , log_level = logging .INFO )
181- root .info ("Parsing args" )
182-
183- # Parse input to fail early if it's not valid
184- chat_completion = construct_prompt (args .chat_completion )
185- default_headers = json .loads (args .default_headers )
186- root .info (f"Chat completion: { chat_completion } " )
187- root .info (f"Default headers: { default_headers } " )
188-
189- # Setup tracing
190- print ("Setting up tracing" )
191- setup_otlp_env_variables (args .otlp_entity_id )
192-
193- root .info (f"Executing request in directory { args .custom_model_dir } " )
194- result = execute_drum (
195- chat_completion = chat_completion ,
196- default_headers = default_headers ,
197- custom_model_dir = args .custom_model_dir ,
190+ # Attempt to parse arguments and setup logging
191+ args = argparse_args ()
192+
193+ root .info ("Setting up logging" )
194+ output_log_path = str (
195+ Path (args .output_path + ".log" ) if args .output_path else DEFAULT_OUTPUT_LOG_PATH
198196 )
199- root .info (f"Result: { result } " )
200- store_result (
201- result ,
202- Path (args .output_path ) if args .output_path else DEFAULT_OUTPUT_JSON_PATH ,
197+ setup_logging (
198+ logger = root ,
199+ output_path = output_log_path ,
200+ log_level = logging .INFO ,
201+ update = True ,
203202 )
204203 except Exception as e :
205- root .exception (f"Error executing agent: { e } " )
204+ root .exception (f"Error parsing arguments: { e } " )
205+ sys .exit (1 )
206+
207+ # Parse input to fail early if it's not valid
208+ chat_completion = construct_prompt (args .chat_completion )
209+ default_headers = json .loads (args .default_headers )
210+ root .info (f"Chat completion: { chat_completion } " )
211+ root .info (f"Default headers: { default_headers } " )
212+
213+ # Setup tracing
214+ root .info ("Setting up tracing" )
215+ setup_otlp_env_variables (args .otlp_entity_id )
216+
217+ root .info (f"Executing request in directory { args .custom_model_dir } " )
218+ result = execute_drum (
219+ chat_completion = chat_completion ,
220+ default_headers = default_headers ,
221+ custom_model_dir = args .custom_model_dir ,
222+ )
223+ root .info (f"Result: { result } " )
224+ store_result (
225+ result ,
226+ Path (args .output_path ) if args .output_path else DEFAULT_OUTPUT_JSON_PATH ,
227+ )
228+ except Exception as e :
229+ root .exception (f"Error executing agent: { e } " )
206230
207231
232+ # Agent execution
208233if __name__ == "__main__" :
209234 stdout = sys .stdout
210235 stderr = sys .stderr
0 commit comments