mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
[Java] Add support for map and sequence information on output nodes (#3468)
This commit is contained in:
parent
7c89f38a34
commit
c91527235a
8 changed files with 175 additions and 35 deletions
|
|
@ -1,5 +1,6 @@
|
|||
plugins {
|
||||
id 'java'
|
||||
id 'jacoco'
|
||||
id 'com.diffplug.gradle.spotless' version '3.26.0'
|
||||
}
|
||||
|
||||
|
|
@ -105,7 +106,6 @@ if (cmakeBuildDir != null) {
|
|||
|
||||
}
|
||||
|
||||
|
||||
dependencies {
|
||||
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.1.1'
|
||||
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.1.1'
|
||||
|
|
@ -118,3 +118,11 @@ test {
|
|||
events "passed", "skipped", "failed"
|
||||
}
|
||||
}
|
||||
|
||||
jacocoTestReport {
|
||||
reports {
|
||||
xml.enabled true
|
||||
csv.enabled true
|
||||
html.destination file("${buildDir}/jacocoHtml")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ final class OnnxRuntime {
|
|||
|
||||
// The initial release of the ORT API.
|
||||
private static final int ORT_API_VERSION_1 = 1;
|
||||
// Post 1.0 builds of the ORT API.
|
||||
private static final int ORT_API_VERSION_2 = 2;
|
||||
|
||||
/** The short name of the ONNX runtime shared library */
|
||||
static final String ONNXRUNTIME_LIBRARY_NAME = "onnxruntime";
|
||||
|
|
@ -48,7 +50,7 @@ final class OnnxRuntime {
|
|||
try {
|
||||
load(tempDirectory, ONNXRUNTIME_LIBRARY_NAME);
|
||||
load(tempDirectory, ONNXRUNTIME_JNI_LIBRARY_NAME);
|
||||
ortApiHandle = initialiseAPIBase(ORT_API_VERSION_1);
|
||||
ortApiHandle = initialiseAPIBase(ORT_API_VERSION_2);
|
||||
loaded = true;
|
||||
} finally {
|
||||
if (!isAndroid()) {
|
||||
|
|
|
|||
|
|
@ -60,18 +60,9 @@ public class OrtSession implements AutoCloseable {
|
|||
*/
|
||||
OrtSession(OrtEnvironment env, String modelPath, OrtAllocator allocator, SessionOptions options)
|
||||
throws OrtException {
|
||||
nativeHandle =
|
||||
createSession(OnnxRuntime.ortApiHandle, env.nativeHandle, modelPath, options.nativeHandle);
|
||||
this.allocator = allocator;
|
||||
numInputs = getNumInputs(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
inputNames =
|
||||
new LinkedHashSet<>(
|
||||
Arrays.asList(getInputNames(OnnxRuntime.ortApiHandle, nativeHandle, allocator.handle)));
|
||||
numOutputs = getNumOutputs(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
outputNames =
|
||||
new LinkedHashSet<>(
|
||||
Arrays.asList(
|
||||
getOutputNames(OnnxRuntime.ortApiHandle, nativeHandle, allocator.handle)));
|
||||
this(
|
||||
createSession(OnnxRuntime.ortApiHandle, env.nativeHandle, modelPath, options.nativeHandle),
|
||||
allocator);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -81,12 +72,24 @@ public class OrtSession implements AutoCloseable {
|
|||
* @param modelArray The model protobuf as a byte array.
|
||||
* @param allocator The allocator to use.
|
||||
* @param options Session configuration options.
|
||||
* @throws OrtException If the mode was corrupted or some other error occurred in native code.
|
||||
* @throws OrtException If the model was corrupted or some other error occurred in native code.
|
||||
*/
|
||||
OrtSession(OrtEnvironment env, byte[] modelArray, OrtAllocator allocator, SessionOptions options)
|
||||
throws OrtException {
|
||||
nativeHandle =
|
||||
createSession(OnnxRuntime.ortApiHandle, env.nativeHandle, modelArray, options.nativeHandle);
|
||||
this(
|
||||
createSession(OnnxRuntime.ortApiHandle, env.nativeHandle, modelArray, options.nativeHandle),
|
||||
allocator);
|
||||
}
|
||||
|
||||
/**
|
||||
* Private constructor to build the Java object wrapped around a native session.
|
||||
*
|
||||
* @param nativeHandle The pointer to the native session.
|
||||
* @param allocator The allocator to use.
|
||||
* @throws OrtException If the model's inputs, outputs or metadata could not be read.
|
||||
*/
|
||||
private OrtSession(long nativeHandle, OrtAllocator allocator) throws OrtException {
|
||||
this.nativeHandle = nativeHandle;
|
||||
this.allocator = allocator;
|
||||
numInputs = getNumInputs(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
inputNames =
|
||||
|
|
@ -289,17 +292,17 @@ public class OrtSession implements AutoCloseable {
|
|||
private static Map<String, NodeInfo> wrapInMap(NodeInfo[] infos) {
|
||||
Map<String, NodeInfo> output = new LinkedHashMap<>();
|
||||
|
||||
for (int i = 0; i < infos.length; i++) {
|
||||
output.put(infos[i].getName(), infos[i]);
|
||||
for (NodeInfo info : infos) {
|
||||
output.put(info.getName(), info);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
private native long createSession(
|
||||
private static native long createSession(
|
||||
long apiHandle, long envHandle, String modelPath, long optsHandle) throws OrtException;
|
||||
|
||||
private native long createSession(
|
||||
private static native long createSession(
|
||||
long apiHandle, long envHandle, byte[] modelArray, long optsHandle) throws OrtException;
|
||||
|
||||
private native long getNumInputs(long apiHandle, long nativeHandle) throws OrtException;
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ public class SequenceInfo implements ValueInfo {
|
|||
}
|
||||
|
||||
/**
|
||||
* Constructs a sequence of known lenght containing maps.
|
||||
* Constructs a sequence of known length containing maps.
|
||||
*
|
||||
* @param length The length of the sequence.
|
||||
* @param keyType The map key type.
|
||||
|
|
|
|||
|
|
@ -205,10 +205,14 @@ jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, OrtTypeInfo * inf
|
|||
return convertToTensorInfo(jniEnv, api, (const OrtTensorTypeAndShapeInfo *) tensorInfo);
|
||||
}
|
||||
case ONNX_TYPE_SEQUENCE: {
|
||||
return createEmptySequenceInfo(jniEnv);
|
||||
const OrtSequenceTypeInfo* sequenceInfo;
|
||||
checkOrtStatus(jniEnv,api,api->CastTypeInfoToSequenceTypeInfo(info,&sequenceInfo));
|
||||
return convertToSequenceInfo(jniEnv, api, sequenceInfo);
|
||||
}
|
||||
case ONNX_TYPE_MAP: {
|
||||
return createEmptyMapInfo(jniEnv);
|
||||
const OrtMapTypeInfo* mapInfo;
|
||||
checkOrtStatus(jniEnv,api,api->CastTypeInfoToMapTypeInfo(info,&mapInfo));
|
||||
return convertToMapInfo(jniEnv, api, mapInfo);
|
||||
}
|
||||
case ONNX_TYPE_UNKNOWN:
|
||||
case ONNX_TYPE_OPAQUE:
|
||||
|
|
@ -261,8 +265,56 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT
|
|||
return tensorInfo;
|
||||
}
|
||||
|
||||
//jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo * info) {
|
||||
// As map info isn't available at this point, it creates an empty map info type.
|
||||
jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtMapTypeInfo * info) {
|
||||
// Create the java methods we need to call.
|
||||
// Get the ONNXTensorType enum static method
|
||||
char *onnxTensorTypeClassName = "ai/onnxruntime/TensorInfo$OnnxTensorType";
|
||||
jclass onnxTensorTypeClazz = (*jniEnv)->FindClass(jniEnv, onnxTensorTypeClassName);
|
||||
jmethodID onnxTensorTypeMapFromInt = (*jniEnv)->GetStaticMethodID(jniEnv,onnxTensorTypeClazz, "mapFromInt", "(I)Lai/onnxruntime/TensorInfo$OnnxTensorType;");
|
||||
|
||||
// Get the ONNXJavaType enum static method
|
||||
char *javaDataTypeClassName = "ai/onnxruntime/OnnxJavaType";
|
||||
jclass onnxJavaTypeClazz = (*jniEnv)->FindClass(jniEnv, javaDataTypeClassName);
|
||||
jmethodID onnxJavaTypeMapFromONNXTensorType = (*jniEnv)->GetStaticMethodID(jniEnv,onnxJavaTypeClazz, "mapFromOnnxTensorType", "(Lai/onnxruntime/TensorInfo$OnnxTensorType;)Lai/onnxruntime/OnnxJavaType;");
|
||||
|
||||
// Get the map info class
|
||||
char *mapInfoClassName = "ai/onnxruntime/MapInfo";
|
||||
jclass mapInfoClazz = (*jniEnv)->FindClass(jniEnv, mapInfoClassName);
|
||||
jmethodID mapInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,mapInfoClazz,"<init>","(ILai/onnxruntime/OnnxJavaType;Lai/onnxruntime/OnnxJavaType;)V");
|
||||
|
||||
// Extract the key type
|
||||
ONNXTensorElementDataType keyType;
|
||||
checkOrtStatus(jniEnv,api,api->GetMapKeyType(info,&keyType));
|
||||
|
||||
// Convert key type to java
|
||||
jint onnxTypeKey = convertFromONNXDataFormat(keyType);
|
||||
jobject onnxTensorTypeJavaKey = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeKey);
|
||||
jobject onnxJavaTypeKey = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJavaKey);
|
||||
|
||||
// according to include/onnxruntime/core/framework/data_types.h only the following values are supported.
|
||||
// string, int64, float, double
|
||||
// So extract the value type, then convert it to a tensor type so we can get it's element type.
|
||||
OrtTypeInfo* valueTypeInfo;
|
||||
checkOrtStatus(jniEnv,api,api->GetMapValueType(info,&valueTypeInfo));
|
||||
const OrtTensorTypeAndShapeInfo* tensorValueInfo;
|
||||
checkOrtStatus(jniEnv,api,api->CastTypeInfoToTensorInfo(valueTypeInfo,&tensorValueInfo));
|
||||
ONNXTensorElementDataType valueType;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorElementType(tensorValueInfo,&valueType));
|
||||
api->ReleaseTypeInfo(valueTypeInfo);
|
||||
tensorValueInfo = NULL;
|
||||
valueTypeInfo = NULL;
|
||||
|
||||
// Convert value type to java
|
||||
jint onnxTypeValue = convertFromONNXDataFormat(valueType);
|
||||
jobject onnxTensorTypeJavaValue = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeValue);
|
||||
jobject onnxJavaTypeValue = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJavaValue);
|
||||
|
||||
// Construct map info
|
||||
jobject mapInfo = (*jniEnv)->NewObject(jniEnv,mapInfoClazz,mapInfoConstructor,(jint)-1,onnxJavaTypeKey,onnxJavaTypeValue);
|
||||
|
||||
return mapInfo;
|
||||
}
|
||||
|
||||
jobject createEmptyMapInfo(JNIEnv *jniEnv) {
|
||||
// Create the ONNXJavaType enum
|
||||
char *onnxJavaTypeClassName = "ai/onnxruntime/OnnxJavaType";
|
||||
|
|
@ -278,8 +330,72 @@ jobject createEmptyMapInfo(JNIEnv *jniEnv) {
|
|||
return mapInfo;
|
||||
}
|
||||
|
||||
//jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo * info) {
|
||||
// As sequence info isn't available at this point, it creates an empty sequence info type.
|
||||
jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSequenceTypeInfo * info) {
|
||||
// Get the sequence info class
|
||||
char *sequenceInfoClassName = "ai/onnxruntime/SequenceInfo";
|
||||
jclass sequenceInfoClazz = (*jniEnv)->FindClass(jniEnv, sequenceInfoClassName);
|
||||
|
||||
// according to include/onnxruntime/core/framework/data_types.h the following values are supported.
|
||||
// tensor types, map<string,float> and map<long,float>
|
||||
OrtTypeInfo* elementTypeInfo;
|
||||
checkOrtStatus(jniEnv,api,api->GetSequenceElementType(info,&elementTypeInfo));
|
||||
ONNXType type;
|
||||
checkOrtStatus(jniEnv,api,api->GetOnnxTypeFromTypeInfo(elementTypeInfo,&type));
|
||||
|
||||
jobject sequenceInfo;
|
||||
|
||||
switch (type) {
|
||||
case ONNX_TYPE_TENSOR: {
|
||||
// Figure out element type
|
||||
const OrtTensorTypeAndShapeInfo* elementTensorInfo;
|
||||
checkOrtStatus(jniEnv,api,api->CastTypeInfoToTensorInfo(elementTypeInfo,&elementTensorInfo));
|
||||
ONNXTensorElementDataType element;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorElementType(elementTensorInfo,&element));
|
||||
|
||||
// Convert element type into ONNXTensorType
|
||||
jint onnxTypeInt = convertFromONNXDataFormat(element);
|
||||
// Get the ONNXTensorType enum static method
|
||||
char *onnxTensorTypeClassName = "ai/onnxruntime/TensorInfo$OnnxTensorType";
|
||||
jclass onnxTensorTypeClazz = (*jniEnv)->FindClass(jniEnv, onnxTensorTypeClassName);
|
||||
jmethodID onnxTensorTypeMapFromInt = (*jniEnv)->GetStaticMethodID(jniEnv,onnxTensorTypeClazz, "mapFromInt", "(I)Lai/onnxruntime/TensorInfo$OnnxTensorType;");
|
||||
jobject onnxTensorTypeJava = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeInt);
|
||||
|
||||
// Get the ONNXJavaType enum static method
|
||||
char *javaDataTypeClassName = "ai/onnxruntime/OnnxJavaType";
|
||||
jclass onnxJavaTypeClazz = (*jniEnv)->FindClass(jniEnv, javaDataTypeClassName);
|
||||
jmethodID onnxJavaTypeMapFromONNXTensorType = (*jniEnv)->GetStaticMethodID(jniEnv,onnxJavaTypeClazz, "mapFromOnnxTensorType", "(Lai/onnxruntime/TensorInfo$OnnxTensorType;)Lai/onnxruntime/OnnxJavaType;");
|
||||
jobject onnxJavaType = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJava);
|
||||
|
||||
// Construct sequence info
|
||||
jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,sequenceInfoClazz,"<init>","(ILai/onnxruntime/OnnxJavaType;)V");
|
||||
sequenceInfo = (*jniEnv)->NewObject(jniEnv,sequenceInfoClazz,sequenceInfoConstructor,(jint)-1,onnxJavaType);
|
||||
break;
|
||||
}
|
||||
case ONNX_TYPE_MAP: {
|
||||
// Extract the map info
|
||||
const OrtMapTypeInfo* mapInfo;
|
||||
checkOrtStatus(jniEnv,api,api->CastTypeInfoToMapTypeInfo(elementTypeInfo,&mapInfo));
|
||||
|
||||
// Convert it using the existing convert function
|
||||
jobject javaMapInfo = convertToMapInfo(jniEnv,api,mapInfo);
|
||||
|
||||
// Construct sequence info
|
||||
jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,sequenceInfoClazz,"<init>","(ILai/onnxruntime/MapInfo;)V");
|
||||
sequenceInfo = (*jniEnv)->NewObject(jniEnv,sequenceInfoClazz,sequenceInfoConstructor,(jint)-1,javaMapInfo);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
sequenceInfo = createEmptySequenceInfo(jniEnv);
|
||||
throwOrtException(jniEnv,convertErrorCode(ORT_INVALID_ARGUMENT),"Invalid element type found in sequence");
|
||||
break;
|
||||
}
|
||||
}
|
||||
api->ReleaseTypeInfo(elementTypeInfo);
|
||||
elementTypeInfo = NULL;
|
||||
|
||||
return sequenceInfo;
|
||||
}
|
||||
|
||||
jobject createEmptySequenceInfo(JNIEnv *jniEnv) {
|
||||
// Create the ONNXJavaType enum
|
||||
char *onnxJavaTypeClassName = "ai/onnxruntime/OnnxJavaType";
|
||||
|
|
|
|||
|
|
@ -31,9 +31,8 @@ jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, OrtTypeInfo * inf
|
|||
|
||||
jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorTypeAndShapeInfo * info);
|
||||
|
||||
//TODO when C API supports inspecting the types of map and sequence types from OutputInfos
|
||||
//jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo * info);
|
||||
//jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo * info);
|
||||
jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtMapTypeInfo * info);
|
||||
jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSequenceTypeInfo * info);
|
||||
|
||||
jobject createEmptyMapInfo(JNIEnv *jniEnv);
|
||||
jobject createEmptySequenceInfo(JNIEnv *jniEnv);
|
||||
|
|
|
|||
|
|
@ -14,8 +14,8 @@
|
|||
* Signature: (JJLjava/lang/String;J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_lang_String_2J
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong envHandle, jstring modelPath, jlong optsHandle) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jstring modelPath, jlong optsHandle) {
|
||||
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
OrtSession* session;
|
||||
|
||||
|
|
@ -43,8 +43,8 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_la
|
|||
* Signature: (JJ[BJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJ_3BJ
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong envHandle, jbyteArray jModelArray, jlong optsHandle) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jbyteArray jModelArray, jlong optsHandle) {
|
||||
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
OrtSession* session;
|
||||
|
||||
|
|
|
|||
|
|
@ -915,6 +915,12 @@ public class InferenceTest {
|
|||
assertTrue(firstOutputInfo.getInfo() instanceof TensorInfo);
|
||||
assertTrue(secondOutputInfo.getInfo() instanceof SequenceInfo);
|
||||
assertEquals(OnnxJavaType.INT64, ((TensorInfo) firstOutputInfo.getInfo()).type);
|
||||
assertTrue(((SequenceInfo) secondOutputInfo.getInfo()).sequenceOfMaps);
|
||||
assertEquals(OnnxJavaType.UNKNOWN, ((SequenceInfo) secondOutputInfo.getInfo()).sequenceType);
|
||||
MapInfo mapInfo = ((SequenceInfo) secondOutputInfo.getInfo()).mapInfo;
|
||||
assertNotNull(mapInfo);
|
||||
assertEquals(OnnxJavaType.INT64, mapInfo.keyType);
|
||||
assertEquals(OnnxJavaType.FLOAT, mapInfo.valueType);
|
||||
|
||||
Map<String, OnnxTensor> container = new HashMap<>();
|
||||
long[] shape = new long[] {1, 2};
|
||||
|
|
@ -975,6 +981,12 @@ public class InferenceTest {
|
|||
assertTrue(firstOutputInfo.getInfo() instanceof TensorInfo);
|
||||
assertTrue(secondOutputInfo.getInfo() instanceof SequenceInfo);
|
||||
assertEquals(OnnxJavaType.STRING, ((TensorInfo) firstOutputInfo.getInfo()).type);
|
||||
assertTrue(((SequenceInfo) secondOutputInfo.getInfo()).sequenceOfMaps);
|
||||
assertEquals(OnnxJavaType.UNKNOWN, ((SequenceInfo) secondOutputInfo.getInfo()).sequenceType);
|
||||
MapInfo mapInfo = ((SequenceInfo) secondOutputInfo.getInfo()).mapInfo;
|
||||
assertNotNull(mapInfo);
|
||||
assertEquals(OnnxJavaType.STRING, mapInfo.keyType);
|
||||
assertEquals(OnnxJavaType.FLOAT, mapInfo.valueType);
|
||||
|
||||
Map<String, OnnxTensor> container = new HashMap<>();
|
||||
long[] shape = new long[] {1, 2};
|
||||
|
|
|
|||
Loading…
Reference in a new issue