From 733db31420568fd47468eb9c7bd879dddb63f92c Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 16 Aug 2022 16:43:06 -0400 Subject: [PATCH] [Java] JNI refactor for OrtSession (#12496) Refactor JNI error reporting --- .../main/native/ai_onnxruntime_OrtSession.c | 742 +++++++++++------- 1 file changed, 454 insertions(+), 288 deletions(-) diff --git a/java/src/main/native/ai_onnxruntime_OrtSession.c b/java/src/main/native/ai_onnxruntime_OrtSession.c index de24f10f56..07354ffd1e 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession.c @@ -4,40 +4,47 @@ */ #include #include +#include #include "onnxruntime/core/session/onnxruntime_c_api.h" #include "OrtJniUtil.h" #include "ai_onnxruntime_OrtSession.h" +const char * const ORTJNI_StringClassName = "java/lang/String"; +const char * const ORTJNI_OnnxValueClassName = "ai/onnxruntime/OnnxValue"; +const char * const ORTJNI_NodeInfoClassName = "ai/onnxruntime/NodeInfo"; +const char * const ORTJNI_MetadataClassName = "ai/onnxruntime/OnnxModelMetadata"; + /* * Class: ai_onnxruntime_OrtSession * Method: createSession * Signature: (JJLjava/lang/String;J)J */ -JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_lang_String_2J - (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jstring modelPath, jlong optsHandle) { - (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtSession* session; +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_lang_String_2J(JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jstring modelPath, jlong optsHandle) { + (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtSession* session = NULL; #ifdef _WIN32 - const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, modelPath, NULL); - size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, modelPath); - wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t)); - if(newString == NULL) { - throwOrtException(jniEnv, 1, "Not enough memory"); - return 0; - } - wcsncpy_s(newString, stringLength+1, (const wchar_t*) cPath, stringLength); - checkOrtStatus(jniEnv,api,api->CreateSession((OrtEnv*)envHandle, newString, (OrtSessionOptions*)optsHandle, &session)); - free(newString); - (*jniEnv)->ReleaseStringChars(jniEnv,modelPath,cPath); + const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, modelPath, NULL); + size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, modelPath); + wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t)); + if (newString == NULL) { + (*jniEnv)->ReleaseStringChars(jniEnv, modelPath, cPath); + throwOrtException(jniEnv, 1, "Not enough memory"); + return 0; + } + wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength); + checkOrtStatus(jniEnv, api, + api->CreateSession((OrtEnv*)envHandle, newString, (OrtSessionOptions*)optsHandle, &session)); + free(newString); + (*jniEnv)->ReleaseStringChars(jniEnv, modelPath, cPath); #else - const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, modelPath, NULL); - checkOrtStatus(jniEnv,api,api->CreateSession((OrtEnv*)envHandle, cPath, (OrtSessionOptions*)optsHandle, &session)); - (*jniEnv)->ReleaseStringUTFChars(jniEnv,modelPath,cPath); + const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, modelPath, NULL); + checkOrtStatus(jniEnv, api, api->CreateSession((OrtEnv*)envHandle, cPath, (OrtSessionOptions*)optsHandle, &session)); + (*jniEnv)->ReleaseStringUTFChars(jniEnv, modelPath, cPath); #endif - return (jlong) session; + return (jlong)session; } /* @@ -45,34 +52,39 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_la * Method: createSession * Signature: (JJ[BJ)J */ -JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJ_3BJ - (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jbyteArray jModelArray, jlong optsHandle) { - (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtSession* session; +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJ_3BJ(JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jbyteArray jModelArray, jlong optsHandle) { + (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtEnv* env = (OrtEnv*)envHandle; + OrtSessionOptions* opts = (OrtSessionOptions*)optsHandle; + OrtSession* session = NULL; - // Get a reference to the byte array elements - jbyte* modelArr = (*jniEnv)->GetByteArrayElements(jniEnv,jModelArray,NULL); - size_t modelLength = (*jniEnv)->GetArrayLength(jniEnv,jModelArray); - checkOrtStatus(jniEnv,api,api->CreateSessionFromArray((OrtEnv*)envHandle, modelArr, modelLength, (OrtSessionOptions*)optsHandle, &session)); - // Release the C array. - (*jniEnv)->ReleaseByteArrayElements(jniEnv,jModelArray,modelArr,JNI_ABORT); - - return (jlong) session; + size_t modelLength = (*jniEnv)->GetArrayLength(jniEnv, jModelArray); + if (modelLength == 0) { + throwOrtException(jniEnv, 2, "Invalid ONNX model, the byte array is zero length."); + return 0; } + // Get a reference to the byte array elements + jbyte* modelArr = (*jniEnv)->GetByteArrayElements(jniEnv, jModelArray, NULL); + checkOrtStatus(jniEnv, api, api->CreateSessionFromArray(env, modelArr, modelLength, opts, &session)); + // Release the C array. + (*jniEnv)->ReleaseByteArrayElements(jniEnv, jModelArray, modelArr, JNI_ABORT); + + return (jlong)session; +} + /* * Class: ai_onnxruntime_OrtSession * Method: getNumInputs * Signature: (JJ)J */ -JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_getNumInputs - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - size_t numInputs; - checkOrtStatus(jniEnv,api,api->SessionGetInputCount((OrtSession*)handle, &numInputs)); - return numInputs; +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_getNumInputs(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong handle) { + (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + size_t numInputs = 0; + checkOrtStatus(jniEnv, api, api->SessionGetInputCount((OrtSession*)handle, &numInputs)); + return numInputs; } /* @@ -80,31 +92,47 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_getNumInputs * Method: getInputNames * Signature: (JJJ)[Ljava/lang/String; */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getInputNames - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; +JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getInputNames(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle) { + (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtAllocator* allocator = (OrtAllocator*)allocatorHandle; + OrtSession* session = (OrtSession*)sessionHandle; - // Setup - char *stringClassName = "java/lang/String"; - jclass stringClazz = (*jniEnv)->FindClass(jniEnv, stringClassName); + // Setup + jclass stringClazz = (*jniEnv)->FindClass(jniEnv, ORTJNI_StringClassName); - // Get the number of inputs - size_t numInputs = Java_ai_onnxruntime_OrtSession_getNumInputs(jniEnv, jobj, apiHandle, sessionHandle); + // Get the number of inputs + size_t numInputs = 0; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->SessionGetInputCount(session, &numInputs)); + if (code != ORT_OK) { + return NULL; + } - // Allocate the return array - jobjectArray array = (*jniEnv)->NewObjectArray(jniEnv,safecast_size_t_to_jsize(numInputs),stringClazz,NULL); - for (uint32_t i = 0; i < numInputs; i++) { - // Read out the input name and convert it to a java.lang.String - char* inputName; - checkOrtStatus(jniEnv,api,api->SessionGetInputName((OrtSession*)sessionHandle, i, allocator, &inputName)); - jstring name = (*jniEnv)->NewStringUTF(jniEnv,inputName); - (*jniEnv)->SetObjectArrayElement(jniEnv, array, i, name); - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,inputName)); + int32_t numInputsInt = (int32_t) numInputs; + if (numInputs != (size_t) numInputsInt) { + throwOrtException(jniEnv, 1, "Too many inputs, expected less than 2^31"); + } + + // Allocate the return array + jobjectArray array = (*jniEnv)->NewObjectArray(jniEnv, numInputsInt, stringClazz, NULL); + for (int32_t i = 0; i < numInputsInt; i++) { + // Read out the input name and convert it to a java.lang.String + char* inputName = NULL; + code = checkOrtStatus(jniEnv, api, api->SessionGetInputName(session, i, allocator, &inputName)); + if (code != ORT_OK) { + // break out on error, return array and let Java throw the exception. + break; } + jstring name = (*jniEnv)->NewStringUTF(jniEnv, inputName); + (*jniEnv)->SetObjectArrayElement(jniEnv, array, i, name); + code = checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, inputName)); + if (code != ORT_OK) { + // break out on error, return array and let Java throw the exception. + break; + } + } - return array; + return array; } /* @@ -112,13 +140,12 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getInputNames * Method: getNumOutputs * Signature: (JJ)J */ -JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_getNumOutputs - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - size_t numOutputs; - checkOrtStatus(jniEnv,api,api->SessionGetOutputCount((OrtSession*)handle, &numOutputs)); - return numOutputs; +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_getNumOutputs(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong handle) { + (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + size_t numOutputs = 0; + checkOrtStatus(jniEnv, api, api->SessionGetOutputCount((OrtSession*)handle, &numOutputs)); + return numOutputs; } /* @@ -126,31 +153,47 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_getNumOutputs * Method: getOutputNames * Signature: (JJJ)[Ljava/lang/String; */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getOutputNames - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; +JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getOutputNames(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle) { + (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtSession* session = (OrtSession*)sessionHandle; + OrtAllocator* allocator = (OrtAllocator*)allocatorHandle; - // Setup - char *stringClassName = "java/lang/String"; - jclass stringClazz = (*jniEnv)->FindClass(jniEnv, stringClassName); + // Setup + jclass stringClazz = (*jniEnv)->FindClass(jniEnv, ORTJNI_StringClassName); - // Get the number of outputs - size_t numOutputs = Java_ai_onnxruntime_OrtSession_getNumOutputs(jniEnv, jobj, apiHandle, sessionHandle); + // Get the number of outputs + size_t numOutputs = 0; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->SessionGetOutputCount(session, &numOutputs)); + if (code != ORT_OK) { + return NULL; + } - // Allocate the return array - jobjectArray array = (*jniEnv)->NewObjectArray(jniEnv,safecast_size_t_to_jsize(numOutputs),stringClazz, NULL); - for (uint32_t i = 0; i < numOutputs; i++) { - // Read out the output name and convert it to a java.lang.String - char* outputName; - checkOrtStatus(jniEnv,api,api->SessionGetOutputName((OrtSession*)sessionHandle, i, allocator, &outputName)); - jstring name = (*jniEnv)->NewStringUTF(jniEnv,outputName); - (*jniEnv)->SetObjectArrayElement(jniEnv, array, i, name); - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,outputName)); + int32_t numOutputsInt = (int32_t) numOutputs; + if (numOutputs != (size_t) numOutputsInt) { + throwOrtException(jniEnv, 1, "Too many outputs, expected less than 2^31"); + } + + // Allocate the return array + jobjectArray array = (*jniEnv)->NewObjectArray(jniEnv, numOutputsInt, stringClazz, NULL); + for (int32_t i = 0; i < numOutputsInt; i++) { + // Read out the output name and convert it to a java.lang.String + char* outputName = NULL; + code = checkOrtStatus(jniEnv, api, api->SessionGetOutputName(session, i, allocator, &outputName)); + if (code != ORT_OK) { + // break out on error, return array and let Java throw the exception. + break; } + jstring name = (*jniEnv)->NewStringUTF(jniEnv, outputName); + (*jniEnv)->SetObjectArrayElement(jniEnv, array, i, name); + code = checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, outputName)); + if (code != ORT_OK) { + // break out on error, return array and let Java throw the exception. + break; + } + } - return array; + return array; } /* @@ -158,41 +201,57 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getOutputNames * Method: getInputInfo * Signature: (JJJ)[Lai/onnxruntime/NodeInfo; */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getInputInfo - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; +JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getInputInfo(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle) { + (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtSession* session = (OrtSession*)sessionHandle; + OrtAllocator* allocator = (OrtAllocator*)allocatorHandle; - // Setup - char *nodeInfoClassName = "ai/onnxruntime/NodeInfo"; - jclass nodeInfoClazz = (*jniEnv)->FindClass(jniEnv, nodeInfoClassName); - jmethodID nodeInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,nodeInfoClazz, "", "(Ljava/lang/String;Lai/onnxruntime/ValueInfo;)V"); + // Setup + jclass nodeInfoClazz = (*jniEnv)->FindClass(jniEnv, ORTJNI_NodeInfoClassName); + jmethodID nodeInfoConstructor = (*jniEnv)->GetMethodID(jniEnv, nodeInfoClazz, "", + "(Ljava/lang/String;Lai/onnxruntime/ValueInfo;)V"); - // Get the number of inputs - size_t numInputs = Java_ai_onnxruntime_OrtSession_getNumInputs(jniEnv, jobj, apiHandle, sessionHandle); + // Get the number of inputs + size_t numInputs = 0; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->SessionGetInputCount(session, &numInputs)); + if (code != ORT_OK) { + return NULL; + } - // Allocate the return array - jobjectArray array = (*jniEnv)->NewObjectArray(jniEnv,safecast_size_t_to_jsize(numInputs),nodeInfoClazz, NULL); - for (size_t i = 0; i < numInputs; i++) { - // Read out the input name and convert it to a java.lang.String - char* inputName; - checkOrtStatus(jniEnv,api,api->SessionGetInputName((OrtSession*)sessionHandle, i, allocator, &inputName)); - jstring name = (*jniEnv)->NewStringUTF(jniEnv,inputName); - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,inputName)); - - // Create a ValueInfo from the OrtTypeInfo - OrtTypeInfo* typeInfo; - checkOrtStatus(jniEnv,api,api->SessionGetInputTypeInfo((OrtSession*)sessionHandle, i, &typeInfo)); - jobject valueInfoJava = convertToValueInfo(jniEnv,api,typeInfo); - api->ReleaseTypeInfo(typeInfo); - - // Create a NodeInfo and assign into the array - jobject nodeInfo = (*jniEnv)->NewObject(jniEnv, nodeInfoClazz, nodeInfoConstructor, name, valueInfoJava); - (*jniEnv)->SetObjectArrayElement(jniEnv, array,safecast_size_t_to_jsize(i),nodeInfo); + // Allocate the return array + jobjectArray array = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(numInputs), nodeInfoClazz, NULL); + for (size_t i = 0; i < numInputs; i++) { + // Read out the input name and convert it to a java.lang.String + char* inputName = NULL; + code = checkOrtStatus(jniEnv, api, api->SessionGetInputName(session, i, allocator, &inputName)); + if (code != ORT_OK) { + break; + } + jstring name = (*jniEnv)->NewStringUTF(jniEnv, inputName); + code = checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, inputName)); + if (code != ORT_OK) { + break; } - return array; + // Create a ValueInfo from the OrtTypeInfo + OrtTypeInfo* typeInfo = NULL; + code = checkOrtStatus(jniEnv, api, api->SessionGetInputTypeInfo(session, i, &typeInfo)); + if (code != ORT_OK) { + break; + } + jobject valueInfoJava = convertToValueInfo(jniEnv, api, typeInfo); + api->ReleaseTypeInfo(typeInfo); + if (valueInfoJava == NULL) { + break; + } + + // Create a NodeInfo and assign into the array + jobject nodeInfo = (*jniEnv)->NewObject(jniEnv, nodeInfoClazz, nodeInfoConstructor, name, valueInfoJava); + (*jniEnv)->SetObjectArrayElement(jniEnv, array, safecast_size_t_to_jsize(i), nodeInfo); + } + + return array; } /* @@ -200,40 +259,57 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getInputInfo * Method: getOutputInfo * Signature: (JJJ)[Lai/onnxruntime/NodeInfo; */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getOutputInfo - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; - // Setup - char *nodeInfoClassName = "ai/onnxruntime/NodeInfo"; - jclass nodeInfoClazz = (*jniEnv)->FindClass(jniEnv, nodeInfoClassName); - jmethodID nodeInfoConstructor = (*jniEnv)->GetMethodID(jniEnv, nodeInfoClazz, "", "(Ljava/lang/String;Lai/onnxruntime/ValueInfo;)V"); +JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getOutputInfo(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle) { + (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtSession* session = (OrtSession*)sessionHandle; + OrtAllocator* allocator = (OrtAllocator*)allocatorHandle; - // Get the number of outputs - size_t numOutputs = Java_ai_onnxruntime_OrtSession_getNumOutputs(jniEnv, jobj, apiHandle, sessionHandle); + // Setup + jclass nodeInfoClazz = (*jniEnv)->FindClass(jniEnv, ORTJNI_NodeInfoClassName); + jmethodID nodeInfoConstructor = (*jniEnv)->GetMethodID(jniEnv, nodeInfoClazz, "", + "(Ljava/lang/String;Lai/onnxruntime/ValueInfo;)V"); - // Allocate the return array - jobjectArray array = (*jniEnv)->NewObjectArray(jniEnv,safecast_size_t_to_jsize(numOutputs),nodeInfoClazz,NULL); - for (uint32_t i = 0; i < numOutputs; i++) { - // Read out the output name and convert it to a java.lang.String - char* outputName; - checkOrtStatus(jniEnv,api,api->SessionGetOutputName((OrtSession*)sessionHandle, i, allocator, &outputName)); - jstring name = (*jniEnv)->NewStringUTF(jniEnv,outputName); - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,outputName)); + // Get the number of outputs + size_t numOutputs = 0; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->SessionGetOutputCount(session, &numOutputs)); + if (code != ORT_OK) { + return NULL; + } - // Create a ValueInfo from the OrtTypeInfo - OrtTypeInfo* typeInfo; - checkOrtStatus(jniEnv,api,api->SessionGetOutputTypeInfo((OrtSession*)sessionHandle, i, &typeInfo)); - jobject valueInfoJava = convertToValueInfo(jniEnv,api,typeInfo); - api->ReleaseTypeInfo(typeInfo); - - // Create a NodeInfo and assign into the array - jobject nodeInfo = (*jniEnv)->NewObject(jniEnv, nodeInfoClazz, nodeInfoConstructor, name, valueInfoJava); - (*jniEnv)->SetObjectArrayElement(jniEnv, array, i, nodeInfo); + // Allocate the return array + jobjectArray array = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(numOutputs), nodeInfoClazz, NULL); + for (uint32_t i = 0; i < numOutputs; i++) { + // Read out the output name and convert it to a java.lang.String + char* outputName = NULL; + code = checkOrtStatus(jniEnv, api, api->SessionGetOutputName(session, i, allocator, &outputName)); + if (code != ORT_OK) { + break; + } + jstring name = (*jniEnv)->NewStringUTF(jniEnv, outputName); + code = checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, outputName)); + if (code != ORT_OK) { + break; } - return array; + // Create a ValueInfo from the OrtTypeInfo + OrtTypeInfo* typeInfo = NULL; + code = checkOrtStatus(jniEnv, api, api->SessionGetOutputTypeInfo(session, i, &typeInfo)); + if (code != ORT_OK) { + break; + } + jobject valueInfoJava = convertToValueInfo(jniEnv, api, typeInfo); + api->ReleaseTypeInfo(typeInfo); + if (valueInfoJava == NULL) { + break; + } + + // Create a NodeInfo and assign into the array + jobject nodeInfo = (*jniEnv)->NewObject(jniEnv, nodeInfoClazz, nodeInfoConstructor, name, valueInfoJava); + (*jniEnv)->SetObjectArrayElement(jniEnv, array, i, nodeInfo); + } + + return array; } /* @@ -242,99 +318,139 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getOutputInfo * Signature: (JJJ[Ljava/lang/String;[JJ[Ljava/lang/String;JJ)[Lai/onnxruntime/OnnxValue; * private native OnnxValue[] run(long apiHandle, long nativeHandle, long allocatorHandle, String[] inputNamesArray, long[] inputs, long numInputs, String[] outputNamesArray, long numOutputs) */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle, jobjectArray inputNamesArr, jlongArray tensorArr, jlong numInputs, jobjectArray outputNamesArr, jlong numOutputs, jlong runOptionsHandle) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; - OrtSession* session = (OrtSession*) sessionHandle; - OrtRunOptions* runOptions = (OrtRunOptions*) runOptionsHandle; +JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, + jlong sessionHandle, jlong allocatorHandle, + jobjectArray inputNamesArr, jlongArray tensorArr, + jlong numInputs, jobjectArray outputNamesArr, + jlong numOutputs, jlong runOptionsHandle) { - // Create the buffers for the Java input and output strings - const char** inputNames; - checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator,sizeof(char*)*numInputs,(void**)&inputNames)); - const char** outputNames; - checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator,sizeof(char*)*numOutputs,(void**)&outputNames)); - jobject* javaInputStrings; - checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator,sizeof(jobject)*numInputs,(void**)&javaInputStrings)); - jobject* javaOutputStrings; - checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator,sizeof(jobject)*numOutputs,(void**)&javaOutputStrings)); + (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtAllocator* allocator = (OrtAllocator*)allocatorHandle; + OrtSession* session = (OrtSession*)sessionHandle; + OrtRunOptions* runOptions = (OrtRunOptions*)runOptionsHandle; - // Extract a C array of longs which are pointers to the input tensors. - // Need to convert longs to OrtValue* in case we run on non-64bit systems - jlong* inputTensors = (*jniEnv)->GetLongArrayElements(jniEnv,tensorArr,NULL); - const OrtValue** inputValues; - checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator,sizeof(OrtValue*)*numInputs,(void**)&inputValues)); - - // Extract the names and native pointers of the input values. - for (int i = 0; i < numInputs; i++) { - javaInputStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv,inputNamesArr,i); - inputNames[i] = (*jniEnv)->GetStringUTFChars(jniEnv,javaInputStrings[i],NULL); - inputValues[i] = (OrtValue*)inputTensors[i]; - } - - // Extract the names of the output values, and allocate their output array. - OrtValue** outputValues; - checkOrtStatus(jniEnv,api,api->AllocatorAlloc(allocator,sizeof(OrtValue*)*numOutputs,(void**)&outputValues)); - for (int i = 0; i < numOutputs; i++) { - javaOutputStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv,outputNamesArr,i); - outputNames[i] = (*jniEnv)->GetStringUTFChars(jniEnv,javaOutputStrings[i],NULL); - outputValues[i] = NULL; - } - - // Actually score the inputs. - //printf("inputTensors = %p, first tensor = %p, numInputs = %ld, outputValues = %p, numOutputs = %ld\n",inputTensors,(OrtValue*)inputTensors[0],numInputs,outputValues,numOutputs); - //ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess, _In_ OrtRunOptions* run_options, _In_ const char* const* input_names, _In_ const OrtValue* const* input, size_t input_len, _In_ const char* const* output_names, size_t output_names_len, _Out_ OrtValue** output); - checkOrtStatus(jniEnv,api,api->Run(session, runOptions, (const char* const*) inputNames, (const OrtValue* const*) inputValues, numInputs, (const char* const*) outputNames, numOutputs, outputValues)); - // Release the C array of pointers to the tensors. - (*jniEnv)->ReleaseLongArrayElements(jniEnv,tensorArr,inputTensors,JNI_ABORT); - - // Construct the output array of ONNXValues - char *onnxValueClassName = "ai/onnxruntime/OnnxValue"; - jclass onnxValueClass = (*jniEnv)->FindClass(jniEnv, onnxValueClassName); - jobjectArray outputArray = (*jniEnv)->NewObjectArray(jniEnv,safecast_int64_to_jsize(numOutputs), onnxValueClass, NULL); - - // Convert the output tensors into ONNXValues and release the output strings. - for (int i = 0; i < numOutputs; i++) { - if (outputValues[i] != NULL) { - jobject onnxValue = convertOrtValueToONNXValue(jniEnv,api,allocator,outputValues[i]); - (*jniEnv)->SetObjectArrayElement(jniEnv,outputArray,i,onnxValue); - } - (*jniEnv)->ReleaseStringUTFChars(jniEnv,javaOutputStrings[i],outputNames[i]); - } - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,outputValues)); - - // Release the Java input strings - for (int i = 0; i < numInputs; i++) { - (*jniEnv)->ReleaseStringUTFChars(jniEnv,javaInputStrings[i],inputNames[i]); - } - - // Release the buffers - checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, (void*)inputNames)); - checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, (void*)inputValues)); - checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, (void*)outputNames)); - checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, javaInputStrings)); - checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, javaOutputStrings)); + jobjectArray outputArray = NULL; + // Create the buffers for the Java input & output strings, and the input pointers + const char** inputNames = malloc(sizeof(char*) * numInputs); + if (inputNames == NULL) { + // Nothing to cleanup, return and throw exception return outputArray; -} + } + const char** outputNames = malloc(sizeof(char*) * numOutputs); + if (outputNames == NULL) { + goto cleanup_input_names; + } + jobject* javaInputStrings = malloc(sizeof(jobject) * numInputs); + if (javaInputStrings == NULL) { + goto cleanup_output_names; + } + jobject* javaOutputStrings = malloc(sizeof(jobject) * numOutputs); + if (javaOutputStrings == NULL) { + goto cleanup_java_input_strings; + } + const OrtValue** inputValuePtrs = malloc(sizeof(OrtValue*) * numInputs); + if (inputValuePtrs == NULL) { + goto cleanup_java_output_strings; + } + OrtValue** outputValues = malloc(sizeof(OrtValue*) * numOutputs); + if (outputValues == NULL) { + goto cleanup_input_values; + } + // Extract a C array of longs which are pointers to the input tensors. + // The Java-side objects store native pointers as 64-bit longs, and on 32-bit systems + // we cannot cast the long array to a pointer array as they are different sizes, + // so we copy the longs applying the appropriate cast. + jlong* inputValueLongs = (*jniEnv)->GetLongArrayElements(jniEnv, tensorArr, NULL); + + // Extract the names and native pointers of the input values. + for (int i = 0; i < numInputs; i++) { + javaInputStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv, inputNamesArr, i); + inputNames[i] = (*jniEnv)->GetStringUTFChars(jniEnv, javaInputStrings[i], NULL); + inputValuePtrs[i] = (OrtValue*)inputValueLongs[i]; + } + + // Release the java array copy of pointers to the tensors. + (*jniEnv)->ReleaseLongArrayElements(jniEnv, tensorArr, inputValueLongs, JNI_ABORT); + + // Extract the names of the output values. + for (int i = 0; i < numOutputs; i++) { + javaOutputStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv, outputNamesArr, i); + outputNames[i] = (*jniEnv)->GetStringUTFChars(jniEnv, javaOutputStrings[i], NULL); + outputValues[i] = NULL; + } + + // Actually score the inputs. + // ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess, _In_ OrtRunOptions* run_options, + // _In_ const char* const* input_names, _In_ const OrtValue* const* input, size_t input_len, + // _In_ const char* const* output_names, size_t output_names_len, _Out_ OrtValue** output); + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->Run(session, runOptions, (const char* const*)inputNames, + (const OrtValue* const*)inputValuePtrs, numInputs, + (const char* const*)outputNames, numOutputs, outputValues)); + if (code != ORT_OK) { + goto cleanup_output_values; + } + + // Construct the output array of ONNXValues + jclass onnxValueClass = (*jniEnv)->FindClass(jniEnv, ORTJNI_OnnxValueClassName); + outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_int64_to_jsize(numOutputs), onnxValueClass, NULL); + + // Convert the output tensors into ONNXValues + for (int i = 0; i < numOutputs; i++) { + if (outputValues[i] != NULL) { + jobject onnxValue = convertOrtValueToONNXValue(jniEnv, api, allocator, outputValues[i]); + if (onnxValue == NULL) { + break; // go to cleanup, exception thrown + } + (*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, i, onnxValue); + } + } + + // Note these gotos are in a specific order so they mirror the allocation pattern above. + // They must be changed if the allocation code is rearranged. +cleanup_output_values: + free(outputValues); + + // Release the Java output strings + for (int i = 0; i < numOutputs; i++) { + (*jniEnv)->ReleaseStringUTFChars(jniEnv, javaOutputStrings[i], outputNames[i]); + } + + // Release the Java input strings + for (int i = 0; i < numInputs; i++) { + (*jniEnv)->ReleaseStringUTFChars(jniEnv, javaInputStrings[i], inputNames[i]); + } + + // Release the buffers +cleanup_input_values: + free((void*)inputValuePtrs); +cleanup_java_output_strings: + free(javaOutputStrings); +cleanup_java_input_strings: + free(javaInputStrings); +cleanup_output_names: + free((void*)outputNames); +cleanup_input_names: + free((void*)inputNames); + + return outputArray; +} /* * Class: ai_onnxruntime_OrtSession * Method: getProfilingStartTimeInNs * Signature: (JJ)J */ -JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_getProfilingStartTimeInNs - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtSession* session = (OrtSession*) sessionHandle; +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_getProfilingStartTimeInNs(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle) { + (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtSession* session = (OrtSession*)sessionHandle; uint64_t timestamp = 0; - - checkOrtStatus(jniEnv,api,api->SessionGetProfilingStartTimeNs(session,×tamp)); - return (jlong) timestamp; + checkOrtStatus(jniEnv, api, api->SessionGetProfilingStartTimeNs(session, ×tamp)); + return (jlong)timestamp; } /* @@ -342,16 +458,19 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_getProfilingStartTimeInNs * Method: endProfiling * Signature: (JJJ)Ljava/lang/String; */ -JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_endProfiling - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle) { - (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; +JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_endProfiling(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle) { + (void)jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtAllocator* allocator = (OrtAllocator*)allocatorHandle; - char* profileStr; - checkOrtStatus(jniEnv,api,api->SessionEndProfiling((OrtSession*)handle,allocator,&profileStr)); - jstring profileOutput = (*jniEnv)->NewStringUTF(jniEnv,profileStr); - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,profileStr)); + char* profileStr = NULL; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->SessionEndProfiling((OrtSession*)handle, allocator, + &profileStr)); + if (code != ORT_OK) { + return NULL; + } + jstring profileOutput = (*jniEnv)->NewStringUTF(jniEnv, profileStr); + checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, profileStr)); return profileOutput; } @@ -360,11 +479,10 @@ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_endProfiling * Method: closeSession * Signature: (J)V */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_closeSession - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) { - (void) jniEnv; (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - api->ReleaseSession((OrtSession*)handle); +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_closeSession(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong handle) { + (void)jniEnv; (void)jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + api->ReleaseSession((OrtSession*)handle); } /* @@ -372,91 +490,139 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_closeSession * Method: constructMetadata * Signature: (JJJ)Ljava/lang/String; */ -JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_constructMetadata - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle, jlong allocatorHandle) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; +JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_constructMetadata(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle, jlong allocatorHandle) { + (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtAllocator* allocator = (OrtAllocator*)allocatorHandle; + jobject metadataJava = NULL; + jstring producerStr = NULL; + jstring graphStr = NULL; + jstring graphDescStr = NULL; + jstring domainStr = NULL; + jstring descriptionStr = NULL; + + // macro for processing char* into a Java UTF-8 string with error handling inside this function +#define STR_PROCESS(STR_NAME) \ + if (code == ORT_OK) { \ + STR_NAME = (*jniEnv)->NewStringUTF(jniEnv, charBuffer); \ + code = checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, charBuffer)); \ + if (code != ORT_OK) { \ + goto release_metadata; \ + } \ + } else { \ + goto release_metadata; \ + } // Setup - char* stringClassName = "java/lang/String"; - jclass stringClazz = (*jniEnv)->FindClass(jniEnv, stringClassName); - char *metadataClassName = "ai/onnxruntime/OnnxModelMetadata"; - jclass metadataClazz = (*jniEnv)->FindClass(jniEnv, metadataClassName); - //OnnxModelMetadata(String producerName, String graphName, String domain, String description, long version, String[] customMetadataArray) - jmethodID metadataConstructor = (*jniEnv)->GetMethodID(jniEnv, metadataClazz, "", - "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;J[Ljava/lang/String;)V"); + jclass stringClazz = (*jniEnv)->FindClass(jniEnv, ORTJNI_StringClassName); + jclass metadataClazz = (*jniEnv)->FindClass(jniEnv, ORTJNI_MetadataClassName); + // OnnxModelMetadata(String producerName, String graphName, String domain, String description, + // long version, String[] customMetadataArray) + jmethodID metadataConstructor = (*jniEnv)->GetMethodID( + jniEnv, metadataClazz, "", + "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;J[Ljava/lang/String;)V"); // Get metadata - OrtModelMetadata* metadata; - checkOrtStatus(jniEnv,api,api->SessionGetModelMetadata((OrtSession*)nativeHandle,&metadata)); + OrtModelMetadata* metadata = NULL; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->SessionGetModelMetadata((OrtSession*)nativeHandle, &metadata)); + if (code != ORT_OK) { + // Nothing to cleanup, return null as an exception has been thrown + return NULL; + } // Read out the producer name and convert it to a java.lang.String - char* charBuffer; - checkOrtStatus(jniEnv,api,api->ModelMetadataGetProducerName(metadata, allocator, &charBuffer)); - jstring producerStr = (*jniEnv)->NewStringUTF(jniEnv,charBuffer); - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,charBuffer)); + char* charBuffer = NULL; + code = checkOrtStatus(jniEnv, api, api->ModelMetadataGetProducerName(metadata, allocator, &charBuffer)); + STR_PROCESS(producerStr) // Read out the graph name and convert it to a java.lang.String - checkOrtStatus(jniEnv,api,api->ModelMetadataGetGraphName(metadata, allocator, &charBuffer)); - jstring graphStr = (*jniEnv)->NewStringUTF(jniEnv,charBuffer); - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,charBuffer)); + code = checkOrtStatus(jniEnv, api, api->ModelMetadataGetGraphName(metadata, allocator, &charBuffer)); + STR_PROCESS(graphStr) // Read out the graph description and convert it to a java.lang.String - checkOrtStatus(jniEnv,api,api->ModelMetadataGetGraphDescription(metadata, allocator, &charBuffer)); - jstring graphDescStr = (*jniEnv)->NewStringUTF(jniEnv,charBuffer); - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,charBuffer)); + code = checkOrtStatus(jniEnv, api, api->ModelMetadataGetGraphDescription(metadata, allocator, &charBuffer)); + STR_PROCESS(graphDescStr) // Read out the domain and convert it to a java.lang.String - checkOrtStatus(jniEnv,api,api->ModelMetadataGetDomain(metadata, allocator, &charBuffer)); - jstring domainStr = (*jniEnv)->NewStringUTF(jniEnv,charBuffer); - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,charBuffer)); + code = checkOrtStatus(jniEnv, api, api->ModelMetadataGetDomain(metadata, allocator, &charBuffer)); + STR_PROCESS(domainStr) // Read out the description and convert it to a java.lang.String - checkOrtStatus(jniEnv,api,api->ModelMetadataGetDescription(metadata, allocator, &charBuffer)); - jstring descriptionStr = (*jniEnv)->NewStringUTF(jniEnv,charBuffer); - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,charBuffer)); + code = checkOrtStatus(jniEnv, api, api->ModelMetadataGetDescription(metadata, allocator, &charBuffer)); + STR_PROCESS(descriptionStr) // Read out the version - int64_t version; - checkOrtStatus(jniEnv,api,api->ModelMetadataGetVersion(metadata, &version)); + int64_t version = 0; + code = checkOrtStatus(jniEnv, api, api->ModelMetadataGetVersion(metadata, &version)); + if (code != ORT_OK) { + goto release_metadata; + } // Read out the keys, look up the values. - int64_t numKeys; - char** keys; - checkOrtStatus(jniEnv,api,api->ModelMetadataGetCustomMetadataMapKeys(metadata, allocator, &keys, &numKeys)); + int64_t numKeys = 0; + char** keys = NULL; + code = checkOrtStatus(jniEnv, api, api->ModelMetadataGetCustomMetadataMapKeys(metadata, allocator, &keys, &numKeys)); + if (code != ORT_OK) { + goto release_metadata; + } jobjectArray customArray = NULL; if (numKeys > 0) { - customArray = (*jniEnv)->NewObjectArray(jniEnv,safecast_int64_to_jsize(numKeys * 2),stringClazz, NULL); + customArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_int64_to_jsize(numKeys * 2), stringClazz, NULL); // Iterate key array to extract the values for (int64_t i = 0; i < numKeys; i++) { // Create a java.lang.String for the key - jstring keyJava = (*jniEnv)->NewStringUTF(jniEnv,keys[i]); + jstring keyJava = (*jniEnv)->NewStringUTF(jniEnv, keys[i]); // Extract the value and convert it to a java.lang.String - checkOrtStatus(jniEnv,api,api->ModelMetadataLookupCustomMetadataMap(metadata,allocator,keys[i],&charBuffer)); - jstring valueJava = (*jniEnv)->NewStringUTF(jniEnv,charBuffer); - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,charBuffer)); + code = checkOrtStatus(jniEnv, api, api->ModelMetadataLookupCustomMetadataMap(metadata, allocator, keys[i], &charBuffer)); + jstring valueJava = NULL; + if (code == ORT_OK) { + valueJava = (*jniEnv)->NewStringUTF(jniEnv, charBuffer); + code = checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, charBuffer)); + if (code != ORT_OK) { + // Signal that custom metadata extraction failed and break out + customArray = NULL; + break; + } + } else { + // Signal that custom metadata extraction failed and break out + customArray = NULL; + break; + } // Write the key and value into the array - (*jniEnv)->SetObjectArrayElement(jniEnv,customArray,safecast_int64_to_jsize(i*2),keyJava); - (*jniEnv)->SetObjectArrayElement(jniEnv,customArray,safecast_int64_to_jsize((i * 2) + 1),valueJava); + (*jniEnv)->SetObjectArrayElement(jniEnv, customArray, safecast_int64_to_jsize(i * 2), keyJava); + (*jniEnv)->SetObjectArrayElement(jniEnv, customArray, safecast_int64_to_jsize((i * 2) + 1), valueJava); } // Release key array for (int64_t i = 0; i < numKeys; i++) { - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,keys[i])); + code = checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, keys[i])); + if (code != ORT_OK) { + customArray = NULL; + break; + } + } + code = checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, keys)); + if (code != ORT_OK) { + customArray = NULL; } - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,keys)); } else { - customArray = (*jniEnv)->NewObjectArray(jniEnv,0,stringClazz,NULL); + customArray = (*jniEnv)->NewObjectArray(jniEnv, 0, stringClazz, NULL); } - // Invoke the metadata constructor - //OnnxModelMetadata(String producerName, String graphName, String graphDescription, String domain, String description, long version, String[] customMetadataArray) - jobject metadataJava = (*jniEnv)->NewObject(jniEnv, metadataClazz, metadataConstructor, producerStr, graphStr, graphDescStr, domainStr, descriptionStr, (jlong) version, customArray); + if (customArray != NULL) { + // If the array is non-null then the custom metadata extraction completed successfully so + // we invoke the metadata constructor + // OnnxModelMetadata(String producerName, String graphName, String graphDescription, String domain, + // String description, long version, String[] customMetadataArray) + metadataJava = (*jniEnv)->NewObject(jniEnv, metadataClazz, metadataConstructor, + producerStr, graphStr, graphDescStr, domainStr, descriptionStr, (jlong)version, + customArray); + } +release_metadata: // Release the metadata api->ReleaseModelMetadata(metadata);