mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[Java] JNI refactor for OrtJniUtil (#12516)
Refactoring more JNI methods in OrtJniUtil. Make the strings const. Removing unnecessary use of OrtAllocator.
This commit is contained in:
parent
60e4d012e0
commit
5d55b0730e
9 changed files with 680 additions and 607 deletions
|
|
@ -4,6 +4,8 @@
|
|||
*/
|
||||
package ai.onnxruntime;
|
||||
|
||||
import ai.onnxruntime.TensorInfo.OnnxTensorType;
|
||||
|
||||
/** Describes an {@link OnnxMap} object or output node. */
|
||||
public class MapInfo implements ValueInfo {
|
||||
|
||||
|
|
@ -42,6 +44,21 @@ public class MapInfo implements ValueInfo {
|
|||
this.valueType = valueType;
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct a MapInfo with the specified size, key type and value type.
|
||||
*
|
||||
* <p>Called from JNI.
|
||||
*
|
||||
* @param size The size.
|
||||
* @param keyTypeInt The int representing the {@link OnnxTensorType} of the keys.
|
||||
* @param valueTypeInt The int representing the {@link OnnxTensorType} of the values.
|
||||
*/
|
||||
MapInfo(int size, int keyTypeInt, int valueTypeInt) {
|
||||
this.size = size;
|
||||
this.keyType = OnnxJavaType.mapFromOnnxTensorType(OnnxTensorType.mapFromInt(keyTypeInt));
|
||||
this.valueType = OnnxJavaType.mapFromOnnxTensorType(OnnxTensorType.mapFromInt(valueTypeInt));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
String initial = size == -1 ? "MapInfo(size=UNKNOWN" : "MapInfo(size=" + size;
|
||||
|
|
|
|||
|
|
@ -90,14 +90,14 @@ public class OnnxTensor implements OnnxValue {
|
|||
case BOOL:
|
||||
return getBool(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
case STRING:
|
||||
return getString(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle);
|
||||
return getString(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
case UNKNOWN:
|
||||
default:
|
||||
throw new OrtException("Extracting the value of an invalid Tensor.");
|
||||
}
|
||||
} else {
|
||||
Object carrier = info.makeCarrier();
|
||||
getArray(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, carrier);
|
||||
getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier);
|
||||
if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) {
|
||||
// We read the strings out from native code in a flat array and then reshape
|
||||
// to the desired output shape.
|
||||
|
|
@ -284,13 +284,12 @@ public class OnnxTensor implements OnnxValue {
|
|||
|
||||
private native long getLong(long apiHandle, long nativeHandle, int onnxType) throws OrtException;
|
||||
|
||||
private native String getString(long apiHandle, long nativeHandle, long allocatorHandle)
|
||||
throws OrtException;
|
||||
private native String getString(long apiHandle, long nativeHandle) throws OrtException;
|
||||
|
||||
private native boolean getBool(long apiHandle, long nativeHandle) throws OrtException;
|
||||
|
||||
private native void getArray(
|
||||
long apiHandle, long nativeHandle, long allocatorHandle, Object carrier) throws OrtException;
|
||||
private native void getArray(long apiHandle, long nativeHandle, Object carrier)
|
||||
throws OrtException;
|
||||
|
||||
private native void close(long apiHandle, long nativeHandle);
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
*/
|
||||
package ai.onnxruntime;
|
||||
|
||||
import ai.onnxruntime.TensorInfo.OnnxTensorType;
|
||||
|
||||
/** Describes an {@link OnnxSequence}, including it's element type if known. */
|
||||
public class SequenceInfo implements ValueInfo {
|
||||
|
||||
|
|
@ -35,6 +37,23 @@ public class SequenceInfo implements ValueInfo {
|
|||
this.mapInfo = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct a sequence of known length, with the specified type. This sequence does not contain
|
||||
* maps.
|
||||
*
|
||||
* <p>Called from JNI.
|
||||
*
|
||||
* @param length The length of the sequence.
|
||||
* @param sequenceTypeInt The element type int of the sequence mapped from {@link OnnxTensorType}.
|
||||
*/
|
||||
SequenceInfo(int length, int sequenceTypeInt) {
|
||||
this.length = length;
|
||||
this.sequenceType =
|
||||
OnnxJavaType.mapFromOnnxTensorType(OnnxTensorType.mapFromInt(sequenceTypeInt));
|
||||
this.sequenceOfMaps = false;
|
||||
this.mapInfo = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct a sequence of known length containing maps.
|
||||
*
|
||||
|
|
|
|||
|
|
@ -120,6 +120,20 @@ public class TensorInfo implements ValueInfo {
|
|||
this.onnxType = onnxType;
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a TensorInfo with the specified shape and native type int.
|
||||
*
|
||||
* <p>Called from JNI.
|
||||
*
|
||||
* @param shape The tensor shape.
|
||||
* @param typeInt The native type int.
|
||||
*/
|
||||
TensorInfo(long[] shape, int typeInt) {
|
||||
this.shape = shape;
|
||||
this.onnxType = OnnxTensorType.mapFromInt(typeInt);
|
||||
this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a copy of the tensor's shape.
|
||||
*
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -38,29 +38,27 @@ OrtErrorCode getTensorTypeShape(JNIEnv * jniEnv, JavaTensorTypeShape * output, c
|
|||
|
||||
jfloat convertHalfToFloat(uint16_t half);
|
||||
|
||||
jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, OrtTypeInfo * info);
|
||||
jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo * info);
|
||||
|
||||
jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorTypeAndShapeInfo * info);
|
||||
|
||||
jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtMapTypeInfo * info);
|
||||
|
||||
jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSequenceTypeInfo * info);
|
||||
|
||||
jobject createEmptyMapInfo(JNIEnv *jniEnv);
|
||||
jobject createEmptySequenceInfo(JNIEnv *jniEnv);
|
||||
int64_t copyJavaToPrimitiveArray(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, jarray inputArray, uint8_t* outputTensor);
|
||||
|
||||
size_t copyJavaToPrimitiveArray(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, jarray input);
|
||||
int64_t copyJavaToTensor(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, size_t tensorSize, size_t dimensionsRemaining, jarray inputArray, uint8_t* outputTensor);
|
||||
|
||||
size_t copyJavaToTensor(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, size_t tensorSize, size_t dimensionsRemaining, jarray input);
|
||||
int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, jarray outputArray);
|
||||
|
||||
size_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, jarray output);
|
||||
int64_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, size_t tensorSize, size_t dimensionsRemaining, jarray outputArray);
|
||||
|
||||
size_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, size_t tensorSize, size_t dimensionsRemaining, jarray output);
|
||||
jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor);
|
||||
|
||||
jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor);
|
||||
OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor, size_t length, jobjectArray outputArray);
|
||||
|
||||
void copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor, size_t length, jobjectArray outputArray);
|
||||
|
||||
jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor);
|
||||
jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor);
|
||||
|
||||
jlongArray createLongArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor);
|
||||
|
||||
|
|
@ -74,6 +72,8 @@ jobject createJavaSequenceFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAlloca
|
|||
|
||||
jobject createJavaMapFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* map);
|
||||
|
||||
jobject createMapInfoFromValue(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator * allocator, const OrtValue * map);
|
||||
|
||||
jobject convertOrtValueToONNXValue(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* onnxValue);
|
||||
|
||||
jint throwOrtException(JNIEnv *env, int messageId, const char *message);
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxMap_getStringKeys(JNIEnv*
|
|||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, 0, allocator, &keys));
|
||||
if (code == ORT_OK) {
|
||||
// Convert to Java String array
|
||||
jobjectArray output = createStringArrayFromTensor(jniEnv, api, allocator, keys);
|
||||
jobjectArray output = createStringArrayFromTensor(jniEnv, api, keys);
|
||||
|
||||
api->ReleaseValue(keys);
|
||||
|
||||
|
|
@ -72,7 +72,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxMap_getStringValues(JNIEn
|
|||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, 1, allocator, &values));
|
||||
if (code == ORT_OK) {
|
||||
// Convert to Java String array
|
||||
jobjectArray output = createStringArrayFromTensor(jniEnv, api, allocator, values);
|
||||
jobjectArray output = createStringArrayFromTensor(jniEnv, api, values);
|
||||
|
||||
api->ReleaseValue(values);
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStringKeys(JN
|
|||
|
||||
if (code == ORT_OK) {
|
||||
// Convert to Java String array
|
||||
output = createStringArrayFromTensor(jniEnv, api, allocator, keys);
|
||||
output = createStringArrayFromTensor(jniEnv, api, keys);
|
||||
// Release if valid
|
||||
api->ReleaseValue(element);
|
||||
}
|
||||
|
|
@ -94,7 +94,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStringValues(
|
|||
|
||||
if (code == ORT_OK) {
|
||||
// Convert to Java String array
|
||||
output = createStringArrayFromTensor(jniEnv, api, allocator, values);
|
||||
output = createStringArrayFromTensor(jniEnv, api, values);
|
||||
// Release if valid
|
||||
api->ReleaseValue(element);
|
||||
}
|
||||
|
|
@ -230,7 +230,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStrings(JNIEn
|
|||
OrtValue* element;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element));
|
||||
if (code == ORT_OK) {
|
||||
jobject str = createStringFromStringTensor(jniEnv, api, allocator, element);
|
||||
jobject str = createStringFromStringTensor(jniEnv, api, element);
|
||||
if (str == NULL) {
|
||||
api->ReleaseValue(element);
|
||||
// bail out as exception has been thrown
|
||||
|
|
|
|||
|
|
@ -44,8 +44,8 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor
|
|||
// Check if we're copying a scalar or not
|
||||
if (shapeLen == 0) {
|
||||
// Scalars are passed in as a single element array
|
||||
size_t copied = copyJavaToPrimitiveArray(jniEnv, onnxType, tensorData, dataObj);
|
||||
failed = copied == 0 ? 1 : failed;
|
||||
int64_t copied = copyJavaToPrimitiveArray(jniEnv, onnxType, dataObj, tensorData);
|
||||
failed = copied == -1 ? 1 : failed;
|
||||
} else {
|
||||
// Extract the tensor shape information
|
||||
JavaTensorTypeShape typeShape;
|
||||
|
|
@ -53,9 +53,9 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor
|
|||
|
||||
if (code == ORT_OK) {
|
||||
// Copy the java array into the tensor
|
||||
size_t copied = copyJavaToTensor(jniEnv, onnxType, tensorData, typeShape.elementCount,
|
||||
typeShape.dimensions, dataObj);
|
||||
failed = copied == 0 ? 1 : failed;
|
||||
int64_t copied = copyJavaToTensor(jniEnv, onnxType, typeShape.elementCount,
|
||||
typeShape.dimensions, dataObj, tensorData);
|
||||
failed = copied == -1 ? 1 : failed;
|
||||
} else {
|
||||
failed = 1;
|
||||
}
|
||||
|
|
@ -367,12 +367,11 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_getLong
|
|||
* Signature: (JJ)Ljava/lang/String;
|
||||
*/
|
||||
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OnnxTensor_getString
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle) {
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
// Extract a String array - if this becomes a performance issue we'll refactor later.
|
||||
jobjectArray outputArray = createStringArrayFromTensor(jniEnv, api, (OrtAllocator*) allocatorHandle,
|
||||
(OrtValue*) handle);
|
||||
jobjectArray outputArray = createStringArrayFromTensor(jniEnv, api, (OrtValue*) handle);
|
||||
if (outputArray != NULL) {
|
||||
// Get reference to the string
|
||||
jobject output = (*jniEnv)->GetObjectArrayElement(jniEnv, outputArray, 0);
|
||||
|
|
@ -410,7 +409,7 @@ JNIEXPORT jboolean JNICALL Java_ai_onnxruntime_OnnxTensor_getBool
|
|||
* Signature: (JJLjava/lang/Object;)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle, jobject carrier) {
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jobject carrier) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
OrtValue* value = (OrtValue*) handle;
|
||||
|
|
@ -418,7 +417,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray
|
|||
OrtErrorCode code = getTensorTypeShape(jniEnv, &typeShape, api, value);
|
||||
if (code == ORT_OK) {
|
||||
if (typeShape.onnxTypeEnum == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
|
||||
copyStringTensorToArray(jniEnv, api, (OrtAllocator*) allocatorHandle, value, typeShape.elementCount, carrier);
|
||||
copyStringTensorToArray(jniEnv, api, value, typeShape.elementCount, carrier);
|
||||
} else {
|
||||
uint8_t* arr = NULL;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(value, (void**)&arr));
|
||||
|
|
|
|||
Loading…
Reference in a new issue