mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-16 01:33:39 +00:00
**Description**: Adds support for creating and receiving sparse tensors in the ORT Java API. CSRC and COO tensors as inputs are tested, but there is no op which accepts a block sparse tensor to test. COO tensors are tested as outputs, but there is no op which emits a CSRC or block sparse tensor to test. **Motivation and Context** - Why is this change required? What problem does it solve? Request to expose ORT sparse tensor support in Java. cc @yuslepukhin
534 lines
20 KiB
C
534 lines
20 KiB
C
/*
|
|
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
|
|
* Licensed under the MIT License.
|
|
*/
|
|
#include <jni.h>
|
|
#include <math.h>
|
|
#include <stdlib.h>
|
|
#include "onnxruntime/core/session/onnxruntime_c_api.h"
|
|
#include "OrtJniUtil.h"
|
|
#include "ai_onnxruntime_OnnxSparseTensor.h"
|
|
|
|
/*
|
|
* Class: ai_onnxruntime_OnnxSparseTensor
|
|
* Method: getIndicesBuffer
|
|
* Signature: (JJ)Ljava/nio/ByteBuffer;
|
|
*/
|
|
JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getIndicesBuffer
|
|
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
|
|
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
|
const OrtApi* api = (const OrtApi*) apiHandle;
|
|
const OrtValue* ortValue = (const OrtValue*) handle;
|
|
OrtSparseFormat format;
|
|
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetSparseTensorFormat(ortValue, &format));
|
|
if (code != ORT_OK) {
|
|
return NULL;
|
|
}
|
|
enum OrtSparseIndicesFormat indicesFormat;
|
|
switch (format) {
|
|
case ORT_SPARSE_COO:
|
|
indicesFormat = ORT_SPARSE_COO_INDICES;
|
|
break;
|
|
case ORT_SPARSE_CSRC:
|
|
indicesFormat = ORT_SPARSE_CSR_OUTER_INDICES;
|
|
break;
|
|
case ORT_SPARSE_BLOCK_SPARSE:
|
|
indicesFormat = ORT_SPARSE_BLOCK_SPARSE_INDICES;
|
|
break;
|
|
case ORT_SPARSE_UNDEFINED:
|
|
default: {
|
|
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Sparse format is ORT_SPARSE_UNDEFINED, cannot get indices");
|
|
return NULL;
|
|
}
|
|
}
|
|
|
|
OrtTensorTypeAndShapeInfo* info = NULL;
|
|
code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndicesTypeShape(ortValue, indicesFormat, &info));
|
|
if (code != ORT_OK) {
|
|
return NULL;
|
|
}
|
|
size_t arrSize = 0;
|
|
code = checkOrtStatus(jniEnv, api, api->GetTensorShapeElementCount(info, &arrSize));
|
|
if (code != ORT_OK) {
|
|
api->ReleaseTensorTypeAndShapeInfo(info);
|
|
return NULL;
|
|
}
|
|
ONNXTensorElementDataType onnxTypeEnum;
|
|
code = checkOrtStatus(jniEnv, api, api->GetTensorElementType(info, &onnxTypeEnum));
|
|
api->ReleaseTensorTypeAndShapeInfo(info);
|
|
if (code != ORT_OK) {
|
|
return NULL;
|
|
}
|
|
|
|
size_t typeSize = onnxTypeSize(onnxTypeEnum);
|
|
size_t sizeBytes = arrSize * typeSize;
|
|
|
|
uint8_t* arr = NULL;
|
|
size_t indices_size = 0;
|
|
code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndices(ortValue, indicesFormat, &indices_size, (const void**)&arr));
|
|
if (code != ORT_OK) {
|
|
return NULL;
|
|
}
|
|
|
|
if (indices_size != arrSize) {
|
|
throwOrtException(jniEnv, convertErrorCode(ORT_RUNTIME_EXCEPTION), "Unexpected size");
|
|
return NULL;
|
|
} else {
|
|
return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, sizeBytes);
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Class: ai_onnxruntime_OnnxSparseTensor
|
|
* Method: getInnerIndicesBuffer
|
|
* Signature: (JJ)Ljava/nio/ByteBuffer;
|
|
*/
|
|
JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getInnerIndicesBuffer
|
|
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
|
|
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
|
const OrtApi* api = (const OrtApi*) apiHandle;
|
|
const OrtValue* ortValue = (const OrtValue*) handle;
|
|
OrtSparseFormat format;
|
|
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetSparseTensorFormat(ortValue, &format));
|
|
if (code != ORT_OK) {
|
|
return NULL;
|
|
}
|
|
enum OrtSparseIndicesFormat indicesFormat;
|
|
switch (format) {
|
|
case ORT_SPARSE_CSRC:
|
|
indicesFormat = ORT_SPARSE_CSR_INNER_INDICES;
|
|
break;
|
|
case ORT_SPARSE_COO:
|
|
case ORT_SPARSE_BLOCK_SPARSE:
|
|
case ORT_SPARSE_UNDEFINED:
|
|
default: {
|
|
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED),
|
|
"Sparse format is ORT_SPARSE_COO, ORT_SPARSE_BLOCK_SPARSE, or ORT_SPARSE_UNDEFINED, inner indices are not defined.");
|
|
return NULL;
|
|
}
|
|
}
|
|
|
|
OrtTensorTypeAndShapeInfo* info = NULL;
|
|
code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndicesTypeShape(ortValue, indicesFormat, &info));
|
|
if (code != ORT_OK) {
|
|
return NULL;
|
|
}
|
|
size_t arrSize = 0;
|
|
code = checkOrtStatus(jniEnv, api, api->GetTensorShapeElementCount(info, &arrSize));
|
|
if (code != ORT_OK) {
|
|
api->ReleaseTensorTypeAndShapeInfo(info);
|
|
return NULL;
|
|
}
|
|
ONNXTensorElementDataType onnxTypeEnum;
|
|
code = checkOrtStatus(jniEnv, api, api->GetTensorElementType(info, &onnxTypeEnum));
|
|
api->ReleaseTensorTypeAndShapeInfo(info);
|
|
if (code != ORT_OK) {
|
|
return NULL;
|
|
}
|
|
|
|
size_t typeSize = onnxTypeSize(onnxTypeEnum);
|
|
size_t sizeBytes = arrSize * typeSize;
|
|
|
|
uint8_t* arr;
|
|
size_t indices_size;
|
|
code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndices(ortValue, indicesFormat, &indices_size, (const void**)&arr));
|
|
if (code != ORT_OK) {
|
|
return NULL;
|
|
}
|
|
|
|
if (indices_size != arrSize) {
|
|
throwOrtException(jniEnv, convertErrorCode(ORT_RUNTIME_EXCEPTION), "Unexpected size");
|
|
return NULL;
|
|
} else {
|
|
return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, sizeBytes);
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Class: ai_onnxruntime_OnnxSparseTensor
|
|
* Method: getValuesBuffer
|
|
* Signature: (JJ)Ljava/nio/ByteBuffer;
|
|
*/
|
|
JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getValuesBuffer
|
|
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
|
|
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
|
const OrtApi* api = (const OrtApi*) apiHandle;
|
|
const OrtValue* ortValue = (const OrtValue*) handle;
|
|
OrtSparseFormat format;
|
|
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetSparseTensorFormat(ortValue, &format));
|
|
if (code != ORT_OK) {
|
|
return NULL;
|
|
}
|
|
switch (format) {
|
|
case ORT_SPARSE_COO:
|
|
case ORT_SPARSE_CSRC:
|
|
case ORT_SPARSE_BLOCK_SPARSE: {
|
|
OrtTensorTypeAndShapeInfo* info = NULL;
|
|
checkOrtStatus(jniEnv, api, api->GetSparseTensorValuesTypeAndShape(ortValue, &info));
|
|
if (code != ORT_OK) {
|
|
return NULL;
|
|
}
|
|
size_t arrSize = 0;
|
|
code = checkOrtStatus(jniEnv, api, api->GetTensorShapeElementCount(info, &arrSize));
|
|
if (code != ORT_OK) {
|
|
api->ReleaseTensorTypeAndShapeInfo(info);
|
|
return NULL;
|
|
}
|
|
ONNXTensorElementDataType onnxTypeEnum;
|
|
code = checkOrtStatus(jniEnv, api, api->GetTensorElementType(info, &onnxTypeEnum));
|
|
api->ReleaseTensorTypeAndShapeInfo(info);
|
|
if (code != ORT_OK) {
|
|
return NULL;
|
|
}
|
|
|
|
size_t typeSize = onnxTypeSize(onnxTypeEnum);
|
|
size_t sizeBytes = arrSize * typeSize;
|
|
|
|
uint8_t* arr = NULL;
|
|
checkOrtStatus(jniEnv, api, api->GetSparseTensorValues(ortValue, (const void**)&arr));
|
|
|
|
return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, sizeBytes);
|
|
}
|
|
case ORT_SPARSE_UNDEFINED:
|
|
default: {
|
|
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED),
|
|
"Sparse format is ORT_SPARSE_UNDEFINED, cannot get data");
|
|
return NULL;
|
|
}
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Class: ai_onnxruntime_OnnxSparseTensor
|
|
* Method: getInnerIndicesShape
|
|
* Signature: (JJ)[J;
|
|
*/
|
|
JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getInnerIndicesShape
|
|
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
|
|
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
|
const OrtApi* api = (const OrtApi*) apiHandle;
|
|
const OrtValue* value = (const OrtValue*) handle;
|
|
|
|
// Extract the info
|
|
OrtTensorTypeAndShapeInfo* info;
|
|
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndicesTypeShape(value, ORT_SPARSE_CSR_INNER_INDICES, &info));
|
|
if (code != ORT_OK) {
|
|
return NULL;
|
|
}
|
|
|
|
// Extract the shape
|
|
size_t numDim = 0;
|
|
code = checkOrtStatus(jniEnv, api, api->GetDimensionsCount(info, &numDim));
|
|
if (code != ORT_OK) {
|
|
api->ReleaseTensorTypeAndShapeInfo(info);
|
|
return NULL;
|
|
}
|
|
int64_t* dimensions = malloc(sizeof(int64_t) * numDim);
|
|
if (dimensions == NULL) {
|
|
throwOrtException(jniEnv, convertErrorCode(ORT_FAIL), "Out of memory when trying to allocate dimensions array");
|
|
api->ReleaseTensorTypeAndShapeInfo(info);
|
|
return NULL;
|
|
}
|
|
code = checkOrtStatus(jniEnv, api, api->GetDimensions(info, dimensions, numDim));
|
|
// Free the info
|
|
api->ReleaseTensorTypeAndShapeInfo(info);
|
|
if (code != ORT_OK) {
|
|
free((void*)dimensions);
|
|
return NULL;
|
|
}
|
|
|
|
// Create the long array for the shape.
|
|
jlongArray shape = (*jniEnv)->NewLongArray(jniEnv, safecast_size_t_to_jsize(numDim));
|
|
(*jniEnv)->SetLongArrayRegion(jniEnv, shape, 0, safecast_size_t_to_jsize(numDim), (jlong*)dimensions);
|
|
|
|
// Free the dimensions array
|
|
free((void*)dimensions);
|
|
|
|
return shape;
|
|
}
|
|
|
|
/*
|
|
* Class: ai_onnxruntime_OnnxSparseTensor
|
|
* Method: getIndicesShape
|
|
* Signature: (JJ)[J;
|
|
*/
|
|
JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getIndicesShape
|
|
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
|
|
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
|
const OrtApi* api = (const OrtApi*) apiHandle;
|
|
const OrtValue* value = (const OrtValue*) handle;
|
|
|
|
// Get the indices format
|
|
OrtSparseFormat format;
|
|
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetSparseTensorFormat(value, &format));
|
|
if (code != ORT_OK) {
|
|
return NULL;
|
|
}
|
|
enum OrtSparseIndicesFormat indicesFormat;
|
|
switch (format) {
|
|
case ORT_SPARSE_CSRC:
|
|
indicesFormat = ORT_SPARSE_CSR_OUTER_INDICES;
|
|
break;
|
|
case ORT_SPARSE_COO:
|
|
indicesFormat = ORT_SPARSE_COO_INDICES;
|
|
break;
|
|
case ORT_SPARSE_BLOCK_SPARSE:
|
|
indicesFormat = ORT_SPARSE_BLOCK_SPARSE_INDICES;
|
|
break;
|
|
case ORT_SPARSE_UNDEFINED:
|
|
default: {
|
|
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED),
|
|
"Sparse format is ORT_SPARSE_UNDEFINED, indices are not defined.");
|
|
return NULL;
|
|
}
|
|
}
|
|
|
|
// Extract the info
|
|
OrtTensorTypeAndShapeInfo* info;
|
|
code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndicesTypeShape(value, indicesFormat, &info));
|
|
if (code != ORT_OK) {
|
|
return NULL;
|
|
}
|
|
|
|
// Extract the shape
|
|
size_t numDim = 0;
|
|
code = checkOrtStatus(jniEnv, api, api->GetDimensionsCount(info, &numDim));
|
|
if (code != ORT_OK) {
|
|
api->ReleaseTensorTypeAndShapeInfo(info);
|
|
return NULL;
|
|
}
|
|
int64_t* dimensions = malloc(sizeof(int64_t) * numDim);
|
|
if (dimensions == NULL) {
|
|
throwOrtException(jniEnv, convertErrorCode(ORT_FAIL), "Out of memory when trying to allocate dimensions array");
|
|
api->ReleaseTensorTypeAndShapeInfo(info);
|
|
return NULL;
|
|
}
|
|
code = checkOrtStatus(jniEnv, api, api->GetDimensions(info, dimensions, numDim));
|
|
// Free the info
|
|
api->ReleaseTensorTypeAndShapeInfo(info);
|
|
if (code != ORT_OK) {
|
|
free((void*)dimensions);
|
|
return NULL;
|
|
}
|
|
|
|
// Create the long array for the shape.
|
|
jlongArray shape = (*jniEnv)->NewLongArray(jniEnv, safecast_size_t_to_jsize(numDim));
|
|
(*jniEnv)->SetLongArrayRegion(jniEnv, shape, 0, safecast_size_t_to_jsize(numDim), (jlong*)dimensions);
|
|
// Free the dimensions array
|
|
free((void*)dimensions);
|
|
|
|
return shape;
|
|
}
|
|
|
|
/*
|
|
* Class: ai_onnxruntime_OnnxSparseTensor
|
|
* Method: getValuesShape
|
|
* Signature: (JJ)[J;
|
|
*/
|
|
JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getValuesShape
|
|
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
|
|
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
|
const OrtApi* api = (const OrtApi*) apiHandle;
|
|
const OrtValue* value = (const OrtValue*) handle;
|
|
|
|
// Extract the info
|
|
OrtTensorTypeAndShapeInfo* info;
|
|
OrtErrorCode code = checkOrtStatus(jniEnv,api,api->GetSparseTensorValuesTypeAndShape(value,&info));
|
|
if (code != ORT_OK) {
|
|
return NULL;
|
|
}
|
|
|
|
// Extract the shape
|
|
size_t numDim = 0;
|
|
code = checkOrtStatus(jniEnv,api,api->GetDimensionsCount(info,&numDim));
|
|
if (code != ORT_OK) {
|
|
api->ReleaseTensorTypeAndShapeInfo(info);
|
|
return NULL;
|
|
}
|
|
int64_t* dimensions = malloc(sizeof(int64_t)*numDim);
|
|
if (dimensions == NULL) {
|
|
throwOrtException(jniEnv, convertErrorCode(ORT_FAIL), "Out of memory when trying to allocate dimensions array");
|
|
api->ReleaseTensorTypeAndShapeInfo(info);
|
|
return NULL;
|
|
}
|
|
code = checkOrtStatus(jniEnv,api,api->GetDimensions(info, dimensions, numDim));
|
|
// Free the info
|
|
api->ReleaseTensorTypeAndShapeInfo(info);
|
|
if (code != ORT_OK) {
|
|
free((void*)dimensions);
|
|
return NULL;
|
|
}
|
|
|
|
// Create the long array for the shape.
|
|
jlongArray shape = (*jniEnv)->NewLongArray(jniEnv, safecast_size_t_to_jsize(numDim));
|
|
(*jniEnv)->SetLongArrayRegion(jniEnv, shape, 0, safecast_size_t_to_jsize(numDim), (jlong*)dimensions);
|
|
|
|
// Free the dimensions array
|
|
free((void*)dimensions);
|
|
|
|
return shape;
|
|
}
|
|
|
|
/*
|
|
* Class: ai_onnxruntime_OnnxSparseTensor
|
|
* Method: close
|
|
* Signature: (JJ)V
|
|
*/
|
|
JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxSparseTensor_close(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
|
|
(void) jniEnv; (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
|
const OrtApi* api = (const OrtApi*) apiHandle;
|
|
api->ReleaseValue((OrtValue*)handle);
|
|
}
|
|
|
|
/*
|
|
* Class: ai_onnxruntime_OnnxSparseTensor
|
|
* Method: createCSRCSparseTensorFromBuffer
|
|
* Signature: (JJLjava/nio/Buffer;IJLjava/nio/Buffer;IJLjava/nio/Buffer;IJ[J[JI)J
|
|
*/
|
|
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxSparseTensor_createCSRCSparseTensorFromBuffer
|
|
(JNIEnv * jniEnv, jclass cls, jlong apiHandle, jlong allocatorHandle,
|
|
jobject indicesBuffer, jint indicesBufferPos, jlong indicesBufferSize,
|
|
jobject innerIndicesBuffer, jint innerIndicesBufferPos, jlong innerIndicesBufferSize,
|
|
jobject dataBuffer, jint dataBufferPos,
|
|
jlongArray denseShape, jlongArray valuesShape,
|
|
jint onnxTypeJava) {
|
|
(void) cls; // Required JNI parameters not needed by functions which don't need to access their host object.
|
|
const OrtApi* api = (const OrtApi*) apiHandle;
|
|
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
|
|
const OrtMemoryInfo* allocatorInfo;
|
|
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->AllocatorGetInfo(allocator, &allocatorInfo));
|
|
if (code != ORT_OK) {
|
|
return 0;
|
|
}
|
|
|
|
// Convert types to ONNX C enums
|
|
ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava);
|
|
|
|
// Extract the buffers
|
|
char* indicesBufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, indicesBuffer);
|
|
char* innerIndicesBufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, innerIndicesBuffer);
|
|
char* dataBufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, dataBuffer);
|
|
// Increment by bufferPos bytes
|
|
indicesBufferArr = indicesBufferArr + indicesBufferPos;
|
|
innerIndicesBufferArr = innerIndicesBufferArr + innerIndicesBufferPos;
|
|
dataBufferArr = dataBufferArr + dataBufferPos;
|
|
|
|
// Extract the dense shape information
|
|
jboolean mkCopy;
|
|
jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, denseShape, &mkCopy);
|
|
jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, denseShape);
|
|
|
|
// Extract the value shape
|
|
jlong* valuesShapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, valuesShape, &mkCopy);
|
|
jsize valuesShapeLen = (*jniEnv)->GetArrayLength(jniEnv, valuesShape);
|
|
|
|
// Create the OrtValue
|
|
OrtValue* ortValue = NULL;
|
|
code = checkOrtStatus(jniEnv, api, api->CreateSparseTensorWithValuesAsOrtValue(allocatorInfo, dataBufferArr,
|
|
(int64_t*) shapeArr, shapeLen, (int64_t*) valuesShapeArr, valuesShapeLen, onnxType, &ortValue));
|
|
// Release shapes
|
|
(*jniEnv)->ReleaseLongArrayElements(jniEnv, denseShape, shapeArr, JNI_ABORT);
|
|
(*jniEnv)->ReleaseLongArrayElements(jniEnv, valuesShape, valuesShapeArr, JNI_ABORT);
|
|
if (code != ORT_OK) {
|
|
return 0;
|
|
}
|
|
|
|
// Fill it with indices
|
|
code = checkOrtStatus(jniEnv, api, api->UseCsrIndices(ortValue,
|
|
(int64_t *) innerIndicesBufferArr, innerIndicesBufferSize,
|
|
(int64_t *) indicesBufferArr, indicesBufferSize));
|
|
if (code != ORT_OK) {
|
|
return 0;
|
|
} else {
|
|
// Return the pointer to the OrtValue
|
|
return (jlong) ortValue;
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Class: ai_onnxruntime_OnnxSparseTensor
|
|
* Method: createSparseTensorFromBuffer
|
|
* Signature: (JJLjava/nio/Buffer;IJLjava/nio/Buffer;IJ[J[J[JII)J
|
|
*/
|
|
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxSparseTensor_createSparseTensorFromBuffer
|
|
(JNIEnv * jniEnv, jclass cls, jlong apiHandle, jlong allocatorHandle,
|
|
jobject indicesBuffer, jint indicesBufferPos, jlong indicesBufferSize,
|
|
jobject dataBuffer, jint dataBufferPos,
|
|
jlongArray denseShape, jlongArray indicesShape, jlongArray valuesShape,
|
|
jint onnxTypeJava, jint sparsityTypeJava) {
|
|
(void) cls; // Required JNI parameters not needed by functions which don't need to access their host object.
|
|
const OrtApi* api = (const OrtApi*) apiHandle;
|
|
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
|
|
const OrtMemoryInfo* allocatorInfo;
|
|
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->AllocatorGetInfo(allocator, &allocatorInfo));
|
|
if (code != ORT_OK) {
|
|
return 0;
|
|
}
|
|
|
|
// Convert types to ONNX C enums
|
|
ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava);
|
|
OrtSparseFormat sparsityType = convertToOrtSparseFormat(sparsityTypeJava);
|
|
|
|
// Extract the buffers
|
|
char* indicesBufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, indicesBuffer);
|
|
char* dataBufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, dataBuffer);
|
|
// Increment by bufferPos bytes
|
|
indicesBufferArr = indicesBufferArr + indicesBufferPos;
|
|
dataBufferArr = dataBufferArr + dataBufferPos;
|
|
|
|
// Extract the dense shape information
|
|
jboolean mkCopy;
|
|
jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, denseShape, &mkCopy);
|
|
jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, denseShape);
|
|
|
|
// Extract the value shape
|
|
jlong* valuesShapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, valuesShape, &mkCopy);
|
|
jsize valuesShapeLen = (*jniEnv)->GetArrayLength(jniEnv, valuesShape);
|
|
|
|
// Create the OrtValue
|
|
OrtValue* ortValue = NULL;
|
|
code = checkOrtStatus(jniEnv, api, api->CreateSparseTensorWithValuesAsOrtValue(allocatorInfo, dataBufferArr,
|
|
(int64_t*) shapeArr, shapeLen, (int64_t*) valuesShapeArr, valuesShapeLen, onnxType, &ortValue));
|
|
|
|
// Release shapes
|
|
(*jniEnv)->ReleaseLongArrayElements(jniEnv, denseShape, shapeArr, JNI_ABORT);
|
|
(*jniEnv)->ReleaseLongArrayElements(jniEnv, valuesShape, valuesShapeArr, JNI_ABORT);
|
|
if (code != ORT_OK) {
|
|
return 0;
|
|
}
|
|
|
|
// Fill it with indices
|
|
switch (sparsityType) {
|
|
case ORT_SPARSE_COO: {
|
|
// The cast is because we compute the offset in bytes in Java.
|
|
code = checkOrtStatus(jniEnv, api, api->UseCooIndices(ortValue, (int64_t *) indicesBufferArr,
|
|
indicesBufferSize));
|
|
break;
|
|
}
|
|
case ORT_SPARSE_BLOCK_SPARSE: {
|
|
// Extract the indices shape
|
|
jlong* indicesShapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, indicesShape, &mkCopy);
|
|
jsize indicesShapeLen = (*jniEnv)->GetArrayLength(jniEnv, indicesShape);
|
|
|
|
// The cast is because we compute the offset in bytes in Java.
|
|
code = checkOrtStatus(jniEnv, api, api->UseBlockSparseIndices(ortValue, (int64_t *) indicesShapeArr,
|
|
indicesShapeLen, (int32_t *) indicesBufferArr));
|
|
|
|
// Release the indices shape
|
|
(*jniEnv)->ReleaseLongArrayElements(jniEnv, indicesShape, indicesShapeArr, JNI_ABORT);
|
|
break;
|
|
}
|
|
case ORT_SPARSE_CSRC:
|
|
case ORT_SPARSE_UNDEFINED: {
|
|
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED),
|
|
"These types are unsupported by this method - ORT_SPARSE_CSRC, ORT_SPARSE_UNDEFINED");
|
|
code = ORT_NOT_IMPLEMENTED;
|
|
}
|
|
}
|
|
if (code != ORT_OK) {
|
|
return 0;
|
|
} else {
|
|
// Return the pointer to the OrtValue
|
|
return (jlong) ortValue;
|
|
}
|
|
}
|