mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Adding Java support for getAvailableProviders and other small methods (#5366)
* Adding Java support for getAvailableProviders, addFreeDimensionOverrideByName, disablePerSessionThreads and getProfilingStartTimeNs. * Fixing copyright years, running spotless and adding javadoc and an accessor to OrtProvider. * Renaming OrtSession.getProfilingStartTimeInNs. * Removing ngraph as it's been deprecated.
This commit is contained in:
parent
40926867c3
commit
fddbd8935c
9 changed files with 293 additions and 4 deletions
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
|
@ -10,6 +10,7 @@ import java.io.IOException;
|
|||
import java.io.InputStream;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.EnumSet;
|
||||
import java.util.Locale;
|
||||
import java.util.logging.Level;
|
||||
import java.util.logging.Logger;
|
||||
|
|
@ -40,6 +41,9 @@ final class OnnxRuntime {
|
|||
/** The API handle. */
|
||||
static long ortApiHandle;
|
||||
|
||||
/** The available runtime providers */
|
||||
static EnumSet<OrtProvider> providers;
|
||||
|
||||
private OnnxRuntime() {}
|
||||
|
||||
/* Computes and initializes OS_ARCH_STR (such as linux-x64) */
|
||||
|
|
@ -90,6 +94,7 @@ final class OnnxRuntime {
|
|||
load(tempDirectory, ONNXRUNTIME_LIBRARY_NAME);
|
||||
load(tempDirectory, ONNXRUNTIME_JNI_LIBRARY_NAME);
|
||||
ortApiHandle = initialiseAPIBase(ORT_API_VERSION_3);
|
||||
providers = initialiseProviders(ortApiHandle);
|
||||
loaded = true;
|
||||
} finally {
|
||||
if (!isAndroid()) {
|
||||
|
|
@ -202,6 +207,28 @@ final class OnnxRuntime {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts the providers array from the C API, converts it into an EnumSet.
|
||||
*
|
||||
* <p>Throws IllegalArgumentException if a provider isn't recognised (note this exception should
|
||||
* only happen during development of ONNX Runtime, if it happens at any other point, file an issue
|
||||
* on Github).
|
||||
*
|
||||
* @param ortApiHandle The API Handle.
|
||||
* @return The enum set.
|
||||
*/
|
||||
private static EnumSet<OrtProvider> initialiseProviders(long ortApiHandle) {
|
||||
String[] providersArray = getAvailableProviders(ortApiHandle);
|
||||
|
||||
EnumSet<OrtProvider> providers = EnumSet.noneOf(OrtProvider.class);
|
||||
|
||||
for (String provider : providersArray) {
|
||||
providers.add(OrtProvider.mapFromName(provider));
|
||||
}
|
||||
|
||||
return providers;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a reference to the API struct.
|
||||
*
|
||||
|
|
@ -209,4 +236,12 @@ final class OnnxRuntime {
|
|||
* @return A pointer to the API struct.
|
||||
*/
|
||||
private static native long initialiseAPIBase(int apiVersionNumber);
|
||||
|
||||
/**
|
||||
* Gets the array of available providers.
|
||||
*
|
||||
* @param ortApiHandle The API handle
|
||||
* @return The array of providers
|
||||
*/
|
||||
private static native String[] getAvailableProviders(long ortApiHandle);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ package ai.onnxruntime;
|
|||
|
||||
import ai.onnxruntime.OrtSession.SessionOptions;
|
||||
import java.io.IOException;
|
||||
import java.util.EnumSet;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
|
|
@ -254,6 +255,15 @@ public class OrtEnvironment implements AutoCloseable {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the providers available in this environment.
|
||||
*
|
||||
* @return An enum set of the available execution providers.
|
||||
*/
|
||||
public static EnumSet<OrtProvider> getAvailableProviders() {
|
||||
return OnnxRuntime.providers.clone();
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates the native object.
|
||||
*
|
||||
|
|
|
|||
63
java/src/main/java/ai/onnxruntime/OrtProvider.java
Normal file
63
java/src/main/java/ai/onnxruntime/OrtProvider.java
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
/*
|
||||
* Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/** The execution providers available through the Java API. */
|
||||
public enum OrtProvider {
|
||||
CPU("CPUExecutionProvider"),
|
||||
CUDA("CUDAExecutionProvider"),
|
||||
DNNL("DnnlExecutionProvider"),
|
||||
OPEN_VINO("OpenVINOExecutionProvider"),
|
||||
NUPHAR("NupharExecutionProvider"),
|
||||
VITIS_AI("VitisAIExecutionProvider"),
|
||||
TENSOR_RT("TensorrtExecutionProvider"),
|
||||
NNAPI("NnapiExecutionProvider"),
|
||||
RK_NPU("RknpuExecutionProvider"),
|
||||
DIRECT_ML("DmlExecutionProvider"),
|
||||
MI_GRAPH_X("MIGraphXExecutionProvider"),
|
||||
ACL("ACLExecutionProvider"),
|
||||
ARM_NN("ArmNNExecutionProvider");
|
||||
|
||||
private static final Map<String, OrtProvider> valueMap = new HashMap<>(values().length);
|
||||
|
||||
static {
|
||||
for (OrtProvider p : OrtProvider.values()) {
|
||||
valueMap.put(p.name, p);
|
||||
}
|
||||
}
|
||||
|
||||
private final String name;
|
||||
|
||||
OrtProvider(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
/**
|
||||
* Accessor for the internal name of this provider.
|
||||
*
|
||||
* @return The internal provider name.
|
||||
*/
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps from the name string used by ONNX Runtime into the enum.
|
||||
*
|
||||
* @param name The provider name string.
|
||||
* @return The enum constant.
|
||||
*/
|
||||
public static OrtProvider mapFromName(String name) {
|
||||
OrtProvider provider = valueMap.get(name);
|
||||
if (provider == null) {
|
||||
throw new IllegalArgumentException("Unknown execution provider - " + name);
|
||||
} else {
|
||||
return provider;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -315,6 +315,16 @@ public class OrtSession implements AutoCloseable {
|
|||
return metadata;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the timestamp that profiling started in nanoseconds.
|
||||
*
|
||||
* @return the profiling start time in ns.
|
||||
* @throws OrtException If the native call failed.
|
||||
*/
|
||||
public long getProfilingStartTimeInNs() throws OrtException {
|
||||
return getProfilingStartTimeInNs(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
}
|
||||
|
||||
/**
|
||||
* Ends the profiling session and returns the output of the profiler.
|
||||
*
|
||||
|
|
@ -414,6 +424,9 @@ public class OrtSession implements AutoCloseable {
|
|||
long runOptionsHandle)
|
||||
throws OrtException;
|
||||
|
||||
private native long getProfilingStartTimeInNs(long apiHandle, long nativeHandle)
|
||||
throws OrtException;
|
||||
|
||||
private native String endProfiling(long apiHandle, long nativeHandle, long allocatorHandle)
|
||||
throws OrtException;
|
||||
|
||||
|
|
@ -679,6 +692,32 @@ public class OrtSession implements AutoCloseable {
|
|||
customLibraryHandles.add(customHandle);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the value of a symbolic dimension. Fixed dimension computations may have more
|
||||
* optimizations applied to them.
|
||||
*
|
||||
* @param dimensionName The name of the symbolic dimension.
|
||||
* @param dimensionValue The value to set that dimension to.
|
||||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void setSymbolicDimensionValue(String dimensionName, long dimensionValue)
|
||||
throws OrtException {
|
||||
checkClosed();
|
||||
addFreeDimensionOverrideByName(
|
||||
OnnxRuntime.ortApiHandle, nativeHandle, dimensionName, dimensionValue);
|
||||
}
|
||||
|
||||
/**
|
||||
* Disables the per session thread pools. Must be used in conjunction with an environment
|
||||
* containing global thread pools.
|
||||
*
|
||||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void disablePerSessionThreads() throws OrtException {
|
||||
checkClosed();
|
||||
disablePerSessionThreads(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a single session configuration entry as a pair of strings.
|
||||
*
|
||||
|
|
@ -708,7 +747,6 @@ public class OrtSession implements AutoCloseable {
|
|||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void addCUDA() throws OrtException {
|
||||
checkClosed();
|
||||
addCUDA(0);
|
||||
}
|
||||
|
||||
|
|
@ -858,6 +896,13 @@ public class OrtSession implements AutoCloseable {
|
|||
|
||||
private native void closeOptions(long apiHandle, long nativeHandle);
|
||||
|
||||
private native void addFreeDimensionOverrideByName(
|
||||
long apiHandle, long nativeHandle, String dimensionName, long dimensionValue)
|
||||
throws OrtException;
|
||||
|
||||
private native void disablePerSessionThreads(long apiHandle, long nativeHandle)
|
||||
throws OrtException;
|
||||
|
||||
private native void addConfigEntry(
|
||||
long apiHandle, long nativeHandle, String configKey, String configValue)
|
||||
throws OrtException;
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
/*
|
||||
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2020 Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
#include <jni.h>
|
||||
#include "onnxruntime/core/session/onnxruntime_c_api.h"
|
||||
#include "ai_onnxruntime_OnnxRuntime.h"
|
||||
#include "OrtJniUtil.h"
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OnnxRuntime
|
||||
|
|
@ -18,3 +19,36 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxRuntime_initialiseAPIBase
|
|||
return (jlong) ortPtr;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OnnxRuntime
|
||||
* Method: getAvailableProviders
|
||||
* Signature: (J)[Ljava/lang/String;
|
||||
*/
|
||||
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxRuntime_getAvailableProviders
|
||||
(JNIEnv * jniEnv, jclass clazz, jlong apiHandle) {
|
||||
(void) jniEnv; (void) clazz; // required JNI parameters not needed by functions which don't call back into Java.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
|
||||
char** providers = NULL;
|
||||
int numProviders = 0;
|
||||
|
||||
// Extract the provider array
|
||||
checkOrtStatus(jniEnv,api,api->GetAvailableProviders(&providers,&numProviders));
|
||||
|
||||
// Convert to Java String Array
|
||||
char *stringClassName = "java/lang/String";
|
||||
jclass stringClazz = (*jniEnv)->FindClass(jniEnv, stringClassName);
|
||||
jobjectArray providerArray = (*jniEnv)->NewObjectArray(jniEnv,numProviders,stringClazz,NULL);
|
||||
|
||||
for (int i = 0; i < numProviders; i++) {
|
||||
// Read out the provider name and convert it to a java.lang.String
|
||||
jstring provider = (*jniEnv)->NewStringUTF(jniEnv,providers[i]);
|
||||
(*jniEnv)->SetObjectArrayElement(jniEnv, providerArray, i, provider);
|
||||
}
|
||||
|
||||
// Release providers
|
||||
checkOrtStatus(jniEnv,api,api->ReleaseAvailableProviders(providers,numProviders));
|
||||
providers = NULL;
|
||||
|
||||
return providerArray;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -310,6 +310,24 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run
|
|||
return outputArray;
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession
|
||||
* Method: getProfilingStartTimeInNs
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_getProfilingStartTimeInNs
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
OrtSession* session = (OrtSession*) sessionHandle;
|
||||
|
||||
uint64_t timestamp = 0;
|
||||
|
||||
checkOrtStatus(jniEnv,api,api->SessionGetProfilingStartTimeNs(session,×tamp));
|
||||
return (jlong) timestamp;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession
|
||||
* Method: endProfiling
|
||||
|
|
|
|||
|
|
@ -294,6 +294,39 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_closeC
|
|||
(*jniEnv)->ReleaseLongArrayElements(jniEnv,libraryHandles,handles,JNI_ABORT);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_SessionOptions
|
||||
* Method: addFreeDimensionOverrideByName
|
||||
* Signature: (JJLjava/lang/String;J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addFreeDimensionOverrideByName
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring dimensionName, jlong dimensionValue) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*)apiHandle;
|
||||
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
|
||||
|
||||
// Extract the string chars
|
||||
const char* cName = (*jniEnv)->GetStringUTFChars(jniEnv, dimensionName, NULL);
|
||||
|
||||
checkOrtStatus(jniEnv,api,api->AddFreeDimensionOverrideByName(options,cName,dimensionValue));
|
||||
|
||||
// Release the string chars
|
||||
(*jniEnv)->ReleaseStringUTFChars(jniEnv,dimensionName,cName);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_SessionOptions
|
||||
* Method: disablePerSessionThreads
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_disablePerSessionThreads
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*)apiHandle;
|
||||
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
|
||||
checkOrtStatus(jniEnv,api,api->DisablePerSessionThreads(options));
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_SessionOptions
|
||||
* Method: addConfigEntry
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
|
@ -33,6 +33,7 @@ import java.nio.file.Path;
|
|||
import java.nio.file.Paths;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.EnumSet;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Iterator;
|
||||
|
|
@ -575,9 +576,52 @@ public class InferenceTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testProviders() {
|
||||
EnumSet<OrtProvider> providers = OrtEnvironment.getAvailableProviders();
|
||||
int providersSize = providers.size();
|
||||
assertTrue(providersSize > 0);
|
||||
assertTrue(providers.contains(OrtProvider.CPU));
|
||||
|
||||
// Check that the providers are a copy of the original, note this does not enable the DNNL
|
||||
// provider
|
||||
providers.add(OrtProvider.DNNL);
|
||||
assertEquals(providersSize, OrtEnvironment.getAvailableProviders().size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSymbolicDimensionAssignment() throws OrtException {
|
||||
// model takes 1x5 input of fixed type, echoes back
|
||||
String modelPath = getResourcePath("/capi_symbolic_dims.onnx").toString();
|
||||
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testSymbolicDimensionAssignment")) {
|
||||
// Check the dimension is symbolic
|
||||
try (SessionOptions options = new SessionOptions()) {
|
||||
try (OrtSession session = env.createSession(modelPath, options)) {
|
||||
Map<String, NodeInfo> infoMap = session.getInputInfo();
|
||||
TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo();
|
||||
assertArrayEquals(new long[] {-1, 2}, aInfo.shape);
|
||||
}
|
||||
}
|
||||
// Check that when the options are assigned it overrides the symbolic dimension
|
||||
try (SessionOptions options = new SessionOptions()) {
|
||||
options.setSymbolicDimensionValue("n", 5);
|
||||
try (OrtSession session = env.createSession(modelPath, options)) {
|
||||
Map<String, NodeInfo> infoMap = session.getInputInfo();
|
||||
TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo();
|
||||
assertArrayEquals(new long[] {5, 2}, aInfo.shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCUDA() throws OrtException {
|
||||
if (System.getProperty("USE_CUDA") != null) {
|
||||
EnumSet<OrtProvider> providers = OrtEnvironment.getAvailableProviders();
|
||||
assertTrue(providers.size() > 1);
|
||||
assertTrue(providers.contains(OrtProvider.CPU));
|
||||
assertTrue(providers.contains(OrtProvider.CUDA));
|
||||
SqueezeNetTuple tuple = openSessionSqueezeNet(0);
|
||||
try (OrtEnvironment env = tuple.env;
|
||||
OrtSession session = tuple.session) {
|
||||
|
|
@ -926,6 +970,10 @@ public class InferenceTest {
|
|||
boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue());
|
||||
assertArrayEquals(flatInput, resultArray);
|
||||
}
|
||||
// Check that the profiling start time doesn't throw
|
||||
long profilingStartTime = session.getProfilingStartTimeInNs();
|
||||
|
||||
// Check the profiling output doesn't throw
|
||||
String profilingOutput = session.endProfiling();
|
||||
File profilingOutputFile = new File(profilingOutput);
|
||||
profilingOutputFile.deleteOnExit();
|
||||
|
|
@ -1208,6 +1256,7 @@ public class InferenceTest {
|
|||
assertEquals(OnnxJavaType.FLOAT, sequenceInfo.mapInfo.valueType);
|
||||
|
||||
// try-cast first element in sequence to map/dictionary type
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<Long, Float> map = (Map<Long, Float>) ((List<Object>) secondOutput.getValue()).get(0);
|
||||
assertEquals(0.25938290, map.get(0L), 1e-6);
|
||||
assertEquals(0.40904793, map.get(1L), 1e-6);
|
||||
|
|
@ -1274,6 +1323,7 @@ public class InferenceTest {
|
|||
assertEquals(OnnxJavaType.FLOAT, sequenceInfo.mapInfo.valueType);
|
||||
|
||||
// try-cast first element in sequence to map/dictionary type
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Float> map =
|
||||
(Map<String, Float>) ((List<Object>) secondOutput.getValue()).get(0);
|
||||
assertEquals(0.25938290, map.get("0"), 1e-6);
|
||||
|
|
|
|||
|
|
@ -166,6 +166,7 @@ class TestHelpers {
|
|||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
static void flattenBase(Object input, List output, Class<?> primitiveClass) {
|
||||
if (primitiveClass.equals(boolean.class)) {
|
||||
flattenBooleanBase((boolean[]) input, output);
|
||||
|
|
|
|||
Loading…
Reference in a new issue