[WebNN EP] Use opSupportLimits to dynamically check data type support (#22025)

- Remove hard code data type checks and use WebNN's opSupportLimits
instead
- Add HasSupportedOutputsImpl for output data type validation
- Get preferred layout info from opSupportLimits
- Move Not op to logical_op_builder.cc because it should be there. This
avoid the inconsistent input names in `unary_op_builder.cc`.
This commit is contained in:
Wanming Lin 2024-09-14 12:36:20 +08:00 committed by GitHub
parent a89bddd5c2
commit c63dd0234b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 288 additions and 642 deletions

View file

@ -45,12 +45,12 @@ bool GetShape(const NodeArg& node_arg, std::vector<int64_t>& shape, const loggin
return true;
}
bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer,
const WebnnDeviceType device_type, const logging::Logger& logger) {
bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const WebnnDeviceType device_type,
const emscripten::val& wnn_limits, const logging::Logger& logger) {
const auto& op_builders = GetOpBuilders();
if (Contains(op_builders, node.OpType())) {
const auto* op_builder = op_builders.at(node.OpType());
return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, device_type, logger);
return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, device_type, wnn_limits, logger);
} else {
return false;
}
@ -86,6 +86,7 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
const emscripten::val& wnn_builder,
const WebnnDeviceType device_type,
const emscripten::val& wnn_limits,
const logging::Logger& logger) {
std::vector<std::vector<size_t>> supported_node_groups;
@ -105,7 +106,7 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
// Firstly check if platform supports the WebNN op.
if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) {
LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser";
supported = IsNodeSupported(*node, graph_viewer, device_type, logger);
supported = IsNodeSupported(*node, graph_viewer, device_type, wnn_limits, logger);
}
LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType()
@ -130,10 +131,54 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
return supported_node_groups;
}
bool IsSupportedDataType(const int32_t data_type,
const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType>& supported_data_types) {
return std::find(supported_data_types.begin(), supported_data_types.end(), data_type) !=
supported_data_types.end();
bool AreInputDataTypesSame(const std::string& op_type,
gsl::span<const int32_t> input_types,
const logging::Logger& logger) {
for (size_t i = 1; i < input_types.size(); i++) {
if (input_types[0] != input_types[i]) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input data types should be the same, but ["
<< input_types[0] << "] does not match "
<< input_types[i] << "].";
return false;
}
}
return true;
}
bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types) {
auto it = onnx_to_webnn_data_type_map.find(static_cast<ONNX_NAMESPACE::TensorProto_DataType>(onnx_data_type));
if (it == onnx_to_webnn_data_type_map.end())
return false;
std::string webnn_data_type = it->second;
// Check if WebNN supports the data type.
emscripten::val is_supported = webnn_supported_data_types.call<emscripten::val>("includes",
emscripten::val(webnn_data_type));
return is_supported.as<bool>();
}
// Check if the input or output data type of ONNX node is supported by the WebNN operator.
bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
const int32_t onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const logging::Logger& logger) {
std::string webnn_op_type;
if (!GetWebNNOpType(onnx_op_type, webnn_op_type))
return false;
if (!IsSupportedDataType(onnx_data_type, wnn_limits[webnn_op_type][webnn_input_output_name]["dataTypes"])) {
LOGS(logger, VERBOSE) << "[" << onnx_op_type
<< "] " << onnx_input_output_name
<< " type: [" << onnx_data_type
<< "] is not supported for now";
return false;
}
return true;
}
bool GetBidirectionalBroadcastShape(std::vector<int64_t>& shape_a,

View file

@ -148,6 +148,7 @@ bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, c
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
const emscripten::val& wnn_builder,
const WebnnDeviceType device_type,
const emscripten::val& wnn_limits,
const logging::Logger& logger);
static const InlinedHashMap<std::string, std::string> op_map = {
{"Abs", "abs"},
@ -250,20 +251,38 @@ inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn
return true;
}
static const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> webnn_supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
ONNX_NAMESPACE::TensorProto_DataType_INT8,
ONNX_NAMESPACE::TensorProto_DataType_UINT8,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_INT32,
ONNX_NAMESPACE::TensorProto_DataType_INT64,
ONNX_NAMESPACE::TensorProto_DataType_UINT32,
ONNX_NAMESPACE::TensorProto_DataType_UINT64,
inline bool GetWebNNOpType(const std::string& op_type, std::string& webnn_op_type) {
auto it = op_map.find(op_type);
// Returns false if the op_type is not listed in the op_map.
if (it == op_map.end()) {
return false;
}
webnn_op_type = it->second;
return true;
}
static const InlinedHashMap<ONNX_NAMESPACE::TensorProto_DataType, std::string> onnx_to_webnn_data_type_map = {
{ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"},
{ONNX_NAMESPACE::TensorProto_DataType_INT8, "int8"},
{ONNX_NAMESPACE::TensorProto_DataType_UINT8, "uint8"},
{ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, "float16"},
{ONNX_NAMESPACE::TensorProto_DataType_FLOAT, "float32"},
{ONNX_NAMESPACE::TensorProto_DataType_INT32, "int32"},
{ONNX_NAMESPACE::TensorProto_DataType_INT64, "int64"},
{ONNX_NAMESPACE::TensorProto_DataType_UINT32, "uint32"},
{ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"},
};
bool IsSupportedDataType(const int32_t data_type,
const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType>& supported_data_types);
bool AreInputDataTypesSame(const std::string& op_type,
gsl::span<const int32_t> input_types,
const logging::Logger& logger);
bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types);
bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
const int32_t onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const logging::Logger& logger);
bool GetBidirectionalBroadcastShape(std::vector<int64_t>& shape_a,
std::vector<int64_t>& shape_b,

View file

@ -21,8 +21,6 @@ class ActivationOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
WebnnDeviceType device_type, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -94,44 +92,6 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initi
return true;
}
bool ActivationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
// WebNN relu op supports float32, float16, int32, int8 input data types.
if (op_type == "Relu") {
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
ONNX_NAMESPACE::TensorProto_DataType_INT32,
ONNX_NAMESPACE::TensorProto_DataType_INT8,
};
// WebNN CPU backend does not support int32 data type for relu.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32);
}
} else { // Others only support float32 and float16.
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
};
}
if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}
return true;
}
void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;

View file

@ -22,8 +22,6 @@ class ArgMaxMinOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
WebnnDeviceType device_type, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -77,31 +75,6 @@ bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia
return true;
}
bool ArgMaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
// WebNN CPU backend doesn't support int64, uint64 input data types for argMax and argMin.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64);
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
}
if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}
return true;
}
void CreateArgMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;

View file

@ -38,9 +38,9 @@ bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node
Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const {
ORT_RETURN_IF_NOT(
IsOpSupported(model_builder.GetInitializerTensors(), node, model_builder.GetWebnnDeviceType(), logger),
"Unsupported operator ",
node.OpType());
IsOpSupported(model_builder.GetInitializerTensors(), node, model_builder.GetWebnnDeviceType(),
model_builder.GetOpSupportLimits(), logger),
"Unsupported operator ", node.OpType());
ORT_RETURN_IF_ERROR(AddToModelBuilderImpl(model_builder, node, logger));
LOGS(logger, VERBOSE) << "Operator name: [" << node.Name()
<< "] type: [" << node.OpType() << "] was added";
@ -50,8 +50,12 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node&
// Operator support related.
bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const {
if (!HasSupportedInputs(node, device_type, logger))
const WebnnDeviceType device_type, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
if (!HasSupportedInputs(node, wnn_limits, logger))
return false;
if (!HasSupportedOutputsImpl(node, wnn_limits, logger))
return false;
// We do not support external initializers for now.
@ -64,7 +68,7 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons
return IsOpSupportedImpl(initializers, node, device_type, logger);
}
bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType device_type,
bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
for (const auto* input : node.InputDefs()) {
@ -73,39 +77,33 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType d
}
}
// WebNN CPU backend (TFLite) will enable float16 input data type soon,
// temporarily fallback float16 input data type for WebNN CPU.
if (device_type == WebnnDeviceType::CPU) {
const auto& input = *node.InputDefs()[0];
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)
return false;
}
return HasSupportedInputsImpl(node, device_type, logger);
return HasSupportedInputsImpl(node, wnn_limits, logger);
}
bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node,
const WebnnDeviceType /* device_type */,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
// We only check the type of input 0 by default, specific op builder can override this.
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
if (!IsSupportedDataType(input_type, webnn_supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << node.OpType()
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}
return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger);
}
return true;
bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
// We only check the type of output 0 by default, specific op builder can override this.
const auto& output = *node.OutputDefs()[0];
const auto& op_type = node.OpType();
int32_t output_type;
if (!GetType(output, output_type, logger))
return false;
return IsDataTypeSupportedByOp(op_type, output_type, wnn_limits, "output", "Output", logger);
}
bool BaseOpBuilder::HasSupportedOpSet(const Node& node,

View file

@ -28,7 +28,8 @@ class BaseOpBuilder : public IOpBuilder {
// Operator support related.
public:
bool IsOpSupported(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
const WebnnDeviceType device_type, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
protected:
virtual bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& /* node */,
@ -36,8 +37,10 @@ class BaseOpBuilder : public IOpBuilder {
return true;
}
virtual bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
virtual bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const;
virtual bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const;
// ONNX Runtime only *guarantees* support for models stamped
// with opset version 7 or above for opset domain 'ai.onnx'.
@ -50,7 +53,7 @@ class BaseOpBuilder : public IOpBuilder {
private:
bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const;
bool HasSupportedInputs(const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const;
bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
};
} // namespace webnn

View file

@ -22,7 +22,7 @@ class BinaryOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};
@ -86,7 +86,7 @@ bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
return true;
}
bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
@ -97,36 +97,14 @@ bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDevice
!GetType(*input_defs[1], input1_type, logger))
return false;
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
// WebNN prelu op only supports float32, float16, int32, int8 input data types.
if (op_type == "Prelu") {
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
ONNX_NAMESPACE::TensorProto_DataType_INT32,
ONNX_NAMESPACE::TensorProto_DataType_INT8,
};
// WebNN CPU backend doesn't support int32 for prelu.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32);
}
} else {
supported_data_types = webnn_supported_data_types;
}
if (!IsSupportedDataType(input0_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input0_type
<< "] is not supported for now";
std::array<int32_t, 2> input_types{input0_type, input1_type};
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
return false;
}
if (input0_type != input1_type) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input data types should be the same.";
return false;
}
return true;
std::string webnn_input_name = op_type == "PRelu" ? "input" : "a";
std::string onnx_input_name = op_type == "PRelu" || op_type == "Pow" ? "X" : "A";
return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, webnn_input_name, onnx_input_name, logger);
}
void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {

View file

@ -21,8 +21,8 @@ class CastOpBuilder : public BaseOpBuilder {
// Operator support related.
private:
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -80,26 +80,22 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
}
// Operator support related.
bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(*input_defs[0], input_type, logger))
return false;
if (!IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "input", logger))
return false;
bool CastOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */,
const Node& node,
const WebnnDeviceType device_type,
const logging::Logger& logger) const {
NodeAttrHelper helper(node);
// Check cast output type.
// Check cast to type.
const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED);
// WebNN CPU backend doesn't support casting to uint64 data type.
if (device_type == WebnnDeviceType::CPU && to_type == ONNX_NAMESPACE::TensorProto_DataType_UINT64) {
LOGS(logger, VERBOSE) << "Cast to uint64 is not supported for WebNN CPU backend.";
return false;
}
if (!IsSupportedDataType(to_type, webnn_supported_data_types)) {
LOGS(logger, VERBOSE) << "WebNN doesn't support casting to type " << to_type << ".";
return false;
}
return true;
return IsDataTypeSupportedByOp(op_type, to_type, wnn_limits, "output", "to", logger);
}
void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {

View file

@ -25,8 +25,6 @@ class ClipOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -94,33 +92,6 @@ bool ClipOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
};
}
bool ClipOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
// WebNN CPU backend doesn't support int32, uint32, int64, uint64 input data types for clamp.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32);
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64);
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
}
if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}
return true;
}
void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<ClipOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());

View file

@ -19,6 +19,10 @@ class ConcatOpBuilder : public BaseOpBuilder {
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
// Operator support related.
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -52,6 +56,30 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
return Status::OK();
}
bool ConcatOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type;
if (!GetType(*input_defs[0], input0_type, logger))
return false;
for (size_t i = 1; i < input_defs.size(); i++) {
int32_t input_type;
if (!GetType(*input_defs[i], input_type, logger)) {
return false;
}
std::array<int32_t, 2> input_types{input0_type, input_type};
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
return false;
}
}
return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger);
}
void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<ConcatOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());

View file

@ -29,7 +29,7 @@ class ConvOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};
@ -397,7 +397,7 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
return true;
}
bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
@ -415,35 +415,18 @@ bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTy
return false;
}
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
if (op_type == "Conv" || op_type == "ConvTranspose") {
// WebNN conv2d and convTranspose2d only support float32 and float16 input data types.
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
};
} else if (op_type == "ConvInteger") {
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_INT8,
ONNX_NAMESPACE::TensorProto_DataType_UINT8,
};
InlinedVector<int32_t, 4> input_types = {input0_type, input1_type};
if (has_input2) {
input_types.push_back(input2_type);
}
if (!IsSupportedDataType(input0_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input0_type
<< "] is not supported for now";
if (has_input3) {
input_types.push_back(input3_type);
}
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
return false;
}
if (input0_type != input1_type ||
(has_input2 && input0_type != input2_type) ||
(has_input3 && input0_type != input3_type)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input data types should be the same.";
return false;
}
return true;
return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger);
}
void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {

View file

@ -22,7 +22,7 @@ class GatherOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};
@ -69,29 +69,19 @@ bool GatherOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ
return true;
}
bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
int32_t indices_type;
if (!GetType(input, input_type, logger) ||
!GetType(indices, indices_type, logger))
return false;
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
// WebNN CPU backend doesn't support uint32, uint64 input data types for gather.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
}
if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}
return true;
return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) &&
IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger);
}
void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {

View file

@ -25,7 +25,7 @@ class GemmOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};
@ -215,7 +215,7 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializer
return true;
}
bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
@ -233,35 +233,18 @@ bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTy
return false;
}
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
if (op_type == "Gemm" || op_type == "MatMul") {
// WebNN gemm and matmul only support float32 and float16 input data types.
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
};
} else if (op_type == "MatMulInteger") {
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_INT8,
ONNX_NAMESPACE::TensorProto_DataType_UINT8,
};
InlinedVector<int32_t, 4> input_types = {input0_type, input1_type};
if (has_input2) {
input_types.push_back(input2_type);
}
if (!IsSupportedDataType(input0_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input0_type
<< "] is not supported for now";
if (has_input3) {
input_types.push_back(input3_type);
}
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
return false;
}
if (input0_type != input1_type ||
(has_input2 && input0_type != input2_type) ||
(has_input3 && input0_type != input3_type)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input data types should be the same.";
return false;
}
return true;
return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger);
}
void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {

View file

@ -26,7 +26,7 @@ class GruOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};
@ -185,7 +185,7 @@ bool GruOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, c
return true;
}
bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
@ -208,37 +208,21 @@ bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTyp
return false;
}
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
if (device_type == WebnnDeviceType::CPU) {
// WebNN CPU backend only support float32 input data type.
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
};
} else if (device_type == WebnnDeviceType::GPU) {
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
};
InlinedVector<int32_t, 6> input_types = {input0_type, input1_type, input2_type};
if (has_input3) {
input_types.push_back(input3_type);
}
if (!IsSupportedDataType(input0_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input0_type
<< "] is not supported for now";
if (has_input4) {
input_types.push_back(input4_type);
}
if (has_input5) {
input_types.push_back(input5_type);
}
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
return false;
}
if (input0_type != input1_type ||
input0_type != input2_type ||
(has_input3 && input0_type != input3_type) ||
(has_input4 && input0_type != input4_type) ||
(has_input5 && input0_type != input5_type)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input data types should be the same.";
return false;
}
return true;
return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger);
}
void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {

View file

@ -21,7 +21,7 @@ class LogicalOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};
@ -29,9 +29,14 @@ class LogicalOpBuilder : public BaseOpBuilder {
Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& /* logger */) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
emscripten::val input0 = model_builder.GetOperand(node.InputDefs()[0]->Name());
emscripten::val input1 = model_builder.GetOperand(node.InputDefs()[1]->Name());
emscripten::val input0 = model_builder.GetOperand(input_defs[0]->Name());
emscripten::val input1 = emscripten::val::undefined();
if (input_defs.size() > 1) {
input1 = model_builder.GetOperand(input_defs[1]->Name());
}
emscripten::val output = emscripten::val::object();
emscripten::val options = emscripten::val::object();
options.set("label", node.Name());
@ -45,6 +50,8 @@ Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons
output = model_builder.GetBuilder().call<emscripten::val>("lesser", input0, input1, options);
} else if (op_type == "LessOrEqual") {
output = model_builder.GetBuilder().call<emscripten::val>("lesserOrEqual", input0, input1, options);
} else if (op_type == "Not") {
output = model_builder.GetBuilder().call<emscripten::val>("logicalNot", input0, options);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"LogicalOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
@ -61,7 +68,7 @@ bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali
const auto& name = node.Name();
const auto& op_type = node.OpType();
const auto& input_defs = node.InputDefs();
if (input_defs.size() < 2) {
if (input_defs.size() < 2 && op_type != "Not") {
LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 2 inputs, actual: "
<< input_defs.size();
return false;
@ -69,31 +76,27 @@ bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali
return true;
}
bool LogicalOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
bool LogicalOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type;
int32_t input1_type;
if (!GetType(*input_defs[0], input0_type, logger) ||
!GetType(*input_defs[1], input1_type, logger))
if (!GetType(*input_defs[0], input0_type, logger))
return false;
if (!IsSupportedDataType(input0_type, webnn_supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input0_type
<< "] is not supported for now";
return false;
if (op_type != "Not") {
if (!GetType(*input_defs[1], input1_type, logger))
return false;
std::array<int32_t, 2> input_types{input0_type, input1_type};
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
return false;
}
}
if (input0_type != input1_type) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input data types should be the same.";
return false;
}
return true;
std::string onnx_input_name = op_type == "Not" ? "X" : "A";
return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", onnx_input_name, logger);
}
void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
@ -107,6 +110,7 @@ void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations&
"GreaterOrEqual",
"Less",
"LessOrEqual",
"Not",
};
op_registrations.builders.push_back(std::make_unique<LogicalOpBuilder>());

View file

@ -22,7 +22,7 @@ class MaxMinOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};
@ -87,31 +87,28 @@ bool MaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ
return true;
}
bool MaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
bool MaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type;
int32_t input1_type;
if (!GetType(*input_defs[0], input0_type, logger) ||
!GetType(*input_defs[1], input1_type, logger))
if (!GetType(*input_defs[0], input0_type, logger))
return false;
if (!IsSupportedDataType(input0_type, webnn_supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input0_type
<< "] is not supported for now";
return false;
for (size_t i = 1; i < input_defs.size(); i++) {
int32_t input_type;
if (!GetType(*input_defs[i], input_type, logger)) {
return false;
}
std::array<int32_t, 2> input_types{input0_type, input_type};
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
return false;
}
}
if (input0_type != input1_type) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input data types should be the same.";
return false;
}
return true;
return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger);
}
void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {

View file

@ -25,7 +25,7 @@ class NormalizationOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};
@ -182,7 +182,7 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initi
return true;
}
bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
@ -203,30 +203,21 @@ bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const Webn
return false;
}
// WebNN batchNormalization, instanceNormalization, layerNormalization
// only support float32 and float16 input data types.
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
};
if (!IsSupportedDataType(input0_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input0_type
<< "] is not supported for now";
std::vector<int32_t> input_types = {input0_type, input1_type};
if (has_input2) {
input_types.push_back(input2_type);
}
if (has_input3) {
input_types.push_back(input3_type);
}
if (has_input4) {
input_types.push_back(input4_type);
}
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
return false;
}
if (input0_type != input1_type ||
(has_input2 && input0_type != input2_type) ||
(has_input3 && input0_type != input3_type) ||
(has_input4 && input0_type != input4_type)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input data types should be the same.";
return false;
}
return true;
return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger);
}
void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {

View file

@ -28,8 +28,6 @@ class PadOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -196,31 +194,6 @@ bool PadOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
return true;
} // namespace webnn
bool PadOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
// WebNN CPU backend doesn't support uint32, uint64 input data types for pad.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
}
if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}
return true;
}
void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<PadOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());

View file

@ -31,8 +31,6 @@ class ReductionOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -147,56 +145,6 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ
return true;
}
bool ReductionOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
if (op_type == "ReduceL1" || op_type == "ReduceProd" ||
op_type == "ReduceSum" || op_type == "ReduceSumSquare") {
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
ONNX_NAMESPACE::TensorProto_DataType_INT32,
ONNX_NAMESPACE::TensorProto_DataType_UINT32,
ONNX_NAMESPACE::TensorProto_DataType_INT64,
ONNX_NAMESPACE::TensorProto_DataType_UINT64,
};
if (device_type == WebnnDeviceType::CPU) {
// WebNN CPU backend doesn't support uint32 and uint64 for reduceL1,
// reduceProd, reduceSum and reduceSumSquare.
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
}
} else if (op_type == "ReduceL2" || op_type == "ReduceLogSum" ||
op_type == "ReduceLogSumExp" || op_type == "ReduceMean") {
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
};
} else { // ReduceMax and ReduceMin
supported_data_types = webnn_supported_data_types;
// WebNN CPU backend doesn't support uint32, uint64 for reduceMax and reduceMin.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
}
}
if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}
return true;
}
void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;

View file

@ -35,8 +35,6 @@ class ResizeOpBuilder : public BaseOpBuilder {
// Resize opset 10- is very different than Resize opset 11+, with many key attributes missing.
// We only support Resize opset 11+ here.
int GetMinSupportedOpSet(const Node& /* node */) const override { return 11; }
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const override;
};
// Helper functions
@ -275,30 +273,6 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
return true;
}
bool ResizeOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
// WebNN resample2d op only supports float32 and float16 input data types.
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
};
if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}
return true;
}
void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<ResizeOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());

View file

@ -18,11 +18,6 @@ class ShapeOpBuilder : public BaseOpBuilder {
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
// Operator support related.
private:
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
};
Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
@ -69,28 +64,6 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
return Status::OK();
}
// Operator support related.
bool ShapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */,
const Node& node,
const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
std::vector<int64_t> input_shape;
if (!GetShape(*input_defs[0], input_shape, logger))
return false;
int32_t output_type = ONNX_NAMESPACE::TensorProto_DataType_INT64;
if (!IsSupportedDataType(output_type, webnn_supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << node.OpType()
<< "] Output type: [" << output_type
<< "] is not supported for now";
return false;
}
return true;
}
void CreateShapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<ShapeOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());

View file

@ -29,8 +29,6 @@ class SliceOpBuilder : public BaseOpBuilder {
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
// TODO: Support Slice opset < 10, which uses attributes for starts and ends.
int GetMinSupportedOpSet(const Node& /* node */) const override { return 10; }
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -166,30 +164,6 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
return true;
}
bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
// WebNN CPU backend doesn't support uint64 input data type for slice.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
}
if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}
return true;
}
void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<SliceOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());

View file

@ -24,8 +24,6 @@ class SoftmaxOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const override;
};
Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
@ -63,30 +61,6 @@ bool SoftmaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali
return true;
}
bool SoftmaxOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
// WebNN softmax only supports float32 and float16 input data types.
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
};
if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}
return true;
}
void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<SoftmaxOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());

View file

@ -18,7 +18,7 @@ class TernaryOpBuilder : public BaseOpBuilder {
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};
@ -46,7 +46,7 @@ Status TernaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons
return Status::OK();
}
bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
@ -59,27 +59,14 @@ bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDevic
!GetType(*input_defs[2], input2_type, logger))
return false;
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
// WebNN CPU backend doesn't support uint64 X, Y data type for where.
if (device_type == WebnnDeviceType::CPU && op_type == "Where") {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
}
// ONNX's condition data type is bool which is same as WebNN.
// Only need to check X, Y data types.
if (!IsSupportedDataType(input1_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input1_type
<< "] is not supported for now";
std::array<int32_t, 2> input_types{input1_type, input2_type};
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
return false;
}
if (input1_type != input2_type) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input X, Y data types should be the same.";
return false;
}
return true;
return IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "trueValue", "X", logger);
}
void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {

View file

@ -18,8 +18,6 @@ class TransposeOpBuilder : public BaseOpBuilder {
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -50,31 +48,6 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
return Status::OK();
}
bool TransposeOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
// WebNN CPU backend doesn't support uint32, uint64 input data types for transpose.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
}
if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}
return true;
}
void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<TransposeOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());

View file

@ -18,8 +18,6 @@ class UnaryOpBuilder : public BaseOpBuilder {
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -51,8 +49,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
output = model_builder.GetBuilder().call<emscripten::val>("log", input, options);
} else if (op_type == "Neg") {
output = model_builder.GetBuilder().call<emscripten::val>("neg", input, options);
} else if (op_type == "Not") {
output = model_builder.GetBuilder().call<emscripten::val>("logicalNot", input, options);
} else if (op_type == "Reciprocal") {
output = model_builder.GetBuilder().call<emscripten::val>("reciprocal", input, options);
} else if (op_type == "Sin") {
@ -70,44 +66,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
return Status::OK();
}
bool UnaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
if (op_type == "Identity") {
supported_data_types = webnn_supported_data_types;
} else if (op_type == "Abs" || op_type == "Neg") {
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
ONNX_NAMESPACE::TensorProto_DataType_INT32,
ONNX_NAMESPACE::TensorProto_DataType_INT8,
};
} else if (op_type == "Not") {
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
};
} else { // Others only support float32, float16 input data types.
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
};
}
if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}
return true;
}
void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;
@ -123,7 +81,6 @@ void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op
"Identity",
"Log",
"Neg",
"Not",
"Reciprocal",
"Sin",
"Sqrt",

View file

@ -21,12 +21,13 @@ namespace webnn {
ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger,
const emscripten::val& context, const DataLayout preferred_layout,
const WebnnDeviceType wnn_device_type)
const WebnnDeviceType wnn_device_type, const emscripten::val& wnn_limits)
: graph_viewer_(graph_viewer),
logger_(logger),
wnn_context_(context),
preferred_layout_(preferred_layout),
wnn_device_type_(wnn_device_type) {
wnn_device_type_(wnn_device_type),
wnn_limits_(wnn_limits) {
// Create WebNN MLGraphBuilder for each ModelBuilder, because MLGraphBuilder.build()
// is only allowed to be called once.
wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(context);
@ -102,7 +103,7 @@ Status ModelBuilder::RegisterInitializers() {
desc.set("dimensions", emscripten::val::array(dims));
auto data_type = tensor.data_type();
emscripten::val operand = emscripten::val::object();
if (IsSupportedDataType(data_type, webnn_supported_data_types)) {
if (IsSupportedDataType(data_type, wnn_limits_["constant"]["dataTypes"])) {
ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type");
auto num_elements = SafeInt<size_t>(Product(shape));
emscripten::val view = emscripten::val::undefined();

View file

@ -23,7 +23,7 @@ class ModelBuilder {
public:
ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger,
const emscripten::val& context, const DataLayout preferred_layout,
const WebnnDeviceType wnn_device_type);
const WebnnDeviceType wnn_device_type, const emscripten::val& wnn_limits);
~ModelBuilder() = default;
Status Compile(std::unique_ptr<Model>& model) ORT_MUST_USE_RESULT;
@ -35,6 +35,8 @@ class ModelBuilder {
const emscripten::val& GetBuilder() const { return wnn_builder_; }
const emscripten::val& GetContext() const { return wnn_context_; }
const emscripten::val& GetOperand(const std::string& name) const { return wnn_operands_.at(name); }
const emscripten::val& GetOpSupportLimits() const { return wnn_limits_; }
void AddOperand(const std::string& name, const emscripten::val& operand);
const emscripten::val& GetZeroConstant(const std::string& data_type);
// Use the buffers to persist WebNN allocated data like transposed weight.
@ -66,6 +68,7 @@ class ModelBuilder {
emscripten::val wnn_builder_ = emscripten::val::undefined();
DataLayout preferred_layout_;
WebnnDeviceType wnn_device_type_;
emscripten::val wnn_limits_ = emscripten::val::undefined();
InlinedHashMap<std::string, emscripten::val> wnn_operands_;
std::vector<std::string> input_names_;
std::vector<std::string> output_names_;

View file

@ -29,7 +29,8 @@ class IOpBuilder {
public:
// Check if an operator is supported.
virtual bool IsOpSupported(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const = 0;
const WebnnDeviceType device_type, const emscripten::val& wnn_limits,
const logging::Logger& logger) const = 0;
};
} // namespace webnn

View file

@ -25,7 +25,6 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateUnaryOpBuilder("Identity", op_registrations);
CreateUnaryOpBuilder("Log", op_registrations);
CreateUnaryOpBuilder("Neg", op_registrations);
CreateUnaryOpBuilder("Not", op_registrations);
CreateUnaryOpBuilder("Reciprocal", op_registrations);
CreateUnaryOpBuilder("Sin", op_registrations);
CreateUnaryOpBuilder("Sqrt", op_registrations);
@ -118,6 +117,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateLogicalOpBuilder("GreaterOrEqual", op_registrations);
CreateLogicalOpBuilder("Less", op_registrations);
CreateLogicalOpBuilder("LessOrEqual", op_registrations);
CreateLogicalOpBuilder("Not", op_registrations);
}
{ // Max/Min

View file

@ -21,10 +21,8 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f
: IExecutionProvider{onnxruntime::kWebNNExecutionProvider} {
// WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend.
if (webnn_device_flags.compare("cpu") == 0) {
preferred_layout_ = DataLayout::NHWC;
wnn_device_type_ = webnn::WebnnDeviceType::CPU;
} else {
preferred_layout_ = DataLayout::NCHW;
if (webnn_device_flags.compare("gpu") == 0) {
wnn_device_type_ = webnn::WebnnDeviceType::GPU;
} else if (webnn_device_flags.compare("npu") == 0) {
@ -38,6 +36,17 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f
if (!wnn_context_.as<bool>()) {
ORT_THROW("Failed to create WebNN context.");
}
// Retrieve the level of support for different WebNN operators.
// This varies across implementations and is obtained via the WebNN's opSupportLimits() function.
// https://www.w3.org/TR/webnn/#api-mlcontext-opsupportlimits
wnn_limits_ = wnn_context_.call<emscripten::val>("opSupportLimits");
if (wnn_limits_["preferredInputLayout"].as<std::string>().compare("nhwc") == 0) {
preferred_layout_ = DataLayout::NHWC;
} else {
preferred_layout_ = DataLayout::NCHW;
}
}
WebNNExecutionProvider::~WebNNExecutionProvider() {}
@ -82,7 +91,7 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
ORT_THROW("Failed to create WebNN builder.");
}
const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, logger);
const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, wnn_limits_, logger);
wnn_builder = emscripten::val::undefined();
if (node_groups.empty()) {
@ -213,7 +222,7 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph);
webnn::ModelBuilder builder(graph_viewer, *GetLogger(), wnn_context_,
preferred_layout_, wnn_device_type_);
preferred_layout_, wnn_device_type_, wnn_limits_);
std::unique_ptr<webnn::Model> model;
ORT_RETURN_IF_ERROR(builder.Compile(model));
@ -295,11 +304,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
auto output_type = output_info.data_type;
auto output_tensor =
ctx.GetOutput(i, output_shape.data(), output_shape.size());
if (!webnn::IsSupportedDataType(output_type, webnn::webnn_supported_data_types)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"Unsupported type: ", output_type, " for output: ", output_name);
}
void* output_buffer = output_tensor.GetTensorMutableRawData();
outputs.emplace(output_name,
webnn::OnnxTensorData{

View file

@ -43,6 +43,7 @@ class WebNNExecutionProvider : public IExecutionProvider {
private:
emscripten::val wnn_context_ = emscripten::val::undefined();
emscripten::val wnn_limits_ = emscripten::val::undefined();
DataLayout preferred_layout_;
webnn::WebnnDeviceType wnn_device_type_;