Add UINT8 datatype support to Java (#8401)

Add UINT8 datatype support
Add inference test for UINT8 model
This commit is contained in:
Frank Liu 2021-07-22 17:11:49 -07:00 committed by GitHub
parent 950fe5e28b
commit 002e427c5b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 46 additions and 1 deletions

View file

@ -16,9 +16,10 @@ public enum OnnxJavaType {
INT64(6, long.class, 8),
BOOL(7, boolean.class, 1),
STRING(8, String.class, 4),
UINT8(9, byte.class, 1),
UNKNOWN(0, Object.class, 0);
private static final OnnxJavaType[] values = new OnnxJavaType[9];
private static final OnnxJavaType[] values = new OnnxJavaType[10];
static {
for (OnnxJavaType ot : OnnxJavaType.values()) {
@ -62,6 +63,7 @@ public enum OnnxJavaType {
public static OnnxJavaType mapFromOnnxTensorType(OnnxTensorType onnxValue) {
switch (onnxValue) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
return OnnxJavaType.UINT8;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
return OnnxJavaType.INT8;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:

View file

@ -77,6 +77,7 @@ public class OnnxMap implements OnnxValue {
return OnnxMapValueType.LONG;
case STRING:
return OnnxMapValueType.STRING;
case UINT8:
case INT8:
case INT16:
case INT32:

View file

@ -100,6 +100,7 @@ public class OnnxSequence implements OnnxValue {
list.addAll(Arrays.asList(strings));
return list;
case BOOL:
case UINT8:
case INT8:
case INT16:
case INT32:

View file

@ -78,6 +78,7 @@ public class OnnxTensor implements OnnxValue {
return getFloat(OnnxRuntime.ortApiHandle, nativeHandle, info.onnxType.value);
case DOUBLE:
return getDouble(OnnxRuntime.ortApiHandle, nativeHandle);
case UINT8:
case INT8:
return getByte(OnnxRuntime.ortApiHandle, nativeHandle, info.onnxType.value);
case INT16:
@ -744,6 +745,7 @@ public class OnnxTensor implements OnnxValue {
case DOUBLE:
tmp = buffer.asDoubleBuffer().put((DoubleBuffer) data);
break;
case UINT8:
case INT8:
// buffer is already a ByteBuffer, no cast needed.
tmp = buffer.put((ByteBuffer) data);

View file

@ -434,6 +434,7 @@ public final class OrtUtil {
double[] doubleArr = new double[1];
doubleArr[0] = (Double) data;
return doubleArr;
case UINT8:
case INT8:
byte[] byteArr = new byte[1];
byteArr[0] = (Byte) data;

View file

@ -80,6 +80,8 @@ public class TensorInfo implements ValueInfo {
return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
case INT8:
return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
case UINT8:
return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
case INT16:
return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
case INT32:
@ -179,6 +181,7 @@ public class TensorInfo implements ValueInfo {
return OrtUtil.newFloatArray(shape);
case DOUBLE:
return OrtUtil.newDoubleArray(shape);
case UINT8:
case INT8:
return OrtUtil.newByteArray(shape);
case INT16:

View file

@ -1247,6 +1247,28 @@ public class InferenceTest {
}
}
@Test
public void testModelInputUINT8() throws OrtException {
String modelPath = getResourcePath("/test_types_UINT8.pb").toString();
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputUINT8");
SessionOptions options = new SessionOptions();
OrtSession session = env.createSession(modelPath, options)) {
String inputName = session.getInputNames().iterator().next();
Map<String, OnnxTensor> container = new HashMap<>();
byte[] flatInput = new byte[] {1, 2, -3, Byte.MIN_VALUE, Byte.MAX_VALUE};
ByteBuffer data = ByteBuffer.wrap(flatInput);
long[] shape = new long[] {1, 5};
OnnxTensor ov = OnnxTensor.createTensor(env, data, shape, OnnxJavaType.UINT8);
container.put(inputName, ov);
try (OrtSession.Result res = session.run(container)) {
byte[] resultArray = TestHelpers.flattenByte(res.get(0).getValue());
assertArrayEquals(flatInput, resultArray);
}
OnnxValue.close(container);
}
}
@Test
public void testModelInputINT16() throws OrtException {
// model takes 1x5 input of fixed type, echoes back

View file

@ -4,6 +4,7 @@
*/
package ai.onnxruntime;
import java.nio.ByteBuffer;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
@ -112,4 +113,16 @@ public class TensorCreationTest {
}
}
}
@Test
public void testUint8Creation() throws OrtException {
try (OrtEnvironment env = OrtEnvironment.getEnvironment()) {
byte[] buf = new byte[] {0, 1};
ByteBuffer data = ByteBuffer.wrap(buf);
long[] shape = new long[] {2};
try (OnnxTensor t = OnnxTensor.createTensor(env, data, shape, OnnxJavaType.UINT8)) {
Assertions.assertArrayEquals(buf, (byte[]) t.getValue());
}
}
}
}