[Java] Add support for map and sequence information on output nodes (#3468)

This commit is contained in:
Adam Pocock 2020-04-16 05:29:23 -04:00 committed by GitHub
parent 7c89f38a34
commit c91527235a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 175 additions and 35 deletions

View file

@ -1,5 +1,6 @@
plugins {
id 'java'
id 'jacoco'
id 'com.diffplug.gradle.spotless' version '3.26.0'
}
@ -105,7 +106,6 @@ if (cmakeBuildDir != null) {
}
dependencies {
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.1.1'
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.1.1'
@ -118,3 +118,11 @@ test {
events "passed", "skipped", "failed"
}
}
jacocoTestReport {
reports {
xml.enabled true
csv.enabled true
html.destination file("${buildDir}/jacocoHtml")
}
}

View file

@ -22,6 +22,8 @@ final class OnnxRuntime {
// The initial release of the ORT API.
private static final int ORT_API_VERSION_1 = 1;
// Post 1.0 builds of the ORT API.
private static final int ORT_API_VERSION_2 = 2;
/** The short name of the ONNX runtime shared library */
static final String ONNXRUNTIME_LIBRARY_NAME = "onnxruntime";
@ -48,7 +50,7 @@ final class OnnxRuntime {
try {
load(tempDirectory, ONNXRUNTIME_LIBRARY_NAME);
load(tempDirectory, ONNXRUNTIME_JNI_LIBRARY_NAME);
ortApiHandle = initialiseAPIBase(ORT_API_VERSION_1);
ortApiHandle = initialiseAPIBase(ORT_API_VERSION_2);
loaded = true;
} finally {
if (!isAndroid()) {

View file

@ -60,18 +60,9 @@ public class OrtSession implements AutoCloseable {
*/
OrtSession(OrtEnvironment env, String modelPath, OrtAllocator allocator, SessionOptions options)
throws OrtException {
nativeHandle =
createSession(OnnxRuntime.ortApiHandle, env.nativeHandle, modelPath, options.nativeHandle);
this.allocator = allocator;
numInputs = getNumInputs(OnnxRuntime.ortApiHandle, nativeHandle);
inputNames =
new LinkedHashSet<>(
Arrays.asList(getInputNames(OnnxRuntime.ortApiHandle, nativeHandle, allocator.handle)));
numOutputs = getNumOutputs(OnnxRuntime.ortApiHandle, nativeHandle);
outputNames =
new LinkedHashSet<>(
Arrays.asList(
getOutputNames(OnnxRuntime.ortApiHandle, nativeHandle, allocator.handle)));
this(
createSession(OnnxRuntime.ortApiHandle, env.nativeHandle, modelPath, options.nativeHandle),
allocator);
}
/**
@ -81,12 +72,24 @@ public class OrtSession implements AutoCloseable {
* @param modelArray The model protobuf as a byte array.
* @param allocator The allocator to use.
* @param options Session configuration options.
* @throws OrtException If the mode was corrupted or some other error occurred in native code.
* @throws OrtException If the model was corrupted or some other error occurred in native code.
*/
OrtSession(OrtEnvironment env, byte[] modelArray, OrtAllocator allocator, SessionOptions options)
throws OrtException {
nativeHandle =
createSession(OnnxRuntime.ortApiHandle, env.nativeHandle, modelArray, options.nativeHandle);
this(
createSession(OnnxRuntime.ortApiHandle, env.nativeHandle, modelArray, options.nativeHandle),
allocator);
}
/**
* Private constructor to build the Java object wrapped around a native session.
*
* @param nativeHandle The pointer to the native session.
* @param allocator The allocator to use.
* @throws OrtException If the model's inputs, outputs or metadata could not be read.
*/
private OrtSession(long nativeHandle, OrtAllocator allocator) throws OrtException {
this.nativeHandle = nativeHandle;
this.allocator = allocator;
numInputs = getNumInputs(OnnxRuntime.ortApiHandle, nativeHandle);
inputNames =
@ -289,17 +292,17 @@ public class OrtSession implements AutoCloseable {
private static Map<String, NodeInfo> wrapInMap(NodeInfo[] infos) {
Map<String, NodeInfo> output = new LinkedHashMap<>();
for (int i = 0; i < infos.length; i++) {
output.put(infos[i].getName(), infos[i]);
for (NodeInfo info : infos) {
output.put(info.getName(), info);
}
return output;
}
private native long createSession(
private static native long createSession(
long apiHandle, long envHandle, String modelPath, long optsHandle) throws OrtException;
private native long createSession(
private static native long createSession(
long apiHandle, long envHandle, byte[] modelArray, long optsHandle) throws OrtException;
private native long getNumInputs(long apiHandle, long nativeHandle) throws OrtException;

View file

@ -49,7 +49,7 @@ public class SequenceInfo implements ValueInfo {
}
/**
* Constructs a sequence of known lenght containing maps.
* Constructs a sequence of known length containing maps.
*
* @param length The length of the sequence.
* @param keyType The map key type.

View file

@ -205,10 +205,14 @@ jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, OrtTypeInfo * inf
return convertToTensorInfo(jniEnv, api, (const OrtTensorTypeAndShapeInfo *) tensorInfo);
}
case ONNX_TYPE_SEQUENCE: {
return createEmptySequenceInfo(jniEnv);
const OrtSequenceTypeInfo* sequenceInfo;
checkOrtStatus(jniEnv,api,api->CastTypeInfoToSequenceTypeInfo(info,&sequenceInfo));
return convertToSequenceInfo(jniEnv, api, sequenceInfo);
}
case ONNX_TYPE_MAP: {
return createEmptyMapInfo(jniEnv);
const OrtMapTypeInfo* mapInfo;
checkOrtStatus(jniEnv,api,api->CastTypeInfoToMapTypeInfo(info,&mapInfo));
return convertToMapInfo(jniEnv, api, mapInfo);
}
case ONNX_TYPE_UNKNOWN:
case ONNX_TYPE_OPAQUE:
@ -261,8 +265,56 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT
return tensorInfo;
}
//jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo * info) {
// As map info isn't available at this point, it creates an empty map info type.
jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtMapTypeInfo * info) {
// Create the java methods we need to call.
// Get the ONNXTensorType enum static method
char *onnxTensorTypeClassName = "ai/onnxruntime/TensorInfo$OnnxTensorType";
jclass onnxTensorTypeClazz = (*jniEnv)->FindClass(jniEnv, onnxTensorTypeClassName);
jmethodID onnxTensorTypeMapFromInt = (*jniEnv)->GetStaticMethodID(jniEnv,onnxTensorTypeClazz, "mapFromInt", "(I)Lai/onnxruntime/TensorInfo$OnnxTensorType;");
// Get the ONNXJavaType enum static method
char *javaDataTypeClassName = "ai/onnxruntime/OnnxJavaType";
jclass onnxJavaTypeClazz = (*jniEnv)->FindClass(jniEnv, javaDataTypeClassName);
jmethodID onnxJavaTypeMapFromONNXTensorType = (*jniEnv)->GetStaticMethodID(jniEnv,onnxJavaTypeClazz, "mapFromOnnxTensorType", "(Lai/onnxruntime/TensorInfo$OnnxTensorType;)Lai/onnxruntime/OnnxJavaType;");
// Get the map info class
char *mapInfoClassName = "ai/onnxruntime/MapInfo";
jclass mapInfoClazz = (*jniEnv)->FindClass(jniEnv, mapInfoClassName);
jmethodID mapInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,mapInfoClazz,"<init>","(ILai/onnxruntime/OnnxJavaType;Lai/onnxruntime/OnnxJavaType;)V");
// Extract the key type
ONNXTensorElementDataType keyType;
checkOrtStatus(jniEnv,api,api->GetMapKeyType(info,&keyType));
// Convert key type to java
jint onnxTypeKey = convertFromONNXDataFormat(keyType);
jobject onnxTensorTypeJavaKey = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeKey);
jobject onnxJavaTypeKey = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJavaKey);
// according to include/onnxruntime/core/framework/data_types.h only the following values are supported.
// string, int64, float, double
// So extract the value type, then convert it to a tensor type so we can get it's element type.
OrtTypeInfo* valueTypeInfo;
checkOrtStatus(jniEnv,api,api->GetMapValueType(info,&valueTypeInfo));
const OrtTensorTypeAndShapeInfo* tensorValueInfo;
checkOrtStatus(jniEnv,api,api->CastTypeInfoToTensorInfo(valueTypeInfo,&tensorValueInfo));
ONNXTensorElementDataType valueType;
checkOrtStatus(jniEnv,api,api->GetTensorElementType(tensorValueInfo,&valueType));
api->ReleaseTypeInfo(valueTypeInfo);
tensorValueInfo = NULL;
valueTypeInfo = NULL;
// Convert value type to java
jint onnxTypeValue = convertFromONNXDataFormat(valueType);
jobject onnxTensorTypeJavaValue = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeValue);
jobject onnxJavaTypeValue = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJavaValue);
// Construct map info
jobject mapInfo = (*jniEnv)->NewObject(jniEnv,mapInfoClazz,mapInfoConstructor,(jint)-1,onnxJavaTypeKey,onnxJavaTypeValue);
return mapInfo;
}
jobject createEmptyMapInfo(JNIEnv *jniEnv) {
// Create the ONNXJavaType enum
char *onnxJavaTypeClassName = "ai/onnxruntime/OnnxJavaType";
@ -278,8 +330,72 @@ jobject createEmptyMapInfo(JNIEnv *jniEnv) {
return mapInfo;
}
//jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo * info) {
// As sequence info isn't available at this point, it creates an empty sequence info type.
jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSequenceTypeInfo * info) {
// Get the sequence info class
char *sequenceInfoClassName = "ai/onnxruntime/SequenceInfo";
jclass sequenceInfoClazz = (*jniEnv)->FindClass(jniEnv, sequenceInfoClassName);
// according to include/onnxruntime/core/framework/data_types.h the following values are supported.
// tensor types, map<string,float> and map<long,float>
OrtTypeInfo* elementTypeInfo;
checkOrtStatus(jniEnv,api,api->GetSequenceElementType(info,&elementTypeInfo));
ONNXType type;
checkOrtStatus(jniEnv,api,api->GetOnnxTypeFromTypeInfo(elementTypeInfo,&type));
jobject sequenceInfo;
switch (type) {
case ONNX_TYPE_TENSOR: {
// Figure out element type
const OrtTensorTypeAndShapeInfo* elementTensorInfo;
checkOrtStatus(jniEnv,api,api->CastTypeInfoToTensorInfo(elementTypeInfo,&elementTensorInfo));
ONNXTensorElementDataType element;
checkOrtStatus(jniEnv,api,api->GetTensorElementType(elementTensorInfo,&element));
// Convert element type into ONNXTensorType
jint onnxTypeInt = convertFromONNXDataFormat(element);
// Get the ONNXTensorType enum static method
char *onnxTensorTypeClassName = "ai/onnxruntime/TensorInfo$OnnxTensorType";
jclass onnxTensorTypeClazz = (*jniEnv)->FindClass(jniEnv, onnxTensorTypeClassName);
jmethodID onnxTensorTypeMapFromInt = (*jniEnv)->GetStaticMethodID(jniEnv,onnxTensorTypeClazz, "mapFromInt", "(I)Lai/onnxruntime/TensorInfo$OnnxTensorType;");
jobject onnxTensorTypeJava = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeInt);
// Get the ONNXJavaType enum static method
char *javaDataTypeClassName = "ai/onnxruntime/OnnxJavaType";
jclass onnxJavaTypeClazz = (*jniEnv)->FindClass(jniEnv, javaDataTypeClassName);
jmethodID onnxJavaTypeMapFromONNXTensorType = (*jniEnv)->GetStaticMethodID(jniEnv,onnxJavaTypeClazz, "mapFromOnnxTensorType", "(Lai/onnxruntime/TensorInfo$OnnxTensorType;)Lai/onnxruntime/OnnxJavaType;");
jobject onnxJavaType = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJava);
// Construct sequence info
jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,sequenceInfoClazz,"<init>","(ILai/onnxruntime/OnnxJavaType;)V");
sequenceInfo = (*jniEnv)->NewObject(jniEnv,sequenceInfoClazz,sequenceInfoConstructor,(jint)-1,onnxJavaType);
break;
}
case ONNX_TYPE_MAP: {
// Extract the map info
const OrtMapTypeInfo* mapInfo;
checkOrtStatus(jniEnv,api,api->CastTypeInfoToMapTypeInfo(elementTypeInfo,&mapInfo));
// Convert it using the existing convert function
jobject javaMapInfo = convertToMapInfo(jniEnv,api,mapInfo);
// Construct sequence info
jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,sequenceInfoClazz,"<init>","(ILai/onnxruntime/MapInfo;)V");
sequenceInfo = (*jniEnv)->NewObject(jniEnv,sequenceInfoClazz,sequenceInfoConstructor,(jint)-1,javaMapInfo);
break;
}
default: {
sequenceInfo = createEmptySequenceInfo(jniEnv);
throwOrtException(jniEnv,convertErrorCode(ORT_INVALID_ARGUMENT),"Invalid element type found in sequence");
break;
}
}
api->ReleaseTypeInfo(elementTypeInfo);
elementTypeInfo = NULL;
return sequenceInfo;
}
jobject createEmptySequenceInfo(JNIEnv *jniEnv) {
// Create the ONNXJavaType enum
char *onnxJavaTypeClassName = "ai/onnxruntime/OnnxJavaType";

View file

@ -31,9 +31,8 @@ jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, OrtTypeInfo * inf
jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorTypeAndShapeInfo * info);
//TODO when C API supports inspecting the types of map and sequence types from OutputInfos
//jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo * info);
//jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo * info);
jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtMapTypeInfo * info);
jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSequenceTypeInfo * info);
jobject createEmptyMapInfo(JNIEnv *jniEnv);
jobject createEmptySequenceInfo(JNIEnv *jniEnv);

View file

@ -14,8 +14,8 @@
* Signature: (JJLjava/lang/String;J)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_lang_String_2J
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong envHandle, jstring modelPath, jlong optsHandle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jstring modelPath, jlong optsHandle) {
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtSession* session;
@ -43,8 +43,8 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_la
* Signature: (JJ[BJ)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJ_3BJ
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong envHandle, jbyteArray jModelArray, jlong optsHandle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jbyteArray jModelArray, jlong optsHandle) {
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtSession* session;

View file

@ -915,6 +915,12 @@ public class InferenceTest {
assertTrue(firstOutputInfo.getInfo() instanceof TensorInfo);
assertTrue(secondOutputInfo.getInfo() instanceof SequenceInfo);
assertEquals(OnnxJavaType.INT64, ((TensorInfo) firstOutputInfo.getInfo()).type);
assertTrue(((SequenceInfo) secondOutputInfo.getInfo()).sequenceOfMaps);
assertEquals(OnnxJavaType.UNKNOWN, ((SequenceInfo) secondOutputInfo.getInfo()).sequenceType);
MapInfo mapInfo = ((SequenceInfo) secondOutputInfo.getInfo()).mapInfo;
assertNotNull(mapInfo);
assertEquals(OnnxJavaType.INT64, mapInfo.keyType);
assertEquals(OnnxJavaType.FLOAT, mapInfo.valueType);
Map<String, OnnxTensor> container = new HashMap<>();
long[] shape = new long[] {1, 2};
@ -975,6 +981,12 @@ public class InferenceTest {
assertTrue(firstOutputInfo.getInfo() instanceof TensorInfo);
assertTrue(secondOutputInfo.getInfo() instanceof SequenceInfo);
assertEquals(OnnxJavaType.STRING, ((TensorInfo) firstOutputInfo.getInfo()).type);
assertTrue(((SequenceInfo) secondOutputInfo.getInfo()).sequenceOfMaps);
assertEquals(OnnxJavaType.UNKNOWN, ((SequenceInfo) secondOutputInfo.getInfo()).sequenceType);
MapInfo mapInfo = ((SequenceInfo) secondOutputInfo.getInfo()).mapInfo;
assertNotNull(mapInfo);
assertEquals(OnnxJavaType.STRING, mapInfo.keyType);
assertEquals(OnnxJavaType.FLOAT, mapInfo.valueType);
Map<String, OnnxTensor> container = new HashMap<>();
long[] shape = new long[] {1, 2};