mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
Add UINT8 datatype support to Java (#8401)
Add UINT8 datatype support Add inference test for UINT8 model
This commit is contained in:
parent
950fe5e28b
commit
002e427c5b
8 changed files with 46 additions and 1 deletions
|
|
@ -16,9 +16,10 @@ public enum OnnxJavaType {
|
|||
INT64(6, long.class, 8),
|
||||
BOOL(7, boolean.class, 1),
|
||||
STRING(8, String.class, 4),
|
||||
UINT8(9, byte.class, 1),
|
||||
UNKNOWN(0, Object.class, 0);
|
||||
|
||||
private static final OnnxJavaType[] values = new OnnxJavaType[9];
|
||||
private static final OnnxJavaType[] values = new OnnxJavaType[10];
|
||||
|
||||
static {
|
||||
for (OnnxJavaType ot : OnnxJavaType.values()) {
|
||||
|
|
@ -62,6 +63,7 @@ public enum OnnxJavaType {
|
|||
public static OnnxJavaType mapFromOnnxTensorType(OnnxTensorType onnxValue) {
|
||||
switch (onnxValue) {
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
return OnnxJavaType.UINT8;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
|
||||
return OnnxJavaType.INT8;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ public class OnnxMap implements OnnxValue {
|
|||
return OnnxMapValueType.LONG;
|
||||
case STRING:
|
||||
return OnnxMapValueType.STRING;
|
||||
case UINT8:
|
||||
case INT8:
|
||||
case INT16:
|
||||
case INT32:
|
||||
|
|
|
|||
|
|
@ -100,6 +100,7 @@ public class OnnxSequence implements OnnxValue {
|
|||
list.addAll(Arrays.asList(strings));
|
||||
return list;
|
||||
case BOOL:
|
||||
case UINT8:
|
||||
case INT8:
|
||||
case INT16:
|
||||
case INT32:
|
||||
|
|
|
|||
|
|
@ -78,6 +78,7 @@ public class OnnxTensor implements OnnxValue {
|
|||
return getFloat(OnnxRuntime.ortApiHandle, nativeHandle, info.onnxType.value);
|
||||
case DOUBLE:
|
||||
return getDouble(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
case UINT8:
|
||||
case INT8:
|
||||
return getByte(OnnxRuntime.ortApiHandle, nativeHandle, info.onnxType.value);
|
||||
case INT16:
|
||||
|
|
@ -744,6 +745,7 @@ public class OnnxTensor implements OnnxValue {
|
|||
case DOUBLE:
|
||||
tmp = buffer.asDoubleBuffer().put((DoubleBuffer) data);
|
||||
break;
|
||||
case UINT8:
|
||||
case INT8:
|
||||
// buffer is already a ByteBuffer, no cast needed.
|
||||
tmp = buffer.put((ByteBuffer) data);
|
||||
|
|
|
|||
|
|
@ -434,6 +434,7 @@ public final class OrtUtil {
|
|||
double[] doubleArr = new double[1];
|
||||
doubleArr[0] = (Double) data;
|
||||
return doubleArr;
|
||||
case UINT8:
|
||||
case INT8:
|
||||
byte[] byteArr = new byte[1];
|
||||
byteArr[0] = (Byte) data;
|
||||
|
|
|
|||
|
|
@ -80,6 +80,8 @@ public class TensorInfo implements ValueInfo {
|
|||
return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
|
||||
case INT8:
|
||||
return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
|
||||
case UINT8:
|
||||
return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
case INT16:
|
||||
return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
|
||||
case INT32:
|
||||
|
|
@ -179,6 +181,7 @@ public class TensorInfo implements ValueInfo {
|
|||
return OrtUtil.newFloatArray(shape);
|
||||
case DOUBLE:
|
||||
return OrtUtil.newDoubleArray(shape);
|
||||
case UINT8:
|
||||
case INT8:
|
||||
return OrtUtil.newByteArray(shape);
|
||||
case INT16:
|
||||
|
|
|
|||
|
|
@ -1247,6 +1247,28 @@ public class InferenceTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testModelInputUINT8() throws OrtException {
|
||||
String modelPath = getResourcePath("/test_types_UINT8.pb").toString();
|
||||
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputUINT8");
|
||||
SessionOptions options = new SessionOptions();
|
||||
OrtSession session = env.createSession(modelPath, options)) {
|
||||
String inputName = session.getInputNames().iterator().next();
|
||||
Map<String, OnnxTensor> container = new HashMap<>();
|
||||
byte[] flatInput = new byte[] {1, 2, -3, Byte.MIN_VALUE, Byte.MAX_VALUE};
|
||||
ByteBuffer data = ByteBuffer.wrap(flatInput);
|
||||
long[] shape = new long[] {1, 5};
|
||||
OnnxTensor ov = OnnxTensor.createTensor(env, data, shape, OnnxJavaType.UINT8);
|
||||
container.put(inputName, ov);
|
||||
try (OrtSession.Result res = session.run(container)) {
|
||||
byte[] resultArray = TestHelpers.flattenByte(res.get(0).getValue());
|
||||
assertArrayEquals(flatInput, resultArray);
|
||||
}
|
||||
OnnxValue.close(container);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testModelInputINT16() throws OrtException {
|
||||
// model takes 1x5 input of fixed type, echoes back
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
*/
|
||||
package ai.onnxruntime;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
|
@ -112,4 +113,16 @@ public class TensorCreationTest {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUint8Creation() throws OrtException {
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment()) {
|
||||
byte[] buf = new byte[] {0, 1};
|
||||
ByteBuffer data = ByteBuffer.wrap(buf);
|
||||
long[] shape = new long[] {2};
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, data, shape, OnnxJavaType.UINT8)) {
|
||||
Assertions.assertArrayEquals(buf, (byte[]) t.getValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue