mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
[Java] Fix OnnxSequence semantics (#13012)
Previously OnnxSequence would flatten out a list of tensors into a single output array assuming they were all scalar values. This doesn't accurately represent the semantics of an ONNX sequence, but was what the semantics appeared to be years ago when I first wrote that class. This PR changes it so that the `getValue` method on `OnnxSequence` unwraps the sequence and returns `List<? extends OnnxValue>` allowing the user to process the individual ONNX values separately. It's done this way rather than returning a multidimensional array for a tensor and a Java map for a map as multidimensional arrays are very inefficient in Java and best practice when operating with a OnnxTensor in Java is to use a `java.nio.ByteBuffer`. So allowing users to access each `OnnxTensor`s individually allows them to control how the data is materialised on the Java heap.
This commit is contained in:
parent
55ae71c160
commit
388d3cf847
4 changed files with 144 additions and 491 deletions
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
|
@ -137,7 +137,7 @@ public class OnnxMap implements OnnxValue {
|
|||
* @throws OrtException If the onnx runtime failed to read the entries.
|
||||
*/
|
||||
@Override
|
||||
public Map<Object, Object> getValue() throws OrtException {
|
||||
public Map<? extends Object, ? extends Object> getValue() throws OrtException {
|
||||
Object[] keys = getMapKeys();
|
||||
Object[] values = getMapValues();
|
||||
HashMap<Object, Object> map = new HashMap<>(OrtUtil.capacityFromSize(keys.length));
|
||||
|
|
|
|||
|
|
@ -1,21 +1,27 @@
|
|||
/*
|
||||
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* A sequence of {@link OnnxValue}s all of the same type.
|
||||
*
|
||||
* <p>Supports the types mentioned in "onnxruntime_c_api.h", currently String, Long, Float, Double,
|
||||
* Map>String,Float<, Map>Long,Float<.
|
||||
* <p>Supports the types mentioned in "onnxruntime_c_api.h", currently
|
||||
*
|
||||
* <ul>
|
||||
* <li>OnnxTensor<String>
|
||||
* <li>OnnxTensor<Long>
|
||||
* <li>OnnxTensor<Float>
|
||||
* <li>OnnxTensor<Double>
|
||||
* <li>OnnxMap<String,Float>
|
||||
* <li>OnnxMap<Long,Float>
|
||||
* </ul>
|
||||
*/
|
||||
public class OnnxSequence implements OnnxValue {
|
||||
|
||||
|
|
@ -54,51 +60,33 @@ public class OnnxSequence implements OnnxValue {
|
|||
}
|
||||
|
||||
/**
|
||||
* Extracts a Java object from the native ONNX type.
|
||||
* Extracts a Java list of the {@link OnnxValue}s which can then be further unwrapped.
|
||||
*
|
||||
* <p>Returns either a {@link List} of boxed primitives, {@link String}s, or {@link
|
||||
* java.util.Map}s.
|
||||
* <p>Returns either a {@link List} of either {@link OnnxTensor} or {@link OnnxMap}.
|
||||
*
|
||||
* @return A Java object containing the value.
|
||||
* <p>Note unlike the other {@link OnnxValue#getValue()} methods, this does not copy the values
|
||||
* themselves into the Java heap, it merely exposes them as {@link OnnxValue} instances, allowing
|
||||
* users to use the faster copy methods available for {@link OnnxTensor}. This also means that
|
||||
* those values need to be closed separately from this instance, and are not closed by {@link
|
||||
* #close} on this object.
|
||||
*
|
||||
* @return A Java list containing the values.
|
||||
* @throws OrtException If the runtime failed to read an element.
|
||||
*/
|
||||
@Override
|
||||
public List<Object> getValue() throws OrtException {
|
||||
public List<? extends OnnxValue> getValue() throws OrtException {
|
||||
if (info.sequenceOfMaps) {
|
||||
List<Object> outputSequence = new ArrayList<>(info.length);
|
||||
for (int i = 0; i < info.length; i++) {
|
||||
Object[] keys = getMapKeys(i);
|
||||
Object[] values = getMapValues(i);
|
||||
HashMap<Object, Object> map = new HashMap<>(OrtUtil.capacityFromSize(keys.length));
|
||||
for (int j = 0; j < keys.length; j++) {
|
||||
map.put(keys[j], values[j]);
|
||||
}
|
||||
outputSequence.add(map);
|
||||
}
|
||||
return outputSequence;
|
||||
OnnxMap[] maps = getMaps(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle);
|
||||
return Collections.unmodifiableList(Arrays.asList(maps));
|
||||
} else {
|
||||
switch (info.sequenceType) {
|
||||
case FLOAT:
|
||||
float[] floats = getFloats(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle);
|
||||
ArrayList<Object> boxed = new ArrayList<>(floats.length);
|
||||
for (float aFloat : floats) {
|
||||
// box float to Float
|
||||
boxed.add(aFloat);
|
||||
}
|
||||
return boxed;
|
||||
case DOUBLE:
|
||||
return Arrays.stream(getDoubles(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle))
|
||||
.boxed()
|
||||
.collect(Collectors.toList());
|
||||
case INT64:
|
||||
return Arrays.stream(getLongs(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle))
|
||||
.boxed()
|
||||
.collect(Collectors.toList());
|
||||
case STRING:
|
||||
String[] strings = getStrings(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle);
|
||||
ArrayList<Object> list = new ArrayList<>(strings.length);
|
||||
list.addAll(Arrays.asList(strings));
|
||||
return list;
|
||||
case INT64:
|
||||
case FLOAT:
|
||||
case DOUBLE:
|
||||
OnnxTensor[] tensors =
|
||||
getTensors(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle);
|
||||
return Collections.unmodifiableList(Arrays.asList(tensors));
|
||||
case BOOL:
|
||||
case UINT8:
|
||||
case INT8:
|
||||
|
|
@ -127,95 +115,10 @@ public class OnnxSequence implements OnnxValue {
|
|||
close(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract the keys for the map at the specified index.
|
||||
*
|
||||
* @param index The index to extract.
|
||||
* @return The map keys as an array.
|
||||
* @throws OrtException If the native code failed to read the keys.
|
||||
*/
|
||||
private Object[] getMapKeys(int index) throws OrtException {
|
||||
if (info.mapInfo.keyType == OnnxJavaType.STRING) {
|
||||
return getStringKeys(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, index);
|
||||
} else {
|
||||
return Arrays.stream(
|
||||
getLongKeys(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, index))
|
||||
.boxed()
|
||||
.toArray();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract the values for the map at the specified index.
|
||||
*
|
||||
* @param index The index to extract.
|
||||
* @return The map values as an array.
|
||||
* @throws OrtException If the native code failed to read the values.
|
||||
*/
|
||||
private Object[] getMapValues(int index) throws OrtException {
|
||||
switch (info.mapInfo.valueType) {
|
||||
case STRING:
|
||||
{
|
||||
return getStringValues(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, index);
|
||||
}
|
||||
case INT64:
|
||||
{
|
||||
return Arrays.stream(
|
||||
getLongValues(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, index))
|
||||
.boxed()
|
||||
.toArray();
|
||||
}
|
||||
case FLOAT:
|
||||
{
|
||||
float[] floats =
|
||||
getFloatValues(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, index);
|
||||
Float[] boxed = new Float[floats.length];
|
||||
for (int i = 0; i < floats.length; i++) {
|
||||
// cast float to Float
|
||||
boxed[i] = floats[i];
|
||||
}
|
||||
return boxed;
|
||||
}
|
||||
case DOUBLE:
|
||||
{
|
||||
return Arrays.stream(
|
||||
getDoubleValues(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, index))
|
||||
.boxed()
|
||||
.toArray();
|
||||
}
|
||||
default:
|
||||
throw new RuntimeException("Invalid or unknown valueType: " + info.mapInfo.valueType);
|
||||
}
|
||||
}
|
||||
|
||||
private native String[] getStringKeys(
|
||||
long apiHandle, long nativeHandle, long allocatorHandle, int index) throws OrtException;
|
||||
|
||||
private native long[] getLongKeys(
|
||||
long apiHandle, long nativeHandle, long allocatorHandle, int index) throws OrtException;
|
||||
|
||||
private native String[] getStringValues(
|
||||
long apiHandle, long nativeHandle, long allocatorHandle, int index) throws OrtException;
|
||||
|
||||
private native long[] getLongValues(
|
||||
long apiHandle, long nativeHandle, long allocatorHandle, int index) throws OrtException;
|
||||
|
||||
private native float[] getFloatValues(
|
||||
long apiHandle, long nativeHandle, long allocatorHandle, int index) throws OrtException;
|
||||
|
||||
private native double[] getDoubleValues(
|
||||
long apiHandle, long nativeHandle, long allocatorHandle, int index) throws OrtException;
|
||||
|
||||
private native String[] getStrings(long apiHandle, long nativeHandle, long allocatorHandle)
|
||||
private native OnnxMap[] getMaps(long apiHandle, long nativeHandle, long allocatorHandle)
|
||||
throws OrtException;
|
||||
|
||||
private native long[] getLongs(long apiHandle, long nativeHandle, long allocatorHandle)
|
||||
throws OrtException;
|
||||
|
||||
private native float[] getFloats(long apiHandle, long nativeHandle, long allocatorHandle)
|
||||
throws OrtException;
|
||||
|
||||
private native double[] getDoubles(long apiHandle, long nativeHandle, long allocatorHandle)
|
||||
private native OnnxTensor[] getTensors(long apiHandle, long nativeHandle, long allocatorHandle)
|
||||
throws OrtException;
|
||||
|
||||
private native void close(long apiHandle, long nativeHandle);
|
||||
|
|
|
|||
|
|
@ -6,210 +6,55 @@
|
|||
#include "onnxruntime/core/session/onnxruntime_c_api.h"
|
||||
#include "OrtJniUtil.h"
|
||||
#include "ai_onnxruntime_OnnxSequence.h"
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OnnxSequence
|
||||
* Method: getStringKeys
|
||||
* Signature: (JJI)[Ljava/lang/String;
|
||||
* Method: getMaps
|
||||
* Signature: (JJJ)[Lai/onnxruntime/OnnxMap;
|
||||
*/
|
||||
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStringKeys(JNIEnv* jniEnv, jobject jobj,
|
||||
jlong apiHandle, jlong handle,
|
||||
jlong allocatorHandle, jint index) {
|
||||
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getMaps(JNIEnv* jniEnv, jobject jobj,
|
||||
jlong apiHandle, jlong handle,
|
||||
jlong allocatorHandle) {
|
||||
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*)apiHandle;
|
||||
OrtValue* sequence = (OrtValue*)handle;
|
||||
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
|
||||
jobjectArray output = NULL;
|
||||
// Extract element
|
||||
OrtValue* element;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, index, allocator, &element));
|
||||
|
||||
jobjectArray outputArray = NULL;
|
||||
|
||||
// Get the element count of this sequence
|
||||
size_t count;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValueCount(sequence, &count));
|
||||
if (code == ORT_OK) {
|
||||
// Extract keys from element
|
||||
OrtValue* keys;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetValue(element, 0, allocator, &keys));
|
||||
|
||||
if (code == ORT_OK) {
|
||||
// Convert to Java String array
|
||||
output = createStringArrayFromTensor(jniEnv, api, keys);
|
||||
// Release if valid
|
||||
api->ReleaseValue(element);
|
||||
jclass tensorClazz = (*jniEnv)->FindClass(jniEnv, "ai/onnxruntime/OnnxMap");
|
||||
outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(count), tensorClazz, NULL);
|
||||
for (size_t i = 0; i < count; i++) {
|
||||
// Extract element
|
||||
OrtValue* element;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element));
|
||||
if (code == ORT_OK) {
|
||||
jobject str = createJavaMapFromONNX(jniEnv, api, allocator, element);
|
||||
if (str == NULL) {
|
||||
api->ReleaseValue(element);
|
||||
// bail out as exception has been thrown
|
||||
return NULL;
|
||||
}
|
||||
(*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, (jsize)i, str);
|
||||
} else {
|
||||
// bail out as exception has been thrown
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
// Keys is valid, so release
|
||||
api->ReleaseValue(keys);
|
||||
}
|
||||
return output;
|
||||
return outputArray;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OnnxSequence
|
||||
* Method: getLongKeys
|
||||
* Signature: (JJI)[J
|
||||
* Method: getTensors
|
||||
* Signature: (JJJ)[Lai/onnxruntime/OnnxTensor;
|
||||
*/
|
||||
JNIEXPORT jlongArray JNICALL Java_ai_onnxruntime_OnnxSequence_getLongKeys(JNIEnv* jniEnv, jobject jobj, jlong apiHandle,
|
||||
jlong handle, jlong allocatorHandle,
|
||||
jint index) {
|
||||
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*)apiHandle;
|
||||
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
|
||||
jlongArray output = NULL;
|
||||
// Extract element
|
||||
OrtValue* element;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, index, allocator, &element));
|
||||
if (code == ORT_OK) {
|
||||
// Extract keys from element
|
||||
OrtValue* keys;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetValue(element, 0, allocator, &keys));
|
||||
|
||||
if (code == ORT_OK) {
|
||||
// Convert to Java long array
|
||||
output = createLongArrayFromTensor(jniEnv, api, keys);
|
||||
// Release if valid
|
||||
api->ReleaseValue(element);
|
||||
}
|
||||
|
||||
// Keys is valid, so release
|
||||
api->ReleaseValue(keys);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OnnxSequence
|
||||
* Method: getStringValues
|
||||
* Signature: (JJI)[Ljava/lang/String;
|
||||
*/
|
||||
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStringValues(JNIEnv* jniEnv, jobject jobj,
|
||||
jlong apiHandle, jlong handle,
|
||||
jlong allocatorHandle, jint index) {
|
||||
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*)apiHandle;
|
||||
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
|
||||
jobjectArray output = NULL;
|
||||
// Extract element
|
||||
OrtValue* element;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, index, allocator, &element));
|
||||
if (code == ORT_OK) {
|
||||
// Extract values from element
|
||||
OrtValue* values;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetValue(element, 1, allocator, &values));
|
||||
|
||||
if (code == ORT_OK) {
|
||||
// Convert to Java String array
|
||||
output = createStringArrayFromTensor(jniEnv, api, values);
|
||||
// Release if valid
|
||||
api->ReleaseValue(element);
|
||||
}
|
||||
|
||||
// values is valid, so release
|
||||
api->ReleaseValue(values);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OnnxSequence
|
||||
* Method: getLongValues
|
||||
* Signature: (JJI)[J
|
||||
*/
|
||||
JNIEXPORT jlongArray JNICALL Java_ai_onnxruntime_OnnxSequence_getLongValues(JNIEnv* jniEnv, jobject jobj,
|
||||
jlong apiHandle, jlong handle,
|
||||
jlong allocatorHandle, jint index) {
|
||||
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*)apiHandle;
|
||||
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
|
||||
jlongArray output = NULL;
|
||||
// Extract element
|
||||
OrtValue* element;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, index, allocator, &element));
|
||||
if (code == ORT_OK) {
|
||||
// Extract values from element
|
||||
OrtValue* values;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetValue(element, 1, allocator, &values));
|
||||
|
||||
if (code == ORT_OK) {
|
||||
// Convert to Java long array
|
||||
output = createLongArrayFromTensor(jniEnv, api, values);
|
||||
// Release if valid
|
||||
api->ReleaseValue(element);
|
||||
}
|
||||
|
||||
// values is valid, so release
|
||||
api->ReleaseValue(values);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OnnxSequence
|
||||
* Method: getFloatValues
|
||||
* Signature: (JJI)[F
|
||||
*/
|
||||
JNIEXPORT jfloatArray JNICALL Java_ai_onnxruntime_OnnxSequence_getFloatValues(JNIEnv* jniEnv, jobject jobj,
|
||||
jlong apiHandle, jlong handle,
|
||||
jlong allocatorHandle, jint index) {
|
||||
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*)apiHandle;
|
||||
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
|
||||
jfloatArray output = NULL;
|
||||
// Extract element
|
||||
OrtValue* element;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, index, allocator, &element));
|
||||
if (code == ORT_OK) {
|
||||
// Extract values from element
|
||||
OrtValue* values;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetValue(element, 1, allocator, &values));
|
||||
|
||||
if (code == ORT_OK) {
|
||||
// Convert to Java float array
|
||||
output = createFloatArrayFromTensor(jniEnv, api, values);
|
||||
// Release if valid
|
||||
api->ReleaseValue(element);
|
||||
}
|
||||
|
||||
// values is valid, so release
|
||||
api->ReleaseValue(values);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OnnxSequence
|
||||
* Method: getDoubleValues
|
||||
* Signature: (JJI)[D
|
||||
*/
|
||||
JNIEXPORT jdoubleArray JNICALL Java_ai_onnxruntime_OnnxSequence_getDoubleValues(JNIEnv* jniEnv, jobject jobj,
|
||||
jlong apiHandle, jlong handle,
|
||||
jlong allocatorHandle, jint index) {
|
||||
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*)apiHandle;
|
||||
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
|
||||
jdoubleArray output = NULL;
|
||||
// Extract element
|
||||
OrtValue* element;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, index, allocator, &element));
|
||||
if (code == ORT_OK) {
|
||||
// Extract values from element
|
||||
OrtValue* values;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetValue(element, 1, allocator, &values));
|
||||
|
||||
if (code == ORT_OK) {
|
||||
// Convert to Java double array
|
||||
output = createDoubleArrayFromTensor(jniEnv, api, values);
|
||||
// Release if valid
|
||||
api->ReleaseValue(element);
|
||||
}
|
||||
|
||||
// values is valid, so release
|
||||
api->ReleaseValue(values);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OnnxSequence
|
||||
* Method: getStrings
|
||||
* Signature: (JJ)[Ljava/lang/String;
|
||||
*/
|
||||
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStrings(JNIEnv* jniEnv, jobject jobj,
|
||||
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getTensors(JNIEnv* jniEnv, jobject jobj,
|
||||
jlong apiHandle, jlong handle,
|
||||
jlong allocatorHandle) {
|
||||
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
|
|
@ -223,22 +68,20 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStrings(JNIEn
|
|||
size_t count;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValueCount(sequence, &count));
|
||||
if (code == ORT_OK) {
|
||||
jclass stringClazz = (*jniEnv)->FindClass(jniEnv, "java/lang/String");
|
||||
outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(count), stringClazz, NULL);
|
||||
jclass tensorClazz = (*jniEnv)->FindClass(jniEnv, "ai/onnxruntime/OnnxTensor");
|
||||
outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(count), tensorClazz, NULL);
|
||||
for (size_t i = 0; i < count; i++) {
|
||||
// Extract element
|
||||
OrtValue* element;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element));
|
||||
if (code == ORT_OK) {
|
||||
jobject str = createStringFromStringTensor(jniEnv, api, element);
|
||||
jobject str = createJavaTensorFromONNX(jniEnv, api, allocator, element);
|
||||
if (str == NULL) {
|
||||
api->ReleaseValue(element);
|
||||
// bail out as exception has been thrown
|
||||
return NULL;
|
||||
}
|
||||
(*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, (jsize)i, str);
|
||||
|
||||
api->ReleaseValue(element);
|
||||
} else {
|
||||
// bail out as exception has been thrown
|
||||
return NULL;
|
||||
|
|
@ -248,169 +91,6 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStrings(JNIEn
|
|||
return outputArray;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OnnxSequence
|
||||
* Method: getLongs
|
||||
* Signature: (JJ)[J
|
||||
*/
|
||||
JNIEXPORT jlongArray JNICALL Java_ai_onnxruntime_OnnxSequence_getLongs(JNIEnv* jniEnv, jobject jobj, jlong apiHandle,
|
||||
jlong handle, jlong allocatorHandle) {
|
||||
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*)apiHandle;
|
||||
OrtValue* sequence = (OrtValue*)handle;
|
||||
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
|
||||
jlongArray outputArray = NULL;
|
||||
|
||||
// Get the element count of this sequence
|
||||
size_t count;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValueCount(sequence, &count));
|
||||
if (code == ORT_OK) {
|
||||
int64_t* values;
|
||||
code = checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator, sizeof(int64_t) * count, (void**)&values));
|
||||
if (code == ORT_OK) {
|
||||
for (size_t i = 0; i < count; i++) {
|
||||
// Extract element
|
||||
OrtValue* element;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element));
|
||||
if (code == ORT_OK) {
|
||||
// Extract the values
|
||||
int64_t* arr;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(element, (void**)&arr));
|
||||
if (code == ORT_OK) {
|
||||
values[i] = arr[0];
|
||||
} else {
|
||||
// bail out as exception has been thrown
|
||||
api->ReleaseValue(element);
|
||||
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values));
|
||||
return NULL;
|
||||
}
|
||||
|
||||
api->ReleaseValue(element);
|
||||
} else {
|
||||
// bail out as exception has been thrown
|
||||
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values));
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
outputArray = (*jniEnv)->NewLongArray(jniEnv, safecast_size_t_to_jsize(count));
|
||||
(*jniEnv)->SetLongArrayRegion(jniEnv, outputArray, 0, safecast_size_t_to_jsize(count), (jlong*)values);
|
||||
|
||||
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values));
|
||||
}
|
||||
}
|
||||
return outputArray;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OnnxSequence
|
||||
* Method: getFloats
|
||||
* Signature: (JJ)[F
|
||||
*/
|
||||
JNIEXPORT jfloatArray JNICALL Java_ai_onnxruntime_OnnxSequence_getFloats(JNIEnv* jniEnv, jobject jobj, jlong apiHandle,
|
||||
jlong handle, jlong allocatorHandle) {
|
||||
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*)apiHandle;
|
||||
OrtValue* sequence = (OrtValue*)handle;
|
||||
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
|
||||
jfloatArray outputArray = NULL;
|
||||
|
||||
// Get the element count of this sequence
|
||||
size_t count;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValueCount(sequence, &count));
|
||||
if (code == ORT_OK) {
|
||||
float* values;
|
||||
code = checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator, sizeof(float) * count, (void**)&values));
|
||||
if (code == ORT_OK) {
|
||||
for (size_t i = 0; i < count; i++) {
|
||||
// Extract element
|
||||
OrtValue* element;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element));
|
||||
if (code == ORT_OK) {
|
||||
// Extract the values
|
||||
float* arr;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(element, (void**)&arr));
|
||||
if (code == ORT_OK) {
|
||||
values[i] = arr[0];
|
||||
} else {
|
||||
// bail out as exception has been thrown
|
||||
api->ReleaseValue(element);
|
||||
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values));
|
||||
return NULL;
|
||||
}
|
||||
|
||||
api->ReleaseValue(element);
|
||||
} else {
|
||||
// bail out as exception has been thrown
|
||||
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values));
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
outputArray = (*jniEnv)->NewFloatArray(jniEnv, safecast_size_t_to_jsize(count));
|
||||
(*jniEnv)->SetFloatArrayRegion(jniEnv, outputArray, 0, safecast_size_t_to_jsize(count), values);
|
||||
|
||||
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values));
|
||||
}
|
||||
}
|
||||
return outputArray;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OnnxSequence
|
||||
* Method: getDoubles
|
||||
* Signature: (JJ)[D
|
||||
*/
|
||||
JNIEXPORT jdoubleArray JNICALL Java_ai_onnxruntime_OnnxSequence_getDoubles(JNIEnv* jniEnv, jobject jobj,
|
||||
jlong apiHandle, jlong handle,
|
||||
jlong allocatorHandle) {
|
||||
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*)apiHandle;
|
||||
OrtValue* sequence = (OrtValue*)handle;
|
||||
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
|
||||
jdoubleArray outputArray = NULL;
|
||||
|
||||
// Get the element count of this sequence
|
||||
size_t count;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValueCount(sequence, &count));
|
||||
if (code == ORT_OK) {
|
||||
double* values;
|
||||
code = checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator, sizeof(double) * count, (void**)&values));
|
||||
if (code == ORT_OK) {
|
||||
for (size_t i = 0; i < count; i++) {
|
||||
// Extract element
|
||||
OrtValue* element;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element));
|
||||
if (code == ORT_OK) {
|
||||
// Extract the values
|
||||
double* arr;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(element, (void**)&arr));
|
||||
if (code == ORT_OK) {
|
||||
values[i] = arr[0];
|
||||
} else {
|
||||
// bail out as exception has been thrown
|
||||
api->ReleaseValue(element);
|
||||
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values));
|
||||
return NULL;
|
||||
}
|
||||
|
||||
api->ReleaseValue(element);
|
||||
} else {
|
||||
// bail out as exception has been thrown
|
||||
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values));
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
outputArray = (*jniEnv)->NewDoubleArray(jniEnv, safecast_size_t_to_jsize(count));
|
||||
(*jniEnv)->SetDoubleArrayRegion(jniEnv, outputArray, 0, safecast_size_t_to_jsize(count), values);
|
||||
|
||||
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values));
|
||||
}
|
||||
}
|
||||
return outputArray;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OnnxSequence
|
||||
* Method: close
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import java.io.IOException;
|
|||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.nio.LongBuffer;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
|
|
@ -32,6 +33,7 @@ import java.util.HashSet;
|
|||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
|
|
@ -1301,7 +1303,8 @@ public class InferenceTest {
|
|||
|
||||
// try-cast first element in sequence to map/dictionary type
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<Long, Float> map = (Map<Long, Float>) ((List<Object>) secondOutput.getValue()).get(0);
|
||||
Map<Long, Float> map =
|
||||
(Map<Long, Float>) ((List<OnnxMap>) secondOutput.getValue()).get(0).getValue();
|
||||
assertEquals(0.25938290, map.get(0L), 1e-6);
|
||||
assertEquals(0.40904793, map.get(1L), 1e-6);
|
||||
assertEquals(0.33156919, map.get(2L), 1e-6);
|
||||
|
|
@ -1368,7 +1371,7 @@ public class InferenceTest {
|
|||
// try-cast first element in sequence to map/dictionary type
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Float> map =
|
||||
(Map<String, Float>) ((List<Object>) secondOutput.getValue()).get(0);
|
||||
(Map<String, Float>) ((List<OnnxMap>) secondOutput.getValue()).get(0).getValue();
|
||||
assertEquals(0.25938290, map.get("0"), 1e-6);
|
||||
assertEquals(0.40904793, map.get("1"), 1e-6);
|
||||
assertEquals(0.33156919, map.get("2"), 1e-6);
|
||||
|
|
@ -1377,6 +1380,73 @@ public class InferenceTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testModelSequenceOfTensors() throws OrtException {
|
||||
String modelPath = TestHelpers.getResourcePath("/test_sequence_tensors.onnx").toString();
|
||||
|
||||
try (SessionOptions options = new SessionOptions();
|
||||
OrtSession session = env.createSession(modelPath, options)) {
|
||||
Map<String, NodeInfo> outputInfos = session.getOutputInfo();
|
||||
NodeInfo outputInfo = outputInfos.get("output_sequence");
|
||||
assertTrue(outputInfo.getInfo() instanceof SequenceInfo);
|
||||
|
||||
Map<String, OnnxTensor> container = new HashMap<>();
|
||||
OnnxTensor firstInputTensor =
|
||||
OnnxTensor.createTensor(
|
||||
env, OrtUtil.reshape(new long[] {1, 2, 3, 4, 5, 6}, new long[] {2, 3}));
|
||||
OnnxTensor secondInputTensor =
|
||||
OnnxTensor.createTensor(
|
||||
env, OrtUtil.reshape(new long[] {7, 8, 9, 10, 11, 12}, new long[] {2, 3}));
|
||||
|
||||
container.put("tensor1", firstInputTensor);
|
||||
container.put("tensor2", secondInputTensor);
|
||||
|
||||
try (OrtSession.Result outputs = session.run(container)) {
|
||||
// output is a sequence<tensors>
|
||||
Optional<OnnxValue> output = outputs.get("output_sequence");
|
||||
assertTrue(output.isPresent());
|
||||
assertTrue(output.get() instanceof OnnxSequence);
|
||||
|
||||
// cast to a sequence
|
||||
OnnxSequence seq = (OnnxSequence) output.get();
|
||||
|
||||
// make sure that the sequence holds only 2 elements (tensors)
|
||||
assertEquals(2, seq.getInfo().length);
|
||||
|
||||
// try-cast the elements in sequence to tensor type
|
||||
List<? extends OnnxValue> elements = seq.getValue();
|
||||
assertEquals(2, elements.size());
|
||||
assertTrue(elements.get(0) instanceof OnnxTensor);
|
||||
assertTrue(elements.get(1) instanceof OnnxTensor);
|
||||
|
||||
OnnxTensor firstTensor = (OnnxTensor) elements.get(0);
|
||||
OnnxTensor secondTensor = (OnnxTensor) elements.get(1);
|
||||
|
||||
LongBuffer outputBuf = firstTensor.getLongBuffer();
|
||||
|
||||
// make sure the tensors in the output sequence hold the correct values
|
||||
assertEquals(1, outputBuf.get(0));
|
||||
assertEquals(2, outputBuf.get(1));
|
||||
assertEquals(3, outputBuf.get(2));
|
||||
assertEquals(4, outputBuf.get(3));
|
||||
assertEquals(5, outputBuf.get(4));
|
||||
assertEquals(6, outputBuf.get(5));
|
||||
|
||||
outputBuf = secondTensor.getLongBuffer();
|
||||
|
||||
assertEquals(7, outputBuf.get(0));
|
||||
assertEquals(8, outputBuf.get(1));
|
||||
assertEquals(9, outputBuf.get(2));
|
||||
assertEquals(10, outputBuf.get(3));
|
||||
assertEquals(11, outputBuf.get(4));
|
||||
assertEquals(12, outputBuf.get(5));
|
||||
|
||||
firstTensor.close();
|
||||
secondTensor.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testModelSerialization() throws OrtException, IOException {
|
||||
String cwd = System.getProperty("user.dir");
|
||||
|
|
|
|||
Loading…
Reference in a new issue