Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ circleimageview = "3.1.0"
simpleMvp = "1.0.2"
dtoast = "1.1.5"
preferenceKtx = "1.2.1"
tensorflowLite = "2.14.0"
tensorflowLiteSupport = "0.4.4"

[libraries]
androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" }
Expand Down Expand Up @@ -177,6 +179,8 @@ circleimageview = { module = "de.hdodenhof:circleimageview", version.ref = "circ
jaredrummler-simple-mvp = { module = "com.jaredrummler:simple-mvp", version.ref = "simpleMvp" }
dtoast = { module = "com.github.Dovar66:DToast", version.ref = "dtoast" }
androidx-preference-ktx = { module = "androidx.preference:preference-ktx", version.ref = "preferenceKtx" }
tensorflow-lite = { module = "org.tensorflow:tensorflow-lite", version.ref = "tensorflowLite" }
tensorflow-lite-support = { module = "org.tensorflow:tensorflow-lite-support", version.ref = "tensorflowLiteSupport" }

[plugins]
androidApplication = { id = "com.android.application", version.ref = "agp" }
Expand Down
10 changes: 7 additions & 3 deletions subs/ai/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ android {
}

compileOptions {
sourceCompatibility = JavaVersion.VERSION_21
targetCompatibility = JavaVersion.VERSION_21
sourceCompatibility = JavaVersion.VERSION_17
targetCompatibility = JavaVersion.VERSION_17
}
kotlinOptions {
jvmTarget = "21"
jvmTarget = "17"
}
androidResources {
noCompress("tflite")
Expand All @@ -58,4 +58,8 @@ dependencies {

implementation(libs.play.services.tflite.java)
implementation(libs.play.services.tflite.support)

implementation(libs.tensorflow.lite)
implementation(libs.tensorflow.lite.support)

}
10 changes: 7 additions & 3 deletions subs/ai/src/main/java/com/engineer/ai/GanActivity.kt
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,13 @@ class GanActivity : AppCompatActivity() {
}

private fun genBitmap() {
TensorFlowLiteHelper.init(this) {
val interpreterApi = TensorFlowLiteHelper.createInterpreterApi(this, "dcgan.tflite")
interpreterApi?.let {
TensorFlowLiteHelper.init(this) { playServicesOk ->
val interpreterApi = TensorFlowLiteHelper.createInterpreterApi(
context = this,
modelName = "dcgan.tflite",
preferPlayServices = playServicesOk
)
interpreterApi.let {
Log.d(TAG, interpreterApi.getInputTensor(0).shape().contentToString())
Log.d(TAG, interpreterApi.getOutputTensor(0).shape().contentToString())

Expand Down
42 changes: 28 additions & 14 deletions subs/ai/src/main/java/com/engineer/ai/util/DigitClassifier.kt
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,35 @@ class DigitClassifier(private val context: Context) {
private var interpreter: InterpreterApi? = null

fun initialize(cb: (Boolean) -> Unit) {
TensorFlowLiteHelper.init(context) {
cb(it)
if (it) {
interpreter = TensorFlowLiteHelper.createInterpreterApi(context, "mnist.tflite")
// Read input shape from model file
interpreter?.let { inter ->
val inputShape = inter.getInputTensor(0).shape()
Log.d(TAG, "input shape = ${inputShape.contentToString()}")
Log.d(TAG, "elem shape = ${inter.getInputTensor(0).numElements()}")
Log.d(TAG, "output shape = ${inter.getOutputTensor(0).shape().contentToString()}")
inputImageWidth = inputShape[1]
inputImageHeight = inputShape[2]
modelInputSize = FLOAT_TYPE_SIZE * inputImageWidth * inputImageHeight * PIXEL_SIZE
isInitialized = true
TensorFlowLiteHelper.init(context) { playServicesOk ->
try {
interpreter = TensorFlowLiteHelper.createInterpreterApi(
context = context,
modelName = "mnist.tflite",
preferPlayServices = playServicesOk
)

val inter = interpreter
if (inter == null) {
isInitialized = false
cb(false)
return@init
}

val inputShape = inter.getInputTensor(0).shape()
Log.d(TAG, "input shape = ${inputShape.contentToString()}")
Log.d(TAG, "elem shape = ${inter.getInputTensor(0).numElements()}")
Log.d(TAG, "output shape = ${inter.getOutputTensor(0).shape().contentToString()}")

inputImageWidth = inputShape[1]
inputImageHeight = inputShape[2]
modelInputSize = FLOAT_TYPE_SIZE * inputImageWidth * inputImageHeight * PIXEL_SIZE
isInitialized = true
cb(true)
} catch (t: Throwable) {
Log.e(TAG, "Failed to initialize DigitClassifier.", t)
isInitialized = false
cb(false)
}
}
}
Expand Down
52 changes: 34 additions & 18 deletions subs/ai/src/main/java/com/engineer/ai/util/TensorFlowLiteHelper.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,45 @@ import java.nio.channels.FileChannel

object TensorFlowLiteHelper {
private const val TAG = "TensorFlowLiteHelper"
private lateinit var initializeTask: Task<Void>
private var interpreter: InterpreterApi? = null

/**
* Try to init Google Play Services TFLite (dynamite module).
*
* This will fail on devices without Google Play Services (e.g. many CN ROMs).
* We treat failure as non-fatal and fall back to bundled TFLite runtime.
*/
fun init(context: Context, cb: (Boolean) -> Unit) {
initializeTask = TfLite.initialize(context)
initializeTask.addOnSuccessListener {
Log.d(TAG, "Initialized TFLite interpreter.")
Log.d(TAG, "ver ${TensorFlowLite.schemaVersion(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY)}")
Log.d(TAG, "ver ${TensorFlowLite.runtimeVersion(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY)}")
cb(true)
}.addOnFailureListener {
Log.d(TAG, "Initialized TFLite fail")
cb(false)
Log.e(TAG, "error ", it)
}
TfLite.initialize(context)
.addOnSuccessListener {
Log.d(TAG, "Initialized Play Services TFLite.")
try {
Log.d(
TAG,
"schema=${TensorFlowLite.schemaVersion(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY)} runtime=${TensorFlowLite.runtimeVersion(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY)}"
)
} catch (t: Throwable) {
Log.w(TAG, "Unable to query system-only TFLite version.", t)
}
cb(true)
}
.addOnFailureListener { e ->
Log.w(TAG, "Play Services TFLite init failed; will fall back to bundled runtime.", e)
cb(false)
}
}

fun createInterpreterApi(context: Context, modelName: String): InterpreterApi? {
fun createInterpreterApi(context: Context, modelName: String, preferPlayServices: Boolean): InterpreterApi {
val model = loadModelFile(context.assets, modelName)
val interpreterOption =
InterpreterApi.Options().setRuntime(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY)
interpreter = InterpreterApi.create(model, interpreterOption)
return interpreter

val runtime = if (preferPlayServices) {
InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY
} else {
// Bundled runtime provided by org.tensorflow:tensorflow-lite
InterpreterApi.Options.TfLiteRuntime.FROM_APPLICATION_ONLY
}

val options = InterpreterApi.Options().setRuntime(runtime)
return InterpreterApi.create(model, options)
}

@Throws(IOException::class)
Expand Down