mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[Java] JNI refactor for ONNX Tensor (#12281)
Working on JNI refactor for OnnxTensor. Simplifying the error handling logic in createTensor. Collapsing casting branches and migrating to ONNX element type enum. Disable cpplint for JNI C files.
This commit is contained in:
parent
8c5c283471
commit
8a86b346a5
4 changed files with 272 additions and 199 deletions
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
|
|
@ -86,7 +86,7 @@ jobs:
|
|||
github_token: ${{ secrets.github_token }}
|
||||
reporter: github-pr-check
|
||||
level: warning
|
||||
flags: --linelength=120
|
||||
flags: --linelength=120 --exclude=java/src/main/native/*.c
|
||||
filter: "-runtime/references"
|
||||
|
||||
lint-js:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2019, 2020 Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2022 Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
#include <jni.h>
|
||||
|
|
@ -183,6 +183,31 @@ size_t onnxTypeSize(ONNXTensorElementDataType type) {
|
|||
}
|
||||
}
|
||||
|
||||
OrtErrorCode getTensorTypeShape(JNIEnv * jniEnv, JavaTensorTypeShape* output, const OrtApi * api, const OrtValue * value) {
|
||||
OrtTensorTypeAndShapeInfo* info;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorTypeAndShape(value, &info));
|
||||
if (code != ORT_OK) {
|
||||
return code;
|
||||
}
|
||||
code = checkOrtStatus(jniEnv, api, api->GetDimensionsCount(info, &output->dimensions));
|
||||
if (code != ORT_OK) {
|
||||
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||
return code;
|
||||
}
|
||||
code = checkOrtStatus(jniEnv, api, api->GetTensorShapeElementCount(info, &output->elementCount));
|
||||
if (code != ORT_OK) {
|
||||
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||
return code;
|
||||
}
|
||||
code = checkOrtStatus(jniEnv, api, api->GetTensorElementType(info, &output->onnxTypeEnum));
|
||||
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||
if (code != ORT_OK) {
|
||||
return code;
|
||||
}
|
||||
|
||||
return ORT_OK;
|
||||
}
|
||||
|
||||
typedef union FP32 {
|
||||
int intVal;
|
||||
float floatVal;
|
||||
|
|
@ -1038,6 +1063,7 @@ OrtErrorCode checkOrtStatus(JNIEnv *jniEnv, const OrtApi * api, OrtStatus * stat
|
|||
size_t len = strlen(message)+1;
|
||||
char* copy = malloc(sizeof(char)*len);
|
||||
if (copy == NULL) {
|
||||
api->ReleaseStatus(status);
|
||||
throwOrtException(jniEnv, 1, "Not enough memory");
|
||||
return ORT_FAIL;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
#include <jni.h>
|
||||
|
|
@ -11,9 +11,14 @@
|
|||
extern "C" {
|
||||
#endif
|
||||
|
||||
jsize safecast_size_t_to_jsize(size_t v);
|
||||
|
||||
jsize safecast_int64_to_jsize(int64_t v);
|
||||
typedef struct {
|
||||
/* The number of dimensions in the Tensor */
|
||||
size_t dimensions;
|
||||
/* The number of elements in the Tensor */
|
||||
size_t elementCount;
|
||||
/* The type of the Tensor */
|
||||
ONNXTensorElementDataType onnxTypeEnum;
|
||||
} JavaTensorTypeShape;
|
||||
|
||||
jint JNI_OnLoad(JavaVM *vm, void *reserved);
|
||||
|
||||
|
|
@ -29,6 +34,8 @@ ONNXTensorElementDataType convertToONNXDataFormat(jint type);
|
|||
|
||||
size_t onnxTypeSize(ONNXTensorElementDataType type);
|
||||
|
||||
OrtErrorCode getTensorTypeShape(JNIEnv * jniEnv, JavaTensorTypeShape * output, const OrtApi * api, const OrtValue * value);
|
||||
|
||||
jfloat convertHalfToFloat(uint16_t half);
|
||||
|
||||
jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, OrtTypeInfo * info);
|
||||
|
|
@ -75,6 +82,10 @@ jint convertErrorCode(OrtErrorCode code);
|
|||
|
||||
OrtErrorCode checkOrtStatus(JNIEnv * env, const OrtApi * api, OrtStatus * status);
|
||||
|
||||
jsize safecast_size_t_to_jsize(size_t v);
|
||||
|
||||
jsize safecast_int64_to_jsize(int64_t v);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
#include <jni.h>
|
||||
|
|
@ -14,45 +14,60 @@
|
|||
* Signature: (JJLjava/lang/Object;[JI)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor
|
||||
(JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject dataObj, jlongArray shape, jint onnxTypeJava) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
(JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject dataObj,
|
||||
jlongArray shape, jint onnxTypeJava) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
|
||||
// Convert type to ONNX C enum
|
||||
ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava);
|
||||
|
||||
// Extract the shape information
|
||||
jboolean mkCopy;
|
||||
jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv,shape,&mkCopy);
|
||||
jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv,shape);
|
||||
jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, shape, NULL);
|
||||
jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, shape);
|
||||
|
||||
// Create the OrtValue
|
||||
OrtValue* ortValue;
|
||||
checkOrtStatus(jniEnv, api, api->CreateTensorAsOrtValue(allocator,(int64_t*)shapeArr,shapeLen,onnxType,&ortValue));
|
||||
(*jniEnv)->ReleaseLongArrayElements(jniEnv,shape,shapeArr,JNI_ABORT);
|
||||
OrtValue* ortValue = NULL;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api,
|
||||
api->CreateTensorAsOrtValue(
|
||||
allocator, (int64_t*)shapeArr, shapeLen, onnxType, &ortValue
|
||||
)
|
||||
);
|
||||
(*jniEnv)->ReleaseLongArrayElements(jniEnv, shape, shapeArr, JNI_ABORT);
|
||||
|
||||
// Get a reference to the OrtValue's data
|
||||
uint8_t* tensorData;
|
||||
checkOrtStatus(jniEnv, api, api->GetTensorMutableData(ortValue, (void**) &tensorData));
|
||||
int failed = 0;
|
||||
if (code == ORT_OK) {
|
||||
// Get a reference to the OrtValue's data
|
||||
uint8_t* tensorData = NULL;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(ortValue, (void**)&tensorData));
|
||||
if (code == ORT_OK) {
|
||||
// Check if we're copying a scalar or not
|
||||
if (shapeLen == 0) {
|
||||
// Scalars are passed in as a single element array
|
||||
size_t copied = copyJavaToPrimitiveArray(jniEnv, onnxType, tensorData, dataObj);
|
||||
failed = copied == 0 ? 1 : failed;
|
||||
} else {
|
||||
// Extract the tensor shape information
|
||||
JavaTensorTypeShape typeShape;
|
||||
code = getTensorTypeShape(jniEnv, &typeShape, api, ortValue);
|
||||
|
||||
// Check if we're copying a scalar or not
|
||||
if (shapeLen == 0) {
|
||||
// Scalars are passed in as a single element array
|
||||
copyJavaToPrimitiveArray(jniEnv, onnxType, tensorData, dataObj);
|
||||
} else {
|
||||
// Extract the tensor shape information
|
||||
OrtTensorTypeAndShapeInfo* info;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(ortValue, &info));
|
||||
size_t dimensions;
|
||||
checkOrtStatus(jniEnv,api,api->GetDimensionsCount(info,&dimensions));
|
||||
size_t arrSize;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(info,&arrSize));
|
||||
ONNXTensorElementDataType onnxTypeEnum;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorElementType(info,&onnxTypeEnum));
|
||||
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||
if (code == ORT_OK) {
|
||||
// Copy the java array into the tensor
|
||||
size_t copied = copyJavaToTensor(jniEnv, onnxType, tensorData, typeShape.elementCount,
|
||||
typeShape.dimensions, dataObj);
|
||||
failed = copied == 0 ? 1 : failed;
|
||||
} else {
|
||||
failed = 1;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
failed = 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Copy the java array into the tensor
|
||||
copyJavaToTensor(jniEnv, onnxType, tensorData, arrSize, dimensions, dataObj);
|
||||
if (failed) {
|
||||
api->ReleaseValue(ortValue);
|
||||
ortValue = NULL;
|
||||
}
|
||||
|
||||
// Return the pointer to the OrtValue
|
||||
|
|
@ -65,30 +80,34 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor
|
|||
* Signature: (JJLjava/nio/Buffer;IJ[JI)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensorFromBuffer
|
||||
(JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject buffer, jint bufferPos, jlong bufferSize, jlongArray shape, jint onnxTypeJava) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
(JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject buffer, jint bufferPos, jlong bufferSize,
|
||||
jlongArray shape, jint onnxTypeJava) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
|
||||
const OrtMemoryInfo* allocatorInfo;
|
||||
checkOrtStatus(jniEnv, api, api->AllocatorGetInfo(allocator,&allocatorInfo));
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->AllocatorGetInfo(allocator, &allocatorInfo));
|
||||
if (code != ORT_OK) {
|
||||
return (jlong) NULL;
|
||||
}
|
||||
|
||||
// Convert type to ONNX C enum
|
||||
ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava);
|
||||
|
||||
// Extract the buffer
|
||||
char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv,buffer);
|
||||
char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, buffer);
|
||||
// Increment by bufferPos bytes
|
||||
bufferArr = bufferArr + bufferPos;
|
||||
|
||||
// Extract the shape information
|
||||
jboolean mkCopy;
|
||||
jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv,shape,&mkCopy);
|
||||
jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv,shape);
|
||||
jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, shape, NULL);
|
||||
jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, shape);
|
||||
|
||||
// Create the OrtValue
|
||||
OrtValue* ortValue;
|
||||
checkOrtStatus(jniEnv, api, api->CreateTensorWithDataAsOrtValue(allocatorInfo,bufferArr,bufferSize,(int64_t*)shapeArr,shapeLen,onnxType,&ortValue));
|
||||
(*jniEnv)->ReleaseLongArrayElements(jniEnv,shape,shapeArr,JNI_ABORT);
|
||||
OrtValue* ortValue = NULL;
|
||||
checkOrtStatus(jniEnv, api, api->CreateTensorWithDataAsOrtValue(allocatorInfo, bufferArr, bufferSize,
|
||||
(int64_t*)shapeArr, shapeLen, onnxType, &ortValue));
|
||||
(*jniEnv)->ReleaseLongArrayElements(jniEnv, shape, shapeArr, JNI_ABORT);
|
||||
|
||||
// Return the pointer to the OrtValue
|
||||
return (jlong) ortValue;
|
||||
|
|
@ -101,32 +120,34 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensorFromBuffer
|
|||
*/
|
||||
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createString
|
||||
(JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jstring input) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
|
||||
|
||||
// Extract the shape information
|
||||
int64_t* shapeArr;
|
||||
checkOrtStatus(jniEnv,api,api->AllocatorAlloc(allocator,sizeof(int64_t),(void**)&shapeArr));
|
||||
shapeArr[0] = 1;
|
||||
|
||||
// Create the OrtValue
|
||||
OrtValue* ortValue;
|
||||
checkOrtStatus(jniEnv, api, api->CreateTensorAsOrtValue(allocator,shapeArr,0,ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,&ortValue));
|
||||
|
||||
// Release the shape
|
||||
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator,shapeArr));
|
||||
// Create the OrtValue
|
||||
int64_t shapeArr = 1;
|
||||
OrtValue* ortValue = NULL;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->CreateTensorAsOrtValue(allocator, &shapeArr, 0,
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, &ortValue));
|
||||
|
||||
if (code == ORT_OK) {
|
||||
// Create the buffer for the Java string
|
||||
const char* stringBuffer = (*jniEnv)->GetStringUTFChars(jniEnv,input,NULL);
|
||||
|
||||
// Assign the strings into the Tensor
|
||||
checkOrtStatus(jniEnv, api, api->FillStringTensor(ortValue,&stringBuffer,1));
|
||||
code = checkOrtStatus(jniEnv, api, api->FillStringTensor(ortValue, &stringBuffer, 1));
|
||||
|
||||
// Release the Java string
|
||||
(*jniEnv)->ReleaseStringUTFChars(jniEnv,input,stringBuffer);
|
||||
// Release the Java string whether the call succeeded or failed
|
||||
(*jniEnv)->ReleaseStringUTFChars(jniEnv, input, stringBuffer);
|
||||
|
||||
return (jlong) ortValue;
|
||||
// Assignment failed, return null
|
||||
if (code != ORT_OK) {
|
||||
api->ReleaseValue(ortValue);
|
||||
return (jlong) NULL;
|
||||
}
|
||||
}
|
||||
|
||||
return (jlong) ortValue;
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
@ -136,47 +157,55 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createString
|
|||
*/
|
||||
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createStringTensor
|
||||
(JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobjectArray stringArr, jlongArray shape) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
|
||||
|
||||
// Extract the shape information
|
||||
jboolean mkCopy;
|
||||
jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv,shape,&mkCopy);
|
||||
jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv,shape);
|
||||
jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, shape, NULL);
|
||||
jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, shape);
|
||||
|
||||
// Array length
|
||||
jsize length = (*jniEnv)->GetArrayLength(jniEnv, stringArr);
|
||||
|
||||
// Create the OrtValue
|
||||
OrtValue* ortValue;
|
||||
checkOrtStatus(jniEnv, api, api->CreateTensorAsOrtValue(allocator,(int64_t*)shapeArr,shapeLen,ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,&ortValue));
|
||||
(*jniEnv)->ReleaseLongArrayElements(jniEnv,shape,shapeArr,JNI_ABORT);
|
||||
OrtValue* ortValue = NULL;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->CreateTensorAsOrtValue(allocator, (int64_t*)shapeArr, shapeLen,
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, &ortValue));
|
||||
(*jniEnv)->ReleaseLongArrayElements(jniEnv, shape, shapeArr, JNI_ABORT);
|
||||
|
||||
// Create the buffers for the Java strings
|
||||
const char** strings;
|
||||
checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator,sizeof(char*)*length,(void**)&strings));
|
||||
jobject* javaStrings;
|
||||
checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator,sizeof(jobject)*length,(void**)&javaStrings));
|
||||
if (code == ORT_OK) {
|
||||
// Create the buffers for the Java strings
|
||||
const char** strings = NULL;
|
||||
code = checkOrtStatus(jniEnv, api, api->AllocatorAlloc(allocator, sizeof(char*) * length, (void**)&strings));
|
||||
|
||||
// Copy the java strings into the buffers
|
||||
for (jsize i = 0; i < length; i++) {
|
||||
javaStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv,stringArr,i);
|
||||
strings[i] = (*jniEnv)->GetStringUTFChars(jniEnv,javaStrings[i],NULL);
|
||||
if (code == ORT_OK) {
|
||||
// Copy the java strings into the buffers
|
||||
for (jsize i = 0; i < length; i++) {
|
||||
jobject javaString = (*jniEnv)->GetObjectArrayElement(jniEnv, stringArr, i);
|
||||
strings[i] = (*jniEnv)->GetStringUTFChars(jniEnv, javaString, NULL);
|
||||
}
|
||||
|
||||
// Assign the strings into the Tensor
|
||||
code = checkOrtStatus(jniEnv, api, api->FillStringTensor(ortValue, strings, length));
|
||||
|
||||
// Release the Java strings
|
||||
for (int i = 0; i < length; i++) {
|
||||
jobject javaString = (*jniEnv)->GetObjectArrayElement(jniEnv, stringArr, i);
|
||||
(*jniEnv)->ReleaseStringUTFChars(jniEnv, javaString, strings[i]);
|
||||
}
|
||||
|
||||
// Release the buffers
|
||||
OrtErrorCode freeCode = checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, (void*)strings));
|
||||
|
||||
// Assignment failed, return null
|
||||
if ((code != ORT_OK) || (freeCode != ORT_OK)) {
|
||||
api->ReleaseValue(ortValue);
|
||||
return (jlong) NULL;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Assign the strings into the Tensor
|
||||
checkOrtStatus(jniEnv, api, api->FillStringTensor(ortValue,strings,length));
|
||||
|
||||
// Release the Java strings
|
||||
for (int i = 0; i < length; i++) {
|
||||
(*jniEnv)->ReleaseStringUTFChars(jniEnv,javaStrings[i],strings[i]);
|
||||
}
|
||||
|
||||
// Release the buffers
|
||||
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, (void*)strings));
|
||||
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, javaStrings));
|
||||
|
||||
return (jlong) ortValue;
|
||||
}
|
||||
|
||||
|
|
@ -187,23 +216,24 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createStringTensor
|
|||
*/
|
||||
JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxTensor_getBuffer
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
OrtTensorTypeAndShapeInfo* info;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape((OrtValue*) handle, &info));
|
||||
size_t arrSize;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(info,&arrSize));
|
||||
ONNXTensorElementDataType onnxTypeEnum;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorElementType(info,&onnxTypeEnum));
|
||||
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||
OrtValue* ortValue = (OrtValue *) handle;
|
||||
JavaTensorTypeShape typeShape;
|
||||
OrtErrorCode code = getTensorTypeShape(jniEnv, &typeShape, api, ortValue);
|
||||
|
||||
size_t typeSize = onnxTypeSize(onnxTypeEnum);
|
||||
size_t sizeBytes = arrSize*typeSize;
|
||||
if (code == ORT_OK) {
|
||||
size_t typeSize = onnxTypeSize(typeShape.onnxTypeEnum);
|
||||
size_t sizeBytes = typeShape.elementCount * typeSize;
|
||||
|
||||
uint8_t* arr;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
|
||||
uint8_t* arr = NULL;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr));
|
||||
|
||||
return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, sizeBytes);
|
||||
if (code == ORT_OK) {
|
||||
return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, (jlong)sizeBytes);
|
||||
}
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
@ -212,21 +242,25 @@ JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxTensor_getBuffer
|
|||
* Signature: (JI)F
|
||||
*/
|
||||
JNIEXPORT jfloat JNICALL Java_ai_onnxruntime_OnnxTensor_getFloat
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxType) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
if (onnxType == 9) {
|
||||
uint16_t* arr;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
|
||||
jfloat floatVal = convertHalfToFloat(*arr);
|
||||
return floatVal;
|
||||
} else if (onnxType == 10) {
|
||||
jfloat* arr;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
|
||||
return *arr;
|
||||
} else {
|
||||
return NAN;
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxTypeInt) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeInt);
|
||||
if (onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
|
||||
uint16_t* arr = NULL;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr));
|
||||
if (code == ORT_OK) {
|
||||
jfloat floatVal = convertHalfToFloat(*arr);
|
||||
return floatVal;
|
||||
}
|
||||
} else if (onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
|
||||
jfloat* arr = NULL;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr));
|
||||
if (code == ORT_OK) {
|
||||
return *arr;
|
||||
}
|
||||
}
|
||||
return NAN;
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
@ -236,11 +270,15 @@ JNIEXPORT jfloat JNICALL Java_ai_onnxruntime_OnnxTensor_getFloat
|
|||
*/
|
||||
JNIEXPORT jdouble JNICALL Java_ai_onnxruntime_OnnxTensor_getDouble
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
jdouble* arr;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
|
||||
return *arr;
|
||||
jdouble* arr = NULL;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr));
|
||||
if (code == ORT_OK) {
|
||||
return *arr;
|
||||
} else {
|
||||
return NAN;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
@ -249,20 +287,18 @@ JNIEXPORT jdouble JNICALL Java_ai_onnxruntime_OnnxTensor_getDouble
|
|||
* Signature: (JI)B
|
||||
*/
|
||||
JNIEXPORT jbyte JNICALL Java_ai_onnxruntime_OnnxTensor_getByte
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxType) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxTypeInt) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
if (onnxType == 1) {
|
||||
uint8_t* arr;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
|
||||
return (jbyte) *arr;
|
||||
} else if (onnxType == 2) {
|
||||
int8_t* arr;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
|
||||
return (jbyte) *arr;
|
||||
} else {
|
||||
return (jbyte) 0;
|
||||
ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeInt);
|
||||
if ((onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) || (onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)) {
|
||||
uint8_t* arr = NULL;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr));
|
||||
if (code == ORT_OK) {
|
||||
return (jbyte) *arr;
|
||||
}
|
||||
}
|
||||
return (jbyte) 0;
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
@ -271,20 +307,18 @@ JNIEXPORT jbyte JNICALL Java_ai_onnxruntime_OnnxTensor_getByte
|
|||
* Signature: (JI)S
|
||||
*/
|
||||
JNIEXPORT jshort JNICALL Java_ai_onnxruntime_OnnxTensor_getShort
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxType) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxTypeInt) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
if (onnxType == 3) {
|
||||
uint16_t* arr;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
|
||||
return (jshort) *arr;
|
||||
} else if (onnxType == 4) {
|
||||
int16_t* arr;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
|
||||
return (jshort) *arr;
|
||||
} else {
|
||||
return (jshort) 0;
|
||||
ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeInt);
|
||||
if ((onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16) || (onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)) {
|
||||
uint16_t* arr = NULL;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr));
|
||||
if (code == ORT_OK) {
|
||||
return (jshort) *arr;
|
||||
}
|
||||
}
|
||||
return (jshort) 0;
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
@ -293,20 +327,18 @@ JNIEXPORT jshort JNICALL Java_ai_onnxruntime_OnnxTensor_getShort
|
|||
* Signature: (JI)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ai_onnxruntime_OnnxTensor_getInt
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxType) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxTypeInt) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
if (onnxType == 5) {
|
||||
uint32_t* arr;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
|
||||
return (jint) *arr;
|
||||
} else if (onnxType == 6) {
|
||||
int32_t* arr;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
|
||||
return (jint) *arr;
|
||||
} else {
|
||||
return (jint) 0;
|
||||
ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeInt);
|
||||
if ((onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32) || (onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)) {
|
||||
uint32_t* arr = NULL;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr));
|
||||
if (code == ORT_OK) {
|
||||
return (jint) *arr;
|
||||
}
|
||||
}
|
||||
return (jint) 0;
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
@ -315,20 +347,18 @@ JNIEXPORT jint JNICALL Java_ai_onnxruntime_OnnxTensor_getInt
|
|||
* Signature: (JI)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_getLong
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxType) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint onnxTypeInt) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
if (onnxType == 7) {
|
||||
uint64_t* arr;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
|
||||
return (jlong) *arr;
|
||||
} else if (onnxType == 8) {
|
||||
int64_t* arr;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
|
||||
return (jlong) *arr;
|
||||
} else {
|
||||
return (jlong) 0;
|
||||
ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeInt);
|
||||
if ((onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64) || (onnxType == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) {
|
||||
uint64_t* arr = NULL;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr));
|
||||
if (code == ORT_OK) {
|
||||
return (jlong) *arr;
|
||||
}
|
||||
}
|
||||
return (jlong) 0;
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
@ -338,18 +368,22 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_getLong
|
|||
*/
|
||||
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OnnxTensor_getString
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
// Extract a String array - if this becomes a performance issue we'll refactor later.
|
||||
jobjectArray outputArray = createStringArrayFromTensor(jniEnv,api, (OrtAllocator*) allocatorHandle, (OrtValue*) handle);
|
||||
jobjectArray outputArray = createStringArrayFromTensor(jniEnv, api, (OrtAllocator*) allocatorHandle,
|
||||
(OrtValue*) handle);
|
||||
if (outputArray != NULL) {
|
||||
// Get reference to the string
|
||||
jobject output = (*jniEnv)->GetObjectArrayElement(jniEnv, outputArray, 0);
|
||||
|
||||
// Get reference to the string
|
||||
jobject output = (*jniEnv)->GetObjectArrayElement(jniEnv, outputArray, 0);
|
||||
// Free array
|
||||
(*jniEnv)->DeleteLocalRef(jniEnv, outputArray);
|
||||
|
||||
// Free array
|
||||
(*jniEnv)->DeleteLocalRef(jniEnv,outputArray);
|
||||
|
||||
return output;
|
||||
return output;
|
||||
} else {
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
@ -359,11 +393,15 @@ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OnnxTensor_getString
|
|||
*/
|
||||
JNIEXPORT jboolean JNICALL Java_ai_onnxruntime_OnnxTensor_getBool
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
jboolean* arr;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
|
||||
return *arr;
|
||||
jboolean* arr = NULL;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr));
|
||||
if (code == ORT_OK) {
|
||||
return *arr;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
@ -373,24 +411,22 @@ JNIEXPORT jboolean JNICALL Java_ai_onnxruntime_OnnxTensor_getBool
|
|||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle, jobject carrier) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
OrtTensorTypeAndShapeInfo* info;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape((OrtValue*) handle, &info));
|
||||
size_t dimensions;
|
||||
checkOrtStatus(jniEnv,api,api->GetDimensionsCount(info,&dimensions));
|
||||
size_t arrSize;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(info,&arrSize));
|
||||
ONNXTensorElementDataType onnxTypeEnum;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorElementType(info,&onnxTypeEnum));
|
||||
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||
|
||||
if (onnxTypeEnum == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
|
||||
copyStringTensorToArray(jniEnv,api, (OrtAllocator*) allocatorHandle, (OrtValue*)handle, arrSize, carrier);
|
||||
} else {
|
||||
uint8_t* arr;
|
||||
checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)handle,(void**)&arr));
|
||||
copyTensorToJava(jniEnv,onnxTypeEnum,arr,arrSize,dimensions,(jarray)carrier);
|
||||
OrtValue* value = (OrtValue*) handle;
|
||||
JavaTensorTypeShape typeShape;
|
||||
OrtErrorCode code = getTensorTypeShape(jniEnv, &typeShape, api, value);
|
||||
if (code == ORT_OK) {
|
||||
if (typeShape.onnxTypeEnum == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
|
||||
copyStringTensorToArray(jniEnv, api, (OrtAllocator*) allocatorHandle, value, typeShape.elementCount, carrier);
|
||||
} else {
|
||||
uint8_t* arr = NULL;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(value, (void**)&arr));
|
||||
if (code == ORT_OK) {
|
||||
copyTensorToJava(jniEnv, typeShape.onnxTypeEnum, arr, typeShape.elementCount,
|
||||
typeShape.dimensions, (jarray)carrier);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue