1212from flask import Response , jsonify , request
1313from werkzeug .exceptions import HTTPException
1414
15+ from opentelemetry import trace
1516from datarobot_drum .drum .description import version as drum_version
1617from datarobot_drum .drum .enum import (
1718 FLASK_EXT_FILE_NAME ,
3839 get_flask_app ,
3940)
4041from datarobot_drum .profiler .stats_collector import StatsCollector , StatsOperation
42+ from datarobot_drum .drum .common import otel_context
4143
4244logger = logging .getLogger (LOGGER_NAME_PREFIX + "." + __name__ )
4345
4446
47+ tracer = trace .get_tracer (__name__ )
48+
49+
4550class PredictionServer (PredictMixin ):
4651 def __init__ (self , params : dict ):
4752 self ._params = params
@@ -157,54 +162,50 @@ def health():
157162 @model_api .route ("/invocations" , methods = ["POST" ])
158163 def predict ():
159164 logger .debug ("Entering predict() endpoint" )
160-
161- self ._pre_predict_and_transform ()
162- try :
163- response , response_status = self .do_predict_structured (logger = logger )
164- finally :
165- self ._post_predict_and_transform ()
165+ with otel_context ( tracer , "drum.invocations" , request . headers ):
166+ self ._pre_predict_and_transform ()
167+ try :
168+ response , response_status = self .do_predict_structured (logger = logger )
169+ finally :
170+ self ._post_predict_and_transform ()
166171
167172 return response , response_status
168173
169174 @model_api .route ("/transform/" , methods = ["POST" ])
170175 def transform ():
171176 logger .debug ("Entering transform() endpoint" )
172-
173- self ._pre_predict_and_transform ()
174-
175- try :
176- response , response_status = self .do_transform (logger = logger )
177- finally :
178- self ._post_predict_and_transform ()
177+ with otel_context (tracer , "drum.transform" , request .headers ):
178+ self ._pre_predict_and_transform ()
179+ try :
180+ response , response_status = self .do_transform (logger = logger )
181+ finally :
182+ self ._post_predict_and_transform ()
179183
180184 return response , response_status
181185
182186 @model_api .route ("/predictionsUnstructured/" , methods = ["POST" ])
183187 @model_api .route ("/predictUnstructured/" , methods = ["POST" ])
184188 def predict_unstructured ():
185189 logger .debug ("Entering predict() endpoint" )
186-
187- self ._pre_predict_and_transform ()
188-
189- try :
190- response , response_status = self .do_predict_unstructured (logger = logger )
191- finally :
192- self ._post_predict_and_transform ()
193-
190+ with otel_context (tracer , "drum.predictUnstructured" , request .headers ):
191+ self ._pre_predict_and_transform ()
192+ try :
193+ response , response_status = self .do_predict_unstructured (logger = logger )
194+ finally :
195+ self ._post_predict_and_transform ()
194196 return (response , response_status )
195197
196198 # Chat routes are defined without trailing slash because this is required by the OpenAI python client.
197199 @model_api .route ("/chat/completions" , methods = ["POST" ])
198200 @model_api .route ("/v1/chat/completions" , methods = ["POST" ])
199201 def chat ():
200202 logger .debug ("Entering chat endpoint" )
201-
202- self ._pre_predict_and_transform ()
203-
204- try :
205- response , response_status = self .do_chat (logger = logger )
206- finally :
207- self ._post_predict_and_transform ()
203+ with otel_context (tracer , "drum.chat.completions" , request .headers ):
204+ self ._pre_predict_and_transform ()
205+ try :
206+ response , response_status = self .do_chat (logger = logger )
207+ finally :
208+ self ._post_predict_and_transform ()
208209
209210 return response , response_status
210211
@@ -226,24 +227,25 @@ def get_supported_llm_models():
226227 @model_api .route ("/directAccess/<path:path>" , methods = ["GET" , "POST" , "PUT" ])
227228 @model_api .route ("/nim/<path:path>" , methods = ["GET" , "POST" , "PUT" ])
228229 def forward_request (path ):
229- if not hasattr (self ._predictor , "openai_host" ) or not hasattr (
230- self ._predictor , "openai_port"
231- ):
232- return {
233- "message" : "This endpoint is only supported by OpenAI based predictors"
234- }, HTTP_400_BAD_REQUEST
235-
236- openai_host = self ._predictor .openai_host
237- openai_port = self ._predictor .openai_port
238-
239- resp = requests .request (
240- method = request .method ,
241- url = f"http://{ openai_host } :{ openai_port } /{ path .rstrip ('/' )} " ,
242- headers = request .headers ,
243- params = request .args ,
244- data = request .get_data (),
245- allow_redirects = False ,
246- )
230+ with otel_context (tracer , "drum.directAccess" , request .headers ) as span :
231+ if not hasattr (self ._predictor , "openai_host" ) or not hasattr (
232+ self ._predictor , "openai_port"
233+ ):
234+ msg = "This endpoint is only supported by OpenAI based predictors"
235+ span .set_status (StatusCode .ERROR , msg )
236+ return {"message" : msg }, HTTP_400_BAD_REQUEST
237+
238+ openai_host = self ._predictor .openai_host
239+ openai_port = self ._predictor .openai_port
240+
241+ resp = requests .request (
242+ method = request .method ,
243+ url = f"http://{ openai_host } :{ openai_port } /{ path .rstrip ('/' )} " ,
244+ headers = request .headers ,
245+ params = request .args ,
246+ data = request .get_data (),
247+ allow_redirects = False ,
248+ )
247249
248250 return Response (resp .content , status = resp .status_code , headers = dict (resp .headers ))
249251
0 commit comments