[Java] Fix OnnxSequence semantics (#13012)

Previously OnnxSequence would flatten out a list of tensors into a
single output array assuming they were all scalar values. This doesn't
accurately represent the semantics of an ONNX sequence, but was what the
semantics appeared to be years ago when I first wrote that class. This
PR changes it so that the `getValue` method on `OnnxSequence` unwraps
the sequence and returns `List<? extends OnnxValue>` allowing the user
to process the individual ONNX values separately. It's done this way
rather than returning a multidimensional array for a tensor and a Java
map for a map as multidimensional arrays are very inefficient in Java
and best practice when operating with a OnnxTensor in Java is to use a
`java.nio.ByteBuffer`. So allowing users to access each `OnnxTensor`s
individually allows them to control how the data is materialised on the
Java heap.
This commit is contained in:
Adam Pocock 2022-09-28 18:53:30 -04:00 committed by GitHub
parent 55ae71c160
commit 388d3cf847
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 144 additions and 491 deletions

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
@ -137,7 +137,7 @@ public class OnnxMap implements OnnxValue {
* @throws OrtException If the onnx runtime failed to read the entries.
*/
@Override
public Map<Object, Object> getValue() throws OrtException {
public Map<? extends Object, ? extends Object> getValue() throws OrtException {
Object[] keys = getMapKeys();
Object[] values = getMapValues();
HashMap<Object, Object> map = new HashMap<>(OrtUtil.capacityFromSize(keys.length));

View file

@ -1,21 +1,27 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
/**
* A sequence of {@link OnnxValue}s all of the same type.
*
* <p>Supports the types mentioned in "onnxruntime_c_api.h", currently String, Long, Float, Double,
* Map&gt;String,Float&lt;, Map&gt;Long,Float&lt;.
* <p>Supports the types mentioned in "onnxruntime_c_api.h", currently
*
* <ul>
* <li>OnnxTensor&lt;String&gt;
* <li>OnnxTensor&lt;Long&gt;
* <li>OnnxTensor&lt;Float&gt;
* <li>OnnxTensor&lt;Double&gt;
* <li>OnnxMap&lt;String,Float&gt;
* <li>OnnxMap&lt;Long,Float&gt;
* </ul>
*/
public class OnnxSequence implements OnnxValue {
@ -54,51 +60,33 @@ public class OnnxSequence implements OnnxValue {
}
/**
* Extracts a Java object from the native ONNX type.
* Extracts a Java list of the {@link OnnxValue}s which can then be further unwrapped.
*
* <p>Returns either a {@link List} of boxed primitives, {@link String}s, or {@link
* java.util.Map}s.
* <p>Returns either a {@link List} of either {@link OnnxTensor} or {@link OnnxMap}.
*
* @return A Java object containing the value.
* <p>Note unlike the other {@link OnnxValue#getValue()} methods, this does not copy the values
* themselves into the Java heap, it merely exposes them as {@link OnnxValue} instances, allowing
* users to use the faster copy methods available for {@link OnnxTensor}. This also means that
* those values need to be closed separately from this instance, and are not closed by {@link
* #close} on this object.
*
* @return A Java list containing the values.
* @throws OrtException If the runtime failed to read an element.
*/
@Override
public List<Object> getValue() throws OrtException {
public List<? extends OnnxValue> getValue() throws OrtException {
if (info.sequenceOfMaps) {
List<Object> outputSequence = new ArrayList<>(info.length);
for (int i = 0; i < info.length; i++) {
Object[] keys = getMapKeys(i);
Object[] values = getMapValues(i);
HashMap<Object, Object> map = new HashMap<>(OrtUtil.capacityFromSize(keys.length));
for (int j = 0; j < keys.length; j++) {
map.put(keys[j], values[j]);
}
outputSequence.add(map);
}
return outputSequence;
OnnxMap[] maps = getMaps(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle);
return Collections.unmodifiableList(Arrays.asList(maps));
} else {
switch (info.sequenceType) {
case FLOAT:
float[] floats = getFloats(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle);
ArrayList<Object> boxed = new ArrayList<>(floats.length);
for (float aFloat : floats) {
// box float to Float
boxed.add(aFloat);
}
return boxed;
case DOUBLE:
return Arrays.stream(getDoubles(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle))
.boxed()
.collect(Collectors.toList());
case INT64:
return Arrays.stream(getLongs(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle))
.boxed()
.collect(Collectors.toList());
case STRING:
String[] strings = getStrings(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle);
ArrayList<Object> list = new ArrayList<>(strings.length);
list.addAll(Arrays.asList(strings));
return list;
case INT64:
case FLOAT:
case DOUBLE:
OnnxTensor[] tensors =
getTensors(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle);
return Collections.unmodifiableList(Arrays.asList(tensors));
case BOOL:
case UINT8:
case INT8:
@ -127,95 +115,10 @@ public class OnnxSequence implements OnnxValue {
close(OnnxRuntime.ortApiHandle, nativeHandle);
}
/**
* Extract the keys for the map at the specified index.
*
* @param index The index to extract.
* @return The map keys as an array.
* @throws OrtException If the native code failed to read the keys.
*/
private Object[] getMapKeys(int index) throws OrtException {
if (info.mapInfo.keyType == OnnxJavaType.STRING) {
return getStringKeys(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, index);
} else {
return Arrays.stream(
getLongKeys(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, index))
.boxed()
.toArray();
}
}
/**
* Extract the values for the map at the specified index.
*
* @param index The index to extract.
* @return The map values as an array.
* @throws OrtException If the native code failed to read the values.
*/
private Object[] getMapValues(int index) throws OrtException {
switch (info.mapInfo.valueType) {
case STRING:
{
return getStringValues(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, index);
}
case INT64:
{
return Arrays.stream(
getLongValues(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, index))
.boxed()
.toArray();
}
case FLOAT:
{
float[] floats =
getFloatValues(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, index);
Float[] boxed = new Float[floats.length];
for (int i = 0; i < floats.length; i++) {
// cast float to Float
boxed[i] = floats[i];
}
return boxed;
}
case DOUBLE:
{
return Arrays.stream(
getDoubleValues(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, index))
.boxed()
.toArray();
}
default:
throw new RuntimeException("Invalid or unknown valueType: " + info.mapInfo.valueType);
}
}
private native String[] getStringKeys(
long apiHandle, long nativeHandle, long allocatorHandle, int index) throws OrtException;
private native long[] getLongKeys(
long apiHandle, long nativeHandle, long allocatorHandle, int index) throws OrtException;
private native String[] getStringValues(
long apiHandle, long nativeHandle, long allocatorHandle, int index) throws OrtException;
private native long[] getLongValues(
long apiHandle, long nativeHandle, long allocatorHandle, int index) throws OrtException;
private native float[] getFloatValues(
long apiHandle, long nativeHandle, long allocatorHandle, int index) throws OrtException;
private native double[] getDoubleValues(
long apiHandle, long nativeHandle, long allocatorHandle, int index) throws OrtException;
private native String[] getStrings(long apiHandle, long nativeHandle, long allocatorHandle)
private native OnnxMap[] getMaps(long apiHandle, long nativeHandle, long allocatorHandle)
throws OrtException;
private native long[] getLongs(long apiHandle, long nativeHandle, long allocatorHandle)
throws OrtException;
private native float[] getFloats(long apiHandle, long nativeHandle, long allocatorHandle)
throws OrtException;
private native double[] getDoubles(long apiHandle, long nativeHandle, long allocatorHandle)
private native OnnxTensor[] getTensors(long apiHandle, long nativeHandle, long allocatorHandle)
throws OrtException;
private native void close(long apiHandle, long nativeHandle);

View file

@ -6,210 +6,55 @@
#include "onnxruntime/core/session/onnxruntime_c_api.h"
#include "OrtJniUtil.h"
#include "ai_onnxruntime_OnnxSequence.h"
/*
* Class: ai_onnxruntime_OnnxSequence
* Method: getStringKeys
* Signature: (JJI)[Ljava/lang/String;
* Method: getMaps
* Signature: (JJJ)[Lai/onnxruntime/OnnxMap;
*/
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStringKeys(JNIEnv* jniEnv, jobject jobj,
jlong apiHandle, jlong handle,
jlong allocatorHandle, jint index) {
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getMaps(JNIEnv* jniEnv, jobject jobj,
jlong apiHandle, jlong handle,
jlong allocatorHandle) {
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtValue* sequence = (OrtValue*)handle;
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
jobjectArray output = NULL;
// Extract element
OrtValue* element;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, index, allocator, &element));
jobjectArray outputArray = NULL;
// Get the element count of this sequence
size_t count;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValueCount(sequence, &count));
if (code == ORT_OK) {
// Extract keys from element
OrtValue* keys;
code = checkOrtStatus(jniEnv, api, api->GetValue(element, 0, allocator, &keys));
if (code == ORT_OK) {
// Convert to Java String array
output = createStringArrayFromTensor(jniEnv, api, keys);
// Release if valid
api->ReleaseValue(element);
jclass tensorClazz = (*jniEnv)->FindClass(jniEnv, "ai/onnxruntime/OnnxMap");
outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(count), tensorClazz, NULL);
for (size_t i = 0; i < count; i++) {
// Extract element
OrtValue* element;
code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element));
if (code == ORT_OK) {
jobject str = createJavaMapFromONNX(jniEnv, api, allocator, element);
if (str == NULL) {
api->ReleaseValue(element);
// bail out as exception has been thrown
return NULL;
}
(*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, (jsize)i, str);
} else {
// bail out as exception has been thrown
return NULL;
}
}
// Keys is valid, so release
api->ReleaseValue(keys);
}
return output;
return outputArray;
}
/*
* Class: ai_onnxruntime_OnnxSequence
* Method: getLongKeys
* Signature: (JJI)[J
* Method: getTensors
* Signature: (JJJ)[Lai/onnxruntime/OnnxTensor;
*/
JNIEXPORT jlongArray JNICALL Java_ai_onnxruntime_OnnxSequence_getLongKeys(JNIEnv* jniEnv, jobject jobj, jlong apiHandle,
jlong handle, jlong allocatorHandle,
jint index) {
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
jlongArray output = NULL;
// Extract element
OrtValue* element;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, index, allocator, &element));
if (code == ORT_OK) {
// Extract keys from element
OrtValue* keys;
code = checkOrtStatus(jniEnv, api, api->GetValue(element, 0, allocator, &keys));
if (code == ORT_OK) {
// Convert to Java long array
output = createLongArrayFromTensor(jniEnv, api, keys);
// Release if valid
api->ReleaseValue(element);
}
// Keys is valid, so release
api->ReleaseValue(keys);
}
return output;
}
/*
* Class: ai_onnxruntime_OnnxSequence
* Method: getStringValues
* Signature: (JJI)[Ljava/lang/String;
*/
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStringValues(JNIEnv* jniEnv, jobject jobj,
jlong apiHandle, jlong handle,
jlong allocatorHandle, jint index) {
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
jobjectArray output = NULL;
// Extract element
OrtValue* element;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, index, allocator, &element));
if (code == ORT_OK) {
// Extract values from element
OrtValue* values;
code = checkOrtStatus(jniEnv, api, api->GetValue(element, 1, allocator, &values));
if (code == ORT_OK) {
// Convert to Java String array
output = createStringArrayFromTensor(jniEnv, api, values);
// Release if valid
api->ReleaseValue(element);
}
// values is valid, so release
api->ReleaseValue(values);
}
return output;
}
/*
* Class: ai_onnxruntime_OnnxSequence
* Method: getLongValues
* Signature: (JJI)[J
*/
JNIEXPORT jlongArray JNICALL Java_ai_onnxruntime_OnnxSequence_getLongValues(JNIEnv* jniEnv, jobject jobj,
jlong apiHandle, jlong handle,
jlong allocatorHandle, jint index) {
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
jlongArray output = NULL;
// Extract element
OrtValue* element;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, index, allocator, &element));
if (code == ORT_OK) {
// Extract values from element
OrtValue* values;
code = checkOrtStatus(jniEnv, api, api->GetValue(element, 1, allocator, &values));
if (code == ORT_OK) {
// Convert to Java long array
output = createLongArrayFromTensor(jniEnv, api, values);
// Release if valid
api->ReleaseValue(element);
}
// values is valid, so release
api->ReleaseValue(values);
}
return output;
}
/*
* Class: ai_onnxruntime_OnnxSequence
* Method: getFloatValues
* Signature: (JJI)[F
*/
JNIEXPORT jfloatArray JNICALL Java_ai_onnxruntime_OnnxSequence_getFloatValues(JNIEnv* jniEnv, jobject jobj,
jlong apiHandle, jlong handle,
jlong allocatorHandle, jint index) {
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
jfloatArray output = NULL;
// Extract element
OrtValue* element;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, index, allocator, &element));
if (code == ORT_OK) {
// Extract values from element
OrtValue* values;
code = checkOrtStatus(jniEnv, api, api->GetValue(element, 1, allocator, &values));
if (code == ORT_OK) {
// Convert to Java float array
output = createFloatArrayFromTensor(jniEnv, api, values);
// Release if valid
api->ReleaseValue(element);
}
// values is valid, so release
api->ReleaseValue(values);
}
return output;
}
/*
* Class: ai_onnxruntime_OnnxSequence
* Method: getDoubleValues
* Signature: (JJI)[D
*/
JNIEXPORT jdoubleArray JNICALL Java_ai_onnxruntime_OnnxSequence_getDoubleValues(JNIEnv* jniEnv, jobject jobj,
jlong apiHandle, jlong handle,
jlong allocatorHandle, jint index) {
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
jdoubleArray output = NULL;
// Extract element
OrtValue* element;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, index, allocator, &element));
if (code == ORT_OK) {
// Extract values from element
OrtValue* values;
code = checkOrtStatus(jniEnv, api, api->GetValue(element, 1, allocator, &values));
if (code == ORT_OK) {
// Convert to Java double array
output = createDoubleArrayFromTensor(jniEnv, api, values);
// Release if valid
api->ReleaseValue(element);
}
// values is valid, so release
api->ReleaseValue(values);
}
return output;
}
/*
* Class: ai_onnxruntime_OnnxSequence
* Method: getStrings
* Signature: (JJ)[Ljava/lang/String;
*/
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStrings(JNIEnv* jniEnv, jobject jobj,
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getTensors(JNIEnv* jniEnv, jobject jobj,
jlong apiHandle, jlong handle,
jlong allocatorHandle) {
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
@ -223,22 +68,20 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStrings(JNIEn
size_t count;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValueCount(sequence, &count));
if (code == ORT_OK) {
jclass stringClazz = (*jniEnv)->FindClass(jniEnv, "java/lang/String");
outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(count), stringClazz, NULL);
jclass tensorClazz = (*jniEnv)->FindClass(jniEnv, "ai/onnxruntime/OnnxTensor");
outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(count), tensorClazz, NULL);
for (size_t i = 0; i < count; i++) {
// Extract element
OrtValue* element;
code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element));
if (code == ORT_OK) {
jobject str = createStringFromStringTensor(jniEnv, api, element);
jobject str = createJavaTensorFromONNX(jniEnv, api, allocator, element);
if (str == NULL) {
api->ReleaseValue(element);
// bail out as exception has been thrown
return NULL;
}
(*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, (jsize)i, str);
api->ReleaseValue(element);
} else {
// bail out as exception has been thrown
return NULL;
@ -248,169 +91,6 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStrings(JNIEn
return outputArray;
}
/*
* Class: ai_onnxruntime_OnnxSequence
* Method: getLongs
* Signature: (JJ)[J
*/
JNIEXPORT jlongArray JNICALL Java_ai_onnxruntime_OnnxSequence_getLongs(JNIEnv* jniEnv, jobject jobj, jlong apiHandle,
jlong handle, jlong allocatorHandle) {
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtValue* sequence = (OrtValue*)handle;
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
jlongArray outputArray = NULL;
// Get the element count of this sequence
size_t count;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValueCount(sequence, &count));
if (code == ORT_OK) {
int64_t* values;
code = checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator, sizeof(int64_t) * count, (void**)&values));
if (code == ORT_OK) {
for (size_t i = 0; i < count; i++) {
// Extract element
OrtValue* element;
code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element));
if (code == ORT_OK) {
// Extract the values
int64_t* arr;
code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(element, (void**)&arr));
if (code == ORT_OK) {
values[i] = arr[0];
} else {
// bail out as exception has been thrown
api->ReleaseValue(element);
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values));
return NULL;
}
api->ReleaseValue(element);
} else {
// bail out as exception has been thrown
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values));
return NULL;
}
}
outputArray = (*jniEnv)->NewLongArray(jniEnv, safecast_size_t_to_jsize(count));
(*jniEnv)->SetLongArrayRegion(jniEnv, outputArray, 0, safecast_size_t_to_jsize(count), (jlong*)values);
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values));
}
}
return outputArray;
}
/*
* Class: ai_onnxruntime_OnnxSequence
* Method: getFloats
* Signature: (JJ)[F
*/
JNIEXPORT jfloatArray JNICALL Java_ai_onnxruntime_OnnxSequence_getFloats(JNIEnv* jniEnv, jobject jobj, jlong apiHandle,
jlong handle, jlong allocatorHandle) {
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtValue* sequence = (OrtValue*)handle;
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
jfloatArray outputArray = NULL;
// Get the element count of this sequence
size_t count;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValueCount(sequence, &count));
if (code == ORT_OK) {
float* values;
code = checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator, sizeof(float) * count, (void**)&values));
if (code == ORT_OK) {
for (size_t i = 0; i < count; i++) {
// Extract element
OrtValue* element;
code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element));
if (code == ORT_OK) {
// Extract the values
float* arr;
code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(element, (void**)&arr));
if (code == ORT_OK) {
values[i] = arr[0];
} else {
// bail out as exception has been thrown
api->ReleaseValue(element);
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values));
return NULL;
}
api->ReleaseValue(element);
} else {
// bail out as exception has been thrown
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values));
return NULL;
}
}
outputArray = (*jniEnv)->NewFloatArray(jniEnv, safecast_size_t_to_jsize(count));
(*jniEnv)->SetFloatArrayRegion(jniEnv, outputArray, 0, safecast_size_t_to_jsize(count), values);
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values));
}
}
return outputArray;
}
/*
* Class: ai_onnxruntime_OnnxSequence
* Method: getDoubles
* Signature: (JJ)[D
*/
JNIEXPORT jdoubleArray JNICALL Java_ai_onnxruntime_OnnxSequence_getDoubles(JNIEnv* jniEnv, jobject jobj,
jlong apiHandle, jlong handle,
jlong allocatorHandle) {
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtValue* sequence = (OrtValue*)handle;
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
jdoubleArray outputArray = NULL;
// Get the element count of this sequence
size_t count;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValueCount(sequence, &count));
if (code == ORT_OK) {
double* values;
code = checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator, sizeof(double) * count, (void**)&values));
if (code == ORT_OK) {
for (size_t i = 0; i < count; i++) {
// Extract element
OrtValue* element;
code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element));
if (code == ORT_OK) {
// Extract the values
double* arr;
code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(element, (void**)&arr));
if (code == ORT_OK) {
values[i] = arr[0];
} else {
// bail out as exception has been thrown
api->ReleaseValue(element);
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values));
return NULL;
}
api->ReleaseValue(element);
} else {
// bail out as exception has been thrown
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values));
return NULL;
}
}
outputArray = (*jniEnv)->NewDoubleArray(jniEnv, safecast_size_t_to_jsize(count));
(*jniEnv)->SetDoubleArrayRegion(jniEnv, outputArray, 0, safecast_size_t_to_jsize(count), values);
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values));
}
}
return outputArray;
}
/*
* Class: ai_onnxruntime_OnnxSequence
* Method: close

View file

@ -20,6 +20,7 @@ import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.LongBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
@ -32,6 +33,7 @@ import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@ -1301,7 +1303,8 @@ public class InferenceTest {
// try-cast first element in sequence to map/dictionary type
@SuppressWarnings("unchecked")
Map<Long, Float> map = (Map<Long, Float>) ((List<Object>) secondOutput.getValue()).get(0);
Map<Long, Float> map =
(Map<Long, Float>) ((List<OnnxMap>) secondOutput.getValue()).get(0).getValue();
assertEquals(0.25938290, map.get(0L), 1e-6);
assertEquals(0.40904793, map.get(1L), 1e-6);
assertEquals(0.33156919, map.get(2L), 1e-6);
@ -1368,7 +1371,7 @@ public class InferenceTest {
// try-cast first element in sequence to map/dictionary type
@SuppressWarnings("unchecked")
Map<String, Float> map =
(Map<String, Float>) ((List<Object>) secondOutput.getValue()).get(0);
(Map<String, Float>) ((List<OnnxMap>) secondOutput.getValue()).get(0).getValue();
assertEquals(0.25938290, map.get("0"), 1e-6);
assertEquals(0.40904793, map.get("1"), 1e-6);
assertEquals(0.33156919, map.get("2"), 1e-6);
@ -1377,6 +1380,73 @@ public class InferenceTest {
}
}
@Test
public void testModelSequenceOfTensors() throws OrtException {
String modelPath = TestHelpers.getResourcePath("/test_sequence_tensors.onnx").toString();
try (SessionOptions options = new SessionOptions();
OrtSession session = env.createSession(modelPath, options)) {
Map<String, NodeInfo> outputInfos = session.getOutputInfo();
NodeInfo outputInfo = outputInfos.get("output_sequence");
assertTrue(outputInfo.getInfo() instanceof SequenceInfo);
Map<String, OnnxTensor> container = new HashMap<>();
OnnxTensor firstInputTensor =
OnnxTensor.createTensor(
env, OrtUtil.reshape(new long[] {1, 2, 3, 4, 5, 6}, new long[] {2, 3}));
OnnxTensor secondInputTensor =
OnnxTensor.createTensor(
env, OrtUtil.reshape(new long[] {7, 8, 9, 10, 11, 12}, new long[] {2, 3}));
container.put("tensor1", firstInputTensor);
container.put("tensor2", secondInputTensor);
try (OrtSession.Result outputs = session.run(container)) {
// output is a sequence<tensors>
Optional<OnnxValue> output = outputs.get("output_sequence");
assertTrue(output.isPresent());
assertTrue(output.get() instanceof OnnxSequence);
// cast to a sequence
OnnxSequence seq = (OnnxSequence) output.get();
// make sure that the sequence holds only 2 elements (tensors)
assertEquals(2, seq.getInfo().length);
// try-cast the elements in sequence to tensor type
List<? extends OnnxValue> elements = seq.getValue();
assertEquals(2, elements.size());
assertTrue(elements.get(0) instanceof OnnxTensor);
assertTrue(elements.get(1) instanceof OnnxTensor);
OnnxTensor firstTensor = (OnnxTensor) elements.get(0);
OnnxTensor secondTensor = (OnnxTensor) elements.get(1);
LongBuffer outputBuf = firstTensor.getLongBuffer();
// make sure the tensors in the output sequence hold the correct values
assertEquals(1, outputBuf.get(0));
assertEquals(2, outputBuf.get(1));
assertEquals(3, outputBuf.get(2));
assertEquals(4, outputBuf.get(3));
assertEquals(5, outputBuf.get(4));
assertEquals(6, outputBuf.get(5));
outputBuf = secondTensor.getLongBuffer();
assertEquals(7, outputBuf.get(0));
assertEquals(8, outputBuf.get(1));
assertEquals(9, outputBuf.get(2));
assertEquals(10, outputBuf.get(3));
assertEquals(11, outputBuf.get(4));
assertEquals(12, outputBuf.get(5));
firstTensor.close();
secondTensor.close();
}
}
}
@Test
public void testModelSerialization() throws OrtException, IOException {
String cwd = System.getProperty("user.dir");