From 0ff5fc98445d7852c20c5922039ac8a6616fc725 Mon Sep 17 00:00:00 2001 From: qyou Date: Fri, 5 Dec 2025 10:51:18 +0800 Subject: [PATCH] =?UTF-8?q?feat(audio):=20=E6=B7=BB=E5=8A=A0=20ASR=20?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E6=94=AF=E6=8C=81=E4=BB=A5=E5=A2=9E=E5=BC=BA?= =?UTF-8?q?=E8=AF=AD=E9=9F=B3=E8=BD=AC=E5=BD=95=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在 AsrConfig 类中新增多个配置项,包括 enable_ddc、enable_itn 和 enable_punc,默认值设为 true - 为 userLanguage 字段设置默认值 "common" - 在 RoomConfig 中为 roomMode 设置默认空字符串并添加 Builder 默认注解 - 更新 TranscriptionsUpdateEventData 模型以支持 asrConfig 参数 - 扩展 WebSocket 客户端测试用例,验证带和不带 asrConfig 的事件处理逻辑 - 更新示例代码以演示如何传递 ASR 配置进行实时转录 - 修复测试工具类中的序列化版本 UID 缺失问题 --- .../client/audio/rooms/model/RoomConfig.java | 1 + .../websocket/event/model/AsrConfig.java | 15 +++- .../model/TranscriptionsUpdateEventData.java | 7 ++ ...ebsocketAudioTranscriptionsClientTest.java | 88 +++++++++++++++++++ .../java/com/coze/openapi/utils/Utils.java | 2 + .../WebsocketTranscriptionsExample.java | 13 ++- 6 files changed, 122 insertions(+), 4 deletions(-) diff --git a/api/src/main/java/com/coze/openapi/client/audio/rooms/model/RoomConfig.java b/api/src/main/java/com/coze/openapi/client/audio/rooms/model/RoomConfig.java index 3c14f6bd..83e303d4 100644 --- a/api/src/main/java/com/coze/openapi/client/audio/rooms/model/RoomConfig.java +++ b/api/src/main/java/com/coze/openapi/client/audio/rooms/model/RoomConfig.java @@ -20,6 +20,7 @@ public static RoomConfig of(AudioCodec codec) { } @JsonProperty("room_mode") + @Builder.Default private String roomMode = ""; @JsonProperty("translate_config") diff --git a/api/src/main/java/com/coze/openapi/client/websocket/event/model/AsrConfig.java b/api/src/main/java/com/coze/openapi/client/websocket/event/model/AsrConfig.java index 2f549cd8..80062b02 100644 --- a/api/src/main/java/com/coze/openapi/client/websocket/event/model/AsrConfig.java +++ b/api/src/main/java/com/coze/openapi/client/websocket/event/model/AsrConfig.java @@ -19,5 +19,18 @@ public class AsrConfig { private String context; @JsonProperty("user_language") - private String userLanguage; + @Builder.Default + private String userLanguage = "common"; + + @JsonProperty("enable_ddc") + @Builder.Default + private Boolean enableDdc = true; + + @JsonProperty("enable_itn") + @Builder.Default + private Boolean enableItn = true; + + @JsonProperty("enable_punc") + @Builder.Default + private Boolean enablePunc = true; } diff --git a/api/src/main/java/com/coze/openapi/client/websocket/event/model/TranscriptionsUpdateEventData.java b/api/src/main/java/com/coze/openapi/client/websocket/event/model/TranscriptionsUpdateEventData.java index f841a3a1..2fbc212d 100644 --- a/api/src/main/java/com/coze/openapi/client/websocket/event/model/TranscriptionsUpdateEventData.java +++ b/api/src/main/java/com/coze/openapi/client/websocket/event/model/TranscriptionsUpdateEventData.java @@ -12,4 +12,11 @@ public class TranscriptionsUpdateEventData { @JsonProperty("input_audio") private InputAudio inputAudio; + + @JsonProperty("asr_config") + private AsrConfig asrConfig; + + public TranscriptionsUpdateEventData(InputAudio inputAudio) { + this.inputAudio = inputAudio; + } } diff --git a/api/src/test/java/com/coze/openapi/service/service/websocket/audio/transcriptions/WebsocketAudioTranscriptionsClientTest.java b/api/src/test/java/com/coze/openapi/service/service/websocket/audio/transcriptions/WebsocketAudioTranscriptionsClientTest.java index b9b8b6df..17188ed8 100644 --- a/api/src/test/java/com/coze/openapi/service/service/websocket/audio/transcriptions/WebsocketAudioTranscriptionsClientTest.java +++ b/api/src/test/java/com/coze/openapi/service/service/websocket/audio/transcriptions/WebsocketAudioTranscriptionsClientTest.java @@ -6,6 +6,8 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; +import java.util.Arrays; + import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; @@ -15,6 +17,7 @@ import com.coze.openapi.client.websocket.event.EventType; import com.coze.openapi.client.websocket.event.downstream.*; +import com.coze.openapi.client.websocket.event.model.AsrConfig; import com.coze.openapi.client.websocket.event.model.InputAudio; import com.coze.openapi.client.websocket.event.model.TranscriptionsUpdateEventData; @@ -117,6 +120,68 @@ public void testHandleTranscriptionsUpdatedEvent() { assertEquals(1, event.getData().getInputAudio().getChannel()); assertEquals(16, event.getData().getInputAudio().getBitDepth()); + // 验证 asr_config 为 null(向后兼容) + assertNull(event.getData().getAsrConfig()); + + // 验证 detail + assertEquals("20241210152726467C48D89D6DB2F3***", event.getDetail().getLogID()); + } + + @Test + public void testHandleTranscriptionsUpdatedEventWithAsrConfig() { + String json = + "{\n" + + " \"id\": \"event_id\",\n" + + " \"event_type\": \"transcriptions.updated\",\n" + + " \"data\": {\n" + + " \"input_audio\": {\n" + + " \"format\": \"pcm\",\n" + + " \"codec\": \"pcm\",\n" + + " \"sample_rate\": 24000,\n" + + " \"channel\": 1,\n" + + " \"bit_depth\": 16\n" + + " },\n" + + " \"asr_config\": {\n" + + " \"hot_words\": [\"Coze\", \"AI\"],\n" + + " \"context\": \"Coze AI\",\n" + + " \"user_language\": \"en-US\",\n" + + " \"enable_ddc\": true,\n" + + " \"enable_itn\": true,\n" + + " \"enable_punc\": true\n" + + " }\n" + + " },\n" + + " \"detail\": {\n" + + " \"logid\": \"20241210152726467C48D89D6DB2F3***\"\n" + + " }\n" + + "}\n"; + + client.handleEvent(mockWebSocket, json); + + verify(mockCallbackHandler) + .onTranscriptionsUpdated(eq(client), transcriptionsUpdatedEventCaptor.capture()); + + TranscriptionsUpdatedEvent event = transcriptionsUpdatedEventCaptor.getValue(); + assertEquals(EventType.TRANSCRIPTIONS_UPDATED, event.getEventType()); + assertEquals("event_id", event.getId()); + + // 验证 data + assertEquals("pcm", event.getData().getInputAudio().getFormat()); + assertEquals("pcm", event.getData().getInputAudio().getCodec()); + assertEquals(24000, event.getData().getInputAudio().getSampleRate()); + assertEquals(1, event.getData().getInputAudio().getChannel()); + assertEquals(16, event.getData().getInputAudio().getBitDepth()); + + // 验证 asr_config + assertNotNull(event.getData().getAsrConfig()); + assertEquals("en-US", event.getData().getAsrConfig().getUserLanguage()); + assertEquals("Coze AI", event.getData().getAsrConfig().getContext()); + assertTrue(event.getData().getAsrConfig().getEnableDdc()); + assertTrue(event.getData().getAsrConfig().getEnableItn()); + assertTrue(event.getData().getAsrConfig().getEnablePunc()); + assertEquals(2, event.getData().getAsrConfig().getHotWords().size()); + assertTrue(event.getData().getAsrConfig().getHotWords().contains("Coze")); + assertTrue(event.getData().getAsrConfig().getHotWords().contains("AI")); + // 验证 detail assertEquals("20241210152726467C48D89D6DB2F3***", event.getDetail().getLogID()); } @@ -287,6 +352,12 @@ void testTranscriptionsUpdate() { .channel(1) .bitDepth(16) .build()) + .asrConfig( + AsrConfig.builder() + .hotWords(Arrays.asList("Coze", "AI")) + .context("Real-time transcription") + .userLanguage("en-US") + .build()) .build(); client.transcriptionsUpdate(data); @@ -294,6 +365,23 @@ void testTranscriptionsUpdate() { verify(mockWebSocket).send(anyString()); // 验证发送了消息 } + @Test + void testTranscriptionsUpdateWithoutAsrConfig() { + TranscriptionsUpdateEventData data = + new TranscriptionsUpdateEventData( + InputAudio.builder() + .format("pcm") + .codec("pcm") + .sampleRate(24000) + .channel(1) + .bitDepth(16) + .build()); + + client.transcriptionsUpdate(data); + + verify(mockWebSocket).send(anyString()); // 验证发送了消息 + } + @Test void testInputAudioBufferAppendWithString() { String audioData = "base64EncodedAudioData"; diff --git a/api/src/test/java/com/coze/openapi/utils/Utils.java b/api/src/test/java/com/coze/openapi/utils/Utils.java index fd98aa60..5e5337b3 100644 --- a/api/src/test/java/com/coze/openapi/utils/Utils.java +++ b/api/src/test/java/com/coze/openapi/utils/Utils.java @@ -10,6 +10,8 @@ public class Utils { private static final Headers commonHeader = Headers.of( new HashMap() { + private static final long serialVersionUID = 1L; + { put(LOG_HEADER, TEST_LOG_ID); } diff --git a/example/src/main/java/example/websocket/audio/transcriptions/WebsocketTranscriptionsExample.java b/example/src/main/java/example/websocket/audio/transcriptions/WebsocketTranscriptionsExample.java index afdfd51f..05cd68b8 100644 --- a/example/src/main/java/example/websocket/audio/transcriptions/WebsocketTranscriptionsExample.java +++ b/example/src/main/java/example/websocket/audio/transcriptions/WebsocketTranscriptionsExample.java @@ -2,7 +2,6 @@ import java.io.IOException; import java.io.InputStream; -import java.nio.ByteBuffer; import java.util.Arrays; import java.util.concurrent.TimeUnit; @@ -10,6 +9,7 @@ import com.coze.openapi.client.audio.speech.CreateSpeechReq; import com.coze.openapi.client.audio.speech.CreateSpeechResp; import com.coze.openapi.client.websocket.event.downstream.*; +import com.coze.openapi.client.websocket.event.model.AsrConfig; import com.coze.openapi.client.websocket.event.model.InputAudio; import com.coze.openapi.client.websocket.event.model.TranscriptionsUpdateEventData; import com.coze.openapi.service.auth.TokenAuth; @@ -27,7 +27,6 @@ public class WebsocketTranscriptionsExample { public static boolean isDone = false; private static class CallbackHandler extends WebsocketsAudioTranscriptionsCallbackHandler { - private final ByteBuffer buffer = ByteBuffer.allocate(1024 * 1024 * 10); // 分配 10MB 缓冲区 public CallbackHandler() { super(); @@ -120,7 +119,15 @@ public static void main(String[] args) throws Exception { InputAudio inputAudio = InputAudio.builder().sampleRate(24000).codec("pcm").format("wav").channel(2).build(); - client.transcriptionsUpdate(new TranscriptionsUpdateEventData(inputAudio)); + + AsrConfig asrConfig = + AsrConfig.builder() + .hotWords(Arrays.asList("Coze", "AI")) + .context("Real-time transcription") + .userLanguage("en-US") + .build(); + + client.transcriptionsUpdate(new TranscriptionsUpdateEventData(inputAudio, asrConfig)); try (InputStream inputStream = speechResp.getResponse().byteStream()) { byte[] buffer = new byte[1024];