mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
Previously OnnxSequence would flatten out a list of tensors into a single output array assuming they were all scalar values. This doesn't accurately represent the semantics of an ONNX sequence, but was what the semantics appeared to be years ago when I first wrote that class. This PR changes it so that the `getValue` method on `OnnxSequence` unwraps the sequence and returns `List<? extends OnnxValue>` allowing the user to process the individual ONNX values separately. It's done this way rather than returning a multidimensional array for a tensor and a Java map for a map as multidimensional arrays are very inefficient in Java and best practice when operating with a OnnxTensor in Java is to use a `java.nio.ByteBuffer`. So allowing users to access each `OnnxTensor`s individually allows them to control how the data is materialised on the Java heap.
105 lines
4.1 KiB
C
105 lines
4.1 KiB
C
/*
|
|
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
|
|
* Licensed under the MIT License.
|
|
*/
|
|
#include <jni.h>
|
|
#include "onnxruntime/core/session/onnxruntime_c_api.h"
|
|
#include "OrtJniUtil.h"
|
|
#include "ai_onnxruntime_OnnxSequence.h"
|
|
|
|
/*
|
|
* Class: ai_onnxruntime_OnnxSequence
|
|
* Method: getMaps
|
|
* Signature: (JJJ)[Lai/onnxruntime/OnnxMap;
|
|
*/
|
|
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getMaps(JNIEnv* jniEnv, jobject jobj,
|
|
jlong apiHandle, jlong handle,
|
|
jlong allocatorHandle) {
|
|
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
|
const OrtApi* api = (const OrtApi*)apiHandle;
|
|
OrtValue* sequence = (OrtValue*)handle;
|
|
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
|
|
|
|
jobjectArray outputArray = NULL;
|
|
|
|
// Get the element count of this sequence
|
|
size_t count;
|
|
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValueCount(sequence, &count));
|
|
if (code == ORT_OK) {
|
|
jclass tensorClazz = (*jniEnv)->FindClass(jniEnv, "ai/onnxruntime/OnnxMap");
|
|
outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(count), tensorClazz, NULL);
|
|
for (size_t i = 0; i < count; i++) {
|
|
// Extract element
|
|
OrtValue* element;
|
|
code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element));
|
|
if (code == ORT_OK) {
|
|
jobject str = createJavaMapFromONNX(jniEnv, api, allocator, element);
|
|
if (str == NULL) {
|
|
api->ReleaseValue(element);
|
|
// bail out as exception has been thrown
|
|
return NULL;
|
|
}
|
|
(*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, (jsize)i, str);
|
|
} else {
|
|
// bail out as exception has been thrown
|
|
return NULL;
|
|
}
|
|
}
|
|
}
|
|
return outputArray;
|
|
}
|
|
|
|
/*
|
|
* Class: ai_onnxruntime_OnnxSequence
|
|
* Method: getTensors
|
|
* Signature: (JJJ)[Lai/onnxruntime/OnnxTensor;
|
|
*/
|
|
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getTensors(JNIEnv* jniEnv, jobject jobj,
|
|
jlong apiHandle, jlong handle,
|
|
jlong allocatorHandle) {
|
|
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
|
const OrtApi* api = (const OrtApi*)apiHandle;
|
|
OrtValue* sequence = (OrtValue*)handle;
|
|
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
|
|
|
|
jobjectArray outputArray = NULL;
|
|
|
|
// Get the element count of this sequence
|
|
size_t count;
|
|
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValueCount(sequence, &count));
|
|
if (code == ORT_OK) {
|
|
jclass tensorClazz = (*jniEnv)->FindClass(jniEnv, "ai/onnxruntime/OnnxTensor");
|
|
outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(count), tensorClazz, NULL);
|
|
for (size_t i = 0; i < count; i++) {
|
|
// Extract element
|
|
OrtValue* element;
|
|
code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element));
|
|
if (code == ORT_OK) {
|
|
jobject str = createJavaTensorFromONNX(jniEnv, api, allocator, element);
|
|
if (str == NULL) {
|
|
api->ReleaseValue(element);
|
|
// bail out as exception has been thrown
|
|
return NULL;
|
|
}
|
|
(*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, (jsize)i, str);
|
|
} else {
|
|
// bail out as exception has been thrown
|
|
return NULL;
|
|
}
|
|
}
|
|
}
|
|
return outputArray;
|
|
}
|
|
|
|
/*
|
|
* Class: ai_onnxruntime_OnnxSequence
|
|
* Method: close
|
|
* Signature: (J)V
|
|
*/
|
|
JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxSequence_close(JNIEnv* jniEnv, jobject jobj, jlong apiHandle,
|
|
jlong handle) {
|
|
(void)jniEnv;
|
|
(void)jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
|
const OrtApi* api = (const OrtApi*)apiHandle;
|
|
api->ReleaseValue((OrtValue*)handle);
|
|
}
|