mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
[WebNN EP] Add data type constraint (#20779)
WebNN spec has added data type constraint for every op, and its CPU backend (currently is TFLite) has additional constraint. Add corresponding constraint to each op in WebNN EP. Note: Temporarily disable fp16 for CPU backend as which is planned to be ready in Chromium next month.
This commit is contained in:
parent
e77f238dc6
commit
9ea9f9e46a
25 changed files with 680 additions and 76 deletions
|
|
@ -130,16 +130,10 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
|
|||
return supported_node_groups;
|
||||
}
|
||||
|
||||
bool IsSupportedDataType(const int32_t data_type, const WebnnDeviceType device_type) {
|
||||
// Current data type implementation status of WebNN is inconsistent along with different backends,
|
||||
// The XNNPack backend supports only FP32, while the DML backend POC supports more.
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
return std::find(supported_cpu_data_types.begin(), supported_cpu_data_types.end(), data_type) !=
|
||||
supported_cpu_data_types.end();
|
||||
} else {
|
||||
return std::find(supported_gpu_data_types.begin(), supported_gpu_data_types.end(), data_type) !=
|
||||
supported_gpu_data_types.end();
|
||||
}
|
||||
bool IsSupportedDataType(const int32_t data_type,
|
||||
const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType>& supported_data_types) {
|
||||
return std::find(supported_data_types.begin(), supported_data_types.end(), data_type) !=
|
||||
supported_data_types.end();
|
||||
}
|
||||
|
||||
bool IsValidMultidirectionalBroadcast(std::vector<int64_t>& shape_a,
|
||||
|
|
|
|||
|
|
@ -262,11 +262,7 @@ inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn
|
|||
return true;
|
||||
}
|
||||
|
||||
constexpr std::array<ONNX_NAMESPACE::TensorProto_DataType, 1> supported_cpu_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
};
|
||||
|
||||
constexpr std::array<ONNX_NAMESPACE::TensorProto_DataType, 9> supported_gpu_data_types = {
|
||||
static const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> webnn_supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT8,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_UINT8,
|
||||
|
|
@ -278,7 +274,8 @@ constexpr std::array<ONNX_NAMESPACE::TensorProto_DataType, 9> supported_gpu_data
|
|||
ONNX_NAMESPACE::TensorProto_DataType_UINT64,
|
||||
};
|
||||
|
||||
bool IsSupportedDataType(const int32_t data_type, const WebnnDeviceType device_type);
|
||||
bool IsSupportedDataType(const int32_t data_type,
|
||||
const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType>& supported_data_types);
|
||||
|
||||
bool IsValidMultidirectionalBroadcast(std::vector<int64_t>& shape_a,
|
||||
std::vector<int64_t>& shape_b,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,8 @@ class ActivationOpBuilder : public BaseOpBuilder {
|
|||
// Operator support related.
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -81,6 +83,44 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initi
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ActivationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
|
||||
// WebNN relu op supports float32, float16, int32, int8 input data types.
|
||||
if (op_type == "Relu") {
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT32,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT8,
|
||||
};
|
||||
// WebNN CPU backend does not support int32 data type for relu.
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32);
|
||||
}
|
||||
} else { // Others only support float32 and float16.
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
};
|
||||
}
|
||||
|
||||
if (!IsSupportedDataType(input_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
if (op_registrations.op_builder_map.count(op_type) > 0)
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ class ArgMaxMinOpBuilder : public BaseOpBuilder {
|
|||
// Operator support related.
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -77,6 +79,31 @@ bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ArgMaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
|
||||
// WebNN CPU backend doesn't support int64, uint64 input data types for argMax and argMin.
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
|
||||
}
|
||||
|
||||
if (!IsSupportedDataType(input_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateArgMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
if (op_registrations.op_builder_map.count(op_type) > 0)
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -73,11 +73,23 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType d
|
|||
}
|
||||
}
|
||||
|
||||
// WebNN CPU backend (TFLite) will enable float16 input data type soon,
|
||||
// temporarily fallback float16 input data type for WebNN CPU.
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)
|
||||
return false;
|
||||
}
|
||||
|
||||
return HasSupportedInputsImpl(node, device_type, logger);
|
||||
}
|
||||
|
||||
bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node,
|
||||
const WebnnDeviceType device_type,
|
||||
const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const {
|
||||
// We only check the type of input 0 by default, specific op builder can override this.
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
|
|
@ -86,7 +98,7 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node,
|
|||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
if (!IsSupportedDataType(input_type, device_type)) {
|
||||
if (!IsSupportedDataType(input_type, webnn_supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << node.OpType()
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ class BinaryOpBuilder : public BaseOpBuilder {
|
|||
// Operator support related.
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -72,6 +74,49 @@ bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
|
|||
return true;
|
||||
}
|
||||
|
||||
bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input0_type;
|
||||
int32_t input1_type;
|
||||
|
||||
if (!GetType(*input_defs[0], input0_type, logger) ||
|
||||
!GetType(*input_defs[1], input1_type, logger))
|
||||
return false;
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
|
||||
// WebNN prelu op only supports float32, float16, int32, int8 input data types.
|
||||
if (op_type == "Prelu") {
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT32,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT8,
|
||||
};
|
||||
// WebNN CPU backend doesn't support int32 for prelu.
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32);
|
||||
}
|
||||
} else {
|
||||
supported_data_types = webnn_supported_data_types;
|
||||
}
|
||||
if (!IsSupportedDataType(input0_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input0_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input0_type != input1_type) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input data types should be the same.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
if (op_registrations.op_builder_map.count(op_type) > 0)
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class CastOpBuilder : public BaseOpBuilder {
|
|||
// Operator support related.
|
||||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -80,12 +80,12 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
|
||||
bool CastOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */,
|
||||
const Node& node,
|
||||
const WebnnDeviceType device_type,
|
||||
const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const {
|
||||
NodeAttrHelper helper(node);
|
||||
// Check cast output type.
|
||||
const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED);
|
||||
if (!IsSupportedDataType(to_type, device_type)) {
|
||||
if (!IsSupportedDataType(to_type, webnn_supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "Invalid cast to type " << to_type << ".";
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ class ClipOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -71,6 +73,33 @@ bool ClipOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
|||
return GetClipMinMax(initializers, node, min, max, logger);
|
||||
}
|
||||
|
||||
bool ClipOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
|
||||
// WebNN CPU backend doesn't support int32, uint32, int64, uint64 input data types for clamp.
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32);
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
|
||||
}
|
||||
|
||||
if (!IsSupportedDataType(input_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<ClipOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@ class ConvOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
|
||||
|
|
@ -148,11 +150,6 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder,
|
|||
bool is_conv1d) {
|
||||
const auto& tensor = *model_builder.GetInitializerTensors().at(name);
|
||||
auto data_type = tensor.data_type();
|
||||
if (!IsSupportedDataType(data_type, model_builder.GetWebnnDeviceType())) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"The initializer of graph has unsupported type, name: ",
|
||||
tensor.name(), " type: ", data_type);
|
||||
}
|
||||
|
||||
const auto& shape = tensor.dims();
|
||||
std::vector<uint32_t> dims = GetVecUint32FromVecInt64(std::vector<int64_t>(std::begin(shape), std::end(shape)));
|
||||
|
|
@ -177,7 +174,6 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder,
|
|||
|
||||
size_t element_size{0};
|
||||
switch (data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
element_size = sizeof(uint8_t);
|
||||
break;
|
||||
|
|
@ -190,17 +186,6 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder,
|
|||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
|
||||
element_size = sizeof(float);
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
|
||||
element_size = sizeof(int32_t);
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
|
||||
element_size = sizeof(int64_t);
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
|
||||
element_size = sizeof(uint32_t);
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
|
||||
element_size = sizeof(uint64_t);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
|
|
@ -396,6 +381,55 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input0_type; // input data type
|
||||
int32_t input1_type; // weight data type
|
||||
int32_t input2_type; // bias or x_zero_point data type
|
||||
int32_t input3_type; // w_zero_point data type
|
||||
bool has_input2 = input_defs.size() > 2 && input_defs[2]->Exists();
|
||||
bool has_input3 = input_defs.size() > 3 && input_defs[3]->Exists();
|
||||
|
||||
if (!GetType(*input_defs[0], input0_type, logger) ||
|
||||
!GetType(*input_defs[1], input1_type, logger) ||
|
||||
(has_input2 && !GetType(*input_defs[2], input2_type, logger)) ||
|
||||
(has_input3 && !GetType(*input_defs[3], input3_type, logger))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
|
||||
if (op_type == "Conv" || op_type == "ConvTranspose") {
|
||||
// WebNN conv2d and convTranspose2d only support float32 and float16 input data types.
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
};
|
||||
} else if (op_type == "ConvInteger") {
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT8,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_UINT8,
|
||||
};
|
||||
}
|
||||
if (!IsSupportedDataType(input0_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input0_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input0_type != input1_type ||
|
||||
(has_input2 && input0_type != input2_type) ||
|
||||
(has_input3 && input0_type != input3_type)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input data types should be the same.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
if (op_registrations.op_builder_map.count(op_type) > 0)
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ class GatherOpBuilder : public BaseOpBuilder {
|
|||
// Operator support related.
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -66,6 +68,31 @@ bool GatherOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ
|
|||
return true;
|
||||
}
|
||||
|
||||
bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
|
||||
// WebNN CPU backend doesn't support uint32, uint64 input data types for gather.
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
|
||||
}
|
||||
|
||||
if (!IsSupportedDataType(input_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<GatherOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ class GemmOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -219,6 +221,55 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
|||
return true;
|
||||
}
|
||||
|
||||
bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input0_type; // A data type
|
||||
int32_t input1_type; // B data type
|
||||
int32_t input2_type; // C or a_zero_point data type
|
||||
int32_t input3_type; // b_zero_point data type
|
||||
bool has_input2 = input_defs.size() > 2 && input_defs[2]->Exists();
|
||||
bool has_input3 = input_defs.size() > 3 && input_defs[3]->Exists();
|
||||
|
||||
if (!GetType(*input_defs[0], input0_type, logger) ||
|
||||
!GetType(*input_defs[1], input1_type, logger) ||
|
||||
(has_input2 && !GetType(*input_defs[2], input2_type, logger)) ||
|
||||
(has_input3 && !GetType(*input_defs[3], input3_type, logger))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
|
||||
if (op_type == "Gemm" || op_type == "MatMul") {
|
||||
// WebNN gemm and matmul only support float32 and float16 input data types.
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
};
|
||||
} else if (op_type == "MatMulInteger") {
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT8,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_UINT8,
|
||||
};
|
||||
}
|
||||
if (!IsSupportedDataType(input0_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input0_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input0_type != input1_type ||
|
||||
(has_input2 && input0_type != input2_type) ||
|
||||
(has_input3 && input0_type != input3_type)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input data types should be the same.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
if (op_registrations.op_builder_map.count(op_type) > 0)
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -21,6 +21,8 @@ class LogicalOpBuilder : public BaseOpBuilder {
|
|||
// Operator support related.
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -50,6 +52,48 @@ Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */,
|
||||
const Node& node,
|
||||
const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& name = node.Name();
|
||||
const auto& op_type = node.OpType();
|
||||
const auto& input_defs = node.InputDefs();
|
||||
if (input_defs.size() < 2) {
|
||||
LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 2 inputs, actual: "
|
||||
<< input_defs.size();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LogicalOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input0_type;
|
||||
int32_t input1_type;
|
||||
|
||||
if (!GetType(*input_defs[0], input0_type, logger) ||
|
||||
!GetType(*input_defs[1], input1_type, logger))
|
||||
return false;
|
||||
|
||||
if (!IsSupportedDataType(input0_type, webnn_supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input0_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input0_type != input1_type) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input data types should be the same.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
if (op_registrations.op_builder_map.count(op_type) > 0)
|
||||
return;
|
||||
|
|
@ -69,20 +113,5 @@ void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations&
|
|||
}
|
||||
}
|
||||
|
||||
bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */,
|
||||
const Node& node,
|
||||
const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& name = node.Name();
|
||||
const auto& op_type = node.OpType();
|
||||
const auto& input_defs = node.InputDefs();
|
||||
if (input_defs.size() < 2) {
|
||||
LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 2 inputs, actual: "
|
||||
<< input_defs.size();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace webnn
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ class MaxMinOpBuilder : public BaseOpBuilder {
|
|||
// Operator support related.
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -84,6 +86,33 @@ bool MaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ
|
|||
return true;
|
||||
}
|
||||
|
||||
bool MaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input0_type;
|
||||
int32_t input1_type;
|
||||
|
||||
if (!GetType(*input_defs[0], input0_type, logger) ||
|
||||
!GetType(*input_defs[1], input1_type, logger))
|
||||
return false;
|
||||
|
||||
if (!IsSupportedDataType(input0_type, webnn_supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input0_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input0_type != input1_type) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input data types should be the same.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
if (op_registrations.op_builder_map.count(op_type) > 0)
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ class NormalizationOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
||||
|
|
@ -173,6 +175,53 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initi
|
|||
return true;
|
||||
}
|
||||
|
||||
bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input0_type; // input data type
|
||||
int32_t input1_type; // scale data type
|
||||
int32_t input2_type; // B data type
|
||||
int32_t input3_type; // mean data type
|
||||
int32_t input4_type; // var data type
|
||||
bool has_input2 = input_defs.size() > 2 && input_defs[2]->Exists();
|
||||
bool has_input3 = input_defs.size() > 3 && input_defs[3]->Exists();
|
||||
bool has_input4 = input_defs.size() > 3 && input_defs[4]->Exists();
|
||||
|
||||
if (!GetType(*input_defs[0], input0_type, logger) ||
|
||||
!GetType(*input_defs[1], input1_type, logger) ||
|
||||
(has_input2 && !GetType(*input_defs[2], input2_type, logger)) ||
|
||||
(has_input3 && !GetType(*input_defs[3], input3_type, logger)) ||
|
||||
(has_input4 && !GetType(*input_defs[4], input4_type, logger))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// WebNN batchNormalization, instanceNormalization, layerNormalization
|
||||
// only support float32 and float16 input data types.
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
};
|
||||
|
||||
if (!IsSupportedDataType(input0_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input0_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input0_type != input1_type ||
|
||||
(has_input2 && input0_type != input2_type) ||
|
||||
(has_input3 && input0_type != input3_type) ||
|
||||
(has_input4 && input0_type != input4_type)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input data types should be the same.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
if (op_registrations.op_builder_map.count(op_type) > 0)
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -28,6 +28,8 @@ class PadOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -190,6 +192,31 @@ bool PadOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
|||
return true;
|
||||
} // namespace webnn
|
||||
|
||||
bool PadOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
|
||||
// WebNN CPU backend doesn't support uint32, uint64 input data types for pad.
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
|
||||
}
|
||||
|
||||
if (!IsSupportedDataType(input_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<PadOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
|
|
|
|||
|
|
@ -31,6 +31,8 @@ class ReductionOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -144,6 +146,51 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ReductionOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
|
||||
if (op_type == "ReduceL1" || op_type == "ReduceProd" ||
|
||||
op_type == "ReduceSum" || op_type == "ReduceSumSquare") {
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT32,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_UINT32,
|
||||
};
|
||||
// WebNN CPU backend doesn't support uint32 for reduceProd and reduceSum.
|
||||
if (device_type == WebnnDeviceType::CPU && (op_type == "ReduceProd" || op_type == "ReduceSum")) {
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
|
||||
}
|
||||
} else if (op_type == "ReduceL2" || op_type == "ReduceLogSum" ||
|
||||
op_type == "ReduceLogSumExp" || op_type == "ReduceMean") {
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
};
|
||||
} else { // ReduceMax and ReduceMin
|
||||
supported_data_types = webnn_supported_data_types;
|
||||
// WebNN CPU backend doesn't support uint32, uint64 for reduceMax and reduceMin.
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
|
||||
}
|
||||
}
|
||||
if (!IsSupportedDataType(input_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
if (op_registrations.op_builder_map.count(op_type) > 0)
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -35,6 +35,8 @@ class ResizeOpBuilder : public BaseOpBuilder {
|
|||
// Resize opset 10- is very different than Resize opset 11+, with many key attributes missing.
|
||||
// We only support Resize opset 11+ here.
|
||||
int GetMinSupportedOpSet(const Node& /* node */) const override { return 11; }
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Helper functions
|
||||
|
|
@ -280,6 +282,30 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ResizeOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
// WebNN resample2d op only supports float32 and float16 input data types.
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
};
|
||||
|
||||
if (!IsSupportedDataType(input_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<ResizeOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
|
||||
bool ShapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */,
|
||||
const Node& node,
|
||||
const WebnnDeviceType device_type,
|
||||
const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
std::vector<int64_t> input_shape;
|
||||
|
|
@ -74,7 +74,7 @@ bool ShapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initialize
|
|||
return false;
|
||||
|
||||
int32_t output_type = ONNX_NAMESPACE::TensorProto_DataType_INT64;
|
||||
if (!IsSupportedDataType(output_type, device_type)) {
|
||||
if (!IsSupportedDataType(output_type, webnn_supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << node.OpType()
|
||||
<< "] Output type: [" << output_type
|
||||
<< "] is not supported for now";
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@ class SliceOpBuilder : public BaseOpBuilder {
|
|||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
// TODO: Support Slice opset < 10, which uses attributes for starts and ends.
|
||||
int GetMinSupportedOpSet(const Node& /* node */) const override { return 10; }
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -161,6 +163,30 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
|||
return true;
|
||||
}
|
||||
|
||||
bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
|
||||
// WebNN CPU backend doesn't support uint64 input data type for slice.
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
|
||||
}
|
||||
|
||||
if (!IsSupportedDataType(input_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<SliceOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ class SoftmaxOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
||||
|
|
@ -131,6 +133,30 @@ bool SoftmaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali
|
|||
return true;
|
||||
}
|
||||
|
||||
bool SoftmaxOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
// WebNN softmax only supports float32 and float16 input data types.
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
};
|
||||
|
||||
if (!IsSupportedDataType(input_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<SoftmaxOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ class TernaryOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -42,6 +44,42 @@ Status TernaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input0_type; // condition data type
|
||||
int32_t input1_type; // X data type
|
||||
int32_t input2_type; // Y data type
|
||||
|
||||
if (!GetType(*input_defs[0], input0_type, logger) ||
|
||||
!GetType(*input_defs[1], input1_type, logger) ||
|
||||
!GetType(*input_defs[2], input2_type, logger))
|
||||
return false;
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
|
||||
// WebNN CPU backend doesn't support uint64 X, Y data type for where.
|
||||
if (device_type == WebnnDeviceType::CPU && op_type == "Where") {
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
|
||||
}
|
||||
// ONNX's condition data type is bool which is same as WebNN.
|
||||
// Only need to check X, Y data types.
|
||||
if (!IsSupportedDataType(input1_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input1_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input1_type != input2_type) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input X, Y data types should be the same.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
if (op_registrations.op_builder_map.count(op_type) > 0)
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ class TransposeOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -47,6 +49,31 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool TransposeOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
|
||||
// WebNN CPU backend doesn't support uint32, uint64 input data types for transpose.
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
|
||||
}
|
||||
|
||||
if (!IsSupportedDataType(input_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<TransposeOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ class UnaryOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -66,6 +68,44 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool UnaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
|
||||
if (op_type == "Identity") {
|
||||
supported_data_types = webnn_supported_data_types;
|
||||
} else if (op_type == "Abs" || op_type == "Neg") {
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT32,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT8,
|
||||
};
|
||||
} else if (op_type == "Not") {
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
|
||||
};
|
||||
} else { // Others only support float32, float16 input data types.
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
};
|
||||
}
|
||||
if (!IsSupportedDataType(input_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
if (op_registrations.op_builder_map.count(op_type) > 0)
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ Status ModelBuilder::RegisterInitializers() {
|
|||
desc.set("dimensions", emscripten::val::array(dims));
|
||||
auto data_type = tensor.data_type();
|
||||
emscripten::val operand = emscripten::val::object();
|
||||
if (IsSupportedDataType(data_type, wnn_device_type_)) {
|
||||
if (IsSupportedDataType(data_type, webnn_supported_data_types)) {
|
||||
ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type");
|
||||
auto num_elements = SafeInt<size_t>(Product(tensor.dims()));
|
||||
emscripten::val view = emscripten::val::undefined();
|
||||
|
|
@ -112,12 +112,10 @@ Status ModelBuilder::RegisterInitializers() {
|
|||
switch (data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
desc.set("type", emscripten::val("uint8"));
|
||||
view = emscripten::val{emscripten::typed_memory_view(num_elements,
|
||||
reinterpret_cast<uint8_t*>(tensor_ptr))};
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
desc.set("type", emscripten::val("int8"));
|
||||
view = emscripten::val{emscripten::typed_memory_view(num_elements,
|
||||
reinterpret_cast<int8_t*>(tensor_ptr))};
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -318,25 +318,11 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
|
|||
auto output_tensor =
|
||||
ctx.GetOutput(i, output_shape.data(), output_shape.size());
|
||||
|
||||
void* output_buffer;
|
||||
switch (output_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
|
||||
output_buffer = output_tensor.GetTensorMutableRawData();
|
||||
break;
|
||||
default:
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
||||
"Unsupported type: ", output_type, " for output: ", output_name);
|
||||
break;
|
||||
if (!webnn::IsSupportedDataType(output_type, webnn::webnn_supported_data_types)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
||||
"Unsupported type: ", output_type, " for output: ", output_name);
|
||||
}
|
||||
|
||||
void* output_buffer = output_tensor.GetTensorMutableRawData();
|
||||
outputs.emplace(output_name,
|
||||
webnn::OnnxTensorData{
|
||||
webnn::OnnxTensorInfo{output_type, output_shape},
|
||||
|
|
|
|||
Loading…
Reference in a new issue