mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Add NNAPI E2E test for Android java package (#8912)
* Add NNAPI E2E test for Android java package * address cr comment
This commit is contained in:
parent
a9a0d3f6fa
commit
8404a2d011
1 changed files with 34 additions and 8 deletions
|
|
@ -3,7 +3,9 @@ package ai.onnxruntime.example.javavalidator
|
|||
import ai.onnxruntime.OnnxTensor
|
||||
import ai.onnxruntime.OrtEnvironment
|
||||
import ai.onnxruntime.OrtException
|
||||
import ai.onnxruntime.OrtProvider
|
||||
import ai.onnxruntime.OrtSession.SessionOptions
|
||||
import android.util.Log
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4
|
||||
import androidx.test.platform.app.InstrumentationRegistry
|
||||
import org.junit.Assert
|
||||
|
|
@ -12,14 +14,44 @@ import org.junit.runner.RunWith
|
|||
import java.io.IOException
|
||||
import java.util.*
|
||||
|
||||
private const val TAG = "ORTAndroidTest"
|
||||
|
||||
@RunWith(AndroidJUnit4::class)
|
||||
class SimpleTest {
|
||||
@Test
|
||||
@Throws(OrtException::class, IOException::class)
|
||||
fun runSigmoidModelTest() {
|
||||
for (intraOpNumThreads in 1..4) {
|
||||
runSigmoidModelTestImpl(intraOpNumThreads)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun runSigmoidModelTestNNAPI() {
|
||||
runSigmoidModelTestImpl(1, true)
|
||||
}
|
||||
|
||||
@Throws(IOException::class)
|
||||
private fun readModel(fileName: String): ByteArray {
|
||||
return InstrumentationRegistry.getInstrumentation().context.assets.open(fileName)
|
||||
.readBytes()
|
||||
}
|
||||
|
||||
@Throws(OrtException::class, IOException::class)
|
||||
fun runSigmoidModelTestImpl(intraOpNumThreads: Int, useNNAPI: Boolean = false) {
|
||||
Log.println(Log.INFO, TAG, "Testing with intraOpNumThreads=$intraOpNumThreads")
|
||||
Log.println(Log.INFO, TAG, "Testing with useNNAPI=$useNNAPI")
|
||||
val env = OrtEnvironment.getEnvironment()
|
||||
env.use {
|
||||
val opts = SessionOptions()
|
||||
opts.setIntraOpNumThreads(intraOpNumThreads)
|
||||
if (useNNAPI) {
|
||||
if (OrtEnvironment.getAvailableProviders().contains(OrtProvider.NNAPI)) {
|
||||
opts.addNnapi()
|
||||
} else {
|
||||
Log.println(Log.INFO, TAG, "NO NNAPI EP available, skip the test")
|
||||
return
|
||||
}
|
||||
}
|
||||
opts.use {
|
||||
val session = env.createSession(readModel("sigmoid.ort"), opts)
|
||||
session.use {
|
||||
|
|
@ -40,6 +72,7 @@ class SimpleTest {
|
|||
inputTensor.use {
|
||||
val output = session.run(Collections.singletonMap(inputName, inputTensor))
|
||||
output.use {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
val rawOutput = output[0].value as Array<Array<FloatArray>>
|
||||
for (i in 0..2) {
|
||||
for (j in 0..3) {
|
||||
|
|
@ -58,11 +91,4 @@ class SimpleTest {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Throws(IOException::class)
|
||||
private fun readModel(fileName: String): ByteArray {
|
||||
return InstrumentationRegistry.getInstrumentation().context.assets.open(fileName)
|
||||
.readBytes()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue