From c91527235abff439dbaa2158bedbe45b3daeda3c Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Thu, 16 Apr 2020 05:29:23 -0400 Subject: [PATCH] [Java] Add support for map and sequence information on output nodes (#3468) --- java/build.gradle | 10 +- .../main/java/ai/onnxruntime/OnnxRuntime.java | 4 +- .../main/java/ai/onnxruntime/OrtSession.java | 41 +++--- .../java/ai/onnxruntime/SequenceInfo.java | 2 +- java/src/main/native/OrtJniUtil.c | 128 +++++++++++++++++- java/src/main/native/OrtJniUtil.h | 5 +- .../main/native/ai_onnxruntime_OrtSession.c | 8 +- .../java/ai/onnxruntime/InferenceTest.java | 12 ++ 8 files changed, 175 insertions(+), 35 deletions(-) diff --git a/java/build.gradle b/java/build.gradle index 528867c1de..dd6c7b3b38 100644 --- a/java/build.gradle +++ b/java/build.gradle @@ -1,5 +1,6 @@ plugins { id 'java' + id 'jacoco' id 'com.diffplug.gradle.spotless' version '3.26.0' } @@ -105,7 +106,6 @@ if (cmakeBuildDir != null) { } - dependencies { testImplementation 'org.junit.jupiter:junit-jupiter-api:5.1.1' testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.1.1' @@ -118,3 +118,11 @@ test { events "passed", "skipped", "failed" } } + +jacocoTestReport { + reports { + xml.enabled true + csv.enabled true + html.destination file("${buildDir}/jacocoHtml") + } +} diff --git a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java index a82bc00a02..18f9c397bb 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java +++ b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java @@ -22,6 +22,8 @@ final class OnnxRuntime { // The initial release of the ORT API. private static final int ORT_API_VERSION_1 = 1; + // Post 1.0 builds of the ORT API. + private static final int ORT_API_VERSION_2 = 2; /** The short name of the ONNX runtime shared library */ static final String ONNXRUNTIME_LIBRARY_NAME = "onnxruntime"; @@ -48,7 +50,7 @@ final class OnnxRuntime { try { load(tempDirectory, ONNXRUNTIME_LIBRARY_NAME); load(tempDirectory, ONNXRUNTIME_JNI_LIBRARY_NAME); - ortApiHandle = initialiseAPIBase(ORT_API_VERSION_1); + ortApiHandle = initialiseAPIBase(ORT_API_VERSION_2); loaded = true; } finally { if (!isAndroid()) { diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index f9136a1d3b..80a1d78fa4 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -60,18 +60,9 @@ public class OrtSession implements AutoCloseable { */ OrtSession(OrtEnvironment env, String modelPath, OrtAllocator allocator, SessionOptions options) throws OrtException { - nativeHandle = - createSession(OnnxRuntime.ortApiHandle, env.nativeHandle, modelPath, options.nativeHandle); - this.allocator = allocator; - numInputs = getNumInputs(OnnxRuntime.ortApiHandle, nativeHandle); - inputNames = - new LinkedHashSet<>( - Arrays.asList(getInputNames(OnnxRuntime.ortApiHandle, nativeHandle, allocator.handle))); - numOutputs = getNumOutputs(OnnxRuntime.ortApiHandle, nativeHandle); - outputNames = - new LinkedHashSet<>( - Arrays.asList( - getOutputNames(OnnxRuntime.ortApiHandle, nativeHandle, allocator.handle))); + this( + createSession(OnnxRuntime.ortApiHandle, env.nativeHandle, modelPath, options.nativeHandle), + allocator); } /** @@ -81,12 +72,24 @@ public class OrtSession implements AutoCloseable { * @param modelArray The model protobuf as a byte array. * @param allocator The allocator to use. * @param options Session configuration options. - * @throws OrtException If the mode was corrupted or some other error occurred in native code. + * @throws OrtException If the model was corrupted or some other error occurred in native code. */ OrtSession(OrtEnvironment env, byte[] modelArray, OrtAllocator allocator, SessionOptions options) throws OrtException { - nativeHandle = - createSession(OnnxRuntime.ortApiHandle, env.nativeHandle, modelArray, options.nativeHandle); + this( + createSession(OnnxRuntime.ortApiHandle, env.nativeHandle, modelArray, options.nativeHandle), + allocator); + } + + /** + * Private constructor to build the Java object wrapped around a native session. + * + * @param nativeHandle The pointer to the native session. + * @param allocator The allocator to use. + * @throws OrtException If the model's inputs, outputs or metadata could not be read. + */ + private OrtSession(long nativeHandle, OrtAllocator allocator) throws OrtException { + this.nativeHandle = nativeHandle; this.allocator = allocator; numInputs = getNumInputs(OnnxRuntime.ortApiHandle, nativeHandle); inputNames = @@ -289,17 +292,17 @@ public class OrtSession implements AutoCloseable { private static Map wrapInMap(NodeInfo[] infos) { Map output = new LinkedHashMap<>(); - for (int i = 0; i < infos.length; i++) { - output.put(infos[i].getName(), infos[i]); + for (NodeInfo info : infos) { + output.put(info.getName(), info); } return output; } - private native long createSession( + private static native long createSession( long apiHandle, long envHandle, String modelPath, long optsHandle) throws OrtException; - private native long createSession( + private static native long createSession( long apiHandle, long envHandle, byte[] modelArray, long optsHandle) throws OrtException; private native long getNumInputs(long apiHandle, long nativeHandle) throws OrtException; diff --git a/java/src/main/java/ai/onnxruntime/SequenceInfo.java b/java/src/main/java/ai/onnxruntime/SequenceInfo.java index ded856bc04..a417634b72 100644 --- a/java/src/main/java/ai/onnxruntime/SequenceInfo.java +++ b/java/src/main/java/ai/onnxruntime/SequenceInfo.java @@ -49,7 +49,7 @@ public class SequenceInfo implements ValueInfo { } /** - * Constructs a sequence of known lenght containing maps. + * Constructs a sequence of known length containing maps. * * @param length The length of the sequence. * @param keyType The map key type. diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index f2d97dd24a..e89e0a70c2 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -205,10 +205,14 @@ jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, OrtTypeInfo * inf return convertToTensorInfo(jniEnv, api, (const OrtTensorTypeAndShapeInfo *) tensorInfo); } case ONNX_TYPE_SEQUENCE: { - return createEmptySequenceInfo(jniEnv); + const OrtSequenceTypeInfo* sequenceInfo; + checkOrtStatus(jniEnv,api,api->CastTypeInfoToSequenceTypeInfo(info,&sequenceInfo)); + return convertToSequenceInfo(jniEnv, api, sequenceInfo); } case ONNX_TYPE_MAP: { - return createEmptyMapInfo(jniEnv); + const OrtMapTypeInfo* mapInfo; + checkOrtStatus(jniEnv,api,api->CastTypeInfoToMapTypeInfo(info,&mapInfo)); + return convertToMapInfo(jniEnv, api, mapInfo); } case ONNX_TYPE_UNKNOWN: case ONNX_TYPE_OPAQUE: @@ -261,8 +265,56 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT return tensorInfo; } -//jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo * info) { -// As map info isn't available at this point, it creates an empty map info type. +jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtMapTypeInfo * info) { + // Create the java methods we need to call. + // Get the ONNXTensorType enum static method + char *onnxTensorTypeClassName = "ai/onnxruntime/TensorInfo$OnnxTensorType"; + jclass onnxTensorTypeClazz = (*jniEnv)->FindClass(jniEnv, onnxTensorTypeClassName); + jmethodID onnxTensorTypeMapFromInt = (*jniEnv)->GetStaticMethodID(jniEnv,onnxTensorTypeClazz, "mapFromInt", "(I)Lai/onnxruntime/TensorInfo$OnnxTensorType;"); + + // Get the ONNXJavaType enum static method + char *javaDataTypeClassName = "ai/onnxruntime/OnnxJavaType"; + jclass onnxJavaTypeClazz = (*jniEnv)->FindClass(jniEnv, javaDataTypeClassName); + jmethodID onnxJavaTypeMapFromONNXTensorType = (*jniEnv)->GetStaticMethodID(jniEnv,onnxJavaTypeClazz, "mapFromOnnxTensorType", "(Lai/onnxruntime/TensorInfo$OnnxTensorType;)Lai/onnxruntime/OnnxJavaType;"); + + // Get the map info class + char *mapInfoClassName = "ai/onnxruntime/MapInfo"; + jclass mapInfoClazz = (*jniEnv)->FindClass(jniEnv, mapInfoClassName); + jmethodID mapInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,mapInfoClazz,"","(ILai/onnxruntime/OnnxJavaType;Lai/onnxruntime/OnnxJavaType;)V"); + + // Extract the key type + ONNXTensorElementDataType keyType; + checkOrtStatus(jniEnv,api,api->GetMapKeyType(info,&keyType)); + + // Convert key type to java + jint onnxTypeKey = convertFromONNXDataFormat(keyType); + jobject onnxTensorTypeJavaKey = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeKey); + jobject onnxJavaTypeKey = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJavaKey); + + // according to include/onnxruntime/core/framework/data_types.h only the following values are supported. + // string, int64, float, double + // So extract the value type, then convert it to a tensor type so we can get it's element type. + OrtTypeInfo* valueTypeInfo; + checkOrtStatus(jniEnv,api,api->GetMapValueType(info,&valueTypeInfo)); + const OrtTensorTypeAndShapeInfo* tensorValueInfo; + checkOrtStatus(jniEnv,api,api->CastTypeInfoToTensorInfo(valueTypeInfo,&tensorValueInfo)); + ONNXTensorElementDataType valueType; + checkOrtStatus(jniEnv,api,api->GetTensorElementType(tensorValueInfo,&valueType)); + api->ReleaseTypeInfo(valueTypeInfo); + tensorValueInfo = NULL; + valueTypeInfo = NULL; + + // Convert value type to java + jint onnxTypeValue = convertFromONNXDataFormat(valueType); + jobject onnxTensorTypeJavaValue = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeValue); + jobject onnxJavaTypeValue = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJavaValue); + + // Construct map info + jobject mapInfo = (*jniEnv)->NewObject(jniEnv,mapInfoClazz,mapInfoConstructor,(jint)-1,onnxJavaTypeKey,onnxJavaTypeValue); + + return mapInfo; +} + jobject createEmptyMapInfo(JNIEnv *jniEnv) { // Create the ONNXJavaType enum char *onnxJavaTypeClassName = "ai/onnxruntime/OnnxJavaType"; @@ -278,8 +330,72 @@ jobject createEmptyMapInfo(JNIEnv *jniEnv) { return mapInfo; } -//jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo * info) { -// As sequence info isn't available at this point, it creates an empty sequence info type. +jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSequenceTypeInfo * info) { + // Get the sequence info class + char *sequenceInfoClassName = "ai/onnxruntime/SequenceInfo"; + jclass sequenceInfoClazz = (*jniEnv)->FindClass(jniEnv, sequenceInfoClassName); + + // according to include/onnxruntime/core/framework/data_types.h the following values are supported. + // tensor types, map and map + OrtTypeInfo* elementTypeInfo; + checkOrtStatus(jniEnv,api,api->GetSequenceElementType(info,&elementTypeInfo)); + ONNXType type; + checkOrtStatus(jniEnv,api,api->GetOnnxTypeFromTypeInfo(elementTypeInfo,&type)); + + jobject sequenceInfo; + + switch (type) { + case ONNX_TYPE_TENSOR: { + // Figure out element type + const OrtTensorTypeAndShapeInfo* elementTensorInfo; + checkOrtStatus(jniEnv,api,api->CastTypeInfoToTensorInfo(elementTypeInfo,&elementTensorInfo)); + ONNXTensorElementDataType element; + checkOrtStatus(jniEnv,api,api->GetTensorElementType(elementTensorInfo,&element)); + + // Convert element type into ONNXTensorType + jint onnxTypeInt = convertFromONNXDataFormat(element); + // Get the ONNXTensorType enum static method + char *onnxTensorTypeClassName = "ai/onnxruntime/TensorInfo$OnnxTensorType"; + jclass onnxTensorTypeClazz = (*jniEnv)->FindClass(jniEnv, onnxTensorTypeClassName); + jmethodID onnxTensorTypeMapFromInt = (*jniEnv)->GetStaticMethodID(jniEnv,onnxTensorTypeClazz, "mapFromInt", "(I)Lai/onnxruntime/TensorInfo$OnnxTensorType;"); + jobject onnxTensorTypeJava = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeInt); + + // Get the ONNXJavaType enum static method + char *javaDataTypeClassName = "ai/onnxruntime/OnnxJavaType"; + jclass onnxJavaTypeClazz = (*jniEnv)->FindClass(jniEnv, javaDataTypeClassName); + jmethodID onnxJavaTypeMapFromONNXTensorType = (*jniEnv)->GetStaticMethodID(jniEnv,onnxJavaTypeClazz, "mapFromOnnxTensorType", "(Lai/onnxruntime/TensorInfo$OnnxTensorType;)Lai/onnxruntime/OnnxJavaType;"); + jobject onnxJavaType = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJava); + + // Construct sequence info + jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,sequenceInfoClazz,"","(ILai/onnxruntime/OnnxJavaType;)V"); + sequenceInfo = (*jniEnv)->NewObject(jniEnv,sequenceInfoClazz,sequenceInfoConstructor,(jint)-1,onnxJavaType); + break; + } + case ONNX_TYPE_MAP: { + // Extract the map info + const OrtMapTypeInfo* mapInfo; + checkOrtStatus(jniEnv,api,api->CastTypeInfoToMapTypeInfo(elementTypeInfo,&mapInfo)); + + // Convert it using the existing convert function + jobject javaMapInfo = convertToMapInfo(jniEnv,api,mapInfo); + + // Construct sequence info + jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,sequenceInfoClazz,"","(ILai/onnxruntime/MapInfo;)V"); + sequenceInfo = (*jniEnv)->NewObject(jniEnv,sequenceInfoClazz,sequenceInfoConstructor,(jint)-1,javaMapInfo); + break; + } + default: { + sequenceInfo = createEmptySequenceInfo(jniEnv); + throwOrtException(jniEnv,convertErrorCode(ORT_INVALID_ARGUMENT),"Invalid element type found in sequence"); + break; + } + } + api->ReleaseTypeInfo(elementTypeInfo); + elementTypeInfo = NULL; + + return sequenceInfo; +} + jobject createEmptySequenceInfo(JNIEnv *jniEnv) { // Create the ONNXJavaType enum char *onnxJavaTypeClassName = "ai/onnxruntime/OnnxJavaType"; diff --git a/java/src/main/native/OrtJniUtil.h b/java/src/main/native/OrtJniUtil.h index 4f05096215..e42af5dd6c 100644 --- a/java/src/main/native/OrtJniUtil.h +++ b/java/src/main/native/OrtJniUtil.h @@ -31,9 +31,8 @@ jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, OrtTypeInfo * inf jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorTypeAndShapeInfo * info); -//TODO when C API supports inspecting the types of map and sequence types from OutputInfos -//jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo * info); -//jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo * info); +jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtMapTypeInfo * info); +jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSequenceTypeInfo * info); jobject createEmptyMapInfo(JNIEnv *jniEnv); jobject createEmptySequenceInfo(JNIEnv *jniEnv); diff --git a/java/src/main/native/ai_onnxruntime_OrtSession.c b/java/src/main/native/ai_onnxruntime_OrtSession.c index 6e1c7ea488..e0eeec63dd 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession.c @@ -14,8 +14,8 @@ * Signature: (JJLjava/lang/String;J)J */ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_lang_String_2J - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong envHandle, jstring modelPath, jlong optsHandle) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jstring modelPath, jlong optsHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; OrtSession* session; @@ -43,8 +43,8 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_la * Signature: (JJ[BJ)J */ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJ_3BJ - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong envHandle, jbyteArray jModelArray, jlong optsHandle) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jbyteArray jModelArray, jlong optsHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; OrtSession* session; diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index c3ed213329..d0521d9429 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -915,6 +915,12 @@ public class InferenceTest { assertTrue(firstOutputInfo.getInfo() instanceof TensorInfo); assertTrue(secondOutputInfo.getInfo() instanceof SequenceInfo); assertEquals(OnnxJavaType.INT64, ((TensorInfo) firstOutputInfo.getInfo()).type); + assertTrue(((SequenceInfo) secondOutputInfo.getInfo()).sequenceOfMaps); + assertEquals(OnnxJavaType.UNKNOWN, ((SequenceInfo) secondOutputInfo.getInfo()).sequenceType); + MapInfo mapInfo = ((SequenceInfo) secondOutputInfo.getInfo()).mapInfo; + assertNotNull(mapInfo); + assertEquals(OnnxJavaType.INT64, mapInfo.keyType); + assertEquals(OnnxJavaType.FLOAT, mapInfo.valueType); Map container = new HashMap<>(); long[] shape = new long[] {1, 2}; @@ -975,6 +981,12 @@ public class InferenceTest { assertTrue(firstOutputInfo.getInfo() instanceof TensorInfo); assertTrue(secondOutputInfo.getInfo() instanceof SequenceInfo); assertEquals(OnnxJavaType.STRING, ((TensorInfo) firstOutputInfo.getInfo()).type); + assertTrue(((SequenceInfo) secondOutputInfo.getInfo()).sequenceOfMaps); + assertEquals(OnnxJavaType.UNKNOWN, ((SequenceInfo) secondOutputInfo.getInfo()).sequenceType); + MapInfo mapInfo = ((SequenceInfo) secondOutputInfo.getInfo()).mapInfo; + assertNotNull(mapInfo); + assertEquals(OnnxJavaType.STRING, mapInfo.keyType); + assertEquals(OnnxJavaType.FLOAT, mapInfo.valueType); Map container = new HashMap<>(); long[] shape = new long[] {1, 2};