onnxruntime/java/src/main/native/ai_onnxruntime_OnnxSequence.c
Adam Pocock 388d3cf847
[Java] Fix OnnxSequence semantics (#13012)
Previously OnnxSequence would flatten out a list of tensors into a
single output array assuming they were all scalar values. This doesn't
accurately represent the semantics of an ONNX sequence, but was what the
semantics appeared to be years ago when I first wrote that class. This
PR changes it so that the `getValue` method on `OnnxSequence` unwraps
the sequence and returns `List<? extends OnnxValue>` allowing the user
to process the individual ONNX values separately. It's done this way
rather than returning a multidimensional array for a tensor and a Java
map for a map as multidimensional arrays are very inefficient in Java
and best practice when operating with a OnnxTensor in Java is to use a
`java.nio.ByteBuffer`. So allowing users to access each `OnnxTensor`s
individually allows them to control how the data is materialised on the
Java heap.
2022-09-28 15:53:30 -07:00

105 lines
4.1 KiB
C

/*
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
#include "onnxruntime/core/session/onnxruntime_c_api.h"
#include "OrtJniUtil.h"
#include "ai_onnxruntime_OnnxSequence.h"
/*
* Class: ai_onnxruntime_OnnxSequence
* Method: getMaps
* Signature: (JJJ)[Lai/onnxruntime/OnnxMap;
*/
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getMaps(JNIEnv* jniEnv, jobject jobj,
jlong apiHandle, jlong handle,
jlong allocatorHandle) {
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtValue* sequence = (OrtValue*)handle;
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
jobjectArray outputArray = NULL;
// Get the element count of this sequence
size_t count;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValueCount(sequence, &count));
if (code == ORT_OK) {
jclass tensorClazz = (*jniEnv)->FindClass(jniEnv, "ai/onnxruntime/OnnxMap");
outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(count), tensorClazz, NULL);
for (size_t i = 0; i < count; i++) {
// Extract element
OrtValue* element;
code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element));
if (code == ORT_OK) {
jobject str = createJavaMapFromONNX(jniEnv, api, allocator, element);
if (str == NULL) {
api->ReleaseValue(element);
// bail out as exception has been thrown
return NULL;
}
(*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, (jsize)i, str);
} else {
// bail out as exception has been thrown
return NULL;
}
}
}
return outputArray;
}
/*
* Class: ai_onnxruntime_OnnxSequence
* Method: getTensors
* Signature: (JJJ)[Lai/onnxruntime/OnnxTensor;
*/
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getTensors(JNIEnv* jniEnv, jobject jobj,
jlong apiHandle, jlong handle,
jlong allocatorHandle) {
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtValue* sequence = (OrtValue*)handle;
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
jobjectArray outputArray = NULL;
// Get the element count of this sequence
size_t count;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValueCount(sequence, &count));
if (code == ORT_OK) {
jclass tensorClazz = (*jniEnv)->FindClass(jniEnv, "ai/onnxruntime/OnnxTensor");
outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(count), tensorClazz, NULL);
for (size_t i = 0; i < count; i++) {
// Extract element
OrtValue* element;
code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element));
if (code == ORT_OK) {
jobject str = createJavaTensorFromONNX(jniEnv, api, allocator, element);
if (str == NULL) {
api->ReleaseValue(element);
// bail out as exception has been thrown
return NULL;
}
(*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, (jsize)i, str);
} else {
// bail out as exception has been thrown
return NULL;
}
}
}
return outputArray;
}
/*
* Class: ai_onnxruntime_OnnxSequence
* Method: close
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxSequence_close(JNIEnv* jniEnv, jobject jobj, jlong apiHandle,
jlong handle) {
(void)jniEnv;
(void)jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
api->ReleaseValue((OrtValue*)handle);
}