From 2efd2878abe95aef74bb6a5624a28b585f2636e4 Mon Sep 17 00:00:00 2001 From: Rachel Guo <35738743+YUNQIUGUO@users.noreply.github.com> Date: Wed, 16 Nov 2022 12:37:47 -0800 Subject: [PATCH] [rn] Add uint8 typedArray support for react native android (#13622) ### Description - Add missing uint8 typedArray case - Add createInputTensor_uint8 unit test in TensorHelperTest.java file ### Motivation and Context Detected inferencesession.run() call error when running react native app with uint8array input ort tensor. Add missing support to fix. --- .../reactnative/TensorHelperTest.java | 31 +++++++++++++++++++ js/react_native/lib/backend.ts | 4 ++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java b/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java index 19c27441a3..f8caae96bb 100644 --- a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java +++ b/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java @@ -108,6 +108,37 @@ public class TensorHelperTest { outputTensor.close(); } + @Test + public void createInputTensor_uint8() throws Exception { + OnnxTensor outputTensor = OnnxTensor.createTensor(ortEnvironment, ByteBuffer.wrap(new byte[] {0, 2, (byte)255}), + new long[] {3}, OnnxJavaType.UINT8); + + JavaOnlyMap inputTensorMap = new JavaOnlyMap(); + + JavaOnlyArray dims = new JavaOnlyArray(); + dims.pushInt(3); + inputTensorMap.putArray("dims", dims); + + inputTensorMap.putString("type", TensorHelper.JsTensorTypeUnsignedByte); + + ByteBuffer dataByteBuffer = ByteBuffer.allocate(3); + dataByteBuffer.put((byte)0); + dataByteBuffer.put((byte)2); + dataByteBuffer.put((byte)255); + String dataEncoded = Base64.encodeToString(dataByteBuffer.array(), Base64.DEFAULT); + inputTensorMap.putString("data", dataEncoded); + + OnnxTensor inputTensor = TensorHelper.createInputTensor(inputTensorMap, ortEnvironment); + + Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8); + Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8); + Assert.assertEquals(inputTensor.toString(), outputTensor.toString()); + Assert.assertArrayEquals(inputTensor.getByteBuffer().array(), outputTensor.getByteBuffer().array()); + + inputTensor.close(); + outputTensor.close(); + } + @Test public void createInputTensor_int32() throws Exception { OnnxTensor outputTensor = diff --git a/js/react_native/lib/backend.ts b/js/react_native/lib/backend.ts index 98035af91f..4ebc364cd8 100644 --- a/js/react_native/lib/backend.ts +++ b/js/react_native/lib/backend.ts @@ -10,12 +10,14 @@ import {binding, Binding} from './binding'; type SupportedTypedArray = Exclude; const tensorTypeToTypedArray = (type: Tensor.Type):|Float32ArrayConstructor|Int8ArrayConstructor|Int16ArrayConstructor| - Int32ArrayConstructor|BigInt64ArrayConstructor|Float64ArrayConstructor => { + Int32ArrayConstructor|BigInt64ArrayConstructor|Float64ArrayConstructor|Uint8ArrayConstructor => { switch (type) { case 'float32': return Float32Array; case 'int8': return Int8Array; + case 'uint8': + return Uint8Array; case 'int16': return Int16Array; case 'int32':