mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
[WebNN EP] Add quantize Ops (#18011)
### Description <!-- Describe your changes. --> Add four quantize Ops: MatmulInteger, ConvInteger, DynamicQuantizeLinear and DequantizeLinear. Add datatype TensorProto_DataType_INT8 and TensorProto_DataType_UINT8. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Support quantized models.
This commit is contained in:
parent
acba63c36a
commit
e1db44b4f0
13 changed files with 232 additions and 4 deletions
|
|
@ -166,6 +166,8 @@ bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type) {
|
|||
// TODO: Remove legacy "type" once all browsers implement the new "dataType".
|
||||
switch (data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
desc.set("type", emscripten::val("uint8"));
|
||||
desc.set("dataType", emscripten::val("uint8"));
|
||||
return true;
|
||||
|
|
|
|||
|
|
@ -101,6 +101,8 @@ inline bool ReadScalarTensorData(const onnx::TensorProto& tensor, emscripten::va
|
|||
}
|
||||
switch (tensor.data_type()) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
scalar = emscripten::val{*reinterpret_cast<uint8_t*>(unpacked_tensor.data())};
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
|
|
@ -148,9 +150,12 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
|
|||
{"Clip", {"clamp", true}},
|
||||
{"Concat", {"concat", true}},
|
||||
{"Conv", {"conv2d", true}},
|
||||
{"ConvInteger", {"conv2dInteger", false}},
|
||||
{"ConvTranspose", {"convTranspose2d", true}},
|
||||
{"Cos", {"cos", false}},
|
||||
{"Div", {"div", true}},
|
||||
{"DequantizeLinear", {"dequantizeLinear", false}},
|
||||
{"DynamicQuantizeLinear", {"dynamicQuantizeLinear", false}},
|
||||
{"Elu", {"elu", true}},
|
||||
{"Equal", {"equal", false}},
|
||||
{"Erf", {"erf", false}},
|
||||
|
|
@ -176,6 +181,7 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
|
|||
{"Log", {"log", false}},
|
||||
{"LpPool", {"l2Pool2d", false}},
|
||||
{"MatMul", {"matmul", false}},
|
||||
{"MatMulInteger", {"matmulInteger", false}},
|
||||
{"Max", {"max", true}},
|
||||
{"MaxPool", {"maxPool2d", true}},
|
||||
{"Min", {"min", true}},
|
||||
|
|
@ -242,8 +248,10 @@ constexpr std::array<ONNX_NAMESPACE::TensorProto_DataType, 1> supported_cpu_data
|
|||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
};
|
||||
|
||||
constexpr std::array<ONNX_NAMESPACE::TensorProto_DataType, 7> supported_gpu_data_types = {
|
||||
constexpr std::array<ONNX_NAMESPACE::TensorProto_DataType, 9> supported_gpu_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT8,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_UINT8,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT32,
|
||||
|
|
|
|||
|
|
@ -39,6 +39,8 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
std::string operand_type;
|
||||
switch (to_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
operand_type = "uint8";
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
|
|
|
|||
|
|
@ -183,6 +183,11 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder,
|
|||
|
||||
size_t element_size{0};
|
||||
switch (data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
element_size = sizeof(uint8_t);
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
element_size = sizeof(uint16_t);
|
||||
break;
|
||||
|
|
@ -257,7 +262,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
|
|||
const auto& weight_name = input_defs[1]->Name();
|
||||
emscripten::val options = emscripten::val::object();
|
||||
ORT_RETURN_IF_ERROR(SetConvBaseOptions(model_builder, node, options, strides, dilations, pads, logger));
|
||||
if (op_type == "Conv") {
|
||||
if (op_type == "Conv" || op_type == "ConvInteger") {
|
||||
int groups = options["groups"].as<int>();
|
||||
std::vector<int64_t> input_shape;
|
||||
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
|
||||
|
|
@ -271,9 +276,26 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
|
|||
options.set("filterLayout", emscripten::val("ihwo"));
|
||||
}
|
||||
}
|
||||
emscripten::val filter = model_builder.GetOperand(input_defs[1]->Name());
|
||||
emscripten::val filter = model_builder.GetOperand(weight_name);
|
||||
if (op_type == "Conv") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("conv2d", input, filter, options);
|
||||
} else {
|
||||
emscripten::val x_zero_point = emscripten::val::null();
|
||||
emscripten::val w_zero_point = emscripten::val::null();
|
||||
if (input_defs.size() >= 3) {
|
||||
x_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name());
|
||||
} else {
|
||||
x_zero_point = model_builder.GetZeroConstant("uint8");
|
||||
}
|
||||
if (input_defs.size() >= 4) {
|
||||
w_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name());
|
||||
} else {
|
||||
w_zero_point = model_builder.GetZeroConstant("uint8");
|
||||
}
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("conv2dInteger",
|
||||
input, x_zero_point, filter, w_zero_point, options);
|
||||
}
|
||||
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("conv2d", input, filter, options);
|
||||
} else {
|
||||
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
|
||||
options.set("inputLayout", emscripten::val("nhwc"));
|
||||
|
|
@ -341,6 +363,7 @@ void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_
|
|||
static std::vector<std::string> op_types =
|
||||
{
|
||||
"Conv",
|
||||
"ConvInteger",
|
||||
"ConvTranspose",
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,70 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Copyright (c) Intel Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
#include "core/providers/webnn/builders/helper.h"
|
||||
#include "core/providers/webnn/builders/model_builder.h"
|
||||
#include "core/providers/webnn/builders/op_builder_factory.h"
|
||||
|
||||
#include "core/providers/webnn/builders/impl/base_op_builder.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace webnn {
|
||||
|
||||
class DequantizeLinearOpBuilder : public BaseOpBuilder {
|
||||
// Add operator related.
|
||||
private:
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
};
|
||||
|
||||
Status DequantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
||||
const Node& node,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
|
||||
emscripten::val scale = model_builder.GetOperand(input_defs[1]->Name());
|
||||
emscripten::val zero_point = emscripten::val::null();
|
||||
if (input_defs.size() == 3) {
|
||||
zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name());
|
||||
} else {
|
||||
zero_point = model_builder.GetZeroConstant("uint8");
|
||||
}
|
||||
emscripten::val output;
|
||||
std::vector<int64_t> input_shape;
|
||||
std::vector<int64_t> scale_shape;
|
||||
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape");
|
||||
ORT_RETURN_IF_NOT(GetShape(*input_defs[1], scale_shape, logger), "Cannot get scale shape");
|
||||
NodeAttrHelper helper(node);
|
||||
int32_t axis = helper.Get("axis", 1);
|
||||
// axis is valid for input shape greater than 1D.
|
||||
if (input_shape.size() > 1) {
|
||||
axis = static_cast<int32_t>(HandleNegativeAxis(axis, input_shape.size()));
|
||||
}
|
||||
// Insert ones before and after the axis dimension for broadcasting of 1D scale tensor.
|
||||
if (1 == scale_shape.size() && 1 < input_shape.size()) {
|
||||
std::vector<int32_t> target_shape{static_cast<int>(input_shape[axis])};
|
||||
target_shape.insert(target_shape.begin(), axis, 1);
|
||||
target_shape.insert(target_shape.end(), input_shape.size() - axis - 1, 1);
|
||||
scale = model_builder.GetBuilder().call<emscripten::val>("reshape", scale, emscripten::val::array(target_shape));
|
||||
zero_point = model_builder.GetBuilder().call<emscripten::val>("reshape",
|
||||
zero_point, emscripten::val::array(target_shape));
|
||||
}
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("dequantizeLinear", input, scale, zero_point);
|
||||
|
||||
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void CreateDequantizeLinearOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<DequantizeLinearOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
}
|
||||
|
||||
} // namespace webnn
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Copyright (c) Intel Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
#include "core/providers/webnn/builders/helper.h"
|
||||
#include "core/providers/webnn/builders/model_builder.h"
|
||||
#include "core/providers/webnn/builders/op_builder_factory.h"
|
||||
|
||||
#include "core/providers/webnn/builders/impl/base_op_builder.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace webnn {
|
||||
|
||||
class DynamicQuantizaLinearOpBuilder : public BaseOpBuilder {
|
||||
// Add operator related.
|
||||
private:
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
};
|
||||
|
||||
Status DynamicQuantizaLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
||||
const Node& node,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
|
||||
emscripten::val output_array;
|
||||
std::vector<int64_t> input_shape;
|
||||
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
|
||||
emscripten::val options = emscripten::val::object();
|
||||
|
||||
output_array = model_builder.GetBuilder().call<emscripten::val>("dynamicQuantizeLinear", input);
|
||||
|
||||
for (size_t i = 0, count = output_array["length"].as<size_t>(); i < count; i++) {
|
||||
model_builder.AddOperand(node.OutputDefs()[i]->Name(), std::move(output_array[i]));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void CreateDynamicQuantizeLinearOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<DynamicQuantizaLinearOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
}
|
||||
|
||||
} // namespace webnn
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -39,6 +39,20 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
|
|||
emscripten::val output = emscripten::val::object();
|
||||
if (op_type == "MatMul") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("matmul", a, b);
|
||||
} else if (op_type == "MatMulInteger") {
|
||||
emscripten::val a_zero_point = emscripten::val::null();
|
||||
emscripten::val b_zero_point = emscripten::val::null();
|
||||
if (input_defs.size() >= 3) {
|
||||
a_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name());
|
||||
} else {
|
||||
a_zero_point = model_builder.GetZeroConstant("uint8");
|
||||
}
|
||||
if (input_defs.size() >= 4) {
|
||||
b_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name());
|
||||
} else {
|
||||
b_zero_point = model_builder.GetZeroConstant("uint8");
|
||||
}
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("matmulInteger", a, a_zero_point, b, b_zero_point);
|
||||
} else { // Gemm
|
||||
emscripten::val options = emscripten::val::object();
|
||||
NodeAttrHelper helper(node);
|
||||
|
|
@ -149,6 +163,7 @@ void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_
|
|||
{
|
||||
"Gemm",
|
||||
"MatMul",
|
||||
"MatMulInteger",
|
||||
};
|
||||
|
||||
op_registrations.builders.push_back(std::make_unique<GemmOpBuilder>());
|
||||
|
|
|
|||
|
|
@ -33,6 +33,8 @@ Status Model::Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
|
|||
emscripten::val view = emscripten::val::undefined();
|
||||
switch (tensor.tensor_info.data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
view = emscripten::val{emscripten::typed_memory_view(num_elements,
|
||||
static_cast<const uint8_t*>(tensor.buffer))};
|
||||
break;
|
||||
|
|
@ -88,6 +90,8 @@ Status Model::Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
|
|||
emscripten::val view = emscripten::val::undefined();
|
||||
switch (tensor.tensor_info.data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
view = emscripten::val{emscripten::typed_memory_view(num_elements,
|
||||
static_cast<const uint8_t*>(tensor.buffer))};
|
||||
break;
|
||||
|
|
@ -164,6 +168,8 @@ void Model::AllocateInputOutputBuffers() {
|
|||
const auto data_type = input_info.data_type;
|
||||
switch (data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
wnn_inputs_.set(input, emscripten::val::global("Uint8Array").new_(num_elements));
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
|
|
@ -195,6 +201,8 @@ void Model::AllocateInputOutputBuffers() {
|
|||
const auto data_type = output_info.data_type;
|
||||
switch (data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
wnn_outputs_.set(output, emscripten::val::global("Uint8Array").new_(num_elements));
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@
|
|||
#include "core/providers/common.h"
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace webnn {
|
||||
|
||||
|
|
@ -158,6 +160,9 @@ Status ModelBuilder::RegisterInitializers() {
|
|||
}
|
||||
switch (data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
desc.set("type", emscripten::val("uint8"));
|
||||
view = emscripten::val{emscripten::typed_memory_view(num_elements,
|
||||
reinterpret_cast<uint8_t*>(tensor_ptr))};
|
||||
break;
|
||||
|
|
@ -313,6 +318,8 @@ Status ModelBuilder::AddOperandFromPersistMemoryBuffer(
|
|||
ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type");
|
||||
switch (data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint8_t),
|
||||
reinterpret_cast<const uint8_t*>(dest))};
|
||||
break;
|
||||
|
|
@ -439,6 +446,38 @@ void ModelBuilder::AddOperand(const std::string& name, const emscripten::val& op
|
|||
wnn_operands_.insert(std::make_pair(name, operand));
|
||||
}
|
||||
|
||||
// Get the zero scalar constant.
|
||||
// Workaround for builer.constant(value, type) method since it has not been implemented now.
|
||||
// https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-constant-value-type
|
||||
// BTW, the spec is discussing if the builer.constant(value, type) should be dropped at
|
||||
// https://github.com/webmachinelearning/webnn/issues/475. Fix me according to the spec decision.
|
||||
const emscripten::val& ModelBuilder::GetZeroConstant(const std::string& data_type) {
|
||||
std::string name = "webnn_zero_constant_" + data_type;
|
||||
// If the operand does not exist, create it.
|
||||
if (wnn_operands_.find(name) == wnn_operands_.end()) {
|
||||
emscripten::val desc = emscripten::val::object();
|
||||
emscripten::val dims = emscripten::val::array();
|
||||
desc.set("dimensions", dims);
|
||||
emscripten::val zero_buffer = emscripten::val::undefined();
|
||||
if (data_type == "uint8") {
|
||||
if (!SetWebnnDataType(desc, ONNX_NAMESPACE::TensorProto_DataType_UINT8)) {
|
||||
ORT_THROW("Unsupported data type: " + data_type);
|
||||
}
|
||||
zero_buffer = emscripten::val::global("Uint8Array").new_(1);
|
||||
} else if (data_type == "float32") {
|
||||
if (!SetWebnnDataType(desc, ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) {
|
||||
ORT_THROW("Unsupported data type: " + data_type);
|
||||
}
|
||||
zero_buffer = emscripten::val::global("Float32Array").new_(1);
|
||||
} else {
|
||||
ORT_THROW("Unsupported data type: " + data_type);
|
||||
}
|
||||
emscripten::val zero_constant = wnn_builder_.call<emscripten::val>("constant", desc, zero_buffer);
|
||||
wnn_operands_.insert(std::make_pair(name, zero_constant));
|
||||
}
|
||||
return wnn_operands_.at(name);
|
||||
}
|
||||
|
||||
void ModelBuilder::AddInitializerToSkip(const std::string& tensor_name) {
|
||||
skipped_initializers_.insert(tensor_name);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ class ModelBuilder {
|
|||
const emscripten::val& GetContext() const { return wnn_context_; }
|
||||
const emscripten::val& GetOperand(const std::string& name) const { return wnn_operands_.at(name); }
|
||||
void AddOperand(const std::string& name, const emscripten::val& operand);
|
||||
const emscripten::val& GetZeroConstant(const std::string& data_type);
|
||||
// Use the buffers to persist WebNN allocated data like transposed weight.
|
||||
// It ensures the validity during inference session.
|
||||
std::vector<std::unique_ptr<uint8_t[]>> mem_persist_buffers_;
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
|
|||
|
||||
{ // Conv
|
||||
CreateConvOpBuilder("Conv", op_registrations);
|
||||
CreateConvOpBuilder("ConvInteger", op_registrations);
|
||||
CreateConvOpBuilder("ConvTranspose", op_registrations);
|
||||
}
|
||||
|
||||
|
|
@ -79,6 +80,11 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
|
|||
CreateConcatOpBuilder("Concat", op_registrations);
|
||||
}
|
||||
|
||||
{ // Quantize/Dequantize
|
||||
CreateDynamicQuantizeLinearOpBuilder("DynamicQuantizeLinear", op_registrations);
|
||||
CreateDequantizeLinearOpBuilder("DequantizeLinear", op_registrations);
|
||||
}
|
||||
|
||||
{ // Expand
|
||||
CreateExpandOpBuilder("Expand", op_registrations);
|
||||
}
|
||||
|
|
@ -94,6 +100,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
|
|||
{ // Gemm/MatMul
|
||||
CreateGemmOpBuilder("Gemm", op_registrations);
|
||||
CreateGemmOpBuilder("MatMul", op_registrations);
|
||||
CreateGemmOpBuilder("MatMulInteger", op_registrations);
|
||||
}
|
||||
|
||||
{ // Logical
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_
|
|||
void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateDynamicQuantizeLinearOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateDequantizeLinearOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
|
|
|
|||
|
|
@ -317,6 +317,8 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
|
|||
void* output_buffer;
|
||||
switch (output_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
|
||||
|
|
|
|||
Loading…
Reference in a new issue