Add NNAPI E2E test for Android java package (#8912)

* Add NNAPI E2E test for Android java package

* address cr comment
This commit is contained in:
Guoyu Wang 2021-08-31 17:34:33 -07:00 committed by GitHub
parent a9a0d3f6fa
commit 8404a2d011
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -3,7 +3,9 @@ package ai.onnxruntime.example.javavalidator
import ai.onnxruntime.OnnxTensor
import ai.onnxruntime.OrtEnvironment
import ai.onnxruntime.OrtException
import ai.onnxruntime.OrtProvider
import ai.onnxruntime.OrtSession.SessionOptions
import android.util.Log
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.platform.app.InstrumentationRegistry
import org.junit.Assert
@ -12,14 +14,44 @@ import org.junit.runner.RunWith
import java.io.IOException
import java.util.*
private const val TAG = "ORTAndroidTest"
@RunWith(AndroidJUnit4::class)
class SimpleTest {
@Test
@Throws(OrtException::class, IOException::class)
fun runSigmoidModelTest() {
for (intraOpNumThreads in 1..4) {
runSigmoidModelTestImpl(intraOpNumThreads)
}
}
@Test
fun runSigmoidModelTestNNAPI() {
runSigmoidModelTestImpl(1, true)
}
@Throws(IOException::class)
private fun readModel(fileName: String): ByteArray {
return InstrumentationRegistry.getInstrumentation().context.assets.open(fileName)
.readBytes()
}
@Throws(OrtException::class, IOException::class)
fun runSigmoidModelTestImpl(intraOpNumThreads: Int, useNNAPI: Boolean = false) {
Log.println(Log.INFO, TAG, "Testing with intraOpNumThreads=$intraOpNumThreads")
Log.println(Log.INFO, TAG, "Testing with useNNAPI=$useNNAPI")
val env = OrtEnvironment.getEnvironment()
env.use {
val opts = SessionOptions()
opts.setIntraOpNumThreads(intraOpNumThreads)
if (useNNAPI) {
if (OrtEnvironment.getAvailableProviders().contains(OrtProvider.NNAPI)) {
opts.addNnapi()
} else {
Log.println(Log.INFO, TAG, "NO NNAPI EP available, skip the test")
return
}
}
opts.use {
val session = env.createSession(readModel("sigmoid.ort"), opts)
session.use {
@ -40,6 +72,7 @@ class SimpleTest {
inputTensor.use {
val output = session.run(Collections.singletonMap(inputName, inputTensor))
output.use {
@Suppress("UNCHECKED_CAST")
val rawOutput = output[0].value as Array<Array<FloatArray>>
for (i in 0..2) {
for (j in 0..3) {
@ -58,11 +91,4 @@ class SimpleTest {
}
}
}
@Throws(IOException::class)
private fun readModel(fileName: String): ByteArray {
return InstrumentationRegistry.getInstrumentation().context.assets.open(fileName)
.readBytes()
}
}