mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-19 21:32:23 +00:00
Java build system enhancements (#2866)
This commit is contained in:
parent
ecdcd682bb
commit
411b3aa801
14 changed files with 414 additions and 300 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -41,6 +41,8 @@ onnxprofile_profile_test_*.json
|
|||
/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.props
|
||||
cmake/external/FeaturizersLibrary/
|
||||
# Java specific ignores
|
||||
java/src/main/native/ai_onnxruntime_*.h
|
||||
java/gradlew
|
||||
java/gradlew.bat
|
||||
java/gradle
|
||||
java/.gradle
|
||||
|
||||
|
|
|
|||
|
|
@ -12,99 +12,61 @@ include_directories(${JNI_INCLUDE_DIRS})
|
|||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=c11")
|
||||
|
||||
set(JAVA_ROOT ${REPO_ROOT}/java)
|
||||
set(CMAKE_JAVA_COMPILE_FLAGS "-source" "1.8" "-target" "1.8" "-encoding" "UTF-8")
|
||||
set(JAVA_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/java)
|
||||
if (onnxruntime_RUN_ONNX_TESTS)
|
||||
set(JAVA_DEPENDS onnxruntime ${test_data_target})
|
||||
else()
|
||||
set(JAVA_DEPENDS onnxruntime)
|
||||
endif()
|
||||
|
||||
# use the gradle wrapper if it exists
|
||||
if(EXISTS "${JAVA_ROOT}/gradlew")
|
||||
set(GRADLE_EXECUTABLE "${JAVA_ROOT}/gradlew")
|
||||
else()
|
||||
# fall back to gradle on our PATH
|
||||
find_program(GRADLE_EXECUTABLE gradle)
|
||||
if(NOT GRADLE_EXECUTABLE)
|
||||
message(SEND_ERROR "Gradle installation not found")
|
||||
endif()
|
||||
endif()
|
||||
message(STATUS "Using gradle: ${GRADLE_EXECUTABLE}")
|
||||
|
||||
# Specify the Java source files
|
||||
set(onnxruntime4j_src
|
||||
${REPO_ROOT}/java/src/main/java/ai/onnxruntime/MapInfo.java
|
||||
${REPO_ROOT}/java/src/main/java/ai/onnxruntime/NodeInfo.java
|
||||
${REPO_ROOT}/java/src/main/java/ai/onnxruntime/OnnxRuntime.java
|
||||
${REPO_ROOT}/java/src/main/java/ai/onnxruntime/OnnxJavaType.java
|
||||
${REPO_ROOT}/java/src/main/java/ai/onnxruntime/OnnxMap.java
|
||||
${REPO_ROOT}/java/src/main/java/ai/onnxruntime/OnnxSequence.java
|
||||
${REPO_ROOT}/java/src/main/java/ai/onnxruntime/OnnxTensor.java
|
||||
${REPO_ROOT}/java/src/main/java/ai/onnxruntime/OnnxValue.java
|
||||
${REPO_ROOT}/java/src/main/java/ai/onnxruntime/OrtAllocator.java
|
||||
${REPO_ROOT}/java/src/main/java/ai/onnxruntime/OrtEnvironment.java
|
||||
${REPO_ROOT}/java/src/main/java/ai/onnxruntime/OrtException.java
|
||||
${REPO_ROOT}/java/src/main/java/ai/onnxruntime/OrtSession.java
|
||||
${REPO_ROOT}/java/src/main/java/ai/onnxruntime/OrtUtil.java
|
||||
${REPO_ROOT}/java/src/main/java/ai/onnxruntime/package-info.java
|
||||
${REPO_ROOT}/java/src/main/java/ai/onnxruntime/SequenceInfo.java
|
||||
${REPO_ROOT}/java/src/main/java/ai/onnxruntime/TensorInfo.java
|
||||
${REPO_ROOT}/java/src/main/java/ai/onnxruntime/ValueInfo.java
|
||||
)
|
||||
file(GLOB_RECURSE onnxruntime4j_gradle_files "${JAVA_ROOT}/*.gradle")
|
||||
file(GLOB_RECURSE onnxruntime4j_src "${JAVA_ROOT}/src/main/java/ai/onnxruntime/*.java")
|
||||
set(JAVA_OUTPUT_JAR ${JAVA_ROOT}/build/libs/onnxruntime.jar)
|
||||
# this jar is solely used to signalling mechanism for dependency management in CMake
|
||||
# if any of the Java sources change, the jar (and generated headers) will be regenerated and the onnxruntime4j_jni target will be rebuilt
|
||||
add_custom_command(OUTPUT ${JAVA_OUTPUT_JAR} COMMAND ${GRADLE_EXECUTABLE} clean jar WORKING_DIRECTORY ${JAVA_ROOT} DEPENDS ${onnxruntime4j_gradle_files} ${onnxruntime4j_src})
|
||||
add_custom_target(onnxruntime4j DEPENDS ${JAVA_OUTPUT_JAR})
|
||||
set_source_files_properties(${JAVA_OUTPUT_JAR} PROPERTIES GENERATED TRUE)
|
||||
set_property(TARGET onnxruntime4j APPEND PROPERTY ADDITIONAL_CLEAN_FILES "${JAVA_OUTPUT_DIR}")
|
||||
|
||||
# Build the jar and generate the native headers
|
||||
add_jar(onnxruntime4j SOURCES ${onnxruntime4j_src} VERSION ${ORT_VERSION} GENERATE_NATIVE_HEADERS onnxruntime4j_generated DESTINATION ${REPO_ROOT}/java/src/main/native/)
|
||||
|
||||
# Specify the native sources (without the generated headers)
|
||||
# Specify the native sources
|
||||
file(GLOB onnxruntime4j_native_src
|
||||
"${REPO_ROOT}/java/src/main/native/*.c"
|
||||
"${REPO_ROOT}/java/src/main/native/OrtJniUtil.h"
|
||||
"${JAVA_ROOT}/src/main/native/*.c"
|
||||
"${JAVA_ROOT}/src/main/native/*.h"
|
||||
"${REPO_ROOT}/include/onnxruntime/core/session/*.h"
|
||||
)
|
||||
|
||||
# Build the JNI library
|
||||
add_library(onnxruntime4j_jni SHARED ${onnxruntime4j_native_src} ${onnxruntime4j_generated})
|
||||
add_library(onnxruntime4j_jni SHARED ${onnxruntime4j_native_src})
|
||||
# depend on java sources. if they change, the JNI should recompile
|
||||
add_dependencies(onnxruntime4j_jni onnxruntime4j)
|
||||
onnxruntime_add_include_to_target(onnxruntime4j_jni onnxruntime_session)
|
||||
target_include_directories(onnxruntime4j_jni PRIVATE ${REPO_ROOT}/include ${REPO_ROOT}/java/src/main/native)
|
||||
target_link_libraries(onnxruntime4j_jni PUBLIC onnxruntime onnxruntime4j_generated)
|
||||
# the JNI headers are generated in the onnxruntime4j target
|
||||
target_include_directories(onnxruntime4j_jni PRIVATE ${REPO_ROOT}/include ${JAVA_ROOT}/build/headers)
|
||||
target_link_libraries(onnxruntime4j_jni PUBLIC onnxruntime)
|
||||
|
||||
# Now the jar, jni binary and shared lib binary have been built, now to build the jar with the binaries added.
|
||||
|
||||
# This blob creates the new jar name
|
||||
get_property(onnxruntime_jar_name TARGET onnxruntime4j PROPERTY JAR_FILE)
|
||||
get_filename_component(onnxruntime_jar_abs ${onnxruntime_jar_name} ABSOLUTE)
|
||||
get_filename_component(jar_path ${onnxruntime_jar_abs} DIRECTORY)
|
||||
set(onnxruntime_jar_binaries_name "${jar_path}/onnxruntime4j-${ORT_VERSION}-with-binaries.jar")
|
||||
set(onnxruntime_jar_binaries_platform "$<SHELL_PATH:${onnxruntime_jar_binaries_name}>")
|
||||
|
||||
# Copy the current jar
|
||||
add_custom_command(TARGET onnxruntime4j_jni PRE_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_jar_name}
|
||||
${onnxruntime_jar_binaries_platform})
|
||||
|
||||
# Make a temp directory to store the binaries
|
||||
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_CURRENT_BINARY_DIR}/java-libs/lib")
|
||||
|
||||
# Copy the binaries
|
||||
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy "$<TARGET_FILE:onnxruntime4j_jni>" ${CMAKE_CURRENT_BINARY_DIR}/java-libs/lib/)
|
||||
|
||||
if (WIN32)
|
||||
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy "$<TARGET_FILE:onnxruntime>" ${CMAKE_CURRENT_BINARY_DIR}/java-libs/lib/)
|
||||
# Update the with-binaries jar so it includes the binaries
|
||||
add_custom_command(
|
||||
TARGET onnxruntime4j_jni POST_BUILD
|
||||
COMMAND ${Java_JAR_EXECUTABLE} -uf ${onnxruntime_jar_binaries_platform} -C ${CMAKE_CURRENT_BINARY_DIR}/java-libs lib/$<TARGET_FILE_NAME:onnxruntime4j_jni> -C ${CMAKE_CURRENT_BINARY_DIR}/java-libs lib/$<TARGET_FILE_NAME:onnxruntime>
|
||||
DEPENDS onnxruntime4j
|
||||
COMMENT "Rebuilding Java archive ${_JAVA_TARGET_OUTPUT_NAME}"
|
||||
VERBATIM
|
||||
)
|
||||
else ()
|
||||
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy "$<TARGET_LINKER_FILE:onnxruntime>" ${CMAKE_CURRENT_BINARY_DIR}/java-libs/lib/)
|
||||
# Update the with-binaries jar so it includes the binaries
|
||||
add_custom_command(
|
||||
TARGET onnxruntime4j_jni POST_BUILD
|
||||
COMMAND ${Java_JAR_EXECUTABLE} -uf ${onnxruntime_jar_binaries_platform} -C ${CMAKE_CURRENT_BINARY_DIR}/java-libs lib/$<TARGET_FILE_NAME:onnxruntime4j_jni> -C ${CMAKE_CURRENT_BINARY_DIR}/java-libs lib/$<TARGET_LINKER_FILE_NAME:onnxruntime>
|
||||
DEPENDS onnxruntime4j
|
||||
COMMENT "Rebuilding Java archive ${_JAVA_TARGET_OUTPUT_NAME}"
|
||||
VERBATIM
|
||||
)
|
||||
endif()
|
||||
|
||||
create_javadoc(onnxruntime4j_javadoc
|
||||
FILES ${onnxruntime4j_src}
|
||||
DOCTITLE "Onnx Runtime Java API"
|
||||
WINDOWTITLE "OnnxRuntime-Java-API"
|
||||
AUTHOR FALSE
|
||||
USE TRUE
|
||||
VERSION FALSE
|
||||
)
|
||||
# expose native libraries to the gradle build process
|
||||
file(MAKE_DIRECTORY ${JAVA_OUTPUT_DIR}/build)
|
||||
set(JAVA_PACKAGE_DIR ai/onnxruntime/native/)
|
||||
set(JAVA_NATIVE_LIB_DIR ${JAVA_OUTPUT_DIR}/native-lib)
|
||||
set(JAVA_NATIVE_JNI_DIR ${JAVA_OUTPUT_DIR}/native-jni)
|
||||
set(JAVA_PACKAGE_LIB_DIR ${JAVA_NATIVE_LIB_DIR}/${JAVA_PACKAGE_DIR})
|
||||
set(JAVA_PACKAGE_JNI_DIR ${JAVA_NATIVE_JNI_DIR}/${JAVA_PACKAGE_DIR})
|
||||
file(MAKE_DIRECTORY ${JAVA_PACKAGE_LIB_DIR})
|
||||
file(MAKE_DIRECTORY ${JAVA_PACKAGE_JNI_DIR})
|
||||
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E create_symlink $<TARGET_FILE:onnxruntime> ${JAVA_PACKAGE_LIB_DIR}/$<TARGET_LINKER_FILE_NAME:onnxruntime>)
|
||||
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E create_symlink $<TARGET_FILE:onnxruntime4j_jni> ${JAVA_PACKAGE_JNI_DIR}/$<TARGET_LINKER_FILE_NAME:onnxruntime4j_jni>)
|
||||
# run the build process (this copies the results back into CMAKE_CURRENT_BINARY_DIR)
|
||||
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${GRADLE_EXECUTABLE} cmakeBuild -DcmakeBuildDir=${CMAKE_CURRENT_BINARY_DIR} WORKING_DIRECTORY ${JAVA_ROOT})
|
||||
|
|
@ -691,35 +691,7 @@ set_property(TARGET custom_op_library APPEND_STRING PROPERTY LINK_FLAGS ${ONNXRU
|
|||
|
||||
if (onnxruntime_BUILD_JAVA)
|
||||
message(STATUS "Running Java tests")
|
||||
# Build and run tests
|
||||
set(onnxruntime4j_test_src
|
||||
${REPO_ROOT}/java/src/test/java/ai/onnxruntime/InferenceTest.java
|
||||
${REPO_ROOT}/java/src/test/java/ai/onnxruntime/TestHelpers.java
|
||||
${REPO_ROOT}/java/src/test/java/ai/onnxruntime/OnnxMl.java
|
||||
${REPO_ROOT}/java/src/test/java/ai/onnxruntime/UtilTest.java
|
||||
)
|
||||
|
||||
# Create test directories
|
||||
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/java-tests/")
|
||||
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/java-tests/results")
|
||||
|
||||
# Download test dependencies
|
||||
if (NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/java-tests/junit-platform-console-standalone-1.5.2.jar)
|
||||
message("Downloading JUnit 5")
|
||||
file(DOWNLOAD https://repo1.maven.org/maven2/org/junit/platform/junit-platform-console-standalone/1.5.2/junit-platform-console-standalone-1.5.2.jar ${CMAKE_CURRENT_BINARY_DIR}/java-tests/junit-platform-console-standalone-1.5.2.jar EXPECTED_HASH SHA1=8d937d2b461018a876836362b256629f4da5feb1)
|
||||
endif()
|
||||
|
||||
if (NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/java-tests/protobuf-java-3.10.0.jar)
|
||||
message("Downloading protobuf-java 3.10.0")
|
||||
file(DOWNLOAD https://repo1.maven.org/maven2/com/google/protobuf/protobuf-java/3.10.0/protobuf-java-3.10.0.jar ${CMAKE_CURRENT_BINARY_DIR}/java-tests/protobuf-java-3.10.0.jar EXPECTED_HASH SHA1=410b61dd0088aab4caa05739558d43df248958c9)
|
||||
endif()
|
||||
|
||||
# Build the test jar
|
||||
add_jar(onnxruntime4j_test SOURCES ${onnxruntime4j_test_src} VERSION ${ORT_VERSION} INCLUDE_JARS ${onnxruntime_jar_name} ${CMAKE_CURRENT_BINARY_DIR}/java-tests/junit-platform-console-standalone-1.5.2.jar ${CMAKE_CURRENT_BINARY_DIR}/java-tests/protobuf-java-3.10.0.jar)
|
||||
|
||||
add_dependencies(onnxruntime4j_test onnxruntime4j_jni onnxruntime4j)
|
||||
get_property(onnxruntime_test_jar_name TARGET onnxruntime4j_test PROPERTY JAR_FILE)
|
||||
|
||||
# Run the tests with JUnit's console launcher
|
||||
add_test(NAME java-api COMMAND ${Java_JAVA_EXECUTABLE} -jar ${CMAKE_CURRENT_BINARY_DIR}/java-tests/junit-platform-console-standalone-1.5.2.jar -cp ${CMAKE_CURRENT_BINARY_DIR}/java-tests/protobuf-java-3.10.0.jar -cp ${onnxruntime_test_jar_name} -cp ${onnxruntime_jar_binaries_platform} --scan-class-path --fail-if-no-tests --reports-dir=${CMAKE_CURRENT_BINARY_DIR}/java-tests/results --disable-banner WORKING_DIRECTORY ${REPO_ROOT})
|
||||
# delegate to gradle's test runner
|
||||
add_test(NAME onnxruntime4j_test COMMAND ${GRADLE_EXECUTABLE} cmakeCheck -DcmakeBuildDir=${CMAKE_CURRENT_BINARY_DIR} WORKING_DIRECTORY ${REPO_ROOT}/java)
|
||||
set_property(TEST onnxruntime4j_test APPEND PROPERTY DEPENDS onnxruntime4j_jni)
|
||||
endif()
|
||||
|
|
|
|||
80
java/README.md
Normal file
80
java/README.md
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
# ONNX Runtime Java API
|
||||
|
||||
This directory contains the Java language binding for the ONNX runtime.
|
||||
Java Native Interface (JNI) is used to allow for seamless calls to ONNX runtime from Java.
|
||||
|
||||
## Usage
|
||||
|
||||
TBD: maven distribution
|
||||
|
||||
The minimum supported Java Runtime is version 8.
|
||||
|
||||
An example implementation is located in [src/test/java/sample/ScoreMNIST.java](src/test/java/sample/ScoreMNIST.java)
|
||||
|
||||
This project can be built manually using the instructions below.
|
||||
|
||||
### Building
|
||||
|
||||
Use the main project's [build instructions](../BUILD.md) with the `--build_java` option.
|
||||
|
||||
#### Requirements
|
||||
|
||||
JDK version 8 or later is required.
|
||||
The [Gradle](https://gradle.org/) build system is required and used here to manage the Java project's dependency management, compilation, testing, and assembly.
|
||||
You may use your system Gradle installation installed on your PATH.
|
||||
Version 6 or newer is recommended.
|
||||
Optionally, you may use your own Gradle [wrapper](https://docs.gradle.org/current/userguide/gradle_wrapper.html) which will be locked to a version specified in the `build.gradle` configuration.
|
||||
This can be done once by using system Gradle installation to invoke the wrapper task in the java project's directory: `cd REPO_ROOT/java && gradle wrapper`
|
||||
Any installed wrapper is gitignored.
|
||||
|
||||
#### Build Output
|
||||
|
||||
The build will generate output in `$REPO_ROOT/build/$OS/$CONFIGURATION/java/build`:
|
||||
|
||||
* `docs/javadoc/` - HTML javadoc
|
||||
* `reports/` - detailed test results and other reports
|
||||
* `libs/onnxruntime.jar` - JAR with classes, depends on `onnxruntime-jni.jar` and `onnxruntime-lib.jar `
|
||||
* `libs/onnxruntime-jni.jar`- JAR with JNI shared library
|
||||
* `libs/onnxruntime-lib.jar` - JAR with onnxruntime shared library
|
||||
* `libs/onnxruntime-all.jar` - the 3 preceding jars all combined: JAR with classes, JNI shared library, and onnxruntime shared library
|
||||
|
||||
The reason the shared libraries are split out like that is that users can mix and match to suit their use case:
|
||||
|
||||
* To support a single OS/Architecture without any dependencies, use `libs/onnxruntime-all.jar`.
|
||||
* To support cross-platform: bundle a single `libs/onnxruntime.jar` and with all of the respective `libs/onnxruntime-jni.jar` and `libs/onnxruntime-lib.jar` for all of the desired OS/Architectures.
|
||||
* To support use case where an onnxruntime shared library will reside in the system's library search path: bundle a single `libs/onnxruntime.jar` and with all of the `libs/onnxruntime-jni.jar`. The onnxruntime shared library should be loaded using one of the other methods described in the "Advanced Loading" section below.
|
||||
|
||||
#### Build System Overview
|
||||
|
||||
The main CMake build system delegates building and testing to Gradle.
|
||||
This allows the CMake system to ensure all of the C/C++ compilation is achieved prior to the Java build.
|
||||
The Java build depends on C/C++ onnxruntime shared library and a C JNI shared library (source located in the `src/main/native` directory).
|
||||
The JNI shared library is the glue that allows for Java to call functions in onnxruntime shared library.
|
||||
Given the fact that CMake injects native dependencies during CMake builds, some gradle tasks (primarily, `build`, `test`, and `check`) may fail.
|
||||
|
||||
When running the build script, CMake will compile the `onnxruntime` target and the JNI glue `onnxruntime4j_jni` target and expose the resulting libraries in a place where Gradle can ingest them.
|
||||
Upon successful compilation of those targets, a special Gradle task to build will be executed. The results will be placed in the output directory stated above.
|
||||
|
||||
### Advanced Loading
|
||||
|
||||
The default behavior is to load the shared libraries using classpath resources.
|
||||
If your use case requires custom loading of the shared libraries, please consult the javadoc in the [package-info.java](src/main/java/ai/onnxruntime/package-info.java) or [OnnxRuntime.java](src/main/java/ai/onnxruntime/OnnxRuntime.java) files.
|
||||
|
||||
## Development
|
||||
|
||||
### Code Formatting
|
||||
|
||||
[Spotless](https://github.com/diffplug/spotless/tree/master/plugin-gradle) is used to keep the code properly formatted.
|
||||
Gradle's `spotlessCheck` task will show any misformatted code.
|
||||
Gradle's `spotlessApply` task will try to fix the formatting.
|
||||
Misformatted code will raise failures when checks are ran during test run.
|
||||
|
||||
### JNI Headers
|
||||
|
||||
When adding or updating native methods in the Java files, it may be necessary to examine the relevant JNI headers in `build/headers/ai_onnxruntime*.h`.
|
||||
These files can be manually generated using Gradle's `compileJava` task which will compile the Java and update the header files accordingly.
|
||||
Then the corresponding C files in `./src/main/native/ai_onnxruntime*.c` may be updated and the build can be ran.
|
||||
|
||||
### Dependencies
|
||||
|
||||
The Java API does not have any runtime or compile dependencies currently.
|
||||
117
java/build.gradle
Normal file
117
java/build.gradle
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
plugins {
|
||||
id 'java'
|
||||
id 'com.diffplug.gradle.spotless' version '3.26.0'
|
||||
}
|
||||
|
||||
allprojects {
|
||||
repositories {
|
||||
mavenCentral()
|
||||
}
|
||||
}
|
||||
|
||||
java {
|
||||
sourceCompatibility = JavaVersion.VERSION_1_8
|
||||
targetCompatibility = JavaVersion.VERSION_1_8
|
||||
withJavadocJar()
|
||||
withSourcesJar()
|
||||
}
|
||||
|
||||
wrapper {
|
||||
gradleVersion = '6.1.1'
|
||||
}
|
||||
|
||||
spotless {
|
||||
//java {
|
||||
// removeUnusedImports()
|
||||
// googleJavaFormat()
|
||||
//}
|
||||
format 'gradle', {
|
||||
target '**/*.gradle'
|
||||
trimTrailingWhitespace()
|
||||
indentWithTabs()
|
||||
}
|
||||
}
|
||||
|
||||
// cmake runs will inform us of the build directory of the current run
|
||||
def cmakeBuildDir = System.properties['cmakeBuildDir']
|
||||
def cmakeJavaDir = "${cmakeBuildDir}/java"
|
||||
def cmakeNativeLibDir = "${cmakeJavaDir}/native-lib"
|
||||
def cmakeNativeJniDir = "${cmakeJavaDir}/native-jni"
|
||||
def cmakeBuildOutputDir = "${cmakeJavaDir}/build"
|
||||
|
||||
compileJava {
|
||||
options.compilerArgs += ["-h", "${project.buildDir}/headers/"]
|
||||
}
|
||||
|
||||
sourceSets.test {
|
||||
// add test resource files
|
||||
resources.srcDirs += [
|
||||
"${rootProject.projectDir}/../csharp/testdata",
|
||||
"${rootProject.projectDir}/../onnxruntime/test/testdata"
|
||||
]
|
||||
if (cmakeBuildDir != null) {
|
||||
// add compiled native libs
|
||||
resources.srcDirs += [
|
||||
cmakeNativeLibDir,
|
||||
cmakeNativeJniDir
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
if (cmakeBuildDir != null) {
|
||||
// generate tasks to be called from cmake
|
||||
|
||||
task jniJar(type: Jar) {
|
||||
classifier = 'jni'
|
||||
from cmakeNativeJniDir
|
||||
}
|
||||
|
||||
task libJar(type: Jar) {
|
||||
classifier = 'lib'
|
||||
from cmakeNativeLibDir
|
||||
}
|
||||
|
||||
task allJar(type: Jar) {
|
||||
classifier = 'all'
|
||||
from sourceSets.main.output
|
||||
from cmakeNativeJniDir
|
||||
from cmakeNativeLibDir
|
||||
}
|
||||
|
||||
task cmakeBuild(type: Copy) {
|
||||
from project.buildDir
|
||||
include 'libs/**'
|
||||
include 'docs/**'
|
||||
into cmakeBuildOutputDir
|
||||
}
|
||||
cmakeBuild.dependsOn jar
|
||||
cmakeBuild.dependsOn jniJar
|
||||
cmakeBuild.dependsOn libJar
|
||||
cmakeBuild.dependsOn allJar
|
||||
cmakeBuild.dependsOn sourcesJar
|
||||
cmakeBuild.dependsOn javadocJar
|
||||
cmakeBuild.dependsOn javadoc
|
||||
|
||||
|
||||
task cmakeCheck(type: Copy) {
|
||||
from project.buildDir
|
||||
include 'reports/**'
|
||||
into cmakeBuildOutputDir
|
||||
}
|
||||
cmakeCheck.dependsOn check
|
||||
|
||||
}
|
||||
|
||||
|
||||
dependencies {
|
||||
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.1.1'
|
||||
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.1.1'
|
||||
testImplementation 'com.google.protobuf:protobuf-java:3.10.0'
|
||||
}
|
||||
|
||||
test {
|
||||
useJUnitPlatform()
|
||||
testLogging {
|
||||
events "passed", "skipped", "failed"
|
||||
}
|
||||
}
|
||||
1
java/settings.gradle
Normal file
1
java/settings.gradle
Normal file
|
|
@ -0,0 +1 @@
|
|||
rootProject.name = 'onnxruntime'
|
||||
|
|
@ -5,175 +5,135 @@
|
|||
package ai.onnxruntime;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Properties;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.logging.Level;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* Static loader for the JNI binding.
|
||||
*/
|
||||
/** Static loader for the JNI binding. No public API, but called from various classes in this package to ensure shared libraries are properly loaded. */
|
||||
final class OnnxRuntime {
|
||||
private static final Logger logger = Logger.getLogger(OnnxRuntime.class.getName());
|
||||
private static final Logger logger = Logger.getLogger(OnnxRuntime.class.getName());
|
||||
|
||||
// The initial release of the ORT API.
|
||||
private static final int ORT_API_VERSION_1 = 1;
|
||||
// The initial release of the ORT API.
|
||||
private static final int ORT_API_VERSION_1 = 1;
|
||||
|
||||
/**
|
||||
* Turns on debug logging during library loading.
|
||||
*/
|
||||
public static final String LIBRARY_LOAD_LOGGING = "ORT_LOAD_LOGGING";
|
||||
/** The short name of the ONNX runtime shared library */
|
||||
static final String ONNXRUNTIME_LIBRARY_NAME = "onnxruntime";
|
||||
/** The short name of the ONNX runtime JNI shared library */
|
||||
static final String ONNXRUNTIME_JNI_LIBRARY_NAME = "onnxruntime4j_jni";
|
||||
|
||||
/**
|
||||
* Specifies that the libraries should be loaded from java.library.path rather than unzipped from the jar file.
|
||||
*/
|
||||
public static final String LOAD_LIBRARY_PATH = "ORT_LOAD_FROM_LIBRARY_PATH";
|
||||
private static boolean loaded = false;
|
||||
|
||||
private static boolean loaded = false;
|
||||
/** The API handle. */
|
||||
static long ortApiHandle;
|
||||
|
||||
/**
|
||||
* The API handle.
|
||||
*/
|
||||
static long ortApiHandle;
|
||||
private OnnxRuntime() {}
|
||||
|
||||
/**
|
||||
* Library names stored in the jar.
|
||||
*/
|
||||
private static final List<String> libraryNames = Arrays.asList("onnxruntime","onnxruntime4j_jni");
|
||||
/**
|
||||
* Loads the native C library.
|
||||
*
|
||||
* @throws IOException If it can't write to disk to copy out the library from the jar file.
|
||||
*/
|
||||
static synchronized void init() throws IOException {
|
||||
if (loaded) {
|
||||
return;
|
||||
}
|
||||
Path tempDirectory = Files.createTempDirectory("onnxruntime-java");
|
||||
try {
|
||||
load(tempDirectory, ONNXRUNTIME_LIBRARY_NAME);
|
||||
load(tempDirectory, ONNXRUNTIME_JNI_LIBRARY_NAME);
|
||||
ortApiHandle = initialiseAPIBase(ORT_API_VERSION_1);
|
||||
loaded = true;
|
||||
} finally {
|
||||
cleanUp(tempDirectory.toFile());
|
||||
}
|
||||
}
|
||||
|
||||
private OnnxRuntime() {}
|
||||
/**
|
||||
* Attempt to remove a file and then mark for delete on exit if it cannot be deleted at this point in time.
|
||||
*
|
||||
* @param file The file to remove.
|
||||
*/
|
||||
private static void cleanUp(File file) {
|
||||
if (!file.exists()) {
|
||||
return;
|
||||
}
|
||||
logger.log(Level.FINE, "Deleting " + file);
|
||||
if (!file.delete()) {
|
||||
logger.log(Level.FINE, "Deleting " + file + " on exit");
|
||||
file.deleteOnExit();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads the native C library.
|
||||
* @throws IOException If it can't write to disk to copy out the library from the jar file.
|
||||
*/
|
||||
static synchronized void init() throws IOException {
|
||||
if (!loaded) {
|
||||
// Check system properties for load time configuration.
|
||||
Properties props = System.getProperties();
|
||||
boolean debug = props.containsKey(LIBRARY_LOAD_LOGGING);
|
||||
boolean loadLibraryPath = props.containsKey(LOAD_LIBRARY_PATH);
|
||||
if (loadLibraryPath) {
|
||||
if (debug) {
|
||||
logger.info("Loading from java.library.path");
|
||||
}
|
||||
try {
|
||||
for (String libraryName : libraryNames) {
|
||||
if (debug) {
|
||||
logger.info("Loading " + libraryName + " from java.library.path");
|
||||
}
|
||||
System.loadLibrary(libraryName);
|
||||
}
|
||||
} catch (UnsatisfiedLinkError e) {
|
||||
logger.log(Level.SEVERE, "Failed to load onnx-runtime library from library path.");
|
||||
throw e;
|
||||
}
|
||||
} else {
|
||||
if (debug) {
|
||||
logger.info("Loading from classpath resource");
|
||||
}
|
||||
try {
|
||||
for (String libraryName : libraryNames) {
|
||||
try {
|
||||
// This code path is used during testing.
|
||||
String libraryFromJar = "/" + System.mapLibraryName(libraryName);
|
||||
if (debug) {
|
||||
logger.info("Attempting to load library from classpath using " + libraryFromJar);
|
||||
}
|
||||
String tempLibraryPath = createTempFileFromResource(libraryFromJar, debug);
|
||||
if (debug) {
|
||||
logger.info("Copied resource " + libraryFromJar + " to location " + tempLibraryPath);
|
||||
}
|
||||
System.load(tempLibraryPath);
|
||||
} catch (Exception e) {
|
||||
if (debug) {
|
||||
logger.info("Failed to load from testing location, looking for /lib/<library-name>");
|
||||
}
|
||||
String libraryFromJar = "/lib/" + System.mapLibraryName(libraryName);
|
||||
if (debug) {
|
||||
logger.info("Attempting to load library from classpath using " + libraryFromJar);
|
||||
}
|
||||
String tempLibraryPath = createTempFileFromResource(libraryFromJar, debug);
|
||||
if (debug) {
|
||||
logger.info("Copied resource " + libraryFromJar + " to location " + tempLibraryPath);
|
||||
}
|
||||
System.load(tempLibraryPath);
|
||||
}
|
||||
}
|
||||
} catch (IOException e) {
|
||||
logger.log(Level.SEVERE, "Failed to load onnx-runtime library from jar");
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
ortApiHandle = initialiseAPIBase(ORT_API_VERSION_1);
|
||||
loaded = true;
|
||||
}
|
||||
/**
|
||||
* Load a shared library by name.
|
||||
*
|
||||
* @param tempDirectory The temp directory to write the library resource to.
|
||||
* @param library The bare name of the library.
|
||||
* @throws IOException If the file failed to read or write.
|
||||
*/
|
||||
private static void load(Path tempDirectory, String library) throws IOException {
|
||||
// 1) The user may skip loading of this library:
|
||||
String skip = System.getProperty("onnxruntime.native." + library + ".skip");
|
||||
if (Boolean.TRUE.toString().equalsIgnoreCase(skip)) {
|
||||
logger.log(Level.FINE, "Skipping load of native library '" + library + "'");
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Copies out the named file from the class path into a temporary directory so it can be loaded
|
||||
* by {@link System#load}.
|
||||
* <p>
|
||||
* The file is marked delete on exit. Throws {@link IllegalArgumentException} if the
|
||||
* supplied path is not absolute.
|
||||
* @param path The path to the file in the classpath.
|
||||
* @param debugLogging If true turn on debug logging.
|
||||
* @return The path to the extracted file on disk.
|
||||
* @throws IOException If the file failed to read or write.
|
||||
*/
|
||||
private static String createTempFileFromResource(String path, boolean debugLogging) throws IOException {
|
||||
if (!path.startsWith("/")) {
|
||||
throw new IllegalArgumentException("The path has to be absolute (start with '/').");
|
||||
} else {
|
||||
String[] parts = path.split("/");
|
||||
String filename = parts.length > 1 ? parts[parts.length - 1] : null;
|
||||
String prefix = "";
|
||||
String suffix = null;
|
||||
if (filename != null) {
|
||||
parts = filename.split("\\.", 2);
|
||||
prefix = parts[0];
|
||||
suffix = parts.length > 1 ? "." + parts[parts.length - 1] : null;
|
||||
}
|
||||
|
||||
if (filename != null && prefix.length() >= 3) {
|
||||
File temp = File.createTempFile(prefix, suffix);
|
||||
if (debugLogging) {
|
||||
logger.info("Writing " + path + " out to " + temp.getAbsolutePath());
|
||||
}
|
||||
temp.deleteOnExit();
|
||||
if (!temp.exists()) {
|
||||
throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist.");
|
||||
} else {
|
||||
byte[] buffer = new byte[1024];
|
||||
try (InputStream is = OnnxRuntime.class.getResourceAsStream(path)) {
|
||||
if (is == null) {
|
||||
throw new FileNotFoundException("File " + path + " was not found inside JAR.");
|
||||
} else {
|
||||
int readBytes;
|
||||
try (FileOutputStream os = new FileOutputStream(temp)) {
|
||||
while ((readBytes = is.read(buffer)) != -1) {
|
||||
os.write(buffer, 0, readBytes);
|
||||
}
|
||||
}
|
||||
return temp.getAbsolutePath();
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw new IllegalArgumentException("The filename has to be at least 3 characters long.");
|
||||
}
|
||||
}
|
||||
// 2) The user may explicitly specify the path to their shared library:
|
||||
String libraryPathProperty = System.getProperty("onnxruntime.native." + library + ".path");
|
||||
if (libraryPathProperty != null) {
|
||||
logger.log(Level.FINE, "Attempting to load native library '" + library + "' from specified path: " + libraryPathProperty);
|
||||
File libraryFile = new File(libraryPathProperty);
|
||||
String libraryFilePath = libraryFile.getAbsolutePath();
|
||||
if (!libraryFile.exists()) {
|
||||
throw new IOException("Native library '" + library + "' not found at "+ libraryFilePath);
|
||||
}
|
||||
System.load(libraryFilePath);
|
||||
logger.log(Level.FINE, "Loaded native library '" + library + "' from specified path");
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a reference to the API struct.
|
||||
* @param apiVersionNumber The API version to use.
|
||||
* @return A pointer to the API struct.
|
||||
*/
|
||||
private static native long initialiseAPIBase(int apiVersionNumber);
|
||||
// 3) try loading from resources or library path:
|
||||
// generate a platform specific library name
|
||||
// replace Mac's jnilib extension to dylib
|
||||
String libraryFileName = System.mapLibraryName(library).replace("jnilib", "dylib");
|
||||
String resourcePath = "/ai/onnxruntime/native/" + libraryFileName;
|
||||
File tempFile = tempDirectory.resolve(libraryFileName).toFile();
|
||||
try(InputStream is = OnnxRuntime.class.getResourceAsStream(resourcePath)){
|
||||
if (is == null) {
|
||||
// 3a) Not found in resources, load from library path
|
||||
logger.log(Level.FINE, "Attempting to load native library '" + library + "' from library path");
|
||||
System.loadLibrary(library);
|
||||
logger.log(Level.FINE, "Loaded native library '" + library + "' from library path");
|
||||
} else {
|
||||
// 3b) Found in resources, load via temporary file
|
||||
logger.log(
|
||||
Level.FINE,
|
||||
"Attempting to load native library '" + library + "' from resource path " + resourcePath + " copying to " + tempFile);
|
||||
byte[] buffer = new byte[1024];
|
||||
int readBytes;
|
||||
try (FileOutputStream os = new FileOutputStream(tempFile)) {
|
||||
while ((readBytes = is.read(buffer)) != -1) {
|
||||
os.write(buffer, 0, readBytes);
|
||||
}
|
||||
}
|
||||
System.load(tempFile.getAbsolutePath());
|
||||
logger.log(Level.FINE, "Loaded native library '" + library + "' from resource path");
|
||||
}
|
||||
} finally {
|
||||
cleanUp(tempFile);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a reference to the API struct.
|
||||
*
|
||||
* @param apiVersionNumber The API version to use.
|
||||
* @return A pointer to the API struct.
|
||||
*/
|
||||
private static native long initialiseAPIBase(int apiVersionNumber);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,10 +4,23 @@
|
|||
*/
|
||||
|
||||
/**
|
||||
* A Java interface to the onnxruntime.
|
||||
* A Java interface to the ONNX Runtime.
|
||||
* <p>
|
||||
* Provides access to the same execution backends as the C library.
|
||||
* Non-representable types in Java (such as fp16) are converted
|
||||
* into the nearest Java primitive type when accessed through this API.
|
||||
*/
|
||||
* <p>
|
||||
* There are two shared libraries required: <code>onnxruntime</code> and <code>onnxruntime4j_jni</code>. The loader is in {@link ai.onnxruntime.OnnxRuntime} and the logic is in this order:
|
||||
* <ol>
|
||||
* <li>The user may signal to skip loading of a shared library using a property in the form <code>onnxruntime.native.LIB_NAME.skip</code> with a value of <code>true</code>. This means the user has decided to load the library by some other means.
|
||||
* <li>The user may specify an explicit location of the shared library file using a property in the form <code>onnxruntime.native.LIB_NAME.path</code>. This uses {@link java.lang.System#load}.
|
||||
* <li>The shared library is autodiscovered:<ol>
|
||||
* <li>If the shared library is present in the classpath resources, load using {@link java.lang.System#load} via a temporary file.
|
||||
* Ideally, this should be the default use case when adding JAR's/dependencies containing the shared libraries to your classpath.
|
||||
* <li>If the shared library is not present in the classpath resources, then load using {@link java.lang.System#loadLibrary}, which usually looks elsewhere on the filesystem for the library.
|
||||
* The semantics and behavior of that method are system/JVM dependent.
|
||||
* Typically, the <code>java.library.path</code> property is used to specify the location of native libraries.
|
||||
* </ol></ol>
|
||||
* For troubleshooting, all shared library loading events are reported to Java logging at the level FINE.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
|
@ -53,22 +53,14 @@ import java.util.stream.Stream;
|
|||
*/
|
||||
public class InferenceTest {
|
||||
private static final Pattern LOAD_PATTERN = Pattern.compile("[,\\[\\] ]");
|
||||
private static Path resourcePath;
|
||||
private static Path otherTestPath;
|
||||
|
||||
private static String propertiesFile = "Properties.txt";
|
||||
|
||||
private static Pattern inputPBPattern = Pattern.compile("input_*.pb");
|
||||
private static Pattern outputPBPattern = Pattern.compile("output_*.pb");
|
||||
|
||||
static {
|
||||
if (System.getProperty("GRADLE_TEST") != null) {
|
||||
resourcePath = Paths.get("..","csharp","testdata");
|
||||
otherTestPath = Paths.get("..","onnxruntime","test", "testdata");
|
||||
} else {
|
||||
resourcePath = Paths.get("csharp","testdata");
|
||||
otherTestPath = Paths.get("onnxruntime","test", "testdata");
|
||||
}
|
||||
private static Path getResourcePath(String path) {
|
||||
return new File(InferenceTest.class.getResource(path).getFile()).toPath();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
@ -84,7 +76,7 @@ public class InferenceTest {
|
|||
|
||||
@Test
|
||||
public void createSessionFromPath() throws OrtException {
|
||||
String modelPath = resourcePath.resolve("squeezenet.onnx").toString();
|
||||
String modelPath = getResourcePath("/squeezenet.onnx").toString();
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath");
|
||||
OrtSession.SessionOptions options = new SessionOptions()) {
|
||||
try (OrtSession session = env.createSession(modelPath,options)) {
|
||||
|
|
@ -123,7 +115,7 @@ public class InferenceTest {
|
|||
}
|
||||
@Test
|
||||
public void createSessionFromByteArray() throws IOException, OrtException {
|
||||
Path modelPath = resourcePath.resolve("squeezenet.onnx");
|
||||
Path modelPath = getResourcePath("/squeezenet.onnx");
|
||||
byte[] modelBytes = Files.readAllBytes(modelPath);
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromByteArray");
|
||||
OrtSession.SessionOptions options = new SessionOptions()) {
|
||||
|
|
@ -171,7 +163,7 @@ public class InferenceTest {
|
|||
}
|
||||
|
||||
private void canRunInferenceOnAModel(OptLevel graphOptimizationLevel, ExecutionMode exectionMode) throws OrtException {
|
||||
String modelPath = resourcePath.resolve("squeezenet.onnx").toString();
|
||||
String modelPath = getResourcePath("/squeezenet.onnx").toString();
|
||||
|
||||
// Set the graph optimization level for this session.
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment("canRunInferenceOnAModel");
|
||||
|
|
@ -184,7 +176,7 @@ public class InferenceTest {
|
|||
Map<String,OnnxTensor> container = new HashMap<>();
|
||||
NodeInfo inputMeta = inputMetaMap.values().iterator().next();
|
||||
|
||||
float[] inputData = loadTensorFromFile(resourcePath.resolve("bench.in"));
|
||||
float[] inputData = loadTensorFromFile(getResourcePath("/bench.in"));
|
||||
// this is the data for only one input tensor for this model
|
||||
Object tensorData = OrtUtil.reshape(inputData,((TensorInfo) inputMeta.getInfo()).getShape());
|
||||
OnnxTensor inputTensor = OnnxTensor.createTensor(env,tensorData);
|
||||
|
|
@ -194,7 +186,7 @@ public class InferenceTest {
|
|||
try (OrtSession.Result results = session.run(container)) {
|
||||
assertEquals(1, results.size());
|
||||
|
||||
float[] expectedOutput = loadTensorFromFile(resourcePath.resolve("bench.expected_out"));
|
||||
float[] expectedOutput = loadTensorFromFile(getResourcePath("/bench.expected_out"));
|
||||
// validate the results
|
||||
// Only iterates once
|
||||
for (Map.Entry<String, OnnxValue> r : results) {
|
||||
|
|
@ -485,7 +477,7 @@ public class InferenceTest {
|
|||
@Test
|
||||
public void testModelInputFLOAT() throws OrtException {
|
||||
// model takes 1x5 input of fixed type, echoes back
|
||||
String modelPath = resourcePath.resolve("test_types_FLOAT.pb").toString();
|
||||
String modelPath = getResourcePath("/test_types_FLOAT.pb").toString();
|
||||
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputFLOAT");
|
||||
SessionOptions options = new SessionOptions();
|
||||
|
|
@ -523,7 +515,7 @@ public class InferenceTest {
|
|||
@Test
|
||||
public void testModelInputBOOL() throws OrtException {
|
||||
// model takes 1x5 input of fixed type, echoes back
|
||||
String modelPath = resourcePath.resolve("test_types_BOOL.pb").toString();
|
||||
String modelPath = getResourcePath("/test_types_BOOL.pb").toString();
|
||||
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputBOOL");
|
||||
SessionOptions options = new SessionOptions();
|
||||
|
|
@ -545,7 +537,7 @@ public class InferenceTest {
|
|||
@Test
|
||||
public void testModelInputINT32() throws OrtException {
|
||||
// model takes 1x5 input of fixed type, echoes back
|
||||
String modelPath = resourcePath.resolve("test_types_INT32.pb").toString();
|
||||
String modelPath = getResourcePath("/test_types_INT32.pb").toString();
|
||||
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputINT32");
|
||||
SessionOptions options = new SessionOptions();
|
||||
|
|
@ -567,7 +559,7 @@ public class InferenceTest {
|
|||
@Test
|
||||
public void testModelInputDOUBLE() throws OrtException {
|
||||
// model takes 1x5 input of fixed type, echoes back
|
||||
String modelPath = resourcePath.resolve("test_types_DOUBLE.pb").toString();
|
||||
String modelPath = getResourcePath("/test_types_DOUBLE.pb").toString();
|
||||
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputDOUBLE");
|
||||
SessionOptions options = new SessionOptions();
|
||||
|
|
@ -589,7 +581,7 @@ public class InferenceTest {
|
|||
@Test
|
||||
public void testModelInputINT8() throws OrtException {
|
||||
// model takes 1x5 input of fixed type, echoes back
|
||||
String modelPath = resourcePath.resolve("test_types_INT8.pb").toString();
|
||||
String modelPath = getResourcePath("/test_types_INT8.pb").toString();
|
||||
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputINT8");
|
||||
SessionOptions options = new SessionOptions();
|
||||
|
|
@ -611,7 +603,7 @@ public class InferenceTest {
|
|||
@Test
|
||||
public void testModelInputINT16() throws OrtException {
|
||||
// model takes 1x5 input of fixed type, echoes back
|
||||
String modelPath = resourcePath.resolve("test_types_INT16.pb").toString();
|
||||
String modelPath = getResourcePath("/test_types_INT16.pb").toString();
|
||||
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputINT16");
|
||||
SessionOptions options = new SessionOptions();
|
||||
|
|
@ -633,7 +625,7 @@ public class InferenceTest {
|
|||
@Test
|
||||
public void testModelInputINT64() throws OrtException {
|
||||
// model takes 1x5 input of fixed type, echoes back
|
||||
String modelPath = resourcePath.resolve("test_types_INT64.pb").toString();
|
||||
String modelPath = getResourcePath("/test_types_INT64.pb").toString();
|
||||
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputINT64");
|
||||
SessionOptions options = new SessionOptions();
|
||||
|
|
@ -660,7 +652,7 @@ public class InferenceTest {
|
|||
// "probabilities" is a sequence<map<int64, float>>
|
||||
// https://github.com/onnx/sklearn-onnx/blob/master/docs/examples/plot_pipeline_lightgbm.py
|
||||
|
||||
String modelPath = resourcePath.resolve("test_sequence_map_int_float.pb").toString();
|
||||
String modelPath = getResourcePath("/test_sequence_map_int_float.pb").toString();
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelSequenceOfMapIntFloat");
|
||||
SessionOptions options = new SessionOptions();
|
||||
OrtSession session = env.createSession(modelPath, options)) {
|
||||
|
|
@ -720,7 +712,7 @@ public class InferenceTest {
|
|||
// "label" is a tensor,
|
||||
// "probabilities" is a sequence<map<int64, float>>
|
||||
// https://github.com/onnx/sklearn-onnx/blob/master/docs/examples/plot_pipeline_lightgbm.py
|
||||
String modelPath = resourcePath.resolve("test_sequence_map_string_float.pb").toString();
|
||||
String modelPath = getResourcePath("/test_sequence_map_string_float.pb").toString();
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelSequenceOfMapStringFloat");
|
||||
SessionOptions options = new SessionOptions();
|
||||
OrtSession session = env.createSession(modelPath, options)) {
|
||||
|
|
@ -774,11 +766,13 @@ public class InferenceTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testModelSerialization() throws OrtException {
|
||||
public void testModelSerialization() throws OrtException, IOException {
|
||||
String cwd = System.getProperty("user.dir");
|
||||
Path squeezeNet = resourcePath.resolve("squeezenet.onnx");
|
||||
Path squeezeNet = getResourcePath("/squeezenet.onnx");
|
||||
String modelPath = squeezeNet.toString();
|
||||
String modelOutputPath = Paths.get(cwd, "optimized-squeezenet.onnx").toString();
|
||||
File tmpFile = File.createTempFile("optimized-squeezenet", ".onnx");
|
||||
String modelOutputPath = tmpFile.getAbsolutePath();
|
||||
Assertions.assertEquals(0, tmpFile.length());
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment()) {
|
||||
// Set the optimized model file path to assert that no exception are thrown.
|
||||
SessionOptions options = new SessionOptions();
|
||||
|
|
@ -786,14 +780,17 @@ public class InferenceTest {
|
|||
options.setOptimizationLevel(OptLevel.BASIC_OPT);
|
||||
try (OrtSession session = env.createSession(modelPath, options)) {
|
||||
Assertions.assertNotNull(session);
|
||||
Assertions.assertTrue((new File(modelOutputPath)).exists());
|
||||
} finally {
|
||||
Assertions.assertTrue(tmpFile.length() > 0);
|
||||
}
|
||||
} finally {
|
||||
tmpFile.delete();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStringIdentity() throws OrtException {
|
||||
String modelPath = otherTestPath.resolve("identity_string.onnx").toString();
|
||||
String modelPath = getResourcePath("/identity_string.onnx").toString();
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testStringIdentity");
|
||||
SessionOptions options = new SessionOptions();
|
||||
OrtSession session = env.createSession(modelPath, options)) {
|
||||
|
|
@ -871,7 +868,7 @@ public class InferenceTest {
|
|||
}
|
||||
|
||||
private static SqueezeNetTuple openSessionSqueezeNet(int cudaDeviceId) throws OrtException {
|
||||
Path squeezeNet = resourcePath.resolve("squeezenet.onnx");
|
||||
Path squeezeNet = getResourcePath("/squeezenet.onnx");
|
||||
String modelPath = squeezeNet.toString();
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
SessionOptions options = new SessionOptions();
|
||||
|
|
@ -879,8 +876,8 @@ public class InferenceTest {
|
|||
options.addCUDA(cudaDeviceId);
|
||||
}
|
||||
OrtSession session = env.createSession(modelPath,options);
|
||||
float[] inputData = loadTensorFromFile(resourcePath.resolve("bench.in"));
|
||||
float[] expectedOutput = loadTensorFromFile(resourcePath.resolve("bench.expected_out"));
|
||||
float[] inputData = loadTensorFromFile(getResourcePath("/bench.in"));
|
||||
float[] expectedOutput = loadTensorFromFile(getResourcePath("/bench.expected_out"));
|
||||
return new SqueezeNetTuple(env, session, inputData, expectedOutput);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package sample;
|
||||
import ai.onnxruntime.NodeInfo;
|
||||
import ai.onnxruntime.OnnxTensor;
|
||||
import ai.onnxruntime.OnnxValue;
|
||||
|
|
@ -92,6 +92,10 @@ else
|
|||
popd
|
||||
fi
|
||||
|
||||
GetFile https://downloads.gradle-dn.com/distributions/gradle-6.2-bin.zip /tmp/src/gradle-6.2-bin.zip
|
||||
cd /tmp/src
|
||||
unzip gradle-6.2-bin.zip
|
||||
mv /tmp/src/gradle-6.2 /usr/local/gradle
|
||||
|
||||
|
||||
if ! [ -x "$(command -v protoc)" ]; then
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ y) YOCTO_VERSION=${OPTARG};;
|
|||
esac
|
||||
done
|
||||
|
||||
export PATH=$PATH:/usr/local/gradle/bin
|
||||
|
||||
if [ $BUILD_OS = "android" ]; then
|
||||
pushd /onnxruntime_src
|
||||
mkdir build-android && cd build-android
|
||||
|
|
@ -80,6 +82,7 @@ else
|
|||
--cuda_home /usr/local/cuda \
|
||||
--cudnn_home /usr/local/cuda $BUILD_EXTR_PAR
|
||||
else #cpu, ngraph and openvino
|
||||
export JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64
|
||||
python3 $SCRIPT_DIR/../../build.py --build_dir /build \
|
||||
--config Debug Release $COMMON_BUILD_ARGS $BUILD_EXTR_PAR
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -1 +1,2 @@
|
|||
set PATH=C:\azcopy;%BUILD_BINARIESDIRECTORY%\packages\python;%BUILD_BINARIESDIRECTORY%\packages\python\DLLs;%BUILD_BINARIESDIRECTORY%\packages\python\Library\bin;%BUILD_BINARIESDIRECTORY%\packages\python\script;%PATH%
|
||||
set GRADLE_OPTS=-Dorg.gradle.daemon=false
|
||||
|
|
|
|||
|
|
@ -1 +1,2 @@
|
|||
set PATH=C:\azcopy;C:\Program Files (x86)\dotnet;%BUILD_BINARIESDIRECTORY%\packages\python;%BUILD_BINARIESDIRECTORY%\packages\python\DLLs;%BUILD_BINARIESDIRECTORY%\packages\python\Library\bin;%BUILD_BINARIESDIRECTORY%\packages\python\script;%PATH%;C:\local\systools
|
||||
set GRADLE_OPTS=-Dorg.gradle.daemon=false
|
||||
|
|
|
|||
Loading…
Reference in a new issue