Adding Java support for getAvailableProviders and other small methods (#5366)

* Adding Java support for getAvailableProviders, addFreeDimensionOverrideByName, disablePerSessionThreads and getProfilingStartTimeNs.

* Fixing copyright years, running spotless and adding javadoc and an accessor to OrtProvider.

* Renaming OrtSession.getProfilingStartTimeInNs.

* Removing ngraph as it's been deprecated.
This commit is contained in:
Adam Pocock 2020-11-25 00:42:57 -05:00 committed by GitHub
parent 40926867c3
commit fddbd8935c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 293 additions and 4 deletions

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
@ -10,6 +10,7 @@ import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.EnumSet;
import java.util.Locale;
import java.util.logging.Level;
import java.util.logging.Logger;
@ -40,6 +41,9 @@ final class OnnxRuntime {
/** The API handle. */
static long ortApiHandle;
/** The available runtime providers */
static EnumSet<OrtProvider> providers;
private OnnxRuntime() {}
/* Computes and initializes OS_ARCH_STR (such as linux-x64) */
@ -90,6 +94,7 @@ final class OnnxRuntime {
load(tempDirectory, ONNXRUNTIME_LIBRARY_NAME);
load(tempDirectory, ONNXRUNTIME_JNI_LIBRARY_NAME);
ortApiHandle = initialiseAPIBase(ORT_API_VERSION_3);
providers = initialiseProviders(ortApiHandle);
loaded = true;
} finally {
if (!isAndroid()) {
@ -202,6 +207,28 @@ final class OnnxRuntime {
}
}
/**
* Extracts the providers array from the C API, converts it into an EnumSet.
*
* <p>Throws IllegalArgumentException if a provider isn't recognised (note this exception should
* only happen during development of ONNX Runtime, if it happens at any other point, file an issue
* on Github).
*
* @param ortApiHandle The API Handle.
* @return The enum set.
*/
private static EnumSet<OrtProvider> initialiseProviders(long ortApiHandle) {
String[] providersArray = getAvailableProviders(ortApiHandle);
EnumSet<OrtProvider> providers = EnumSet.noneOf(OrtProvider.class);
for (String provider : providersArray) {
providers.add(OrtProvider.mapFromName(provider));
}
return providers;
}
/**
* Get a reference to the API struct.
*
@ -209,4 +236,12 @@ final class OnnxRuntime {
* @return A pointer to the API struct.
*/
private static native long initialiseAPIBase(int apiVersionNumber);
/**
* Gets the array of available providers.
*
* @param ortApiHandle The API handle
* @return The array of providers
*/
private static native String[] getAvailableProviders(long ortApiHandle);
}

View file

@ -6,6 +6,7 @@ package ai.onnxruntime;
import ai.onnxruntime.OrtSession.SessionOptions;
import java.io.IOException;
import java.util.EnumSet;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Logger;
@ -254,6 +255,15 @@ public class OrtEnvironment implements AutoCloseable {
}
}
/**
* Gets the providers available in this environment.
*
* @return An enum set of the available execution providers.
*/
public static EnumSet<OrtProvider> getAvailableProviders() {
return OnnxRuntime.providers.clone();
}
/**
* Creates the native object.
*

View file

@ -0,0 +1,63 @@
/*
* Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
import java.util.HashMap;
import java.util.Map;
/** The execution providers available through the Java API. */
public enum OrtProvider {
CPU("CPUExecutionProvider"),
CUDA("CUDAExecutionProvider"),
DNNL("DnnlExecutionProvider"),
OPEN_VINO("OpenVINOExecutionProvider"),
NUPHAR("NupharExecutionProvider"),
VITIS_AI("VitisAIExecutionProvider"),
TENSOR_RT("TensorrtExecutionProvider"),
NNAPI("NnapiExecutionProvider"),
RK_NPU("RknpuExecutionProvider"),
DIRECT_ML("DmlExecutionProvider"),
MI_GRAPH_X("MIGraphXExecutionProvider"),
ACL("ACLExecutionProvider"),
ARM_NN("ArmNNExecutionProvider");
private static final Map<String, OrtProvider> valueMap = new HashMap<>(values().length);
static {
for (OrtProvider p : OrtProvider.values()) {
valueMap.put(p.name, p);
}
}
private final String name;
OrtProvider(String name) {
this.name = name;
}
/**
* Accessor for the internal name of this provider.
*
* @return The internal provider name.
*/
public String getName() {
return name;
}
/**
* Maps from the name string used by ONNX Runtime into the enum.
*
* @param name The provider name string.
* @return The enum constant.
*/
public static OrtProvider mapFromName(String name) {
OrtProvider provider = valueMap.get(name);
if (provider == null) {
throw new IllegalArgumentException("Unknown execution provider - " + name);
} else {
return provider;
}
}
}

View file

@ -315,6 +315,16 @@ public class OrtSession implements AutoCloseable {
return metadata;
}
/**
* Returns the timestamp that profiling started in nanoseconds.
*
* @return the profiling start time in ns.
* @throws OrtException If the native call failed.
*/
public long getProfilingStartTimeInNs() throws OrtException {
return getProfilingStartTimeInNs(OnnxRuntime.ortApiHandle, nativeHandle);
}
/**
* Ends the profiling session and returns the output of the profiler.
*
@ -414,6 +424,9 @@ public class OrtSession implements AutoCloseable {
long runOptionsHandle)
throws OrtException;
private native long getProfilingStartTimeInNs(long apiHandle, long nativeHandle)
throws OrtException;
private native String endProfiling(long apiHandle, long nativeHandle, long allocatorHandle)
throws OrtException;
@ -679,6 +692,32 @@ public class OrtSession implements AutoCloseable {
customLibraryHandles.add(customHandle);
}
/**
* Sets the value of a symbolic dimension. Fixed dimension computations may have more
* optimizations applied to them.
*
* @param dimensionName The name of the symbolic dimension.
* @param dimensionValue The value to set that dimension to.
* @throws OrtException If there was an error in native code.
*/
public void setSymbolicDimensionValue(String dimensionName, long dimensionValue)
throws OrtException {
checkClosed();
addFreeDimensionOverrideByName(
OnnxRuntime.ortApiHandle, nativeHandle, dimensionName, dimensionValue);
}
/**
* Disables the per session thread pools. Must be used in conjunction with an environment
* containing global thread pools.
*
* @throws OrtException If there was an error in native code.
*/
public void disablePerSessionThreads() throws OrtException {
checkClosed();
disablePerSessionThreads(OnnxRuntime.ortApiHandle, nativeHandle);
}
/**
* Adds a single session configuration entry as a pair of strings.
*
@ -708,7 +747,6 @@ public class OrtSession implements AutoCloseable {
* @throws OrtException If there was an error in native code.
*/
public void addCUDA() throws OrtException {
checkClosed();
addCUDA(0);
}
@ -858,6 +896,13 @@ public class OrtSession implements AutoCloseable {
private native void closeOptions(long apiHandle, long nativeHandle);
private native void addFreeDimensionOverrideByName(
long apiHandle, long nativeHandle, String dimensionName, long dimensionValue)
throws OrtException;
private native void disablePerSessionThreads(long apiHandle, long nativeHandle)
throws OrtException;
private native void addConfigEntry(
long apiHandle, long nativeHandle, String configKey, String configValue)
throws OrtException;

View file

@ -1,10 +1,11 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2020 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
#include "onnxruntime/core/session/onnxruntime_c_api.h"
#include "ai_onnxruntime_OnnxRuntime.h"
#include "OrtJniUtil.h"
/*
* Class: ai_onnxruntime_OnnxRuntime
@ -18,3 +19,36 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxRuntime_initialiseAPIBase
return (jlong) ortPtr;
}
/*
* Class: ai_onnxruntime_OnnxRuntime
* Method: getAvailableProviders
* Signature: (J)[Ljava/lang/String;
*/
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxRuntime_getAvailableProviders
(JNIEnv * jniEnv, jclass clazz, jlong apiHandle) {
(void) jniEnv; (void) clazz; // required JNI parameters not needed by functions which don't call back into Java.
const OrtApi* api = (const OrtApi*) apiHandle;
char** providers = NULL;
int numProviders = 0;
// Extract the provider array
checkOrtStatus(jniEnv,api,api->GetAvailableProviders(&providers,&numProviders));
// Convert to Java String Array
char *stringClassName = "java/lang/String";
jclass stringClazz = (*jniEnv)->FindClass(jniEnv, stringClassName);
jobjectArray providerArray = (*jniEnv)->NewObjectArray(jniEnv,numProviders,stringClazz,NULL);
for (int i = 0; i < numProviders; i++) {
// Read out the provider name and convert it to a java.lang.String
jstring provider = (*jniEnv)->NewStringUTF(jniEnv,providers[i]);
(*jniEnv)->SetObjectArrayElement(jniEnv, providerArray, i, provider);
}
// Release providers
checkOrtStatus(jniEnv,api,api->ReleaseAvailableProviders(providers,numProviders));
providers = NULL;
return providerArray;
}

View file

@ -310,6 +310,24 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run
return outputArray;
}
/*
* Class: ai_onnxruntime_OrtSession
* Method: getProfilingStartTimeInNs
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_getProfilingStartTimeInNs
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtSession* session = (OrtSession*) sessionHandle;
uint64_t timestamp = 0;
checkOrtStatus(jniEnv,api,api->SessionGetProfilingStartTimeNs(session,&timestamp));
return (jlong) timestamp;
}
/*
* Class: ai_onnxruntime_OrtSession
* Method: endProfiling

View file

@ -294,6 +294,39 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_closeC
(*jniEnv)->ReleaseLongArrayElements(jniEnv,libraryHandles,handles,JNI_ABORT);
}
/*
* Class: ai_onnxruntime_OrtSession_SessionOptions
* Method: addFreeDimensionOverrideByName
* Signature: (JJLjava/lang/String;J)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addFreeDimensionOverrideByName
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring dimensionName, jlong dimensionValue) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
// Extract the string chars
const char* cName = (*jniEnv)->GetStringUTFChars(jniEnv, dimensionName, NULL);
checkOrtStatus(jniEnv,api,api->AddFreeDimensionOverrideByName(options,cName,dimensionValue));
// Release the string chars
(*jniEnv)->ReleaseStringUTFChars(jniEnv,dimensionName,cName);
}
/*
* Class: ai_onnxruntime_OrtSession_SessionOptions
* Method: disablePerSessionThreads
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_disablePerSessionThreads
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
checkOrtStatus(jniEnv,api,api->DisablePerSessionThreads(options));
}
/*
* Class: ai_onnxruntime_OrtSession_SessionOptions
* Method: addConfigEntry

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
@ -33,6 +33,7 @@ import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
@ -575,9 +576,52 @@ public class InferenceTest {
}
}
@Test
public void testProviders() {
EnumSet<OrtProvider> providers = OrtEnvironment.getAvailableProviders();
int providersSize = providers.size();
assertTrue(providersSize > 0);
assertTrue(providers.contains(OrtProvider.CPU));
// Check that the providers are a copy of the original, note this does not enable the DNNL
// provider
providers.add(OrtProvider.DNNL);
assertEquals(providersSize, OrtEnvironment.getAvailableProviders().size());
}
@Test
public void testSymbolicDimensionAssignment() throws OrtException {
// model takes 1x5 input of fixed type, echoes back
String modelPath = getResourcePath("/capi_symbolic_dims.onnx").toString();
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testSymbolicDimensionAssignment")) {
// Check the dimension is symbolic
try (SessionOptions options = new SessionOptions()) {
try (OrtSession session = env.createSession(modelPath, options)) {
Map<String, NodeInfo> infoMap = session.getInputInfo();
TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo();
assertArrayEquals(new long[] {-1, 2}, aInfo.shape);
}
}
// Check that when the options are assigned it overrides the symbolic dimension
try (SessionOptions options = new SessionOptions()) {
options.setSymbolicDimensionValue("n", 5);
try (OrtSession session = env.createSession(modelPath, options)) {
Map<String, NodeInfo> infoMap = session.getInputInfo();
TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo();
assertArrayEquals(new long[] {5, 2}, aInfo.shape);
}
}
}
}
@Test
public void testCUDA() throws OrtException {
if (System.getProperty("USE_CUDA") != null) {
EnumSet<OrtProvider> providers = OrtEnvironment.getAvailableProviders();
assertTrue(providers.size() > 1);
assertTrue(providers.contains(OrtProvider.CPU));
assertTrue(providers.contains(OrtProvider.CUDA));
SqueezeNetTuple tuple = openSessionSqueezeNet(0);
try (OrtEnvironment env = tuple.env;
OrtSession session = tuple.session) {
@ -926,6 +970,10 @@ public class InferenceTest {
boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue());
assertArrayEquals(flatInput, resultArray);
}
// Check that the profiling start time doesn't throw
long profilingStartTime = session.getProfilingStartTimeInNs();
// Check the profiling output doesn't throw
String profilingOutput = session.endProfiling();
File profilingOutputFile = new File(profilingOutput);
profilingOutputFile.deleteOnExit();
@ -1208,6 +1256,7 @@ public class InferenceTest {
assertEquals(OnnxJavaType.FLOAT, sequenceInfo.mapInfo.valueType);
// try-cast first element in sequence to map/dictionary type
@SuppressWarnings("unchecked")
Map<Long, Float> map = (Map<Long, Float>) ((List<Object>) secondOutput.getValue()).get(0);
assertEquals(0.25938290, map.get(0L), 1e-6);
assertEquals(0.40904793, map.get(1L), 1e-6);
@ -1274,6 +1323,7 @@ public class InferenceTest {
assertEquals(OnnxJavaType.FLOAT, sequenceInfo.mapInfo.valueType);
// try-cast first element in sequence to map/dictionary type
@SuppressWarnings("unchecked")
Map<String, Float> map =
(Map<String, Float>) ((List<Object>) secondOutput.getValue()).get(0);
assertEquals(0.25938290, map.get("0"), 1e-6);

View file

@ -166,6 +166,7 @@ class TestHelpers {
}
}
@SuppressWarnings("unchecked")
static void flattenBase(Object input, List output, Class<?> primitiveClass) {
if (primitiveClass.equals(boolean.class)) {
flattenBooleanBase((boolean[]) input, output);