From 8404a2d011749c277d89ddac324b99f6ed3ea8ba Mon Sep 17 00:00:00 2001 From: Guoyu Wang <62914304+gwang-msft@users.noreply.github.com> Date: Tue, 31 Aug 2021 17:34:33 -0700 Subject: [PATCH] Add NNAPI E2E test for Android java package (#8912) * Add NNAPI E2E test for Android java package * address cr comment --- .../example/javavalidator/SimpleTest.kt | 42 +++++++++++++++---- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/java/src/test/android/app/src/androidTest/java/ai/onnxruntime/example/javavalidator/SimpleTest.kt b/java/src/test/android/app/src/androidTest/java/ai/onnxruntime/example/javavalidator/SimpleTest.kt index aa86e91789..8b2b459882 100644 --- a/java/src/test/android/app/src/androidTest/java/ai/onnxruntime/example/javavalidator/SimpleTest.kt +++ b/java/src/test/android/app/src/androidTest/java/ai/onnxruntime/example/javavalidator/SimpleTest.kt @@ -3,7 +3,9 @@ package ai.onnxruntime.example.javavalidator import ai.onnxruntime.OnnxTensor import ai.onnxruntime.OrtEnvironment import ai.onnxruntime.OrtException +import ai.onnxruntime.OrtProvider import ai.onnxruntime.OrtSession.SessionOptions +import android.util.Log import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.platform.app.InstrumentationRegistry import org.junit.Assert @@ -12,14 +14,44 @@ import org.junit.runner.RunWith import java.io.IOException import java.util.* +private const val TAG = "ORTAndroidTest" + @RunWith(AndroidJUnit4::class) class SimpleTest { @Test - @Throws(OrtException::class, IOException::class) fun runSigmoidModelTest() { + for (intraOpNumThreads in 1..4) { + runSigmoidModelTestImpl(intraOpNumThreads) + } + } + + @Test + fun runSigmoidModelTestNNAPI() { + runSigmoidModelTestImpl(1, true) + } + + @Throws(IOException::class) + private fun readModel(fileName: String): ByteArray { + return InstrumentationRegistry.getInstrumentation().context.assets.open(fileName) + .readBytes() + } + + @Throws(OrtException::class, IOException::class) + fun runSigmoidModelTestImpl(intraOpNumThreads: Int, useNNAPI: Boolean = false) { + Log.println(Log.INFO, TAG, "Testing with intraOpNumThreads=$intraOpNumThreads") + Log.println(Log.INFO, TAG, "Testing with useNNAPI=$useNNAPI") val env = OrtEnvironment.getEnvironment() env.use { val opts = SessionOptions() + opts.setIntraOpNumThreads(intraOpNumThreads) + if (useNNAPI) { + if (OrtEnvironment.getAvailableProviders().contains(OrtProvider.NNAPI)) { + opts.addNnapi() + } else { + Log.println(Log.INFO, TAG, "NO NNAPI EP available, skip the test") + return + } + } opts.use { val session = env.createSession(readModel("sigmoid.ort"), opts) session.use { @@ -40,6 +72,7 @@ class SimpleTest { inputTensor.use { val output = session.run(Collections.singletonMap(inputName, inputTensor)) output.use { + @Suppress("UNCHECKED_CAST") val rawOutput = output[0].value as Array> for (i in 0..2) { for (j in 0..3) { @@ -58,11 +91,4 @@ class SimpleTest { } } } - - - @Throws(IOException::class) - private fun readModel(fileName: String): ByteArray { - return InstrumentationRegistry.getInstrumentation().context.assets.open(fileName) - .readBytes() - } }