[WebNN EP] Fixed bug in int8 data type processing (#19134)

This commit is contained in:
Wanming Lin 2024-01-16 10:44:25 +08:00 committed by GitHub
parent 9dee543bed
commit 1bab98988b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 36 additions and 10 deletions

View file

@ -166,11 +166,14 @@ bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type) {
// TODO: Remove legacy "type" once all browsers implement the new "dataType".
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
desc.set("type", emscripten::val("uint8"));
desc.set("dataType", emscripten::val("uint8"));
return true;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
desc.set("type", emscripten::val("int8"));
desc.set("dataType", emscripten::val("int8"));
return true;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
desc.set("type", emscripten::val("float16"));
desc.set("dataType", emscripten::val("float16"));

View file

@ -101,10 +101,12 @@ inline bool ReadScalarTensorData(const onnx::TensorProto& tensor, emscripten::va
}
switch (tensor.data_type()) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
scalar = emscripten::val{*reinterpret_cast<uint8_t*>(unpacked_tensor.data())};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
scalar = emscripten::val{*reinterpret_cast<int8_t*>(unpacked_tensor.data())};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
scalar = emscripten::val{MLFloat16::FromBits(*reinterpret_cast<uint16_t*>(unpacked_tensor.data())).ToFloat()};
break;

View file

@ -39,10 +39,12 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
std::string operand_type;
switch (to_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
operand_type = "uint8";
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
operand_type = "int8";
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
operand_type = "float16";
break;

View file

@ -184,10 +184,12 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder,
size_t element_size{0};
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
element_size = sizeof(uint8_t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
element_size = sizeof(int8_t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
element_size = sizeof(uint16_t);
break;

View file

@ -33,11 +33,14 @@ Status Model::Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
emscripten::val view = emscripten::val::undefined();
switch (tensor.tensor_info.data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const uint8_t*>(tensor.buffer))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const int8_t*>(tensor.buffer))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const uint16_t*>(tensor.buffer))};
@ -90,11 +93,14 @@ Status Model::Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
emscripten::val view = emscripten::val::undefined();
switch (tensor.tensor_info.data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const uint8_t*>(tensor.buffer))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const int8_t*>(tensor.buffer))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const uint16_t*>(tensor.buffer))};
@ -168,10 +174,12 @@ void Model::AllocateInputOutputBuffers() {
const auto data_type = input_info.data_type;
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
wnn_inputs_.set(input, emscripten::val::global("Uint8Array").new_(num_elements));
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
wnn_inputs_.set(input, emscripten::val::global("Int8Array").new_(num_elements));
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
wnn_inputs_.set(input, emscripten::val::global("Uint16Array").new_(num_elements));
break;
@ -201,10 +209,12 @@ void Model::AllocateInputOutputBuffers() {
const auto data_type = output_info.data_type;
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
wnn_outputs_.set(output, emscripten::val::global("Uint8Array").new_(num_elements));
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
wnn_outputs_.set(output, emscripten::val::global("Int8Array").new_(num_elements));
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
wnn_outputs_.set(output, emscripten::val::global("Uint16Array").new_(num_elements));
break;

View file

@ -160,12 +160,16 @@ Status ModelBuilder::RegisterInitializers() {
}
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
desc.set("type", emscripten::val("uint8"));
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint8_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
desc.set("type", emscripten::val("int8"));
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<int8_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint16_t*>(tensor_ptr))};
@ -318,11 +322,14 @@ Status ModelBuilder::AddOperandFromPersistMemoryBuffer(
ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type");
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint8_t),
reinterpret_cast<const uint8_t*>(dest))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(int8_t),
reinterpret_cast<const int8_t*>(dest))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint16_t),
reinterpret_cast<const uint16_t*>(dest))};