From 814e5cfa4c81b059b9865edcbc33510bdf1b1ebe Mon Sep 17 00:00:00 2001 From: Rachel Guo <35738743+YUNQIUGUO@users.noreply.github.com> Date: Thu, 6 Oct 2022 11:35:25 -0700 Subject: [PATCH] [rn] Support UINT8 type for onnxruntime-react-native on iOS (#13210) ### Description As title. ### Motivation and Context Uint8 type might be required for some model used in sample application. To match supported data types for onnxruntime-react-native for Android. Co-authored-by: rachguo Co-authored-by: rachguo --- .../Resources/test_types_uint8.ort | Bin 0 -> 1872 bytes .../OnnxruntimeModuleTest/TensorHelperTest.mm | 16 ++++++++++++++++ js/react_native/ios/TensorHelper.h | 1 + js/react_native/ios/TensorHelper.mm | 9 +++++++-- 4 files changed, 24 insertions(+), 2 deletions(-) create mode 100644 js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_uint8.ort diff --git a/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_uint8.ort b/js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_uint8.ort new file mode 100644 index 0000000000000000000000000000000000000000..9f5310803323a5cd6e3db7cdae24041143608da2 GIT binary patch literal 1872 zcmZ8hJ7^S96uqNkGDeN7F0zP(6kDVS!-^jySom3sK-5K}1PQ`2J4pt1-z<~K$I2~4 z#43e_g{6py45k~RiQiHA?W0DyvMD5 z9wXhme7p;+;_JIW?H7_>BF@!*4&CwIvpkMhS$hNHAn#}oFVFD>lkD9)VC(S>LA(H6 z@IJC8dt>jPEH`c5H~PN9XZGi4-~PDsFz?yd@7cr^`r>p8;K9|BXTS_#*C~&{&3?F6 zzV8XZ_6N=Er5ktm&-l+_;^RPm89wH_nyyFME}!~Y8P%_gFEPL>M{S<#!|}k~Cxbj^ zFiRiziT|^7C2z0b;?w3@djDCBW6!3(8JW)>6Q$Gizw{vjY5?ESHs(G68$bi70OWp! zJ4+r4ZLT@Z7Q^=aJk6q|u%5??bDINkC*X_HR@h3SEL~0)@-SX+$63;fRbFq3bg-X} zm-*!8zojn+_C!w{WB9MDCt<48df3k5)g)dE^Q@kzL@i=U<_6-~uZ!=E;~OlXD(}Y; zCR$+)-e=v%a+s(*&Q!e_s-?AvA@9J?`$b_a|FMq0g>`)ysDBLeeQOF;M5+MQ&8rYO Vj2H^}obY~Jnx0k*S^fVx{sZOF4GjPQ literal 0 HcmV?d00001 diff --git a/js/react_native/ios/OnnxruntimeModuleTest/TensorHelperTest.mm b/js/react_native/ios/OnnxruntimeModuleTest/TensorHelperTest.mm index e2fd4ec7c6..3ed082e222 100644 --- a/js/react_native/ios/OnnxruntimeModuleTest/TensorHelperTest.mm +++ b/js/react_native/ios/OnnxruntimeModuleTest/TensorHelperTest.mm @@ -73,6 +73,14 @@ static void testCreateInputTensorT(const std::array &outValues, std::funct testCreateInputTensorT(outValues, convert, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, JsTensorTypeBool); } +- (void)testCreateInputTensorUInt8 { + std::array outValues{std::numeric_limits::min(), 2, std::numeric_limits::max()}; + std::function convert = [](uint8_t value) { + return [NSNumber numberWithUnsignedChar:value]; + }; + testCreateInputTensorT(outValues, convert, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, JsTensorTypeUnsignedByte); +} + - (void)testCreateInputTensorInt8 { std::array outValues{std::numeric_limits::min(), 2, std::numeric_limits::max()}; std::function convert = [](int8_t value) { return [NSNumber numberWithChar:value]; }; @@ -225,6 +233,14 @@ static void testCreateOutputTensorT(const std::array &outValues, std::func testCreateOutputTensorT(outValues, convert, JsTensorTypeBool, @"test_types_bool", @"onnx"); } +- (void)testCreateOutputTensorUInt8 { + std::array outValues{std::numeric_limits::min(), 1, 2, 3, std::numeric_limits::max()}; + std::function convert = [](uint8_t value) { + return [NSNumber numberWithUnsignedChar:value]; + }; + testCreateOutputTensorT(outValues, convert, JsTensorTypeUnsignedByte, @"test_types_uint8", @"ort"); +} + - (void)testCreateOutputTensorInt8 { std::array outValues{std::numeric_limits::min(), 1, -2, 3, std::numeric_limits::max()}; std::function convert = [](int8_t value) { return [NSNumber numberWithChar:value]; }; diff --git a/js/react_native/ios/TensorHelper.h b/js/react_native/ios/TensorHelper.h index 40203c0cf9..f0936cce8b 100644 --- a/js/react_native/ios/TensorHelper.h +++ b/js/react_native/ios/TensorHelper.h @@ -13,6 +13,7 @@ * Supported tensor data type */ FOUNDATION_EXPORT NSString* const JsTensorTypeBool; +FOUNDATION_EXPORT NSString* const JsTensorTypeUnsignedByte; FOUNDATION_EXPORT NSString* const JsTensorTypeByte; FOUNDATION_EXPORT NSString* const JsTensorTypeShort; FOUNDATION_EXPORT NSString* const JsTensorTypeInt; diff --git a/js/react_native/ios/TensorHelper.mm b/js/react_native/ios/TensorHelper.mm index eecf064d7d..00c1c79def 100644 --- a/js/react_native/ios/TensorHelper.mm +++ b/js/react_native/ios/TensorHelper.mm @@ -10,6 +10,7 @@ * Supported tensor data type */ NSString *const JsTensorTypeBool = @"bool"; +NSString *const JsTensorTypeUnsignedByte = @"uint8"; NSString *const JsTensorTypeByte = @"int8"; NSString *const JsTensorTypeShort = @"int16"; NSString *const JsTensorTypeInt = @"int32"; @@ -137,6 +138,8 @@ static Ort::Value createInputTensorT(OrtAllocator *ortAllocator, const std::vect switch (tensorType) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return createInputTensorT(ortAllocator, dims, buffer, allocations); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return createInputTensorT(ortAllocator, dims, buffer, allocations); case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return createInputTensorT(ortAllocator, dims, buffer, allocations); case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: @@ -150,7 +153,6 @@ static Ort::Value createInputTensorT(OrtAllocator *ortAllocator, const std::vect case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return createInputTensorT(ortAllocator, dims, buffer, allocations); case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: @@ -182,6 +184,8 @@ template static NSString *createOutputTensorT(const Ort::Value &ten switch (tensorType) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return createOutputTensorT(tensor); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return createOutputTensorT(tensor); case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return createOutputTensorT(tensor); case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: @@ -195,7 +199,6 @@ template static NSString *createOutputTensorT(const Ort::Value &ten case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return createOutputTensorT(tensor); case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: @@ -219,6 +222,7 @@ NSDictionary *OnnxTensorTypeToJsTensorTypeMap; + (void)initialize { JsTensorTypeToOnnxTensorTypeMap = @{ JsTensorTypeFloat : @(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT), + JsTensorTypeUnsignedByte : @(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8), JsTensorTypeByte : @(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8), JsTensorTypeShort : @(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16), JsTensorTypeInt : @(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32), @@ -230,6 +234,7 @@ NSDictionary *OnnxTensorTypeToJsTensorTypeMap; OnnxTensorTypeToJsTensorTypeMap = @{ @(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) : JsTensorTypeFloat, + @(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) : JsTensorTypeUnsignedByte, @(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) : JsTensorTypeByte, @(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16) : JsTensorTypeShort, @(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) : JsTensorTypeInt,