mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
[rn] Support UINT8 type for onnxruntime-react-native on iOS (#13210)
### Description <!-- Describe your changes. --> As title. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Uint8 type might be required for some model used in sample application. To match supported data types for onnxruntime-react-native for Android. Co-authored-by: rachguo <rachguo@rachguos-Mac-mini.local> Co-authored-by: rachguo <rachguo@rachguos-Mini.attlocal.net>
This commit is contained in:
parent
b09dd11ece
commit
814e5cfa4c
4 changed files with 24 additions and 2 deletions
Binary file not shown.
|
|
@ -73,6 +73,14 @@ static void testCreateInputTensorT(const std::array<T, 3> &outValues, std::funct
|
|||
testCreateInputTensorT<bool>(outValues, convert, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, JsTensorTypeBool);
|
||||
}
|
||||
|
||||
- (void)testCreateInputTensorUInt8 {
|
||||
std::array<uint8_t, 3> outValues{std::numeric_limits<uint8_t>::min(), 2, std::numeric_limits<uint8_t>::max()};
|
||||
std::function<NSNumber *(uint8_t value)> convert = [](uint8_t value) {
|
||||
return [NSNumber numberWithUnsignedChar:value];
|
||||
};
|
||||
testCreateInputTensorT<uint8_t>(outValues, convert, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, JsTensorTypeUnsignedByte);
|
||||
}
|
||||
|
||||
- (void)testCreateInputTensorInt8 {
|
||||
std::array<int8_t, 3> outValues{std::numeric_limits<int8_t>::min(), 2, std::numeric_limits<int8_t>::max()};
|
||||
std::function<NSNumber *(int8_t value)> convert = [](int8_t value) { return [NSNumber numberWithChar:value]; };
|
||||
|
|
@ -225,6 +233,14 @@ static void testCreateOutputTensorT(const std::array<T, 5> &outValues, std::func
|
|||
testCreateOutputTensorT<bool>(outValues, convert, JsTensorTypeBool, @"test_types_bool", @"onnx");
|
||||
}
|
||||
|
||||
- (void)testCreateOutputTensorUInt8 {
|
||||
std::array<uint8_t, 5> outValues{std::numeric_limits<uint8_t>::min(), 1, 2, 3, std::numeric_limits<uint8_t>::max()};
|
||||
std::function<NSNumber *(uint8_t value)> convert = [](uint8_t value) {
|
||||
return [NSNumber numberWithUnsignedChar:value];
|
||||
};
|
||||
testCreateOutputTensorT<uint8_t>(outValues, convert, JsTensorTypeUnsignedByte, @"test_types_uint8", @"ort");
|
||||
}
|
||||
|
||||
- (void)testCreateOutputTensorInt8 {
|
||||
std::array<int8_t, 5> outValues{std::numeric_limits<int8_t>::min(), 1, -2, 3, std::numeric_limits<int8_t>::max()};
|
||||
std::function<NSNumber *(int8_t value)> convert = [](int8_t value) { return [NSNumber numberWithChar:value]; };
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
* Supported tensor data type
|
||||
*/
|
||||
FOUNDATION_EXPORT NSString* const JsTensorTypeBool;
|
||||
FOUNDATION_EXPORT NSString* const JsTensorTypeUnsignedByte;
|
||||
FOUNDATION_EXPORT NSString* const JsTensorTypeByte;
|
||||
FOUNDATION_EXPORT NSString* const JsTensorTypeShort;
|
||||
FOUNDATION_EXPORT NSString* const JsTensorTypeInt;
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
* Supported tensor data type
|
||||
*/
|
||||
NSString *const JsTensorTypeBool = @"bool";
|
||||
NSString *const JsTensorTypeUnsignedByte = @"uint8";
|
||||
NSString *const JsTensorTypeByte = @"int8";
|
||||
NSString *const JsTensorTypeShort = @"int16";
|
||||
NSString *const JsTensorTypeInt = @"int32";
|
||||
|
|
@ -137,6 +138,8 @@ static Ort::Value createInputTensorT(OrtAllocator *ortAllocator, const std::vect
|
|||
switch (tensorType) {
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
||||
return createInputTensorT<float_t>(ortAllocator, dims, buffer, allocations);
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
return createInputTensorT<uint8_t>(ortAllocator, dims, buffer, allocations);
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
|
||||
return createInputTensorT<int8_t>(ortAllocator, dims, buffer, allocations);
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
|
||||
|
|
@ -150,7 +153,6 @@ static Ort::Value createInputTensorT(OrtAllocator *ortAllocator, const std::vect
|
|||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
|
||||
return createInputTensorT<double_t>(ortAllocator, dims, buffer, allocations);
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED:
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
|
||||
|
|
@ -182,6 +184,8 @@ template <typename T> static NSString *createOutputTensorT(const Ort::Value &ten
|
|||
switch (tensorType) {
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
||||
return createOutputTensorT<float_t>(tensor);
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
return createOutputTensorT<uint8_t>(tensor);
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
|
||||
return createOutputTensorT<int8_t>(tensor);
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
|
||||
|
|
@ -195,7 +199,6 @@ template <typename T> static NSString *createOutputTensorT(const Ort::Value &ten
|
|||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
|
||||
return createOutputTensorT<double_t>(tensor);
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED:
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
|
||||
|
|
@ -219,6 +222,7 @@ NSDictionary *OnnxTensorTypeToJsTensorTypeMap;
|
|||
+ (void)initialize {
|
||||
JsTensorTypeToOnnxTensorTypeMap = @{
|
||||
JsTensorTypeFloat : @(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT),
|
||||
JsTensorTypeUnsignedByte : @(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8),
|
||||
JsTensorTypeByte : @(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8),
|
||||
JsTensorTypeShort : @(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16),
|
||||
JsTensorTypeInt : @(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32),
|
||||
|
|
@ -230,6 +234,7 @@ NSDictionary *OnnxTensorTypeToJsTensorTypeMap;
|
|||
|
||||
OnnxTensorTypeToJsTensorTypeMap = @{
|
||||
@(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) : JsTensorTypeFloat,
|
||||
@(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) : JsTensorTypeUnsignedByte,
|
||||
@(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) : JsTensorTypeByte,
|
||||
@(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16) : JsTensorTypeShort,
|
||||
@(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) : JsTensorTypeInt,
|
||||
|
|
|
|||
Loading…
Reference in a new issue