Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 45 additions & 48 deletions src/agents/voice/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,24 +84,20 @@ async def _process_audio_input(self, audio_input: AudioInput) -> str:
)

async def _run_single_turn(self, audio_input: AudioInput) -> StreamedAudioResult:
# Since this is single turn, we can use the TraceCtxManager to manage starting/ending the
# trace
with TraceCtxManager(
workflow_name=self.config.workflow_name or "Voice Agent",
trace_id=None, # Automatically generated
group_id=self.config.group_id,
metadata=self.config.trace_metadata,
tracing=self.config.tracing,
disabled=self.config.tracing_disabled,
):
input_text = await self._process_audio_input(audio_input)

output = StreamedAudioResult(
self._get_tts_model(), self.config.tts_settings, self.config
)

async def stream_events():
output = StreamedAudioResult(self._get_tts_model(), self.config.tts_settings, self.config)

async def stream_events():
# Keep the trace scope active for the entire async processing lifecycle.
with TraceCtxManager(
workflow_name=self.config.workflow_name or "Voice Agent",
trace_id=None, # Automatically generated
group_id=self.config.group_id,
metadata=self.config.trace_metadata,
tracing=self.config.tracing,
disabled=self.config.tracing_disabled,
):
try:
input_text = await self._process_audio_input(audio_input)
async for text_event in self.workflow.run(input_text):
await output._add_text(text_event)
await output._turn_done()
Expand All @@ -111,37 +107,37 @@ async def stream_events():
await output._add_error(e)
raise e

output._set_task(asyncio.create_task(stream_events()))
return output
output._set_task(asyncio.create_task(stream_events()))
return output

async def _run_multi_turn(self, audio_input: StreamedAudioInput) -> StreamedAudioResult:
with TraceCtxManager(
workflow_name=self.config.workflow_name or "Voice Agent",
trace_id=None,
group_id=self.config.group_id,
metadata=self.config.trace_metadata,
tracing=self.config.tracing,
disabled=self.config.tracing_disabled,
):
output = StreamedAudioResult(
self._get_tts_model(), self.config.tts_settings, self.config
)

try:
async for intro_text in self.workflow.on_start():
await output._add_text(intro_text)
except Exception as e:
logger.warning(f"on_start() failed: {e}")

transcription_session = await self._get_stt_model().create_session(
audio_input,
self.config.stt_settings,
self.config.trace_include_sensitive_data,
self.config.trace_include_sensitive_audio_data,
)

async def process_turns():
output = StreamedAudioResult(self._get_tts_model(), self.config.tts_settings, self.config)

async def process_turns():
# Keep the trace scope active for the full streamed session.
with TraceCtxManager(
workflow_name=self.config.workflow_name or "Voice Agent",
trace_id=None,
group_id=self.config.group_id,
metadata=self.config.trace_metadata,
tracing=self.config.tracing,
disabled=self.config.tracing_disabled,
):
transcription_session = None
try:
try:
async for intro_text in self.workflow.on_start():
await output._add_text(intro_text)
except Exception as e:
logger.warning(f"on_start() failed: {e}")

transcription_session = await self._get_stt_model().create_session(
audio_input,
self.config.stt_settings,
self.config.trace_include_sensitive_data,
self.config.trace_include_sensitive_audio_data,
)

async for input_text in transcription_session.transcribe_turns():
result = self.workflow.run(input_text)
async for text_event in result:
Expand All @@ -152,8 +148,9 @@ async def process_turns():
await output._add_error(e)
raise e
finally:
await transcription_session.close()
if transcription_session is not None:
await transcription_session.close()
await output._done()

output._set_task(asyncio.create_task(process_turns()))
return output
output._set_task(asyncio.create_task(process_turns()))
return output
11 changes: 11 additions & 0 deletions src/agents/voice/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def _check_errors(self):

async def stream(self) -> AsyncIterator[VoiceStreamEvent]:
"""Stream the events and audio data as they're generated."""
saw_session_end = False
while True:
try:
event = await self._queue.get()
Expand All @@ -278,8 +279,18 @@ async def stream(self) -> AsyncIterator[VoiceStreamEvent]:
break
yield event
if event.type == "voice_stream_event_lifecycle" and event.event == "session_ended":
saw_session_end = True
break

# On the normal completion path, let the producer task finish gracefully so any active
# trace context can emit `trace_end` before we run cleanup.
if (
saw_session_end
and self.text_generation_task is not None
and not self.text_generation_task.done()
):
await asyncio.shield(self.text_generation_task)

self._check_errors()
self._cleanup_tasks()

Expand Down
72 changes: 72 additions & 0 deletions tests/voice/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

import asyncio

import numpy as np
import numpy.typing as npt
import pytest

from tests.testing_processor import fetch_events

try:
from agents.voice import AudioInput, TTSModelSettings, VoicePipeline, VoicePipelineConfig

Expand Down Expand Up @@ -177,3 +181,71 @@ def _transform_data(
"session_ended",
]
await fake_tts.verify_audio("out_1", audio_chunks[0], dtype=np.int16)


class _BlockingWorkflow(FakeWorkflow):
def __init__(self, gate: asyncio.Event):
super().__init__()
self._gate = gate

async def run(self, _: str):
await self._gate.wait()
yield "out_1"


class _OnStartYieldThenFailWorkflow(FakeWorkflow):
async def on_start(self):
yield "intro"
raise RuntimeError("boom")


@pytest.mark.asyncio
async def test_voicepipeline_trace_not_finished_before_single_turn_completes() -> None:
fake_stt = FakeSTT(["first"])
fake_tts = FakeTTS()
gate = asyncio.Event()
workflow = _BlockingWorkflow(gate)
config = VoicePipelineConfig(tts_settings=TTSModelSettings(buffer_size=1))
pipeline = VoicePipeline(
workflow=workflow, stt_model=fake_stt, tts_model=fake_tts, config=config
)

audio_input = AudioInput(buffer=np.zeros(2, dtype=np.int16))
result = await pipeline.run(audio_input)
await asyncio.sleep(0)

events_before_unblock = fetch_events()
assert "trace_start" in events_before_unblock
assert "trace_end" not in events_before_unblock

gate.set()
await extract_events(result)
assert fetch_events()[-1] == "trace_end"


@pytest.mark.asyncio
async def test_voicepipeline_trace_finishes_after_multi_turn_processing() -> None:
fake_stt = FakeSTT(["first", "second"])
workflow = FakeWorkflow([["out_1"], ["out_2"]])
fake_tts = FakeTTS()
pipeline = VoicePipeline(workflow=workflow, stt_model=fake_stt, tts_model=fake_tts)

streamed_audio_input = await FakeStreamedAudioInput.get(count=2)
result = await pipeline.run(streamed_audio_input)
await extract_events(result)
assert fetch_events()[-1] == "trace_end"


@pytest.mark.asyncio
async def test_voicepipeline_multi_turn_on_start_exception_does_not_abort() -> None:
fake_stt = FakeSTT(["first"])
workflow = _OnStartYieldThenFailWorkflow([["out_1"]])
fake_tts = FakeTTS()
pipeline = VoicePipeline(workflow=workflow, stt_model=fake_stt, tts_model=fake_tts)

streamed_audio_input = await FakeStreamedAudioInput.get(count=1)
result = await pipeline.run(streamed_audio_input)
events, _ = await extract_events(result)

assert events[-1] == "session_ended"
assert "error" not in events