diff --git a/java/src/main/java/ai/onnxruntime/OnnxMap.java b/java/src/main/java/ai/onnxruntime/OnnxMap.java index c88d1cdaf7..53e9e64432 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxMap.java +++ b/java/src/main/java/ai/onnxruntime/OnnxMap.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -137,7 +137,7 @@ public class OnnxMap implements OnnxValue { * @throws OrtException If the onnx runtime failed to read the entries. */ @Override - public Map getValue() throws OrtException { + public Map getValue() throws OrtException { Object[] keys = getMapKeys(); Object[] values = getMapValues(); HashMap map = new HashMap<>(OrtUtil.capacityFromSize(keys.length)); diff --git a/java/src/main/java/ai/onnxruntime/OnnxSequence.java b/java/src/main/java/ai/onnxruntime/OnnxSequence.java index efa719ef02..ff58a862cb 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxSequence.java +++ b/java/src/main/java/ai/onnxruntime/OnnxSequence.java @@ -1,21 +1,27 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; -import java.util.HashMap; +import java.util.Collections; import java.util.List; -import java.util.stream.Collectors; /** * A sequence of {@link OnnxValue}s all of the same type. * - *

Supports the types mentioned in "onnxruntime_c_api.h", currently String, Long, Float, Double, - * Map>String,Float<, Map>Long,Float<. + *

Supports the types mentioned in "onnxruntime_c_api.h", currently + * + *

*/ public class OnnxSequence implements OnnxValue { @@ -54,51 +60,33 @@ public class OnnxSequence implements OnnxValue { } /** - * Extracts a Java object from the native ONNX type. + * Extracts a Java list of the {@link OnnxValue}s which can then be further unwrapped. * - *

Returns either a {@link List} of boxed primitives, {@link String}s, or {@link - * java.util.Map}s. + *

Returns either a {@link List} of either {@link OnnxTensor} or {@link OnnxMap}. * - * @return A Java object containing the value. + *

Note unlike the other {@link OnnxValue#getValue()} methods, this does not copy the values + * themselves into the Java heap, it merely exposes them as {@link OnnxValue} instances, allowing + * users to use the faster copy methods available for {@link OnnxTensor}. This also means that + * those values need to be closed separately from this instance, and are not closed by {@link + * #close} on this object. + * + * @return A Java list containing the values. * @throws OrtException If the runtime failed to read an element. */ @Override - public List getValue() throws OrtException { + public List getValue() throws OrtException { if (info.sequenceOfMaps) { - List outputSequence = new ArrayList<>(info.length); - for (int i = 0; i < info.length; i++) { - Object[] keys = getMapKeys(i); - Object[] values = getMapValues(i); - HashMap map = new HashMap<>(OrtUtil.capacityFromSize(keys.length)); - for (int j = 0; j < keys.length; j++) { - map.put(keys[j], values[j]); - } - outputSequence.add(map); - } - return outputSequence; + OnnxMap[] maps = getMaps(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle); + return Collections.unmodifiableList(Arrays.asList(maps)); } else { switch (info.sequenceType) { - case FLOAT: - float[] floats = getFloats(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle); - ArrayList boxed = new ArrayList<>(floats.length); - for (float aFloat : floats) { - // box float to Float - boxed.add(aFloat); - } - return boxed; - case DOUBLE: - return Arrays.stream(getDoubles(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle)) - .boxed() - .collect(Collectors.toList()); - case INT64: - return Arrays.stream(getLongs(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle)) - .boxed() - .collect(Collectors.toList()); case STRING: - String[] strings = getStrings(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle); - ArrayList list = new ArrayList<>(strings.length); - list.addAll(Arrays.asList(strings)); - return list; + case INT64: + case FLOAT: + case DOUBLE: + OnnxTensor[] tensors = + getTensors(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle); + return Collections.unmodifiableList(Arrays.asList(tensors)); case BOOL: case UINT8: case INT8: @@ -127,95 +115,10 @@ public class OnnxSequence implements OnnxValue { close(OnnxRuntime.ortApiHandle, nativeHandle); } - /** - * Extract the keys for the map at the specified index. - * - * @param index The index to extract. - * @return The map keys as an array. - * @throws OrtException If the native code failed to read the keys. - */ - private Object[] getMapKeys(int index) throws OrtException { - if (info.mapInfo.keyType == OnnxJavaType.STRING) { - return getStringKeys(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, index); - } else { - return Arrays.stream( - getLongKeys(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, index)) - .boxed() - .toArray(); - } - } - - /** - * Extract the values for the map at the specified index. - * - * @param index The index to extract. - * @return The map values as an array. - * @throws OrtException If the native code failed to read the values. - */ - private Object[] getMapValues(int index) throws OrtException { - switch (info.mapInfo.valueType) { - case STRING: - { - return getStringValues(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, index); - } - case INT64: - { - return Arrays.stream( - getLongValues(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, index)) - .boxed() - .toArray(); - } - case FLOAT: - { - float[] floats = - getFloatValues(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, index); - Float[] boxed = new Float[floats.length]; - for (int i = 0; i < floats.length; i++) { - // cast float to Float - boxed[i] = floats[i]; - } - return boxed; - } - case DOUBLE: - { - return Arrays.stream( - getDoubleValues(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, index)) - .boxed() - .toArray(); - } - default: - throw new RuntimeException("Invalid or unknown valueType: " + info.mapInfo.valueType); - } - } - - private native String[] getStringKeys( - long apiHandle, long nativeHandle, long allocatorHandle, int index) throws OrtException; - - private native long[] getLongKeys( - long apiHandle, long nativeHandle, long allocatorHandle, int index) throws OrtException; - - private native String[] getStringValues( - long apiHandle, long nativeHandle, long allocatorHandle, int index) throws OrtException; - - private native long[] getLongValues( - long apiHandle, long nativeHandle, long allocatorHandle, int index) throws OrtException; - - private native float[] getFloatValues( - long apiHandle, long nativeHandle, long allocatorHandle, int index) throws OrtException; - - private native double[] getDoubleValues( - long apiHandle, long nativeHandle, long allocatorHandle, int index) throws OrtException; - - private native String[] getStrings(long apiHandle, long nativeHandle, long allocatorHandle) + private native OnnxMap[] getMaps(long apiHandle, long nativeHandle, long allocatorHandle) throws OrtException; - private native long[] getLongs(long apiHandle, long nativeHandle, long allocatorHandle) - throws OrtException; - - private native float[] getFloats(long apiHandle, long nativeHandle, long allocatorHandle) - throws OrtException; - - private native double[] getDoubles(long apiHandle, long nativeHandle, long allocatorHandle) + private native OnnxTensor[] getTensors(long apiHandle, long nativeHandle, long allocatorHandle) throws OrtException; private native void close(long apiHandle, long nativeHandle); diff --git a/java/src/main/native/ai_onnxruntime_OnnxSequence.c b/java/src/main/native/ai_onnxruntime_OnnxSequence.c index 22cfca6315..522f3f0839 100644 --- a/java/src/main/native/ai_onnxruntime_OnnxSequence.c +++ b/java/src/main/native/ai_onnxruntime_OnnxSequence.c @@ -6,210 +6,55 @@ #include "onnxruntime/core/session/onnxruntime_c_api.h" #include "OrtJniUtil.h" #include "ai_onnxruntime_OnnxSequence.h" + /* * Class: ai_onnxruntime_OnnxSequence - * Method: getStringKeys - * Signature: (JJI)[Ljava/lang/String; + * Method: getMaps + * Signature: (JJJ)[Lai/onnxruntime/OnnxMap; */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStringKeys(JNIEnv* jniEnv, jobject jobj, - jlong apiHandle, jlong handle, - jlong allocatorHandle, jint index) { +JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getMaps(JNIEnv* jniEnv, jobject jobj, + jlong apiHandle, jlong handle, + jlong allocatorHandle) { (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*)apiHandle; + OrtValue* sequence = (OrtValue*)handle; OrtAllocator* allocator = (OrtAllocator*)allocatorHandle; - jobjectArray output = NULL; - // Extract element - OrtValue* element; - OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, index, allocator, &element)); + + jobjectArray outputArray = NULL; + + // Get the element count of this sequence + size_t count; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValueCount(sequence, &count)); if (code == ORT_OK) { - // Extract keys from element - OrtValue* keys; - code = checkOrtStatus(jniEnv, api, api->GetValue(element, 0, allocator, &keys)); - - if (code == ORT_OK) { - // Convert to Java String array - output = createStringArrayFromTensor(jniEnv, api, keys); - // Release if valid - api->ReleaseValue(element); + jclass tensorClazz = (*jniEnv)->FindClass(jniEnv, "ai/onnxruntime/OnnxMap"); + outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(count), tensorClazz, NULL); + for (size_t i = 0; i < count; i++) { + // Extract element + OrtValue* element; + code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element)); + if (code == ORT_OK) { + jobject str = createJavaMapFromONNX(jniEnv, api, allocator, element); + if (str == NULL) { + api->ReleaseValue(element); + // bail out as exception has been thrown + return NULL; + } + (*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, (jsize)i, str); + } else { + // bail out as exception has been thrown + return NULL; + } } - - // Keys is valid, so release - api->ReleaseValue(keys); } - return output; + return outputArray; } /* * Class: ai_onnxruntime_OnnxSequence - * Method: getLongKeys - * Signature: (JJI)[J + * Method: getTensors + * Signature: (JJJ)[Lai/onnxruntime/OnnxTensor; */ -JNIEXPORT jlongArray JNICALL Java_ai_onnxruntime_OnnxSequence_getLongKeys(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, - jlong handle, jlong allocatorHandle, - jint index) { - (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*)apiHandle; - OrtAllocator* allocator = (OrtAllocator*)allocatorHandle; - jlongArray output = NULL; - // Extract element - OrtValue* element; - OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, index, allocator, &element)); - if (code == ORT_OK) { - // Extract keys from element - OrtValue* keys; - code = checkOrtStatus(jniEnv, api, api->GetValue(element, 0, allocator, &keys)); - - if (code == ORT_OK) { - // Convert to Java long array - output = createLongArrayFromTensor(jniEnv, api, keys); - // Release if valid - api->ReleaseValue(element); - } - - // Keys is valid, so release - api->ReleaseValue(keys); - } - return output; -} - -/* - * Class: ai_onnxruntime_OnnxSequence - * Method: getStringValues - * Signature: (JJI)[Ljava/lang/String; - */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStringValues(JNIEnv* jniEnv, jobject jobj, - jlong apiHandle, jlong handle, - jlong allocatorHandle, jint index) { - (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*)apiHandle; - OrtAllocator* allocator = (OrtAllocator*)allocatorHandle; - jobjectArray output = NULL; - // Extract element - OrtValue* element; - OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, index, allocator, &element)); - if (code == ORT_OK) { - // Extract values from element - OrtValue* values; - code = checkOrtStatus(jniEnv, api, api->GetValue(element, 1, allocator, &values)); - - if (code == ORT_OK) { - // Convert to Java String array - output = createStringArrayFromTensor(jniEnv, api, values); - // Release if valid - api->ReleaseValue(element); - } - - // values is valid, so release - api->ReleaseValue(values); - } - return output; -} - -/* - * Class: ai_onnxruntime_OnnxSequence - * Method: getLongValues - * Signature: (JJI)[J - */ -JNIEXPORT jlongArray JNICALL Java_ai_onnxruntime_OnnxSequence_getLongValues(JNIEnv* jniEnv, jobject jobj, - jlong apiHandle, jlong handle, - jlong allocatorHandle, jint index) { - (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*)apiHandle; - OrtAllocator* allocator = (OrtAllocator*)allocatorHandle; - jlongArray output = NULL; - // Extract element - OrtValue* element; - OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, index, allocator, &element)); - if (code == ORT_OK) { - // Extract values from element - OrtValue* values; - code = checkOrtStatus(jniEnv, api, api->GetValue(element, 1, allocator, &values)); - - if (code == ORT_OK) { - // Convert to Java long array - output = createLongArrayFromTensor(jniEnv, api, values); - // Release if valid - api->ReleaseValue(element); - } - - // values is valid, so release - api->ReleaseValue(values); - } - return output; -} - -/* - * Class: ai_onnxruntime_OnnxSequence - * Method: getFloatValues - * Signature: (JJI)[F - */ -JNIEXPORT jfloatArray JNICALL Java_ai_onnxruntime_OnnxSequence_getFloatValues(JNIEnv* jniEnv, jobject jobj, - jlong apiHandle, jlong handle, - jlong allocatorHandle, jint index) { - (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*)apiHandle; - OrtAllocator* allocator = (OrtAllocator*)allocatorHandle; - jfloatArray output = NULL; - // Extract element - OrtValue* element; - OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, index, allocator, &element)); - if (code == ORT_OK) { - // Extract values from element - OrtValue* values; - code = checkOrtStatus(jniEnv, api, api->GetValue(element, 1, allocator, &values)); - - if (code == ORT_OK) { - // Convert to Java float array - output = createFloatArrayFromTensor(jniEnv, api, values); - // Release if valid - api->ReleaseValue(element); - } - - // values is valid, so release - api->ReleaseValue(values); - } - return output; -} - -/* - * Class: ai_onnxruntime_OnnxSequence - * Method: getDoubleValues - * Signature: (JJI)[D - */ -JNIEXPORT jdoubleArray JNICALL Java_ai_onnxruntime_OnnxSequence_getDoubleValues(JNIEnv* jniEnv, jobject jobj, - jlong apiHandle, jlong handle, - jlong allocatorHandle, jint index) { - (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*)apiHandle; - OrtAllocator* allocator = (OrtAllocator*)allocatorHandle; - jdoubleArray output = NULL; - // Extract element - OrtValue* element; - OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, index, allocator, &element)); - if (code == ORT_OK) { - // Extract values from element - OrtValue* values; - code = checkOrtStatus(jniEnv, api, api->GetValue(element, 1, allocator, &values)); - - if (code == ORT_OK) { - // Convert to Java double array - output = createDoubleArrayFromTensor(jniEnv, api, values); - // Release if valid - api->ReleaseValue(element); - } - - // values is valid, so release - api->ReleaseValue(values); - } - return output; -} - -/* - * Class: ai_onnxruntime_OnnxSequence - * Method: getStrings - * Signature: (JJ)[Ljava/lang/String; - */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStrings(JNIEnv* jniEnv, jobject jobj, +JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getTensors(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle) { (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. @@ -223,22 +68,20 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStrings(JNIEn size_t count; OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValueCount(sequence, &count)); if (code == ORT_OK) { - jclass stringClazz = (*jniEnv)->FindClass(jniEnv, "java/lang/String"); - outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(count), stringClazz, NULL); + jclass tensorClazz = (*jniEnv)->FindClass(jniEnv, "ai/onnxruntime/OnnxTensor"); + outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(count), tensorClazz, NULL); for (size_t i = 0; i < count; i++) { // Extract element OrtValue* element; code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element)); if (code == ORT_OK) { - jobject str = createStringFromStringTensor(jniEnv, api, element); + jobject str = createJavaTensorFromONNX(jniEnv, api, allocator, element); if (str == NULL) { api->ReleaseValue(element); // bail out as exception has been thrown return NULL; } (*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, (jsize)i, str); - - api->ReleaseValue(element); } else { // bail out as exception has been thrown return NULL; @@ -248,169 +91,6 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStrings(JNIEn return outputArray; } -/* - * Class: ai_onnxruntime_OnnxSequence - * Method: getLongs - * Signature: (JJ)[J - */ -JNIEXPORT jlongArray JNICALL Java_ai_onnxruntime_OnnxSequence_getLongs(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, - jlong handle, jlong allocatorHandle) { - (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*)apiHandle; - OrtValue* sequence = (OrtValue*)handle; - OrtAllocator* allocator = (OrtAllocator*)allocatorHandle; - jlongArray outputArray = NULL; - - // Get the element count of this sequence - size_t count; - OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValueCount(sequence, &count)); - if (code == ORT_OK) { - int64_t* values; - code = checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator, sizeof(int64_t) * count, (void**)&values)); - if (code == ORT_OK) { - for (size_t i = 0; i < count; i++) { - // Extract element - OrtValue* element; - code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element)); - if (code == ORT_OK) { - // Extract the values - int64_t* arr; - code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(element, (void**)&arr)); - if (code == ORT_OK) { - values[i] = arr[0]; - } else { - // bail out as exception has been thrown - api->ReleaseValue(element); - checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values)); - return NULL; - } - - api->ReleaseValue(element); - } else { - // bail out as exception has been thrown - checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values)); - return NULL; - } - } - - outputArray = (*jniEnv)->NewLongArray(jniEnv, safecast_size_t_to_jsize(count)); - (*jniEnv)->SetLongArrayRegion(jniEnv, outputArray, 0, safecast_size_t_to_jsize(count), (jlong*)values); - - checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values)); - } - } - return outputArray; -} - -/* - * Class: ai_onnxruntime_OnnxSequence - * Method: getFloats - * Signature: (JJ)[F - */ -JNIEXPORT jfloatArray JNICALL Java_ai_onnxruntime_OnnxSequence_getFloats(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, - jlong handle, jlong allocatorHandle) { - (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*)apiHandle; - OrtValue* sequence = (OrtValue*)handle; - OrtAllocator* allocator = (OrtAllocator*)allocatorHandle; - jfloatArray outputArray = NULL; - - // Get the element count of this sequence - size_t count; - OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValueCount(sequence, &count)); - if (code == ORT_OK) { - float* values; - code = checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator, sizeof(float) * count, (void**)&values)); - if (code == ORT_OK) { - for (size_t i = 0; i < count; i++) { - // Extract element - OrtValue* element; - code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element)); - if (code == ORT_OK) { - // Extract the values - float* arr; - code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(element, (void**)&arr)); - if (code == ORT_OK) { - values[i] = arr[0]; - } else { - // bail out as exception has been thrown - api->ReleaseValue(element); - checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values)); - return NULL; - } - - api->ReleaseValue(element); - } else { - // bail out as exception has been thrown - checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values)); - return NULL; - } - } - - outputArray = (*jniEnv)->NewFloatArray(jniEnv, safecast_size_t_to_jsize(count)); - (*jniEnv)->SetFloatArrayRegion(jniEnv, outputArray, 0, safecast_size_t_to_jsize(count), values); - - checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values)); - } - } - return outputArray; -} - -/* - * Class: ai_onnxruntime_OnnxSequence - * Method: getDoubles - * Signature: (JJ)[D - */ -JNIEXPORT jdoubleArray JNICALL Java_ai_onnxruntime_OnnxSequence_getDoubles(JNIEnv* jniEnv, jobject jobj, - jlong apiHandle, jlong handle, - jlong allocatorHandle) { - (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*)apiHandle; - OrtValue* sequence = (OrtValue*)handle; - OrtAllocator* allocator = (OrtAllocator*)allocatorHandle; - jdoubleArray outputArray = NULL; - - // Get the element count of this sequence - size_t count; - OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValueCount(sequence, &count)); - if (code == ORT_OK) { - double* values; - code = checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator, sizeof(double) * count, (void**)&values)); - if (code == ORT_OK) { - for (size_t i = 0; i < count; i++) { - // Extract element - OrtValue* element; - code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element)); - if (code == ORT_OK) { - // Extract the values - double* arr; - code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(element, (void**)&arr)); - if (code == ORT_OK) { - values[i] = arr[0]; - } else { - // bail out as exception has been thrown - api->ReleaseValue(element); - checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values)); - return NULL; - } - - api->ReleaseValue(element); - } else { - // bail out as exception has been thrown - checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values)); - return NULL; - } - } - - outputArray = (*jniEnv)->NewDoubleArray(jniEnv, safecast_size_t_to_jsize(count)); - (*jniEnv)->SetDoubleArrayRegion(jniEnv, outputArray, 0, safecast_size_t_to_jsize(count), values); - - checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, values)); - } - } - return outputArray; -} - /* * Class: ai_onnxruntime_OnnxSequence * Method: close diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index f1dd3c57c0..d657654865 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; +import java.nio.LongBuffer; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; @@ -32,6 +33,7 @@ import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -1301,7 +1303,8 @@ public class InferenceTest { // try-cast first element in sequence to map/dictionary type @SuppressWarnings("unchecked") - Map map = (Map) ((List) secondOutput.getValue()).get(0); + Map map = + (Map) ((List) secondOutput.getValue()).get(0).getValue(); assertEquals(0.25938290, map.get(0L), 1e-6); assertEquals(0.40904793, map.get(1L), 1e-6); assertEquals(0.33156919, map.get(2L), 1e-6); @@ -1368,7 +1371,7 @@ public class InferenceTest { // try-cast first element in sequence to map/dictionary type @SuppressWarnings("unchecked") Map map = - (Map) ((List) secondOutput.getValue()).get(0); + (Map) ((List) secondOutput.getValue()).get(0).getValue(); assertEquals(0.25938290, map.get("0"), 1e-6); assertEquals(0.40904793, map.get("1"), 1e-6); assertEquals(0.33156919, map.get("2"), 1e-6); @@ -1377,6 +1380,73 @@ public class InferenceTest { } } + @Test + public void testModelSequenceOfTensors() throws OrtException { + String modelPath = TestHelpers.getResourcePath("/test_sequence_tensors.onnx").toString(); + + try (SessionOptions options = new SessionOptions(); + OrtSession session = env.createSession(modelPath, options)) { + Map outputInfos = session.getOutputInfo(); + NodeInfo outputInfo = outputInfos.get("output_sequence"); + assertTrue(outputInfo.getInfo() instanceof SequenceInfo); + + Map container = new HashMap<>(); + OnnxTensor firstInputTensor = + OnnxTensor.createTensor( + env, OrtUtil.reshape(new long[] {1, 2, 3, 4, 5, 6}, new long[] {2, 3})); + OnnxTensor secondInputTensor = + OnnxTensor.createTensor( + env, OrtUtil.reshape(new long[] {7, 8, 9, 10, 11, 12}, new long[] {2, 3})); + + container.put("tensor1", firstInputTensor); + container.put("tensor2", secondInputTensor); + + try (OrtSession.Result outputs = session.run(container)) { + // output is a sequence + Optional output = outputs.get("output_sequence"); + assertTrue(output.isPresent()); + assertTrue(output.get() instanceof OnnxSequence); + + // cast to a sequence + OnnxSequence seq = (OnnxSequence) output.get(); + + // make sure that the sequence holds only 2 elements (tensors) + assertEquals(2, seq.getInfo().length); + + // try-cast the elements in sequence to tensor type + List elements = seq.getValue(); + assertEquals(2, elements.size()); + assertTrue(elements.get(0) instanceof OnnxTensor); + assertTrue(elements.get(1) instanceof OnnxTensor); + + OnnxTensor firstTensor = (OnnxTensor) elements.get(0); + OnnxTensor secondTensor = (OnnxTensor) elements.get(1); + + LongBuffer outputBuf = firstTensor.getLongBuffer(); + + // make sure the tensors in the output sequence hold the correct values + assertEquals(1, outputBuf.get(0)); + assertEquals(2, outputBuf.get(1)); + assertEquals(3, outputBuf.get(2)); + assertEquals(4, outputBuf.get(3)); + assertEquals(5, outputBuf.get(4)); + assertEquals(6, outputBuf.get(5)); + + outputBuf = secondTensor.getLongBuffer(); + + assertEquals(7, outputBuf.get(0)); + assertEquals(8, outputBuf.get(1)); + assertEquals(9, outputBuf.get(2)); + assertEquals(10, outputBuf.get(3)); + assertEquals(11, outputBuf.get(4)); + assertEquals(12, outputBuf.get(5)); + + firstTensor.close(); + secondTensor.close(); + } + } + } + @Test public void testModelSerialization() throws OrtException, IOException { String cwd = System.getProperty("user.dir");