From 7c421f5ac2a562ad4377288e5d9e0c4596cbe5bb Mon Sep 17 00:00:00 2001 From: "Kick.snare" Date: Fri, 16 Jan 2026 21:55:06 +0900 Subject: [PATCH 1/3] [AI] Add LiveServerGoAway message type Servers send goAway messages to gracefully disconnect sessions. The timeLeft field uses protobuf Duration string format. --- .../firebase/ai/type/LiveServerMessage.kt | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveServerMessage.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveServerMessage.kt index a250f4a13c9..ece023d5486 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveServerMessage.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveServerMessage.kt @@ -16,6 +16,7 @@ package com.google.firebase.ai.type +import kotlin.time.Duration.Companion.seconds import kotlinx.serialization.DeserializationStrategy import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.Serializable @@ -32,6 +33,7 @@ import kotlinx.serialization.json.jsonObject * @see LiveServerToolCall * @see LiveServerToolCallCancellation * @see LiveServerSetupComplete + * @see LiveServerGoAway */ @PublicPreviewAPI public interface LiveServerMessage @@ -182,6 +184,70 @@ public class LiveServerToolCallCancellation(public val functionIds: List } } +/** + * Notification that the server is initiating a disconnect of the session. + * + * This message is sent by the server when it needs to close the connection, typically due to + * session timeout, resource constraints, or other server-side reasons. + * + * When this message is received, the client should gracefully close the [LiveSession] by calling + * [LiveSession.close]. + * + * @property timeLeft The time remaining before the connection terminates as a duration string + * (e.g., "57s", "1.5s"). If null, the connection will terminate immediately. Use [parseTimeLeft] to + * convert this to a [kotlin.time.Duration]. + */ +@PublicPreviewAPI +public class LiveServerGoAway(public val timeLeft: String?) : LiveServerMessage { + /** + * Parses the [timeLeft] string into a [kotlin.time.Duration]. + * + * Supports protobuf Duration format: "57s", "1.5s", "0.001s", etc. Nanoseconds are expressed as + * fractional seconds (e.g., "1.000000001s"). + * + * @return The parsed duration, or null if [timeLeft] is null or cannot be parsed. + */ + public fun parseTimeLeft(): kotlin.time.Duration? { + return timeLeft?.let { parseDurationString(it) } + } + + @Serializable internal data class Internal(val timeLeft: String? = null) + @Serializable + internal data class InternalWrapper(val goAway: Internal) : InternalLiveServerMessage { + override fun toPublic() = LiveServerGoAway(goAway.timeLeft) + } +} + +/** + * Parses a protobuf Duration string (e.g., "57s", "1.5s") into a [kotlin.time.Duration]. + * + * According to the protobuf specification, the JSON representation for Duration is a String that + * ends in 's' to indicate seconds, with nanoseconds expressed as fractional seconds. + * + * @param durationString The duration string to parse (must end with 's'). + * @return The parsed duration, or null if the string cannot be parsed. + * @see Protobuf + * Duration + */ +private fun parseDurationString(durationString: String): kotlin.time.Duration? { + return try { + val trimmed = durationString.trim() + + // Protobuf Duration format: always ends with 's' (seconds) + if (!trimmed.endsWith("s")) { + return null + } + + // Remove 's' suffix and parse as double + val secondsStr = trimmed.dropLast(1) + val seconds = secondsStr.toDoubleOrNull() ?: return null + + seconds.seconds + } catch (e: Exception) { + null + } +} + @PublicPreviewAPI @Serializable(LiveServerMessageSerializer::class) internal sealed interface InternalLiveServerMessage { @@ -202,6 +268,7 @@ internal object LiveServerMessageSerializer : "toolCall" in jsonObject -> LiveServerToolCall.InternalWrapper.serializer() "toolCallCancellation" in jsonObject -> LiveServerToolCallCancellation.InternalWrapper.serializer() + "goAway" in jsonObject -> LiveServerGoAway.InternalWrapper.serializer() else -> throw SerializationException( "Unknown LiveServerMessage response type. Keys found: ${jsonObject.keys}" From 5536de9ed40e36d4f8863a2815571e059d519a72 Mon Sep 17 00:00:00 2001 From: "Kick.snare" Date: Fri, 16 Jan 2026 21:56:01 +0900 Subject: [PATCH 2/3] [AI] Handle goAway in LiveSession and add goAwayHandler Add exception handling and goAwayHandler support: - Add CoroutineExceptionHandler to prevent crashes - Handle LiveServerGoAway in processModelResponses - Add goAwayHandler to LiveAudioConversationConfig --- .../ai/type/LiveAudioConversationConfig.kt | 13 ++++ .../google/firebase/ai/type/LiveSession.kt | 70 ++++++++++++++++--- 2 files changed, 75 insertions(+), 8 deletions(-) diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveAudioConversationConfig.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveAudioConversationConfig.kt index a2360e62f46..8365ae94e5c 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveAudioConversationConfig.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveAudioConversationConfig.kt @@ -35,6 +35,10 @@ import android.media.AudioTrack * offers a final opportunity to configure these objects, which will remain valid and effective for * the duration of the current audio session. * + * @property goAwayHandler A callback that is invoked when the server initiates a disconnect via a + * [LiveServerGoAway] message. This allows the application to handle server-initiated session + * termination gracefully, such as displaying a message to the user or attempting to reconnect. + * * @property enableInterruptions If enabled, allows the user to speak over or interrupt the model's * ongoing reply. * @@ -47,6 +51,7 @@ private constructor( internal val functionCallHandler: ((FunctionCallPart) -> FunctionResponsePart)?, internal val initializationHandler: ((AudioRecord.Builder, AudioTrack.Builder) -> Unit)?, internal val transcriptHandler: ((Transcription?, Transcription?) -> Unit)?, + internal val goAwayHandler: ((LiveServerGoAway) -> Unit)?, internal val enableInterruptions: Boolean ) { @@ -62,6 +67,8 @@ private constructor( * * @property transcriptHandler See [LiveAudioConversationConfig.transcriptHandler]. * + * @property goAwayHandler See [LiveAudioConversationConfig.goAwayHandler]. + * * @property enableInterruptions See [LiveAudioConversationConfig.enableInterruptions]. */ public class Builder { @@ -69,6 +76,7 @@ private constructor( @JvmField public var initializationHandler: ((AudioRecord.Builder, AudioTrack.Builder) -> Unit)? = null @JvmField public var transcriptHandler: ((Transcription?, Transcription?) -> Unit)? = null + @JvmField public var goAwayHandler: ((LiveServerGoAway) -> Unit)? = null @JvmField public var enableInterruptions: Boolean = false public fun setFunctionCallHandler( @@ -83,6 +91,10 @@ private constructor( transcriptHandler: ((Transcription?, Transcription?) -> Unit)? ): Builder = apply { this.transcriptHandler = transcriptHandler } + public fun setGoAwayHandler(goAwayHandler: ((LiveServerGoAway) -> Unit)?): Builder = apply { + this.goAwayHandler = goAwayHandler + } + public fun setEnableInterruptions(enableInterruptions: Boolean): Builder = apply { this.enableInterruptions = enableInterruptions } @@ -93,6 +105,7 @@ private constructor( functionCallHandler = functionCallHandler, initializationHandler = initializationHandler, transcriptHandler = transcriptHandler, + goAwayHandler = goAwayHandler, enableInterruptions = enableInterruptions ) } diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt index deaf2aba079..f3cc8538837 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt @@ -44,6 +44,7 @@ import java.util.concurrent.ThreadFactory import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicLong import kotlin.coroutines.CoroutineContext +import kotlinx.coroutines.CoroutineExceptionHandler import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.asCoroutineDispatcher @@ -92,6 +93,21 @@ internal constructor( */ private var audioScope = CancelledCoroutineScope + /** + * Exception handler for unhandled exceptions in background coroutines. + * + * Logs the exception and attempts to clean up resources to prevent app crashes. + */ + private val exceptionHandler = CoroutineExceptionHandler { _, throwable -> + Log.e(TAG, "Unhandled exception in LiveSession", throwable) + // Clean up resources to prevent resource leaks + try { + stopAudioConversation() + } catch (e: Exception) { + Log.e(TAG, "Error during cleanup in exception handler", e) + } + } + /** * Playback audio data sent from the model. * @@ -118,7 +134,12 @@ internal constructor( public suspend fun startAudioConversation( functionCallHandler: ((FunctionCallPart) -> FunctionResponsePart)? = null ) { - startAudioConversation(functionCallHandler, false) + startAudioConversation( + functionCallHandler = functionCallHandler, + transcriptHandler = null, + goAwayHandler = null, + enableInterruptions = false + ) } /** @@ -143,6 +164,7 @@ internal constructor( startAudioConversation( functionCallHandler = functionCallHandler, transcriptHandler = null, + goAwayHandler = null, enableInterruptions = enableInterruptions ) } @@ -159,6 +181,10 @@ internal constructor( * transcript. The first [Transcription] object is the input transcription, and the second is the * output transcription. * + * @param goAwayHandler A callback function that is invoked when the server initiates a disconnect + * via a [LiveServerGoAway] message. This allows the application to handle server-initiated + * session termination gracefully. + * * @param enableInterruptions If enabled, allows the user to speak over or interrupt the model's * ongoing reply. * @@ -169,12 +195,14 @@ internal constructor( public suspend fun startAudioConversation( functionCallHandler: ((FunctionCallPart) -> FunctionResponsePart)? = null, transcriptHandler: ((Transcription?, Transcription?) -> Unit)? = null, + goAwayHandler: ((LiveServerGoAway) -> Unit)? = null, enableInterruptions: Boolean = false, ) { startAudioConversation( liveAudioConversationConfig { this.functionCallHandler = functionCallHandler this.transcriptHandler = transcriptHandler + this.goAwayHandler = goAwayHandler this.enableInterruptions = enableInterruptions } ) @@ -209,14 +237,20 @@ internal constructor( return@catchAsync } networkScope = - CoroutineScope(blockingDispatcher + childJob() + CoroutineName("LiveSession Network")) - audioScope = CoroutineScope(audioDispatcher + childJob() + CoroutineName("LiveSession Audio")) + CoroutineScope( + blockingDispatcher + childJob() + CoroutineName("LiveSession Network") + exceptionHandler + ) + audioScope = + CoroutineScope( + audioDispatcher + childJob() + CoroutineName("LiveSession Audio") + exceptionHandler + ) audioHelper = AudioHelper.build(liveAudioConversationConfig.initializationHandler) recordUserAudio() processModelResponses( liveAudioConversationConfig.functionCallHandler, - liveAudioConversationConfig.transcriptHandler + liveAudioConversationConfig.transcriptHandler, + liveAudioConversationConfig.goAwayHandler ) listenForModelPlayback(liveAudioConversationConfig.enableInterruptions) } @@ -272,9 +306,14 @@ internal constructor( response .getOrNull() ?.let { - JSON.decodeFromString( - it.readBytes().toString(Charsets.UTF_8) - ) + try { + JSON.decodeFromString( + it.readBytes().toString(Charsets.UTF_8) + ) + } catch (e: SerializationException) { + Log.w(TAG, "Failed to deserialize server message: ${e.message}") + null // Skip unknown messages instead of crashing + } } ?.let { emit(it.toPublic()) } // delay uses a different scheduler in the backend, so it's "stickier" in its @@ -481,10 +520,15 @@ internal constructor( * * @param functionCallHandler A callback function that is invoked whenever the server receives a * function call. + * @param transcriptHandler A callback function that is invoked whenever the server receives a + * transcript. + * @param goAwayHandler A callback function that is invoked when the server initiates a + * disconnect. */ private fun processModelResponses( functionCallHandler: ((FunctionCallPart) -> FunctionResponsePart)?, - transcriptHandler: ((Transcription?, Transcription?) -> Unit)? + transcriptHandler: ((Transcription?, Transcription?) -> Unit)?, + goAwayHandler: ((LiveServerGoAway) -> Unit)? ) { receive() .onEach { @@ -532,6 +576,16 @@ internal constructor( "The model sent LiveServerSetupComplete after the connection was established." ) } + is LiveServerGoAway -> { + val timeLeftMsg = it.timeLeft?.let { duration -> " (time left: $duration)" } ?: "" + Log.i(TAG, "Server initiated disconnect$timeLeftMsg") + + // Notify the application + goAwayHandler?.invoke(it) + + // Close the session gracefully + close() + } } } .launchIn(networkScope) From e0abffa17fbc94db93f556d2086893bb635c2ad3 Mon Sep 17 00:00:00 2001 From: "Kick.snare" Date: Fri, 16 Jan 2026 21:56:54 +0900 Subject: [PATCH 3/3] [AI] Add tests for LiveServerGoAway Add unit tests for LiveServerGoAway message handling: - Duration parsing tests (protobuf format) - Deserialization tests - Serialization schema test --- .../google/firebase/ai/SerializationTests.kt | 20 ++ .../ai/type/LiveServerMessageTests.kt | 229 ++++++++++++++++++ 2 files changed, 249 insertions(+) create mode 100644 firebase-ai/src/test/java/com/google/firebase/ai/type/LiveServerMessageTests.kt diff --git a/firebase-ai/src/test/java/com/google/firebase/ai/SerializationTests.kt b/firebase-ai/src/test/java/com/google/firebase/ai/SerializationTests.kt index 215b1eca9eb..db92327c496 100644 --- a/firebase-ai/src/test/java/com/google/firebase/ai/SerializationTests.kt +++ b/firebase-ai/src/test/java/com/google/firebase/ai/SerializationTests.kt @@ -28,6 +28,7 @@ import com.google.firebase.ai.type.GroundingChunk import com.google.firebase.ai.type.GroundingMetadata import com.google.firebase.ai.type.GroundingSupport import com.google.firebase.ai.type.ImagenReferenceImage +import com.google.firebase.ai.type.LiveServerGoAway import com.google.firebase.ai.type.ModalityTokenCount import com.google.firebase.ai.type.PublicPreviewAPI import com.google.firebase.ai.type.Schema @@ -593,4 +594,23 @@ internal class SerializationTests { val actualJson = descriptorToJson(UrlContext.Internal.serializer().descriptor) expectedJsonAsString shouldEqualJson actualJson.toString() } + + @Test + fun `test LiveServerGoAway serialization as Json`() { + val expectedJsonAsString = + """ + { + "id": "LiveServerGoAway", + "type": "object", + "properties": { + "timeLeft": { + "type": "string" + } + } + } + """ + .trimIndent() + val actualJson = descriptorToJson(LiveServerGoAway.Internal.serializer().descriptor) + expectedJsonAsString shouldEqualJson actualJson.toString() + } } diff --git a/firebase-ai/src/test/java/com/google/firebase/ai/type/LiveServerMessageTests.kt b/firebase-ai/src/test/java/com/google/firebase/ai/type/LiveServerMessageTests.kt new file mode 100644 index 00000000000..ae43784b3cf --- /dev/null +++ b/firebase-ai/src/test/java/com/google/firebase/ai/type/LiveServerMessageTests.kt @@ -0,0 +1,229 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.ai.type + +import com.google.firebase.ai.common.JSON +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.matchers.nulls.shouldBeNull +import io.kotest.matchers.nulls.shouldNotBeNull +import io.kotest.matchers.shouldBe +import io.kotest.matchers.types.shouldBeInstanceOf +import org.junit.Test + +@OptIn(PublicPreviewAPI::class) +internal class LiveServerMessageTests { + + // ===== Duration Parsing Tests ===== + + @Test + fun `parseTimeLeft with integer seconds`() { + val goAway = LiveServerGoAway("57s") + val duration = goAway.parseTimeLeft() + + duration.shouldNotBeNull() + duration.inWholeSeconds shouldBe 57 + } + + @Test + fun `parseTimeLeft with fractional seconds`() { + val goAway = LiveServerGoAway("1.5s") + val duration = goAway.parseTimeLeft() + + duration.shouldNotBeNull() + duration.inWholeMilliseconds shouldBe 1500 + } + + @Test + fun `parseTimeLeft with small fractional seconds (nanoseconds)`() { + val goAway = LiveServerGoAway("0.000000001s") + val duration = goAway.parseTimeLeft() + + duration.shouldNotBeNull() + duration.inWholeNanoseconds shouldBe 1 + } + + @Test + fun `parseTimeLeft with zero seconds`() { + val goAway = LiveServerGoAway("0s") + val duration = goAway.parseTimeLeft() + + duration.shouldNotBeNull() + duration.inWholeSeconds shouldBe 0 + } + + @Test + fun `parseTimeLeft with null timeLeft`() { + val goAway = LiveServerGoAway(null) + val duration = goAway.parseTimeLeft() + + duration.shouldBeNull() + } + + @Test + fun `parseTimeLeft with invalid format returns null`() { + val goAway = LiveServerGoAway("invalid") + val duration = goAway.parseTimeLeft() + + duration.shouldBeNull() + } + + @Test + fun `parseTimeLeft with non-second units returns null`() { + val goAway = LiveServerGoAway("100ms") + val duration = goAway.parseTimeLeft() + + // Protobuf Duration only uses 's' suffix + duration.shouldBeNull() + } + + // ===== LiveServerGoAway Deserialization Tests ===== + + @Test + fun `LiveServerGoAway with timeLeft as string`() { + val json = + """ + { + "goAway": { + "timeLeft": "57s" + } + } + """ + .trimIndent() + + val message = JSON.decodeFromString(json) + val goAway = message.toPublic() + + goAway.shouldBeInstanceOf() + goAway.timeLeft shouldBe "57s" + + // Test parsing + val duration = goAway.parseTimeLeft() + duration.shouldNotBeNull() + duration.inWholeSeconds shouldBe 57 + } + + @Test + fun `LiveServerGoAway with fractional seconds string`() { + val json = + """ + { + "goAway": { + "timeLeft": "1.5s" + } + } + """ + .trimIndent() + + val message = JSON.decodeFromString(json) + val goAway = message.toPublic() as LiveServerGoAway + + goAway.timeLeft shouldBe "1.5s" + + val duration = goAway.parseTimeLeft() + duration.shouldNotBeNull() + duration.inWholeMilliseconds shouldBe 1500 + } + + @Test + fun `LiveServerGoAway with null timeLeft`() { + val json = """{"goAway": {"timeLeft": null}}""" + + val message = JSON.decodeFromString(json) + val goAway = message.toPublic() + + goAway.shouldBeInstanceOf() + (goAway as LiveServerGoAway).timeLeft.shouldBeNull() + } + + @Test + fun `LiveServerGoAway with missing timeLeft field`() { + val json = """{"goAway": {}}""" + + val message = JSON.decodeFromString(json) + val goAway = message.toPublic() as LiveServerGoAway + + goAway.timeLeft.shouldBeNull() + } + + // ===== Polymorphic Serializer Tests ===== + + @Test + fun `LiveServerMessageSerializer recognizes goAway message`() { + val json = + """ + { + "goAway": { + "timeLeft": "30s" + } + } + """ + .trimIndent() + + // Should not throw SerializationException + val message = JSON.decodeFromString(json) + message.toPublic().shouldBeInstanceOf() + } + + @Test + fun `LiveServerMessageSerializer throws on unknown message type`() { + val json = """{"unknownType": {"data": "value"}}""" + + shouldThrow { JSON.decodeFromString(json) } + } + + @Test + fun `LiveServerMessageSerializer recognizes serverContent message`() { + val json = + """ + { + "serverContent": { + "modelTurn": null, + "interrupted": false, + "turnComplete": false + } + } + """ + .trimIndent() + + val message = JSON.decodeFromString(json) + message.toPublic().shouldBeInstanceOf() + } + + @Test + fun `LiveServerMessageSerializer recognizes setupComplete message`() { + val json = """{"setupComplete": {}}""" + + val message = JSON.decodeFromString(json) + message.toPublic().shouldBeInstanceOf() + } + + @Test + fun `LiveServerMessageSerializer recognizes toolCall message`() { + val json = """{"toolCall": {"functionCalls": []}}""" + + val message = JSON.decodeFromString(json) + message.toPublic().shouldBeInstanceOf() + } + + @Test + fun `LiveServerMessageSerializer recognizes toolCallCancellation message`() { + val json = """{"toolCallCancellation": {"functionIds": []}}""" + + val message = JSON.decodeFromString(json) + message.toPublic().shouldBeInstanceOf() + } +}