diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrCallback.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrCallback.kt index e2012d84c26..51a220167c0 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrCallback.kt +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrCallback.kt @@ -12,7 +12,7 @@ import org.pytorch.executorch.annotations.Experimental /** * Callback interface for ASR (Automatic Speech Recognition) module. Users can implement this - * interface to receive the transcribed tokens and completion notification. + * interface to receive the transcribed tokens as they are generated. * * Warning: These APIs are experimental and subject to change without notice */ @@ -20,16 +20,9 @@ import org.pytorch.executorch.annotations.Experimental interface AsrCallback { /** * Called when a new token is available from JNI. Users will keep getting onToken() invocations - * until transcription finishes. + * until transcription finishes (when the transcribe method returns). * * @param token The decoded text token */ fun onToken(token: String) - - /** - * Called when transcription is complete. - * - * @param transcription The complete transcription (may be empty if tokens were streamed) - */ - fun onComplete(transcription: String) {} } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt index b875fc05353..987cb3ec3be 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt @@ -88,7 +88,7 @@ class AsrModule( ): Int } - /** Check if the native handle is valid. */ + /** Check if the native handle is valid (not yet closed). */ val isValid: Boolean get() = nativeHandle.get() != 0L @@ -100,18 +100,13 @@ class AsrModule( } /** Releases native resources. Call this when done with the module. */ - fun destroy() { + override fun close() { val handle = nativeHandle.getAndSet(0L) if (handle != 0L) { nativeDestroy(handle) } } - /** Closeable implementation for use with use {} blocks. */ - override fun close() { - destroy() - } - /** * Force loading the module. Otherwise the model is loaded during first transcribe() call. * @@ -125,71 +120,41 @@ class AsrModule( } /** - * Transcribe audio from a WAV file with default configuration. + * Transcribe audio from a WAV file. * - * @param wavPath Path to the WAV audio file - * @param callback Callback to receive tokens, can be null - * @return 0 on success, error code otherwise - * @throws IllegalStateException if the module has been destroyed - */ - fun transcribe(wavPath: String, callback: AsrCallback? = null): Int = - transcribe(wavPath, AsrTranscribeConfig(), callback) - - /** - * Transcribe audio from a WAV file with custom configuration. + * This is a blocking call that returns the complete transcription. * * @param wavPath Path to the WAV audio file - * @param config Configuration for transcription - * @param callback Callback to receive tokens, can be null - * @return 0 on success, error code otherwise + * @param config Configuration for transcription (null uses default configuration) + * @param callback Optional callback to receive tokens as they are generated (can be null) + * @return The complete transcribed text * @throws IllegalStateException if the module has been destroyed + * @throws RuntimeException if transcription fails (non-zero result code) */ + @JvmOverloads fun transcribe( wavPath: String, - config: AsrTranscribeConfig, + config: AsrTranscribeConfig? = null, callback: AsrCallback? = null, - ): Int { + ): String { val handle = nativeHandle.get() check(handle != 0L) { "AsrModule has been destroyed" } val wavFile = File(wavPath) require(wavFile.canRead() && wavFile.isFile) { "Cannot read WAV file: $wavPath" } - return nativeTranscribe( - handle, - wavPath, - config.maxNewTokens, - config.temperature, - config.decoderStartTokenId, - callback, - ) - } - /** - * Transcribe audio from a WAV file and return the full transcription. - * - * This is a blocking call that collects all tokens and returns the complete transcription. - * - * @param wavPath Path to the WAV audio file - * @param config Configuration for transcription - * @return The transcribed text - * @throws RuntimeException if transcription fails - */ - @JvmOverloads - fun transcribeBlocking( - wavPath: String, - config: AsrTranscribeConfig = AsrTranscribeConfig(), - ): String { + val effectiveConfig = config ?: AsrTranscribeConfig() val result = StringBuilder() val status = - transcribe( + nativeTranscribe( + handle, wavPath, - config, + effectiveConfig.maxNewTokens, + effectiveConfig.temperature, + effectiveConfig.decoderStartTokenId, object : AsrCallback { override fun onToken(token: String) { result.append(token) - } - - override fun onComplete(transcription: String) { - // Tokens already collected + callback?.onToken(token) } }, ) diff --git a/extension/android/jni/jni_layer_asr.cpp b/extension/android/jni/jni_layer_asr.cpp index 50a3656b437..2eb5f0a7968 100644 --- a/extension/android/jni/jni_layer_asr.cpp +++ b/extension/android/jni/jni_layer_asr.cpp @@ -80,7 +80,6 @@ bool utf8_check_validity(const char* str, size_t length) { struct AsrCallbackCache { jclass callbackClass = nullptr; jmethodID onTokenMethod = nullptr; - jmethodID onCompleteMethod = nullptr; }; AsrCallbackCache callbackCache; @@ -97,10 +96,6 @@ void initCallbackCache(JNIEnv* env) { (jclass)localEnv->NewGlobalRef(localClass); callbackCache.onTokenMethod = localEnv->GetMethodID( callbackCache.callbackClass, "onToken", "(Ljava/lang/String;)V"); - callbackCache.onCompleteMethod = localEnv->GetMethodID( - callbackCache.callbackClass, - "onComplete", - "(Ljava/lang/String;)V"); localEnv->DeleteLocalRef(localClass); } }, @@ -411,18 +406,6 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeTranscribe( auto result = handle->runner->transcribe(featuresTensor, config, tokenCallback); - // Call onComplete if callback provided - if (scopedCallback) { - jstring emptyStr = env->NewStringUTF(""); - env->CallVoidMethod( - scopedCallback.get(), callbackCache.onCompleteMethod, emptyStr); - if (env->ExceptionCheck()) { - ET_LOG(Error, "Exception occurred in AsrCallback.onComplete"); - env->ExceptionClear(); - } - env->DeleteLocalRef(emptyStr); - } - if (!result.ok()) { return static_cast(result.error()); }