diff --git a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java b/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java index 19c27441a3..f8caae96bb 100644 --- a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java +++ b/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java @@ -108,6 +108,37 @@ public class TensorHelperTest { outputTensor.close(); } + @Test + public void createInputTensor_uint8() throws Exception { + OnnxTensor outputTensor = OnnxTensor.createTensor(ortEnvironment, ByteBuffer.wrap(new byte[] {0, 2, (byte)255}), + new long[] {3}, OnnxJavaType.UINT8); + + JavaOnlyMap inputTensorMap = new JavaOnlyMap(); + + JavaOnlyArray dims = new JavaOnlyArray(); + dims.pushInt(3); + inputTensorMap.putArray("dims", dims); + + inputTensorMap.putString("type", TensorHelper.JsTensorTypeUnsignedByte); + + ByteBuffer dataByteBuffer = ByteBuffer.allocate(3); + dataByteBuffer.put((byte)0); + dataByteBuffer.put((byte)2); + dataByteBuffer.put((byte)255); + String dataEncoded = Base64.encodeToString(dataByteBuffer.array(), Base64.DEFAULT); + inputTensorMap.putString("data", dataEncoded); + + OnnxTensor inputTensor = TensorHelper.createInputTensor(inputTensorMap, ortEnvironment); + + Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8); + Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8); + Assert.assertEquals(inputTensor.toString(), outputTensor.toString()); + Assert.assertArrayEquals(inputTensor.getByteBuffer().array(), outputTensor.getByteBuffer().array()); + + inputTensor.close(); + outputTensor.close(); + } + @Test public void createInputTensor_int32() throws Exception { OnnxTensor outputTensor = diff --git a/js/react_native/lib/backend.ts b/js/react_native/lib/backend.ts index 98035af91f..4ebc364cd8 100644 --- a/js/react_native/lib/backend.ts +++ b/js/react_native/lib/backend.ts @@ -10,12 +10,14 @@ import {binding, Binding} from './binding'; type SupportedTypedArray = Exclude; const tensorTypeToTypedArray = (type: Tensor.Type):|Float32ArrayConstructor|Int8ArrayConstructor|Int16ArrayConstructor| - Int32ArrayConstructor|BigInt64ArrayConstructor|Float64ArrayConstructor => { + Int32ArrayConstructor|BigInt64ArrayConstructor|Float64ArrayConstructor|Uint8ArrayConstructor => { switch (type) { case 'float32': return Float32Array; case 'int8': return Int8Array; + case 'uint8': + return Uint8Array; case 'int16': return Int16Array; case 'int32':