diff --git a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java index 1f972fadc0..52e992dac5 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java +++ b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -10,6 +10,7 @@ import java.io.IOException; import java.io.InputStream; import java.nio.file.Files; import java.nio.file.Path; +import java.util.EnumSet; import java.util.Locale; import java.util.logging.Level; import java.util.logging.Logger; @@ -40,6 +41,9 @@ final class OnnxRuntime { /** The API handle. */ static long ortApiHandle; + /** The available runtime providers */ + static EnumSet providers; + private OnnxRuntime() {} /* Computes and initializes OS_ARCH_STR (such as linux-x64) */ @@ -90,6 +94,7 @@ final class OnnxRuntime { load(tempDirectory, ONNXRUNTIME_LIBRARY_NAME); load(tempDirectory, ONNXRUNTIME_JNI_LIBRARY_NAME); ortApiHandle = initialiseAPIBase(ORT_API_VERSION_3); + providers = initialiseProviders(ortApiHandle); loaded = true; } finally { if (!isAndroid()) { @@ -202,6 +207,28 @@ final class OnnxRuntime { } } + /** + * Extracts the providers array from the C API, converts it into an EnumSet. + * + *

Throws IllegalArgumentException if a provider isn't recognised (note this exception should + * only happen during development of ONNX Runtime, if it happens at any other point, file an issue + * on Github). + * + * @param ortApiHandle The API Handle. + * @return The enum set. + */ + private static EnumSet initialiseProviders(long ortApiHandle) { + String[] providersArray = getAvailableProviders(ortApiHandle); + + EnumSet providers = EnumSet.noneOf(OrtProvider.class); + + for (String provider : providersArray) { + providers.add(OrtProvider.mapFromName(provider)); + } + + return providers; + } + /** * Get a reference to the API struct. * @@ -209,4 +236,12 @@ final class OnnxRuntime { * @return A pointer to the API struct. */ private static native long initialiseAPIBase(int apiVersionNumber); + + /** + * Gets the array of available providers. + * + * @param ortApiHandle The API handle + * @return The array of providers + */ + private static native String[] getAvailableProviders(long ortApiHandle); } diff --git a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java index adbe7d0b4b..9e40d039f4 100644 --- a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java +++ b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java @@ -6,6 +6,7 @@ package ai.onnxruntime; import ai.onnxruntime.OrtSession.SessionOptions; import java.io.IOException; +import java.util.EnumSet; import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Logger; @@ -254,6 +255,15 @@ public class OrtEnvironment implements AutoCloseable { } } + /** + * Gets the providers available in this environment. + * + * @return An enum set of the available execution providers. + */ + public static EnumSet getAvailableProviders() { + return OnnxRuntime.providers.clone(); + } + /** * Creates the native object. * diff --git a/java/src/main/java/ai/onnxruntime/OrtProvider.java b/java/src/main/java/ai/onnxruntime/OrtProvider.java new file mode 100644 index 0000000000..abcb9594ee --- /dev/null +++ b/java/src/main/java/ai/onnxruntime/OrtProvider.java @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import java.util.HashMap; +import java.util.Map; + +/** The execution providers available through the Java API. */ +public enum OrtProvider { + CPU("CPUExecutionProvider"), + CUDA("CUDAExecutionProvider"), + DNNL("DnnlExecutionProvider"), + OPEN_VINO("OpenVINOExecutionProvider"), + NUPHAR("NupharExecutionProvider"), + VITIS_AI("VitisAIExecutionProvider"), + TENSOR_RT("TensorrtExecutionProvider"), + NNAPI("NnapiExecutionProvider"), + RK_NPU("RknpuExecutionProvider"), + DIRECT_ML("DmlExecutionProvider"), + MI_GRAPH_X("MIGraphXExecutionProvider"), + ACL("ACLExecutionProvider"), + ARM_NN("ArmNNExecutionProvider"); + + private static final Map valueMap = new HashMap<>(values().length); + + static { + for (OrtProvider p : OrtProvider.values()) { + valueMap.put(p.name, p); + } + } + + private final String name; + + OrtProvider(String name) { + this.name = name; + } + + /** + * Accessor for the internal name of this provider. + * + * @return The internal provider name. + */ + public String getName() { + return name; + } + + /** + * Maps from the name string used by ONNX Runtime into the enum. + * + * @param name The provider name string. + * @return The enum constant. + */ + public static OrtProvider mapFromName(String name) { + OrtProvider provider = valueMap.get(name); + if (provider == null) { + throw new IllegalArgumentException("Unknown execution provider - " + name); + } else { + return provider; + } + } +} diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 5911303066..dd72dd5dea 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -315,6 +315,16 @@ public class OrtSession implements AutoCloseable { return metadata; } + /** + * Returns the timestamp that profiling started in nanoseconds. + * + * @return the profiling start time in ns. + * @throws OrtException If the native call failed. + */ + public long getProfilingStartTimeInNs() throws OrtException { + return getProfilingStartTimeInNs(OnnxRuntime.ortApiHandle, nativeHandle); + } + /** * Ends the profiling session and returns the output of the profiler. * @@ -414,6 +424,9 @@ public class OrtSession implements AutoCloseable { long runOptionsHandle) throws OrtException; + private native long getProfilingStartTimeInNs(long apiHandle, long nativeHandle) + throws OrtException; + private native String endProfiling(long apiHandle, long nativeHandle, long allocatorHandle) throws OrtException; @@ -679,6 +692,32 @@ public class OrtSession implements AutoCloseable { customLibraryHandles.add(customHandle); } + /** + * Sets the value of a symbolic dimension. Fixed dimension computations may have more + * optimizations applied to them. + * + * @param dimensionName The name of the symbolic dimension. + * @param dimensionValue The value to set that dimension to. + * @throws OrtException If there was an error in native code. + */ + public void setSymbolicDimensionValue(String dimensionName, long dimensionValue) + throws OrtException { + checkClosed(); + addFreeDimensionOverrideByName( + OnnxRuntime.ortApiHandle, nativeHandle, dimensionName, dimensionValue); + } + + /** + * Disables the per session thread pools. Must be used in conjunction with an environment + * containing global thread pools. + * + * @throws OrtException If there was an error in native code. + */ + public void disablePerSessionThreads() throws OrtException { + checkClosed(); + disablePerSessionThreads(OnnxRuntime.ortApiHandle, nativeHandle); + } + /** * Adds a single session configuration entry as a pair of strings. * @@ -708,7 +747,6 @@ public class OrtSession implements AutoCloseable { * @throws OrtException If there was an error in native code. */ public void addCUDA() throws OrtException { - checkClosed(); addCUDA(0); } @@ -858,6 +896,13 @@ public class OrtSession implements AutoCloseable { private native void closeOptions(long apiHandle, long nativeHandle); + private native void addFreeDimensionOverrideByName( + long apiHandle, long nativeHandle, String dimensionName, long dimensionValue) + throws OrtException; + + private native void disablePerSessionThreads(long apiHandle, long nativeHandle) + throws OrtException; + private native void addConfigEntry( long apiHandle, long nativeHandle, String configKey, String configValue) throws OrtException; diff --git a/java/src/main/native/ai_onnxruntime_OnnxRuntime.c b/java/src/main/native/ai_onnxruntime_OnnxRuntime.c index 5e47f37271..7842300075 100644 --- a/java/src/main/native/ai_onnxruntime_OnnxRuntime.c +++ b/java/src/main/native/ai_onnxruntime_OnnxRuntime.c @@ -1,10 +1,11 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2020 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include #include "onnxruntime/core/session/onnxruntime_c_api.h" #include "ai_onnxruntime_OnnxRuntime.h" +#include "OrtJniUtil.h" /* * Class: ai_onnxruntime_OnnxRuntime @@ -18,3 +19,36 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxRuntime_initialiseAPIBase return (jlong) ortPtr; } +/* + * Class: ai_onnxruntime_OnnxRuntime + * Method: getAvailableProviders + * Signature: (J)[Ljava/lang/String; + */ +JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxRuntime_getAvailableProviders + (JNIEnv * jniEnv, jclass clazz, jlong apiHandle) { + (void) jniEnv; (void) clazz; // required JNI parameters not needed by functions which don't call back into Java. + const OrtApi* api = (const OrtApi*) apiHandle; + + char** providers = NULL; + int numProviders = 0; + + // Extract the provider array + checkOrtStatus(jniEnv,api,api->GetAvailableProviders(&providers,&numProviders)); + + // Convert to Java String Array + char *stringClassName = "java/lang/String"; + jclass stringClazz = (*jniEnv)->FindClass(jniEnv, stringClassName); + jobjectArray providerArray = (*jniEnv)->NewObjectArray(jniEnv,numProviders,stringClazz,NULL); + + for (int i = 0; i < numProviders; i++) { + // Read out the provider name and convert it to a java.lang.String + jstring provider = (*jniEnv)->NewStringUTF(jniEnv,providers[i]); + (*jniEnv)->SetObjectArrayElement(jniEnv, providerArray, i, provider); + } + + // Release providers + checkOrtStatus(jniEnv,api,api->ReleaseAvailableProviders(providers,numProviders)); + providers = NULL; + + return providerArray; +} diff --git a/java/src/main/native/ai_onnxruntime_OrtSession.c b/java/src/main/native/ai_onnxruntime_OrtSession.c index 79a2abc2e1..6ac19c9e07 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession.c @@ -310,6 +310,24 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run return outputArray; } + +/* + * Class: ai_onnxruntime_OrtSession + * Method: getProfilingStartTimeInNs + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_getProfilingStartTimeInNs + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtSession* session = (OrtSession*) sessionHandle; + + uint64_t timestamp = 0; + + checkOrtStatus(jniEnv,api,api->SessionGetProfilingStartTimeNs(session,×tamp)); + return (jlong) timestamp; +} + /* * Class: ai_onnxruntime_OrtSession * Method: endProfiling diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index 522d3233cb..21f8af6930 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -294,6 +294,39 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_closeC (*jniEnv)->ReleaseLongArrayElements(jniEnv,libraryHandles,handles,JNI_ABORT); } +/* + * Class: ai_onnxruntime_OrtSession_SessionOptions + * Method: addFreeDimensionOverrideByName + * Signature: (JJLjava/lang/String;J)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addFreeDimensionOverrideByName + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring dimensionName, jlong dimensionValue) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle; + + // Extract the string chars + const char* cName = (*jniEnv)->GetStringUTFChars(jniEnv, dimensionName, NULL); + + checkOrtStatus(jniEnv,api,api->AddFreeDimensionOverrideByName(options,cName,dimensionValue)); + + // Release the string chars + (*jniEnv)->ReleaseStringUTFChars(jniEnv,dimensionName,cName); +} + +/* + * Class: ai_onnxruntime_OrtSession_SessionOptions + * Method: disablePerSessionThreads + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_disablePerSessionThreads + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle; + checkOrtStatus(jniEnv,api,api->DisablePerSessionThreads(options)); +} + /* * Class: ai_onnxruntime_OrtSession_SessionOptions * Method: addConfigEntry diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index a2315bf7de..c0313dc0d2 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -33,6 +33,7 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.Arrays; +import java.util.EnumSet; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -575,9 +576,52 @@ public class InferenceTest { } } + @Test + public void testProviders() { + EnumSet providers = OrtEnvironment.getAvailableProviders(); + int providersSize = providers.size(); + assertTrue(providersSize > 0); + assertTrue(providers.contains(OrtProvider.CPU)); + + // Check that the providers are a copy of the original, note this does not enable the DNNL + // provider + providers.add(OrtProvider.DNNL); + assertEquals(providersSize, OrtEnvironment.getAvailableProviders().size()); + } + + @Test + public void testSymbolicDimensionAssignment() throws OrtException { + // model takes 1x5 input of fixed type, echoes back + String modelPath = getResourcePath("/capi_symbolic_dims.onnx").toString(); + + try (OrtEnvironment env = OrtEnvironment.getEnvironment("testSymbolicDimensionAssignment")) { + // Check the dimension is symbolic + try (SessionOptions options = new SessionOptions()) { + try (OrtSession session = env.createSession(modelPath, options)) { + Map infoMap = session.getInputInfo(); + TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo(); + assertArrayEquals(new long[] {-1, 2}, aInfo.shape); + } + } + // Check that when the options are assigned it overrides the symbolic dimension + try (SessionOptions options = new SessionOptions()) { + options.setSymbolicDimensionValue("n", 5); + try (OrtSession session = env.createSession(modelPath, options)) { + Map infoMap = session.getInputInfo(); + TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo(); + assertArrayEquals(new long[] {5, 2}, aInfo.shape); + } + } + } + } + @Test public void testCUDA() throws OrtException { if (System.getProperty("USE_CUDA") != null) { + EnumSet providers = OrtEnvironment.getAvailableProviders(); + assertTrue(providers.size() > 1); + assertTrue(providers.contains(OrtProvider.CPU)); + assertTrue(providers.contains(OrtProvider.CUDA)); SqueezeNetTuple tuple = openSessionSqueezeNet(0); try (OrtEnvironment env = tuple.env; OrtSession session = tuple.session) { @@ -926,6 +970,10 @@ public class InferenceTest { boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue()); assertArrayEquals(flatInput, resultArray); } + // Check that the profiling start time doesn't throw + long profilingStartTime = session.getProfilingStartTimeInNs(); + + // Check the profiling output doesn't throw String profilingOutput = session.endProfiling(); File profilingOutputFile = new File(profilingOutput); profilingOutputFile.deleteOnExit(); @@ -1208,6 +1256,7 @@ public class InferenceTest { assertEquals(OnnxJavaType.FLOAT, sequenceInfo.mapInfo.valueType); // try-cast first element in sequence to map/dictionary type + @SuppressWarnings("unchecked") Map map = (Map) ((List) secondOutput.getValue()).get(0); assertEquals(0.25938290, map.get(0L), 1e-6); assertEquals(0.40904793, map.get(1L), 1e-6); @@ -1274,6 +1323,7 @@ public class InferenceTest { assertEquals(OnnxJavaType.FLOAT, sequenceInfo.mapInfo.valueType); // try-cast first element in sequence to map/dictionary type + @SuppressWarnings("unchecked") Map map = (Map) ((List) secondOutput.getValue()).get(0); assertEquals(0.25938290, map.get("0"), 1e-6); diff --git a/java/src/test/java/ai/onnxruntime/TestHelpers.java b/java/src/test/java/ai/onnxruntime/TestHelpers.java index 237a549e97..d255f04ecb 100644 --- a/java/src/test/java/ai/onnxruntime/TestHelpers.java +++ b/java/src/test/java/ai/onnxruntime/TestHelpers.java @@ -166,6 +166,7 @@ class TestHelpers { } } + @SuppressWarnings("unchecked") static void flattenBase(Object input, List output, Class primitiveClass) { if (primitiveClass.equals(boolean.class)) { flattenBooleanBase((boolean[]) input, output);