[Java] Adding missing methods on Session, SessionOptions and RunOptions (v2) (#3832)

* java - adding support for custom op libraries.

* Adding support for RunOptions and additional methods for SessionOptions and OrtSession.

As a result OrtEnvironment.LoggingLevel moved to be a top level enum
called OrtLoggingLevel.

* java - adding unit tests for RunOptions and SessionOptions.

* java - removing unused releaseNamesHandle method

* java - add test for custom op library.

* java - adding log verbosity methods, and tests for the same.

* java - fixes for custom op loading test on Windows.

* Cleanup after rebase on master.
This commit is contained in:
Adam Pocock 2020-05-06 04:19:46 -04:00 committed by GitHub
parent d5ec353e58
commit d38b79c6e5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 942 additions and 72 deletions

View file

@ -815,6 +815,10 @@ if (onnxruntime_BUILD_JAVA)
message(STATUS "Running Java tests")
# delegate to gradle's test runner
if(WIN32)
# If we're on windows, symlink the custom op test library somewhere we can see it
set(JAVA_NATIVE_TEST_DIR ${JAVA_OUTPUT_DIR}/native-test)
file(MAKE_DIRECTORY ${JAVA_NATIVE_TEST_DIR})
add_custom_command(TARGET custom_op_library POST_BUILD COMMAND ${CMAKE_COMMAND} -E create_symlink $<TARGET_FILE:custom_op_library> ${JAVA_NATIVE_TEST_DIR}/$<TARGET_FILE_NAME:custom_op_library>)
# On windows ctest requires a test to be an .exe(.com) file
# So there are two options 1) Install Chocolatey and its gradle package
# That package would install gradle.exe shim to its bin so ctest could run gradle.exe
@ -826,8 +830,8 @@ if (onnxruntime_BUILD_JAVA)
-DREPO_ROOT=${REPO_ROOT}
-P ${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime_java_unittests.cmake)
else()
add_test(NAME onnxruntime4j_test COMMAND ${GRADLE_EXECUTABLE} cmakeCheck -DcmakeBuildDir=${CMAKE_CURRENT_BINARY_DIR}
WORKING_DIRECTORY ${REPO_ROOT}/java)
add_test(NAME onnxruntime4j_test COMMAND ${GRADLE_EXECUTABLE} cmakeCheck -DcmakeBuildDir=${CMAKE_CURRENT_BINARY_DIR}
WORKING_DIRECTORY ${REPO_ROOT}/java)
endif()
set_property(TEST onnxruntime4j_test APPEND PROPERTY DEPENDS onnxruntime4j_jni)
endif()

View file

@ -67,6 +67,7 @@ def cmakeBuildDir = System.properties['cmakeBuildDir']
def cmakeJavaDir = "${cmakeBuildDir}/java"
def cmakeNativeLibDir = "${cmakeJavaDir}/native-lib"
def cmakeNativeJniDir = "${cmakeJavaDir}/native-jni"
def cmakeNativeTestDir = "${cmakeJavaDir}/native-test"
def cmakeBuildOutputDir = "${cmakeJavaDir}/build"
compileJava {
@ -84,7 +85,8 @@ sourceSets.test {
// add compiled native libs
resources.srcDirs += [
cmakeNativeLibDir,
cmakeNativeJniDir
cmakeNativeJniDir,
cmakeNativeTestDir
]
}
}
@ -144,6 +146,9 @@ dependencies {
test {
useJUnitPlatform()
if (cmakeBuildDir != null) {
workingDir cmakeBuildDir
}
testLogging {
events "passed", "skipped", "failed"
showStandardStreams = true

View file

@ -78,7 +78,12 @@ final class OnnxRuntime {
}
}
private static boolean isAndroid() {
/**
* Check if we're running on Android.
*
* @return True if the {@code android.app.Activity} class can be loaded, false otherwise.
*/
static boolean isAndroid() {
try {
Class.forName("android.app.Activity");
return true;

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2020 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
@ -15,24 +15,6 @@ import java.util.logging.Logger;
*/
public class OrtEnvironment implements AutoCloseable {
/** The logging level for messages from the environment and session. */
public enum LoggingLevel {
ORT_LOGGING_LEVEL_VERBOSE(0),
ORT_LOGGING_LEVEL_INFO(1),
ORT_LOGGING_LEVEL_WARNING(2),
ORT_LOGGING_LEVEL_ERROR(3),
ORT_LOGGING_LEVEL_FATAL(4);
private final int value;
LoggingLevel(int value) {
this.value = value;
}
public int getValue() {
return value;
}
}
private static final Logger logger = Logger.getLogger(OrtEnvironment.class.getName());
public static final String DEFAULT_NAME = "ort-java";
@ -49,29 +31,29 @@ public class OrtEnvironment implements AutoCloseable {
private static final AtomicInteger refCount = new AtomicInteger();
private static volatile LoggingLevel curLogLevel;
private static volatile OrtLoggingLevel curLogLevel;
private static volatile String curLoggingName;
/**
* Gets the OrtEnvironment. If there is not an environment currently created, it creates one using
* {@link OrtEnvironment#DEFAULT_NAME} and {@link LoggingLevel#ORT_LOGGING_LEVEL_WARNING}.
* {@link OrtEnvironment#DEFAULT_NAME} and {@link OrtLoggingLevel#ORT_LOGGING_LEVEL_WARNING}.
*
* @return An onnxruntime environment.
*/
public static OrtEnvironment getEnvironment() {
return getEnvironment(LoggingLevel.ORT_LOGGING_LEVEL_WARNING, DEFAULT_NAME);
return getEnvironment(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING, DEFAULT_NAME);
}
/**
* Gets the OrtEnvironment. If there is not an environment currently created, it creates one using
* the supplied name and {@link LoggingLevel#ORT_LOGGING_LEVEL_WARNING}.
* the supplied name and {@link OrtLoggingLevel#ORT_LOGGING_LEVEL_WARNING}.
*
* @param name The logging id of the environment.
* @return An onnxruntime environment.
*/
public static OrtEnvironment getEnvironment(String name) {
return getEnvironment(LoggingLevel.ORT_LOGGING_LEVEL_WARNING, name);
return getEnvironment(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING, name);
}
/**
@ -81,7 +63,7 @@ public class OrtEnvironment implements AutoCloseable {
* @param logLevel The logging level to use.
* @return An onnxruntime environment.
*/
public static OrtEnvironment getEnvironment(LoggingLevel logLevel) {
public static OrtEnvironment getEnvironment(OrtLoggingLevel logLevel) {
return getEnvironment(logLevel, DEFAULT_NAME);
}
@ -94,7 +76,8 @@ public class OrtEnvironment implements AutoCloseable {
* @param name The log id.
* @return The OrtEnvironment singleton.
*/
public static synchronized OrtEnvironment getEnvironment(LoggingLevel loggingLevel, String name) {
public static synchronized OrtEnvironment getEnvironment(
OrtLoggingLevel loggingLevel, String name) {
if (INSTANCE == null) {
try {
INSTANCE = new OrtEnvironment(loggingLevel, name);
@ -104,7 +87,7 @@ public class OrtEnvironment implements AutoCloseable {
throw new IllegalStateException("Failed to create OrtEnvironment", e);
}
} else {
if ((loggingLevel.value != curLogLevel.value) || (!name.equals(curLoggingName))) {
if ((loggingLevel.getValue() != curLogLevel.getValue()) || (!name.equals(curLoggingName))) {
logger.warning(
"Tried to change OrtEnvironment's logging level or name while a reference exists.");
}
@ -125,7 +108,7 @@ public class OrtEnvironment implements AutoCloseable {
* @throws OrtException If the environment couldn't be created.
*/
private OrtEnvironment() throws OrtException {
this(LoggingLevel.ORT_LOGGING_LEVEL_WARNING, "java-default");
this(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING, "java-default");
}
/**
@ -135,7 +118,7 @@ public class OrtEnvironment implements AutoCloseable {
* @param name The logging id of the environment.
* @throws OrtException If the environment couldn't be created.
*/
private OrtEnvironment(LoggingLevel loggingLevel, String name) throws OrtException {
private OrtEnvironment(OrtLoggingLevel loggingLevel, String name) throws OrtException {
nativeHandle = createHandle(OnnxRuntime.ortApiHandle, loggingLevel.getValue(), name);
defaultAllocator = new OrtAllocator(getDefaultAllocator(OnnxRuntime.ortApiHandle), true);
}

View file

@ -0,0 +1,54 @@
/*
* Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
import java.util.logging.Logger;
/** The logging level for messages from the environment and session. */
public enum OrtLoggingLevel {
ORT_LOGGING_LEVEL_VERBOSE(0),
ORT_LOGGING_LEVEL_INFO(1),
ORT_LOGGING_LEVEL_WARNING(2),
ORT_LOGGING_LEVEL_ERROR(3),
ORT_LOGGING_LEVEL_FATAL(4);
private final int value;
private static final Logger logger = Logger.getLogger(OrtLoggingLevel.class.getName());
private static final OrtLoggingLevel[] values = new OrtLoggingLevel[5];
static {
for (OrtLoggingLevel ot : OrtLoggingLevel.values()) {
values[ot.value] = ot;
}
}
OrtLoggingLevel(int value) {
this.value = value;
}
/**
* Gets the native value associated with this logging level.
*
* @return The native value.
*/
public int getValue() {
return value;
}
/**
* Maps from the C API's int enum to the Java enum.
*
* @param logLevel The index of the Java enum.
* @return The Java enum.
*/
public static OrtLoggingLevel mapFromInt(int logLevel) {
if ((logLevel > 0) && (logLevel < values.length)) {
return values[logLevel];
} else {
logger.warning("Unknown logging level " + logLevel + " setting to ORT_LOGGING_LEVEL_VERBOSE");
return ORT_LOGGING_LEVEL_VERBOSE;
}
}
}

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
@ -200,6 +200,21 @@ public class OrtSession implements AutoCloseable {
return run(inputs, outputNames);
}
/**
* Scores an input feed dict, returning the map of all inferred outputs.
*
* <p>The outputs are sorted based on their id number.
*
* @param inputs The inputs to score.
* @param runOptions The RunOptions to control this run.
* @return The inferred outputs.
* @throws OrtException If there was an error in native code, the input names are invalid, or if
* there are zero or too many inputs.
*/
public Result run(Map<String, OnnxTensor> inputs, RunOptions runOptions) throws OrtException {
return run(inputs, outputNames, runOptions);
}
/**
* Scores an input feed dict, returning the map of requested inferred outputs.
*
@ -213,6 +228,24 @@ public class OrtSession implements AutoCloseable {
*/
public Result run(Map<String, OnnxTensor> inputs, Set<String> requestedOutputs)
throws OrtException {
return run(inputs, requestedOutputs, null);
}
/**
* Scores an input feed dict, returning the map of requested inferred outputs.
*
* <p>The outputs are sorted based on the supplied set traveral order.
*
* @param inputs The inputs to score.
* @param requestedOutputs The requested outputs.
* @param runOptions The RunOptions to control this run.
* @return The inferred outputs.
* @throws OrtException If there was an error in native code, the input or output names are
* invalid, or if there are zero or too many inputs or outputs.
*/
public Result run(
Map<String, OnnxTensor> inputs, Set<String> requestedOutputs, RunOptions runOptions)
throws OrtException {
if (!closed) {
if (inputs.isEmpty() || (inputs.size() > numInputs)) {
throw new OrtException(
@ -249,6 +282,8 @@ public class OrtSession implements AutoCloseable {
"Unknown output name " + s + ", expected one of " + outputNames.toString());
}
}
long runOptionsHandle = runOptions == null ? 0 : runOptions.nativeHandle;
OnnxValue[] outputValues =
run(
OnnxRuntime.ortApiHandle,
@ -258,7 +293,8 @@ public class OrtSession implements AutoCloseable {
inputHandles,
inputNamesArray.length,
outputNamesArray,
outputNamesArray.length);
outputNamesArray.length,
runOptionsHandle);
return new Result(outputNamesArray, outputValues);
} else {
throw new IllegalStateException("Trying to score a closed OrtSession.");
@ -268,8 +304,8 @@ public class OrtSession implements AutoCloseable {
/**
* Gets the metadata for the currently loaded model.
*
* @throws OrtException on failure
* @return The metadata.
* @throws OrtException If the native call failed.
*/
public OnnxModelMetadata getMetadata() throws OrtException {
if (metadata == null) {
@ -278,6 +314,19 @@ public class OrtSession implements AutoCloseable {
return metadata;
}
/**
* Ends the profiling session and returns the output of the profiler.
*
* <p>Profiling should be enabled in the {@link SessionOptions} used to construct this {@code
* Session}.
*
* @return The profiling output.
* @throws OrtException If the native call failed.
*/
public String endProfiling() throws OrtException {
return endProfiling(OnnxRuntime.ortApiHandle, nativeHandle, allocator.handle);
}
@Override
public String toString() {
return "OrtSession(numInputs=" + numInputs + ",numOutputs=" + numOutputs + ")";
@ -336,6 +385,22 @@ public class OrtSession implements AutoCloseable {
private native NodeInfo[] getOutputInfo(long apiHandle, long nativeHandle, long allocatorHandle)
throws OrtException;
/**
* The native run call. runOptionsHandle can be zero (i.e. the null pointer), but all other
* handles must be valid pointers.
*
* @param apiHandle The pointer to the api.
* @param nativeHandle The pointer to the session.
* @param allocatorHandle The pointer to the allocator.
* @param inputNamesArray The input names.
* @param inputs The input tensors.
* @param numInputs The number of inputs.
* @param outputNamesArray The requested output names.
* @param numOutputs The number of requested outputs.
* @param runOptionsHandle The (possibly null) pointer to the run options.
* @return The OnnxValues produced by this run.
* @throws OrtException If the native call failed in some way.
*/
private native OnnxValue[] run(
long apiHandle,
long nativeHandle,
@ -344,7 +409,11 @@ public class OrtSession implements AutoCloseable {
long[] inputs,
long numInputs,
String[] outputNamesArray,
long numOutputs)
long numOutputs,
long runOptionsHandle)
throws OrtException;
private native String endProfiling(long apiHandle, long nativeHandle, long allocatorHandle)
throws OrtException;
private native void closeSession(long apiHandle, long nativeHandle) throws OrtException;
@ -368,6 +437,9 @@ public class OrtSession implements AutoCloseable {
* options.
*
* <p>Modifying this after the session has been constructed will have no effect.
*
* <p>The SessionOptions object must not be closed until all sessions which use it are closed, as
* otherwise it could release resources that are in use.
*/
public static class SessionOptions implements AutoCloseable {
@ -421,15 +493,39 @@ public class OrtSession implements AutoCloseable {
private final long nativeHandle;
private final List<Long> customLibraryHandles;
private boolean closed = false;
/** Create an empty session options. */
public SessionOptions() {
nativeHandle = createOptions(OnnxRuntime.ortApiHandle);
customLibraryHandles = new ArrayList<>();
}
/** Closes the session options, releasing any memory acquired. */
@Override
public void close() {
closeOptions(OnnxRuntime.ortApiHandle, nativeHandle);
if (!closed) {
if (customLibraryHandles.size() > 0) {
long[] longArray = new long[customLibraryHandles.size()];
for (int i = 0; i < customLibraryHandles.size(); i++) {
longArray[i] = customLibraryHandles.get(i);
}
closeCustomLibraries(longArray);
}
closeOptions(OnnxRuntime.ortApiHandle, nativeHandle);
closed = true;
} else {
throw new IllegalStateException("Trying to close a closed SessionOptions.");
}
}
/** Checks if the SessionOptions is closed, if so throws {@link IllegalStateException}. */
private void checkClosed() {
if (closed) {
throw new IllegalStateException("Trying to use a closed SessionOptions");
}
}
/**
@ -439,6 +535,7 @@ public class OrtSession implements AutoCloseable {
* @throws OrtException If there was an error in native code.
*/
public void setExecutionMode(ExecutionMode mode) throws OrtException {
checkClosed();
setExecutionMode(OnnxRuntime.ortApiHandle, nativeHandle, mode.getID());
}
@ -449,6 +546,7 @@ public class OrtSession implements AutoCloseable {
* @throws OrtException If there was an error in native code.
*/
public void setOptimizationLevel(OptLevel level) throws OrtException {
checkClosed();
setOptimizationLevel(OnnxRuntime.ortApiHandle, nativeHandle, level.getID());
}
@ -460,6 +558,7 @@ public class OrtSession implements AutoCloseable {
* @throws OrtException If there was an error in native code.
*/
public void setInterOpNumThreads(int numThreads) throws OrtException {
checkClosed();
setInterOpNumThreads(OnnxRuntime.ortApiHandle, nativeHandle, numThreads);
}
@ -471,6 +570,7 @@ public class OrtSession implements AutoCloseable {
* @throws OrtException If there was an error in native code.
*/
public void setIntraOpNumThreads(int numThreads) throws OrtException {
checkClosed();
setIntraOpNumThreads(OnnxRuntime.ortApiHandle, nativeHandle, numThreads);
}
@ -481,15 +581,107 @@ public class OrtSession implements AutoCloseable {
* @throws OrtException If there was an error in native code.
*/
public void setOptimizedModelFilePath(String outputPath) throws OrtException {
checkClosed();
setOptimizationModelFilePath(OnnxRuntime.ortApiHandle, nativeHandle, outputPath);
}
/**
* Sets the logger id to use.
*
* @param loggerId The logger id string.
* @throws OrtException If there was an error in native code.
*/
public void setLoggerId(String loggerId) throws OrtException {
checkClosed();
setLoggerId(OnnxRuntime.ortApiHandle, nativeHandle, loggerId);
}
/**
* Enables profiling in sessions using this SessionOptions.
*
* @param filePath The file to write profile information to.
* @throws OrtException If there was an error in native code.
*/
public void enableProfiling(String filePath) throws OrtException {
checkClosed();
enableProfiling(OnnxRuntime.ortApiHandle, nativeHandle, filePath);
}
/**
* Disables profiling in sessions using this SessionOptions.
*
* @throws OrtException If there was an error in native code.
*/
public void disableProfiling() throws OrtException {
checkClosed();
disableProfiling(OnnxRuntime.ortApiHandle, nativeHandle);
}
/**
* Turns on memory pattern optimizations, where memory is preallocated if all shapes are known.
*
* @param memoryPatternOptimization If true enable memory pattern optimizations.
* @throws OrtException If there was an error in native code.
*/
public void setMemoryPatternOptimization(boolean memoryPatternOptimization)
throws OrtException {
checkClosed();
setMemoryPatternOptimization(
OnnxRuntime.ortApiHandle, nativeHandle, memoryPatternOptimization);
}
/**
* Sets the CPU to use an arena memory allocator.
*
* @param useArena If true use an arena memory allocator for the CPU execution provider.
* @throws OrtException If there was an error in native code.
*/
public void setCPUArenaAllocator(boolean useArena) throws OrtException {
checkClosed();
setCPUArenaAllocator(OnnxRuntime.ortApiHandle, nativeHandle, useArena);
}
/**
* Sets the Session's logging level.
*
* @param logLevel The log level to use.
* @throws OrtException If there was an error in native code.
*/
public void setSessionLogLevel(OrtLoggingLevel logLevel) throws OrtException {
checkClosed();
setSessionLogLevel(OnnxRuntime.ortApiHandle, nativeHandle, logLevel.getValue());
}
/**
* Sets the Session's logging verbosity level.
*
* @param logLevel The logging verbosity to use.
* @throws OrtException If there was an error in native code.
*/
public void setSessionLogVerbosityLevel(int logLevel) throws OrtException {
checkClosed();
setSessionLogVerbosityLevel(OnnxRuntime.ortApiHandle, nativeHandle, logLevel);
}
/**
* Registers a library of custom ops for use with {@link OrtSession}s using this SessionOptions.
*
* @param path The path to the library on disk.
* @throws OrtException If there was an error loading the library.
*/
public void registerCustomOpLibrary(String path) throws OrtException {
checkClosed();
long customHandle = registerCustomOpLibrary(OnnxRuntime.ortApiHandle, nativeHandle, path);
customLibraryHandles.add(customHandle);
}
/**
* Add CUDA as an execution backend, using device 0.
*
* @throws OrtException If there was an error in native code.
*/
public void addCUDA() throws OrtException {
checkClosed();
addCUDA(0);
}
@ -500,6 +692,7 @@ public class OrtSession implements AutoCloseable {
* @throws OrtException If there was an error in native code.
*/
public void addCUDA(int deviceNum) throws OrtException {
checkClosed();
addCUDA(OnnxRuntime.ortApiHandle, nativeHandle, deviceNum);
}
@ -513,6 +706,7 @@ public class OrtSession implements AutoCloseable {
* @throws OrtException If there was an error in native code.
*/
public void addCPU(boolean useArena) throws OrtException {
checkClosed();
addCPU(OnnxRuntime.ortApiHandle, nativeHandle, useArena ? 1 : 0);
}
@ -523,6 +717,7 @@ public class OrtSession implements AutoCloseable {
* @throws OrtException If there was an error in native code.
*/
public void addDnnl(boolean useArena) throws OrtException {
checkClosed();
addDnnl(OnnxRuntime.ortApiHandle, nativeHandle, useArena ? 1 : 0);
}
@ -535,6 +730,7 @@ public class OrtSession implements AutoCloseable {
* @throws OrtException If there was an error in native code.
*/
public void addNGraph(String ngBackendType) throws OrtException {
checkClosed();
addNGraph(OnnxRuntime.ortApiHandle, nativeHandle, ngBackendType);
}
@ -545,6 +741,7 @@ public class OrtSession implements AutoCloseable {
* @throws OrtException If there was an error in native code.
*/
public void addOpenVINO(String deviceId) throws OrtException {
checkClosed();
addOpenVINO(OnnxRuntime.ortApiHandle, nativeHandle, deviceId);
}
@ -555,6 +752,7 @@ public class OrtSession implements AutoCloseable {
* @throws OrtException If there was an error in native code.
*/
public void addTensorrt(int deviceNum) throws OrtException {
checkClosed();
addTensorrt(OnnxRuntime.ortApiHandle, nativeHandle, deviceNum);
}
@ -564,6 +762,7 @@ public class OrtSession implements AutoCloseable {
* @throws OrtException If there was an error in native code.
*/
public void addNnapi() throws OrtException {
checkClosed();
addNnapi(OnnxRuntime.ortApiHandle, nativeHandle);
}
@ -575,6 +774,7 @@ public class OrtSession implements AutoCloseable {
* @throws OrtException If there was an error in native code.
*/
public void addNuphar(boolean allowUnalignedBuffers, String settings) throws OrtException {
checkClosed();
addNuphar(OnnxRuntime.ortApiHandle, nativeHandle, allowUnalignedBuffers ? 1 : 0, settings);
}
@ -585,6 +785,7 @@ public class OrtSession implements AutoCloseable {
* @throws OrtException If there was an error in native code.
*/
public void addDirectML(int deviceId) throws OrtException {
checkClosed();
addDirectML(OnnxRuntime.ortApiHandle, nativeHandle, deviceId);
}
@ -595,6 +796,7 @@ public class OrtSession implements AutoCloseable {
* @throws OrtException If there was an error in native code.
*/
public void addACL(boolean useArena) throws OrtException {
checkClosed();
addACL(OnnxRuntime.ortApiHandle, nativeHandle, useArena ? 1 : 0);
}
@ -615,6 +817,31 @@ public class OrtSession implements AutoCloseable {
private native long createOptions(long apiHandle);
private native void setLoggerId(long apiHandle, long nativeHandle, String loggerId)
throws OrtException;
private native void enableProfiling(long apiHandle, long nativeHandle, String filePrefix)
throws OrtException;
private native void disableProfiling(long apiHandle, long nativeHandle) throws OrtException;
private native void setMemoryPatternOptimization(
long apiHandle, long nativeHandle, boolean memoryPatternOptimization) throws OrtException;
private native void setCPUArenaAllocator(long apiHandle, long nativeHandle, boolean useArena)
throws OrtException;
private native void setSessionLogLevel(long apiHandle, long nativeHandle, int logLevel)
throws OrtException;
private native void setSessionLogVerbosityLevel(long apiHandle, long nativeHandle, int logLevel)
throws OrtException;
private native long registerCustomOpLibrary(long apiHandle, long nativeHandle, String path)
throws OrtException;
private native void closeCustomLibraries(long[] nativeHandle);
private native void closeOptions(long apiHandle, long nativeHandle);
/*
@ -658,6 +885,141 @@ public class OrtSession implements AutoCloseable {
private native void addACL(long apiHandle, long nativeHandle, int useArena) throws OrtException;
}
/** Used to control logging and termination of a call to {@link OrtSession#run}. */
public static class RunOptions implements AutoCloseable {
private final long nativeHandle;
private boolean closed = false;
/**
* Creates a RunOptions.
*
* @throws OrtException If the construction of the native RunOptions failed.
*/
public RunOptions() throws OrtException {
this.nativeHandle = createRunOptions(OnnxRuntime.ortApiHandle);
}
/**
* Sets the current logging level on this RunOptions.
*
* @param level The new logging level.
* @throws OrtException If the native call failed.
*/
public void setLogLevel(OrtLoggingLevel level) throws OrtException {
checkClosed();
setLogLevel(OnnxRuntime.ortApiHandle, nativeHandle, level.getValue());
}
/**
* Gets the current logging level set on this RunOptions.
*
* @return The logging level.
* @throws OrtException If the native call failed.
*/
public OrtLoggingLevel getLogLevel() throws OrtException {
checkClosed();
return OrtLoggingLevel.mapFromInt(getLogLevel(OnnxRuntime.ortApiHandle, nativeHandle));
}
/**
* Sets the current logging verbosity level on this RunOptions.
*
* @param level The new logging verbosity level.
* @throws OrtException If the native call failed.
*/
public void setLogVerbosityLevel(int level) throws OrtException {
checkClosed();
setLogVerbosityLevel(OnnxRuntime.ortApiHandle, nativeHandle, level);
}
/**
* Gets the current logging verbosity level set on this RunOptions.
*
* @return The logging verbosity level.
* @throws OrtException If the native call failed.
*/
public int getLogVerbosityLevel() throws OrtException {
checkClosed();
return getLogVerbosityLevel(OnnxRuntime.ortApiHandle, nativeHandle);
}
/**
* Sets the run tag used in logging.
*
* @param runTag The run tag in logging output.
* @throws OrtException If the native library call failed.
*/
public void setRunTag(String runTag) throws OrtException {
checkClosed();
setRunTag(OnnxRuntime.ortApiHandle, nativeHandle, runTag);
}
/**
* Gets the String used to log information about this run.
*
* @return The run tag.
* @throws OrtException If the native library call failed.
*/
public String getRunTag() throws OrtException {
checkClosed();
return getRunTag(OnnxRuntime.ortApiHandle, nativeHandle);
}
/**
* Sets a flag so that all incomplete {@link OrtSession#run} calls using this instance of {@code
* RunOptions} will terminate as soon as possible. If the flag is false, it resets this {@code
* RunOptions} so it can be used with other calls to {@link OrtSession#run}.
*
* @param terminate If true terminate all runs associated with this RunOptions.
* @throws OrtException If the native library call failed.
*/
public void setTerminate(boolean terminate) throws OrtException {
checkClosed();
setTerminate(OnnxRuntime.ortApiHandle, nativeHandle, terminate);
}
/** Checks if the RunOptions is closed, if so throws {@link IllegalStateException}. */
private void checkClosed() {
if (closed) {
throw new IllegalStateException("Trying to use a closed RunOptions");
}
}
@Override
public void close() {
if (!closed) {
close(OnnxRuntime.ortApiHandle, nativeHandle);
closed = true;
} else {
throw new IllegalStateException("Trying to close an already closed RunOptions");
}
}
private static native long createRunOptions(long apiHandle) throws OrtException;
private native void setLogLevel(long apiHandle, long nativeHandle, int logLevel)
throws OrtException;
private native int getLogLevel(long apiHandle, long nativeHandle) throws OrtException;
private native void setLogVerbosityLevel(long apiHandle, long nativeHandle, int logLevel)
throws OrtException;
private native int getLogVerbosityLevel(long apiHandle, long nativeHandle) throws OrtException;
private native void setRunTag(long apiHandle, long nativeHandle, String runTag)
throws OrtException;
private native String getRunTag(long apiHandle, long nativeHandle) throws OrtException;
private native void setTerminate(long apiHandle, long nativeHandle, boolean terminate)
throws OrtException;
private static native void close(long apiHandle, long nativeHandle);
}
/**
* An {@link AutoCloseable} wrapper around a {@link Map} containing {@link OnnxValue}s.
*

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2020 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
@ -16,7 +16,7 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
}
/**
* Must be kept in sync with ORT_LOGGING_LEVEL and OrtEnvironment#LoggingLevel
* Must be kept in sync with ORT_LOGGING_LEVEL and the OrtLoggingLevel java enum
*/
OrtLoggingLevel convertLoggingLevel(jint level) {
switch (level) {

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2020 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
@ -19,9 +19,8 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_la
const OrtApi* api = (const OrtApi*) apiHandle;
OrtSession* session;
jboolean copy;
#ifdef _WIN32
const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, modelPath, &copy);
const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, modelPath, NULL);
size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, modelPath);
wchar_t* newString = (wchar_t*)calloc(stringLength+1,sizeof(jchar));
wcsncpy_s(newString, stringLength+1, (const wchar_t*) cPath, stringLength);
@ -29,7 +28,7 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_la
free(newString);
(*jniEnv)->ReleaseStringChars(jniEnv,modelPath,cPath);
#else
const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, modelPath, &copy);
const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, modelPath, NULL);
checkOrtStatus(jniEnv,api,api->CreateSession((OrtEnv*)envHandle, cPath, (OrtSessionOptions*)optsHandle, &session));
(*jniEnv)->ReleaseStringUTFChars(jniEnv,modelPath,cPath);
#endif
@ -236,14 +235,16 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getOutputInfo
/*
* Class: ai_onnxruntime_OrtSession
* Method: run
* Signature: (JJJ[Ljava/lang/String;[JJ[Ljava/lang/String;J)[Lai/onnxruntime/OnnxValue;
* Signature: (JJJ[Ljava/lang/String;[JJ[Ljava/lang/String;JJ)[Lai/onnxruntime/OnnxValue;
* private native OnnxValue[] run(long apiHandle, long nativeHandle, long allocatorHandle, String[] inputNamesArray, long[] inputs, long numInputs, String[] outputNamesArray, long numOutputs)
*/
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle, jobjectArray inputNamesArr, jlongArray tensorArr, jlong numInputs, jobjectArray outputNamesArr, jlong numOutputs) {
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle, jobjectArray inputNamesArr, jlongArray tensorArr, jlong numInputs, jobjectArray outputNamesArr, jlong numOutputs, jlong runOptionsHandle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
OrtSession* session = (OrtSession*) sessionHandle;
OrtRunOptions* runOptions = (OrtRunOptions*) runOptionsHandle;
// Create the buffers for the Java input and output strings
const char** inputNames;
@ -276,7 +277,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run
// Actually score the inputs.
//printf("inputTensors = %p, first tensor = %p, numInputs = %ld, outputValues = %p, numOutputs = %ld\n",inputTensors,(OrtValue*)inputTensors[0],numInputs,outputValues,numOutputs);
//ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess, _In_ OrtRunOptions* run_options, _In_ const char* const* input_names, _In_ const OrtValue* const* input, size_t input_len, _In_ const char* const* output_names, size_t output_names_len, _Out_ OrtValue** output);
checkOrtStatus(jniEnv,api,api->Run((OrtSession*)sessionHandle, NULL, (const char* const*) inputNames, (const OrtValue* const*) inputTensors, numInputs, (const char* const*) outputNames, numOutputs, outputValues));
checkOrtStatus(jniEnv,api,api->Run(session, runOptions, (const char* const*) inputNames, (const OrtValue* const*) inputTensors, numInputs, (const char* const*) outputNames, numOutputs, outputValues));
// Release the C array of pointers to the tensors.
(*jniEnv)->ReleaseLongArrayElements(jniEnv,tensorArr,inputTensors,JNI_ABORT);
@ -309,6 +310,24 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run
return outputArray;
}
/*
* Class: ai_onnxruntime_OrtSession
* Method: endProfiling
* Signature: (JJJ)Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_endProfiling
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
char* profileStr;
checkOrtStatus(jniEnv,api,api->SessionEndProfiling((OrtSession*)handle,allocator,&profileStr));
jstring profileOutput = (*jniEnv)->NewStringUTF(jniEnv,profileStr);
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,profileStr));
return profileOutput;
}
/*
* Class: ai_onnxruntime_OrtSession
* Method: closeSession
@ -321,23 +340,6 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_closeSession
api->ReleaseSession((OrtSession*)handle);
}
/*
* Class: ai_onnxruntime_OrtSession
* Method: releaseNamesHandle
* Signature: (JJJ)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_releaseNamesHandle
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong allocatorHandle, jlong namesHandle, jlong length) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
char** names = (char**) namesHandle;
for (uint32_t i = 0; i < length; i++) {
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,names[i]));
}
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,names));
}
/*
* Class: ai_onnxruntime_OrtSession
* Method: constructMetadata

View file

@ -0,0 +1,134 @@
/*
* Copyright (c) 2020 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
#include <string.h>
#include "onnxruntime/core/session/onnxruntime_c_api.h"
#include "OrtJniUtil.h"
#include "ai_onnxruntime_OrtSession_RunOptions.h"
/*
* Class: ai_onnxruntime_OrtSession_RunOptions
* Method: createRunOptions
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_createRunOptions
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle) {
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtRunOptions* opts;
checkOrtStatus(jniEnv,api,api->CreateRunOptions(&opts));
return (jlong) opts;
}
/*
* Class: ai_onnxruntime_OrtSession_RunOptions
* Method: setLogLevel
* Signature: (JJI)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_setLogLevel
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle, jint logLevel) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
checkOrtStatus(jniEnv,api,api->RunOptionsSetRunLogSeverityLevel((OrtRunOptions*) nativeHandle,logLevel));
}
/*
* Class: ai_onnxruntime_OrtSession_RunOptions
* Method: getLogLevel
* Signature: (JJ)I
*/
JNIEXPORT jint JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_getLogLevel
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
jint logLevel;
checkOrtStatus(jniEnv,api,api->RunOptionsGetRunLogSeverityLevel((OrtRunOptions*) nativeHandle,&logLevel));
return logLevel;
}
/*
* Class: ai_onnxruntime_OrtSession_RunOptions
* Method: setLogVerbosityLevel
* Signature: (JJI)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_setLogVerbosityLevel
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle, jint logLevel) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
checkOrtStatus(jniEnv,api,api->RunOptionsSetRunLogVerbosityLevel((OrtRunOptions*) nativeHandle,logLevel));
}
/*
* Class: ai_onnxruntime_OrtSession_RunOptions
* Method: getLogVerbosityLevel
* Signature: (JJ)I
*/
JNIEXPORT jint JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_getLogVerbosityLevel
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
jint logLevel;
checkOrtStatus(jniEnv,api,api->RunOptionsGetRunLogVerbosityLevel((OrtRunOptions*) nativeHandle,&logLevel));
return logLevel;
}
/*
* Class: ai_onnxruntime_OrtSession_RunOptions
* Method: setRunTag
* Signature: (JJLjava/lang/String;)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_setRunTag
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle, jstring runTag) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
const char* runTagStr = (*jniEnv)->GetStringUTFChars(jniEnv, runTag, NULL);
checkOrtStatus(jniEnv,api,api->RunOptionsSetRunTag((OrtRunOptions*) nativeHandle, runTagStr));
(*jniEnv)->ReleaseStringUTFChars(jniEnv,runTag,runTagStr);
}
/*
* Class: ai_onnxruntime_OrtSession_RunOptions
* Method: getRunTag
* Signature: (JJ)Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_getRunTag
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
const char* runTagStr;
// This is a reference to the C str, and should not be freed.
checkOrtStatus(jniEnv,api,api->RunOptionsGetRunTag((OrtRunOptions*)nativeHandle,&runTagStr));
jstring runTag = (*jniEnv)->NewStringUTF(jniEnv,runTagStr);
return runTag;
}
/*
* Class: ai_onnxruntime_OrtSession_RunOptions
* Method: setTerminate
* Signature: (JJZ)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_setTerminate
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle, jboolean terminate) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtRunOptions* runOptions = (OrtRunOptions*) nativeHandle;
if (terminate) {
checkOrtStatus(jniEnv,api,api->RunOptionsSetTerminate(runOptions));
} else {
checkOrtStatus(jniEnv,api,api->RunOptionsUnsetTerminate(runOptions));
}
}
/*
* Class: ai_onnxruntime_OrtSession_RunOptions
* Method: close
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_close
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void) jniEnv; (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
api->ReleaseRunOptions((OrtRunOptions*) handle);
}

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2020 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
@ -7,6 +7,11 @@
#include "onnxruntime/core/session/onnxruntime_c_api.h"
#include "OrtJniUtil.h"
#include "ai_onnxruntime_OrtSession_SessionOptions.h"
#ifdef WIN32
#include <Windows.h>
#else
#include <dlfcn.h>
#endif
// Providers
#include "onnxruntime/core/providers/cpu/cpu_provider_factory.h"
@ -107,7 +112,9 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_creat
OrtSessionOptions* opts;
checkOrtStatus(jniEnv,api,api->CreateSessionOptions(&opts));
checkOrtStatus(jniEnv,api,api->SetInterOpNumThreads(opts, 1));
checkOrtStatus(jniEnv,api,api->SetIntraOpNumThreads(opts, 1));
// Commented out due to constant OpenMP warning as this API is invalid when running with OpenMP.
// Not sure how to detect that from within the C API though.
//checkOrtStatus(jniEnv,api,api->SetIntraOpNumThreads(opts, 1));
return (jlong) opts;
}
@ -117,10 +124,174 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_creat
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_closeOptions
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void) jniEnv; (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
api->ReleaseSessionOptions((OrtSessionOptions*) handle);
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void)jniEnv; (void)jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
api->ReleaseSessionOptions((OrtSessionOptions*)handle);
}
/*
* Class: ai_onnxruntime_OrtSession_SessionOptions
* Method: setLoggerId
* Signature: (JJLjava/lang/String;)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setLoggerId
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring loggerId) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
const char* loggerIdStr = (*jniEnv)->GetStringUTFChars(jniEnv, loggerId, NULL);
checkOrtStatus(jniEnv,api,api->SetSessionLogId(options, loggerIdStr));
(*jniEnv)->ReleaseStringUTFChars(jniEnv,loggerId,loggerIdStr);
}
/*
* Class: ai_onnxruntime_OrtSession_SessionOptions
* Method: enableProfiling
* Signature: (JJLjava/lang/String;)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_enableProfiling
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring pathString) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
#ifdef _WIN32
const jchar* path = (*jniEnv)->GetStringChars(jniEnv, pathString, NULL);
size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, pathString);
wchar_t* newString = (wchar_t*)calloc(stringLength+1,sizeof(jchar));
wcsncpy_s(newString, stringLength+1, (const wchar_t*) path, stringLength);
checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,api->EnableProfiling(options, (const wchar_t*) newString));
free(newString);
(*jniEnv)->ReleaseStringChars(jniEnv,pathString,path);
#else
const char* path = (*jniEnv)->GetStringUTFChars(jniEnv, pathString, NULL);
checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,api->EnableProfiling(options, path));
(*jniEnv)->ReleaseStringUTFChars(jniEnv,pathString,path);
#endif
}
/*
* Class: ai_onnxruntime_OrtSession_SessionOptions
* Method: disableProfiling
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_disableProfiling
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
checkOrtStatus(jniEnv,api,api->DisableProfiling(options));
}
/*
* Class: ai_onnxruntime_OrtSession_SessionOptions
* Method: setMemoryPatternOptimization
* Signature: (JJZ)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setMemoryPatternOptimization
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jboolean memPattern) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
if (memPattern) {
checkOrtStatus(jniEnv,api,api->EnableMemPattern(options));
} else {
checkOrtStatus(jniEnv,api,api->DisableMemPattern(options));
}
}
/*
* Class: ai_onnxruntime_OrtSession_SessionOptions
* Method: setCPUArenaAllocator
* Signature: (JJZ)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setCPUArenaAllocator
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jboolean useArena) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
if (useArena) {
checkOrtStatus(jniEnv,api,api->EnableCpuMemArena(options));
} else {
checkOrtStatus(jniEnv,api,api->DisableCpuMemArena(options));
}
}
/*
* Class: ai_onnxruntime_OrtSession_SessionOptions
* Method: setSessionLogLevel
* Signature: (JJI)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setSessionLogLevel
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jint logLevel) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
checkOrtStatus(jniEnv,api,api->SetSessionLogSeverityLevel(options,logLevel));
}
/*
* Class: ai_onnxruntime_OrtSession_SessionOptions
* Method: setSessionLogVerbosityLevel
* Signature: (JJI)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setSessionLogVerbosityLevel
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jint logLevel) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
checkOrtStatus(jniEnv,api,api->SetSessionLogVerbosityLevel(options,logLevel));
}
/*
* Class: ai_onnxruntime_OrtSession_SessionOptions
* Method: registerCustomOpLibrary
* Signature: (JJLjava/lang/String;)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_registerCustomOpLibrary
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring libraryPath) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
// Extract the string chars
const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, libraryPath, NULL);
// Load the library
void* libraryHandle;
checkOrtStatus(jniEnv,api,api->RegisterCustomOpsLibrary((OrtSessionOptions*)optionsHandle,cPath,&libraryHandle));
// Release the string chars
(*jniEnv)->ReleaseStringUTFChars(jniEnv,libraryPath,cPath);
return (jlong) libraryHandle;
}
/*
* Class: ai_onnxruntime_OrtSession_SessionOptions
* Method: closeCustomLibraries
* Signature: ([J)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_closeCustomLibraries
(JNIEnv * jniEnv, jobject jobj, jlongArray libraryHandles) {
(void) jniEnv; (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
// Get the number of elements in the array
jsize numHandles = (*jniEnv)->GetArrayLength(jniEnv, libraryHandles);
// Get the elements of the libraryHandles array
jlong* handles = (*jniEnv)->GetLongArrayElements(jniEnv,libraryHandles,NULL);
// Iterate the handles, calling the appropriate close function
for (jint i = 0; i < numHandles; i++) {
#ifdef WIN32
FreeLibrary((void*)handles[i]);
#else
dlclose((void*)handles[i]);
#endif
}
// Release the long array
(*jniEnv)->ReleaseLongArrayElements(jniEnv,libraryHandles,handles,JNI_ABORT);
}
/*

View file

@ -120,7 +120,7 @@ public class InferenceTest {
String modelPath = getResourcePath("/partial-inputs-test-2.onnx").toString();
try (OrtEnvironment env =
OrtEnvironment.getEnvironment(
OrtEnvironment.LoggingLevel.ORT_LOGGING_LEVEL_FATAL, "partialInputs");
OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "partialInputs");
OrtSession.SessionOptions options = new SessionOptions();
OrtSession session = env.createSession(modelPath, options)) {
assertNotNull(session);
@ -207,7 +207,7 @@ public class InferenceTest {
String modelPath = getResourcePath("/partial-inputs-test.onnx").toString();
try (OrtEnvironment env =
OrtEnvironment.getEnvironment(
OrtEnvironment.LoggingLevel.ORT_LOGGING_LEVEL_FATAL, "partialInputs");
OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "partialInputs");
OrtSession.SessionOptions options = new SessionOptions();
OrtSession session = env.createSession(modelPath, options)) {
assertNotNull(session);
@ -779,6 +779,156 @@ public class InferenceTest {
}
}
@Test
public void testRunOptions() throws OrtException {
// model takes 1x5 input of fixed type, echoes back
String modelPath = getResourcePath("/test_types_BOOL.pb").toString();
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testRunOptions");
SessionOptions options = new SessionOptions();
OrtSession session = env.createSession(modelPath, options);
OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) {
runOptions.setRunTag("monkeys");
assertEquals("monkeys", runOptions.getRunTag());
runOptions.setLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL);
assertEquals(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, runOptions.getLogLevel());
runOptions.setLogVerbosityLevel(9000);
assertEquals(9000, runOptions.getLogVerbosityLevel());
runOptions.setTerminate(true);
String inputName = session.getInputNames().iterator().next();
Map<String, OnnxTensor> container = new HashMap<>();
boolean[] flatInput = new boolean[] {true, false, true, false, true};
Object tensorIn = OrtUtil.reshape(flatInput, new long[] {1, 5});
OnnxTensor ov = OnnxTensor.createTensor(env, tensorIn);
container.put(inputName, ov);
try (OrtSession.Result res = session.run(container, runOptions)) {
fail("Should have terminated.");
} catch (OrtException e) {
assertTrue(e.getMessage().contains("Exiting due to terminate flag being set to true."));
assertEquals(OrtException.OrtErrorCode.ORT_FAIL, e.getCode());
}
OnnxValue.close(container);
}
}
@Test
public void testExtraSessionOptions() throws OrtException, IOException {
// model takes 1x5 input of fixed type, echoes back
String modelPath = getResourcePath("/test_types_BOOL.pb").toString();
File tmpPath = File.createTempFile("onnx-runtime-profiling", "file");
tmpPath.deleteOnExit();
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testExtraSessionOptions")) {
try (SessionOptions options = new SessionOptions()) {
options.setCPUArenaAllocator(true);
options.setMemoryPatternOptimization(true);
options.enableProfiling(tmpPath.getAbsolutePath());
options.setLoggerId("monkeys");
options.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL);
options.setSessionLogVerbosityLevel(5);
try (OrtSession session = env.createSession(modelPath, options)) {
String inputName = session.getInputNames().iterator().next();
Map<String, OnnxTensor> container = new HashMap<>();
boolean[] flatInput = new boolean[] {true, false, true, false, true};
Object tensorIn = OrtUtil.reshape(flatInput, new long[] {1, 5});
OnnxTensor ov = OnnxTensor.createTensor(env, tensorIn);
container.put(inputName, ov);
try (OrtSession.Result res = session.run(container)) {
boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue());
assertArrayEquals(flatInput, resultArray);
}
String profilingOutput = session.endProfiling();
File profilingOutputFile = new File(profilingOutput);
profilingOutputFile.deleteOnExit();
try (OrtSession.Result res = session.run(container)) {
boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue());
assertArrayEquals(flatInput, resultArray);
}
OnnxValue.close(container);
}
}
try (SessionOptions options = new SessionOptions()) {
options.setCPUArenaAllocator(false);
options.setMemoryPatternOptimization(false);
options.enableProfiling(tmpPath.getAbsolutePath());
options.disableProfiling();
options.setSessionLogVerbosityLevel(0);
try (OrtSession session = env.createSession(modelPath, options)) {
String inputName = session.getInputNames().iterator().next();
Map<String, OnnxTensor> container = new HashMap<>();
boolean[] flatInput = new boolean[] {true, false, true, false, true};
Object tensorIn = OrtUtil.reshape(flatInput, new long[] {1, 5});
OnnxTensor ov = OnnxTensor.createTensor(env, tensorIn);
container.put(inputName, ov);
try (OrtSession.Result res = session.run(container)) {
boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue());
assertArrayEquals(flatInput, resultArray);
}
OnnxValue.close(container);
}
}
}
}
@Test
public void testLoadCustomLibrary() throws OrtException {
// This test is disabled on Android.
if (!OnnxRuntime.isAndroid()) {
String customLibraryName = "";
String osName = System.getProperty("os.name").toLowerCase();
if (osName.contains("windows")) {
// In windows we start in the wrong working directory relative to the custom_op_library.dll
// So we look it up as a classpath resource and resolve it to a real path
customLibraryName = getResourcePath("/custom_op_library.dll").toString();
} else if (osName.contains("mac")) {
customLibraryName = "libcustom_op_library.dylib";
} else if (osName.contains("linux")) {
customLibraryName = "./libcustom_op_library.so";
} else {
fail("Unknown os/platform '" + osName + "'");
}
String customOpLibraryTestModel =
getResourcePath("/custom_op_library/custom_op_test.onnx").toString();
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testLoadCustomLibrary");
SessionOptions options = new SessionOptions()) {
options.registerCustomOpLibrary(customLibraryName);
try (OrtSession session = env.createSession(customOpLibraryTestModel, options)) {
Map<String, OnnxTensor> container = new HashMap<>();
// prepare expected inputs and outputs
float[] flatInputOne =
new float[] {
1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f, 9.9f, 10.0f, 11.1f, 12.2f, 13.3f,
14.4f, 15.5f
};
Object tensorIn = OrtUtil.reshape(flatInputOne, new long[] {3, 5});
OnnxTensor ov = OnnxTensor.createTensor(env, tensorIn);
container.put("input_1", ov);
float[] flatInputTwo =
new float[] {
15.5f, 14.4f, 13.3f, 12.2f, 11.1f, 10.0f, 9.9f, 8.8f, 7.7f, 6.6f, 5.5f, 4.4f, 3.3f,
2.2f, 1.1f
};
tensorIn = OrtUtil.reshape(flatInputTwo, new long[] {3, 5});
ov = OnnxTensor.createTensor(env, tensorIn);
container.put("input_2", ov);
int[] flatOutput = new int[] {17, 17, 17, 17, 17, 17, 18, 18, 18, 17, 17, 17, 17, 17, 17};
try (OrtSession.Result res = session.run(container)) {
OnnxTensor outputTensor = (OnnxTensor) res.get(0);
assertArrayEquals(new long[] {3, 5}, outputTensor.getInfo().shape);
int[] resultArray = TestHelpers.flattenInteger(res.get(0).getValue());
assertArrayEquals(flatOutput, resultArray);
}
OnnxValue.close(container);
}
}
}
}
@Test
public void testModelInputBOOL() throws OrtException {
// model takes 1x5 input of fixed type, echoes back