diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index cf3d9e3e70..d2b0ebd1f3 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -815,6 +815,10 @@ if (onnxruntime_BUILD_JAVA) message(STATUS "Running Java tests") # delegate to gradle's test runner if(WIN32) + # If we're on windows, symlink the custom op test library somewhere we can see it + set(JAVA_NATIVE_TEST_DIR ${JAVA_OUTPUT_DIR}/native-test) + file(MAKE_DIRECTORY ${JAVA_NATIVE_TEST_DIR}) + add_custom_command(TARGET custom_op_library POST_BUILD COMMAND ${CMAKE_COMMAND} -E create_symlink $ ${JAVA_NATIVE_TEST_DIR}/$) # On windows ctest requires a test to be an .exe(.com) file # So there are two options 1) Install Chocolatey and its gradle package # That package would install gradle.exe shim to its bin so ctest could run gradle.exe @@ -826,8 +830,8 @@ if (onnxruntime_BUILD_JAVA) -DREPO_ROOT=${REPO_ROOT} -P ${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime_java_unittests.cmake) else() - add_test(NAME onnxruntime4j_test COMMAND ${GRADLE_EXECUTABLE} cmakeCheck -DcmakeBuildDir=${CMAKE_CURRENT_BINARY_DIR} - WORKING_DIRECTORY ${REPO_ROOT}/java) + add_test(NAME onnxruntime4j_test COMMAND ${GRADLE_EXECUTABLE} cmakeCheck -DcmakeBuildDir=${CMAKE_CURRENT_BINARY_DIR} + WORKING_DIRECTORY ${REPO_ROOT}/java) endif() set_property(TEST onnxruntime4j_test APPEND PROPERTY DEPENDS onnxruntime4j_jni) endif() diff --git a/java/build.gradle b/java/build.gradle index 7daee81bc1..1119a2c7d7 100644 --- a/java/build.gradle +++ b/java/build.gradle @@ -67,6 +67,7 @@ def cmakeBuildDir = System.properties['cmakeBuildDir'] def cmakeJavaDir = "${cmakeBuildDir}/java" def cmakeNativeLibDir = "${cmakeJavaDir}/native-lib" def cmakeNativeJniDir = "${cmakeJavaDir}/native-jni" +def cmakeNativeTestDir = "${cmakeJavaDir}/native-test" def cmakeBuildOutputDir = "${cmakeJavaDir}/build" compileJava { @@ -84,7 +85,8 @@ sourceSets.test { // add compiled native libs resources.srcDirs += [ cmakeNativeLibDir, - cmakeNativeJniDir + cmakeNativeJniDir, + cmakeNativeTestDir ] } } @@ -144,6 +146,9 @@ dependencies { test { useJUnitPlatform() + if (cmakeBuildDir != null) { + workingDir cmakeBuildDir + } testLogging { events "passed", "skipped", "failed" showStandardStreams = true diff --git a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java index 83ea79dd4e..12a8c42881 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java +++ b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java @@ -78,7 +78,12 @@ final class OnnxRuntime { } } - private static boolean isAndroid() { + /** + * Check if we're running on Android. + * + * @return True if the {@code android.app.Activity} class can be loaded, false otherwise. + */ + static boolean isAndroid() { try { Class.forName("android.app.Activity"); return true; diff --git a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java index 9d6710c4f2..adbe7d0b4b 100644 --- a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java +++ b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2020 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -15,24 +15,6 @@ import java.util.logging.Logger; */ public class OrtEnvironment implements AutoCloseable { - /** The logging level for messages from the environment and session. */ - public enum LoggingLevel { - ORT_LOGGING_LEVEL_VERBOSE(0), - ORT_LOGGING_LEVEL_INFO(1), - ORT_LOGGING_LEVEL_WARNING(2), - ORT_LOGGING_LEVEL_ERROR(3), - ORT_LOGGING_LEVEL_FATAL(4); - private final int value; - - LoggingLevel(int value) { - this.value = value; - } - - public int getValue() { - return value; - } - } - private static final Logger logger = Logger.getLogger(OrtEnvironment.class.getName()); public static final String DEFAULT_NAME = "ort-java"; @@ -49,29 +31,29 @@ public class OrtEnvironment implements AutoCloseable { private static final AtomicInteger refCount = new AtomicInteger(); - private static volatile LoggingLevel curLogLevel; + private static volatile OrtLoggingLevel curLogLevel; private static volatile String curLoggingName; /** * Gets the OrtEnvironment. If there is not an environment currently created, it creates one using - * {@link OrtEnvironment#DEFAULT_NAME} and {@link LoggingLevel#ORT_LOGGING_LEVEL_WARNING}. + * {@link OrtEnvironment#DEFAULT_NAME} and {@link OrtLoggingLevel#ORT_LOGGING_LEVEL_WARNING}. * * @return An onnxruntime environment. */ public static OrtEnvironment getEnvironment() { - return getEnvironment(LoggingLevel.ORT_LOGGING_LEVEL_WARNING, DEFAULT_NAME); + return getEnvironment(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING, DEFAULT_NAME); } /** * Gets the OrtEnvironment. If there is not an environment currently created, it creates one using - * the supplied name and {@link LoggingLevel#ORT_LOGGING_LEVEL_WARNING}. + * the supplied name and {@link OrtLoggingLevel#ORT_LOGGING_LEVEL_WARNING}. * * @param name The logging id of the environment. * @return An onnxruntime environment. */ public static OrtEnvironment getEnvironment(String name) { - return getEnvironment(LoggingLevel.ORT_LOGGING_LEVEL_WARNING, name); + return getEnvironment(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING, name); } /** @@ -81,7 +63,7 @@ public class OrtEnvironment implements AutoCloseable { * @param logLevel The logging level to use. * @return An onnxruntime environment. */ - public static OrtEnvironment getEnvironment(LoggingLevel logLevel) { + public static OrtEnvironment getEnvironment(OrtLoggingLevel logLevel) { return getEnvironment(logLevel, DEFAULT_NAME); } @@ -94,7 +76,8 @@ public class OrtEnvironment implements AutoCloseable { * @param name The log id. * @return The OrtEnvironment singleton. */ - public static synchronized OrtEnvironment getEnvironment(LoggingLevel loggingLevel, String name) { + public static synchronized OrtEnvironment getEnvironment( + OrtLoggingLevel loggingLevel, String name) { if (INSTANCE == null) { try { INSTANCE = new OrtEnvironment(loggingLevel, name); @@ -104,7 +87,7 @@ public class OrtEnvironment implements AutoCloseable { throw new IllegalStateException("Failed to create OrtEnvironment", e); } } else { - if ((loggingLevel.value != curLogLevel.value) || (!name.equals(curLoggingName))) { + if ((loggingLevel.getValue() != curLogLevel.getValue()) || (!name.equals(curLoggingName))) { logger.warning( "Tried to change OrtEnvironment's logging level or name while a reference exists."); } @@ -125,7 +108,7 @@ public class OrtEnvironment implements AutoCloseable { * @throws OrtException If the environment couldn't be created. */ private OrtEnvironment() throws OrtException { - this(LoggingLevel.ORT_LOGGING_LEVEL_WARNING, "java-default"); + this(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING, "java-default"); } /** @@ -135,7 +118,7 @@ public class OrtEnvironment implements AutoCloseable { * @param name The logging id of the environment. * @throws OrtException If the environment couldn't be created. */ - private OrtEnvironment(LoggingLevel loggingLevel, String name) throws OrtException { + private OrtEnvironment(OrtLoggingLevel loggingLevel, String name) throws OrtException { nativeHandle = createHandle(OnnxRuntime.ortApiHandle, loggingLevel.getValue(), name); defaultAllocator = new OrtAllocator(getDefaultAllocator(OnnxRuntime.ortApiHandle), true); } diff --git a/java/src/main/java/ai/onnxruntime/OrtLoggingLevel.java b/java/src/main/java/ai/onnxruntime/OrtLoggingLevel.java new file mode 100644 index 0000000000..1e0fc60cbe --- /dev/null +++ b/java/src/main/java/ai/onnxruntime/OrtLoggingLevel.java @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import java.util.logging.Logger; + +/** The logging level for messages from the environment and session. */ +public enum OrtLoggingLevel { + ORT_LOGGING_LEVEL_VERBOSE(0), + ORT_LOGGING_LEVEL_INFO(1), + ORT_LOGGING_LEVEL_WARNING(2), + ORT_LOGGING_LEVEL_ERROR(3), + ORT_LOGGING_LEVEL_FATAL(4); + private final int value; + + private static final Logger logger = Logger.getLogger(OrtLoggingLevel.class.getName()); + private static final OrtLoggingLevel[] values = new OrtLoggingLevel[5]; + + static { + for (OrtLoggingLevel ot : OrtLoggingLevel.values()) { + values[ot.value] = ot; + } + } + + OrtLoggingLevel(int value) { + this.value = value; + } + + /** + * Gets the native value associated with this logging level. + * + * @return The native value. + */ + public int getValue() { + return value; + } + + /** + * Maps from the C API's int enum to the Java enum. + * + * @param logLevel The index of the Java enum. + * @return The Java enum. + */ + public static OrtLoggingLevel mapFromInt(int logLevel) { + if ((logLevel > 0) && (logLevel < values.length)) { + return values[logLevel]; + } else { + logger.warning("Unknown logging level " + logLevel + " setting to ORT_LOGGING_LEVEL_VERBOSE"); + return ORT_LOGGING_LEVEL_VERBOSE; + } + } +} diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 5393fae7da..8a30124187 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -200,6 +200,21 @@ public class OrtSession implements AutoCloseable { return run(inputs, outputNames); } + /** + * Scores an input feed dict, returning the map of all inferred outputs. + * + *

The outputs are sorted based on their id number. + * + * @param inputs The inputs to score. + * @param runOptions The RunOptions to control this run. + * @return The inferred outputs. + * @throws OrtException If there was an error in native code, the input names are invalid, or if + * there are zero or too many inputs. + */ + public Result run(Map inputs, RunOptions runOptions) throws OrtException { + return run(inputs, outputNames, runOptions); + } + /** * Scores an input feed dict, returning the map of requested inferred outputs. * @@ -213,6 +228,24 @@ public class OrtSession implements AutoCloseable { */ public Result run(Map inputs, Set requestedOutputs) throws OrtException { + return run(inputs, requestedOutputs, null); + } + + /** + * Scores an input feed dict, returning the map of requested inferred outputs. + * + *

The outputs are sorted based on the supplied set traveral order. + * + * @param inputs The inputs to score. + * @param requestedOutputs The requested outputs. + * @param runOptions The RunOptions to control this run. + * @return The inferred outputs. + * @throws OrtException If there was an error in native code, the input or output names are + * invalid, or if there are zero or too many inputs or outputs. + */ + public Result run( + Map inputs, Set requestedOutputs, RunOptions runOptions) + throws OrtException { if (!closed) { if (inputs.isEmpty() || (inputs.size() > numInputs)) { throw new OrtException( @@ -249,6 +282,8 @@ public class OrtSession implements AutoCloseable { "Unknown output name " + s + ", expected one of " + outputNames.toString()); } } + long runOptionsHandle = runOptions == null ? 0 : runOptions.nativeHandle; + OnnxValue[] outputValues = run( OnnxRuntime.ortApiHandle, @@ -258,7 +293,8 @@ public class OrtSession implements AutoCloseable { inputHandles, inputNamesArray.length, outputNamesArray, - outputNamesArray.length); + outputNamesArray.length, + runOptionsHandle); return new Result(outputNamesArray, outputValues); } else { throw new IllegalStateException("Trying to score a closed OrtSession."); @@ -268,8 +304,8 @@ public class OrtSession implements AutoCloseable { /** * Gets the metadata for the currently loaded model. * - * @throws OrtException on failure * @return The metadata. + * @throws OrtException If the native call failed. */ public OnnxModelMetadata getMetadata() throws OrtException { if (metadata == null) { @@ -278,6 +314,19 @@ public class OrtSession implements AutoCloseable { return metadata; } + /** + * Ends the profiling session and returns the output of the profiler. + * + *

Profiling should be enabled in the {@link SessionOptions} used to construct this {@code + * Session}. + * + * @return The profiling output. + * @throws OrtException If the native call failed. + */ + public String endProfiling() throws OrtException { + return endProfiling(OnnxRuntime.ortApiHandle, nativeHandle, allocator.handle); + } + @Override public String toString() { return "OrtSession(numInputs=" + numInputs + ",numOutputs=" + numOutputs + ")"; @@ -336,6 +385,22 @@ public class OrtSession implements AutoCloseable { private native NodeInfo[] getOutputInfo(long apiHandle, long nativeHandle, long allocatorHandle) throws OrtException; + /** + * The native run call. runOptionsHandle can be zero (i.e. the null pointer), but all other + * handles must be valid pointers. + * + * @param apiHandle The pointer to the api. + * @param nativeHandle The pointer to the session. + * @param allocatorHandle The pointer to the allocator. + * @param inputNamesArray The input names. + * @param inputs The input tensors. + * @param numInputs The number of inputs. + * @param outputNamesArray The requested output names. + * @param numOutputs The number of requested outputs. + * @param runOptionsHandle The (possibly null) pointer to the run options. + * @return The OnnxValues produced by this run. + * @throws OrtException If the native call failed in some way. + */ private native OnnxValue[] run( long apiHandle, long nativeHandle, @@ -344,7 +409,11 @@ public class OrtSession implements AutoCloseable { long[] inputs, long numInputs, String[] outputNamesArray, - long numOutputs) + long numOutputs, + long runOptionsHandle) + throws OrtException; + + private native String endProfiling(long apiHandle, long nativeHandle, long allocatorHandle) throws OrtException; private native void closeSession(long apiHandle, long nativeHandle) throws OrtException; @@ -368,6 +437,9 @@ public class OrtSession implements AutoCloseable { * options. * *

Modifying this after the session has been constructed will have no effect. + * + *

The SessionOptions object must not be closed until all sessions which use it are closed, as + * otherwise it could release resources that are in use. */ public static class SessionOptions implements AutoCloseable { @@ -421,15 +493,39 @@ public class OrtSession implements AutoCloseable { private final long nativeHandle; + private final List customLibraryHandles; + + private boolean closed = false; + /** Create an empty session options. */ public SessionOptions() { nativeHandle = createOptions(OnnxRuntime.ortApiHandle); + customLibraryHandles = new ArrayList<>(); } /** Closes the session options, releasing any memory acquired. */ @Override public void close() { - closeOptions(OnnxRuntime.ortApiHandle, nativeHandle); + if (!closed) { + if (customLibraryHandles.size() > 0) { + long[] longArray = new long[customLibraryHandles.size()]; + for (int i = 0; i < customLibraryHandles.size(); i++) { + longArray[i] = customLibraryHandles.get(i); + } + closeCustomLibraries(longArray); + } + closeOptions(OnnxRuntime.ortApiHandle, nativeHandle); + closed = true; + } else { + throw new IllegalStateException("Trying to close a closed SessionOptions."); + } + } + + /** Checks if the SessionOptions is closed, if so throws {@link IllegalStateException}. */ + private void checkClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed SessionOptions"); + } } /** @@ -439,6 +535,7 @@ public class OrtSession implements AutoCloseable { * @throws OrtException If there was an error in native code. */ public void setExecutionMode(ExecutionMode mode) throws OrtException { + checkClosed(); setExecutionMode(OnnxRuntime.ortApiHandle, nativeHandle, mode.getID()); } @@ -449,6 +546,7 @@ public class OrtSession implements AutoCloseable { * @throws OrtException If there was an error in native code. */ public void setOptimizationLevel(OptLevel level) throws OrtException { + checkClosed(); setOptimizationLevel(OnnxRuntime.ortApiHandle, nativeHandle, level.getID()); } @@ -460,6 +558,7 @@ public class OrtSession implements AutoCloseable { * @throws OrtException If there was an error in native code. */ public void setInterOpNumThreads(int numThreads) throws OrtException { + checkClosed(); setInterOpNumThreads(OnnxRuntime.ortApiHandle, nativeHandle, numThreads); } @@ -471,6 +570,7 @@ public class OrtSession implements AutoCloseable { * @throws OrtException If there was an error in native code. */ public void setIntraOpNumThreads(int numThreads) throws OrtException { + checkClosed(); setIntraOpNumThreads(OnnxRuntime.ortApiHandle, nativeHandle, numThreads); } @@ -481,15 +581,107 @@ public class OrtSession implements AutoCloseable { * @throws OrtException If there was an error in native code. */ public void setOptimizedModelFilePath(String outputPath) throws OrtException { + checkClosed(); setOptimizationModelFilePath(OnnxRuntime.ortApiHandle, nativeHandle, outputPath); } + /** + * Sets the logger id to use. + * + * @param loggerId The logger id string. + * @throws OrtException If there was an error in native code. + */ + public void setLoggerId(String loggerId) throws OrtException { + checkClosed(); + setLoggerId(OnnxRuntime.ortApiHandle, nativeHandle, loggerId); + } + + /** + * Enables profiling in sessions using this SessionOptions. + * + * @param filePath The file to write profile information to. + * @throws OrtException If there was an error in native code. + */ + public void enableProfiling(String filePath) throws OrtException { + checkClosed(); + enableProfiling(OnnxRuntime.ortApiHandle, nativeHandle, filePath); + } + + /** + * Disables profiling in sessions using this SessionOptions. + * + * @throws OrtException If there was an error in native code. + */ + public void disableProfiling() throws OrtException { + checkClosed(); + disableProfiling(OnnxRuntime.ortApiHandle, nativeHandle); + } + + /** + * Turns on memory pattern optimizations, where memory is preallocated if all shapes are known. + * + * @param memoryPatternOptimization If true enable memory pattern optimizations. + * @throws OrtException If there was an error in native code. + */ + public void setMemoryPatternOptimization(boolean memoryPatternOptimization) + throws OrtException { + checkClosed(); + setMemoryPatternOptimization( + OnnxRuntime.ortApiHandle, nativeHandle, memoryPatternOptimization); + } + + /** + * Sets the CPU to use an arena memory allocator. + * + * @param useArena If true use an arena memory allocator for the CPU execution provider. + * @throws OrtException If there was an error in native code. + */ + public void setCPUArenaAllocator(boolean useArena) throws OrtException { + checkClosed(); + setCPUArenaAllocator(OnnxRuntime.ortApiHandle, nativeHandle, useArena); + } + + /** + * Sets the Session's logging level. + * + * @param logLevel The log level to use. + * @throws OrtException If there was an error in native code. + */ + public void setSessionLogLevel(OrtLoggingLevel logLevel) throws OrtException { + checkClosed(); + setSessionLogLevel(OnnxRuntime.ortApiHandle, nativeHandle, logLevel.getValue()); + } + + /** + * Sets the Session's logging verbosity level. + * + * @param logLevel The logging verbosity to use. + * @throws OrtException If there was an error in native code. + */ + public void setSessionLogVerbosityLevel(int logLevel) throws OrtException { + checkClosed(); + setSessionLogVerbosityLevel(OnnxRuntime.ortApiHandle, nativeHandle, logLevel); + } + + /** + * Registers a library of custom ops for use with {@link OrtSession}s using this SessionOptions. + * + * @param path The path to the library on disk. + * @throws OrtException If there was an error loading the library. + */ + public void registerCustomOpLibrary(String path) throws OrtException { + checkClosed(); + long customHandle = registerCustomOpLibrary(OnnxRuntime.ortApiHandle, nativeHandle, path); + customLibraryHandles.add(customHandle); + } + /** * Add CUDA as an execution backend, using device 0. * * @throws OrtException If there was an error in native code. */ public void addCUDA() throws OrtException { + checkClosed(); addCUDA(0); } @@ -500,6 +692,7 @@ public class OrtSession implements AutoCloseable { * @throws OrtException If there was an error in native code. */ public void addCUDA(int deviceNum) throws OrtException { + checkClosed(); addCUDA(OnnxRuntime.ortApiHandle, nativeHandle, deviceNum); } @@ -513,6 +706,7 @@ public class OrtSession implements AutoCloseable { * @throws OrtException If there was an error in native code. */ public void addCPU(boolean useArena) throws OrtException { + checkClosed(); addCPU(OnnxRuntime.ortApiHandle, nativeHandle, useArena ? 1 : 0); } @@ -523,6 +717,7 @@ public class OrtSession implements AutoCloseable { * @throws OrtException If there was an error in native code. */ public void addDnnl(boolean useArena) throws OrtException { + checkClosed(); addDnnl(OnnxRuntime.ortApiHandle, nativeHandle, useArena ? 1 : 0); } @@ -535,6 +730,7 @@ public class OrtSession implements AutoCloseable { * @throws OrtException If there was an error in native code. */ public void addNGraph(String ngBackendType) throws OrtException { + checkClosed(); addNGraph(OnnxRuntime.ortApiHandle, nativeHandle, ngBackendType); } @@ -545,6 +741,7 @@ public class OrtSession implements AutoCloseable { * @throws OrtException If there was an error in native code. */ public void addOpenVINO(String deviceId) throws OrtException { + checkClosed(); addOpenVINO(OnnxRuntime.ortApiHandle, nativeHandle, deviceId); } @@ -555,6 +752,7 @@ public class OrtSession implements AutoCloseable { * @throws OrtException If there was an error in native code. */ public void addTensorrt(int deviceNum) throws OrtException { + checkClosed(); addTensorrt(OnnxRuntime.ortApiHandle, nativeHandle, deviceNum); } @@ -564,6 +762,7 @@ public class OrtSession implements AutoCloseable { * @throws OrtException If there was an error in native code. */ public void addNnapi() throws OrtException { + checkClosed(); addNnapi(OnnxRuntime.ortApiHandle, nativeHandle); } @@ -575,6 +774,7 @@ public class OrtSession implements AutoCloseable { * @throws OrtException If there was an error in native code. */ public void addNuphar(boolean allowUnalignedBuffers, String settings) throws OrtException { + checkClosed(); addNuphar(OnnxRuntime.ortApiHandle, nativeHandle, allowUnalignedBuffers ? 1 : 0, settings); } @@ -585,6 +785,7 @@ public class OrtSession implements AutoCloseable { * @throws OrtException If there was an error in native code. */ public void addDirectML(int deviceId) throws OrtException { + checkClosed(); addDirectML(OnnxRuntime.ortApiHandle, nativeHandle, deviceId); } @@ -595,6 +796,7 @@ public class OrtSession implements AutoCloseable { * @throws OrtException If there was an error in native code. */ public void addACL(boolean useArena) throws OrtException { + checkClosed(); addACL(OnnxRuntime.ortApiHandle, nativeHandle, useArena ? 1 : 0); } @@ -615,6 +817,31 @@ public class OrtSession implements AutoCloseable { private native long createOptions(long apiHandle); + private native void setLoggerId(long apiHandle, long nativeHandle, String loggerId) + throws OrtException; + + private native void enableProfiling(long apiHandle, long nativeHandle, String filePrefix) + throws OrtException; + + private native void disableProfiling(long apiHandle, long nativeHandle) throws OrtException; + + private native void setMemoryPatternOptimization( + long apiHandle, long nativeHandle, boolean memoryPatternOptimization) throws OrtException; + + private native void setCPUArenaAllocator(long apiHandle, long nativeHandle, boolean useArena) + throws OrtException; + + private native void setSessionLogLevel(long apiHandle, long nativeHandle, int logLevel) + throws OrtException; + + private native void setSessionLogVerbosityLevel(long apiHandle, long nativeHandle, int logLevel) + throws OrtException; + + private native long registerCustomOpLibrary(long apiHandle, long nativeHandle, String path) + throws OrtException; + + private native void closeCustomLibraries(long[] nativeHandle); + private native void closeOptions(long apiHandle, long nativeHandle); /* @@ -658,6 +885,141 @@ public class OrtSession implements AutoCloseable { private native void addACL(long apiHandle, long nativeHandle, int useArena) throws OrtException; } + /** Used to control logging and termination of a call to {@link OrtSession#run}. */ + public static class RunOptions implements AutoCloseable { + + private final long nativeHandle; + + private boolean closed = false; + + /** + * Creates a RunOptions. + * + * @throws OrtException If the construction of the native RunOptions failed. + */ + public RunOptions() throws OrtException { + this.nativeHandle = createRunOptions(OnnxRuntime.ortApiHandle); + } + + /** + * Sets the current logging level on this RunOptions. + * + * @param level The new logging level. + * @throws OrtException If the native call failed. + */ + public void setLogLevel(OrtLoggingLevel level) throws OrtException { + checkClosed(); + setLogLevel(OnnxRuntime.ortApiHandle, nativeHandle, level.getValue()); + } + + /** + * Gets the current logging level set on this RunOptions. + * + * @return The logging level. + * @throws OrtException If the native call failed. + */ + public OrtLoggingLevel getLogLevel() throws OrtException { + checkClosed(); + return OrtLoggingLevel.mapFromInt(getLogLevel(OnnxRuntime.ortApiHandle, nativeHandle)); + } + + /** + * Sets the current logging verbosity level on this RunOptions. + * + * @param level The new logging verbosity level. + * @throws OrtException If the native call failed. + */ + public void setLogVerbosityLevel(int level) throws OrtException { + checkClosed(); + setLogVerbosityLevel(OnnxRuntime.ortApiHandle, nativeHandle, level); + } + + /** + * Gets the current logging verbosity level set on this RunOptions. + * + * @return The logging verbosity level. + * @throws OrtException If the native call failed. + */ + public int getLogVerbosityLevel() throws OrtException { + checkClosed(); + return getLogVerbosityLevel(OnnxRuntime.ortApiHandle, nativeHandle); + } + + /** + * Sets the run tag used in logging. + * + * @param runTag The run tag in logging output. + * @throws OrtException If the native library call failed. + */ + public void setRunTag(String runTag) throws OrtException { + checkClosed(); + setRunTag(OnnxRuntime.ortApiHandle, nativeHandle, runTag); + } + + /** + * Gets the String used to log information about this run. + * + * @return The run tag. + * @throws OrtException If the native library call failed. + */ + public String getRunTag() throws OrtException { + checkClosed(); + return getRunTag(OnnxRuntime.ortApiHandle, nativeHandle); + } + + /** + * Sets a flag so that all incomplete {@link OrtSession#run} calls using this instance of {@code + * RunOptions} will terminate as soon as possible. If the flag is false, it resets this {@code + * RunOptions} so it can be used with other calls to {@link OrtSession#run}. + * + * @param terminate If true terminate all runs associated with this RunOptions. + * @throws OrtException If the native library call failed. + */ + public void setTerminate(boolean terminate) throws OrtException { + checkClosed(); + setTerminate(OnnxRuntime.ortApiHandle, nativeHandle, terminate); + } + + /** Checks if the RunOptions is closed, if so throws {@link IllegalStateException}. */ + private void checkClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed RunOptions"); + } + } + + @Override + public void close() { + if (!closed) { + close(OnnxRuntime.ortApiHandle, nativeHandle); + closed = true; + } else { + throw new IllegalStateException("Trying to close an already closed RunOptions"); + } + } + + private static native long createRunOptions(long apiHandle) throws OrtException; + + private native void setLogLevel(long apiHandle, long nativeHandle, int logLevel) + throws OrtException; + + private native int getLogLevel(long apiHandle, long nativeHandle) throws OrtException; + + private native void setLogVerbosityLevel(long apiHandle, long nativeHandle, int logLevel) + throws OrtException; + + private native int getLogVerbosityLevel(long apiHandle, long nativeHandle) throws OrtException; + + private native void setRunTag(long apiHandle, long nativeHandle, String runTag) + throws OrtException; + + private native String getRunTag(long apiHandle, long nativeHandle) throws OrtException; + + private native void setTerminate(long apiHandle, long nativeHandle, boolean terminate) + throws OrtException; + + private static native void close(long apiHandle, long nativeHandle); + } + /** * An {@link AutoCloseable} wrapper around a {@link Map} containing {@link OnnxValue}s. * diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index e89e0a70c2..53ed61c671 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2020 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -16,7 +16,7 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { } /** - * Must be kept in sync with ORT_LOGGING_LEVEL and OrtEnvironment#LoggingLevel + * Must be kept in sync with ORT_LOGGING_LEVEL and the OrtLoggingLevel java enum */ OrtLoggingLevel convertLoggingLevel(jint level) { switch (level) { diff --git a/java/src/main/native/ai_onnxruntime_OrtSession.c b/java/src/main/native/ai_onnxruntime_OrtSession.c index 2111d16746..79a2abc2e1 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2020 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -19,9 +19,8 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_la const OrtApi* api = (const OrtApi*) apiHandle; OrtSession* session; - jboolean copy; #ifdef _WIN32 - const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, modelPath, ©); + const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, modelPath, NULL); size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, modelPath); wchar_t* newString = (wchar_t*)calloc(stringLength+1,sizeof(jchar)); wcsncpy_s(newString, stringLength+1, (const wchar_t*) cPath, stringLength); @@ -29,7 +28,7 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_la free(newString); (*jniEnv)->ReleaseStringChars(jniEnv,modelPath,cPath); #else - const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, modelPath, ©); + const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, modelPath, NULL); checkOrtStatus(jniEnv,api,api->CreateSession((OrtEnv*)envHandle, cPath, (OrtSessionOptions*)optsHandle, &session)); (*jniEnv)->ReleaseStringUTFChars(jniEnv,modelPath,cPath); #endif @@ -236,14 +235,16 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getOutputInfo /* * Class: ai_onnxruntime_OrtSession * Method: run - * Signature: (JJJ[Ljava/lang/String;[JJ[Ljava/lang/String;J)[Lai/onnxruntime/OnnxValue; + * Signature: (JJJ[Ljava/lang/String;[JJ[Ljava/lang/String;JJ)[Lai/onnxruntime/OnnxValue; * private native OnnxValue[] run(long apiHandle, long nativeHandle, long allocatorHandle, String[] inputNamesArray, long[] inputs, long numInputs, String[] outputNamesArray, long numOutputs) */ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle, jobjectArray inputNamesArr, jlongArray tensorArr, jlong numInputs, jobjectArray outputNamesArr, jlong numOutputs) { + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle, jobjectArray inputNamesArr, jlongArray tensorArr, jlong numInputs, jobjectArray outputNamesArr, jlong numOutputs, jlong runOptionsHandle) { (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; + OrtSession* session = (OrtSession*) sessionHandle; + OrtRunOptions* runOptions = (OrtRunOptions*) runOptionsHandle; // Create the buffers for the Java input and output strings const char** inputNames; @@ -276,7 +277,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run // Actually score the inputs. //printf("inputTensors = %p, first tensor = %p, numInputs = %ld, outputValues = %p, numOutputs = %ld\n",inputTensors,(OrtValue*)inputTensors[0],numInputs,outputValues,numOutputs); //ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess, _In_ OrtRunOptions* run_options, _In_ const char* const* input_names, _In_ const OrtValue* const* input, size_t input_len, _In_ const char* const* output_names, size_t output_names_len, _Out_ OrtValue** output); - checkOrtStatus(jniEnv,api,api->Run((OrtSession*)sessionHandle, NULL, (const char* const*) inputNames, (const OrtValue* const*) inputTensors, numInputs, (const char* const*) outputNames, numOutputs, outputValues)); + checkOrtStatus(jniEnv,api,api->Run(session, runOptions, (const char* const*) inputNames, (const OrtValue* const*) inputTensors, numInputs, (const char* const*) outputNames, numOutputs, outputValues)); // Release the C array of pointers to the tensors. (*jniEnv)->ReleaseLongArrayElements(jniEnv,tensorArr,inputTensors,JNI_ABORT); @@ -309,6 +310,24 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run return outputArray; } +/* + * Class: ai_onnxruntime_OrtSession + * Method: endProfiling + * Signature: (JJJ)Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_endProfiling + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle) { + (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; + + char* profileStr; + checkOrtStatus(jniEnv,api,api->SessionEndProfiling((OrtSession*)handle,allocator,&profileStr)); + jstring profileOutput = (*jniEnv)->NewStringUTF(jniEnv,profileStr); + checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,profileStr)); + return profileOutput; +} + /* * Class: ai_onnxruntime_OrtSession * Method: closeSession @@ -321,23 +340,6 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_closeSession api->ReleaseSession((OrtSession*)handle); } -/* - * Class: ai_onnxruntime_OrtSession - * Method: releaseNamesHandle - * Signature: (JJJ)V - */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_releaseNamesHandle - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong allocatorHandle, jlong namesHandle, jlong length) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; - char** names = (char**) namesHandle; - for (uint32_t i = 0; i < length; i++) { - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,names[i])); - } - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,names)); -} - /* * Class: ai_onnxruntime_OrtSession * Method: constructMetadata diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_RunOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_RunOptions.c new file mode 100644 index 0000000000..eda097a28e --- /dev/null +++ b/java/src/main/native/ai_onnxruntime_OrtSession_RunOptions.c @@ -0,0 +1,134 @@ +/* + * Copyright (c) 2020 Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +#include +#include +#include "onnxruntime/core/session/onnxruntime_c_api.h" +#include "OrtJniUtil.h" +#include "ai_onnxruntime_OrtSession_RunOptions.h" + +/* + * Class: ai_onnxruntime_OrtSession_RunOptions + * Method: createRunOptions + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_createRunOptions + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtRunOptions* opts; + checkOrtStatus(jniEnv,api,api->CreateRunOptions(&opts)); + return (jlong) opts; +} + +/* + * Class: ai_onnxruntime_OrtSession_RunOptions + * Method: setLogLevel + * Signature: (JJI)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_setLogLevel + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle, jint logLevel) { + (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + checkOrtStatus(jniEnv,api,api->RunOptionsSetRunLogSeverityLevel((OrtRunOptions*) nativeHandle,logLevel)); +} + +/* + * Class: ai_onnxruntime_OrtSession_RunOptions + * Method: getLogLevel + * Signature: (JJ)I + */ +JNIEXPORT jint JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_getLogLevel + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle) { + (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + jint logLevel; + checkOrtStatus(jniEnv,api,api->RunOptionsGetRunLogSeverityLevel((OrtRunOptions*) nativeHandle,&logLevel)); + return logLevel; +} + +/* + * Class: ai_onnxruntime_OrtSession_RunOptions + * Method: setLogVerbosityLevel + * Signature: (JJI)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_setLogVerbosityLevel + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle, jint logLevel) { + (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + checkOrtStatus(jniEnv,api,api->RunOptionsSetRunLogVerbosityLevel((OrtRunOptions*) nativeHandle,logLevel)); +} + +/* + * Class: ai_onnxruntime_OrtSession_RunOptions + * Method: getLogVerbosityLevel + * Signature: (JJ)I + */ +JNIEXPORT jint JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_getLogVerbosityLevel + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle) { + (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + jint logLevel; + checkOrtStatus(jniEnv,api,api->RunOptionsGetRunLogVerbosityLevel((OrtRunOptions*) nativeHandle,&logLevel)); + return logLevel; +} + +/* + * Class: ai_onnxruntime_OrtSession_RunOptions + * Method: setRunTag + * Signature: (JJLjava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_setRunTag + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle, jstring runTag) { + (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + const char* runTagStr = (*jniEnv)->GetStringUTFChars(jniEnv, runTag, NULL); + checkOrtStatus(jniEnv,api,api->RunOptionsSetRunTag((OrtRunOptions*) nativeHandle, runTagStr)); + (*jniEnv)->ReleaseStringUTFChars(jniEnv,runTag,runTagStr); +} + +/* + * Class: ai_onnxruntime_OrtSession_RunOptions + * Method: getRunTag + * Signature: (JJ)Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_getRunTag + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle) { + (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + const char* runTagStr; + // This is a reference to the C str, and should not be freed. + checkOrtStatus(jniEnv,api,api->RunOptionsGetRunTag((OrtRunOptions*)nativeHandle,&runTagStr)); + jstring runTag = (*jniEnv)->NewStringUTF(jniEnv,runTagStr); + return runTag; +} + +/* + * Class: ai_onnxruntime_OrtSession_RunOptions + * Method: setTerminate + * Signature: (JJZ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_setTerminate + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle, jboolean terminate) { + (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtRunOptions* runOptions = (OrtRunOptions*) nativeHandle; + if (terminate) { + checkOrtStatus(jniEnv,api,api->RunOptionsSetTerminate(runOptions)); + } else { + checkOrtStatus(jniEnv,api,api->RunOptionsUnsetTerminate(runOptions)); + } +} + +/* + * Class: ai_onnxruntime_OrtSession_RunOptions + * Method: close + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_close + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) { + (void) jniEnv; (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + api->ReleaseRunOptions((OrtRunOptions*) handle); +} diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index 739ab68357..47db160487 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2020 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -7,6 +7,11 @@ #include "onnxruntime/core/session/onnxruntime_c_api.h" #include "OrtJniUtil.h" #include "ai_onnxruntime_OrtSession_SessionOptions.h" +#ifdef WIN32 +#include +#else +#include +#endif // Providers #include "onnxruntime/core/providers/cpu/cpu_provider_factory.h" @@ -107,7 +112,9 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_creat OrtSessionOptions* opts; checkOrtStatus(jniEnv,api,api->CreateSessionOptions(&opts)); checkOrtStatus(jniEnv,api,api->SetInterOpNumThreads(opts, 1)); - checkOrtStatus(jniEnv,api,api->SetIntraOpNumThreads(opts, 1)); + // Commented out due to constant OpenMP warning as this API is invalid when running with OpenMP. + // Not sure how to detect that from within the C API though. + //checkOrtStatus(jniEnv,api,api->SetIntraOpNumThreads(opts, 1)); return (jlong) opts; } @@ -117,10 +124,174 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_creat * Signature: (JJ)V */ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_closeOptions - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) { - (void) jniEnv; (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - api->ReleaseSessionOptions((OrtSessionOptions*) handle); + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) { + (void)jniEnv; (void)jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + api->ReleaseSessionOptions((OrtSessionOptions*)handle); +} + +/* + * Class: ai_onnxruntime_OrtSession_SessionOptions + * Method: setLoggerId + * Signature: (JJLjava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setLoggerId + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring loggerId) { + (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle; + const char* loggerIdStr = (*jniEnv)->GetStringUTFChars(jniEnv, loggerId, NULL); + checkOrtStatus(jniEnv,api,api->SetSessionLogId(options, loggerIdStr)); + (*jniEnv)->ReleaseStringUTFChars(jniEnv,loggerId,loggerIdStr); +} + +/* + * Class: ai_onnxruntime_OrtSession_SessionOptions + * Method: enableProfiling + * Signature: (JJLjava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_enableProfiling + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring pathString) { + (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle; +#ifdef _WIN32 + const jchar* path = (*jniEnv)->GetStringChars(jniEnv, pathString, NULL); + size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, pathString); + wchar_t* newString = (wchar_t*)calloc(stringLength+1,sizeof(jchar)); + wcsncpy_s(newString, stringLength+1, (const wchar_t*) path, stringLength); + checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,api->EnableProfiling(options, (const wchar_t*) newString)); + free(newString); + (*jniEnv)->ReleaseStringChars(jniEnv,pathString,path); +#else + const char* path = (*jniEnv)->GetStringUTFChars(jniEnv, pathString, NULL); + checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,api->EnableProfiling(options, path)); + (*jniEnv)->ReleaseStringUTFChars(jniEnv,pathString,path); +#endif +} + +/* + * Class: ai_onnxruntime_OrtSession_SessionOptions + * Method: disableProfiling + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_disableProfiling + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle) { + (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle; + checkOrtStatus(jniEnv,api,api->DisableProfiling(options)); +} + +/* + * Class: ai_onnxruntime_OrtSession_SessionOptions + * Method: setMemoryPatternOptimization + * Signature: (JJZ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setMemoryPatternOptimization + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jboolean memPattern) { + (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle; + if (memPattern) { + checkOrtStatus(jniEnv,api,api->EnableMemPattern(options)); + } else { + checkOrtStatus(jniEnv,api,api->DisableMemPattern(options)); + } +} + +/* + * Class: ai_onnxruntime_OrtSession_SessionOptions + * Method: setCPUArenaAllocator + * Signature: (JJZ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setCPUArenaAllocator + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jboolean useArena) { + (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle; + if (useArena) { + checkOrtStatus(jniEnv,api,api->EnableCpuMemArena(options)); + } else { + checkOrtStatus(jniEnv,api,api->DisableCpuMemArena(options)); + } +} + +/* + * Class: ai_onnxruntime_OrtSession_SessionOptions + * Method: setSessionLogLevel + * Signature: (JJI)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setSessionLogLevel + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jint logLevel) { + (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle; + checkOrtStatus(jniEnv,api,api->SetSessionLogSeverityLevel(options,logLevel)); +} + +/* + * Class: ai_onnxruntime_OrtSession_SessionOptions + * Method: setSessionLogVerbosityLevel + * Signature: (JJI)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setSessionLogVerbosityLevel + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jint logLevel) { + (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle; + checkOrtStatus(jniEnv,api,api->SetSessionLogVerbosityLevel(options,logLevel)); +} + +/* + * Class: ai_onnxruntime_OrtSession_SessionOptions + * Method: registerCustomOpLibrary + * Signature: (JJLjava/lang/String;)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_registerCustomOpLibrary + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring libraryPath) { + (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + + // Extract the string chars + const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, libraryPath, NULL); + + // Load the library + void* libraryHandle; + checkOrtStatus(jniEnv,api,api->RegisterCustomOpsLibrary((OrtSessionOptions*)optionsHandle,cPath,&libraryHandle)); + + // Release the string chars + (*jniEnv)->ReleaseStringUTFChars(jniEnv,libraryPath,cPath); + + return (jlong) libraryHandle; +} + +/* + * Class: ai_onnxruntime_OrtSession_SessionOptions + * Method: closeCustomLibraries + * Signature: ([J)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_closeCustomLibraries + (JNIEnv * jniEnv, jobject jobj, jlongArray libraryHandles) { + (void) jniEnv; (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + + // Get the number of elements in the array + jsize numHandles = (*jniEnv)->GetArrayLength(jniEnv, libraryHandles); + + // Get the elements of the libraryHandles array + jlong* handles = (*jniEnv)->GetLongArrayElements(jniEnv,libraryHandles,NULL); + + // Iterate the handles, calling the appropriate close function + for (jint i = 0; i < numHandles; i++) { +#ifdef WIN32 + FreeLibrary((void*)handles[i]); +#else + dlclose((void*)handles[i]); +#endif + } + + // Release the long array + (*jniEnv)->ReleaseLongArrayElements(jniEnv,libraryHandles,handles,JNI_ABORT); } /* diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 5423df20a1..81088ee7ce 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -120,7 +120,7 @@ public class InferenceTest { String modelPath = getResourcePath("/partial-inputs-test-2.onnx").toString(); try (OrtEnvironment env = OrtEnvironment.getEnvironment( - OrtEnvironment.LoggingLevel.ORT_LOGGING_LEVEL_FATAL, "partialInputs"); + OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "partialInputs"); OrtSession.SessionOptions options = new SessionOptions(); OrtSession session = env.createSession(modelPath, options)) { assertNotNull(session); @@ -207,7 +207,7 @@ public class InferenceTest { String modelPath = getResourcePath("/partial-inputs-test.onnx").toString(); try (OrtEnvironment env = OrtEnvironment.getEnvironment( - OrtEnvironment.LoggingLevel.ORT_LOGGING_LEVEL_FATAL, "partialInputs"); + OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "partialInputs"); OrtSession.SessionOptions options = new SessionOptions(); OrtSession session = env.createSession(modelPath, options)) { assertNotNull(session); @@ -779,6 +779,156 @@ public class InferenceTest { } } + @Test + public void testRunOptions() throws OrtException { + // model takes 1x5 input of fixed type, echoes back + String modelPath = getResourcePath("/test_types_BOOL.pb").toString(); + + try (OrtEnvironment env = OrtEnvironment.getEnvironment("testRunOptions"); + SessionOptions options = new SessionOptions(); + OrtSession session = env.createSession(modelPath, options); + OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) { + runOptions.setRunTag("monkeys"); + assertEquals("monkeys", runOptions.getRunTag()); + runOptions.setLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL); + assertEquals(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, runOptions.getLogLevel()); + runOptions.setLogVerbosityLevel(9000); + assertEquals(9000, runOptions.getLogVerbosityLevel()); + runOptions.setTerminate(true); + String inputName = session.getInputNames().iterator().next(); + Map container = new HashMap<>(); + boolean[] flatInput = new boolean[] {true, false, true, false, true}; + Object tensorIn = OrtUtil.reshape(flatInput, new long[] {1, 5}); + OnnxTensor ov = OnnxTensor.createTensor(env, tensorIn); + container.put(inputName, ov); + try (OrtSession.Result res = session.run(container, runOptions)) { + fail("Should have terminated."); + } catch (OrtException e) { + assertTrue(e.getMessage().contains("Exiting due to terminate flag being set to true.")); + assertEquals(OrtException.OrtErrorCode.ORT_FAIL, e.getCode()); + } + OnnxValue.close(container); + } + } + + @Test + public void testExtraSessionOptions() throws OrtException, IOException { + // model takes 1x5 input of fixed type, echoes back + String modelPath = getResourcePath("/test_types_BOOL.pb").toString(); + File tmpPath = File.createTempFile("onnx-runtime-profiling", "file"); + tmpPath.deleteOnExit(); + + try (OrtEnvironment env = OrtEnvironment.getEnvironment("testExtraSessionOptions")) { + try (SessionOptions options = new SessionOptions()) { + options.setCPUArenaAllocator(true); + options.setMemoryPatternOptimization(true); + options.enableProfiling(tmpPath.getAbsolutePath()); + options.setLoggerId("monkeys"); + options.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL); + options.setSessionLogVerbosityLevel(5); + try (OrtSession session = env.createSession(modelPath, options)) { + String inputName = session.getInputNames().iterator().next(); + Map container = new HashMap<>(); + boolean[] flatInput = new boolean[] {true, false, true, false, true}; + Object tensorIn = OrtUtil.reshape(flatInput, new long[] {1, 5}); + OnnxTensor ov = OnnxTensor.createTensor(env, tensorIn); + container.put(inputName, ov); + try (OrtSession.Result res = session.run(container)) { + boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue()); + assertArrayEquals(flatInput, resultArray); + } + String profilingOutput = session.endProfiling(); + File profilingOutputFile = new File(profilingOutput); + profilingOutputFile.deleteOnExit(); + try (OrtSession.Result res = session.run(container)) { + boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue()); + assertArrayEquals(flatInput, resultArray); + } + OnnxValue.close(container); + } + } + try (SessionOptions options = new SessionOptions()) { + options.setCPUArenaAllocator(false); + options.setMemoryPatternOptimization(false); + options.enableProfiling(tmpPath.getAbsolutePath()); + options.disableProfiling(); + options.setSessionLogVerbosityLevel(0); + try (OrtSession session = env.createSession(modelPath, options)) { + String inputName = session.getInputNames().iterator().next(); + Map container = new HashMap<>(); + boolean[] flatInput = new boolean[] {true, false, true, false, true}; + Object tensorIn = OrtUtil.reshape(flatInput, new long[] {1, 5}); + OnnxTensor ov = OnnxTensor.createTensor(env, tensorIn); + container.put(inputName, ov); + try (OrtSession.Result res = session.run(container)) { + boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue()); + assertArrayEquals(flatInput, resultArray); + } + OnnxValue.close(container); + } + } + } + } + + @Test + public void testLoadCustomLibrary() throws OrtException { + // This test is disabled on Android. + if (!OnnxRuntime.isAndroid()) { + String customLibraryName = ""; + String osName = System.getProperty("os.name").toLowerCase(); + if (osName.contains("windows")) { + // In windows we start in the wrong working directory relative to the custom_op_library.dll + // So we look it up as a classpath resource and resolve it to a real path + customLibraryName = getResourcePath("/custom_op_library.dll").toString(); + } else if (osName.contains("mac")) { + customLibraryName = "libcustom_op_library.dylib"; + } else if (osName.contains("linux")) { + customLibraryName = "./libcustom_op_library.so"; + } else { + fail("Unknown os/platform '" + osName + "'"); + } + String customOpLibraryTestModel = + getResourcePath("/custom_op_library/custom_op_test.onnx").toString(); + + try (OrtEnvironment env = OrtEnvironment.getEnvironment("testLoadCustomLibrary"); + SessionOptions options = new SessionOptions()) { + options.registerCustomOpLibrary(customLibraryName); + try (OrtSession session = env.createSession(customOpLibraryTestModel, options)) { + Map container = new HashMap<>(); + + // prepare expected inputs and outputs + float[] flatInputOne = + new float[] { + 1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f, 9.9f, 10.0f, 11.1f, 12.2f, 13.3f, + 14.4f, 15.5f + }; + Object tensorIn = OrtUtil.reshape(flatInputOne, new long[] {3, 5}); + OnnxTensor ov = OnnxTensor.createTensor(env, tensorIn); + container.put("input_1", ov); + + float[] flatInputTwo = + new float[] { + 15.5f, 14.4f, 13.3f, 12.2f, 11.1f, 10.0f, 9.9f, 8.8f, 7.7f, 6.6f, 5.5f, 4.4f, 3.3f, + 2.2f, 1.1f + }; + tensorIn = OrtUtil.reshape(flatInputTwo, new long[] {3, 5}); + ov = OnnxTensor.createTensor(env, tensorIn); + container.put("input_2", ov); + + int[] flatOutput = new int[] {17, 17, 17, 17, 17, 17, 18, 18, 18, 17, 17, 17, 17, 17, 17}; + + try (OrtSession.Result res = session.run(container)) { + OnnxTensor outputTensor = (OnnxTensor) res.get(0); + assertArrayEquals(new long[] {3, 5}, outputTensor.getInfo().shape); + int[] resultArray = TestHelpers.flattenInteger(res.get(0).getValue()); + assertArrayEquals(flatOutput, resultArray); + } + OnnxValue.close(container); + } + } + } + } + @Test public void testModelInputBOOL() throws OrtException { // model takes 1x5 input of fixed type, echoes back