[java] Fixing the buffer semantics. (#5223)

* [java] Fixing the buffer semantics.
* Renaming bufferCapacity to bufferRemaining.
* Adding a cast to char* so the pointer arithmetic works on Windows.
This commit is contained in:
Adam Pocock 2020-09-23 00:29:01 -04:00 committed by GitHub
parent c52561d044
commit d26c71f55c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 156 additions and 140 deletions

View file

@ -439,27 +439,7 @@ public class OnnxTensor implements OnnxValue {
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
OnnxJavaType type = OnnxJavaType.FLOAT;
int bufferSize = data.capacity() * type.size;
FloatBuffer tmp;
if (data.isDirect()) {
tmp = data;
} else {
ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
tmp = buffer.asFloatBuffer();
tmp.put(data);
}
TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type);
return new OnnxTensor(
createTensorFromBuffer(
OnnxRuntime.ortApiHandle,
allocator.handle,
tmp,
bufferSize,
shape,
info.onnxType.value),
allocator.handle,
info,
tmp);
return createTensor(type, allocator, data, shape);
} else {
throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
}
@ -500,27 +480,7 @@ public class OnnxTensor implements OnnxValue {
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
OnnxJavaType type = OnnxJavaType.DOUBLE;
int bufferSize = data.capacity() * type.size;
DoubleBuffer tmp;
if (data.isDirect()) {
tmp = data;
} else {
ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
tmp = buffer.asDoubleBuffer();
tmp.put(data);
}
TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type);
return new OnnxTensor(
createTensorFromBuffer(
OnnxRuntime.ortApiHandle,
allocator.handle,
tmp,
bufferSize,
shape,
info.onnxType.value),
allocator.handle,
info,
tmp);
return createTensor(type, allocator, data, shape);
} else {
throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
}
@ -599,25 +559,7 @@ public class OnnxTensor implements OnnxValue {
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
int bufferSize = data.capacity();
ByteBuffer tmp;
if (data.isDirect()) {
tmp = data;
} else {
tmp = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
tmp.put(data);
}
TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type);
return new OnnxTensor(
createTensorFromBuffer(
OnnxRuntime.ortApiHandle,
allocator.handle,
tmp,
bufferSize,
shape,
info.onnxType.value),
allocator.handle,
info,
tmp);
return createTensor(type, allocator, data, shape);
} else {
throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
}
@ -658,27 +600,7 @@ public class OnnxTensor implements OnnxValue {
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
OnnxJavaType type = OnnxJavaType.INT16;
int bufferSize = data.capacity() * type.size;
ShortBuffer tmp;
if (data.isDirect()) {
tmp = data;
} else {
ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
tmp = buffer.asShortBuffer();
tmp.put(data);
}
TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type);
return new OnnxTensor(
createTensorFromBuffer(
OnnxRuntime.ortApiHandle,
allocator.handle,
tmp,
bufferSize,
shape,
info.onnxType.value),
allocator.handle,
info,
tmp);
return createTensor(type, allocator, data, shape);
} else {
throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
}
@ -719,27 +641,7 @@ public class OnnxTensor implements OnnxValue {
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
OnnxJavaType type = OnnxJavaType.INT32;
int bufferSize = data.capacity() * type.size;
IntBuffer tmp;
if (data.isDirect()) {
tmp = data;
} else {
ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
tmp = buffer.asIntBuffer();
tmp.put(data);
}
TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type);
return new OnnxTensor(
createTensorFromBuffer(
OnnxRuntime.ortApiHandle,
allocator.handle,
tmp,
bufferSize,
shape,
info.onnxType.value),
allocator.handle,
info,
tmp);
return createTensor(type, allocator, data, shape);
} else {
throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
}
@ -780,32 +682,91 @@ public class OnnxTensor implements OnnxValue {
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
OnnxJavaType type = OnnxJavaType.INT64;
int bufferSize = data.capacity() * type.size;
LongBuffer tmp;
if (data.isDirect()) {
tmp = data;
} else {
ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
tmp = buffer.asLongBuffer();
tmp.put(data);
}
TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type);
return new OnnxTensor(
createTensorFromBuffer(
OnnxRuntime.ortApiHandle,
allocator.handle,
tmp,
bufferSize,
shape,
info.onnxType.value),
allocator.handle,
info,
tmp);
return createTensor(type, allocator, data, shape);
} else {
throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
}
}
/**
* Creates a tensor by wrapping the data in a direct byte buffer and passing it to JNI.
*
* <p>Throws IllegalStateException if the buffer is too large to create a direct byte buffer copy,
* which is more than approximately (Integer.MAX_VALUE - 5) / type.size elements.
*
* @param type The buffer type.
* @param allocator The OrtAllocator.
* @param data The data.
* @param shape The tensor shape.
* @return An OnnxTensor instance.
* @throws OrtException If the create call failed.
*/
private static OnnxTensor createTensor(
OnnxJavaType type, OrtAllocator allocator, Buffer data, long[] shape) throws OrtException {
int bufferPos;
int bufferSize = data.remaining() * type.size;
if ((bufferSize < 0) || (bufferSize > ((Integer.MAX_VALUE / type.size) - 10))) {
// overflowed as we can't make a direct byte buffer of that size
throw new IllegalStateException(
"Cannot allocate a direct buffer of the requested size and type, size "
+ data.remaining()
+ ", type = "
+ type);
}
Buffer tmp;
if (data.isDirect()) {
tmp = data;
bufferPos = data.position() * type.size;
} else {
// Copy the data to a new direct buffer, then restore the state of the input.
int origPosition = data.position();
ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
switch (type) {
case FLOAT:
tmp = buffer.asFloatBuffer().put((FloatBuffer) data);
break;
case DOUBLE:
tmp = buffer.asDoubleBuffer().put((DoubleBuffer) data);
break;
case INT8:
// buffer is already a ByteBuffer, no cast needed.
tmp = buffer.put((ByteBuffer) data);
break;
case INT16:
tmp = buffer.asShortBuffer().put((ShortBuffer) data);
break;
case INT32:
tmp = buffer.asIntBuffer().put((IntBuffer) data);
break;
case INT64:
tmp = buffer.asLongBuffer().put((LongBuffer) data);
break;
case BOOL:
case STRING:
case UNKNOWN:
default:
throw new IllegalStateException(
"Impossible to reach here, managed to cast a buffer as an incorrect type");
}
data.position(origPosition);
tmp.rewind();
bufferPos = 0;
}
TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type);
return new OnnxTensor(
createTensorFromBuffer(
OnnxRuntime.ortApiHandle,
allocator.handle,
tmp,
bufferPos,
bufferSize,
shape,
info.onnxType.value),
allocator.handle,
info,
tmp);
}
private static native long createTensor(
long apiHandle, long allocatorHandle, Object data, long[] shape, int onnxType)
throws OrtException;
@ -814,6 +775,7 @@ public class OnnxTensor implements OnnxValue {
long apiHandle,
long allocatorHandle,
Buffer data,
int bufferPos,
long bufferSize,
long[] shape,
int onnxType)

View file

@ -251,16 +251,16 @@ public class TensorInfo implements ValueInfo {
long elementCount = OrtUtil.elementCount(shape);
long bufferCapacity = buffer.capacity();
long bufferRemaining = buffer.remaining();
if (elementCount != bufferCapacity) {
if (elementCount != bufferRemaining) {
throw new OrtException(
"Shape "
+ Arrays.toString(shape)
+ ", requires "
+ elementCount
+ " elements but the buffer has "
+ bufferCapacity
+ bufferRemaining
+ " elements.");
}

View file

@ -56,10 +56,10 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor
/*
* Class: ai_onnxruntime_OnnxTensor
* Method: createTensorFromBuffer
* Signature: (JJLjava/nio/Buffer;J[JI)J
* Signature: (JJLjava/nio/Buffer;IJ[JI)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensorFromBuffer
(JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject buffer, jlong bufferSize, jlongArray shape, jint onnxTypeJava) {
(JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject buffer, jint bufferPos, jlong bufferSize, jlongArray shape, jint onnxTypeJava) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
@ -70,7 +70,9 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensorFromBuffer
ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava);
// Extract the buffer
void* bufferArr = (*jniEnv)->GetDirectBufferAddress(jniEnv,buffer);
char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv,buffer);
// Increment by bufferPos bytes
bufferArr = bufferArr + bufferPos;
// Extract the shape information
jboolean mkCopy;

View file

@ -26,11 +26,13 @@ import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
@ -789,18 +791,68 @@ public class InferenceTest {
float[] resultBufferArray = new float[flatInput.length];
((OnnxTensor) res.get(0)).getFloatBuffer().get(resultBufferArray);
assertArrayEquals(flatInput, resultBufferArray, 1e-6f);
OnnxValue.close(container);
}
container.clear();
OnnxValue.close(container);
}
}
// Now test loading from buffer
FloatBuffer buffer = FloatBuffer.wrap(flatInput);
OnnxTensor newTensor = OnnxTensor.createTensor(env, buffer, shape);
container.put(inputName, newTensor);
try (OrtSession.Result res = session.run(container)) {
resultArray = TestHelpers.flattenFloat(res.get(0).getValue());
assertArrayEquals(flatInput, resultArray, 1e-6f);
OnnxValue.close(container);
@Test
public void testModelInputBuffer() throws OrtException {
// model takes 1x5 input of fixed type, echoes back
String modelPath = getResourcePath("/test_types_FLOAT.pb").toString();
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputFLOAT");
SessionOptions options = new SessionOptions();
OrtSession session = env.createSession(modelPath, options)) {
String inputName = session.getInputNames().iterator().next();
long[] shape = new long[] {1, 5};
Map<String, OnnxTensor> container = new HashMap<>();
float[] inputArr =
new float[] {
1.0f, -2.0f, 3.0f, -4.0f, 5.0f, -6.0f, 7.0f, -8.0f, 9.0f, -10.0f, 11.0f, -12.0f, 13.0f,
-14.0f, 15
};
FloatBuffer buffer = FloatBuffer.wrap(inputArr);
FloatBuffer directBuffer =
ByteBuffer.allocateDirect(inputArr.length * 4)
.order(ByteOrder.nativeOrder())
.asFloatBuffer()
.put(buffer);
buffer.rewind();
directBuffer.rewind();
float[] resultArray;
// Test loading from buffer
for (int i = 0; i < 3; i++) {
// Set limits
buffer.position(i * 5);
buffer.limit((i + 1) * 5);
directBuffer.position(i * 5);
directBuffer.limit((i + 1) * 5);
// Check regular buffer (copies to direct)
OnnxTensor newTensor = OnnxTensor.createTensor(env, buffer, shape);
container.put(inputName, newTensor);
try (OrtSession.Result res = session.run(container)) {
resultArray = TestHelpers.flattenFloat(res.get(0).getValue());
assertArrayEquals(Arrays.copyOfRange(inputArr, i * 5, (i + 1) * 5), resultArray, 1e-6f);
OnnxValue.close(container);
}
container.clear();
// buffer should be unchanged
assertEquals(i * 5, buffer.position());
// Check direct buffer (no-copy)
newTensor = OnnxTensor.createTensor(env, directBuffer, shape);
container.put(inputName, newTensor);
try (OrtSession.Result res = session.run(container)) {
resultArray = TestHelpers.flattenFloat(res.get(0).getValue());
assertArrayEquals(Arrays.copyOfRange(inputArr, i * 5, (i + 1) * 5), resultArray, 1e-6f);
OnnxValue.close(container);
}
container.clear();
// direct buffer should be unchanged
assertEquals(i * 5, directBuffer.position());
}
}
}