diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index 7e3e27bfa5..752652bf22 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -439,27 +439,7 @@ public class OnnxTensor implements OnnxValue { throws OrtException { if ((!env.isClosed()) && (!allocator.isClosed())) { OnnxJavaType type = OnnxJavaType.FLOAT; - int bufferSize = data.capacity() * type.size; - FloatBuffer tmp; - if (data.isDirect()) { - tmp = data; - } else { - ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder()); - tmp = buffer.asFloatBuffer(); - tmp.put(data); - } - TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type); - return new OnnxTensor( - createTensorFromBuffer( - OnnxRuntime.ortApiHandle, - allocator.handle, - tmp, - bufferSize, - shape, - info.onnxType.value), - allocator.handle, - info, - tmp); + return createTensor(type, allocator, data, shape); } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); } @@ -500,27 +480,7 @@ public class OnnxTensor implements OnnxValue { throws OrtException { if ((!env.isClosed()) && (!allocator.isClosed())) { OnnxJavaType type = OnnxJavaType.DOUBLE; - int bufferSize = data.capacity() * type.size; - DoubleBuffer tmp; - if (data.isDirect()) { - tmp = data; - } else { - ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder()); - tmp = buffer.asDoubleBuffer(); - tmp.put(data); - } - TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type); - return new OnnxTensor( - createTensorFromBuffer( - OnnxRuntime.ortApiHandle, - allocator.handle, - tmp, - bufferSize, - shape, - info.onnxType.value), - allocator.handle, - info, - tmp); + return createTensor(type, allocator, data, shape); } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); } @@ -599,25 +559,7 @@ public class OnnxTensor implements OnnxValue { throws OrtException { if ((!env.isClosed()) && (!allocator.isClosed())) { int bufferSize = data.capacity(); - ByteBuffer tmp; - if (data.isDirect()) { - tmp = data; - } else { - tmp = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder()); - tmp.put(data); - } - TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type); - return new OnnxTensor( - createTensorFromBuffer( - OnnxRuntime.ortApiHandle, - allocator.handle, - tmp, - bufferSize, - shape, - info.onnxType.value), - allocator.handle, - info, - tmp); + return createTensor(type, allocator, data, shape); } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); } @@ -658,27 +600,7 @@ public class OnnxTensor implements OnnxValue { throws OrtException { if ((!env.isClosed()) && (!allocator.isClosed())) { OnnxJavaType type = OnnxJavaType.INT16; - int bufferSize = data.capacity() * type.size; - ShortBuffer tmp; - if (data.isDirect()) { - tmp = data; - } else { - ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder()); - tmp = buffer.asShortBuffer(); - tmp.put(data); - } - TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type); - return new OnnxTensor( - createTensorFromBuffer( - OnnxRuntime.ortApiHandle, - allocator.handle, - tmp, - bufferSize, - shape, - info.onnxType.value), - allocator.handle, - info, - tmp); + return createTensor(type, allocator, data, shape); } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); } @@ -719,27 +641,7 @@ public class OnnxTensor implements OnnxValue { throws OrtException { if ((!env.isClosed()) && (!allocator.isClosed())) { OnnxJavaType type = OnnxJavaType.INT32; - int bufferSize = data.capacity() * type.size; - IntBuffer tmp; - if (data.isDirect()) { - tmp = data; - } else { - ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder()); - tmp = buffer.asIntBuffer(); - tmp.put(data); - } - TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type); - return new OnnxTensor( - createTensorFromBuffer( - OnnxRuntime.ortApiHandle, - allocator.handle, - tmp, - bufferSize, - shape, - info.onnxType.value), - allocator.handle, - info, - tmp); + return createTensor(type, allocator, data, shape); } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); } @@ -780,32 +682,91 @@ public class OnnxTensor implements OnnxValue { throws OrtException { if ((!env.isClosed()) && (!allocator.isClosed())) { OnnxJavaType type = OnnxJavaType.INT64; - int bufferSize = data.capacity() * type.size; - LongBuffer tmp; - if (data.isDirect()) { - tmp = data; - } else { - ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder()); - tmp = buffer.asLongBuffer(); - tmp.put(data); - } - TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type); - return new OnnxTensor( - createTensorFromBuffer( - OnnxRuntime.ortApiHandle, - allocator.handle, - tmp, - bufferSize, - shape, - info.onnxType.value), - allocator.handle, - info, - tmp); + return createTensor(type, allocator, data, shape); } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); } } + /** + * Creates a tensor by wrapping the data in a direct byte buffer and passing it to JNI. + * + *

Throws IllegalStateException if the buffer is too large to create a direct byte buffer copy, + * which is more than approximately (Integer.MAX_VALUE - 5) / type.size elements. + * + * @param type The buffer type. + * @param allocator The OrtAllocator. + * @param data The data. + * @param shape The tensor shape. + * @return An OnnxTensor instance. + * @throws OrtException If the create call failed. + */ + private static OnnxTensor createTensor( + OnnxJavaType type, OrtAllocator allocator, Buffer data, long[] shape) throws OrtException { + int bufferPos; + int bufferSize = data.remaining() * type.size; + if ((bufferSize < 0) || (bufferSize > ((Integer.MAX_VALUE / type.size) - 10))) { + // overflowed as we can't make a direct byte buffer of that size + throw new IllegalStateException( + "Cannot allocate a direct buffer of the requested size and type, size " + + data.remaining() + + ", type = " + + type); + } + Buffer tmp; + if (data.isDirect()) { + tmp = data; + bufferPos = data.position() * type.size; + } else { + // Copy the data to a new direct buffer, then restore the state of the input. + int origPosition = data.position(); + ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder()); + switch (type) { + case FLOAT: + tmp = buffer.asFloatBuffer().put((FloatBuffer) data); + break; + case DOUBLE: + tmp = buffer.asDoubleBuffer().put((DoubleBuffer) data); + break; + case INT8: + // buffer is already a ByteBuffer, no cast needed. + tmp = buffer.put((ByteBuffer) data); + break; + case INT16: + tmp = buffer.asShortBuffer().put((ShortBuffer) data); + break; + case INT32: + tmp = buffer.asIntBuffer().put((IntBuffer) data); + break; + case INT64: + tmp = buffer.asLongBuffer().put((LongBuffer) data); + break; + case BOOL: + case STRING: + case UNKNOWN: + default: + throw new IllegalStateException( + "Impossible to reach here, managed to cast a buffer as an incorrect type"); + } + data.position(origPosition); + tmp.rewind(); + bufferPos = 0; + } + TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type); + return new OnnxTensor( + createTensorFromBuffer( + OnnxRuntime.ortApiHandle, + allocator.handle, + tmp, + bufferPos, + bufferSize, + shape, + info.onnxType.value), + allocator.handle, + info, + tmp); + } + private static native long createTensor( long apiHandle, long allocatorHandle, Object data, long[] shape, int onnxType) throws OrtException; @@ -814,6 +775,7 @@ public class OnnxTensor implements OnnxValue { long apiHandle, long allocatorHandle, Buffer data, + int bufferPos, long bufferSize, long[] shape, int onnxType) diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java index 201305319f..8f53a3a732 100644 --- a/java/src/main/java/ai/onnxruntime/TensorInfo.java +++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java @@ -251,16 +251,16 @@ public class TensorInfo implements ValueInfo { long elementCount = OrtUtil.elementCount(shape); - long bufferCapacity = buffer.capacity(); + long bufferRemaining = buffer.remaining(); - if (elementCount != bufferCapacity) { + if (elementCount != bufferRemaining) { throw new OrtException( "Shape " + Arrays.toString(shape) + ", requires " + elementCount + " elements but the buffer has " - + bufferCapacity + + bufferRemaining + " elements."); } diff --git a/java/src/main/native/ai_onnxruntime_OnnxTensor.c b/java/src/main/native/ai_onnxruntime_OnnxTensor.c index b24e7a62b2..f1b12dc905 100644 --- a/java/src/main/native/ai_onnxruntime_OnnxTensor.c +++ b/java/src/main/native/ai_onnxruntime_OnnxTensor.c @@ -56,10 +56,10 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor /* * Class: ai_onnxruntime_OnnxTensor * Method: createTensorFromBuffer - * Signature: (JJLjava/nio/Buffer;J[JI)J + * Signature: (JJLjava/nio/Buffer;IJ[JI)J */ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensorFromBuffer - (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject buffer, jlong bufferSize, jlongArray shape, jint onnxTypeJava) { + (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject buffer, jint bufferPos, jlong bufferSize, jlongArray shape, jint onnxTypeJava) { (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; @@ -70,7 +70,9 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensorFromBuffer ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava); // Extract the buffer - void* bufferArr = (*jniEnv)->GetDirectBufferAddress(jniEnv,buffer); + char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv,buffer); + // Increment by bufferPos bytes + bufferArr = bufferArr + bufferPos; // Extract the shape information jboolean mkCopy; diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 4c29d65cf2..a2315bf7de 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -26,11 +26,13 @@ import java.io.IOException; import java.io.InputStream; import java.io.UncheckedIOException; import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.FloatBuffer; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -789,18 +791,68 @@ public class InferenceTest { float[] resultBufferArray = new float[flatInput.length]; ((OnnxTensor) res.get(0)).getFloatBuffer().get(resultBufferArray); assertArrayEquals(flatInput, resultBufferArray, 1e-6f); - OnnxValue.close(container); } - container.clear(); + OnnxValue.close(container); + } + } - // Now test loading from buffer - FloatBuffer buffer = FloatBuffer.wrap(flatInput); - OnnxTensor newTensor = OnnxTensor.createTensor(env, buffer, shape); - container.put(inputName, newTensor); - try (OrtSession.Result res = session.run(container)) { - resultArray = TestHelpers.flattenFloat(res.get(0).getValue()); - assertArrayEquals(flatInput, resultArray, 1e-6f); - OnnxValue.close(container); + @Test + public void testModelInputBuffer() throws OrtException { + // model takes 1x5 input of fixed type, echoes back + String modelPath = getResourcePath("/test_types_FLOAT.pb").toString(); + + try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputFLOAT"); + SessionOptions options = new SessionOptions(); + OrtSession session = env.createSession(modelPath, options)) { + String inputName = session.getInputNames().iterator().next(); + long[] shape = new long[] {1, 5}; + Map container = new HashMap<>(); + float[] inputArr = + new float[] { + 1.0f, -2.0f, 3.0f, -4.0f, 5.0f, -6.0f, 7.0f, -8.0f, 9.0f, -10.0f, 11.0f, -12.0f, 13.0f, + -14.0f, 15 + }; + FloatBuffer buffer = FloatBuffer.wrap(inputArr); + FloatBuffer directBuffer = + ByteBuffer.allocateDirect(inputArr.length * 4) + .order(ByteOrder.nativeOrder()) + .asFloatBuffer() + .put(buffer); + buffer.rewind(); + directBuffer.rewind(); + float[] resultArray; + + // Test loading from buffer + for (int i = 0; i < 3; i++) { + // Set limits + buffer.position(i * 5); + buffer.limit((i + 1) * 5); + directBuffer.position(i * 5); + directBuffer.limit((i + 1) * 5); + + // Check regular buffer (copies to direct) + OnnxTensor newTensor = OnnxTensor.createTensor(env, buffer, shape); + container.put(inputName, newTensor); + try (OrtSession.Result res = session.run(container)) { + resultArray = TestHelpers.flattenFloat(res.get(0).getValue()); + assertArrayEquals(Arrays.copyOfRange(inputArr, i * 5, (i + 1) * 5), resultArray, 1e-6f); + OnnxValue.close(container); + } + container.clear(); + // buffer should be unchanged + assertEquals(i * 5, buffer.position()); + + // Check direct buffer (no-copy) + newTensor = OnnxTensor.createTensor(env, directBuffer, shape); + container.put(inputName, newTensor); + try (OrtSession.Result res = session.run(container)) { + resultArray = TestHelpers.flattenFloat(res.get(0).getValue()); + assertArrayEquals(Arrays.copyOfRange(inputArr, i * 5, (i + 1) * 5), resultArray, 1e-6f); + OnnxValue.close(container); + } + container.clear(); + // direct buffer should be unchanged + assertEquals(i * 5, directBuffer.position()); } } }