mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[WebNN EP] Fixed bug in int8 data type processing (#19134)
This commit is contained in:
parent
9dee543bed
commit
1bab98988b
6 changed files with 36 additions and 10 deletions
|
|
@ -166,11 +166,14 @@ bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type) {
|
|||
// TODO: Remove legacy "type" once all browsers implement the new "dataType".
|
||||
switch (data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
desc.set("type", emscripten::val("uint8"));
|
||||
desc.set("dataType", emscripten::val("uint8"));
|
||||
return true;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
desc.set("type", emscripten::val("int8"));
|
||||
desc.set("dataType", emscripten::val("int8"));
|
||||
return true;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
desc.set("type", emscripten::val("float16"));
|
||||
desc.set("dataType", emscripten::val("float16"));
|
||||
|
|
|
|||
|
|
@ -101,10 +101,12 @@ inline bool ReadScalarTensorData(const onnx::TensorProto& tensor, emscripten::va
|
|||
}
|
||||
switch (tensor.data_type()) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
scalar = emscripten::val{*reinterpret_cast<uint8_t*>(unpacked_tensor.data())};
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
scalar = emscripten::val{*reinterpret_cast<int8_t*>(unpacked_tensor.data())};
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
scalar = emscripten::val{MLFloat16::FromBits(*reinterpret_cast<uint16_t*>(unpacked_tensor.data())).ToFloat()};
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -39,10 +39,12 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
std::string operand_type;
|
||||
switch (to_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
operand_type = "uint8";
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
operand_type = "int8";
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
operand_type = "float16";
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -184,10 +184,12 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder,
|
|||
size_t element_size{0};
|
||||
switch (data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
element_size = sizeof(uint8_t);
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
element_size = sizeof(int8_t);
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
element_size = sizeof(uint16_t);
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -33,11 +33,14 @@ Status Model::Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
|
|||
emscripten::val view = emscripten::val::undefined();
|
||||
switch (tensor.tensor_info.data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
view = emscripten::val{emscripten::typed_memory_view(num_elements,
|
||||
static_cast<const uint8_t*>(tensor.buffer))};
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
view = emscripten::val{emscripten::typed_memory_view(num_elements,
|
||||
static_cast<const int8_t*>(tensor.buffer))};
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
view = emscripten::val{emscripten::typed_memory_view(num_elements,
|
||||
static_cast<const uint16_t*>(tensor.buffer))};
|
||||
|
|
@ -90,11 +93,14 @@ Status Model::Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
|
|||
emscripten::val view = emscripten::val::undefined();
|
||||
switch (tensor.tensor_info.data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
view = emscripten::val{emscripten::typed_memory_view(num_elements,
|
||||
static_cast<const uint8_t*>(tensor.buffer))};
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
view = emscripten::val{emscripten::typed_memory_view(num_elements,
|
||||
static_cast<const int8_t*>(tensor.buffer))};
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
view = emscripten::val{emscripten::typed_memory_view(num_elements,
|
||||
static_cast<const uint16_t*>(tensor.buffer))};
|
||||
|
|
@ -168,10 +174,12 @@ void Model::AllocateInputOutputBuffers() {
|
|||
const auto data_type = input_info.data_type;
|
||||
switch (data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
wnn_inputs_.set(input, emscripten::val::global("Uint8Array").new_(num_elements));
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
wnn_inputs_.set(input, emscripten::val::global("Int8Array").new_(num_elements));
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
wnn_inputs_.set(input, emscripten::val::global("Uint16Array").new_(num_elements));
|
||||
break;
|
||||
|
|
@ -201,10 +209,12 @@ void Model::AllocateInputOutputBuffers() {
|
|||
const auto data_type = output_info.data_type;
|
||||
switch (data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
wnn_outputs_.set(output, emscripten::val::global("Uint8Array").new_(num_elements));
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
wnn_outputs_.set(output, emscripten::val::global("Int8Array").new_(num_elements));
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
wnn_outputs_.set(output, emscripten::val::global("Uint16Array").new_(num_elements));
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -160,12 +160,16 @@ Status ModelBuilder::RegisterInitializers() {
|
|||
}
|
||||
switch (data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
desc.set("type", emscripten::val("uint8"));
|
||||
view = emscripten::val{emscripten::typed_memory_view(num_elements,
|
||||
reinterpret_cast<uint8_t*>(tensor_ptr))};
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
desc.set("type", emscripten::val("int8"));
|
||||
view = emscripten::val{emscripten::typed_memory_view(num_elements,
|
||||
reinterpret_cast<int8_t*>(tensor_ptr))};
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
view = emscripten::val{emscripten::typed_memory_view(num_elements,
|
||||
reinterpret_cast<uint16_t*>(tensor_ptr))};
|
||||
|
|
@ -318,11 +322,14 @@ Status ModelBuilder::AddOperandFromPersistMemoryBuffer(
|
|||
ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type");
|
||||
switch (data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint8_t),
|
||||
reinterpret_cast<const uint8_t*>(dest))};
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(int8_t),
|
||||
reinterpret_cast<const int8_t*>(dest))};
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint16_t),
|
||||
reinterpret_cast<const uint16_t*>(dest))};
|
||||
|
|
|
|||
Loading…
Reference in a new issue