[Java] JNI refactor for OrtJniUtil (#12516)

Refactoring more JNI methods in OrtJniUtil.
Make the strings const.
Removing unnecessary use of OrtAllocator.
This commit is contained in:
Adam Pocock 2022-09-08 20:04:42 -04:00 committed by GitHub
parent 60e4d012e0
commit 5d55b0730e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 680 additions and 607 deletions

View file

@ -4,6 +4,8 @@
*/
package ai.onnxruntime;
import ai.onnxruntime.TensorInfo.OnnxTensorType;
/** Describes an {@link OnnxMap} object or output node. */
public class MapInfo implements ValueInfo {
@ -42,6 +44,21 @@ public class MapInfo implements ValueInfo {
this.valueType = valueType;
}
/**
* Construct a MapInfo with the specified size, key type and value type.
*
* <p>Called from JNI.
*
* @param size The size.
* @param keyTypeInt The int representing the {@link OnnxTensorType} of the keys.
* @param valueTypeInt The int representing the {@link OnnxTensorType} of the values.
*/
MapInfo(int size, int keyTypeInt, int valueTypeInt) {
this.size = size;
this.keyType = OnnxJavaType.mapFromOnnxTensorType(OnnxTensorType.mapFromInt(keyTypeInt));
this.valueType = OnnxJavaType.mapFromOnnxTensorType(OnnxTensorType.mapFromInt(valueTypeInt));
}
@Override
public String toString() {
String initial = size == -1 ? "MapInfo(size=UNKNOWN" : "MapInfo(size=" + size;

View file

@ -90,14 +90,14 @@ public class OnnxTensor implements OnnxValue {
case BOOL:
return getBool(OnnxRuntime.ortApiHandle, nativeHandle);
case STRING:
return getString(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle);
return getString(OnnxRuntime.ortApiHandle, nativeHandle);
case UNKNOWN:
default:
throw new OrtException("Extracting the value of an invalid Tensor.");
}
} else {
Object carrier = info.makeCarrier();
getArray(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, carrier);
getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier);
if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) {
// We read the strings out from native code in a flat array and then reshape
// to the desired output shape.
@ -284,13 +284,12 @@ public class OnnxTensor implements OnnxValue {
private native long getLong(long apiHandle, long nativeHandle, int onnxType) throws OrtException;
private native String getString(long apiHandle, long nativeHandle, long allocatorHandle)
throws OrtException;
private native String getString(long apiHandle, long nativeHandle) throws OrtException;
private native boolean getBool(long apiHandle, long nativeHandle) throws OrtException;
private native void getArray(
long apiHandle, long nativeHandle, long allocatorHandle, Object carrier) throws OrtException;
private native void getArray(long apiHandle, long nativeHandle, Object carrier)
throws OrtException;
private native void close(long apiHandle, long nativeHandle);

View file

@ -4,6 +4,8 @@
*/
package ai.onnxruntime;
import ai.onnxruntime.TensorInfo.OnnxTensorType;
/** Describes an {@link OnnxSequence}, including it's element type if known. */
public class SequenceInfo implements ValueInfo {
@ -35,6 +37,23 @@ public class SequenceInfo implements ValueInfo {
this.mapInfo = null;
}
/**
* Construct a sequence of known length, with the specified type. This sequence does not contain
* maps.
*
* <p>Called from JNI.
*
* @param length The length of the sequence.
* @param sequenceTypeInt The element type int of the sequence mapped from {@link OnnxTensorType}.
*/
SequenceInfo(int length, int sequenceTypeInt) {
this.length = length;
this.sequenceType =
OnnxJavaType.mapFromOnnxTensorType(OnnxTensorType.mapFromInt(sequenceTypeInt));
this.sequenceOfMaps = false;
this.mapInfo = null;
}
/**
* Construct a sequence of known length containing maps.
*

View file

@ -120,6 +120,20 @@ public class TensorInfo implements ValueInfo {
this.onnxType = onnxType;
}
/**
* Constructs a TensorInfo with the specified shape and native type int.
*
* <p>Called from JNI.
*
* @param shape The tensor shape.
* @param typeInt The native type int.
*/
TensorInfo(long[] shape, int typeInt) {
this.shape = shape;
this.onnxType = OnnxTensorType.mapFromInt(typeInt);
this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType);
}
/**
* Get a copy of the tensor's shape.
*

File diff suppressed because it is too large Load diff

View file

@ -38,29 +38,27 @@ OrtErrorCode getTensorTypeShape(JNIEnv * jniEnv, JavaTensorTypeShape * output, c
jfloat convertHalfToFloat(uint16_t half);
jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, OrtTypeInfo * info);
jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo * info);
jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorTypeAndShapeInfo * info);
jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtMapTypeInfo * info);
jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSequenceTypeInfo * info);
jobject createEmptyMapInfo(JNIEnv *jniEnv);
jobject createEmptySequenceInfo(JNIEnv *jniEnv);
int64_t copyJavaToPrimitiveArray(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, jarray inputArray, uint8_t* outputTensor);
size_t copyJavaToPrimitiveArray(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, jarray input);
int64_t copyJavaToTensor(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, size_t tensorSize, size_t dimensionsRemaining, jarray inputArray, uint8_t* outputTensor);
size_t copyJavaToTensor(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, size_t tensorSize, size_t dimensionsRemaining, jarray input);
int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, jarray outputArray);
size_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, jarray output);
int64_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, size_t tensorSize, size_t dimensionsRemaining, jarray outputArray);
size_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, size_t tensorSize, size_t dimensionsRemaining, jarray output);
jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor);
jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor);
OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor, size_t length, jobjectArray outputArray);
void copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor, size_t length, jobjectArray outputArray);
jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor);
jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor);
jlongArray createLongArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor);
@ -74,6 +72,8 @@ jobject createJavaSequenceFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAlloca
jobject createJavaMapFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* map);
jobject createMapInfoFromValue(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator * allocator, const OrtValue * map);
jobject convertOrtValueToONNXValue(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* onnxValue);
jint throwOrtException(JNIEnv *env, int messageId, const char *message);

View file

@ -22,7 +22,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxMap_getStringKeys(JNIEnv*
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, 0, allocator, &keys));
if (code == ORT_OK) {
// Convert to Java String array
jobjectArray output = createStringArrayFromTensor(jniEnv, api, allocator, keys);
jobjectArray output = createStringArrayFromTensor(jniEnv, api, keys);
api->ReleaseValue(keys);
@ -72,7 +72,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxMap_getStringValues(JNIEn
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, 1, allocator, &values));
if (code == ORT_OK) {
// Convert to Java String array
jobjectArray output = createStringArrayFromTensor(jniEnv, api, allocator, values);
jobjectArray output = createStringArrayFromTensor(jniEnv, api, values);
api->ReleaseValue(values);

View file

@ -28,7 +28,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStringKeys(JN
if (code == ORT_OK) {
// Convert to Java String array
output = createStringArrayFromTensor(jniEnv, api, allocator, keys);
output = createStringArrayFromTensor(jniEnv, api, keys);
// Release if valid
api->ReleaseValue(element);
}
@ -94,7 +94,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStringValues(
if (code == ORT_OK) {
// Convert to Java String array
output = createStringArrayFromTensor(jniEnv, api, allocator, values);
output = createStringArrayFromTensor(jniEnv, api, values);
// Release if valid
api->ReleaseValue(element);
}
@ -230,7 +230,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStrings(JNIEn
OrtValue* element;
code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element));
if (code == ORT_OK) {
jobject str = createStringFromStringTensor(jniEnv, api, allocator, element);
jobject str = createStringFromStringTensor(jniEnv, api, element);
if (str == NULL) {
api->ReleaseValue(element);
// bail out as exception has been thrown

View file

@ -44,8 +44,8 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor
// Check if we're copying a scalar or not
if (shapeLen == 0) {
// Scalars are passed in as a single element array
size_t copied = copyJavaToPrimitiveArray(jniEnv, onnxType, tensorData, dataObj);
failed = copied == 0 ? 1 : failed;
int64_t copied = copyJavaToPrimitiveArray(jniEnv, onnxType, dataObj, tensorData);
failed = copied == -1 ? 1 : failed;
} else {
// Extract the tensor shape information
JavaTensorTypeShape typeShape;
@ -53,9 +53,9 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor
if (code == ORT_OK) {
// Copy the java array into the tensor
size_t copied = copyJavaToTensor(jniEnv, onnxType, tensorData, typeShape.elementCount,
typeShape.dimensions, dataObj);
failed = copied == 0 ? 1 : failed;
int64_t copied = copyJavaToTensor(jniEnv, onnxType, typeShape.elementCount,
typeShape.dimensions, dataObj, tensorData);
failed = copied == -1 ? 1 : failed;
} else {
failed = 1;
}
@ -367,12 +367,11 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_getLong
* Signature: (JJ)Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OnnxTensor_getString
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle) {
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
// Extract a String array - if this becomes a performance issue we'll refactor later.
jobjectArray outputArray = createStringArrayFromTensor(jniEnv, api, (OrtAllocator*) allocatorHandle,
(OrtValue*) handle);
jobjectArray outputArray = createStringArrayFromTensor(jniEnv, api, (OrtValue*) handle);
if (outputArray != NULL) {
// Get reference to the string
jobject output = (*jniEnv)->GetObjectArrayElement(jniEnv, outputArray, 0);
@ -410,7 +409,7 @@ JNIEXPORT jboolean JNICALL Java_ai_onnxruntime_OnnxTensor_getBool
* Signature: (JJLjava/lang/Object;)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle, jobject carrier) {
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jobject carrier) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtValue* value = (OrtValue*) handle;
@ -418,7 +417,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray
OrtErrorCode code = getTensorTypeShape(jniEnv, &typeShape, api, value);
if (code == ORT_OK) {
if (typeShape.onnxTypeEnum == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
copyStringTensorToArray(jniEnv, api, (OrtAllocator*) allocatorHandle, value, typeShape.elementCount, carrier);
copyStringTensorToArray(jniEnv, api, value, typeShape.elementCount, carrier);
} else {
uint8_t* arr = NULL;
code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(value, (void**)&arr));