diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 43a2ecb0e5..81d546ffa5 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -86,7 +86,7 @@ jobs: github_token: ${{ secrets.github_token }} reporter: github-pr-check level: warning - flags: --linelength=120 + flags: --linelength=120 --exclude=java/src/main/native/*.c filter: "-runtime/references" lint-js: diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index 1f5b9f109a..55f26f35b4 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2020 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2022 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -183,6 +183,31 @@ size_t onnxTypeSize(ONNXTensorElementDataType type) { } } +OrtErrorCode getTensorTypeShape(JNIEnv * jniEnv, JavaTensorTypeShape* output, const OrtApi * api, const OrtValue * value) { + OrtTensorTypeAndShapeInfo* info; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorTypeAndShape(value, &info)); + if (code != ORT_OK) { + return code; + } + code = checkOrtStatus(jniEnv, api, api->GetDimensionsCount(info, &output->dimensions)); + if (code != ORT_OK) { + api->ReleaseTensorTypeAndShapeInfo(info); + return code; + } + code = checkOrtStatus(jniEnv, api, api->GetTensorShapeElementCount(info, &output->elementCount)); + if (code != ORT_OK) { + api->ReleaseTensorTypeAndShapeInfo(info); + return code; + } + code = checkOrtStatus(jniEnv, api, api->GetTensorElementType(info, &output->onnxTypeEnum)); + api->ReleaseTensorTypeAndShapeInfo(info); + if (code != ORT_OK) { + return code; + } + + return ORT_OK; +} + typedef union FP32 { int intVal; float floatVal; @@ -1038,6 +1063,7 @@ OrtErrorCode checkOrtStatus(JNIEnv *jniEnv, const OrtApi * api, OrtStatus * stat size_t len = strlen(message)+1; char* copy = malloc(sizeof(char)*len); if (copy == NULL) { + api->ReleaseStatus(status); throwOrtException(jniEnv, 1, "Not enough memory"); return ORT_FAIL; } diff --git a/java/src/main/native/OrtJniUtil.h b/java/src/main/native/OrtJniUtil.h index f9cbeeca82..075202333e 100644 --- a/java/src/main/native/OrtJniUtil.h +++ b/java/src/main/native/OrtJniUtil.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -11,9 +11,14 @@ extern "C" { #endif -jsize safecast_size_t_to_jsize(size_t v); - -jsize safecast_int64_to_jsize(int64_t v); +typedef struct { + /* The number of dimensions in the Tensor */ + size_t dimensions; + /* The number of elements in the Tensor */ + size_t elementCount; + /* The type of the Tensor */ + ONNXTensorElementDataType onnxTypeEnum; +} JavaTensorTypeShape; jint JNI_OnLoad(JavaVM *vm, void *reserved); @@ -29,6 +34,8 @@ ONNXTensorElementDataType convertToONNXDataFormat(jint type); size_t onnxTypeSize(ONNXTensorElementDataType type); +OrtErrorCode getTensorTypeShape(JNIEnv * jniEnv, JavaTensorTypeShape * output, const OrtApi * api, const OrtValue * value); + jfloat convertHalfToFloat(uint16_t half); jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, OrtTypeInfo * info); @@ -75,6 +82,10 @@ jint convertErrorCode(OrtErrorCode code); OrtErrorCode checkOrtStatus(JNIEnv * env, const OrtApi * api, OrtStatus * status); +jsize safecast_size_t_to_jsize(size_t v); + +jsize safecast_int64_to_jsize(int64_t v); + #ifdef __cplusplus } #endif diff --git a/java/src/main/native/ai_onnxruntime_OnnxTensor.c b/java/src/main/native/ai_onnxruntime_OnnxTensor.c index 33f41fad44..06a8ad42e7 100644 --- a/java/src/main/native/ai_onnxruntime_OnnxTensor.c +++ b/java/src/main/native/ai_onnxruntime_OnnxTensor.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -14,45 +14,60 @@ * Signature: (JJLjava/lang/Object;[JI)J */ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor - (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject dataObj, jlongArray shape, jint onnxTypeJava) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject dataObj, + jlongArray shape, jint onnxTypeJava) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; // Convert type to ONNX C enum ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava); // Extract the shape information - jboolean mkCopy; - jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv,shape,&mkCopy); - jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv,shape); + jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, shape, NULL); + jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, shape); // Create the OrtValue - OrtValue* ortValue; - checkOrtStatus(jniEnv, api, api->CreateTensorAsOrtValue(allocator,(int64_t*)shapeArr,shapeLen,onnxType,&ortValue)); - (*jniEnv)->ReleaseLongArrayElements(jniEnv,shape,shapeArr,JNI_ABORT); + OrtValue* ortValue = NULL; + OrtErrorCode code = checkOrtStatus(jniEnv, api, + api->CreateTensorAsOrtValue( + allocator, (int64_t*)shapeArr, shapeLen, onnxType, &ortValue + ) + ); + (*jniEnv)->ReleaseLongArrayElements(jniEnv, shape, shapeArr, JNI_ABORT); - // Get a reference to the OrtValue's data - uint8_t* tensorData; - checkOrtStatus(jniEnv, api, api->GetTensorMutableData(ortValue, (void**) &tensorData)); + int failed = 0; + if (code == ORT_OK) { + // Get a reference to the OrtValue's data + uint8_t* tensorData = NULL; + code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(ortValue, (void**)&tensorData)); + if (code == ORT_OK) { + // Check if we're copying a scalar or not + if (shapeLen == 0) { + // Scalars are passed in as a single element array + size_t copied = copyJavaToPrimitiveArray(jniEnv, onnxType, tensorData, dataObj); + failed = copied == 0 ? 1 : failed; + } else { + // Extract the tensor shape information + JavaTensorTypeShape typeShape; + code = getTensorTypeShape(jniEnv, &typeShape, api, ortValue); - // Check if we're copying a scalar or not - if (shapeLen == 0) { - // Scalars are passed in as a single element array - copyJavaToPrimitiveArray(jniEnv, onnxType, tensorData, dataObj); - } else { - // Extract the tensor shape information - OrtTensorTypeAndShapeInfo* info; - checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(ortValue, &info)); - size_t dimensions; - checkOrtStatus(jniEnv,api,api->GetDimensionsCount(info,&dimensions)); - size_t arrSize; - checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(info,&arrSize)); - ONNXTensorElementDataType onnxTypeEnum; - checkOrtStatus(jniEnv,api,api->GetTensorElementType(info,&onnxTypeEnum)); - api->ReleaseTensorTypeAndShapeInfo(info); + if (code == ORT_OK) { + // Copy the java array into the tensor + size_t copied = copyJavaToTensor(jniEnv, onnxType, tensorData, typeShape.elementCount, + typeShape.dimensions, dataObj); + failed = copied == 0 ? 1 : failed; + } else { + failed = 1; + } + } + } else { + failed = 1; + } + } - // Copy the java array into the tensor - copyJavaToTensor(jniEnv, onnxType, tensorData, arrSize, dimensions, dataObj); + if (failed) { + api->ReleaseValue(ortValue); + ortValue = NULL; } // Return the pointer to the OrtValue @@ -65,30 +80,34 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor * Signature: (JJLjava/nio/Buffer;IJ[JI)J */ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensorFromBuffer - (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject buffer, jint bufferPos, jlong bufferSize, jlongArray shape, jint onnxTypeJava) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject buffer, jint bufferPos, jlong bufferSize, + jlongArray shape, jint onnxTypeJava) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; const OrtMemoryInfo* allocatorInfo; - checkOrtStatus(jniEnv, api, api->AllocatorGetInfo(allocator,&allocatorInfo)); + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->AllocatorGetInfo(allocator, &allocatorInfo)); + if (code != ORT_OK) { + return (jlong) NULL; + } // Convert type to ONNX C enum ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava); // Extract the buffer - char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv,buffer); + char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, buffer); // Increment by bufferPos bytes bufferArr = bufferArr + bufferPos; // Extract the shape information - jboolean mkCopy; - jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv,shape,&mkCopy); - jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv,shape); + jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, shape, NULL); + jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, shape); // Create the OrtValue - OrtValue* ortValue; - checkOrtStatus(jniEnv, api, api->CreateTensorWithDataAsOrtValue(allocatorInfo,bufferArr,bufferSize,(int64_t*)shapeArr,shapeLen,onnxType,&ortValue)); - (*jniEnv)->ReleaseLongArrayElements(jniEnv,shape,shapeArr,JNI_ABORT); + OrtValue* ortValue = NULL; + checkOrtStatus(jniEnv, api, api->CreateTensorWithDataAsOrtValue(allocatorInfo, bufferArr, bufferSize, + (int64_t*)shapeArr, shapeLen, onnxType, &ortValue)); + (*jniEnv)->ReleaseLongArrayElements(jniEnv, shape, shapeArr, JNI_ABORT); // Return the pointer to the OrtValue return (jlong) ortValue; @@ -101,32 +120,34 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensorFromBuffer */ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createString (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jstring input) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; - // Extract the shape information - int64_t* shapeArr; - checkOrtStatus(jniEnv,api,api->AllocatorAlloc(allocator,sizeof(int64_t),(void**)&shapeArr)); - shapeArr[0] = 1; - - // Create the OrtValue - OrtValue* ortValue; - checkOrtStatus(jniEnv, api, api->CreateTensorAsOrtValue(allocator,shapeArr,0,ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,&ortValue)); - - // Release the shape - checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator,shapeArr)); + // Create the OrtValue + int64_t shapeArr = 1; + OrtValue* ortValue = NULL; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->CreateTensorAsOrtValue(allocator, &shapeArr, 0, + ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, &ortValue)); + if (code == ORT_OK) { // Create the buffer for the Java string const char* stringBuffer = (*jniEnv)->GetStringUTFChars(jniEnv,input,NULL); // Assign the strings into the Tensor - checkOrtStatus(jniEnv, api, api->FillStringTensor(ortValue,&stringBuffer,1)); + code = checkOrtStatus(jniEnv, api, api->FillStringTensor(ortValue, &stringBuffer, 1)); - // Release the Java string - (*jniEnv)->ReleaseStringUTFChars(jniEnv,input,stringBuffer); + // Release the Java string whether the call succeeded or failed + (*jniEnv)->ReleaseStringUTFChars(jniEnv, input, stringBuffer); - return (jlong) ortValue; + // Assignment failed, return null + if (code != ORT_OK) { + api->ReleaseValue(ortValue); + return (jlong) NULL; + } + } + + return (jlong) ortValue; } /* @@ -136,47 +157,55 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createString */ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createStringTensor (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobjectArray stringArr, jlongArray shape) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; // Extract the shape information - jboolean mkCopy; - jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv,shape,&mkCopy); - jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv,shape); + jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, shape, NULL); + jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, shape); // Array length jsize length = (*jniEnv)->GetArrayLength(jniEnv, stringArr); // Create the OrtValue - OrtValue* ortValue; - checkOrtStatus(jniEnv, api, api->CreateTensorAsOrtValue(allocator,(int64_t*)shapeArr,shapeLen,ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,&ortValue)); - (*jniEnv)->ReleaseLongArrayElements(jniEnv,shape,shapeArr,JNI_ABORT); + OrtValue* ortValue = NULL; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->CreateTensorAsOrtValue(allocator, (int64_t*)shapeArr, shapeLen, + ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, &ortValue)); + (*jniEnv)->ReleaseLongArrayElements(jniEnv, shape, shapeArr, JNI_ABORT); - // Create the buffers for the Java strings - const char** strings; - checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator,sizeof(char*)*length,(void**)&strings)); - jobject* javaStrings; - checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator,sizeof(jobject)*length,(void**)&javaStrings)); + if (code == ORT_OK) { + // Create the buffers for the Java strings + const char** strings = NULL; + code = checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator, sizeof(char*) * length, (void**)&strings)); - // Copy the java strings into the buffers - for (jsize i = 0; i < length; i++) { - javaStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv,stringArr,i); - strings[i] = (*jniEnv)->GetStringUTFChars(jniEnv,javaStrings[i],NULL); + if (code == ORT_OK) { + // Copy the java strings into the buffers + for (jsize i = 0; i < length; i++) { + jobject javaString = (*jniEnv)->GetObjectArrayElement(jniEnv, stringArr, i); + strings[i] = (*jniEnv)->GetStringUTFChars(jniEnv, javaString, NULL); + } + + // Assign the strings into the Tensor + code = checkOrtStatus(jniEnv, api, api->FillStringTensor(ortValue, strings, length)); + + // Release the Java strings + for (int i = 0; i < length; i++) { + jobject javaString = (*jniEnv)->GetObjectArrayElement(jniEnv, stringArr, i); + (*jniEnv)->ReleaseStringUTFChars(jniEnv, javaString, strings[i]); + } + + // Release the buffers + OrtErrorCode freeCode = checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, (void*)strings)); + + // Assignment failed, return null + if ((code != ORT_OK) || (freeCode != ORT_OK)) { + api->ReleaseValue(ortValue); + return (jlong) NULL; + } + } } - // Assign the strings into the Tensor - checkOrtStatus(jniEnv, api, api->FillStringTensor(ortValue,strings,length)); - - // Release the Java strings - for (int i = 0; i < length; i++) { - (*jniEnv)->ReleaseStringUTFChars(jniEnv,javaStrings[i],strings[i]); - } - - // Release the buffers - checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, (void*)strings)); - checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, javaStrings)); - return (jlong) ortValue; } @@ -187,23 +216,24 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createStringTensor */ JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxTensor_getBuffer (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; - OrtTensorTypeAndShapeInfo* info; - checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape((OrtValue*) handle, &info)); - size_t arrSize; - checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(info,&arrSize)); - ONNXTensorElementDataType onnxTypeEnum; - checkOrtStatus(jniEnv,api,api->GetTensorElementType(info,&onnxTypeEnum)); - api->ReleaseTensorTypeAndShapeInfo(info); + OrtValue* ortValue = (OrtValue *) handle; + JavaTensorTypeShape typeShape; + OrtErrorCode code = getTensorTypeShape(jniEnv, &typeShape, api, ortValue); - size_t typeSize = onnxTypeSize(onnxTypeEnum); - size_t sizeBytes = arrSize*typeSize; + if (code == ORT_OK) { + size_t typeSize = onnxTypeSize(typeShape.onnxTypeEnum); + size_t sizeBytes = typeShape.elementCount * typeSize; - uint8_t* arr; - checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr)); + uint8_t* arr = NULL; + code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr)); - return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, sizeBytes); + if (code == ORT_OK) { + return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, (jlong)sizeBytes); + } + } + return NULL; } /* @@ -212,21 +242,25 @@ JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxTensor_getBuffer * Signature: (JI)F */ JNIEXPORT jfloat JNICALL Java_ai_onnxruntime_OnnxTensor_getFloat - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxType) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - if (onnxType == 9) { - uint16_t* arr; - checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr)); - jfloat floatVal = convertHalfToFloat(*arr); - return floatVal; - } else if (onnxType == 10) { - jfloat* arr; - checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr)); - return *arr; - } else { - return NAN; + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxTypeInt) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeInt); + if (onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { + uint16_t* arr = NULL; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr)); + if (code == ORT_OK) { + jfloat floatVal = convertHalfToFloat(*arr); + return floatVal; } + } else if (onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { + jfloat* arr = NULL; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr)); + if (code == ORT_OK) { + return *arr; + } + } + return NAN; } /* @@ -236,11 +270,15 @@ JNIEXPORT jfloat JNICALL Java_ai_onnxruntime_OnnxTensor_getFloat */ JNIEXPORT jdouble JNICALL Java_ai_onnxruntime_OnnxTensor_getDouble (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; - jdouble* arr; - checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr)); - return *arr; + jdouble* arr = NULL; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr)); + if (code == ORT_OK) { + return *arr; + } else { + return NAN; + } } /* @@ -249,20 +287,18 @@ JNIEXPORT jdouble JNICALL Java_ai_onnxruntime_OnnxTensor_getDouble * Signature: (JI)B */ JNIEXPORT jbyte JNICALL Java_ai_onnxruntime_OnnxTensor_getByte - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxType) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxTypeInt) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; - if (onnxType == 1) { - uint8_t* arr; - checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr)); - return (jbyte) *arr; - } else if (onnxType == 2) { - int8_t* arr; - checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr)); - return (jbyte) *arr; - } else { - return (jbyte) 0; + ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeInt); + if ((onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) || (onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)) { + uint8_t* arr = NULL; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr)); + if (code == ORT_OK) { + return (jbyte) *arr; + } } + return (jbyte) 0; } /* @@ -271,20 +307,18 @@ JNIEXPORT jbyte JNICALL Java_ai_onnxruntime_OnnxTensor_getByte * Signature: (JI)S */ JNIEXPORT jshort JNICALL Java_ai_onnxruntime_OnnxTensor_getShort - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxType) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxTypeInt) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; - if (onnxType == 3) { - uint16_t* arr; - checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr)); - return (jshort) *arr; - } else if (onnxType == 4) { - int16_t* arr; - checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr)); - return (jshort) *arr; - } else { - return (jshort) 0; + ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeInt); + if ((onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16) || (onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)) { + uint16_t* arr = NULL; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr)); + if (code == ORT_OK) { + return (jshort) *arr; + } } + return (jshort) 0; } /* @@ -293,20 +327,18 @@ JNIEXPORT jshort JNICALL Java_ai_onnxruntime_OnnxTensor_getShort * Signature: (JI)I */ JNIEXPORT jint JNICALL Java_ai_onnxruntime_OnnxTensor_getInt - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxType) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxTypeInt) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; - if (onnxType == 5) { - uint32_t* arr; - checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr)); - return (jint) *arr; - } else if (onnxType == 6) { - int32_t* arr; - checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr)); - return (jint) *arr; - } else { - return (jint) 0; + ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeInt); + if ((onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32) || (onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)) { + uint32_t* arr = NULL; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr)); + if (code == ORT_OK) { + return (jint) *arr; + } } + return (jint) 0; } /* @@ -315,20 +347,18 @@ JNIEXPORT jint JNICALL Java_ai_onnxruntime_OnnxTensor_getInt * Signature: (JI)J */ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_getLong - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxType) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxTypeInt) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; - if (onnxType == 7) { - uint64_t* arr; - checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr)); - return (jlong) *arr; - } else if (onnxType == 8) { - int64_t* arr; - checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr)); - return (jlong) *arr; - } else { - return (jlong) 0; + ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeInt); + if ((onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64) || (onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) { + uint64_t* arr = NULL; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr)); + if (code == ORT_OK) { + return (jlong) *arr; + } } + return (jlong) 0; } /* @@ -338,18 +368,22 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_getLong */ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OnnxTensor_getString (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; // Extract a String array - if this becomes a performance issue we'll refactor later. - jobjectArray outputArray = createStringArrayFromTensor(jniEnv,api, (OrtAllocator*) allocatorHandle, (OrtValue*) handle); + jobjectArray outputArray = createStringArrayFromTensor(jniEnv, api, (OrtAllocator*) allocatorHandle, + (OrtValue*) handle); + if (outputArray != NULL) { + // Get reference to the string + jobject output = (*jniEnv)->GetObjectArrayElement(jniEnv, outputArray, 0); - // Get reference to the string - jobject output = (*jniEnv)->GetObjectArrayElement(jniEnv, outputArray, 0); + // Free array + (*jniEnv)->DeleteLocalRef(jniEnv, outputArray); - // Free array - (*jniEnv)->DeleteLocalRef(jniEnv,outputArray); - - return output; + return output; + } else { + return NULL; + } } /* @@ -359,11 +393,15 @@ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OnnxTensor_getString */ JNIEXPORT jboolean JNICALL Java_ai_onnxruntime_OnnxTensor_getBool (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; - jboolean* arr; - checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr)); - return *arr; + jboolean* arr = NULL; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr)); + if (code == ORT_OK) { + return *arr; + } else { + return 0; + } } /* @@ -373,24 +411,22 @@ JNIEXPORT jboolean JNICALL Java_ai_onnxruntime_OnnxTensor_getBool */ JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle, jobject carrier) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; - OrtTensorTypeAndShapeInfo* info; - checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape((OrtValue*) handle, &info)); - size_t dimensions; - checkOrtStatus(jniEnv,api,api->GetDimensionsCount(info,&dimensions)); - size_t arrSize; - checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(info,&arrSize)); - ONNXTensorElementDataType onnxTypeEnum; - checkOrtStatus(jniEnv,api,api->GetTensorElementType(info,&onnxTypeEnum)); - api->ReleaseTensorTypeAndShapeInfo(info); - - if (onnxTypeEnum == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { - copyStringTensorToArray(jniEnv,api, (OrtAllocator*) allocatorHandle, (OrtValue*)handle, arrSize, carrier); - } else { - uint8_t* arr; - checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr)); - copyTensorToJava(jniEnv,onnxTypeEnum,arr,arrSize,dimensions,(jarray)carrier); + OrtValue* value = (OrtValue*) handle; + JavaTensorTypeShape typeShape; + OrtErrorCode code = getTensorTypeShape(jniEnv, &typeShape, api, value); + if (code == ORT_OK) { + if (typeShape.onnxTypeEnum == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { + copyStringTensorToArray(jniEnv, api, (OrtAllocator*) allocatorHandle, value, typeShape.elementCount, carrier); + } else { + uint8_t* arr = NULL; + code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(value, (void**)&arr)); + if (code == ORT_OK) { + copyTensorToJava(jniEnv, typeShape.onnxTypeEnum, arr, typeShape.elementCount, + typeShape.dimensions, (jarray)carrier); + } + } } }