diff --git a/java/src/main/java/ai/onnxruntime/OnnxJavaType.java b/java/src/main/java/ai/onnxruntime/OnnxJavaType.java index b3d8bb8e0c..12b720327f 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxJavaType.java +++ b/java/src/main/java/ai/onnxruntime/OnnxJavaType.java @@ -16,9 +16,10 @@ public enum OnnxJavaType { INT64(6, long.class, 8), BOOL(7, boolean.class, 1), STRING(8, String.class, 4), + UINT8(9, byte.class, 1), UNKNOWN(0, Object.class, 0); - private static final OnnxJavaType[] values = new OnnxJavaType[9]; + private static final OnnxJavaType[] values = new OnnxJavaType[10]; static { for (OnnxJavaType ot : OnnxJavaType.values()) { @@ -62,6 +63,7 @@ public enum OnnxJavaType { public static OnnxJavaType mapFromOnnxTensorType(OnnxTensorType onnxValue) { switch (onnxValue) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return OnnxJavaType.UINT8; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return OnnxJavaType.INT8; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: diff --git a/java/src/main/java/ai/onnxruntime/OnnxMap.java b/java/src/main/java/ai/onnxruntime/OnnxMap.java index d5286675df..9d60efb961 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxMap.java +++ b/java/src/main/java/ai/onnxruntime/OnnxMap.java @@ -77,6 +77,7 @@ public class OnnxMap implements OnnxValue { return OnnxMapValueType.LONG; case STRING: return OnnxMapValueType.STRING; + case UINT8: case INT8: case INT16: case INT32: diff --git a/java/src/main/java/ai/onnxruntime/OnnxSequence.java b/java/src/main/java/ai/onnxruntime/OnnxSequence.java index 69d56e580e..f5480d9e04 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxSequence.java +++ b/java/src/main/java/ai/onnxruntime/OnnxSequence.java @@ -100,6 +100,7 @@ public class OnnxSequence implements OnnxValue { list.addAll(Arrays.asList(strings)); return list; case BOOL: + case UINT8: case INT8: case INT16: case INT32: diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index 29b19ffef5..fa199c9389 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -78,6 +78,7 @@ public class OnnxTensor implements OnnxValue { return getFloat(OnnxRuntime.ortApiHandle, nativeHandle, info.onnxType.value); case DOUBLE: return getDouble(OnnxRuntime.ortApiHandle, nativeHandle); + case UINT8: case INT8: return getByte(OnnxRuntime.ortApiHandle, nativeHandle, info.onnxType.value); case INT16: @@ -744,6 +745,7 @@ public class OnnxTensor implements OnnxValue { case DOUBLE: tmp = buffer.asDoubleBuffer().put((DoubleBuffer) data); break; + case UINT8: case INT8: // buffer is already a ByteBuffer, no cast needed. tmp = buffer.put((ByteBuffer) data); diff --git a/java/src/main/java/ai/onnxruntime/OrtUtil.java b/java/src/main/java/ai/onnxruntime/OrtUtil.java index cddc34f65f..08c095ea7c 100644 --- a/java/src/main/java/ai/onnxruntime/OrtUtil.java +++ b/java/src/main/java/ai/onnxruntime/OrtUtil.java @@ -434,6 +434,7 @@ public final class OrtUtil { double[] doubleArr = new double[1]; doubleArr[0] = (Double) data; return doubleArr; + case UINT8: case INT8: byte[] byteArr = new byte[1]; byteArr[0] = (Byte) data; diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java index 550ddfe57f..ead90d83bc 100644 --- a/java/src/main/java/ai/onnxruntime/TensorInfo.java +++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java @@ -80,6 +80,8 @@ public class TensorInfo implements ValueInfo { return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; case INT8: return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; + case UINT8: + return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; case INT16: return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; case INT32: @@ -179,6 +181,7 @@ public class TensorInfo implements ValueInfo { return OrtUtil.newFloatArray(shape); case DOUBLE: return OrtUtil.newDoubleArray(shape); + case UINT8: case INT8: return OrtUtil.newByteArray(shape); case INT16: diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 38627d2ec7..a6a09111d1 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -1247,6 +1247,28 @@ public class InferenceTest { } } + @Test + public void testModelInputUINT8() throws OrtException { + String modelPath = getResourcePath("/test_types_UINT8.pb").toString(); + + try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputUINT8"); + SessionOptions options = new SessionOptions(); + OrtSession session = env.createSession(modelPath, options)) { + String inputName = session.getInputNames().iterator().next(); + Map container = new HashMap<>(); + byte[] flatInput = new byte[] {1, 2, -3, Byte.MIN_VALUE, Byte.MAX_VALUE}; + ByteBuffer data = ByteBuffer.wrap(flatInput); + long[] shape = new long[] {1, 5}; + OnnxTensor ov = OnnxTensor.createTensor(env, data, shape, OnnxJavaType.UINT8); + container.put(inputName, ov); + try (OrtSession.Result res = session.run(container)) { + byte[] resultArray = TestHelpers.flattenByte(res.get(0).getValue()); + assertArrayEquals(flatInput, resultArray); + } + OnnxValue.close(container); + } + } + @Test public void testModelInputINT16() throws OrtException { // model takes 1x5 input of fixed type, echoes back diff --git a/java/src/test/java/ai/onnxruntime/TensorCreationTest.java b/java/src/test/java/ai/onnxruntime/TensorCreationTest.java index 8ebba91b31..1b5eafd354 100644 --- a/java/src/test/java/ai/onnxruntime/TensorCreationTest.java +++ b/java/src/test/java/ai/onnxruntime/TensorCreationTest.java @@ -4,6 +4,7 @@ */ package ai.onnxruntime; +import java.nio.ByteBuffer; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -112,4 +113,16 @@ public class TensorCreationTest { } } } + + @Test + public void testUint8Creation() throws OrtException { + try (OrtEnvironment env = OrtEnvironment.getEnvironment()) { + byte[] buf = new byte[] {0, 1}; + ByteBuffer data = ByteBuffer.wrap(buf); + long[] shape = new long[] {2}; + try (OnnxTensor t = OnnxTensor.createTensor(env, data, shape, OnnxJavaType.UINT8)) { + Assertions.assertArrayEquals(buf, (byte[]) t.getValue()); + } + } + } }