2525)
2626
2727import asyncio
28+ from collections .abc import Awaitable
2829import queue
2930import threading
3031import warnings
@@ -231,9 +232,10 @@ def _warn(msg: str):
231232 _warn ._LOGGER .warning (msg ) # pyright: ignore[reportFunctionMemberAccess]
232233
233234
234- def _force_flush_traces ( ):
235+ async def _force_flush_otel ( tracing_enabled : bool , logging_enabled : bool ):
235236 try :
236237 import opentelemetry .trace
238+ import opentelemetry ._logs
237239 except (ImportError , AttributeError ):
238240 _warn (
239241 "Could not force flush traces. opentelemetry-api is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
@@ -242,15 +244,26 @@ def _force_flush_traces():
242244
243245 try :
244246 import opentelemetry .sdk .trace
247+ import opentelemetry .sdk ._logs
245248 except (ImportError , AttributeError ):
246249 _warn (
247250 "Could not force flush traces. opentelemetry-sdk is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
248251 )
249252 return None
250253
251- provider = opentelemetry .trace .get_tracer_provider ()
252- if isinstance (provider , opentelemetry .sdk .trace .TracerProvider ):
253- _ = provider .force_flush ()
254+ coros : List [Awaitable [bool ]] = []
255+
256+ if tracing_enabled :
257+ tracer_provider = opentelemetry .trace .get_tracer_provider ()
258+ if isinstance (tracer_provider , opentelemetry .sdk .trace .TracerProvider ):
259+ coros .append (asyncio .to_thread (tracer_provider .force_flush ))
260+
261+ if logging_enabled :
262+ logger_provider = opentelemetry ._logs .get_logger_provider ()
263+ if isinstance (logger_provider , opentelemetry .sdk ._logs .LoggerProvider ):
264+ coros .append (asyncio .to_thread (logger_provider .force_flush ))
265+
266+ await asyncio .gather (* coros , return_exceptions = True )
254267
255268
256269def _default_instrumentor_builder (
@@ -900,9 +913,11 @@ async def async_stream_query(
900913 # Yield the event data as a dictionary
901914 yield _utils .dump_event_for_json (event )
902915 finally :
903- # Avoid trace data loss having to do with CPU throttling on instance turndown
904- if self ._tracing_enabled ():
905- _ = await asyncio .to_thread (_force_flush_traces )
916+ # Avoid telemetry data loss having to do with CPU throttling on instance turndown
917+ _ = await _force_flush_otel (
918+ tracing_enabled = self ._tracing_enabled (),
919+ logging_enabled = self ._telemetry_enabled (),
920+ )
906921
907922 def stream_query (
908923 self ,
@@ -1054,9 +1069,11 @@ async def streaming_agent_run_with_events(self, request_json: str):
10541069 user_id = request .user_id ,
10551070 session_id = session .id ,
10561071 )
1057- # Avoid trace data loss having to do with CPU throttling on instance turndown
1058- if self ._tracing_enabled ():
1059- _ = await asyncio .to_thread (_force_flush_traces )
1072+ # Avoid telemetry data loss having to do with CPU throttling on instance turndown
1073+ _ = await _force_flush_otel (
1074+ tracing_enabled = self ._tracing_enabled (),
1075+ logging_enabled = self ._telemetry_enabled (),
1076+ )
10601077
10611078 async def async_get_session (
10621079 self ,
0 commit comments