Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,17 @@ import org.pytorch.executorch.annotations.Experimental

/**
* Callback interface for ASR (Automatic Speech Recognition) module. Users can implement this
* interface to receive the transcribed tokens and completion notification.
* interface to receive the transcribed tokens as they are generated.
*
* Warning: These APIs are experimental and subject to change without notice
*/
@Experimental
interface AsrCallback {
/**
* Called when a new token is available from JNI. Users will keep getting onToken() invocations
* until transcription finishes.
* until transcription finishes (when the transcribe method returns).
*
* @param token The decoded text token
*/
fun onToken(token: String)

/**
* Called when transcription is complete.
*
* @param transcription The complete transcription (may be empty if tokens were streamed)
*/
fun onComplete(transcription: String) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class AsrModule(
): Int
}

/** Check if the native handle is valid. */
/** Check if the native handle is valid (not yet closed). */
val isValid: Boolean
get() = nativeHandle.get() != 0L

Expand All @@ -100,18 +100,13 @@ class AsrModule(
}

/** Releases native resources. Call this when done with the module. */
fun destroy() {
override fun close() {
val handle = nativeHandle.getAndSet(0L)
if (handle != 0L) {
nativeDestroy(handle)
}
}

/** Closeable implementation for use with use {} blocks. */
override fun close() {
destroy()
}

/**
* Force loading the module. Otherwise the model is loaded during first transcribe() call.
*
Expand All @@ -125,71 +120,41 @@ class AsrModule(
}

/**
* Transcribe audio from a WAV file with default configuration.
* Transcribe audio from a WAV file.
*
* @param wavPath Path to the WAV audio file
* @param callback Callback to receive tokens, can be null
* @return 0 on success, error code otherwise
* @throws IllegalStateException if the module has been destroyed
*/
fun transcribe(wavPath: String, callback: AsrCallback? = null): Int =
transcribe(wavPath, AsrTranscribeConfig(), callback)

/**
* Transcribe audio from a WAV file with custom configuration.
* This is a blocking call that returns the complete transcription.
*
* @param wavPath Path to the WAV audio file
* @param config Configuration for transcription
* @param callback Callback to receive tokens, can be null
* @return 0 on success, error code otherwise
* @param config Configuration for transcription (null uses default configuration)
* @param callback Optional callback to receive tokens as they are generated (can be null)
* @return The complete transcribed text
* @throws IllegalStateException if the module has been destroyed
* @throws RuntimeException if transcription fails (non-zero result code)
*/
@JvmOverloads
fun transcribe(
wavPath: String,
config: AsrTranscribeConfig,
config: AsrTranscribeConfig? = null,
callback: AsrCallback? = null,
): Int {
): String {
val handle = nativeHandle.get()
check(handle != 0L) { "AsrModule has been destroyed" }
val wavFile = File(wavPath)
require(wavFile.canRead() && wavFile.isFile) { "Cannot read WAV file: $wavPath" }
return nativeTranscribe(
handle,
wavPath,
config.maxNewTokens,
config.temperature,
config.decoderStartTokenId,
callback,
)
}

/**
* Transcribe audio from a WAV file and return the full transcription.
*
* This is a blocking call that collects all tokens and returns the complete transcription.
*
* @param wavPath Path to the WAV audio file
* @param config Configuration for transcription
* @return The transcribed text
* @throws RuntimeException if transcription fails
*/
@JvmOverloads
fun transcribeBlocking(
wavPath: String,
config: AsrTranscribeConfig = AsrTranscribeConfig(),
): String {
val effectiveConfig = config ?: AsrTranscribeConfig()
val result = StringBuilder()
val status =
transcribe(
nativeTranscribe(
handle,
wavPath,
config,
effectiveConfig.maxNewTokens,
effectiveConfig.temperature,
effectiveConfig.decoderStartTokenId,
object : AsrCallback {
override fun onToken(token: String) {
result.append(token)
}

override fun onComplete(transcription: String) {
// Tokens already collected
callback?.onToken(token)
}
},
)
Expand Down
17 changes: 0 additions & 17 deletions extension/android/jni/jni_layer_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ bool utf8_check_validity(const char* str, size_t length) {
struct AsrCallbackCache {
jclass callbackClass = nullptr;
jmethodID onTokenMethod = nullptr;
jmethodID onCompleteMethod = nullptr;
};

AsrCallbackCache callbackCache;
Expand All @@ -97,10 +96,6 @@ void initCallbackCache(JNIEnv* env) {
(jclass)localEnv->NewGlobalRef(localClass);
callbackCache.onTokenMethod = localEnv->GetMethodID(
callbackCache.callbackClass, "onToken", "(Ljava/lang/String;)V");
callbackCache.onCompleteMethod = localEnv->GetMethodID(
callbackCache.callbackClass,
"onComplete",
"(Ljava/lang/String;)V");
localEnv->DeleteLocalRef(localClass);
}
},
Expand Down Expand Up @@ -411,18 +406,6 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeTranscribe(
auto result =
handle->runner->transcribe(featuresTensor, config, tokenCallback);

// Call onComplete if callback provided
if (scopedCallback) {
jstring emptyStr = env->NewStringUTF("");
env->CallVoidMethod(
scopedCallback.get(), callbackCache.onCompleteMethod, emptyStr);
if (env->ExceptionCheck()) {
ET_LOG(Error, "Exception occurred in AsrCallback.onComplete");
env->ExceptionClear();
}
env->DeleteLocalRef(emptyStr);
}

if (!result.ok()) {
return static_cast<jint>(result.error());
}
Expand Down
Loading