mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
[rn] Add uint8 typedArray support for react native android (#13622)
### Description <!-- Describe your changes. --> - Add missing uint8 typedArray case - Add createInputTensor_uint8 unit test in TensorHelperTest.java file ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Detected inferencesession.run() call error when running react native app with uint8array input ort tensor. Add missing support to fix.
This commit is contained in:
parent
359091f64a
commit
2efd2878ab
2 changed files with 34 additions and 1 deletions
|
|
@ -108,6 +108,37 @@ public class TensorHelperTest {
|
|||
outputTensor.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createInputTensor_uint8() throws Exception {
|
||||
OnnxTensor outputTensor = OnnxTensor.createTensor(ortEnvironment, ByteBuffer.wrap(new byte[] {0, 2, (byte)255}),
|
||||
new long[] {3}, OnnxJavaType.UINT8);
|
||||
|
||||
JavaOnlyMap inputTensorMap = new JavaOnlyMap();
|
||||
|
||||
JavaOnlyArray dims = new JavaOnlyArray();
|
||||
dims.pushInt(3);
|
||||
inputTensorMap.putArray("dims", dims);
|
||||
|
||||
inputTensorMap.putString("type", TensorHelper.JsTensorTypeUnsignedByte);
|
||||
|
||||
ByteBuffer dataByteBuffer = ByteBuffer.allocate(3);
|
||||
dataByteBuffer.put((byte)0);
|
||||
dataByteBuffer.put((byte)2);
|
||||
dataByteBuffer.put((byte)255);
|
||||
String dataEncoded = Base64.encodeToString(dataByteBuffer.array(), Base64.DEFAULT);
|
||||
inputTensorMap.putString("data", dataEncoded);
|
||||
|
||||
OnnxTensor inputTensor = TensorHelper.createInputTensor(inputTensorMap, ortEnvironment);
|
||||
|
||||
Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
|
||||
Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
|
||||
Assert.assertEquals(inputTensor.toString(), outputTensor.toString());
|
||||
Assert.assertArrayEquals(inputTensor.getByteBuffer().array(), outputTensor.getByteBuffer().array());
|
||||
|
||||
inputTensor.close();
|
||||
outputTensor.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createInputTensor_int32() throws Exception {
|
||||
OnnxTensor outputTensor =
|
||||
|
|
|
|||
|
|
@ -10,12 +10,14 @@ import {binding, Binding} from './binding';
|
|||
type SupportedTypedArray = Exclude<Tensor.DataType, string[]>;
|
||||
|
||||
const tensorTypeToTypedArray = (type: Tensor.Type):|Float32ArrayConstructor|Int8ArrayConstructor|Int16ArrayConstructor|
|
||||
Int32ArrayConstructor|BigInt64ArrayConstructor|Float64ArrayConstructor => {
|
||||
Int32ArrayConstructor|BigInt64ArrayConstructor|Float64ArrayConstructor|Uint8ArrayConstructor => {
|
||||
switch (type) {
|
||||
case 'float32':
|
||||
return Float32Array;
|
||||
case 'int8':
|
||||
return Int8Array;
|
||||
case 'uint8':
|
||||
return Uint8Array;
|
||||
case 'int16':
|
||||
return Int16Array;
|
||||
case 'int32':
|
||||
|
|
|
|||
Loading…
Reference in a new issue