diff --git a/src/together/resources/audio/transcriptions.py b/src/together/resources/audio/transcriptions.py index 766d417..49aea2a 100644 --- a/src/together/resources/audio/transcriptions.py +++ b/src/together/resources/audio/transcriptions.py @@ -104,7 +104,12 @@ def create( ) # Add any additional kwargs - params_data.update(kwargs) + # Convert boolean values to lowercase strings for proper form encoding + for key, value in kwargs.items(): + if isinstance(value, bool): + params_data[key] = str(value).lower() + else: + params_data[key] = value try: response, _, _ = requestor.request( @@ -131,7 +136,8 @@ def create( response_format == "verbose_json" or response_format == AudioTranscriptionResponseFormat.VERBOSE_JSON ): - return AudioTranscriptionVerboseResponse(**response.data) + # Create response with model validation that preserves extra fields + return AudioTranscriptionVerboseResponse.model_validate(response.data) else: return AudioTranscriptionResponse(**response.data) @@ -234,7 +240,12 @@ async def create( ) # Add any additional kwargs - params_data.update(kwargs) + # Convert boolean values to lowercase strings for proper form encoding + for key, value in kwargs.items(): + if isinstance(value, bool): + params_data[key] = str(value).lower() + else: + params_data[key] = value try: response, _, _ = await requestor.arequest( @@ -261,6 +272,7 @@ async def create( response_format == "verbose_json" or response_format == AudioTranscriptionResponseFormat.VERBOSE_JSON ): - return AudioTranscriptionVerboseResponse(**response.data) + # Create response with model validation that preserves extra fields + return AudioTranscriptionVerboseResponse.model_validate(response.data) else: return AudioTranscriptionResponse(**response.data) diff --git a/src/together/types/audio_speech.py b/src/together/types/audio_speech.py index b3c110f..bb54cc7 100644 --- a/src/together/types/audio_speech.py +++ b/src/together/types/audio_speech.py @@ -158,6 +158,17 @@ class AudioTranscriptionWord(BaseModel): word: str start: float end: float + id: Optional[int] = None + speaker_id: Optional[str] = None + + +class AudioSpeakerSegment(BaseModel): + id: int + speaker_id: str + start: float + end: float + text: str + words: List[AudioTranscriptionWord] class AudioTranscriptionResponse(BaseModel): @@ -165,11 +176,13 @@ class AudioTranscriptionResponse(BaseModel): class AudioTranscriptionVerboseResponse(BaseModel): + id: Optional[str] = None language: Optional[str] = None duration: Optional[float] = None text: str segments: Optional[List[AudioTranscriptionSegment]] = None words: Optional[List[AudioTranscriptionWord]] = None + speaker_segments: Optional[List[AudioSpeakerSegment]] = None class AudioTranslationResponse(BaseModel): diff --git a/tests/integration/resources/test_transcriptions.py b/tests/integration/resources/test_transcriptions.py index 3852ebe..d8afd36 100644 --- a/tests/integration/resources/test_transcriptions.py +++ b/tests/integration/resources/test_transcriptions.py @@ -9,6 +9,47 @@ ) +def validate_diarization_response(response_dict): + """ + Helper function to validate diarization response structure + """ + # Validate top-level speaker_segments field + assert "speaker_segments" in response_dict + assert isinstance(response_dict["speaker_segments"], list) + assert len(response_dict["speaker_segments"]) > 0 + + # Validate each speaker segment structure + for segment in response_dict["speaker_segments"]: + assert "text" in segment + assert "id" in segment + assert "speaker_id" in segment + assert "start" in segment + assert "end" in segment + assert "words" in segment + + # Validate nested words in speaker segments + assert isinstance(segment["words"], list) + for word in segment["words"]: + assert "id" in word + assert "word" in word + assert "start" in word + assert "end" in word + assert "speaker_id" in word + + # Validate top-level words field + assert "words" in response_dict + assert isinstance(response_dict["words"], list) + assert len(response_dict["words"]) > 0 + + # Validate each word in top-level words + for word in response_dict["words"]: + assert "id" in word + assert "word" in word + assert "start" in word + assert "end" in word + assert "speaker_id" in word + + class TestTogetherTranscriptions: @pytest.fixture def sync_together_client(self) -> Together: @@ -116,3 +157,96 @@ def test_language_detection_hindi(self, sync_together_client): assert len(response.text) > 0 assert hasattr(response, "language") assert response.language == "hi" + + def test_diarization_default(self, sync_together_client): + """ + Test diarization with default model in verbose JSON format + """ + audio_url = "https://together-public-test-data.s3.us-west-2.amazonaws.com/audio/2-speaker-conversation.wav" + + response = sync_together_client.audio.transcriptions.create( + file=audio_url, + model="openai/whisper-large-v3", + response_format="verbose_json", + diarize=True, + ) + + assert isinstance(response, AudioTranscriptionVerboseResponse) + assert isinstance(response.text, str) + assert len(response.text) > 0 + + # Validate diarization fields + response_dict = response.model_dump() + validate_diarization_response(response_dict) + + def test_diarization_nvidia(self, sync_together_client): + """ + Test diarization with nvidia model in verbose JSON format + """ + audio_url = "https://together-public-test-data.s3.us-west-2.amazonaws.com/audio/2-speaker-conversation.wav" + + response = sync_together_client.audio.transcriptions.create( + file=audio_url, + model="openai/whisper-large-v3", + response_format="verbose_json", + diarize=True, + diarization_model="nvidia", + ) + + assert isinstance(response, AudioTranscriptionVerboseResponse) + assert isinstance(response.text, str) + assert len(response.text) > 0 + + # Validate diarization fields + response_dict = response.model_dump() + validate_diarization_response(response_dict) + + def test_diarization_pyannote(self, sync_together_client): + """ + Test diarization with pyannote model in verbose JSON format + """ + audio_url = "https://together-public-test-data.s3.us-west-2.amazonaws.com/audio/2-speaker-conversation.wav" + + response = sync_together_client.audio.transcriptions.create( + file=audio_url, + model="openai/whisper-large-v3", + response_format="verbose_json", + diarize=True, + diarization_model="pyannote", + ) + + assert isinstance(response, AudioTranscriptionVerboseResponse) + assert isinstance(response.text, str) + assert len(response.text) > 0 + + # Validate diarization fields + response_dict = response.model_dump() + validate_diarization_response(response_dict) + + def test_no_diarization(self, sync_together_client): + """ + Test with diarize=false should not have speaker segments + """ + audio_url = "https://together-public-test-data.s3.us-west-2.amazonaws.com/audio/2-speaker-conversation.wav" + + response = sync_together_client.audio.transcriptions.create( + file=audio_url, + model="openai/whisper-large-v3", + response_format="verbose_json", + diarize=False, + ) + + assert isinstance(response, AudioTranscriptionVerboseResponse) + assert isinstance(response.text, str) + assert len(response.text) > 0 + + # Verify no diarization fields + response_dict = response.model_dump() + assert response_dict.get("speaker_segments") is None + assert response_dict.get("words") is None + + # Should still have standard fields + assert "text" in response_dict + assert "language" in response_dict + assert "duration" in response_dict + assert "segments" in response_dict