[Java] JNI refactor for OrtSession (#12496)

Refactor JNI error reporting
This commit is contained in:
Adam Pocock 2022-08-16 16:43:06 -04:00 committed by GitHub
parent eb6aa861cf
commit 733db31420
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -4,40 +4,47 @@
*/
#include <jni.h>
#include <string.h>
#include <stdlib.h>
#include "onnxruntime/core/session/onnxruntime_c_api.h"
#include "OrtJniUtil.h"
#include "ai_onnxruntime_OrtSession.h"
const char * const ORTJNI_StringClassName = "java/lang/String";
const char * const ORTJNI_OnnxValueClassName = "ai/onnxruntime/OnnxValue";
const char * const ORTJNI_NodeInfoClassName = "ai/onnxruntime/NodeInfo";
const char * const ORTJNI_MetadataClassName = "ai/onnxruntime/OnnxModelMetadata";
/*
* Class: ai_onnxruntime_OrtSession
* Method: createSession
* Signature: (JJLjava/lang/String;J)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_lang_String_2J
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jstring modelPath, jlong optsHandle) {
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtSession* session;
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_lang_String_2J(JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jstring modelPath, jlong optsHandle) {
(void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtSession* session = NULL;
#ifdef _WIN32
const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, modelPath, NULL);
size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, modelPath);
wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t));
if(newString == NULL) {
throwOrtException(jniEnv, 1, "Not enough memory");
return 0;
}
wcsncpy_s(newString, stringLength+1, (const wchar_t*) cPath, stringLength);
checkOrtStatus(jniEnv,api,api->CreateSession((OrtEnv*)envHandle, newString, (OrtSessionOptions*)optsHandle, &session));
free(newString);
(*jniEnv)->ReleaseStringChars(jniEnv,modelPath,cPath);
const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, modelPath, NULL);
size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, modelPath);
wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t));
if (newString == NULL) {
(*jniEnv)->ReleaseStringChars(jniEnv, modelPath, cPath);
throwOrtException(jniEnv, 1, "Not enough memory");
return 0;
}
wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength);
checkOrtStatus(jniEnv, api,
api->CreateSession((OrtEnv*)envHandle, newString, (OrtSessionOptions*)optsHandle, &session));
free(newString);
(*jniEnv)->ReleaseStringChars(jniEnv, modelPath, cPath);
#else
const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, modelPath, NULL);
checkOrtStatus(jniEnv,api,api->CreateSession((OrtEnv*)envHandle, cPath, (OrtSessionOptions*)optsHandle, &session));
(*jniEnv)->ReleaseStringUTFChars(jniEnv,modelPath,cPath);
const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, modelPath, NULL);
checkOrtStatus(jniEnv, api, api->CreateSession((OrtEnv*)envHandle, cPath, (OrtSessionOptions*)optsHandle, &session));
(*jniEnv)->ReleaseStringUTFChars(jniEnv, modelPath, cPath);
#endif
return (jlong) session;
return (jlong)session;
}
/*
@ -45,34 +52,39 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_la
* Method: createSession
* Signature: (JJ[BJ)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJ_3BJ
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jbyteArray jModelArray, jlong optsHandle) {
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtSession* session;
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJ_3BJ(JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jbyteArray jModelArray, jlong optsHandle) {
(void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtEnv* env = (OrtEnv*)envHandle;
OrtSessionOptions* opts = (OrtSessionOptions*)optsHandle;
OrtSession* session = NULL;
// Get a reference to the byte array elements
jbyte* modelArr = (*jniEnv)->GetByteArrayElements(jniEnv,jModelArray,NULL);
size_t modelLength = (*jniEnv)->GetArrayLength(jniEnv,jModelArray);
checkOrtStatus(jniEnv,api,api->CreateSessionFromArray((OrtEnv*)envHandle, modelArr, modelLength, (OrtSessionOptions*)optsHandle, &session));
// Release the C array.
(*jniEnv)->ReleaseByteArrayElements(jniEnv,jModelArray,modelArr,JNI_ABORT);
return (jlong) session;
size_t modelLength = (*jniEnv)->GetArrayLength(jniEnv, jModelArray);
if (modelLength == 0) {
throwOrtException(jniEnv, 2, "Invalid ONNX model, the byte array is zero length.");
return 0;
}
// Get a reference to the byte array elements
jbyte* modelArr = (*jniEnv)->GetByteArrayElements(jniEnv, jModelArray, NULL);
checkOrtStatus(jniEnv, api, api->CreateSessionFromArray(env, modelArr, modelLength, opts, &session));
// Release the C array.
(*jniEnv)->ReleaseByteArrayElements(jniEnv, jModelArray, modelArr, JNI_ABORT);
return (jlong)session;
}
/*
* Class: ai_onnxruntime_OrtSession
* Method: getNumInputs
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_getNumInputs
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
size_t numInputs;
checkOrtStatus(jniEnv,api,api->SessionGetInputCount((OrtSession*)handle, &numInputs));
return numInputs;
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_getNumInputs(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
size_t numInputs = 0;
checkOrtStatus(jniEnv, api, api->SessionGetInputCount((OrtSession*)handle, &numInputs));
return numInputs;
}
/*
@ -80,31 +92,47 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_getNumInputs
* Method: getInputNames
* Signature: (JJJ)[Ljava/lang/String;
*/
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getInputNames
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getInputNames(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle) {
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
OrtSession* session = (OrtSession*)sessionHandle;
// Setup
char *stringClassName = "java/lang/String";
jclass stringClazz = (*jniEnv)->FindClass(jniEnv, stringClassName);
// Setup
jclass stringClazz = (*jniEnv)->FindClass(jniEnv, ORTJNI_StringClassName);
// Get the number of inputs
size_t numInputs = Java_ai_onnxruntime_OrtSession_getNumInputs(jniEnv, jobj, apiHandle, sessionHandle);
// Get the number of inputs
size_t numInputs = 0;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->SessionGetInputCount(session, &numInputs));
if (code != ORT_OK) {
return NULL;
}
// Allocate the return array
jobjectArray array = (*jniEnv)->NewObjectArray(jniEnv,safecast_size_t_to_jsize(numInputs),stringClazz,NULL);
for (uint32_t i = 0; i < numInputs; i++) {
// Read out the input name and convert it to a java.lang.String
char* inputName;
checkOrtStatus(jniEnv,api,api->SessionGetInputName((OrtSession*)sessionHandle, i, allocator, &inputName));
jstring name = (*jniEnv)->NewStringUTF(jniEnv,inputName);
(*jniEnv)->SetObjectArrayElement(jniEnv, array, i, name);
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,inputName));
int32_t numInputsInt = (int32_t) numInputs;
if (numInputs != (size_t) numInputsInt) {
throwOrtException(jniEnv, 1, "Too many inputs, expected less than 2^31");
}
// Allocate the return array
jobjectArray array = (*jniEnv)->NewObjectArray(jniEnv, numInputsInt, stringClazz, NULL);
for (int32_t i = 0; i < numInputsInt; i++) {
// Read out the input name and convert it to a java.lang.String
char* inputName = NULL;
code = checkOrtStatus(jniEnv, api, api->SessionGetInputName(session, i, allocator, &inputName));
if (code != ORT_OK) {
// break out on error, return array and let Java throw the exception.
break;
}
jstring name = (*jniEnv)->NewStringUTF(jniEnv, inputName);
(*jniEnv)->SetObjectArrayElement(jniEnv, array, i, name);
code = checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, inputName));
if (code != ORT_OK) {
// break out on error, return array and let Java throw the exception.
break;
}
}
return array;
return array;
}
/*
@ -112,13 +140,12 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getInputNames
* Method: getNumOutputs
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_getNumOutputs
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
size_t numOutputs;
checkOrtStatus(jniEnv,api,api->SessionGetOutputCount((OrtSession*)handle, &numOutputs));
return numOutputs;
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_getNumOutputs(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
size_t numOutputs = 0;
checkOrtStatus(jniEnv, api, api->SessionGetOutputCount((OrtSession*)handle, &numOutputs));
return numOutputs;
}
/*
@ -126,31 +153,47 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_getNumOutputs
* Method: getOutputNames
* Signature: (JJJ)[Ljava/lang/String;
*/
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getOutputNames
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getOutputNames(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle) {
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtSession* session = (OrtSession*)sessionHandle;
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
// Setup
char *stringClassName = "java/lang/String";
jclass stringClazz = (*jniEnv)->FindClass(jniEnv, stringClassName);
// Setup
jclass stringClazz = (*jniEnv)->FindClass(jniEnv, ORTJNI_StringClassName);
// Get the number of outputs
size_t numOutputs = Java_ai_onnxruntime_OrtSession_getNumOutputs(jniEnv, jobj, apiHandle, sessionHandle);
// Get the number of outputs
size_t numOutputs = 0;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->SessionGetOutputCount(session, &numOutputs));
if (code != ORT_OK) {
return NULL;
}
// Allocate the return array
jobjectArray array = (*jniEnv)->NewObjectArray(jniEnv,safecast_size_t_to_jsize(numOutputs),stringClazz, NULL);
for (uint32_t i = 0; i < numOutputs; i++) {
// Read out the output name and convert it to a java.lang.String
char* outputName;
checkOrtStatus(jniEnv,api,api->SessionGetOutputName((OrtSession*)sessionHandle, i, allocator, &outputName));
jstring name = (*jniEnv)->NewStringUTF(jniEnv,outputName);
(*jniEnv)->SetObjectArrayElement(jniEnv, array, i, name);
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,outputName));
int32_t numOutputsInt = (int32_t) numOutputs;
if (numOutputs != (size_t) numOutputsInt) {
throwOrtException(jniEnv, 1, "Too many outputs, expected less than 2^31");
}
// Allocate the return array
jobjectArray array = (*jniEnv)->NewObjectArray(jniEnv, numOutputsInt, stringClazz, NULL);
for (int32_t i = 0; i < numOutputsInt; i++) {
// Read out the output name and convert it to a java.lang.String
char* outputName = NULL;
code = checkOrtStatus(jniEnv, api, api->SessionGetOutputName(session, i, allocator, &outputName));
if (code != ORT_OK) {
// break out on error, return array and let Java throw the exception.
break;
}
jstring name = (*jniEnv)->NewStringUTF(jniEnv, outputName);
(*jniEnv)->SetObjectArrayElement(jniEnv, array, i, name);
code = checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, outputName));
if (code != ORT_OK) {
// break out on error, return array and let Java throw the exception.
break;
}
}
return array;
return array;
}
/*
@ -158,41 +201,57 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getOutputNames
* Method: getInputInfo
* Signature: (JJJ)[Lai/onnxruntime/NodeInfo;
*/
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getInputInfo
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getInputInfo(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle) {
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtSession* session = (OrtSession*)sessionHandle;
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
// Setup
char *nodeInfoClassName = "ai/onnxruntime/NodeInfo";
jclass nodeInfoClazz = (*jniEnv)->FindClass(jniEnv, nodeInfoClassName);
jmethodID nodeInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,nodeInfoClazz, "<init>", "(Ljava/lang/String;Lai/onnxruntime/ValueInfo;)V");
// Setup
jclass nodeInfoClazz = (*jniEnv)->FindClass(jniEnv, ORTJNI_NodeInfoClassName);
jmethodID nodeInfoConstructor = (*jniEnv)->GetMethodID(jniEnv, nodeInfoClazz, "<init>",
"(Ljava/lang/String;Lai/onnxruntime/ValueInfo;)V");
// Get the number of inputs
size_t numInputs = Java_ai_onnxruntime_OrtSession_getNumInputs(jniEnv, jobj, apiHandle, sessionHandle);
// Get the number of inputs
size_t numInputs = 0;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->SessionGetInputCount(session, &numInputs));
if (code != ORT_OK) {
return NULL;
}
// Allocate the return array
jobjectArray array = (*jniEnv)->NewObjectArray(jniEnv,safecast_size_t_to_jsize(numInputs),nodeInfoClazz, NULL);
for (size_t i = 0; i < numInputs; i++) {
// Read out the input name and convert it to a java.lang.String
char* inputName;
checkOrtStatus(jniEnv,api,api->SessionGetInputName((OrtSession*)sessionHandle, i, allocator, &inputName));
jstring name = (*jniEnv)->NewStringUTF(jniEnv,inputName);
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,inputName));
// Create a ValueInfo from the OrtTypeInfo
OrtTypeInfo* typeInfo;
checkOrtStatus(jniEnv,api,api->SessionGetInputTypeInfo((OrtSession*)sessionHandle, i, &typeInfo));
jobject valueInfoJava = convertToValueInfo(jniEnv,api,typeInfo);
api->ReleaseTypeInfo(typeInfo);
// Create a NodeInfo and assign into the array
jobject nodeInfo = (*jniEnv)->NewObject(jniEnv, nodeInfoClazz, nodeInfoConstructor, name, valueInfoJava);
(*jniEnv)->SetObjectArrayElement(jniEnv, array,safecast_size_t_to_jsize(i),nodeInfo);
// Allocate the return array
jobjectArray array = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(numInputs), nodeInfoClazz, NULL);
for (size_t i = 0; i < numInputs; i++) {
// Read out the input name and convert it to a java.lang.String
char* inputName = NULL;
code = checkOrtStatus(jniEnv, api, api->SessionGetInputName(session, i, allocator, &inputName));
if (code != ORT_OK) {
break;
}
jstring name = (*jniEnv)->NewStringUTF(jniEnv, inputName);
code = checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, inputName));
if (code != ORT_OK) {
break;
}
return array;
// Create a ValueInfo from the OrtTypeInfo
OrtTypeInfo* typeInfo = NULL;
code = checkOrtStatus(jniEnv, api, api->SessionGetInputTypeInfo(session, i, &typeInfo));
if (code != ORT_OK) {
break;
}
jobject valueInfoJava = convertToValueInfo(jniEnv, api, typeInfo);
api->ReleaseTypeInfo(typeInfo);
if (valueInfoJava == NULL) {
break;
}
// Create a NodeInfo and assign into the array
jobject nodeInfo = (*jniEnv)->NewObject(jniEnv, nodeInfoClazz, nodeInfoConstructor, name, valueInfoJava);
(*jniEnv)->SetObjectArrayElement(jniEnv, array, safecast_size_t_to_jsize(i), nodeInfo);
}
return array;
}
/*
@ -200,40 +259,57 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getInputInfo
* Method: getOutputInfo
* Signature: (JJJ)[Lai/onnxruntime/NodeInfo;
*/
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getOutputInfo
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
// Setup
char *nodeInfoClassName = "ai/onnxruntime/NodeInfo";
jclass nodeInfoClazz = (*jniEnv)->FindClass(jniEnv, nodeInfoClassName);
jmethodID nodeInfoConstructor = (*jniEnv)->GetMethodID(jniEnv, nodeInfoClazz, "<init>", "(Ljava/lang/String;Lai/onnxruntime/ValueInfo;)V");
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getOutputInfo(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle) {
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtSession* session = (OrtSession*)sessionHandle;
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
// Get the number of outputs
size_t numOutputs = Java_ai_onnxruntime_OrtSession_getNumOutputs(jniEnv, jobj, apiHandle, sessionHandle);
// Setup
jclass nodeInfoClazz = (*jniEnv)->FindClass(jniEnv, ORTJNI_NodeInfoClassName);
jmethodID nodeInfoConstructor = (*jniEnv)->GetMethodID(jniEnv, nodeInfoClazz, "<init>",
"(Ljava/lang/String;Lai/onnxruntime/ValueInfo;)V");
// Allocate the return array
jobjectArray array = (*jniEnv)->NewObjectArray(jniEnv,safecast_size_t_to_jsize(numOutputs),nodeInfoClazz,NULL);
for (uint32_t i = 0; i < numOutputs; i++) {
// Read out the output name and convert it to a java.lang.String
char* outputName;
checkOrtStatus(jniEnv,api,api->SessionGetOutputName((OrtSession*)sessionHandle, i, allocator, &outputName));
jstring name = (*jniEnv)->NewStringUTF(jniEnv,outputName);
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,outputName));
// Get the number of outputs
size_t numOutputs = 0;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->SessionGetOutputCount(session, &numOutputs));
if (code != ORT_OK) {
return NULL;
}
// Create a ValueInfo from the OrtTypeInfo
OrtTypeInfo* typeInfo;
checkOrtStatus(jniEnv,api,api->SessionGetOutputTypeInfo((OrtSession*)sessionHandle, i, &typeInfo));
jobject valueInfoJava = convertToValueInfo(jniEnv,api,typeInfo);
api->ReleaseTypeInfo(typeInfo);
// Create a NodeInfo and assign into the array
jobject nodeInfo = (*jniEnv)->NewObject(jniEnv, nodeInfoClazz, nodeInfoConstructor, name, valueInfoJava);
(*jniEnv)->SetObjectArrayElement(jniEnv, array, i, nodeInfo);
// Allocate the return array
jobjectArray array = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(numOutputs), nodeInfoClazz, NULL);
for (uint32_t i = 0; i < numOutputs; i++) {
// Read out the output name and convert it to a java.lang.String
char* outputName = NULL;
code = checkOrtStatus(jniEnv, api, api->SessionGetOutputName(session, i, allocator, &outputName));
if (code != ORT_OK) {
break;
}
jstring name = (*jniEnv)->NewStringUTF(jniEnv, outputName);
code = checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, outputName));
if (code != ORT_OK) {
break;
}
return array;
// Create a ValueInfo from the OrtTypeInfo
OrtTypeInfo* typeInfo = NULL;
code = checkOrtStatus(jniEnv, api, api->SessionGetOutputTypeInfo(session, i, &typeInfo));
if (code != ORT_OK) {
break;
}
jobject valueInfoJava = convertToValueInfo(jniEnv, api, typeInfo);
api->ReleaseTypeInfo(typeInfo);
if (valueInfoJava == NULL) {
break;
}
// Create a NodeInfo and assign into the array
jobject nodeInfo = (*jniEnv)->NewObject(jniEnv, nodeInfoClazz, nodeInfoConstructor, name, valueInfoJava);
(*jniEnv)->SetObjectArrayElement(jniEnv, array, i, nodeInfo);
}
return array;
}
/*
@ -242,99 +318,139 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getOutputInfo
* Signature: (JJJ[Ljava/lang/String;[JJ[Ljava/lang/String;JJ)[Lai/onnxruntime/OnnxValue;
* private native OnnxValue[] run(long apiHandle, long nativeHandle, long allocatorHandle, String[] inputNamesArray, long[] inputs, long numInputs, String[] outputNamesArray, long numOutputs)
*/
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle, jobjectArray inputNamesArr, jlongArray tensorArr, jlong numInputs, jobjectArray outputNamesArr, jlong numOutputs, jlong runOptionsHandle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
OrtSession* session = (OrtSession*) sessionHandle;
OrtRunOptions* runOptions = (OrtRunOptions*) runOptionsHandle;
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run(JNIEnv* jniEnv, jobject jobj, jlong apiHandle,
jlong sessionHandle, jlong allocatorHandle,
jobjectArray inputNamesArr, jlongArray tensorArr,
jlong numInputs, jobjectArray outputNamesArr,
jlong numOutputs, jlong runOptionsHandle) {
// Create the buffers for the Java input and output strings
const char** inputNames;
checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator,sizeof(char*)*numInputs,(void**)&inputNames));
const char** outputNames;
checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator,sizeof(char*)*numOutputs,(void**)&outputNames));
jobject* javaInputStrings;
checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator,sizeof(jobject)*numInputs,(void**)&javaInputStrings));
jobject* javaOutputStrings;
checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator,sizeof(jobject)*numOutputs,(void**)&javaOutputStrings));
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
OrtSession* session = (OrtSession*)sessionHandle;
OrtRunOptions* runOptions = (OrtRunOptions*)runOptionsHandle;
// Extract a C array of longs which are pointers to the input tensors.
// Need to convert longs to OrtValue* in case we run on non-64bit systems
jlong* inputTensors = (*jniEnv)->GetLongArrayElements(jniEnv,tensorArr,NULL);
const OrtValue** inputValues;
checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator,sizeof(OrtValue*)*numInputs,(void**)&inputValues));
// Extract the names and native pointers of the input values.
for (int i = 0; i < numInputs; i++) {
javaInputStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv,inputNamesArr,i);
inputNames[i] = (*jniEnv)->GetStringUTFChars(jniEnv,javaInputStrings[i],NULL);
inputValues[i] = (OrtValue*)inputTensors[i];
}
// Extract the names of the output values, and allocate their output array.
OrtValue** outputValues;
checkOrtStatus(jniEnv,api,api->AllocatorAlloc(allocator,sizeof(OrtValue*)*numOutputs,(void**)&outputValues));
for (int i = 0; i < numOutputs; i++) {
javaOutputStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv,outputNamesArr,i);
outputNames[i] = (*jniEnv)->GetStringUTFChars(jniEnv,javaOutputStrings[i],NULL);
outputValues[i] = NULL;
}
// Actually score the inputs.
//printf("inputTensors = %p, first tensor = %p, numInputs = %ld, outputValues = %p, numOutputs = %ld\n",inputTensors,(OrtValue*)inputTensors[0],numInputs,outputValues,numOutputs);
//ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess, _In_ OrtRunOptions* run_options, _In_ const char* const* input_names, _In_ const OrtValue* const* input, size_t input_len, _In_ const char* const* output_names, size_t output_names_len, _Out_ OrtValue** output);
checkOrtStatus(jniEnv,api,api->Run(session, runOptions, (const char* const*) inputNames, (const OrtValue* const*) inputValues, numInputs, (const char* const*) outputNames, numOutputs, outputValues));
// Release the C array of pointers to the tensors.
(*jniEnv)->ReleaseLongArrayElements(jniEnv,tensorArr,inputTensors,JNI_ABORT);
// Construct the output array of ONNXValues
char *onnxValueClassName = "ai/onnxruntime/OnnxValue";
jclass onnxValueClass = (*jniEnv)->FindClass(jniEnv, onnxValueClassName);
jobjectArray outputArray = (*jniEnv)->NewObjectArray(jniEnv,safecast_int64_to_jsize(numOutputs), onnxValueClass, NULL);
// Convert the output tensors into ONNXValues and release the output strings.
for (int i = 0; i < numOutputs; i++) {
if (outputValues[i] != NULL) {
jobject onnxValue = convertOrtValueToONNXValue(jniEnv,api,allocator,outputValues[i]);
(*jniEnv)->SetObjectArrayElement(jniEnv,outputArray,i,onnxValue);
}
(*jniEnv)->ReleaseStringUTFChars(jniEnv,javaOutputStrings[i],outputNames[i]);
}
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,outputValues));
// Release the Java input strings
for (int i = 0; i < numInputs; i++) {
(*jniEnv)->ReleaseStringUTFChars(jniEnv,javaInputStrings[i],inputNames[i]);
}
// Release the buffers
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, (void*)inputNames));
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, (void*)inputValues));
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, (void*)outputNames));
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, javaInputStrings));
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, javaOutputStrings));
jobjectArray outputArray = NULL;
// Create the buffers for the Java input & output strings, and the input pointers
const char** inputNames = malloc(sizeof(char*) * numInputs);
if (inputNames == NULL) {
// Nothing to cleanup, return and throw exception
return outputArray;
}
}
const char** outputNames = malloc(sizeof(char*) * numOutputs);
if (outputNames == NULL) {
goto cleanup_input_names;
}
jobject* javaInputStrings = malloc(sizeof(jobject) * numInputs);
if (javaInputStrings == NULL) {
goto cleanup_output_names;
}
jobject* javaOutputStrings = malloc(sizeof(jobject) * numOutputs);
if (javaOutputStrings == NULL) {
goto cleanup_java_input_strings;
}
const OrtValue** inputValuePtrs = malloc(sizeof(OrtValue*) * numInputs);
if (inputValuePtrs == NULL) {
goto cleanup_java_output_strings;
}
OrtValue** outputValues = malloc(sizeof(OrtValue*) * numOutputs);
if (outputValues == NULL) {
goto cleanup_input_values;
}
// Extract a C array of longs which are pointers to the input tensors.
// The Java-side objects store native pointers as 64-bit longs, and on 32-bit systems
// we cannot cast the long array to a pointer array as they are different sizes,
// so we copy the longs applying the appropriate cast.
jlong* inputValueLongs = (*jniEnv)->GetLongArrayElements(jniEnv, tensorArr, NULL);
// Extract the names and native pointers of the input values.
for (int i = 0; i < numInputs; i++) {
javaInputStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv, inputNamesArr, i);
inputNames[i] = (*jniEnv)->GetStringUTFChars(jniEnv, javaInputStrings[i], NULL);
inputValuePtrs[i] = (OrtValue*)inputValueLongs[i];
}
// Release the java array copy of pointers to the tensors.
(*jniEnv)->ReleaseLongArrayElements(jniEnv, tensorArr, inputValueLongs, JNI_ABORT);
// Extract the names of the output values.
for (int i = 0; i < numOutputs; i++) {
javaOutputStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv, outputNamesArr, i);
outputNames[i] = (*jniEnv)->GetStringUTFChars(jniEnv, javaOutputStrings[i], NULL);
outputValues[i] = NULL;
}
// Actually score the inputs.
// ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess, _In_ OrtRunOptions* run_options,
// _In_ const char* const* input_names, _In_ const OrtValue* const* input, size_t input_len,
// _In_ const char* const* output_names, size_t output_names_len, _Out_ OrtValue** output);
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->Run(session, runOptions, (const char* const*)inputNames,
(const OrtValue* const*)inputValuePtrs, numInputs,
(const char* const*)outputNames, numOutputs, outputValues));
if (code != ORT_OK) {
goto cleanup_output_values;
}
// Construct the output array of ONNXValues
jclass onnxValueClass = (*jniEnv)->FindClass(jniEnv, ORTJNI_OnnxValueClassName);
outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_int64_to_jsize(numOutputs), onnxValueClass, NULL);
// Convert the output tensors into ONNXValues
for (int i = 0; i < numOutputs; i++) {
if (outputValues[i] != NULL) {
jobject onnxValue = convertOrtValueToONNXValue(jniEnv, api, allocator, outputValues[i]);
if (onnxValue == NULL) {
break; // go to cleanup, exception thrown
}
(*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, i, onnxValue);
}
}
// Note these gotos are in a specific order so they mirror the allocation pattern above.
// They must be changed if the allocation code is rearranged.
cleanup_output_values:
free(outputValues);
// Release the Java output strings
for (int i = 0; i < numOutputs; i++) {
(*jniEnv)->ReleaseStringUTFChars(jniEnv, javaOutputStrings[i], outputNames[i]);
}
// Release the Java input strings
for (int i = 0; i < numInputs; i++) {
(*jniEnv)->ReleaseStringUTFChars(jniEnv, javaInputStrings[i], inputNames[i]);
}
// Release the buffers
cleanup_input_values:
free((void*)inputValuePtrs);
cleanup_java_output_strings:
free(javaOutputStrings);
cleanup_java_input_strings:
free(javaInputStrings);
cleanup_output_names:
free((void*)outputNames);
cleanup_input_names:
free((void*)inputNames);
return outputArray;
}
/*
* Class: ai_onnxruntime_OrtSession
* Method: getProfilingStartTimeInNs
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_getProfilingStartTimeInNs
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtSession* session = (OrtSession*) sessionHandle;
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_getProfilingStartTimeInNs(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle) {
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtSession* session = (OrtSession*)sessionHandle;
uint64_t timestamp = 0;
checkOrtStatus(jniEnv,api,api->SessionGetProfilingStartTimeNs(session,&timestamp));
return (jlong) timestamp;
checkOrtStatus(jniEnv, api, api->SessionGetProfilingStartTimeNs(session, &timestamp));
return (jlong)timestamp;
}
/*
@ -342,16 +458,19 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_getProfilingStartTimeInNs
* Method: endProfiling
* Signature: (JJJ)Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_endProfiling
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_endProfiling(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle) {
(void)jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
char* profileStr;
checkOrtStatus(jniEnv,api,api->SessionEndProfiling((OrtSession*)handle,allocator,&profileStr));
jstring profileOutput = (*jniEnv)->NewStringUTF(jniEnv,profileStr);
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,profileStr));
char* profileStr = NULL;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->SessionEndProfiling((OrtSession*)handle, allocator,
&profileStr));
if (code != ORT_OK) {
return NULL;
}
jstring profileOutput = (*jniEnv)->NewStringUTF(jniEnv, profileStr);
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, profileStr));
return profileOutput;
}
@ -360,11 +479,10 @@ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_endProfiling
* Method: closeSession
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_closeSession
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void) jniEnv; (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
api->ReleaseSession((OrtSession*)handle);
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_closeSession(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void)jniEnv; (void)jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
api->ReleaseSession((OrtSession*)handle);
}
/*
@ -372,91 +490,139 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_closeSession
* Method: constructMetadata
* Signature: (JJJ)Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_constructMetadata
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle, jlong allocatorHandle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_constructMetadata(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle, jlong allocatorHandle) {
(void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtAllocator* allocator = (OrtAllocator*)allocatorHandle;
jobject metadataJava = NULL;
jstring producerStr = NULL;
jstring graphStr = NULL;
jstring graphDescStr = NULL;
jstring domainStr = NULL;
jstring descriptionStr = NULL;
// macro for processing char* into a Java UTF-8 string with error handling inside this function
#define STR_PROCESS(STR_NAME) \
if (code == ORT_OK) { \
STR_NAME = (*jniEnv)->NewStringUTF(jniEnv, charBuffer); \
code = checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, charBuffer)); \
if (code != ORT_OK) { \
goto release_metadata; \
} \
} else { \
goto release_metadata; \
}
// Setup
char* stringClassName = "java/lang/String";
jclass stringClazz = (*jniEnv)->FindClass(jniEnv, stringClassName);
char *metadataClassName = "ai/onnxruntime/OnnxModelMetadata";
jclass metadataClazz = (*jniEnv)->FindClass(jniEnv, metadataClassName);
//OnnxModelMetadata(String producerName, String graphName, String domain, String description, long version, String[] customMetadataArray)
jmethodID metadataConstructor = (*jniEnv)->GetMethodID(jniEnv, metadataClazz, "<init>",
"(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;J[Ljava/lang/String;)V");
jclass stringClazz = (*jniEnv)->FindClass(jniEnv, ORTJNI_StringClassName);
jclass metadataClazz = (*jniEnv)->FindClass(jniEnv, ORTJNI_MetadataClassName);
// OnnxModelMetadata(String producerName, String graphName, String domain, String description,
// long version, String[] customMetadataArray)
jmethodID metadataConstructor = (*jniEnv)->GetMethodID(
jniEnv, metadataClazz, "<init>",
"(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;J[Ljava/lang/String;)V");
// Get metadata
OrtModelMetadata* metadata;
checkOrtStatus(jniEnv,api,api->SessionGetModelMetadata((OrtSession*)nativeHandle,&metadata));
OrtModelMetadata* metadata = NULL;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->SessionGetModelMetadata((OrtSession*)nativeHandle, &metadata));
if (code != ORT_OK) {
// Nothing to cleanup, return null as an exception has been thrown
return NULL;
}
// Read out the producer name and convert it to a java.lang.String
char* charBuffer;
checkOrtStatus(jniEnv,api,api->ModelMetadataGetProducerName(metadata, allocator, &charBuffer));
jstring producerStr = (*jniEnv)->NewStringUTF(jniEnv,charBuffer);
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,charBuffer));
char* charBuffer = NULL;
code = checkOrtStatus(jniEnv, api, api->ModelMetadataGetProducerName(metadata, allocator, &charBuffer));
STR_PROCESS(producerStr)
// Read out the graph name and convert it to a java.lang.String
checkOrtStatus(jniEnv,api,api->ModelMetadataGetGraphName(metadata, allocator, &charBuffer));
jstring graphStr = (*jniEnv)->NewStringUTF(jniEnv,charBuffer);
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,charBuffer));
code = checkOrtStatus(jniEnv, api, api->ModelMetadataGetGraphName(metadata, allocator, &charBuffer));
STR_PROCESS(graphStr)
// Read out the graph description and convert it to a java.lang.String
checkOrtStatus(jniEnv,api,api->ModelMetadataGetGraphDescription(metadata, allocator, &charBuffer));
jstring graphDescStr = (*jniEnv)->NewStringUTF(jniEnv,charBuffer);
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,charBuffer));
code = checkOrtStatus(jniEnv, api, api->ModelMetadataGetGraphDescription(metadata, allocator, &charBuffer));
STR_PROCESS(graphDescStr)
// Read out the domain and convert it to a java.lang.String
checkOrtStatus(jniEnv,api,api->ModelMetadataGetDomain(metadata, allocator, &charBuffer));
jstring domainStr = (*jniEnv)->NewStringUTF(jniEnv,charBuffer);
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,charBuffer));
code = checkOrtStatus(jniEnv, api, api->ModelMetadataGetDomain(metadata, allocator, &charBuffer));
STR_PROCESS(domainStr)
// Read out the description and convert it to a java.lang.String
checkOrtStatus(jniEnv,api,api->ModelMetadataGetDescription(metadata, allocator, &charBuffer));
jstring descriptionStr = (*jniEnv)->NewStringUTF(jniEnv,charBuffer);
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,charBuffer));
code = checkOrtStatus(jniEnv, api, api->ModelMetadataGetDescription(metadata, allocator, &charBuffer));
STR_PROCESS(descriptionStr)
// Read out the version
int64_t version;
checkOrtStatus(jniEnv,api,api->ModelMetadataGetVersion(metadata, &version));
int64_t version = 0;
code = checkOrtStatus(jniEnv, api, api->ModelMetadataGetVersion(metadata, &version));
if (code != ORT_OK) {
goto release_metadata;
}
// Read out the keys, look up the values.
int64_t numKeys;
char** keys;
checkOrtStatus(jniEnv,api,api->ModelMetadataGetCustomMetadataMapKeys(metadata, allocator, &keys, &numKeys));
int64_t numKeys = 0;
char** keys = NULL;
code = checkOrtStatus(jniEnv, api, api->ModelMetadataGetCustomMetadataMapKeys(metadata, allocator, &keys, &numKeys));
if (code != ORT_OK) {
goto release_metadata;
}
jobjectArray customArray = NULL;
if (numKeys > 0) {
customArray = (*jniEnv)->NewObjectArray(jniEnv,safecast_int64_to_jsize(numKeys * 2),stringClazz, NULL);
customArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_int64_to_jsize(numKeys * 2), stringClazz, NULL);
// Iterate key array to extract the values
for (int64_t i = 0; i < numKeys; i++) {
// Create a java.lang.String for the key
jstring keyJava = (*jniEnv)->NewStringUTF(jniEnv,keys[i]);
jstring keyJava = (*jniEnv)->NewStringUTF(jniEnv, keys[i]);
// Extract the value and convert it to a java.lang.String
checkOrtStatus(jniEnv,api,api->ModelMetadataLookupCustomMetadataMap(metadata,allocator,keys[i],&charBuffer));
jstring valueJava = (*jniEnv)->NewStringUTF(jniEnv,charBuffer);
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,charBuffer));
code = checkOrtStatus(jniEnv, api, api->ModelMetadataLookupCustomMetadataMap(metadata, allocator, keys[i], &charBuffer));
jstring valueJava = NULL;
if (code == ORT_OK) {
valueJava = (*jniEnv)->NewStringUTF(jniEnv, charBuffer);
code = checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, charBuffer));
if (code != ORT_OK) {
// Signal that custom metadata extraction failed and break out
customArray = NULL;
break;
}
} else {
// Signal that custom metadata extraction failed and break out
customArray = NULL;
break;
}
// Write the key and value into the array
(*jniEnv)->SetObjectArrayElement(jniEnv,customArray,safecast_int64_to_jsize(i*2),keyJava);
(*jniEnv)->SetObjectArrayElement(jniEnv,customArray,safecast_int64_to_jsize((i * 2) + 1),valueJava);
(*jniEnv)->SetObjectArrayElement(jniEnv, customArray, safecast_int64_to_jsize(i * 2), keyJava);
(*jniEnv)->SetObjectArrayElement(jniEnv, customArray, safecast_int64_to_jsize((i * 2) + 1), valueJava);
}
// Release key array
for (int64_t i = 0; i < numKeys; i++) {
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,keys[i]));
code = checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, keys[i]));
if (code != ORT_OK) {
customArray = NULL;
break;
}
}
code = checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, keys));
if (code != ORT_OK) {
customArray = NULL;
}
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,keys));
} else {
customArray = (*jniEnv)->NewObjectArray(jniEnv,0,stringClazz,NULL);
customArray = (*jniEnv)->NewObjectArray(jniEnv, 0, stringClazz, NULL);
}
// Invoke the metadata constructor
//OnnxModelMetadata(String producerName, String graphName, String graphDescription, String domain, String description, long version, String[] customMetadataArray)
jobject metadataJava = (*jniEnv)->NewObject(jniEnv, metadataClazz, metadataConstructor, producerStr, graphStr, graphDescStr, domainStr, descriptionStr, (jlong) version, customArray);
if (customArray != NULL) {
// If the array is non-null then the custom metadata extraction completed successfully so
// we invoke the metadata constructor
// OnnxModelMetadata(String producerName, String graphName, String graphDescription, String domain,
// String description, long version, String[] customMetadataArray)
metadataJava = (*jniEnv)->NewObject(jniEnv, metadataClazz, metadataConstructor,
producerStr, graphStr, graphDescStr, domainStr, descriptionStr, (jlong)version,
customArray);
}
release_metadata:
// Release the metadata
api->ReleaseModelMetadata(metadata);