From 1bab98988b4e7b6d33be0e672fce361ccbb1d397 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Tue, 16 Jan 2024 10:44:25 +0800 Subject: [PATCH] [WebNN EP] Fixed bug in int8 data type processing (#19134) --- .../core/providers/webnn/builders/helper.cc | 5 ++++- .../core/providers/webnn/builders/helper.h | 4 +++- .../webnn/builders/impl/cast_op_builder.cc | 4 +++- .../webnn/builders/impl/conv_op_builder.cc | 4 +++- .../core/providers/webnn/builders/model.cc | 18 ++++++++++++++---- .../providers/webnn/builders/model_builder.cc | 11 +++++++++-- 6 files changed, 36 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index a55145b012..ef7c10dae5 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -166,11 +166,14 @@ bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type) { // TODO: Remove legacy "type" once all browsers implement the new "dataType". switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: desc.set("type", emscripten::val("uint8")); desc.set("dataType", emscripten::val("uint8")); return true; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + desc.set("type", emscripten::val("int8")); + desc.set("dataType", emscripten::val("int8")); + return true; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: desc.set("type", emscripten::val("float16")); desc.set("dataType", emscripten::val("float16")); diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index f3fc7ec5cc..85dafcaf66 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -101,10 +101,12 @@ inline bool ReadScalarTensorData(const onnx::TensorProto& tensor, emscripten::va } switch (tensor.data_type()) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: scalar = emscripten::val{*reinterpret_cast(unpacked_tensor.data())}; break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + scalar = emscripten::val{*reinterpret_cast(unpacked_tensor.data())}; + break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: scalar = emscripten::val{MLFloat16::FromBits(*reinterpret_cast(unpacked_tensor.data())).ToFloat()}; break; diff --git a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc index 062f1c5606..3d961e4589 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc @@ -39,10 +39,12 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::string operand_type; switch (to_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: operand_type = "uint8"; break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + operand_type = "int8"; + break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: operand_type = "float16"; break; diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 123a9cc016..ceacb7c2b3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -184,10 +184,12 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder, size_t element_size{0}; switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: element_size = sizeof(uint8_t); break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + element_size = sizeof(int8_t); + break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: element_size = sizeof(uint16_t); break; diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index a4031fd935..eaf549ef4e 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -33,11 +33,14 @@ Status Model::Predict(const InlinedHashMap& inputs, emscripten::val view = emscripten::val::undefined(); switch (tensor.tensor_info.data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: view = emscripten::val{emscripten::typed_memory_view(num_elements, static_cast(tensor.buffer))}; break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + static_cast(tensor.buffer))}; + break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: view = emscripten::val{emscripten::typed_memory_view(num_elements, static_cast(tensor.buffer))}; @@ -90,11 +93,14 @@ Status Model::Predict(const InlinedHashMap& inputs, emscripten::val view = emscripten::val::undefined(); switch (tensor.tensor_info.data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: view = emscripten::val{emscripten::typed_memory_view(num_elements, static_cast(tensor.buffer))}; break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + static_cast(tensor.buffer))}; + break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: view = emscripten::val{emscripten::typed_memory_view(num_elements, static_cast(tensor.buffer))}; @@ -168,10 +174,12 @@ void Model::AllocateInputOutputBuffers() { const auto data_type = input_info.data_type; switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: wnn_inputs_.set(input, emscripten::val::global("Uint8Array").new_(num_elements)); break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + wnn_inputs_.set(input, emscripten::val::global("Int8Array").new_(num_elements)); + break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: wnn_inputs_.set(input, emscripten::val::global("Uint16Array").new_(num_elements)); break; @@ -201,10 +209,12 @@ void Model::AllocateInputOutputBuffers() { const auto data_type = output_info.data_type; switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: wnn_outputs_.set(output, emscripten::val::global("Uint8Array").new_(num_elements)); break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + wnn_outputs_.set(output, emscripten::val::global("Int8Array").new_(num_elements)); + break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: wnn_outputs_.set(output, emscripten::val::global("Uint16Array").new_(num_elements)); break; diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 4e0c83db8b..cf8a0e23db 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -160,12 +160,16 @@ Status ModelBuilder::RegisterInitializers() { } switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: desc.set("type", emscripten::val("uint8")); view = emscripten::val{emscripten::typed_memory_view(num_elements, reinterpret_cast(tensor_ptr))}; break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + desc.set("type", emscripten::val("int8")); + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: view = emscripten::val{emscripten::typed_memory_view(num_elements, reinterpret_cast(tensor_ptr))}; @@ -318,11 +322,14 @@ Status ModelBuilder::AddOperandFromPersistMemoryBuffer( ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type"); switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint8_t), reinterpret_cast(dest))}; break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + view = emscripten::val{emscripten::typed_memory_view(size / sizeof(int8_t), + reinterpret_cast(dest))}; + break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint16_t), reinterpret_cast(dest))};