mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-21 02:18:09 +00:00
[WebNN EP] Use opSupportLimits to dynamically check data type support (#22025)
- Remove hard code data type checks and use WebNN's opSupportLimits instead - Add HasSupportedOutputsImpl for output data type validation - Get preferred layout info from opSupportLimits - Move Not op to logical_op_builder.cc because it should be there. This avoid the inconsistent input names in `unary_op_builder.cc`.
This commit is contained in:
parent
a89bddd5c2
commit
c63dd0234b
32 changed files with 288 additions and 642 deletions
|
|
@ -45,12 +45,12 @@ bool GetShape(const NodeArg& node_arg, std::vector<int64_t>& shape, const loggin
|
|||
return true;
|
||||
}
|
||||
|
||||
bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer,
|
||||
const WebnnDeviceType device_type, const logging::Logger& logger) {
|
||||
bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const WebnnDeviceType device_type,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) {
|
||||
const auto& op_builders = GetOpBuilders();
|
||||
if (Contains(op_builders, node.OpType())) {
|
||||
const auto* op_builder = op_builders.at(node.OpType());
|
||||
return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, device_type, logger);
|
||||
return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, device_type, wnn_limits, logger);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -86,6 +86,7 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons
|
|||
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
|
||||
const emscripten::val& wnn_builder,
|
||||
const WebnnDeviceType device_type,
|
||||
const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) {
|
||||
std::vector<std::vector<size_t>> supported_node_groups;
|
||||
|
||||
|
|
@ -105,7 +106,7 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
|
|||
// Firstly check if platform supports the WebNN op.
|
||||
if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) {
|
||||
LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser";
|
||||
supported = IsNodeSupported(*node, graph_viewer, device_type, logger);
|
||||
supported = IsNodeSupported(*node, graph_viewer, device_type, wnn_limits, logger);
|
||||
}
|
||||
|
||||
LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType()
|
||||
|
|
@ -130,10 +131,54 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
|
|||
return supported_node_groups;
|
||||
}
|
||||
|
||||
bool IsSupportedDataType(const int32_t data_type,
|
||||
const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType>& supported_data_types) {
|
||||
return std::find(supported_data_types.begin(), supported_data_types.end(), data_type) !=
|
||||
supported_data_types.end();
|
||||
bool AreInputDataTypesSame(const std::string& op_type,
|
||||
gsl::span<const int32_t> input_types,
|
||||
const logging::Logger& logger) {
|
||||
for (size_t i = 1; i < input_types.size(); i++) {
|
||||
if (input_types[0] != input_types[i]) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input data types should be the same, but ["
|
||||
<< input_types[0] << "] does not match "
|
||||
<< input_types[i] << "].";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types) {
|
||||
auto it = onnx_to_webnn_data_type_map.find(static_cast<ONNX_NAMESPACE::TensorProto_DataType>(onnx_data_type));
|
||||
if (it == onnx_to_webnn_data_type_map.end())
|
||||
return false;
|
||||
|
||||
std::string webnn_data_type = it->second;
|
||||
|
||||
// Check if WebNN supports the data type.
|
||||
emscripten::val is_supported = webnn_supported_data_types.call<emscripten::val>("includes",
|
||||
emscripten::val(webnn_data_type));
|
||||
return is_supported.as<bool>();
|
||||
}
|
||||
|
||||
// Check if the input or output data type of ONNX node is supported by the WebNN operator.
|
||||
bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
|
||||
const int32_t onnx_data_type,
|
||||
const emscripten::val& wnn_limits,
|
||||
const std::string& webnn_input_output_name,
|
||||
const std::string& onnx_input_output_name,
|
||||
const logging::Logger& logger) {
|
||||
std::string webnn_op_type;
|
||||
if (!GetWebNNOpType(onnx_op_type, webnn_op_type))
|
||||
return false;
|
||||
|
||||
if (!IsSupportedDataType(onnx_data_type, wnn_limits[webnn_op_type][webnn_input_output_name]["dataTypes"])) {
|
||||
LOGS(logger, VERBOSE) << "[" << onnx_op_type
|
||||
<< "] " << onnx_input_output_name
|
||||
<< " type: [" << onnx_data_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GetBidirectionalBroadcastShape(std::vector<int64_t>& shape_a,
|
||||
|
|
|
|||
|
|
@ -148,6 +148,7 @@ bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, c
|
|||
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
|
||||
const emscripten::val& wnn_builder,
|
||||
const WebnnDeviceType device_type,
|
||||
const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger);
|
||||
static const InlinedHashMap<std::string, std::string> op_map = {
|
||||
{"Abs", "abs"},
|
||||
|
|
@ -250,20 +251,38 @@ inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn
|
|||
return true;
|
||||
}
|
||||
|
||||
static const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> webnn_supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT8,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_UINT8,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT32,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT64,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_UINT32,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_UINT64,
|
||||
inline bool GetWebNNOpType(const std::string& op_type, std::string& webnn_op_type) {
|
||||
auto it = op_map.find(op_type);
|
||||
// Returns false if the op_type is not listed in the op_map.
|
||||
if (it == op_map.end()) {
|
||||
return false;
|
||||
}
|
||||
webnn_op_type = it->second;
|
||||
return true;
|
||||
}
|
||||
|
||||
static const InlinedHashMap<ONNX_NAMESPACE::TensorProto_DataType, std::string> onnx_to_webnn_data_type_map = {
|
||||
{ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"},
|
||||
{ONNX_NAMESPACE::TensorProto_DataType_INT8, "int8"},
|
||||
{ONNX_NAMESPACE::TensorProto_DataType_UINT8, "uint8"},
|
||||
{ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, "float16"},
|
||||
{ONNX_NAMESPACE::TensorProto_DataType_FLOAT, "float32"},
|
||||
{ONNX_NAMESPACE::TensorProto_DataType_INT32, "int32"},
|
||||
{ONNX_NAMESPACE::TensorProto_DataType_INT64, "int64"},
|
||||
{ONNX_NAMESPACE::TensorProto_DataType_UINT32, "uint32"},
|
||||
{ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"},
|
||||
};
|
||||
|
||||
bool IsSupportedDataType(const int32_t data_type,
|
||||
const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType>& supported_data_types);
|
||||
bool AreInputDataTypesSame(const std::string& op_type,
|
||||
gsl::span<const int32_t> input_types,
|
||||
const logging::Logger& logger);
|
||||
bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types);
|
||||
bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
|
||||
const int32_t onnx_data_type,
|
||||
const emscripten::val& wnn_limits,
|
||||
const std::string& webnn_input_output_name,
|
||||
const std::string& onnx_input_output_name,
|
||||
const logging::Logger& logger);
|
||||
|
||||
bool GetBidirectionalBroadcastShape(std::vector<int64_t>& shape_a,
|
||||
std::vector<int64_t>& shape_b,
|
||||
|
|
|
|||
|
|
@ -21,8 +21,6 @@ class ActivationOpBuilder : public BaseOpBuilder {
|
|||
// Operator support related.
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
WebnnDeviceType device_type, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -94,44 +92,6 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initi
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ActivationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
|
||||
// WebNN relu op supports float32, float16, int32, int8 input data types.
|
||||
if (op_type == "Relu") {
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT32,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT8,
|
||||
};
|
||||
// WebNN CPU backend does not support int32 data type for relu.
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32);
|
||||
}
|
||||
} else { // Others only support float32 and float16.
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
};
|
||||
}
|
||||
|
||||
if (!IsSupportedDataType(input_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
if (op_registrations.op_builder_map.count(op_type) > 0)
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -22,8 +22,6 @@ class ArgMaxMinOpBuilder : public BaseOpBuilder {
|
|||
// Operator support related.
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
WebnnDeviceType device_type, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -77,31 +75,6 @@ bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ArgMaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
|
||||
// WebNN CPU backend doesn't support int64, uint64 input data types for argMax and argMin.
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
|
||||
}
|
||||
|
||||
if (!IsSupportedDataType(input_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateArgMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
if (op_registrations.op_builder_map.count(op_type) > 0)
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -38,9 +38,9 @@ bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node
|
|||
Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const {
|
||||
ORT_RETURN_IF_NOT(
|
||||
IsOpSupported(model_builder.GetInitializerTensors(), node, model_builder.GetWebnnDeviceType(), logger),
|
||||
"Unsupported operator ",
|
||||
node.OpType());
|
||||
IsOpSupported(model_builder.GetInitializerTensors(), node, model_builder.GetWebnnDeviceType(),
|
||||
model_builder.GetOpSupportLimits(), logger),
|
||||
"Unsupported operator ", node.OpType());
|
||||
ORT_RETURN_IF_ERROR(AddToModelBuilderImpl(model_builder, node, logger));
|
||||
LOGS(logger, VERBOSE) << "Operator name: [" << node.Name()
|
||||
<< "] type: [" << node.OpType() << "] was added";
|
||||
|
|
@ -50,8 +50,12 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node&
|
|||
// Operator support related.
|
||||
|
||||
bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType device_type, const logging::Logger& logger) const {
|
||||
if (!HasSupportedInputs(node, device_type, logger))
|
||||
const WebnnDeviceType device_type, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
if (!HasSupportedInputs(node, wnn_limits, logger))
|
||||
return false;
|
||||
|
||||
if (!HasSupportedOutputsImpl(node, wnn_limits, logger))
|
||||
return false;
|
||||
|
||||
// We do not support external initializers for now.
|
||||
|
|
@ -64,7 +68,7 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons
|
|||
return IsOpSupportedImpl(initializers, node, device_type, logger);
|
||||
}
|
||||
|
||||
bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType device_type,
|
||||
bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
|
||||
for (const auto* input : node.InputDefs()) {
|
||||
|
|
@ -73,39 +77,33 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType d
|
|||
}
|
||||
}
|
||||
|
||||
// WebNN CPU backend (TFLite) will enable float16 input data type soon,
|
||||
// temporarily fallback float16 input data type for WebNN CPU.
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)
|
||||
return false;
|
||||
}
|
||||
|
||||
return HasSupportedInputsImpl(node, device_type, logger);
|
||||
return HasSupportedInputsImpl(node, wnn_limits, logger);
|
||||
}
|
||||
|
||||
bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node,
|
||||
const WebnnDeviceType /* device_type */,
|
||||
const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
// We only check the type of input 0 by default, specific op builder can override this.
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
if (!IsSupportedDataType(input_type, webnn_supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << node.OpType()
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger);
|
||||
}
|
||||
|
||||
return true;
|
||||
bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node,
|
||||
const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
// We only check the type of output 0 by default, specific op builder can override this.
|
||||
const auto& output = *node.OutputDefs()[0];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t output_type;
|
||||
if (!GetType(output, output_type, logger))
|
||||
return false;
|
||||
|
||||
return IsDataTypeSupportedByOp(op_type, output_type, wnn_limits, "output", "Output", logger);
|
||||
}
|
||||
|
||||
bool BaseOpBuilder::HasSupportedOpSet(const Node& node,
|
||||
|
|
|
|||
|
|
@ -28,7 +28,8 @@ class BaseOpBuilder : public IOpBuilder {
|
|||
// Operator support related.
|
||||
public:
|
||||
bool IsOpSupported(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
|
||||
const WebnnDeviceType device_type, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
|
||||
protected:
|
||||
virtual bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& /* node */,
|
||||
|
|
@ -36,8 +37,10 @@ class BaseOpBuilder : public IOpBuilder {
|
|||
return true;
|
||||
}
|
||||
|
||||
virtual bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
virtual bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const;
|
||||
virtual bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const;
|
||||
|
||||
// ONNX Runtime only *guarantees* support for models stamped
|
||||
// with opset version 7 or above for opset domain 'ai.onnx'.
|
||||
|
|
@ -50,7 +53,7 @@ class BaseOpBuilder : public IOpBuilder {
|
|||
|
||||
private:
|
||||
bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const;
|
||||
bool HasSupportedInputs(const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const;
|
||||
bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
|
||||
};
|
||||
|
||||
} // namespace webnn
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class BinaryOpBuilder : public BaseOpBuilder {
|
|||
// Operator support related.
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
|
|||
return true;
|
||||
}
|
||||
|
||||
bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
|
|
@ -97,36 +97,14 @@ bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDevice
|
|||
!GetType(*input_defs[1], input1_type, logger))
|
||||
return false;
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
|
||||
// WebNN prelu op only supports float32, float16, int32, int8 input data types.
|
||||
if (op_type == "Prelu") {
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT32,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT8,
|
||||
};
|
||||
// WebNN CPU backend doesn't support int32 for prelu.
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32);
|
||||
}
|
||||
} else {
|
||||
supported_data_types = webnn_supported_data_types;
|
||||
}
|
||||
if (!IsSupportedDataType(input0_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input0_type
|
||||
<< "] is not supported for now";
|
||||
std::array<int32_t, 2> input_types{input0_type, input1_type};
|
||||
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input0_type != input1_type) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input data types should be the same.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
std::string webnn_input_name = op_type == "PRelu" ? "input" : "a";
|
||||
std::string onnx_input_name = op_type == "PRelu" || op_type == "Pow" ? "X" : "A";
|
||||
return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, webnn_input_name, onnx_input_name, logger);
|
||||
}
|
||||
|
||||
void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
|
|
|
|||
|
|
@ -21,8 +21,8 @@ class CastOpBuilder : public BaseOpBuilder {
|
|||
|
||||
// Operator support related.
|
||||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -80,26 +80,22 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
}
|
||||
|
||||
// Operator support related.
|
||||
bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
|
||||
if (!GetType(*input_defs[0], input_type, logger))
|
||||
return false;
|
||||
|
||||
if (!IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "input", logger))
|
||||
return false;
|
||||
|
||||
bool CastOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */,
|
||||
const Node& node,
|
||||
const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const {
|
||||
NodeAttrHelper helper(node);
|
||||
// Check cast output type.
|
||||
// Check cast to type.
|
||||
const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED);
|
||||
|
||||
// WebNN CPU backend doesn't support casting to uint64 data type.
|
||||
if (device_type == WebnnDeviceType::CPU && to_type == ONNX_NAMESPACE::TensorProto_DataType_UINT64) {
|
||||
LOGS(logger, VERBOSE) << "Cast to uint64 is not supported for WebNN CPU backend.";
|
||||
return false;
|
||||
}
|
||||
if (!IsSupportedDataType(to_type, webnn_supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "WebNN doesn't support casting to type " << to_type << ".";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
return IsDataTypeSupportedByOp(op_type, to_type, wnn_limits, "output", "to", logger);
|
||||
}
|
||||
|
||||
void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
|
|
|
|||
|
|
@ -25,8 +25,6 @@ class ClipOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -94,33 +92,6 @@ bool ClipOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
|||
};
|
||||
}
|
||||
|
||||
bool ClipOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
|
||||
// WebNN CPU backend doesn't support int32, uint32, int64, uint64 input data types for clamp.
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32);
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
|
||||
}
|
||||
|
||||
if (!IsSupportedDataType(input_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<ClipOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
|
|
|
|||
|
|
@ -19,6 +19,10 @@ class ConcatOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
|
||||
// Operator support related.
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -52,6 +56,30 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool ConcatOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input0_type;
|
||||
|
||||
if (!GetType(*input_defs[0], input0_type, logger))
|
||||
return false;
|
||||
|
||||
for (size_t i = 1; i < input_defs.size(); i++) {
|
||||
int32_t input_type;
|
||||
if (!GetType(*input_defs[i], input_type, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::array<int32_t, 2> input_types{input0_type, input_type};
|
||||
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger);
|
||||
}
|
||||
|
||||
void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<ConcatOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class ConvOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
|
|
@ -397,7 +397,7 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
|
|
@ -415,35 +415,18 @@ bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTy
|
|||
return false;
|
||||
}
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
|
||||
if (op_type == "Conv" || op_type == "ConvTranspose") {
|
||||
// WebNN conv2d and convTranspose2d only support float32 and float16 input data types.
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
};
|
||||
} else if (op_type == "ConvInteger") {
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT8,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_UINT8,
|
||||
};
|
||||
InlinedVector<int32_t, 4> input_types = {input0_type, input1_type};
|
||||
if (has_input2) {
|
||||
input_types.push_back(input2_type);
|
||||
}
|
||||
if (!IsSupportedDataType(input0_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input0_type
|
||||
<< "] is not supported for now";
|
||||
if (has_input3) {
|
||||
input_types.push_back(input3_type);
|
||||
}
|
||||
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input0_type != input1_type ||
|
||||
(has_input2 && input0_type != input2_type) ||
|
||||
(has_input3 && input0_type != input3_type)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input data types should be the same.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger);
|
||||
}
|
||||
|
||||
void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class GatherOpBuilder : public BaseOpBuilder {
|
|||
// Operator support related.
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
|
|
@ -69,29 +69,19 @@ bool GatherOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ
|
|||
return true;
|
||||
}
|
||||
|
||||
bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& indices = *node.InputDefs()[1];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
int32_t indices_type;
|
||||
if (!GetType(input, input_type, logger) ||
|
||||
!GetType(indices, indices_type, logger))
|
||||
return false;
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
|
||||
// WebNN CPU backend doesn't support uint32, uint64 input data types for gather.
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
|
||||
}
|
||||
|
||||
if (!IsSupportedDataType(input_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) &&
|
||||
IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger);
|
||||
}
|
||||
|
||||
void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class GemmOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
|
|
@ -215,7 +215,7 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializer
|
|||
return true;
|
||||
}
|
||||
|
||||
bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
|
|
@ -233,35 +233,18 @@ bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTy
|
|||
return false;
|
||||
}
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
|
||||
if (op_type == "Gemm" || op_type == "MatMul") {
|
||||
// WebNN gemm and matmul only support float32 and float16 input data types.
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
};
|
||||
} else if (op_type == "MatMulInteger") {
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT8,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_UINT8,
|
||||
};
|
||||
InlinedVector<int32_t, 4> input_types = {input0_type, input1_type};
|
||||
if (has_input2) {
|
||||
input_types.push_back(input2_type);
|
||||
}
|
||||
if (!IsSupportedDataType(input0_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input0_type
|
||||
<< "] is not supported for now";
|
||||
if (has_input3) {
|
||||
input_types.push_back(input3_type);
|
||||
}
|
||||
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input0_type != input1_type ||
|
||||
(has_input2 && input0_type != input2_type) ||
|
||||
(has_input3 && input0_type != input3_type)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input data types should be the same.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger);
|
||||
}
|
||||
|
||||
void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class GruOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
|
|
@ -185,7 +185,7 @@ bool GruOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, c
|
|||
return true;
|
||||
}
|
||||
|
||||
bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
|
|
@ -208,37 +208,21 @@ bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTyp
|
|||
return false;
|
||||
}
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
// WebNN CPU backend only support float32 input data type.
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
};
|
||||
} else if (device_type == WebnnDeviceType::GPU) {
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
};
|
||||
InlinedVector<int32_t, 6> input_types = {input0_type, input1_type, input2_type};
|
||||
if (has_input3) {
|
||||
input_types.push_back(input3_type);
|
||||
}
|
||||
|
||||
if (!IsSupportedDataType(input0_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input0_type
|
||||
<< "] is not supported for now";
|
||||
if (has_input4) {
|
||||
input_types.push_back(input4_type);
|
||||
}
|
||||
if (has_input5) {
|
||||
input_types.push_back(input5_type);
|
||||
}
|
||||
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input0_type != input1_type ||
|
||||
input0_type != input2_type ||
|
||||
(has_input3 && input0_type != input3_type) ||
|
||||
(has_input4 && input0_type != input4_type) ||
|
||||
(has_input5 && input0_type != input5_type)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input data types should be the same.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger);
|
||||
}
|
||||
|
||||
void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class LogicalOpBuilder : public BaseOpBuilder {
|
|||
// Operator support related.
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
|
|
@ -29,9 +29,14 @@ class LogicalOpBuilder : public BaseOpBuilder {
|
|||
|
||||
Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& /* logger */) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
emscripten::val input0 = model_builder.GetOperand(node.InputDefs()[0]->Name());
|
||||
emscripten::val input1 = model_builder.GetOperand(node.InputDefs()[1]->Name());
|
||||
emscripten::val input0 = model_builder.GetOperand(input_defs[0]->Name());
|
||||
emscripten::val input1 = emscripten::val::undefined();
|
||||
if (input_defs.size() > 1) {
|
||||
input1 = model_builder.GetOperand(input_defs[1]->Name());
|
||||
}
|
||||
|
||||
emscripten::val output = emscripten::val::object();
|
||||
emscripten::val options = emscripten::val::object();
|
||||
options.set("label", node.Name());
|
||||
|
|
@ -45,6 +50,8 @@ Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons
|
|||
output = model_builder.GetBuilder().call<emscripten::val>("lesser", input0, input1, options);
|
||||
} else if (op_type == "LessOrEqual") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("lesserOrEqual", input0, input1, options);
|
||||
} else if (op_type == "Not") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("logicalNot", input0, options);
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"LogicalOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
|
||||
|
|
@ -61,7 +68,7 @@ bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali
|
|||
const auto& name = node.Name();
|
||||
const auto& op_type = node.OpType();
|
||||
const auto& input_defs = node.InputDefs();
|
||||
if (input_defs.size() < 2) {
|
||||
if (input_defs.size() < 2 && op_type != "Not") {
|
||||
LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 2 inputs, actual: "
|
||||
<< input_defs.size();
|
||||
return false;
|
||||
|
|
@ -69,31 +76,27 @@ bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali
|
|||
return true;
|
||||
}
|
||||
|
||||
bool LogicalOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
bool LogicalOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input0_type;
|
||||
int32_t input1_type;
|
||||
|
||||
if (!GetType(*input_defs[0], input0_type, logger) ||
|
||||
!GetType(*input_defs[1], input1_type, logger))
|
||||
if (!GetType(*input_defs[0], input0_type, logger))
|
||||
return false;
|
||||
|
||||
if (!IsSupportedDataType(input0_type, webnn_supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input0_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
if (op_type != "Not") {
|
||||
if (!GetType(*input_defs[1], input1_type, logger))
|
||||
return false;
|
||||
std::array<int32_t, 2> input_types{input0_type, input1_type};
|
||||
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (input0_type != input1_type) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input data types should be the same.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
std::string onnx_input_name = op_type == "Not" ? "X" : "A";
|
||||
return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", onnx_input_name, logger);
|
||||
}
|
||||
|
||||
void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
|
|
@ -107,6 +110,7 @@ void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations&
|
|||
"GreaterOrEqual",
|
||||
"Less",
|
||||
"LessOrEqual",
|
||||
"Not",
|
||||
};
|
||||
|
||||
op_registrations.builders.push_back(std::make_unique<LogicalOpBuilder>());
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class MaxMinOpBuilder : public BaseOpBuilder {
|
|||
// Operator support related.
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
|
|
@ -87,31 +87,28 @@ bool MaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ
|
|||
return true;
|
||||
}
|
||||
|
||||
bool MaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
bool MaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input0_type;
|
||||
int32_t input1_type;
|
||||
|
||||
if (!GetType(*input_defs[0], input0_type, logger) ||
|
||||
!GetType(*input_defs[1], input1_type, logger))
|
||||
if (!GetType(*input_defs[0], input0_type, logger))
|
||||
return false;
|
||||
|
||||
if (!IsSupportedDataType(input0_type, webnn_supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input0_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
for (size_t i = 1; i < input_defs.size(); i++) {
|
||||
int32_t input_type;
|
||||
if (!GetType(*input_defs[i], input_type, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::array<int32_t, 2> input_types{input0_type, input_type};
|
||||
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (input0_type != input1_type) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input data types should be the same.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger);
|
||||
}
|
||||
|
||||
void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class NormalizationOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
|
|
@ -182,7 +182,7 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initi
|
|||
return true;
|
||||
}
|
||||
|
||||
bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
|
|
@ -203,30 +203,21 @@ bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const Webn
|
|||
return false;
|
||||
}
|
||||
|
||||
// WebNN batchNormalization, instanceNormalization, layerNormalization
|
||||
// only support float32 and float16 input data types.
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
};
|
||||
|
||||
if (!IsSupportedDataType(input0_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input0_type
|
||||
<< "] is not supported for now";
|
||||
std::vector<int32_t> input_types = {input0_type, input1_type};
|
||||
if (has_input2) {
|
||||
input_types.push_back(input2_type);
|
||||
}
|
||||
if (has_input3) {
|
||||
input_types.push_back(input3_type);
|
||||
}
|
||||
if (has_input4) {
|
||||
input_types.push_back(input4_type);
|
||||
}
|
||||
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input0_type != input1_type ||
|
||||
(has_input2 && input0_type != input2_type) ||
|
||||
(has_input3 && input0_type != input3_type) ||
|
||||
(has_input4 && input0_type != input4_type)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input data types should be the same.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger);
|
||||
}
|
||||
|
||||
void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
|
|
|
|||
|
|
@ -28,8 +28,6 @@ class PadOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -196,31 +194,6 @@ bool PadOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
|||
return true;
|
||||
} // namespace webnn
|
||||
|
||||
bool PadOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
|
||||
// WebNN CPU backend doesn't support uint32, uint64 input data types for pad.
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
|
||||
}
|
||||
|
||||
if (!IsSupportedDataType(input_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<PadOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
|
|
|
|||
|
|
@ -31,8 +31,6 @@ class ReductionOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -147,56 +145,6 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ReductionOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
|
||||
if (op_type == "ReduceL1" || op_type == "ReduceProd" ||
|
||||
op_type == "ReduceSum" || op_type == "ReduceSumSquare") {
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT32,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_UINT32,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT64,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_UINT64,
|
||||
};
|
||||
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
// WebNN CPU backend doesn't support uint32 and uint64 for reduceL1,
|
||||
// reduceProd, reduceSum and reduceSumSquare.
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
|
||||
}
|
||||
} else if (op_type == "ReduceL2" || op_type == "ReduceLogSum" ||
|
||||
op_type == "ReduceLogSumExp" || op_type == "ReduceMean") {
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
};
|
||||
} else { // ReduceMax and ReduceMin
|
||||
supported_data_types = webnn_supported_data_types;
|
||||
// WebNN CPU backend doesn't support uint32, uint64 for reduceMax and reduceMin.
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
|
||||
}
|
||||
}
|
||||
if (!IsSupportedDataType(input_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
if (op_registrations.op_builder_map.count(op_type) > 0)
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -35,8 +35,6 @@ class ResizeOpBuilder : public BaseOpBuilder {
|
|||
// Resize opset 10- is very different than Resize opset 11+, with many key attributes missing.
|
||||
// We only support Resize opset 11+ here.
|
||||
int GetMinSupportedOpSet(const Node& /* node */) const override { return 11; }
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Helper functions
|
||||
|
|
@ -275,30 +273,6 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ResizeOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
// WebNN resample2d op only supports float32 and float16 input data types.
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
};
|
||||
|
||||
if (!IsSupportedDataType(input_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<ResizeOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
|
|
|
|||
|
|
@ -18,11 +18,6 @@ class ShapeOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
|
||||
// Operator support related.
|
||||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
||||
|
|
@ -69,28 +64,6 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Operator support related.
|
||||
|
||||
bool ShapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */,
|
||||
const Node& node,
|
||||
const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
std::vector<int64_t> input_shape;
|
||||
if (!GetShape(*input_defs[0], input_shape, logger))
|
||||
return false;
|
||||
|
||||
int32_t output_type = ONNX_NAMESPACE::TensorProto_DataType_INT64;
|
||||
if (!IsSupportedDataType(output_type, webnn_supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << node.OpType()
|
||||
<< "] Output type: [" << output_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateShapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<ShapeOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
|
|
|
|||
|
|
@ -29,8 +29,6 @@ class SliceOpBuilder : public BaseOpBuilder {
|
|||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
// TODO: Support Slice opset < 10, which uses attributes for starts and ends.
|
||||
int GetMinSupportedOpSet(const Node& /* node */) const override { return 10; }
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -166,30 +164,6 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
|||
return true;
|
||||
}
|
||||
|
||||
bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
|
||||
// WebNN CPU backend doesn't support uint64 input data type for slice.
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
|
||||
}
|
||||
|
||||
if (!IsSupportedDataType(input_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<SliceOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
|
|
|
|||
|
|
@ -24,8 +24,6 @@ class SoftmaxOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
||||
|
|
@ -63,30 +61,6 @@ bool SoftmaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali
|
|||
return true;
|
||||
}
|
||||
|
||||
bool SoftmaxOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
// WebNN softmax only supports float32 and float16 input data types.
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
};
|
||||
|
||||
if (!IsSupportedDataType(input_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<SoftmaxOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class TernaryOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
|
|
@ -46,7 +46,7 @@ Status TernaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
|
|
@ -59,27 +59,14 @@ bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDevic
|
|||
!GetType(*input_defs[2], input2_type, logger))
|
||||
return false;
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
|
||||
// WebNN CPU backend doesn't support uint64 X, Y data type for where.
|
||||
if (device_type == WebnnDeviceType::CPU && op_type == "Where") {
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
|
||||
}
|
||||
// ONNX's condition data type is bool which is same as WebNN.
|
||||
// Only need to check X, Y data types.
|
||||
if (!IsSupportedDataType(input1_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input1_type
|
||||
<< "] is not supported for now";
|
||||
std::array<int32_t, 2> input_types{input1_type, input2_type};
|
||||
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input1_type != input2_type) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input X, Y data types should be the same.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
return IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "trueValue", "X", logger);
|
||||
}
|
||||
|
||||
void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
|
|
|
|||
|
|
@ -18,8 +18,6 @@ class TransposeOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -50,31 +48,6 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool TransposeOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
|
||||
// WebNN CPU backend doesn't support uint32, uint64 input data types for transpose.
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
|
||||
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
|
||||
}
|
||||
|
||||
if (!IsSupportedDataType(input_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<TransposeOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
|
|
|
|||
|
|
@ -18,8 +18,6 @@ class UnaryOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
|
@ -51,8 +49,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
output = model_builder.GetBuilder().call<emscripten::val>("log", input, options);
|
||||
} else if (op_type == "Neg") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("neg", input, options);
|
||||
} else if (op_type == "Not") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("logicalNot", input, options);
|
||||
} else if (op_type == "Reciprocal") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("reciprocal", input, options);
|
||||
} else if (op_type == "Sin") {
|
||||
|
|
@ -70,44 +66,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool UnaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
|
||||
if (op_type == "Identity") {
|
||||
supported_data_types = webnn_supported_data_types;
|
||||
} else if (op_type == "Abs" || op_type == "Neg") {
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT32,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT8,
|
||||
};
|
||||
} else if (op_type == "Not") {
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
|
||||
};
|
||||
} else { // Others only support float32, float16 input data types.
|
||||
supported_data_types = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
|
||||
};
|
||||
}
|
||||
if (!IsSupportedDataType(input_type, supported_data_types)) {
|
||||
LOGS(logger, VERBOSE) << "[" << op_type
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
if (op_registrations.op_builder_map.count(op_type) > 0)
|
||||
return;
|
||||
|
|
@ -123,7 +81,6 @@ void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op
|
|||
"Identity",
|
||||
"Log",
|
||||
"Neg",
|
||||
"Not",
|
||||
"Reciprocal",
|
||||
"Sin",
|
||||
"Sqrt",
|
||||
|
|
|
|||
|
|
@ -21,12 +21,13 @@ namespace webnn {
|
|||
|
||||
ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger,
|
||||
const emscripten::val& context, const DataLayout preferred_layout,
|
||||
const WebnnDeviceType wnn_device_type)
|
||||
const WebnnDeviceType wnn_device_type, const emscripten::val& wnn_limits)
|
||||
: graph_viewer_(graph_viewer),
|
||||
logger_(logger),
|
||||
wnn_context_(context),
|
||||
preferred_layout_(preferred_layout),
|
||||
wnn_device_type_(wnn_device_type) {
|
||||
wnn_device_type_(wnn_device_type),
|
||||
wnn_limits_(wnn_limits) {
|
||||
// Create WebNN MLGraphBuilder for each ModelBuilder, because MLGraphBuilder.build()
|
||||
// is only allowed to be called once.
|
||||
wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(context);
|
||||
|
|
@ -102,7 +103,7 @@ Status ModelBuilder::RegisterInitializers() {
|
|||
desc.set("dimensions", emscripten::val::array(dims));
|
||||
auto data_type = tensor.data_type();
|
||||
emscripten::val operand = emscripten::val::object();
|
||||
if (IsSupportedDataType(data_type, webnn_supported_data_types)) {
|
||||
if (IsSupportedDataType(data_type, wnn_limits_["constant"]["dataTypes"])) {
|
||||
ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type");
|
||||
auto num_elements = SafeInt<size_t>(Product(shape));
|
||||
emscripten::val view = emscripten::val::undefined();
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class ModelBuilder {
|
|||
public:
|
||||
ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger,
|
||||
const emscripten::val& context, const DataLayout preferred_layout,
|
||||
const WebnnDeviceType wnn_device_type);
|
||||
const WebnnDeviceType wnn_device_type, const emscripten::val& wnn_limits);
|
||||
~ModelBuilder() = default;
|
||||
|
||||
Status Compile(std::unique_ptr<Model>& model) ORT_MUST_USE_RESULT;
|
||||
|
|
@ -35,6 +35,8 @@ class ModelBuilder {
|
|||
const emscripten::val& GetBuilder() const { return wnn_builder_; }
|
||||
const emscripten::val& GetContext() const { return wnn_context_; }
|
||||
const emscripten::val& GetOperand(const std::string& name) const { return wnn_operands_.at(name); }
|
||||
const emscripten::val& GetOpSupportLimits() const { return wnn_limits_; }
|
||||
|
||||
void AddOperand(const std::string& name, const emscripten::val& operand);
|
||||
const emscripten::val& GetZeroConstant(const std::string& data_type);
|
||||
// Use the buffers to persist WebNN allocated data like transposed weight.
|
||||
|
|
@ -66,6 +68,7 @@ class ModelBuilder {
|
|||
emscripten::val wnn_builder_ = emscripten::val::undefined();
|
||||
DataLayout preferred_layout_;
|
||||
WebnnDeviceType wnn_device_type_;
|
||||
emscripten::val wnn_limits_ = emscripten::val::undefined();
|
||||
InlinedHashMap<std::string, emscripten::val> wnn_operands_;
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<std::string> output_names_;
|
||||
|
|
|
|||
|
|
@ -29,7 +29,8 @@ class IOpBuilder {
|
|||
public:
|
||||
// Check if an operator is supported.
|
||||
virtual bool IsOpSupported(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType device_type, const logging::Logger& logger) const = 0;
|
||||
const WebnnDeviceType device_type, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const = 0;
|
||||
};
|
||||
|
||||
} // namespace webnn
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
|
|||
CreateUnaryOpBuilder("Identity", op_registrations);
|
||||
CreateUnaryOpBuilder("Log", op_registrations);
|
||||
CreateUnaryOpBuilder("Neg", op_registrations);
|
||||
CreateUnaryOpBuilder("Not", op_registrations);
|
||||
CreateUnaryOpBuilder("Reciprocal", op_registrations);
|
||||
CreateUnaryOpBuilder("Sin", op_registrations);
|
||||
CreateUnaryOpBuilder("Sqrt", op_registrations);
|
||||
|
|
@ -118,6 +117,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
|
|||
CreateLogicalOpBuilder("GreaterOrEqual", op_registrations);
|
||||
CreateLogicalOpBuilder("Less", op_registrations);
|
||||
CreateLogicalOpBuilder("LessOrEqual", op_registrations);
|
||||
CreateLogicalOpBuilder("Not", op_registrations);
|
||||
}
|
||||
|
||||
{ // Max/Min
|
||||
|
|
|
|||
|
|
@ -21,10 +21,8 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f
|
|||
: IExecutionProvider{onnxruntime::kWebNNExecutionProvider} {
|
||||
// WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend.
|
||||
if (webnn_device_flags.compare("cpu") == 0) {
|
||||
preferred_layout_ = DataLayout::NHWC;
|
||||
wnn_device_type_ = webnn::WebnnDeviceType::CPU;
|
||||
} else {
|
||||
preferred_layout_ = DataLayout::NCHW;
|
||||
if (webnn_device_flags.compare("gpu") == 0) {
|
||||
wnn_device_type_ = webnn::WebnnDeviceType::GPU;
|
||||
} else if (webnn_device_flags.compare("npu") == 0) {
|
||||
|
|
@ -38,6 +36,17 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f
|
|||
if (!wnn_context_.as<bool>()) {
|
||||
ORT_THROW("Failed to create WebNN context.");
|
||||
}
|
||||
|
||||
// Retrieve the level of support for different WebNN operators.
|
||||
// This varies across implementations and is obtained via the WebNN's opSupportLimits() function.
|
||||
// https://www.w3.org/TR/webnn/#api-mlcontext-opsupportlimits
|
||||
wnn_limits_ = wnn_context_.call<emscripten::val>("opSupportLimits");
|
||||
|
||||
if (wnn_limits_["preferredInputLayout"].as<std::string>().compare("nhwc") == 0) {
|
||||
preferred_layout_ = DataLayout::NHWC;
|
||||
} else {
|
||||
preferred_layout_ = DataLayout::NCHW;
|
||||
}
|
||||
}
|
||||
|
||||
WebNNExecutionProvider::~WebNNExecutionProvider() {}
|
||||
|
|
@ -82,7 +91,7 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
|
|||
ORT_THROW("Failed to create WebNN builder.");
|
||||
}
|
||||
|
||||
const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, logger);
|
||||
const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, wnn_limits_, logger);
|
||||
wnn_builder = emscripten::val::undefined();
|
||||
|
||||
if (node_groups.empty()) {
|
||||
|
|
@ -213,7 +222,7 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
|
|||
const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph);
|
||||
|
||||
webnn::ModelBuilder builder(graph_viewer, *GetLogger(), wnn_context_,
|
||||
preferred_layout_, wnn_device_type_);
|
||||
preferred_layout_, wnn_device_type_, wnn_limits_);
|
||||
std::unique_ptr<webnn::Model> model;
|
||||
ORT_RETURN_IF_ERROR(builder.Compile(model));
|
||||
|
||||
|
|
@ -295,11 +304,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
|
|||
auto output_type = output_info.data_type;
|
||||
auto output_tensor =
|
||||
ctx.GetOutput(i, output_shape.data(), output_shape.size());
|
||||
|
||||
if (!webnn::IsSupportedDataType(output_type, webnn::webnn_supported_data_types)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
||||
"Unsupported type: ", output_type, " for output: ", output_name);
|
||||
}
|
||||
void* output_buffer = output_tensor.GetTensorMutableRawData();
|
||||
outputs.emplace(output_name,
|
||||
webnn::OnnxTensorData{
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ class WebNNExecutionProvider : public IExecutionProvider {
|
|||
|
||||
private:
|
||||
emscripten::val wnn_context_ = emscripten::val::undefined();
|
||||
emscripten::val wnn_limits_ = emscripten::val::undefined();
|
||||
|
||||
DataLayout preferred_layout_;
|
||||
webnn::WebnnDeviceType wnn_device_type_;
|
||||
|
|
|
|||
Loading…
Reference in a new issue