diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 4ede116c..210c8c86 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -84,6 +84,8 @@ circleimageview = "3.1.0" simpleMvp = "1.0.2" dtoast = "1.1.5" preferenceKtx = "1.2.1" +tensorflowLite = "2.14.0" +tensorflowLiteSupport = "0.4.4" [libraries] androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" } @@ -177,6 +179,8 @@ circleimageview = { module = "de.hdodenhof:circleimageview", version.ref = "circ jaredrummler-simple-mvp = { module = "com.jaredrummler:simple-mvp", version.ref = "simpleMvp" } dtoast = { module = "com.github.Dovar66:DToast", version.ref = "dtoast" } androidx-preference-ktx = { module = "androidx.preference:preference-ktx", version.ref = "preferenceKtx" } +tensorflow-lite = { module = "org.tensorflow:tensorflow-lite", version.ref = "tensorflowLite" } +tensorflow-lite-support = { module = "org.tensorflow:tensorflow-lite-support", version.ref = "tensorflowLiteSupport" } [plugins] androidApplication = { id = "com.android.application", version.ref = "agp" } diff --git a/subs/ai/build.gradle.kts b/subs/ai/build.gradle.kts index d265d3e4..637dce66 100644 --- a/subs/ai/build.gradle.kts +++ b/subs/ai/build.gradle.kts @@ -30,11 +30,11 @@ android { } compileOptions { - sourceCompatibility = JavaVersion.VERSION_21 - targetCompatibility = JavaVersion.VERSION_21 + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 } kotlinOptions { - jvmTarget = "21" + jvmTarget = "17" } androidResources { noCompress("tflite") @@ -58,4 +58,8 @@ dependencies { implementation(libs.play.services.tflite.java) implementation(libs.play.services.tflite.support) + + implementation(libs.tensorflow.lite) + implementation(libs.tensorflow.lite.support) + } diff --git a/subs/ai/src/main/java/com/engineer/ai/GanActivity.kt b/subs/ai/src/main/java/com/engineer/ai/GanActivity.kt index ce43d4f4..6cb6ca7a 100644 --- a/subs/ai/src/main/java/com/engineer/ai/GanActivity.kt +++ b/subs/ai/src/main/java/com/engineer/ai/GanActivity.kt @@ -60,9 +60,13 @@ class GanActivity : AppCompatActivity() { } private fun genBitmap() { - TensorFlowLiteHelper.init(this) { - val interpreterApi = TensorFlowLiteHelper.createInterpreterApi(this, "dcgan.tflite") - interpreterApi?.let { + TensorFlowLiteHelper.init(this) { playServicesOk -> + val interpreterApi = TensorFlowLiteHelper.createInterpreterApi( + context = this, + modelName = "dcgan.tflite", + preferPlayServices = playServicesOk + ) + interpreterApi.let { Log.d(TAG, interpreterApi.getInputTensor(0).shape().contentToString()) Log.d(TAG, interpreterApi.getOutputTensor(0).shape().contentToString()) diff --git a/subs/ai/src/main/java/com/engineer/ai/util/DigitClassifier.kt b/subs/ai/src/main/java/com/engineer/ai/util/DigitClassifier.kt index e3b88992..ca5190ee 100644 --- a/subs/ai/src/main/java/com/engineer/ai/util/DigitClassifier.kt +++ b/subs/ai/src/main/java/com/engineer/ai/util/DigitClassifier.kt @@ -40,21 +40,35 @@ class DigitClassifier(private val context: Context) { private var interpreter: InterpreterApi? = null fun initialize(cb: (Boolean) -> Unit) { - TensorFlowLiteHelper.init(context) { - cb(it) - if (it) { - interpreter = TensorFlowLiteHelper.createInterpreterApi(context, "mnist.tflite") - // Read input shape from model file - interpreter?.let { inter -> - val inputShape = inter.getInputTensor(0).shape() - Log.d(TAG, "input shape = ${inputShape.contentToString()}") - Log.d(TAG, "elem shape = ${inter.getInputTensor(0).numElements()}") - Log.d(TAG, "output shape = ${inter.getOutputTensor(0).shape().contentToString()}") - inputImageWidth = inputShape[1] - inputImageHeight = inputShape[2] - modelInputSize = FLOAT_TYPE_SIZE * inputImageWidth * inputImageHeight * PIXEL_SIZE - isInitialized = true + TensorFlowLiteHelper.init(context) { playServicesOk -> + try { + interpreter = TensorFlowLiteHelper.createInterpreterApi( + context = context, + modelName = "mnist.tflite", + preferPlayServices = playServicesOk + ) + + val inter = interpreter + if (inter == null) { + isInitialized = false + cb(false) + return@init } + + val inputShape = inter.getInputTensor(0).shape() + Log.d(TAG, "input shape = ${inputShape.contentToString()}") + Log.d(TAG, "elem shape = ${inter.getInputTensor(0).numElements()}") + Log.d(TAG, "output shape = ${inter.getOutputTensor(0).shape().contentToString()}") + + inputImageWidth = inputShape[1] + inputImageHeight = inputShape[2] + modelInputSize = FLOAT_TYPE_SIZE * inputImageWidth * inputImageHeight * PIXEL_SIZE + isInitialized = true + cb(true) + } catch (t: Throwable) { + Log.e(TAG, "Failed to initialize DigitClassifier.", t) + isInitialized = false + cb(false) } } } diff --git a/subs/ai/src/main/java/com/engineer/ai/util/TensorFlowLiteHelper.kt b/subs/ai/src/main/java/com/engineer/ai/util/TensorFlowLiteHelper.kt index e2fbd014..e379b9b4 100644 --- a/subs/ai/src/main/java/com/engineer/ai/util/TensorFlowLiteHelper.kt +++ b/subs/ai/src/main/java/com/engineer/ai/util/TensorFlowLiteHelper.kt @@ -16,29 +16,45 @@ import java.nio.channels.FileChannel object TensorFlowLiteHelper { private const val TAG = "TensorFlowLiteHelper" - private lateinit var initializeTask: Task - private var interpreter: InterpreterApi? = null + /** + * Try to init Google Play Services TFLite (dynamite module). + * + * This will fail on devices without Google Play Services (e.g. many CN ROMs). + * We treat failure as non-fatal and fall back to bundled TFLite runtime. + */ fun init(context: Context, cb: (Boolean) -> Unit) { - initializeTask = TfLite.initialize(context) - initializeTask.addOnSuccessListener { - Log.d(TAG, "Initialized TFLite interpreter.") - Log.d(TAG, "ver ${TensorFlowLite.schemaVersion(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY)}") - Log.d(TAG, "ver ${TensorFlowLite.runtimeVersion(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY)}") - cb(true) - }.addOnFailureListener { - Log.d(TAG, "Initialized TFLite fail") - cb(false) - Log.e(TAG, "error ", it) - } + TfLite.initialize(context) + .addOnSuccessListener { + Log.d(TAG, "Initialized Play Services TFLite.") + try { + Log.d( + TAG, + "schema=${TensorFlowLite.schemaVersion(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY)} runtime=${TensorFlowLite.runtimeVersion(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY)}" + ) + } catch (t: Throwable) { + Log.w(TAG, "Unable to query system-only TFLite version.", t) + } + cb(true) + } + .addOnFailureListener { e -> + Log.w(TAG, "Play Services TFLite init failed; will fall back to bundled runtime.", e) + cb(false) + } } - fun createInterpreterApi(context: Context, modelName: String): InterpreterApi? { + fun createInterpreterApi(context: Context, modelName: String, preferPlayServices: Boolean): InterpreterApi { val model = loadModelFile(context.assets, modelName) - val interpreterOption = - InterpreterApi.Options().setRuntime(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY) - interpreter = InterpreterApi.create(model, interpreterOption) - return interpreter + + val runtime = if (preferPlayServices) { + InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY + } else { + // Bundled runtime provided by org.tensorflow:tensorflow-lite + InterpreterApi.Options.TfLiteRuntime.FROM_APPLICATION_ONLY + } + + val options = InterpreterApi.Options().setRuntime(runtime) + return InterpreterApi.create(model, options) } @Throws(IOException::class)