mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
[Java] Adding missing methods on Session, SessionOptions and RunOptions (v2) (#3832)
* java - adding support for custom op libraries. * Adding support for RunOptions and additional methods for SessionOptions and OrtSession. As a result OrtEnvironment.LoggingLevel moved to be a top level enum called OrtLoggingLevel. * java - adding unit tests for RunOptions and SessionOptions. * java - removing unused releaseNamesHandle method * java - add test for custom op library. * java - adding log verbosity methods, and tests for the same. * java - fixes for custom op loading test on Windows. * Cleanup after rebase on master.
This commit is contained in:
parent
d5ec353e58
commit
d38b79c6e5
11 changed files with 942 additions and 72 deletions
|
|
@ -815,6 +815,10 @@ if (onnxruntime_BUILD_JAVA)
|
|||
message(STATUS "Running Java tests")
|
||||
# delegate to gradle's test runner
|
||||
if(WIN32)
|
||||
# If we're on windows, symlink the custom op test library somewhere we can see it
|
||||
set(JAVA_NATIVE_TEST_DIR ${JAVA_OUTPUT_DIR}/native-test)
|
||||
file(MAKE_DIRECTORY ${JAVA_NATIVE_TEST_DIR})
|
||||
add_custom_command(TARGET custom_op_library POST_BUILD COMMAND ${CMAKE_COMMAND} -E create_symlink $<TARGET_FILE:custom_op_library> ${JAVA_NATIVE_TEST_DIR}/$<TARGET_FILE_NAME:custom_op_library>)
|
||||
# On windows ctest requires a test to be an .exe(.com) file
|
||||
# So there are two options 1) Install Chocolatey and its gradle package
|
||||
# That package would install gradle.exe shim to its bin so ctest could run gradle.exe
|
||||
|
|
@ -826,8 +830,8 @@ if (onnxruntime_BUILD_JAVA)
|
|||
-DREPO_ROOT=${REPO_ROOT}
|
||||
-P ${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime_java_unittests.cmake)
|
||||
else()
|
||||
add_test(NAME onnxruntime4j_test COMMAND ${GRADLE_EXECUTABLE} cmakeCheck -DcmakeBuildDir=${CMAKE_CURRENT_BINARY_DIR}
|
||||
WORKING_DIRECTORY ${REPO_ROOT}/java)
|
||||
add_test(NAME onnxruntime4j_test COMMAND ${GRADLE_EXECUTABLE} cmakeCheck -DcmakeBuildDir=${CMAKE_CURRENT_BINARY_DIR}
|
||||
WORKING_DIRECTORY ${REPO_ROOT}/java)
|
||||
endif()
|
||||
set_property(TEST onnxruntime4j_test APPEND PROPERTY DEPENDS onnxruntime4j_jni)
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -67,6 +67,7 @@ def cmakeBuildDir = System.properties['cmakeBuildDir']
|
|||
def cmakeJavaDir = "${cmakeBuildDir}/java"
|
||||
def cmakeNativeLibDir = "${cmakeJavaDir}/native-lib"
|
||||
def cmakeNativeJniDir = "${cmakeJavaDir}/native-jni"
|
||||
def cmakeNativeTestDir = "${cmakeJavaDir}/native-test"
|
||||
def cmakeBuildOutputDir = "${cmakeJavaDir}/build"
|
||||
|
||||
compileJava {
|
||||
|
|
@ -84,7 +85,8 @@ sourceSets.test {
|
|||
// add compiled native libs
|
||||
resources.srcDirs += [
|
||||
cmakeNativeLibDir,
|
||||
cmakeNativeJniDir
|
||||
cmakeNativeJniDir,
|
||||
cmakeNativeTestDir
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
@ -144,6 +146,9 @@ dependencies {
|
|||
|
||||
test {
|
||||
useJUnitPlatform()
|
||||
if (cmakeBuildDir != null) {
|
||||
workingDir cmakeBuildDir
|
||||
}
|
||||
testLogging {
|
||||
events "passed", "skipped", "failed"
|
||||
showStandardStreams = true
|
||||
|
|
|
|||
|
|
@ -78,7 +78,12 @@ final class OnnxRuntime {
|
|||
}
|
||||
}
|
||||
|
||||
private static boolean isAndroid() {
|
||||
/**
|
||||
* Check if we're running on Android.
|
||||
*
|
||||
* @return True if the {@code android.app.Activity} class can be loaded, false otherwise.
|
||||
*/
|
||||
static boolean isAndroid() {
|
||||
try {
|
||||
Class.forName("android.app.Activity");
|
||||
return true;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2020 Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
|
@ -15,24 +15,6 @@ import java.util.logging.Logger;
|
|||
*/
|
||||
public class OrtEnvironment implements AutoCloseable {
|
||||
|
||||
/** The logging level for messages from the environment and session. */
|
||||
public enum LoggingLevel {
|
||||
ORT_LOGGING_LEVEL_VERBOSE(0),
|
||||
ORT_LOGGING_LEVEL_INFO(1),
|
||||
ORT_LOGGING_LEVEL_WARNING(2),
|
||||
ORT_LOGGING_LEVEL_ERROR(3),
|
||||
ORT_LOGGING_LEVEL_FATAL(4);
|
||||
private final int value;
|
||||
|
||||
LoggingLevel(int value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
public int getValue() {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
private static final Logger logger = Logger.getLogger(OrtEnvironment.class.getName());
|
||||
|
||||
public static final String DEFAULT_NAME = "ort-java";
|
||||
|
|
@ -49,29 +31,29 @@ public class OrtEnvironment implements AutoCloseable {
|
|||
|
||||
private static final AtomicInteger refCount = new AtomicInteger();
|
||||
|
||||
private static volatile LoggingLevel curLogLevel;
|
||||
private static volatile OrtLoggingLevel curLogLevel;
|
||||
|
||||
private static volatile String curLoggingName;
|
||||
|
||||
/**
|
||||
* Gets the OrtEnvironment. If there is not an environment currently created, it creates one using
|
||||
* {@link OrtEnvironment#DEFAULT_NAME} and {@link LoggingLevel#ORT_LOGGING_LEVEL_WARNING}.
|
||||
* {@link OrtEnvironment#DEFAULT_NAME} and {@link OrtLoggingLevel#ORT_LOGGING_LEVEL_WARNING}.
|
||||
*
|
||||
* @return An onnxruntime environment.
|
||||
*/
|
||||
public static OrtEnvironment getEnvironment() {
|
||||
return getEnvironment(LoggingLevel.ORT_LOGGING_LEVEL_WARNING, DEFAULT_NAME);
|
||||
return getEnvironment(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING, DEFAULT_NAME);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the OrtEnvironment. If there is not an environment currently created, it creates one using
|
||||
* the supplied name and {@link LoggingLevel#ORT_LOGGING_LEVEL_WARNING}.
|
||||
* the supplied name and {@link OrtLoggingLevel#ORT_LOGGING_LEVEL_WARNING}.
|
||||
*
|
||||
* @param name The logging id of the environment.
|
||||
* @return An onnxruntime environment.
|
||||
*/
|
||||
public static OrtEnvironment getEnvironment(String name) {
|
||||
return getEnvironment(LoggingLevel.ORT_LOGGING_LEVEL_WARNING, name);
|
||||
return getEnvironment(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING, name);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -81,7 +63,7 @@ public class OrtEnvironment implements AutoCloseable {
|
|||
* @param logLevel The logging level to use.
|
||||
* @return An onnxruntime environment.
|
||||
*/
|
||||
public static OrtEnvironment getEnvironment(LoggingLevel logLevel) {
|
||||
public static OrtEnvironment getEnvironment(OrtLoggingLevel logLevel) {
|
||||
return getEnvironment(logLevel, DEFAULT_NAME);
|
||||
}
|
||||
|
||||
|
|
@ -94,7 +76,8 @@ public class OrtEnvironment implements AutoCloseable {
|
|||
* @param name The log id.
|
||||
* @return The OrtEnvironment singleton.
|
||||
*/
|
||||
public static synchronized OrtEnvironment getEnvironment(LoggingLevel loggingLevel, String name) {
|
||||
public static synchronized OrtEnvironment getEnvironment(
|
||||
OrtLoggingLevel loggingLevel, String name) {
|
||||
if (INSTANCE == null) {
|
||||
try {
|
||||
INSTANCE = new OrtEnvironment(loggingLevel, name);
|
||||
|
|
@ -104,7 +87,7 @@ public class OrtEnvironment implements AutoCloseable {
|
|||
throw new IllegalStateException("Failed to create OrtEnvironment", e);
|
||||
}
|
||||
} else {
|
||||
if ((loggingLevel.value != curLogLevel.value) || (!name.equals(curLoggingName))) {
|
||||
if ((loggingLevel.getValue() != curLogLevel.getValue()) || (!name.equals(curLoggingName))) {
|
||||
logger.warning(
|
||||
"Tried to change OrtEnvironment's logging level or name while a reference exists.");
|
||||
}
|
||||
|
|
@ -125,7 +108,7 @@ public class OrtEnvironment implements AutoCloseable {
|
|||
* @throws OrtException If the environment couldn't be created.
|
||||
*/
|
||||
private OrtEnvironment() throws OrtException {
|
||||
this(LoggingLevel.ORT_LOGGING_LEVEL_WARNING, "java-default");
|
||||
this(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING, "java-default");
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -135,7 +118,7 @@ public class OrtEnvironment implements AutoCloseable {
|
|||
* @param name The logging id of the environment.
|
||||
* @throws OrtException If the environment couldn't be created.
|
||||
*/
|
||||
private OrtEnvironment(LoggingLevel loggingLevel, String name) throws OrtException {
|
||||
private OrtEnvironment(OrtLoggingLevel loggingLevel, String name) throws OrtException {
|
||||
nativeHandle = createHandle(OnnxRuntime.ortApiHandle, loggingLevel.getValue(), name);
|
||||
defaultAllocator = new OrtAllocator(getDefaultAllocator(OnnxRuntime.ortApiHandle), true);
|
||||
}
|
||||
|
|
|
|||
54
java/src/main/java/ai/onnxruntime/OrtLoggingLevel.java
Normal file
54
java/src/main/java/ai/onnxruntime/OrtLoggingLevel.java
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
/*
|
||||
* Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/** The logging level for messages from the environment and session. */
|
||||
public enum OrtLoggingLevel {
|
||||
ORT_LOGGING_LEVEL_VERBOSE(0),
|
||||
ORT_LOGGING_LEVEL_INFO(1),
|
||||
ORT_LOGGING_LEVEL_WARNING(2),
|
||||
ORT_LOGGING_LEVEL_ERROR(3),
|
||||
ORT_LOGGING_LEVEL_FATAL(4);
|
||||
private final int value;
|
||||
|
||||
private static final Logger logger = Logger.getLogger(OrtLoggingLevel.class.getName());
|
||||
private static final OrtLoggingLevel[] values = new OrtLoggingLevel[5];
|
||||
|
||||
static {
|
||||
for (OrtLoggingLevel ot : OrtLoggingLevel.values()) {
|
||||
values[ot.value] = ot;
|
||||
}
|
||||
}
|
||||
|
||||
OrtLoggingLevel(int value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the native value associated with this logging level.
|
||||
*
|
||||
* @return The native value.
|
||||
*/
|
||||
public int getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps from the C API's int enum to the Java enum.
|
||||
*
|
||||
* @param logLevel The index of the Java enum.
|
||||
* @return The Java enum.
|
||||
*/
|
||||
public static OrtLoggingLevel mapFromInt(int logLevel) {
|
||||
if ((logLevel > 0) && (logLevel < values.length)) {
|
||||
return values[logLevel];
|
||||
} else {
|
||||
logger.warning("Unknown logging level " + logLevel + " setting to ORT_LOGGING_LEVEL_VERBOSE");
|
||||
return ORT_LOGGING_LEVEL_VERBOSE;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
|
@ -200,6 +200,21 @@ public class OrtSession implements AutoCloseable {
|
|||
return run(inputs, outputNames);
|
||||
}
|
||||
|
||||
/**
|
||||
* Scores an input feed dict, returning the map of all inferred outputs.
|
||||
*
|
||||
* <p>The outputs are sorted based on their id number.
|
||||
*
|
||||
* @param inputs The inputs to score.
|
||||
* @param runOptions The RunOptions to control this run.
|
||||
* @return The inferred outputs.
|
||||
* @throws OrtException If there was an error in native code, the input names are invalid, or if
|
||||
* there are zero or too many inputs.
|
||||
*/
|
||||
public Result run(Map<String, OnnxTensor> inputs, RunOptions runOptions) throws OrtException {
|
||||
return run(inputs, outputNames, runOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Scores an input feed dict, returning the map of requested inferred outputs.
|
||||
*
|
||||
|
|
@ -213,6 +228,24 @@ public class OrtSession implements AutoCloseable {
|
|||
*/
|
||||
public Result run(Map<String, OnnxTensor> inputs, Set<String> requestedOutputs)
|
||||
throws OrtException {
|
||||
return run(inputs, requestedOutputs, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Scores an input feed dict, returning the map of requested inferred outputs.
|
||||
*
|
||||
* <p>The outputs are sorted based on the supplied set traveral order.
|
||||
*
|
||||
* @param inputs The inputs to score.
|
||||
* @param requestedOutputs The requested outputs.
|
||||
* @param runOptions The RunOptions to control this run.
|
||||
* @return The inferred outputs.
|
||||
* @throws OrtException If there was an error in native code, the input or output names are
|
||||
* invalid, or if there are zero or too many inputs or outputs.
|
||||
*/
|
||||
public Result run(
|
||||
Map<String, OnnxTensor> inputs, Set<String> requestedOutputs, RunOptions runOptions)
|
||||
throws OrtException {
|
||||
if (!closed) {
|
||||
if (inputs.isEmpty() || (inputs.size() > numInputs)) {
|
||||
throw new OrtException(
|
||||
|
|
@ -249,6 +282,8 @@ public class OrtSession implements AutoCloseable {
|
|||
"Unknown output name " + s + ", expected one of " + outputNames.toString());
|
||||
}
|
||||
}
|
||||
long runOptionsHandle = runOptions == null ? 0 : runOptions.nativeHandle;
|
||||
|
||||
OnnxValue[] outputValues =
|
||||
run(
|
||||
OnnxRuntime.ortApiHandle,
|
||||
|
|
@ -258,7 +293,8 @@ public class OrtSession implements AutoCloseable {
|
|||
inputHandles,
|
||||
inputNamesArray.length,
|
||||
outputNamesArray,
|
||||
outputNamesArray.length);
|
||||
outputNamesArray.length,
|
||||
runOptionsHandle);
|
||||
return new Result(outputNamesArray, outputValues);
|
||||
} else {
|
||||
throw new IllegalStateException("Trying to score a closed OrtSession.");
|
||||
|
|
@ -268,8 +304,8 @@ public class OrtSession implements AutoCloseable {
|
|||
/**
|
||||
* Gets the metadata for the currently loaded model.
|
||||
*
|
||||
* @throws OrtException on failure
|
||||
* @return The metadata.
|
||||
* @throws OrtException If the native call failed.
|
||||
*/
|
||||
public OnnxModelMetadata getMetadata() throws OrtException {
|
||||
if (metadata == null) {
|
||||
|
|
@ -278,6 +314,19 @@ public class OrtSession implements AutoCloseable {
|
|||
return metadata;
|
||||
}
|
||||
|
||||
/**
|
||||
* Ends the profiling session and returns the output of the profiler.
|
||||
*
|
||||
* <p>Profiling should be enabled in the {@link SessionOptions} used to construct this {@code
|
||||
* Session}.
|
||||
*
|
||||
* @return The profiling output.
|
||||
* @throws OrtException If the native call failed.
|
||||
*/
|
||||
public String endProfiling() throws OrtException {
|
||||
return endProfiling(OnnxRuntime.ortApiHandle, nativeHandle, allocator.handle);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "OrtSession(numInputs=" + numInputs + ",numOutputs=" + numOutputs + ")";
|
||||
|
|
@ -336,6 +385,22 @@ public class OrtSession implements AutoCloseable {
|
|||
private native NodeInfo[] getOutputInfo(long apiHandle, long nativeHandle, long allocatorHandle)
|
||||
throws OrtException;
|
||||
|
||||
/**
|
||||
* The native run call. runOptionsHandle can be zero (i.e. the null pointer), but all other
|
||||
* handles must be valid pointers.
|
||||
*
|
||||
* @param apiHandle The pointer to the api.
|
||||
* @param nativeHandle The pointer to the session.
|
||||
* @param allocatorHandle The pointer to the allocator.
|
||||
* @param inputNamesArray The input names.
|
||||
* @param inputs The input tensors.
|
||||
* @param numInputs The number of inputs.
|
||||
* @param outputNamesArray The requested output names.
|
||||
* @param numOutputs The number of requested outputs.
|
||||
* @param runOptionsHandle The (possibly null) pointer to the run options.
|
||||
* @return The OnnxValues produced by this run.
|
||||
* @throws OrtException If the native call failed in some way.
|
||||
*/
|
||||
private native OnnxValue[] run(
|
||||
long apiHandle,
|
||||
long nativeHandle,
|
||||
|
|
@ -344,7 +409,11 @@ public class OrtSession implements AutoCloseable {
|
|||
long[] inputs,
|
||||
long numInputs,
|
||||
String[] outputNamesArray,
|
||||
long numOutputs)
|
||||
long numOutputs,
|
||||
long runOptionsHandle)
|
||||
throws OrtException;
|
||||
|
||||
private native String endProfiling(long apiHandle, long nativeHandle, long allocatorHandle)
|
||||
throws OrtException;
|
||||
|
||||
private native void closeSession(long apiHandle, long nativeHandle) throws OrtException;
|
||||
|
|
@ -368,6 +437,9 @@ public class OrtSession implements AutoCloseable {
|
|||
* options.
|
||||
*
|
||||
* <p>Modifying this after the session has been constructed will have no effect.
|
||||
*
|
||||
* <p>The SessionOptions object must not be closed until all sessions which use it are closed, as
|
||||
* otherwise it could release resources that are in use.
|
||||
*/
|
||||
public static class SessionOptions implements AutoCloseable {
|
||||
|
||||
|
|
@ -421,15 +493,39 @@ public class OrtSession implements AutoCloseable {
|
|||
|
||||
private final long nativeHandle;
|
||||
|
||||
private final List<Long> customLibraryHandles;
|
||||
|
||||
private boolean closed = false;
|
||||
|
||||
/** Create an empty session options. */
|
||||
public SessionOptions() {
|
||||
nativeHandle = createOptions(OnnxRuntime.ortApiHandle);
|
||||
customLibraryHandles = new ArrayList<>();
|
||||
}
|
||||
|
||||
/** Closes the session options, releasing any memory acquired. */
|
||||
@Override
|
||||
public void close() {
|
||||
closeOptions(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
if (!closed) {
|
||||
if (customLibraryHandles.size() > 0) {
|
||||
long[] longArray = new long[customLibraryHandles.size()];
|
||||
for (int i = 0; i < customLibraryHandles.size(); i++) {
|
||||
longArray[i] = customLibraryHandles.get(i);
|
||||
}
|
||||
closeCustomLibraries(longArray);
|
||||
}
|
||||
closeOptions(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
closed = true;
|
||||
} else {
|
||||
throw new IllegalStateException("Trying to close a closed SessionOptions.");
|
||||
}
|
||||
}
|
||||
|
||||
/** Checks if the SessionOptions is closed, if so throws {@link IllegalStateException}. */
|
||||
private void checkClosed() {
|
||||
if (closed) {
|
||||
throw new IllegalStateException("Trying to use a closed SessionOptions");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -439,6 +535,7 @@ public class OrtSession implements AutoCloseable {
|
|||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void setExecutionMode(ExecutionMode mode) throws OrtException {
|
||||
checkClosed();
|
||||
setExecutionMode(OnnxRuntime.ortApiHandle, nativeHandle, mode.getID());
|
||||
}
|
||||
|
||||
|
|
@ -449,6 +546,7 @@ public class OrtSession implements AutoCloseable {
|
|||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void setOptimizationLevel(OptLevel level) throws OrtException {
|
||||
checkClosed();
|
||||
setOptimizationLevel(OnnxRuntime.ortApiHandle, nativeHandle, level.getID());
|
||||
}
|
||||
|
||||
|
|
@ -460,6 +558,7 @@ public class OrtSession implements AutoCloseable {
|
|||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void setInterOpNumThreads(int numThreads) throws OrtException {
|
||||
checkClosed();
|
||||
setInterOpNumThreads(OnnxRuntime.ortApiHandle, nativeHandle, numThreads);
|
||||
}
|
||||
|
||||
|
|
@ -471,6 +570,7 @@ public class OrtSession implements AutoCloseable {
|
|||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void setIntraOpNumThreads(int numThreads) throws OrtException {
|
||||
checkClosed();
|
||||
setIntraOpNumThreads(OnnxRuntime.ortApiHandle, nativeHandle, numThreads);
|
||||
}
|
||||
|
||||
|
|
@ -481,15 +581,107 @@ public class OrtSession implements AutoCloseable {
|
|||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void setOptimizedModelFilePath(String outputPath) throws OrtException {
|
||||
checkClosed();
|
||||
setOptimizationModelFilePath(OnnxRuntime.ortApiHandle, nativeHandle, outputPath);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the logger id to use.
|
||||
*
|
||||
* @param loggerId The logger id string.
|
||||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void setLoggerId(String loggerId) throws OrtException {
|
||||
checkClosed();
|
||||
setLoggerId(OnnxRuntime.ortApiHandle, nativeHandle, loggerId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Enables profiling in sessions using this SessionOptions.
|
||||
*
|
||||
* @param filePath The file to write profile information to.
|
||||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void enableProfiling(String filePath) throws OrtException {
|
||||
checkClosed();
|
||||
enableProfiling(OnnxRuntime.ortApiHandle, nativeHandle, filePath);
|
||||
}
|
||||
|
||||
/**
|
||||
* Disables profiling in sessions using this SessionOptions.
|
||||
*
|
||||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void disableProfiling() throws OrtException {
|
||||
checkClosed();
|
||||
disableProfiling(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
}
|
||||
|
||||
/**
|
||||
* Turns on memory pattern optimizations, where memory is preallocated if all shapes are known.
|
||||
*
|
||||
* @param memoryPatternOptimization If true enable memory pattern optimizations.
|
||||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void setMemoryPatternOptimization(boolean memoryPatternOptimization)
|
||||
throws OrtException {
|
||||
checkClosed();
|
||||
setMemoryPatternOptimization(
|
||||
OnnxRuntime.ortApiHandle, nativeHandle, memoryPatternOptimization);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the CPU to use an arena memory allocator.
|
||||
*
|
||||
* @param useArena If true use an arena memory allocator for the CPU execution provider.
|
||||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void setCPUArenaAllocator(boolean useArena) throws OrtException {
|
||||
checkClosed();
|
||||
setCPUArenaAllocator(OnnxRuntime.ortApiHandle, nativeHandle, useArena);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the Session's logging level.
|
||||
*
|
||||
* @param logLevel The log level to use.
|
||||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void setSessionLogLevel(OrtLoggingLevel logLevel) throws OrtException {
|
||||
checkClosed();
|
||||
setSessionLogLevel(OnnxRuntime.ortApiHandle, nativeHandle, logLevel.getValue());
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the Session's logging verbosity level.
|
||||
*
|
||||
* @param logLevel The logging verbosity to use.
|
||||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void setSessionLogVerbosityLevel(int logLevel) throws OrtException {
|
||||
checkClosed();
|
||||
setSessionLogVerbosityLevel(OnnxRuntime.ortApiHandle, nativeHandle, logLevel);
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers a library of custom ops for use with {@link OrtSession}s using this SessionOptions.
|
||||
*
|
||||
* @param path The path to the library on disk.
|
||||
* @throws OrtException If there was an error loading the library.
|
||||
*/
|
||||
public void registerCustomOpLibrary(String path) throws OrtException {
|
||||
checkClosed();
|
||||
long customHandle = registerCustomOpLibrary(OnnxRuntime.ortApiHandle, nativeHandle, path);
|
||||
customLibraryHandles.add(customHandle);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add CUDA as an execution backend, using device 0.
|
||||
*
|
||||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void addCUDA() throws OrtException {
|
||||
checkClosed();
|
||||
addCUDA(0);
|
||||
}
|
||||
|
||||
|
|
@ -500,6 +692,7 @@ public class OrtSession implements AutoCloseable {
|
|||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void addCUDA(int deviceNum) throws OrtException {
|
||||
checkClosed();
|
||||
addCUDA(OnnxRuntime.ortApiHandle, nativeHandle, deviceNum);
|
||||
}
|
||||
|
||||
|
|
@ -513,6 +706,7 @@ public class OrtSession implements AutoCloseable {
|
|||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void addCPU(boolean useArena) throws OrtException {
|
||||
checkClosed();
|
||||
addCPU(OnnxRuntime.ortApiHandle, nativeHandle, useArena ? 1 : 0);
|
||||
}
|
||||
|
||||
|
|
@ -523,6 +717,7 @@ public class OrtSession implements AutoCloseable {
|
|||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void addDnnl(boolean useArena) throws OrtException {
|
||||
checkClosed();
|
||||
addDnnl(OnnxRuntime.ortApiHandle, nativeHandle, useArena ? 1 : 0);
|
||||
}
|
||||
|
||||
|
|
@ -535,6 +730,7 @@ public class OrtSession implements AutoCloseable {
|
|||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void addNGraph(String ngBackendType) throws OrtException {
|
||||
checkClosed();
|
||||
addNGraph(OnnxRuntime.ortApiHandle, nativeHandle, ngBackendType);
|
||||
}
|
||||
|
||||
|
|
@ -545,6 +741,7 @@ public class OrtSession implements AutoCloseable {
|
|||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void addOpenVINO(String deviceId) throws OrtException {
|
||||
checkClosed();
|
||||
addOpenVINO(OnnxRuntime.ortApiHandle, nativeHandle, deviceId);
|
||||
}
|
||||
|
||||
|
|
@ -555,6 +752,7 @@ public class OrtSession implements AutoCloseable {
|
|||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void addTensorrt(int deviceNum) throws OrtException {
|
||||
checkClosed();
|
||||
addTensorrt(OnnxRuntime.ortApiHandle, nativeHandle, deviceNum);
|
||||
}
|
||||
|
||||
|
|
@ -564,6 +762,7 @@ public class OrtSession implements AutoCloseable {
|
|||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void addNnapi() throws OrtException {
|
||||
checkClosed();
|
||||
addNnapi(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
}
|
||||
|
||||
|
|
@ -575,6 +774,7 @@ public class OrtSession implements AutoCloseable {
|
|||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void addNuphar(boolean allowUnalignedBuffers, String settings) throws OrtException {
|
||||
checkClosed();
|
||||
addNuphar(OnnxRuntime.ortApiHandle, nativeHandle, allowUnalignedBuffers ? 1 : 0, settings);
|
||||
}
|
||||
|
||||
|
|
@ -585,6 +785,7 @@ public class OrtSession implements AutoCloseable {
|
|||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void addDirectML(int deviceId) throws OrtException {
|
||||
checkClosed();
|
||||
addDirectML(OnnxRuntime.ortApiHandle, nativeHandle, deviceId);
|
||||
}
|
||||
|
||||
|
|
@ -595,6 +796,7 @@ public class OrtSession implements AutoCloseable {
|
|||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void addACL(boolean useArena) throws OrtException {
|
||||
checkClosed();
|
||||
addACL(OnnxRuntime.ortApiHandle, nativeHandle, useArena ? 1 : 0);
|
||||
}
|
||||
|
||||
|
|
@ -615,6 +817,31 @@ public class OrtSession implements AutoCloseable {
|
|||
|
||||
private native long createOptions(long apiHandle);
|
||||
|
||||
private native void setLoggerId(long apiHandle, long nativeHandle, String loggerId)
|
||||
throws OrtException;
|
||||
|
||||
private native void enableProfiling(long apiHandle, long nativeHandle, String filePrefix)
|
||||
throws OrtException;
|
||||
|
||||
private native void disableProfiling(long apiHandle, long nativeHandle) throws OrtException;
|
||||
|
||||
private native void setMemoryPatternOptimization(
|
||||
long apiHandle, long nativeHandle, boolean memoryPatternOptimization) throws OrtException;
|
||||
|
||||
private native void setCPUArenaAllocator(long apiHandle, long nativeHandle, boolean useArena)
|
||||
throws OrtException;
|
||||
|
||||
private native void setSessionLogLevel(long apiHandle, long nativeHandle, int logLevel)
|
||||
throws OrtException;
|
||||
|
||||
private native void setSessionLogVerbosityLevel(long apiHandle, long nativeHandle, int logLevel)
|
||||
throws OrtException;
|
||||
|
||||
private native long registerCustomOpLibrary(long apiHandle, long nativeHandle, String path)
|
||||
throws OrtException;
|
||||
|
||||
private native void closeCustomLibraries(long[] nativeHandle);
|
||||
|
||||
private native void closeOptions(long apiHandle, long nativeHandle);
|
||||
|
||||
/*
|
||||
|
|
@ -658,6 +885,141 @@ public class OrtSession implements AutoCloseable {
|
|||
private native void addACL(long apiHandle, long nativeHandle, int useArena) throws OrtException;
|
||||
}
|
||||
|
||||
/** Used to control logging and termination of a call to {@link OrtSession#run}. */
|
||||
public static class RunOptions implements AutoCloseable {
|
||||
|
||||
private final long nativeHandle;
|
||||
|
||||
private boolean closed = false;
|
||||
|
||||
/**
|
||||
* Creates a RunOptions.
|
||||
*
|
||||
* @throws OrtException If the construction of the native RunOptions failed.
|
||||
*/
|
||||
public RunOptions() throws OrtException {
|
||||
this.nativeHandle = createRunOptions(OnnxRuntime.ortApiHandle);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the current logging level on this RunOptions.
|
||||
*
|
||||
* @param level The new logging level.
|
||||
* @throws OrtException If the native call failed.
|
||||
*/
|
||||
public void setLogLevel(OrtLoggingLevel level) throws OrtException {
|
||||
checkClosed();
|
||||
setLogLevel(OnnxRuntime.ortApiHandle, nativeHandle, level.getValue());
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the current logging level set on this RunOptions.
|
||||
*
|
||||
* @return The logging level.
|
||||
* @throws OrtException If the native call failed.
|
||||
*/
|
||||
public OrtLoggingLevel getLogLevel() throws OrtException {
|
||||
checkClosed();
|
||||
return OrtLoggingLevel.mapFromInt(getLogLevel(OnnxRuntime.ortApiHandle, nativeHandle));
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the current logging verbosity level on this RunOptions.
|
||||
*
|
||||
* @param level The new logging verbosity level.
|
||||
* @throws OrtException If the native call failed.
|
||||
*/
|
||||
public void setLogVerbosityLevel(int level) throws OrtException {
|
||||
checkClosed();
|
||||
setLogVerbosityLevel(OnnxRuntime.ortApiHandle, nativeHandle, level);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the current logging verbosity level set on this RunOptions.
|
||||
*
|
||||
* @return The logging verbosity level.
|
||||
* @throws OrtException If the native call failed.
|
||||
*/
|
||||
public int getLogVerbosityLevel() throws OrtException {
|
||||
checkClosed();
|
||||
return getLogVerbosityLevel(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the run tag used in logging.
|
||||
*
|
||||
* @param runTag The run tag in logging output.
|
||||
* @throws OrtException If the native library call failed.
|
||||
*/
|
||||
public void setRunTag(String runTag) throws OrtException {
|
||||
checkClosed();
|
||||
setRunTag(OnnxRuntime.ortApiHandle, nativeHandle, runTag);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the String used to log information about this run.
|
||||
*
|
||||
* @return The run tag.
|
||||
* @throws OrtException If the native library call failed.
|
||||
*/
|
||||
public String getRunTag() throws OrtException {
|
||||
checkClosed();
|
||||
return getRunTag(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets a flag so that all incomplete {@link OrtSession#run} calls using this instance of {@code
|
||||
* RunOptions} will terminate as soon as possible. If the flag is false, it resets this {@code
|
||||
* RunOptions} so it can be used with other calls to {@link OrtSession#run}.
|
||||
*
|
||||
* @param terminate If true terminate all runs associated with this RunOptions.
|
||||
* @throws OrtException If the native library call failed.
|
||||
*/
|
||||
public void setTerminate(boolean terminate) throws OrtException {
|
||||
checkClosed();
|
||||
setTerminate(OnnxRuntime.ortApiHandle, nativeHandle, terminate);
|
||||
}
|
||||
|
||||
/** Checks if the RunOptions is closed, if so throws {@link IllegalStateException}. */
|
||||
private void checkClosed() {
|
||||
if (closed) {
|
||||
throw new IllegalStateException("Trying to use a closed RunOptions");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
if (!closed) {
|
||||
close(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
closed = true;
|
||||
} else {
|
||||
throw new IllegalStateException("Trying to close an already closed RunOptions");
|
||||
}
|
||||
}
|
||||
|
||||
private static native long createRunOptions(long apiHandle) throws OrtException;
|
||||
|
||||
private native void setLogLevel(long apiHandle, long nativeHandle, int logLevel)
|
||||
throws OrtException;
|
||||
|
||||
private native int getLogLevel(long apiHandle, long nativeHandle) throws OrtException;
|
||||
|
||||
private native void setLogVerbosityLevel(long apiHandle, long nativeHandle, int logLevel)
|
||||
throws OrtException;
|
||||
|
||||
private native int getLogVerbosityLevel(long apiHandle, long nativeHandle) throws OrtException;
|
||||
|
||||
private native void setRunTag(long apiHandle, long nativeHandle, String runTag)
|
||||
throws OrtException;
|
||||
|
||||
private native String getRunTag(long apiHandle, long nativeHandle) throws OrtException;
|
||||
|
||||
private native void setTerminate(long apiHandle, long nativeHandle, boolean terminate)
|
||||
throws OrtException;
|
||||
|
||||
private static native void close(long apiHandle, long nativeHandle);
|
||||
}
|
||||
|
||||
/**
|
||||
* An {@link AutoCloseable} wrapper around a {@link Map} containing {@link OnnxValue}s.
|
||||
*
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2020 Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
#include <jni.h>
|
||||
|
|
@ -16,7 +16,7 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
|
|||
}
|
||||
|
||||
/**
|
||||
* Must be kept in sync with ORT_LOGGING_LEVEL and OrtEnvironment#LoggingLevel
|
||||
* Must be kept in sync with ORT_LOGGING_LEVEL and the OrtLoggingLevel java enum
|
||||
*/
|
||||
OrtLoggingLevel convertLoggingLevel(jint level) {
|
||||
switch (level) {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2020 Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
#include <jni.h>
|
||||
|
|
@ -19,9 +19,8 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_la
|
|||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
OrtSession* session;
|
||||
|
||||
jboolean copy;
|
||||
#ifdef _WIN32
|
||||
const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, modelPath, ©);
|
||||
const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, modelPath, NULL);
|
||||
size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, modelPath);
|
||||
wchar_t* newString = (wchar_t*)calloc(stringLength+1,sizeof(jchar));
|
||||
wcsncpy_s(newString, stringLength+1, (const wchar_t*) cPath, stringLength);
|
||||
|
|
@ -29,7 +28,7 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_la
|
|||
free(newString);
|
||||
(*jniEnv)->ReleaseStringChars(jniEnv,modelPath,cPath);
|
||||
#else
|
||||
const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, modelPath, ©);
|
||||
const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, modelPath, NULL);
|
||||
checkOrtStatus(jniEnv,api,api->CreateSession((OrtEnv*)envHandle, cPath, (OrtSessionOptions*)optsHandle, &session));
|
||||
(*jniEnv)->ReleaseStringUTFChars(jniEnv,modelPath,cPath);
|
||||
#endif
|
||||
|
|
@ -236,14 +235,16 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getOutputInfo
|
|||
/*
|
||||
* Class: ai_onnxruntime_OrtSession
|
||||
* Method: run
|
||||
* Signature: (JJJ[Ljava/lang/String;[JJ[Ljava/lang/String;J)[Lai/onnxruntime/OnnxValue;
|
||||
* Signature: (JJJ[Ljava/lang/String;[JJ[Ljava/lang/String;JJ)[Lai/onnxruntime/OnnxValue;
|
||||
* private native OnnxValue[] run(long apiHandle, long nativeHandle, long allocatorHandle, String[] inputNamesArray, long[] inputs, long numInputs, String[] outputNamesArray, long numOutputs)
|
||||
*/
|
||||
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle, jobjectArray inputNamesArr, jlongArray tensorArr, jlong numInputs, jobjectArray outputNamesArr, jlong numOutputs) {
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle, jobjectArray inputNamesArr, jlongArray tensorArr, jlong numInputs, jobjectArray outputNamesArr, jlong numOutputs, jlong runOptionsHandle) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
|
||||
OrtSession* session = (OrtSession*) sessionHandle;
|
||||
OrtRunOptions* runOptions = (OrtRunOptions*) runOptionsHandle;
|
||||
|
||||
// Create the buffers for the Java input and output strings
|
||||
const char** inputNames;
|
||||
|
|
@ -276,7 +277,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run
|
|||
// Actually score the inputs.
|
||||
//printf("inputTensors = %p, first tensor = %p, numInputs = %ld, outputValues = %p, numOutputs = %ld\n",inputTensors,(OrtValue*)inputTensors[0],numInputs,outputValues,numOutputs);
|
||||
//ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess, _In_ OrtRunOptions* run_options, _In_ const char* const* input_names, _In_ const OrtValue* const* input, size_t input_len, _In_ const char* const* output_names, size_t output_names_len, _Out_ OrtValue** output);
|
||||
checkOrtStatus(jniEnv,api,api->Run((OrtSession*)sessionHandle, NULL, (const char* const*) inputNames, (const OrtValue* const*) inputTensors, numInputs, (const char* const*) outputNames, numOutputs, outputValues));
|
||||
checkOrtStatus(jniEnv,api,api->Run(session, runOptions, (const char* const*) inputNames, (const OrtValue* const*) inputTensors, numInputs, (const char* const*) outputNames, numOutputs, outputValues));
|
||||
// Release the C array of pointers to the tensors.
|
||||
(*jniEnv)->ReleaseLongArrayElements(jniEnv,tensorArr,inputTensors,JNI_ABORT);
|
||||
|
||||
|
|
@ -309,6 +310,24 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run
|
|||
return outputArray;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession
|
||||
* Method: endProfiling
|
||||
* Signature: (JJJ)Ljava/lang/String;
|
||||
*/
|
||||
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_endProfiling
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle) {
|
||||
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
|
||||
|
||||
char* profileStr;
|
||||
checkOrtStatus(jniEnv,api,api->SessionEndProfiling((OrtSession*)handle,allocator,&profileStr));
|
||||
jstring profileOutput = (*jniEnv)->NewStringUTF(jniEnv,profileStr);
|
||||
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,profileStr));
|
||||
return profileOutput;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession
|
||||
* Method: closeSession
|
||||
|
|
@ -321,23 +340,6 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_closeSession
|
|||
api->ReleaseSession((OrtSession*)handle);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession
|
||||
* Method: releaseNamesHandle
|
||||
* Signature: (JJJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_releaseNamesHandle
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong allocatorHandle, jlong namesHandle, jlong length) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
|
||||
char** names = (char**) namesHandle;
|
||||
for (uint32_t i = 0; i < length; i++) {
|
||||
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,names[i]));
|
||||
}
|
||||
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,names));
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession
|
||||
* Method: constructMetadata
|
||||
|
|
|
|||
134
java/src/main/native/ai_onnxruntime_OrtSession_RunOptions.c
Normal file
134
java/src/main/native/ai_onnxruntime_OrtSession_RunOptions.c
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
#include <jni.h>
|
||||
#include <string.h>
|
||||
#include "onnxruntime/core/session/onnxruntime_c_api.h"
|
||||
#include "OrtJniUtil.h"
|
||||
#include "ai_onnxruntime_OrtSession_RunOptions.h"
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_RunOptions
|
||||
* Method: createRunOptions
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_createRunOptions
|
||||
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle) {
|
||||
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
OrtRunOptions* opts;
|
||||
checkOrtStatus(jniEnv,api,api->CreateRunOptions(&opts));
|
||||
return (jlong) opts;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_RunOptions
|
||||
* Method: setLogLevel
|
||||
* Signature: (JJI)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_setLogLevel
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle, jint logLevel) {
|
||||
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
checkOrtStatus(jniEnv,api,api->RunOptionsSetRunLogSeverityLevel((OrtRunOptions*) nativeHandle,logLevel));
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_RunOptions
|
||||
* Method: getLogLevel
|
||||
* Signature: (JJ)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_getLogLevel
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle) {
|
||||
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
jint logLevel;
|
||||
checkOrtStatus(jniEnv,api,api->RunOptionsGetRunLogSeverityLevel((OrtRunOptions*) nativeHandle,&logLevel));
|
||||
return logLevel;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_RunOptions
|
||||
* Method: setLogVerbosityLevel
|
||||
* Signature: (JJI)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_setLogVerbosityLevel
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle, jint logLevel) {
|
||||
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
checkOrtStatus(jniEnv,api,api->RunOptionsSetRunLogVerbosityLevel((OrtRunOptions*) nativeHandle,logLevel));
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_RunOptions
|
||||
* Method: getLogVerbosityLevel
|
||||
* Signature: (JJ)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_getLogVerbosityLevel
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle) {
|
||||
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
jint logLevel;
|
||||
checkOrtStatus(jniEnv,api,api->RunOptionsGetRunLogVerbosityLevel((OrtRunOptions*) nativeHandle,&logLevel));
|
||||
return logLevel;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_RunOptions
|
||||
* Method: setRunTag
|
||||
* Signature: (JJLjava/lang/String;)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_setRunTag
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle, jstring runTag) {
|
||||
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
const char* runTagStr = (*jniEnv)->GetStringUTFChars(jniEnv, runTag, NULL);
|
||||
checkOrtStatus(jniEnv,api,api->RunOptionsSetRunTag((OrtRunOptions*) nativeHandle, runTagStr));
|
||||
(*jniEnv)->ReleaseStringUTFChars(jniEnv,runTag,runTagStr);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_RunOptions
|
||||
* Method: getRunTag
|
||||
* Signature: (JJ)Ljava/lang/String;
|
||||
*/
|
||||
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_getRunTag
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle) {
|
||||
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
const char* runTagStr;
|
||||
// This is a reference to the C str, and should not be freed.
|
||||
checkOrtStatus(jniEnv,api,api->RunOptionsGetRunTag((OrtRunOptions*)nativeHandle,&runTagStr));
|
||||
jstring runTag = (*jniEnv)->NewStringUTF(jniEnv,runTagStr);
|
||||
return runTag;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_RunOptions
|
||||
* Method: setTerminate
|
||||
* Signature: (JJZ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_setTerminate
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle, jboolean terminate) {
|
||||
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
OrtRunOptions* runOptions = (OrtRunOptions*) nativeHandle;
|
||||
if (terminate) {
|
||||
checkOrtStatus(jniEnv,api,api->RunOptionsSetTerminate(runOptions));
|
||||
} else {
|
||||
checkOrtStatus(jniEnv,api,api->RunOptionsUnsetTerminate(runOptions));
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_RunOptions
|
||||
* Method: close
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_close
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
|
||||
(void) jniEnv; (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
api->ReleaseRunOptions((OrtRunOptions*) handle);
|
||||
}
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2020 Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
#include <jni.h>
|
||||
|
|
@ -7,6 +7,11 @@
|
|||
#include "onnxruntime/core/session/onnxruntime_c_api.h"
|
||||
#include "OrtJniUtil.h"
|
||||
#include "ai_onnxruntime_OrtSession_SessionOptions.h"
|
||||
#ifdef WIN32
|
||||
#include <Windows.h>
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
#endif
|
||||
|
||||
// Providers
|
||||
#include "onnxruntime/core/providers/cpu/cpu_provider_factory.h"
|
||||
|
|
@ -107,7 +112,9 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_creat
|
|||
OrtSessionOptions* opts;
|
||||
checkOrtStatus(jniEnv,api,api->CreateSessionOptions(&opts));
|
||||
checkOrtStatus(jniEnv,api,api->SetInterOpNumThreads(opts, 1));
|
||||
checkOrtStatus(jniEnv,api,api->SetIntraOpNumThreads(opts, 1));
|
||||
// Commented out due to constant OpenMP warning as this API is invalid when running with OpenMP.
|
||||
// Not sure how to detect that from within the C API though.
|
||||
//checkOrtStatus(jniEnv,api,api->SetIntraOpNumThreads(opts, 1));
|
||||
return (jlong) opts;
|
||||
}
|
||||
|
||||
|
|
@ -117,10 +124,174 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_creat
|
|||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_closeOptions
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
|
||||
(void) jniEnv; (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
api->ReleaseSessionOptions((OrtSessionOptions*) handle);
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
|
||||
(void)jniEnv; (void)jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*)apiHandle;
|
||||
api->ReleaseSessionOptions((OrtSessionOptions*)handle);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_SessionOptions
|
||||
* Method: setLoggerId
|
||||
* Signature: (JJLjava/lang/String;)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setLoggerId
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring loggerId) {
|
||||
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*)apiHandle;
|
||||
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
|
||||
const char* loggerIdStr = (*jniEnv)->GetStringUTFChars(jniEnv, loggerId, NULL);
|
||||
checkOrtStatus(jniEnv,api,api->SetSessionLogId(options, loggerIdStr));
|
||||
(*jniEnv)->ReleaseStringUTFChars(jniEnv,loggerId,loggerIdStr);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_SessionOptions
|
||||
* Method: enableProfiling
|
||||
* Signature: (JJLjava/lang/String;)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_enableProfiling
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring pathString) {
|
||||
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*)apiHandle;
|
||||
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
|
||||
#ifdef _WIN32
|
||||
const jchar* path = (*jniEnv)->GetStringChars(jniEnv, pathString, NULL);
|
||||
size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, pathString);
|
||||
wchar_t* newString = (wchar_t*)calloc(stringLength+1,sizeof(jchar));
|
||||
wcsncpy_s(newString, stringLength+1, (const wchar_t*) path, stringLength);
|
||||
checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,api->EnableProfiling(options, (const wchar_t*) newString));
|
||||
free(newString);
|
||||
(*jniEnv)->ReleaseStringChars(jniEnv,pathString,path);
|
||||
#else
|
||||
const char* path = (*jniEnv)->GetStringUTFChars(jniEnv, pathString, NULL);
|
||||
checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,api->EnableProfiling(options, path));
|
||||
(*jniEnv)->ReleaseStringUTFChars(jniEnv,pathString,path);
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_SessionOptions
|
||||
* Method: disableProfiling
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_disableProfiling
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle) {
|
||||
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*)apiHandle;
|
||||
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
|
||||
checkOrtStatus(jniEnv,api,api->DisableProfiling(options));
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_SessionOptions
|
||||
* Method: setMemoryPatternOptimization
|
||||
* Signature: (JJZ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setMemoryPatternOptimization
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jboolean memPattern) {
|
||||
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*)apiHandle;
|
||||
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
|
||||
if (memPattern) {
|
||||
checkOrtStatus(jniEnv,api,api->EnableMemPattern(options));
|
||||
} else {
|
||||
checkOrtStatus(jniEnv,api,api->DisableMemPattern(options));
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_SessionOptions
|
||||
* Method: setCPUArenaAllocator
|
||||
* Signature: (JJZ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setCPUArenaAllocator
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jboolean useArena) {
|
||||
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*)apiHandle;
|
||||
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
|
||||
if (useArena) {
|
||||
checkOrtStatus(jniEnv,api,api->EnableCpuMemArena(options));
|
||||
} else {
|
||||
checkOrtStatus(jniEnv,api,api->DisableCpuMemArena(options));
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_SessionOptions
|
||||
* Method: setSessionLogLevel
|
||||
* Signature: (JJI)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setSessionLogLevel
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jint logLevel) {
|
||||
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*)apiHandle;
|
||||
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
|
||||
checkOrtStatus(jniEnv,api,api->SetSessionLogSeverityLevel(options,logLevel));
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_SessionOptions
|
||||
* Method: setSessionLogVerbosityLevel
|
||||
* Signature: (JJI)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setSessionLogVerbosityLevel
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jint logLevel) {
|
||||
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*)apiHandle;
|
||||
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
|
||||
checkOrtStatus(jniEnv,api,api->SetSessionLogVerbosityLevel(options,logLevel));
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_SessionOptions
|
||||
* Method: registerCustomOpLibrary
|
||||
* Signature: (JJLjava/lang/String;)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_registerCustomOpLibrary
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring libraryPath) {
|
||||
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
|
||||
// Extract the string chars
|
||||
const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, libraryPath, NULL);
|
||||
|
||||
// Load the library
|
||||
void* libraryHandle;
|
||||
checkOrtStatus(jniEnv,api,api->RegisterCustomOpsLibrary((OrtSessionOptions*)optionsHandle,cPath,&libraryHandle));
|
||||
|
||||
// Release the string chars
|
||||
(*jniEnv)->ReleaseStringUTFChars(jniEnv,libraryPath,cPath);
|
||||
|
||||
return (jlong) libraryHandle;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_SessionOptions
|
||||
* Method: closeCustomLibraries
|
||||
* Signature: ([J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_closeCustomLibraries
|
||||
(JNIEnv * jniEnv, jobject jobj, jlongArray libraryHandles) {
|
||||
(void) jniEnv; (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
|
||||
// Get the number of elements in the array
|
||||
jsize numHandles = (*jniEnv)->GetArrayLength(jniEnv, libraryHandles);
|
||||
|
||||
// Get the elements of the libraryHandles array
|
||||
jlong* handles = (*jniEnv)->GetLongArrayElements(jniEnv,libraryHandles,NULL);
|
||||
|
||||
// Iterate the handles, calling the appropriate close function
|
||||
for (jint i = 0; i < numHandles; i++) {
|
||||
#ifdef WIN32
|
||||
FreeLibrary((void*)handles[i]);
|
||||
#else
|
||||
dlclose((void*)handles[i]);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Release the long array
|
||||
(*jniEnv)->ReleaseLongArrayElements(jniEnv,libraryHandles,handles,JNI_ABORT);
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ public class InferenceTest {
|
|||
String modelPath = getResourcePath("/partial-inputs-test-2.onnx").toString();
|
||||
try (OrtEnvironment env =
|
||||
OrtEnvironment.getEnvironment(
|
||||
OrtEnvironment.LoggingLevel.ORT_LOGGING_LEVEL_FATAL, "partialInputs");
|
||||
OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "partialInputs");
|
||||
OrtSession.SessionOptions options = new SessionOptions();
|
||||
OrtSession session = env.createSession(modelPath, options)) {
|
||||
assertNotNull(session);
|
||||
|
|
@ -207,7 +207,7 @@ public class InferenceTest {
|
|||
String modelPath = getResourcePath("/partial-inputs-test.onnx").toString();
|
||||
try (OrtEnvironment env =
|
||||
OrtEnvironment.getEnvironment(
|
||||
OrtEnvironment.LoggingLevel.ORT_LOGGING_LEVEL_FATAL, "partialInputs");
|
||||
OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "partialInputs");
|
||||
OrtSession.SessionOptions options = new SessionOptions();
|
||||
OrtSession session = env.createSession(modelPath, options)) {
|
||||
assertNotNull(session);
|
||||
|
|
@ -779,6 +779,156 @@ public class InferenceTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRunOptions() throws OrtException {
|
||||
// model takes 1x5 input of fixed type, echoes back
|
||||
String modelPath = getResourcePath("/test_types_BOOL.pb").toString();
|
||||
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testRunOptions");
|
||||
SessionOptions options = new SessionOptions();
|
||||
OrtSession session = env.createSession(modelPath, options);
|
||||
OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) {
|
||||
runOptions.setRunTag("monkeys");
|
||||
assertEquals("monkeys", runOptions.getRunTag());
|
||||
runOptions.setLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL);
|
||||
assertEquals(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, runOptions.getLogLevel());
|
||||
runOptions.setLogVerbosityLevel(9000);
|
||||
assertEquals(9000, runOptions.getLogVerbosityLevel());
|
||||
runOptions.setTerminate(true);
|
||||
String inputName = session.getInputNames().iterator().next();
|
||||
Map<String, OnnxTensor> container = new HashMap<>();
|
||||
boolean[] flatInput = new boolean[] {true, false, true, false, true};
|
||||
Object tensorIn = OrtUtil.reshape(flatInput, new long[] {1, 5});
|
||||
OnnxTensor ov = OnnxTensor.createTensor(env, tensorIn);
|
||||
container.put(inputName, ov);
|
||||
try (OrtSession.Result res = session.run(container, runOptions)) {
|
||||
fail("Should have terminated.");
|
||||
} catch (OrtException e) {
|
||||
assertTrue(e.getMessage().contains("Exiting due to terminate flag being set to true."));
|
||||
assertEquals(OrtException.OrtErrorCode.ORT_FAIL, e.getCode());
|
||||
}
|
||||
OnnxValue.close(container);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testExtraSessionOptions() throws OrtException, IOException {
|
||||
// model takes 1x5 input of fixed type, echoes back
|
||||
String modelPath = getResourcePath("/test_types_BOOL.pb").toString();
|
||||
File tmpPath = File.createTempFile("onnx-runtime-profiling", "file");
|
||||
tmpPath.deleteOnExit();
|
||||
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testExtraSessionOptions")) {
|
||||
try (SessionOptions options = new SessionOptions()) {
|
||||
options.setCPUArenaAllocator(true);
|
||||
options.setMemoryPatternOptimization(true);
|
||||
options.enableProfiling(tmpPath.getAbsolutePath());
|
||||
options.setLoggerId("monkeys");
|
||||
options.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL);
|
||||
options.setSessionLogVerbosityLevel(5);
|
||||
try (OrtSession session = env.createSession(modelPath, options)) {
|
||||
String inputName = session.getInputNames().iterator().next();
|
||||
Map<String, OnnxTensor> container = new HashMap<>();
|
||||
boolean[] flatInput = new boolean[] {true, false, true, false, true};
|
||||
Object tensorIn = OrtUtil.reshape(flatInput, new long[] {1, 5});
|
||||
OnnxTensor ov = OnnxTensor.createTensor(env, tensorIn);
|
||||
container.put(inputName, ov);
|
||||
try (OrtSession.Result res = session.run(container)) {
|
||||
boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue());
|
||||
assertArrayEquals(flatInput, resultArray);
|
||||
}
|
||||
String profilingOutput = session.endProfiling();
|
||||
File profilingOutputFile = new File(profilingOutput);
|
||||
profilingOutputFile.deleteOnExit();
|
||||
try (OrtSession.Result res = session.run(container)) {
|
||||
boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue());
|
||||
assertArrayEquals(flatInput, resultArray);
|
||||
}
|
||||
OnnxValue.close(container);
|
||||
}
|
||||
}
|
||||
try (SessionOptions options = new SessionOptions()) {
|
||||
options.setCPUArenaAllocator(false);
|
||||
options.setMemoryPatternOptimization(false);
|
||||
options.enableProfiling(tmpPath.getAbsolutePath());
|
||||
options.disableProfiling();
|
||||
options.setSessionLogVerbosityLevel(0);
|
||||
try (OrtSession session = env.createSession(modelPath, options)) {
|
||||
String inputName = session.getInputNames().iterator().next();
|
||||
Map<String, OnnxTensor> container = new HashMap<>();
|
||||
boolean[] flatInput = new boolean[] {true, false, true, false, true};
|
||||
Object tensorIn = OrtUtil.reshape(flatInput, new long[] {1, 5});
|
||||
OnnxTensor ov = OnnxTensor.createTensor(env, tensorIn);
|
||||
container.put(inputName, ov);
|
||||
try (OrtSession.Result res = session.run(container)) {
|
||||
boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue());
|
||||
assertArrayEquals(flatInput, resultArray);
|
||||
}
|
||||
OnnxValue.close(container);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLoadCustomLibrary() throws OrtException {
|
||||
// This test is disabled on Android.
|
||||
if (!OnnxRuntime.isAndroid()) {
|
||||
String customLibraryName = "";
|
||||
String osName = System.getProperty("os.name").toLowerCase();
|
||||
if (osName.contains("windows")) {
|
||||
// In windows we start in the wrong working directory relative to the custom_op_library.dll
|
||||
// So we look it up as a classpath resource and resolve it to a real path
|
||||
customLibraryName = getResourcePath("/custom_op_library.dll").toString();
|
||||
} else if (osName.contains("mac")) {
|
||||
customLibraryName = "libcustom_op_library.dylib";
|
||||
} else if (osName.contains("linux")) {
|
||||
customLibraryName = "./libcustom_op_library.so";
|
||||
} else {
|
||||
fail("Unknown os/platform '" + osName + "'");
|
||||
}
|
||||
String customOpLibraryTestModel =
|
||||
getResourcePath("/custom_op_library/custom_op_test.onnx").toString();
|
||||
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testLoadCustomLibrary");
|
||||
SessionOptions options = new SessionOptions()) {
|
||||
options.registerCustomOpLibrary(customLibraryName);
|
||||
try (OrtSession session = env.createSession(customOpLibraryTestModel, options)) {
|
||||
Map<String, OnnxTensor> container = new HashMap<>();
|
||||
|
||||
// prepare expected inputs and outputs
|
||||
float[] flatInputOne =
|
||||
new float[] {
|
||||
1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f, 9.9f, 10.0f, 11.1f, 12.2f, 13.3f,
|
||||
14.4f, 15.5f
|
||||
};
|
||||
Object tensorIn = OrtUtil.reshape(flatInputOne, new long[] {3, 5});
|
||||
OnnxTensor ov = OnnxTensor.createTensor(env, tensorIn);
|
||||
container.put("input_1", ov);
|
||||
|
||||
float[] flatInputTwo =
|
||||
new float[] {
|
||||
15.5f, 14.4f, 13.3f, 12.2f, 11.1f, 10.0f, 9.9f, 8.8f, 7.7f, 6.6f, 5.5f, 4.4f, 3.3f,
|
||||
2.2f, 1.1f
|
||||
};
|
||||
tensorIn = OrtUtil.reshape(flatInputTwo, new long[] {3, 5});
|
||||
ov = OnnxTensor.createTensor(env, tensorIn);
|
||||
container.put("input_2", ov);
|
||||
|
||||
int[] flatOutput = new int[] {17, 17, 17, 17, 17, 17, 18, 18, 18, 17, 17, 17, 17, 17, 17};
|
||||
|
||||
try (OrtSession.Result res = session.run(container)) {
|
||||
OnnxTensor outputTensor = (OnnxTensor) res.get(0);
|
||||
assertArrayEquals(new long[] {3, 5}, outputTensor.getInfo().shape);
|
||||
int[] resultArray = TestHelpers.flattenInteger(res.get(0).getValue());
|
||||
assertArrayEquals(flatOutput, resultArray);
|
||||
}
|
||||
OnnxValue.close(container);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testModelInputBOOL() throws OrtException {
|
||||
// model takes 1x5 input of fixed type, echoes back
|
||||
|
|
|
|||
Loading…
Reference in a new issue