diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index 7e3e27bfa5..752652bf22 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -439,27 +439,7 @@ public class OnnxTensor implements OnnxValue { throws OrtException { if ((!env.isClosed()) && (!allocator.isClosed())) { OnnxJavaType type = OnnxJavaType.FLOAT; - int bufferSize = data.capacity() * type.size; - FloatBuffer tmp; - if (data.isDirect()) { - tmp = data; - } else { - ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder()); - tmp = buffer.asFloatBuffer(); - tmp.put(data); - } - TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type); - return new OnnxTensor( - createTensorFromBuffer( - OnnxRuntime.ortApiHandle, - allocator.handle, - tmp, - bufferSize, - shape, - info.onnxType.value), - allocator.handle, - info, - tmp); + return createTensor(type, allocator, data, shape); } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); } @@ -500,27 +480,7 @@ public class OnnxTensor implements OnnxValue { throws OrtException { if ((!env.isClosed()) && (!allocator.isClosed())) { OnnxJavaType type = OnnxJavaType.DOUBLE; - int bufferSize = data.capacity() * type.size; - DoubleBuffer tmp; - if (data.isDirect()) { - tmp = data; - } else { - ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder()); - tmp = buffer.asDoubleBuffer(); - tmp.put(data); - } - TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type); - return new OnnxTensor( - createTensorFromBuffer( - OnnxRuntime.ortApiHandle, - allocator.handle, - tmp, - bufferSize, - shape, - info.onnxType.value), - allocator.handle, - info, - tmp); + return createTensor(type, allocator, data, shape); } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); } @@ -599,25 +559,7 @@ public class OnnxTensor implements OnnxValue { throws OrtException { if ((!env.isClosed()) && (!allocator.isClosed())) { int bufferSize = data.capacity(); - ByteBuffer tmp; - if (data.isDirect()) { - tmp = data; - } else { - tmp = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder()); - tmp.put(data); - } - TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type); - return new OnnxTensor( - createTensorFromBuffer( - OnnxRuntime.ortApiHandle, - allocator.handle, - tmp, - bufferSize, - shape, - info.onnxType.value), - allocator.handle, - info, - tmp); + return createTensor(type, allocator, data, shape); } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); } @@ -658,27 +600,7 @@ public class OnnxTensor implements OnnxValue { throws OrtException { if ((!env.isClosed()) && (!allocator.isClosed())) { OnnxJavaType type = OnnxJavaType.INT16; - int bufferSize = data.capacity() * type.size; - ShortBuffer tmp; - if (data.isDirect()) { - tmp = data; - } else { - ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder()); - tmp = buffer.asShortBuffer(); - tmp.put(data); - } - TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type); - return new OnnxTensor( - createTensorFromBuffer( - OnnxRuntime.ortApiHandle, - allocator.handle, - tmp, - bufferSize, - shape, - info.onnxType.value), - allocator.handle, - info, - tmp); + return createTensor(type, allocator, data, shape); } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); } @@ -719,27 +641,7 @@ public class OnnxTensor implements OnnxValue { throws OrtException { if ((!env.isClosed()) && (!allocator.isClosed())) { OnnxJavaType type = OnnxJavaType.INT32; - int bufferSize = data.capacity() * type.size; - IntBuffer tmp; - if (data.isDirect()) { - tmp = data; - } else { - ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder()); - tmp = buffer.asIntBuffer(); - tmp.put(data); - } - TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type); - return new OnnxTensor( - createTensorFromBuffer( - OnnxRuntime.ortApiHandle, - allocator.handle, - tmp, - bufferSize, - shape, - info.onnxType.value), - allocator.handle, - info, - tmp); + return createTensor(type, allocator, data, shape); } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); } @@ -780,32 +682,91 @@ public class OnnxTensor implements OnnxValue { throws OrtException { if ((!env.isClosed()) && (!allocator.isClosed())) { OnnxJavaType type = OnnxJavaType.INT64; - int bufferSize = data.capacity() * type.size; - LongBuffer tmp; - if (data.isDirect()) { - tmp = data; - } else { - ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder()); - tmp = buffer.asLongBuffer(); - tmp.put(data); - } - TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type); - return new OnnxTensor( - createTensorFromBuffer( - OnnxRuntime.ortApiHandle, - allocator.handle, - tmp, - bufferSize, - shape, - info.onnxType.value), - allocator.handle, - info, - tmp); + return createTensor(type, allocator, data, shape); } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); } } + /** + * Creates a tensor by wrapping the data in a direct byte buffer and passing it to JNI. + * + *
Throws IllegalStateException if the buffer is too large to create a direct byte buffer copy,
+ * which is more than approximately (Integer.MAX_VALUE - 5) / type.size elements.
+ *
+ * @param type The buffer type.
+ * @param allocator The OrtAllocator.
+ * @param data The data.
+ * @param shape The tensor shape.
+ * @return An OnnxTensor instance.
+ * @throws OrtException If the create call failed.
+ */
+ private static OnnxTensor createTensor(
+ OnnxJavaType type, OrtAllocator allocator, Buffer data, long[] shape) throws OrtException {
+ int bufferPos;
+ int bufferSize = data.remaining() * type.size;
+ if ((bufferSize < 0) || (bufferSize > ((Integer.MAX_VALUE / type.size) - 10))) {
+ // overflowed as we can't make a direct byte buffer of that size
+ throw new IllegalStateException(
+ "Cannot allocate a direct buffer of the requested size and type, size "
+ + data.remaining()
+ + ", type = "
+ + type);
+ }
+ Buffer tmp;
+ if (data.isDirect()) {
+ tmp = data;
+ bufferPos = data.position() * type.size;
+ } else {
+ // Copy the data to a new direct buffer, then restore the state of the input.
+ int origPosition = data.position();
+ ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
+ switch (type) {
+ case FLOAT:
+ tmp = buffer.asFloatBuffer().put((FloatBuffer) data);
+ break;
+ case DOUBLE:
+ tmp = buffer.asDoubleBuffer().put((DoubleBuffer) data);
+ break;
+ case INT8:
+ // buffer is already a ByteBuffer, no cast needed.
+ tmp = buffer.put((ByteBuffer) data);
+ break;
+ case INT16:
+ tmp = buffer.asShortBuffer().put((ShortBuffer) data);
+ break;
+ case INT32:
+ tmp = buffer.asIntBuffer().put((IntBuffer) data);
+ break;
+ case INT64:
+ tmp = buffer.asLongBuffer().put((LongBuffer) data);
+ break;
+ case BOOL:
+ case STRING:
+ case UNKNOWN:
+ default:
+ throw new IllegalStateException(
+ "Impossible to reach here, managed to cast a buffer as an incorrect type");
+ }
+ data.position(origPosition);
+ tmp.rewind();
+ bufferPos = 0;
+ }
+ TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type);
+ return new OnnxTensor(
+ createTensorFromBuffer(
+ OnnxRuntime.ortApiHandle,
+ allocator.handle,
+ tmp,
+ bufferPos,
+ bufferSize,
+ shape,
+ info.onnxType.value),
+ allocator.handle,
+ info,
+ tmp);
+ }
+
private static native long createTensor(
long apiHandle, long allocatorHandle, Object data, long[] shape, int onnxType)
throws OrtException;
@@ -814,6 +775,7 @@ public class OnnxTensor implements OnnxValue {
long apiHandle,
long allocatorHandle,
Buffer data,
+ int bufferPos,
long bufferSize,
long[] shape,
int onnxType)
diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java
index 201305319f..8f53a3a732 100644
--- a/java/src/main/java/ai/onnxruntime/TensorInfo.java
+++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java
@@ -251,16 +251,16 @@ public class TensorInfo implements ValueInfo {
long elementCount = OrtUtil.elementCount(shape);
- long bufferCapacity = buffer.capacity();
+ long bufferRemaining = buffer.remaining();
- if (elementCount != bufferCapacity) {
+ if (elementCount != bufferRemaining) {
throw new OrtException(
"Shape "
+ Arrays.toString(shape)
+ ", requires "
+ elementCount
+ " elements but the buffer has "
- + bufferCapacity
+ + bufferRemaining
+ " elements.");
}
diff --git a/java/src/main/native/ai_onnxruntime_OnnxTensor.c b/java/src/main/native/ai_onnxruntime_OnnxTensor.c
index b24e7a62b2..f1b12dc905 100644
--- a/java/src/main/native/ai_onnxruntime_OnnxTensor.c
+++ b/java/src/main/native/ai_onnxruntime_OnnxTensor.c
@@ -56,10 +56,10 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor
/*
* Class: ai_onnxruntime_OnnxTensor
* Method: createTensorFromBuffer
- * Signature: (JJLjava/nio/Buffer;J[JI)J
+ * Signature: (JJLjava/nio/Buffer;IJ[JI)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensorFromBuffer
- (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject buffer, jlong bufferSize, jlongArray shape, jint onnxTypeJava) {
+ (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject buffer, jint bufferPos, jlong bufferSize, jlongArray shape, jint onnxTypeJava) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
@@ -70,7 +70,9 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensorFromBuffer
ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava);
// Extract the buffer
- void* bufferArr = (*jniEnv)->GetDirectBufferAddress(jniEnv,buffer);
+ char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv,buffer);
+ // Increment by bufferPos bytes
+ bufferArr = bufferArr + bufferPos;
// Extract the shape information
jboolean mkCopy;
diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java
index 4c29d65cf2..a2315bf7de 100644
--- a/java/src/test/java/ai/onnxruntime/InferenceTest.java
+++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java
@@ -26,11 +26,13 @@ import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
@@ -789,18 +791,68 @@ public class InferenceTest {
float[] resultBufferArray = new float[flatInput.length];
((OnnxTensor) res.get(0)).getFloatBuffer().get(resultBufferArray);
assertArrayEquals(flatInput, resultBufferArray, 1e-6f);
- OnnxValue.close(container);
}
- container.clear();
+ OnnxValue.close(container);
+ }
+ }
- // Now test loading from buffer
- FloatBuffer buffer = FloatBuffer.wrap(flatInput);
- OnnxTensor newTensor = OnnxTensor.createTensor(env, buffer, shape);
- container.put(inputName, newTensor);
- try (OrtSession.Result res = session.run(container)) {
- resultArray = TestHelpers.flattenFloat(res.get(0).getValue());
- assertArrayEquals(flatInput, resultArray, 1e-6f);
- OnnxValue.close(container);
+ @Test
+ public void testModelInputBuffer() throws OrtException {
+ // model takes 1x5 input of fixed type, echoes back
+ String modelPath = getResourcePath("/test_types_FLOAT.pb").toString();
+
+ try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputFLOAT");
+ SessionOptions options = new SessionOptions();
+ OrtSession session = env.createSession(modelPath, options)) {
+ String inputName = session.getInputNames().iterator().next();
+ long[] shape = new long[] {1, 5};
+ Map