mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[java] Fixing the buffer semantics. (#5223)
* [java] Fixing the buffer semantics. * Renaming bufferCapacity to bufferRemaining. * Adding a cast to char* so the pointer arithmetic works on Windows.
This commit is contained in:
parent
c52561d044
commit
d26c71f55c
4 changed files with 156 additions and 140 deletions
|
|
@ -439,27 +439,7 @@ public class OnnxTensor implements OnnxValue {
|
|||
throws OrtException {
|
||||
if ((!env.isClosed()) && (!allocator.isClosed())) {
|
||||
OnnxJavaType type = OnnxJavaType.FLOAT;
|
||||
int bufferSize = data.capacity() * type.size;
|
||||
FloatBuffer tmp;
|
||||
if (data.isDirect()) {
|
||||
tmp = data;
|
||||
} else {
|
||||
ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
|
||||
tmp = buffer.asFloatBuffer();
|
||||
tmp.put(data);
|
||||
}
|
||||
TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type);
|
||||
return new OnnxTensor(
|
||||
createTensorFromBuffer(
|
||||
OnnxRuntime.ortApiHandle,
|
||||
allocator.handle,
|
||||
tmp,
|
||||
bufferSize,
|
||||
shape,
|
||||
info.onnxType.value),
|
||||
allocator.handle,
|
||||
info,
|
||||
tmp);
|
||||
return createTensor(type, allocator, data, shape);
|
||||
} else {
|
||||
throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
|
||||
}
|
||||
|
|
@ -500,27 +480,7 @@ public class OnnxTensor implements OnnxValue {
|
|||
throws OrtException {
|
||||
if ((!env.isClosed()) && (!allocator.isClosed())) {
|
||||
OnnxJavaType type = OnnxJavaType.DOUBLE;
|
||||
int bufferSize = data.capacity() * type.size;
|
||||
DoubleBuffer tmp;
|
||||
if (data.isDirect()) {
|
||||
tmp = data;
|
||||
} else {
|
||||
ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
|
||||
tmp = buffer.asDoubleBuffer();
|
||||
tmp.put(data);
|
||||
}
|
||||
TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type);
|
||||
return new OnnxTensor(
|
||||
createTensorFromBuffer(
|
||||
OnnxRuntime.ortApiHandle,
|
||||
allocator.handle,
|
||||
tmp,
|
||||
bufferSize,
|
||||
shape,
|
||||
info.onnxType.value),
|
||||
allocator.handle,
|
||||
info,
|
||||
tmp);
|
||||
return createTensor(type, allocator, data, shape);
|
||||
} else {
|
||||
throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
|
||||
}
|
||||
|
|
@ -599,25 +559,7 @@ public class OnnxTensor implements OnnxValue {
|
|||
throws OrtException {
|
||||
if ((!env.isClosed()) && (!allocator.isClosed())) {
|
||||
int bufferSize = data.capacity();
|
||||
ByteBuffer tmp;
|
||||
if (data.isDirect()) {
|
||||
tmp = data;
|
||||
} else {
|
||||
tmp = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
|
||||
tmp.put(data);
|
||||
}
|
||||
TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type);
|
||||
return new OnnxTensor(
|
||||
createTensorFromBuffer(
|
||||
OnnxRuntime.ortApiHandle,
|
||||
allocator.handle,
|
||||
tmp,
|
||||
bufferSize,
|
||||
shape,
|
||||
info.onnxType.value),
|
||||
allocator.handle,
|
||||
info,
|
||||
tmp);
|
||||
return createTensor(type, allocator, data, shape);
|
||||
} else {
|
||||
throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
|
||||
}
|
||||
|
|
@ -658,27 +600,7 @@ public class OnnxTensor implements OnnxValue {
|
|||
throws OrtException {
|
||||
if ((!env.isClosed()) && (!allocator.isClosed())) {
|
||||
OnnxJavaType type = OnnxJavaType.INT16;
|
||||
int bufferSize = data.capacity() * type.size;
|
||||
ShortBuffer tmp;
|
||||
if (data.isDirect()) {
|
||||
tmp = data;
|
||||
} else {
|
||||
ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
|
||||
tmp = buffer.asShortBuffer();
|
||||
tmp.put(data);
|
||||
}
|
||||
TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type);
|
||||
return new OnnxTensor(
|
||||
createTensorFromBuffer(
|
||||
OnnxRuntime.ortApiHandle,
|
||||
allocator.handle,
|
||||
tmp,
|
||||
bufferSize,
|
||||
shape,
|
||||
info.onnxType.value),
|
||||
allocator.handle,
|
||||
info,
|
||||
tmp);
|
||||
return createTensor(type, allocator, data, shape);
|
||||
} else {
|
||||
throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
|
||||
}
|
||||
|
|
@ -719,27 +641,7 @@ public class OnnxTensor implements OnnxValue {
|
|||
throws OrtException {
|
||||
if ((!env.isClosed()) && (!allocator.isClosed())) {
|
||||
OnnxJavaType type = OnnxJavaType.INT32;
|
||||
int bufferSize = data.capacity() * type.size;
|
||||
IntBuffer tmp;
|
||||
if (data.isDirect()) {
|
||||
tmp = data;
|
||||
} else {
|
||||
ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
|
||||
tmp = buffer.asIntBuffer();
|
||||
tmp.put(data);
|
||||
}
|
||||
TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type);
|
||||
return new OnnxTensor(
|
||||
createTensorFromBuffer(
|
||||
OnnxRuntime.ortApiHandle,
|
||||
allocator.handle,
|
||||
tmp,
|
||||
bufferSize,
|
||||
shape,
|
||||
info.onnxType.value),
|
||||
allocator.handle,
|
||||
info,
|
||||
tmp);
|
||||
return createTensor(type, allocator, data, shape);
|
||||
} else {
|
||||
throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
|
||||
}
|
||||
|
|
@ -780,32 +682,91 @@ public class OnnxTensor implements OnnxValue {
|
|||
throws OrtException {
|
||||
if ((!env.isClosed()) && (!allocator.isClosed())) {
|
||||
OnnxJavaType type = OnnxJavaType.INT64;
|
||||
int bufferSize = data.capacity() * type.size;
|
||||
LongBuffer tmp;
|
||||
if (data.isDirect()) {
|
||||
tmp = data;
|
||||
} else {
|
||||
ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
|
||||
tmp = buffer.asLongBuffer();
|
||||
tmp.put(data);
|
||||
}
|
||||
TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type);
|
||||
return new OnnxTensor(
|
||||
createTensorFromBuffer(
|
||||
OnnxRuntime.ortApiHandle,
|
||||
allocator.handle,
|
||||
tmp,
|
||||
bufferSize,
|
||||
shape,
|
||||
info.onnxType.value),
|
||||
allocator.handle,
|
||||
info,
|
||||
tmp);
|
||||
return createTensor(type, allocator, data, shape);
|
||||
} else {
|
||||
throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a tensor by wrapping the data in a direct byte buffer and passing it to JNI.
|
||||
*
|
||||
* <p>Throws IllegalStateException if the buffer is too large to create a direct byte buffer copy,
|
||||
* which is more than approximately (Integer.MAX_VALUE - 5) / type.size elements.
|
||||
*
|
||||
* @param type The buffer type.
|
||||
* @param allocator The OrtAllocator.
|
||||
* @param data The data.
|
||||
* @param shape The tensor shape.
|
||||
* @return An OnnxTensor instance.
|
||||
* @throws OrtException If the create call failed.
|
||||
*/
|
||||
private static OnnxTensor createTensor(
|
||||
OnnxJavaType type, OrtAllocator allocator, Buffer data, long[] shape) throws OrtException {
|
||||
int bufferPos;
|
||||
int bufferSize = data.remaining() * type.size;
|
||||
if ((bufferSize < 0) || (bufferSize > ((Integer.MAX_VALUE / type.size) - 10))) {
|
||||
// overflowed as we can't make a direct byte buffer of that size
|
||||
throw new IllegalStateException(
|
||||
"Cannot allocate a direct buffer of the requested size and type, size "
|
||||
+ data.remaining()
|
||||
+ ", type = "
|
||||
+ type);
|
||||
}
|
||||
Buffer tmp;
|
||||
if (data.isDirect()) {
|
||||
tmp = data;
|
||||
bufferPos = data.position() * type.size;
|
||||
} else {
|
||||
// Copy the data to a new direct buffer, then restore the state of the input.
|
||||
int origPosition = data.position();
|
||||
ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
|
||||
switch (type) {
|
||||
case FLOAT:
|
||||
tmp = buffer.asFloatBuffer().put((FloatBuffer) data);
|
||||
break;
|
||||
case DOUBLE:
|
||||
tmp = buffer.asDoubleBuffer().put((DoubleBuffer) data);
|
||||
break;
|
||||
case INT8:
|
||||
// buffer is already a ByteBuffer, no cast needed.
|
||||
tmp = buffer.put((ByteBuffer) data);
|
||||
break;
|
||||
case INT16:
|
||||
tmp = buffer.asShortBuffer().put((ShortBuffer) data);
|
||||
break;
|
||||
case INT32:
|
||||
tmp = buffer.asIntBuffer().put((IntBuffer) data);
|
||||
break;
|
||||
case INT64:
|
||||
tmp = buffer.asLongBuffer().put((LongBuffer) data);
|
||||
break;
|
||||
case BOOL:
|
||||
case STRING:
|
||||
case UNKNOWN:
|
||||
default:
|
||||
throw new IllegalStateException(
|
||||
"Impossible to reach here, managed to cast a buffer as an incorrect type");
|
||||
}
|
||||
data.position(origPosition);
|
||||
tmp.rewind();
|
||||
bufferPos = 0;
|
||||
}
|
||||
TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type);
|
||||
return new OnnxTensor(
|
||||
createTensorFromBuffer(
|
||||
OnnxRuntime.ortApiHandle,
|
||||
allocator.handle,
|
||||
tmp,
|
||||
bufferPos,
|
||||
bufferSize,
|
||||
shape,
|
||||
info.onnxType.value),
|
||||
allocator.handle,
|
||||
info,
|
||||
tmp);
|
||||
}
|
||||
|
||||
private static native long createTensor(
|
||||
long apiHandle, long allocatorHandle, Object data, long[] shape, int onnxType)
|
||||
throws OrtException;
|
||||
|
|
@ -814,6 +775,7 @@ public class OnnxTensor implements OnnxValue {
|
|||
long apiHandle,
|
||||
long allocatorHandle,
|
||||
Buffer data,
|
||||
int bufferPos,
|
||||
long bufferSize,
|
||||
long[] shape,
|
||||
int onnxType)
|
||||
|
|
|
|||
|
|
@ -251,16 +251,16 @@ public class TensorInfo implements ValueInfo {
|
|||
|
||||
long elementCount = OrtUtil.elementCount(shape);
|
||||
|
||||
long bufferCapacity = buffer.capacity();
|
||||
long bufferRemaining = buffer.remaining();
|
||||
|
||||
if (elementCount != bufferCapacity) {
|
||||
if (elementCount != bufferRemaining) {
|
||||
throw new OrtException(
|
||||
"Shape "
|
||||
+ Arrays.toString(shape)
|
||||
+ ", requires "
|
||||
+ elementCount
|
||||
+ " elements but the buffer has "
|
||||
+ bufferCapacity
|
||||
+ bufferRemaining
|
||||
+ " elements.");
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -56,10 +56,10 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor
|
|||
/*
|
||||
* Class: ai_onnxruntime_OnnxTensor
|
||||
* Method: createTensorFromBuffer
|
||||
* Signature: (JJLjava/nio/Buffer;J[JI)J
|
||||
* Signature: (JJLjava/nio/Buffer;IJ[JI)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensorFromBuffer
|
||||
(JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject buffer, jlong bufferSize, jlongArray shape, jint onnxTypeJava) {
|
||||
(JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject buffer, jint bufferPos, jlong bufferSize, jlongArray shape, jint onnxTypeJava) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
|
||||
|
|
@ -70,7 +70,9 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensorFromBuffer
|
|||
ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava);
|
||||
|
||||
// Extract the buffer
|
||||
void* bufferArr = (*jniEnv)->GetDirectBufferAddress(jniEnv,buffer);
|
||||
char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv,buffer);
|
||||
// Increment by bufferPos bytes
|
||||
bufferArr = bufferArr + bufferPos;
|
||||
|
||||
// Extract the shape information
|
||||
jboolean mkCopy;
|
||||
|
|
|
|||
|
|
@ -26,11 +26,13 @@ import java.io.IOException;
|
|||
import java.io.InputStream;
|
||||
import java.io.UncheckedIOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Iterator;
|
||||
|
|
@ -789,18 +791,68 @@ public class InferenceTest {
|
|||
float[] resultBufferArray = new float[flatInput.length];
|
||||
((OnnxTensor) res.get(0)).getFloatBuffer().get(resultBufferArray);
|
||||
assertArrayEquals(flatInput, resultBufferArray, 1e-6f);
|
||||
OnnxValue.close(container);
|
||||
}
|
||||
container.clear();
|
||||
OnnxValue.close(container);
|
||||
}
|
||||
}
|
||||
|
||||
// Now test loading from buffer
|
||||
FloatBuffer buffer = FloatBuffer.wrap(flatInput);
|
||||
OnnxTensor newTensor = OnnxTensor.createTensor(env, buffer, shape);
|
||||
container.put(inputName, newTensor);
|
||||
try (OrtSession.Result res = session.run(container)) {
|
||||
resultArray = TestHelpers.flattenFloat(res.get(0).getValue());
|
||||
assertArrayEquals(flatInput, resultArray, 1e-6f);
|
||||
OnnxValue.close(container);
|
||||
@Test
|
||||
public void testModelInputBuffer() throws OrtException {
|
||||
// model takes 1x5 input of fixed type, echoes back
|
||||
String modelPath = getResourcePath("/test_types_FLOAT.pb").toString();
|
||||
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputFLOAT");
|
||||
SessionOptions options = new SessionOptions();
|
||||
OrtSession session = env.createSession(modelPath, options)) {
|
||||
String inputName = session.getInputNames().iterator().next();
|
||||
long[] shape = new long[] {1, 5};
|
||||
Map<String, OnnxTensor> container = new HashMap<>();
|
||||
float[] inputArr =
|
||||
new float[] {
|
||||
1.0f, -2.0f, 3.0f, -4.0f, 5.0f, -6.0f, 7.0f, -8.0f, 9.0f, -10.0f, 11.0f, -12.0f, 13.0f,
|
||||
-14.0f, 15
|
||||
};
|
||||
FloatBuffer buffer = FloatBuffer.wrap(inputArr);
|
||||
FloatBuffer directBuffer =
|
||||
ByteBuffer.allocateDirect(inputArr.length * 4)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
.asFloatBuffer()
|
||||
.put(buffer);
|
||||
buffer.rewind();
|
||||
directBuffer.rewind();
|
||||
float[] resultArray;
|
||||
|
||||
// Test loading from buffer
|
||||
for (int i = 0; i < 3; i++) {
|
||||
// Set limits
|
||||
buffer.position(i * 5);
|
||||
buffer.limit((i + 1) * 5);
|
||||
directBuffer.position(i * 5);
|
||||
directBuffer.limit((i + 1) * 5);
|
||||
|
||||
// Check regular buffer (copies to direct)
|
||||
OnnxTensor newTensor = OnnxTensor.createTensor(env, buffer, shape);
|
||||
container.put(inputName, newTensor);
|
||||
try (OrtSession.Result res = session.run(container)) {
|
||||
resultArray = TestHelpers.flattenFloat(res.get(0).getValue());
|
||||
assertArrayEquals(Arrays.copyOfRange(inputArr, i * 5, (i + 1) * 5), resultArray, 1e-6f);
|
||||
OnnxValue.close(container);
|
||||
}
|
||||
container.clear();
|
||||
// buffer should be unchanged
|
||||
assertEquals(i * 5, buffer.position());
|
||||
|
||||
// Check direct buffer (no-copy)
|
||||
newTensor = OnnxTensor.createTensor(env, directBuffer, shape);
|
||||
container.put(inputName, newTensor);
|
||||
try (OrtSession.Result res = session.run(container)) {
|
||||
resultArray = TestHelpers.flattenFloat(res.get(0).getValue());
|
||||
assertArrayEquals(Arrays.copyOfRange(inputArr, i * 5, (i + 1) * 5), resultArray, 1e-6f);
|
||||
OnnxValue.close(container);
|
||||
}
|
||||
container.clear();
|
||||
// direct buffer should be unchanged
|
||||
assertEquals(i * 5, directBuffer.position());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue