[Java] JNI refactor for ONNX Tensor (#12281)

Working on JNI refactor for OnnxTensor.
  Simplifying the error handling logic in createTensor.
  Collapsing casting branches and migrating to ONNX element type enum.
  Disable cpplint for JNI C files.
This commit is contained in:
Adam Pocock 2022-08-08 15:48:30 -04:00 committed by GitHub
parent 8c5c283471
commit 8a86b346a5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 272 additions and 199 deletions

View file

@ -86,7 +86,7 @@ jobs:
github_token: ${{ secrets.github_token }}
reporter: github-pr-check
level: warning
flags: --linelength=120
flags: --linelength=120 --exclude=java/src/main/native/*.c
filter: "-runtime/references"
lint-js:

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2020 Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2022 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
@ -183,6 +183,31 @@ size_t onnxTypeSize(ONNXTensorElementDataType type) {
}
}
OrtErrorCode getTensorTypeShape(JNIEnv * jniEnv, JavaTensorTypeShape* output, const OrtApi * api, const OrtValue * value) {
OrtTensorTypeAndShapeInfo* info;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorTypeAndShape(value, &info));
if (code != ORT_OK) {
return code;
}
code = checkOrtStatus(jniEnv, api, api->GetDimensionsCount(info, &output->dimensions));
if (code != ORT_OK) {
api->ReleaseTensorTypeAndShapeInfo(info);
return code;
}
code = checkOrtStatus(jniEnv, api, api->GetTensorShapeElementCount(info, &output->elementCount));
if (code != ORT_OK) {
api->ReleaseTensorTypeAndShapeInfo(info);
return code;
}
code = checkOrtStatus(jniEnv, api, api->GetTensorElementType(info, &output->onnxTypeEnum));
api->ReleaseTensorTypeAndShapeInfo(info);
if (code != ORT_OK) {
return code;
}
return ORT_OK;
}
typedef union FP32 {
int intVal;
float floatVal;
@ -1038,6 +1063,7 @@ OrtErrorCode checkOrtStatus(JNIEnv *jniEnv, const OrtApi * api, OrtStatus * stat
size_t len = strlen(message)+1;
char* copy = malloc(sizeof(char)*len);
if (copy == NULL) {
api->ReleaseStatus(status);
throwOrtException(jniEnv, 1, "Not enough memory");
return ORT_FAIL;
}

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
@ -11,9 +11,14 @@
extern "C" {
#endif
jsize safecast_size_t_to_jsize(size_t v);
jsize safecast_int64_to_jsize(int64_t v);
typedef struct {
/* The number of dimensions in the Tensor */
size_t dimensions;
/* The number of elements in the Tensor */
size_t elementCount;
/* The type of the Tensor */
ONNXTensorElementDataType onnxTypeEnum;
} JavaTensorTypeShape;
jint JNI_OnLoad(JavaVM *vm, void *reserved);
@ -29,6 +34,8 @@ ONNXTensorElementDataType convertToONNXDataFormat(jint type);
size_t onnxTypeSize(ONNXTensorElementDataType type);
OrtErrorCode getTensorTypeShape(JNIEnv * jniEnv, JavaTensorTypeShape * output, const OrtApi * api, const OrtValue * value);
jfloat convertHalfToFloat(uint16_t half);
jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, OrtTypeInfo * info);
@ -75,6 +82,10 @@ jint convertErrorCode(OrtErrorCode code);
OrtErrorCode checkOrtStatus(JNIEnv * env, const OrtApi * api, OrtStatus * status);
jsize safecast_size_t_to_jsize(size_t v);
jsize safecast_int64_to_jsize(int64_t v);
#ifdef __cplusplus
}
#endif

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
@ -14,45 +14,60 @@
* Signature: (JJLjava/lang/Object;[JI)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor
(JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject dataObj, jlongArray shape, jint onnxTypeJava) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
(JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject dataObj,
jlongArray shape, jint onnxTypeJava) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
// Convert type to ONNX C enum
ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava);
// Extract the shape information
jboolean mkCopy;
jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv,shape,&mkCopy);
jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv,shape);
jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, shape, NULL);
jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, shape);
// Create the OrtValue
OrtValue* ortValue;
checkOrtStatus(jniEnv, api, api->CreateTensorAsOrtValue(allocator,(int64_t*)shapeArr,shapeLen,onnxType,&ortValue));
(*jniEnv)->ReleaseLongArrayElements(jniEnv,shape,shapeArr,JNI_ABORT);
OrtValue* ortValue = NULL;
OrtErrorCode code = checkOrtStatus(jniEnv, api,
api->CreateTensorAsOrtValue(
allocator, (int64_t*)shapeArr, shapeLen, onnxType, &ortValue
)
);
(*jniEnv)->ReleaseLongArrayElements(jniEnv, shape, shapeArr, JNI_ABORT);
// Get a reference to the OrtValue's data
uint8_t* tensorData;
checkOrtStatus(jniEnv, api, api->GetTensorMutableData(ortValue, (void**) &tensorData));
int failed = 0;
if (code == ORT_OK) {
// Get a reference to the OrtValue's data
uint8_t* tensorData = NULL;
code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(ortValue, (void**)&tensorData));
if (code == ORT_OK) {
// Check if we're copying a scalar or not
if (shapeLen == 0) {
// Scalars are passed in as a single element array
size_t copied = copyJavaToPrimitiveArray(jniEnv, onnxType, tensorData, dataObj);
failed = copied == 0 ? 1 : failed;
} else {
// Extract the tensor shape information
JavaTensorTypeShape typeShape;
code = getTensorTypeShape(jniEnv, &typeShape, api, ortValue);
// Check if we're copying a scalar or not
if (shapeLen == 0) {
// Scalars are passed in as a single element array
copyJavaToPrimitiveArray(jniEnv, onnxType, tensorData, dataObj);
} else {
// Extract the tensor shape information
OrtTensorTypeAndShapeInfo* info;
checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(ortValue, &info));
size_t dimensions;
checkOrtStatus(jniEnv,api,api->GetDimensionsCount(info,&dimensions));
size_t arrSize;
checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(info,&arrSize));
ONNXTensorElementDataType onnxTypeEnum;
checkOrtStatus(jniEnv,api,api->GetTensorElementType(info,&onnxTypeEnum));
api->ReleaseTensorTypeAndShapeInfo(info);
if (code == ORT_OK) {
// Copy the java array into the tensor
size_t copied = copyJavaToTensor(jniEnv, onnxType, tensorData, typeShape.elementCount,
typeShape.dimensions, dataObj);
failed = copied == 0 ? 1 : failed;
} else {
failed = 1;
}
}
} else {
failed = 1;
}
}
// Copy the java array into the tensor
copyJavaToTensor(jniEnv, onnxType, tensorData, arrSize, dimensions, dataObj);
if (failed) {
api->ReleaseValue(ortValue);
ortValue = NULL;
}
// Return the pointer to the OrtValue
@ -65,30 +80,34 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor
* Signature: (JJLjava/nio/Buffer;IJ[JI)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensorFromBuffer
(JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject buffer, jint bufferPos, jlong bufferSize, jlongArray shape, jint onnxTypeJava) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
(JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject buffer, jint bufferPos, jlong bufferSize,
jlongArray shape, jint onnxTypeJava) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
const OrtMemoryInfo* allocatorInfo;
checkOrtStatus(jniEnv, api, api->AllocatorGetInfo(allocator,&allocatorInfo));
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->AllocatorGetInfo(allocator, &allocatorInfo));
if (code != ORT_OK) {
return (jlong) NULL;
}
// Convert type to ONNX C enum
ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava);
// Extract the buffer
char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv,buffer);
char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, buffer);
// Increment by bufferPos bytes
bufferArr = bufferArr + bufferPos;
// Extract the shape information
jboolean mkCopy;
jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv,shape,&mkCopy);
jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv,shape);
jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, shape, NULL);
jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, shape);
// Create the OrtValue
OrtValue* ortValue;
checkOrtStatus(jniEnv, api, api->CreateTensorWithDataAsOrtValue(allocatorInfo,bufferArr,bufferSize,(int64_t*)shapeArr,shapeLen,onnxType,&ortValue));
(*jniEnv)->ReleaseLongArrayElements(jniEnv,shape,shapeArr,JNI_ABORT);
OrtValue* ortValue = NULL;
checkOrtStatus(jniEnv, api, api->CreateTensorWithDataAsOrtValue(allocatorInfo, bufferArr, bufferSize,
(int64_t*)shapeArr, shapeLen, onnxType, &ortValue));
(*jniEnv)->ReleaseLongArrayElements(jniEnv, shape, shapeArr, JNI_ABORT);
// Return the pointer to the OrtValue
return (jlong) ortValue;
@ -101,32 +120,34 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensorFromBuffer
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createString
(JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jstring input) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
// Extract the shape information
int64_t* shapeArr;
checkOrtStatus(jniEnv,api,api->AllocatorAlloc(allocator,sizeof(int64_t),(void**)&shapeArr));
shapeArr[0] = 1;
// Create the OrtValue
OrtValue* ortValue;
checkOrtStatus(jniEnv, api, api->CreateTensorAsOrtValue(allocator,shapeArr,0,ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,&ortValue));
// Release the shape
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator,shapeArr));
// Create the OrtValue
int64_t shapeArr = 1;
OrtValue* ortValue = NULL;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->CreateTensorAsOrtValue(allocator, &shapeArr, 0,
ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, &ortValue));
if (code == ORT_OK) {
// Create the buffer for the Java string
const char* stringBuffer = (*jniEnv)->GetStringUTFChars(jniEnv,input,NULL);
// Assign the strings into the Tensor
checkOrtStatus(jniEnv, api, api->FillStringTensor(ortValue,&stringBuffer,1));
code = checkOrtStatus(jniEnv, api, api->FillStringTensor(ortValue, &stringBuffer, 1));
// Release the Java string
(*jniEnv)->ReleaseStringUTFChars(jniEnv,input,stringBuffer);
// Release the Java string whether the call succeeded or failed
(*jniEnv)->ReleaseStringUTFChars(jniEnv, input, stringBuffer);
return (jlong) ortValue;
// Assignment failed, return null
if (code != ORT_OK) {
api->ReleaseValue(ortValue);
return (jlong) NULL;
}
}
return (jlong) ortValue;
}
/*
@ -136,47 +157,55 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createString
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createStringTensor
(JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobjectArray stringArr, jlongArray shape) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
// Extract the shape information
jboolean mkCopy;
jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv,shape,&mkCopy);
jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv,shape);
jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, shape, NULL);
jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, shape);
// Array length
jsize length = (*jniEnv)->GetArrayLength(jniEnv, stringArr);
// Create the OrtValue
OrtValue* ortValue;
checkOrtStatus(jniEnv, api, api->CreateTensorAsOrtValue(allocator,(int64_t*)shapeArr,shapeLen,ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,&ortValue));
(*jniEnv)->ReleaseLongArrayElements(jniEnv,shape,shapeArr,JNI_ABORT);
OrtValue* ortValue = NULL;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->CreateTensorAsOrtValue(allocator, (int64_t*)shapeArr, shapeLen,
ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, &ortValue));
(*jniEnv)->ReleaseLongArrayElements(jniEnv, shape, shapeArr, JNI_ABORT);
// Create the buffers for the Java strings
const char** strings;
checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator,sizeof(char*)*length,(void**)&strings));
jobject* javaStrings;
checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator,sizeof(jobject)*length,(void**)&javaStrings));
if (code == ORT_OK) {
// Create the buffers for the Java strings
const char** strings = NULL;
code = checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator, sizeof(char*) * length, (void**)&strings));
// Copy the java strings into the buffers
for (jsize i = 0; i < length; i++) {
javaStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv,stringArr,i);
strings[i] = (*jniEnv)->GetStringUTFChars(jniEnv,javaStrings[i],NULL);
if (code == ORT_OK) {
// Copy the java strings into the buffers
for (jsize i = 0; i < length; i++) {
jobject javaString = (*jniEnv)->GetObjectArrayElement(jniEnv, stringArr, i);
strings[i] = (*jniEnv)->GetStringUTFChars(jniEnv, javaString, NULL);
}
// Assign the strings into the Tensor
code = checkOrtStatus(jniEnv, api, api->FillStringTensor(ortValue, strings, length));
// Release the Java strings
for (int i = 0; i < length; i++) {
jobject javaString = (*jniEnv)->GetObjectArrayElement(jniEnv, stringArr, i);
(*jniEnv)->ReleaseStringUTFChars(jniEnv, javaString, strings[i]);
}
// Release the buffers
OrtErrorCode freeCode = checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, (void*)strings));
// Assignment failed, return null
if ((code != ORT_OK) || (freeCode != ORT_OK)) {
api->ReleaseValue(ortValue);
return (jlong) NULL;
}
}
}
// Assign the strings into the Tensor
checkOrtStatus(jniEnv, api, api->FillStringTensor(ortValue,strings,length));
// Release the Java strings
for (int i = 0; i < length; i++) {
(*jniEnv)->ReleaseStringUTFChars(jniEnv,javaStrings[i],strings[i]);
}
// Release the buffers
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, (void*)strings));
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, javaStrings));
return (jlong) ortValue;
}
@ -187,23 +216,24 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createStringTensor
*/
JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxTensor_getBuffer
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtTensorTypeAndShapeInfo* info;
checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape((OrtValue*) handle, &info));
size_t arrSize;
checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(info,&arrSize));
ONNXTensorElementDataType onnxTypeEnum;
checkOrtStatus(jniEnv,api,api->GetTensorElementType(info,&onnxTypeEnum));
api->ReleaseTensorTypeAndShapeInfo(info);
OrtValue* ortValue = (OrtValue *) handle;
JavaTensorTypeShape typeShape;
OrtErrorCode code = getTensorTypeShape(jniEnv, &typeShape, api, ortValue);
size_t typeSize = onnxTypeSize(onnxTypeEnum);
size_t sizeBytes = arrSize*typeSize;
if (code == ORT_OK) {
size_t typeSize = onnxTypeSize(typeShape.onnxTypeEnum);
size_t sizeBytes = typeShape.elementCount * typeSize;
uint8_t* arr;
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
uint8_t* arr = NULL;
code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr));
return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, sizeBytes);
if (code == ORT_OK) {
return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, (jlong)sizeBytes);
}
}
return NULL;
}
/*
@ -212,21 +242,25 @@ JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxTensor_getBuffer
* Signature: (JI)F
*/
JNIEXPORT jfloat JNICALL Java_ai_onnxruntime_OnnxTensor_getFloat
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxType) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
if (onnxType == 9) {
uint16_t* arr;
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
jfloat floatVal = convertHalfToFloat(*arr);
return floatVal;
} else if (onnxType == 10) {
jfloat* arr;
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
return *arr;
} else {
return NAN;
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxTypeInt) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeInt);
if (onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
uint16_t* arr = NULL;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr));
if (code == ORT_OK) {
jfloat floatVal = convertHalfToFloat(*arr);
return floatVal;
}
} else if (onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
jfloat* arr = NULL;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr));
if (code == ORT_OK) {
return *arr;
}
}
return NAN;
}
/*
@ -236,11 +270,15 @@ JNIEXPORT jfloat JNICALL Java_ai_onnxruntime_OnnxTensor_getFloat
*/
JNIEXPORT jdouble JNICALL Java_ai_onnxruntime_OnnxTensor_getDouble
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
jdouble* arr;
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
return *arr;
jdouble* arr = NULL;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr));
if (code == ORT_OK) {
return *arr;
} else {
return NAN;
}
}
/*
@ -249,20 +287,18 @@ JNIEXPORT jdouble JNICALL Java_ai_onnxruntime_OnnxTensor_getDouble
* Signature: (JI)B
*/
JNIEXPORT jbyte JNICALL Java_ai_onnxruntime_OnnxTensor_getByte
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxType) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxTypeInt) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
if (onnxType == 1) {
uint8_t* arr;
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
return (jbyte) *arr;
} else if (onnxType == 2) {
int8_t* arr;
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
return (jbyte) *arr;
} else {
return (jbyte) 0;
ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeInt);
if ((onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) || (onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)) {
uint8_t* arr = NULL;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr));
if (code == ORT_OK) {
return (jbyte) *arr;
}
}
return (jbyte) 0;
}
/*
@ -271,20 +307,18 @@ JNIEXPORT jbyte JNICALL Java_ai_onnxruntime_OnnxTensor_getByte
* Signature: (JI)S
*/
JNIEXPORT jshort JNICALL Java_ai_onnxruntime_OnnxTensor_getShort
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxType) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxTypeInt) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
if (onnxType == 3) {
uint16_t* arr;
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
return (jshort) *arr;
} else if (onnxType == 4) {
int16_t* arr;
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
return (jshort) *arr;
} else {
return (jshort) 0;
ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeInt);
if ((onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16) || (onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)) {
uint16_t* arr = NULL;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr));
if (code == ORT_OK) {
return (jshort) *arr;
}
}
return (jshort) 0;
}
/*
@ -293,20 +327,18 @@ JNIEXPORT jshort JNICALL Java_ai_onnxruntime_OnnxTensor_getShort
* Signature: (JI)I
*/
JNIEXPORT jint JNICALL Java_ai_onnxruntime_OnnxTensor_getInt
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxType) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxTypeInt) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
if (onnxType == 5) {
uint32_t* arr;
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
return (jint) *arr;
} else if (onnxType == 6) {
int32_t* arr;
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
return (jint) *arr;
} else {
return (jint) 0;
ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeInt);
if ((onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32) || (onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)) {
uint32_t* arr = NULL;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr));
if (code == ORT_OK) {
return (jint) *arr;
}
}
return (jint) 0;
}
/*
@ -315,20 +347,18 @@ JNIEXPORT jint JNICALL Java_ai_onnxruntime_OnnxTensor_getInt
* Signature: (JI)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_getLong
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxType) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxTypeInt) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
if (onnxType == 7) {
uint64_t* arr;
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
return (jlong) *arr;
} else if (onnxType == 8) {
int64_t* arr;
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
return (jlong) *arr;
} else {
return (jlong) 0;
ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeInt);
if ((onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64) || (onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) {
uint64_t* arr = NULL;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr));
if (code == ORT_OK) {
return (jlong) *arr;
}
}
return (jlong) 0;
}
/*
@ -338,18 +368,22 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_getLong
*/
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OnnxTensor_getString
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
// Extract a String array - if this becomes a performance issue we'll refactor later.
jobjectArray outputArray = createStringArrayFromTensor(jniEnv,api, (OrtAllocator*) allocatorHandle, (OrtValue*) handle);
jobjectArray outputArray = createStringArrayFromTensor(jniEnv, api, (OrtAllocator*) allocatorHandle,
(OrtValue*) handle);
if (outputArray != NULL) {
// Get reference to the string
jobject output = (*jniEnv)->GetObjectArrayElement(jniEnv, outputArray, 0);
// Get reference to the string
jobject output = (*jniEnv)->GetObjectArrayElement(jniEnv, outputArray, 0);
// Free array
(*jniEnv)->DeleteLocalRef(jniEnv, outputArray);
// Free array
(*jniEnv)->DeleteLocalRef(jniEnv,outputArray);
return output;
return output;
} else {
return NULL;
}
}
/*
@ -359,11 +393,15 @@ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OnnxTensor_getString
*/
JNIEXPORT jboolean JNICALL Java_ai_onnxruntime_OnnxTensor_getBool
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
jboolean* arr;
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
return *arr;
jboolean* arr = NULL;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr));
if (code == ORT_OK) {
return *arr;
} else {
return 0;
}
}
/*
@ -373,24 +411,22 @@ JNIEXPORT jboolean JNICALL Java_ai_onnxruntime_OnnxTensor_getBool
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle, jobject carrier) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtTensorTypeAndShapeInfo* info;
checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape((OrtValue*) handle, &info));
size_t dimensions;
checkOrtStatus(jniEnv,api,api->GetDimensionsCount(info,&dimensions));
size_t arrSize;
checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(info,&arrSize));
ONNXTensorElementDataType onnxTypeEnum;
checkOrtStatus(jniEnv,api,api->GetTensorElementType(info,&onnxTypeEnum));
api->ReleaseTensorTypeAndShapeInfo(info);
if (onnxTypeEnum == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
copyStringTensorToArray(jniEnv,api, (OrtAllocator*) allocatorHandle, (OrtValue*)handle, arrSize, carrier);
} else {
uint8_t* arr;
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
copyTensorToJava(jniEnv,onnxTypeEnum,arr,arrSize,dimensions,(jarray)carrier);
OrtValue* value = (OrtValue*) handle;
JavaTensorTypeShape typeShape;
OrtErrorCode code = getTensorTypeShape(jniEnv, &typeShape, api, value);
if (code == ORT_OK) {
if (typeShape.onnxTypeEnum == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
copyStringTensorToArray(jniEnv, api, (OrtAllocator*) allocatorHandle, value, typeShape.elementCount, carrier);
} else {
uint8_t* arr = NULL;
code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(value, (void**)&arr));
if (code == ORT_OK) {
copyTensorToJava(jniEnv, typeShape.onnxTypeEnum, arr, typeShape.elementCount,
typeShape.dimensions, (jarray)carrier);
}
}
}
}