From 9ea9f9e46a930b2cfa93196158d355cc367ba36e Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 30 May 2024 01:19:51 +0800 Subject: [PATCH] [WebNN EP] Add data type constraint (#20779) WebNN spec has added data type constraint for every op, and its CPU backend (currently is TFLite) has additional constraint. Add corresponding constraint to each op in WebNN EP. Note: Temporarily disable fp16 for CPU backend as which is planned to be ready in Chromium next month. --- .../core/providers/webnn/builders/helper.cc | 14 ++-- .../core/providers/webnn/builders/helper.h | 9 +-- .../builders/impl/activation_op_builder.cc | 40 +++++++++++ .../builders/impl/argmax_min_op_builder.cc | 27 ++++++++ .../webnn/builders/impl/base_op_builder.cc | 16 ++++- .../webnn/builders/impl/binary_op_builder.cc | 45 ++++++++++++ .../webnn/builders/impl/cast_op_builder.cc | 6 +- .../webnn/builders/impl/clip_op_builder.cc | 29 ++++++++ .../webnn/builders/impl/conv_op_builder.cc | 68 ++++++++++++++----- .../webnn/builders/impl/gather_op_builder.cc | 27 ++++++++ .../webnn/builders/impl/gemm_op_builder.cc | 51 ++++++++++++++ .../webnn/builders/impl/logical_op_builder.cc | 59 ++++++++++++---- .../webnn/builders/impl/max_min_op_builder.cc | 29 ++++++++ .../builders/impl/normalization_op_builder.cc | 49 +++++++++++++ .../webnn/builders/impl/pad_op_builder.cc | 27 ++++++++ .../builders/impl/reduction_op_builder.cc | 47 +++++++++++++ .../webnn/builders/impl/resize_op_builder.cc | 26 +++++++ .../webnn/builders/impl/shape_op_builder.cc | 4 +- .../webnn/builders/impl/slice_op_builder.cc | 26 +++++++ .../webnn/builders/impl/softmax_op_builder.cc | 26 +++++++ .../webnn/builders/impl/ternary_op_builder.cc | 38 +++++++++++ .../builders/impl/transpose_op_builder.cc | 27 ++++++++ .../webnn/builders/impl/unary_op_builder.cc | 40 +++++++++++ .../providers/webnn/builders/model_builder.cc | 4 +- .../webnn/webnn_execution_provider.cc | 22 ++---- 25 files changed, 680 insertions(+), 76 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index ef7c10dae5..b8d2324010 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -130,16 +130,10 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v return supported_node_groups; } -bool IsSupportedDataType(const int32_t data_type, const WebnnDeviceType device_type) { - // Current data type implementation status of WebNN is inconsistent along with different backends, - // The XNNPack backend supports only FP32, while the DML backend POC supports more. - if (device_type == WebnnDeviceType::CPU) { - return std::find(supported_cpu_data_types.begin(), supported_cpu_data_types.end(), data_type) != - supported_cpu_data_types.end(); - } else { - return std::find(supported_gpu_data_types.begin(), supported_gpu_data_types.end(), data_type) != - supported_gpu_data_types.end(); - } +bool IsSupportedDataType(const int32_t data_type, + const std::unordered_set& supported_data_types) { + return std::find(supported_data_types.begin(), supported_data_types.end(), data_type) != + supported_data_types.end(); } bool IsValidMultidirectionalBroadcast(std::vector& shape_a, diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 9cec78bbfe..486f7f69be 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -262,11 +262,7 @@ inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn return true; } -constexpr std::array supported_cpu_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, -}; - -constexpr std::array supported_gpu_data_types = { +static const std::unordered_set webnn_supported_data_types = { ONNX_NAMESPACE::TensorProto_DataType_BOOL, ONNX_NAMESPACE::TensorProto_DataType_INT8, ONNX_NAMESPACE::TensorProto_DataType_UINT8, @@ -278,7 +274,8 @@ constexpr std::array supported_gpu_data ONNX_NAMESPACE::TensorProto_DataType_UINT64, }; -bool IsSupportedDataType(const int32_t data_type, const WebnnDeviceType device_type); +bool IsSupportedDataType(const int32_t data_type, + const std::unordered_set& supported_data_types); bool IsValidMultidirectionalBroadcast(std::vector& shape_a, std::vector& shape_b, diff --git a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc index cd6a95a2aa..163c9b0fb9 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc @@ -21,6 +21,8 @@ class ActivationOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const override; }; // Add operator related. @@ -81,6 +83,44 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initi return true; } +bool ActivationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const { + const auto& input = *node.InputDefs()[0]; + const auto& op_type = node.OpType(); + int32_t input_type; + if (!GetType(input, input_type, logger)) + return false; + + std::unordered_set supported_data_types; + // WebNN relu op supports float32, float16, int32, int8 input data types. + if (op_type == "Relu") { + supported_data_types = { + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + ONNX_NAMESPACE::TensorProto_DataType_INT32, + ONNX_NAMESPACE::TensorProto_DataType_INT8, + }; + // WebNN CPU backend does not support int32 data type for relu. + if (device_type == WebnnDeviceType::CPU) { + supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32); + } + } else { // Others only support float32 and float16. + supported_data_types = { + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + }; + } + + if (!IsSupportedDataType(input_type, supported_data_types)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input type: [" << input_type + << "] is not supported for now"; + return false; + } + + return true; +} + void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc index 5f8defe8fc..7926311f3c 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc @@ -22,6 +22,8 @@ class ArgMaxMinOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const override; }; // Add operator related. @@ -77,6 +79,31 @@ bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia return true; } +bool ArgMaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const { + const auto& input = *node.InputDefs()[0]; + const auto& op_type = node.OpType(); + int32_t input_type; + if (!GetType(input, input_type, logger)) + return false; + + std::unordered_set supported_data_types = webnn_supported_data_types; + // WebNN CPU backend doesn't support int64, uint64 input data types for argMax and argMin. + if (device_type == WebnnDeviceType::CPU) { + supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64); + supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); + } + + if (!IsSupportedDataType(input_type, supported_data_types)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input type: [" << input_type + << "] is not supported for now"; + return false; + } + + return true; +} + void CreateArgMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index a893d2ff2c..fa53588929 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -73,11 +73,23 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType d } } + // WebNN CPU backend (TFLite) will enable float16 input data type soon, + // temporarily fallback float16 input data type for WebNN CPU. + if (device_type == WebnnDeviceType::CPU) { + const auto& input = *node.InputDefs()[0]; + + int32_t input_type; + if (!GetType(input, input_type, logger)) + return false; + if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) + return false; + } + return HasSupportedInputsImpl(node, device_type, logger); } bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, - const WebnnDeviceType device_type, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { // We only check the type of input 0 by default, specific op builder can override this. const auto& input = *node.InputDefs()[0]; @@ -86,7 +98,7 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, if (!GetType(input, input_type, logger)) return false; - if (!IsSupportedDataType(input_type, device_type)) { + if (!IsSupportedDataType(input_type, webnn_supported_data_types)) { LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not supported for now"; diff --git a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc index 5adaf80543..2c97ef490f 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc @@ -22,6 +22,8 @@ class BinaryOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const override; }; // Add operator related. @@ -72,6 +74,49 @@ bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return true; } +bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input0_type; + int32_t input1_type; + + if (!GetType(*input_defs[0], input0_type, logger) || + !GetType(*input_defs[1], input1_type, logger)) + return false; + + std::unordered_set supported_data_types; + // WebNN prelu op only supports float32, float16, int32, int8 input data types. + if (op_type == "Prelu") { + supported_data_types = { + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + ONNX_NAMESPACE::TensorProto_DataType_INT32, + ONNX_NAMESPACE::TensorProto_DataType_INT8, + }; + // WebNN CPU backend doesn't support int32 for prelu. + if (device_type == WebnnDeviceType::CPU) { + supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32); + } + } else { + supported_data_types = webnn_supported_data_types; + } + if (!IsSupportedDataType(input0_type, supported_data_types)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input type: [" << input0_type + << "] is not supported for now"; + return false; + } + + if (input0_type != input1_type) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input data types should be the same."; + return false; + } + + return true; +} + void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc index 3d961e4589..f7d3d308d2 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc @@ -22,7 +22,7 @@ class CastOpBuilder : public BaseOpBuilder { // Operator support related. private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const override; + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; }; // Add operator related. @@ -80,12 +80,12 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, bool CastOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, - const WebnnDeviceType device_type, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { NodeAttrHelper helper(node); // Check cast output type. const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED); - if (!IsSupportedDataType(to_type, device_type)) { + if (!IsSupportedDataType(to_type, webnn_supported_data_types)) { LOGS(logger, VERBOSE) << "Invalid cast to type " << to_type << "."; return false; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc index 84a51b6679..30848b6660 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc @@ -25,6 +25,8 @@ class ClipOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const override; }; // Add operator related. @@ -71,6 +73,33 @@ bool ClipOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return GetClipMinMax(initializers, node, min, max, logger); } +bool ClipOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const { + const auto& input = *node.InputDefs()[0]; + const auto& op_type = node.OpType(); + int32_t input_type; + if (!GetType(input, input_type, logger)) + return false; + + std::unordered_set supported_data_types = webnn_supported_data_types; + // WebNN CPU backend doesn't support int32, uint32, int64, uint64 input data types for clamp. + if (device_type == WebnnDeviceType::CPU) { + supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32); + supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); + supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64); + supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); + } + + if (!IsSupportedDataType(input_type, supported_data_types)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input type: [" << input_type + << "] is not supported for now"; + return false; + } + + return true; +} + void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index f79920b63a..4eaa4855ff 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -29,6 +29,8 @@ class ConvOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const override; }; void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { @@ -148,11 +150,6 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder, bool is_conv1d) { const auto& tensor = *model_builder.GetInitializerTensors().at(name); auto data_type = tensor.data_type(); - if (!IsSupportedDataType(data_type, model_builder.GetWebnnDeviceType())) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "The initializer of graph has unsupported type, name: ", - tensor.name(), " type: ", data_type); - } const auto& shape = tensor.dims(); std::vector dims = GetVecUint32FromVecInt64(std::vector(std::begin(shape), std::end(shape))); @@ -177,7 +174,6 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder, size_t element_size{0}; switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_BOOL: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: element_size = sizeof(uint8_t); break; @@ -190,17 +186,6 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder, case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: element_size = sizeof(float); break; - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - element_size = sizeof(int32_t); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - element_size = sizeof(int64_t); - break; - case ONNX_NAMESPACE::TensorProto_DataType_UINT32: - element_size = sizeof(uint32_t); - break; - case ONNX_NAMESPACE::TensorProto_DataType_UINT64: - element_size = sizeof(uint64_t); break; default: break; @@ -396,6 +381,55 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } +bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input0_type; // input data type + int32_t input1_type; // weight data type + int32_t input2_type; // bias or x_zero_point data type + int32_t input3_type; // w_zero_point data type + bool has_input2 = input_defs.size() > 2 && input_defs[2]->Exists(); + bool has_input3 = input_defs.size() > 3 && input_defs[3]->Exists(); + + if (!GetType(*input_defs[0], input0_type, logger) || + !GetType(*input_defs[1], input1_type, logger) || + (has_input2 && !GetType(*input_defs[2], input2_type, logger)) || + (has_input3 && !GetType(*input_defs[3], input3_type, logger))) { + return false; + } + + std::unordered_set supported_data_types; + if (op_type == "Conv" || op_type == "ConvTranspose") { + // WebNN conv2d and convTranspose2d only support float32 and float16 input data types. + supported_data_types = { + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + }; + } else if (op_type == "ConvInteger") { + supported_data_types = { + ONNX_NAMESPACE::TensorProto_DataType_INT8, + ONNX_NAMESPACE::TensorProto_DataType_UINT8, + }; + } + if (!IsSupportedDataType(input0_type, supported_data_types)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input type: [" << input0_type + << "] is not supported for now"; + return false; + } + + if (input0_type != input1_type || + (has_input2 && input0_type != input2_type) || + (has_input3 && input0_type != input3_type)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input data types should be the same."; + return false; + } + + return true; +} + void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc index 74a8f74474..014a08616c 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc @@ -22,6 +22,8 @@ class GatherOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const override; }; // Add operator related. @@ -66,6 +68,31 @@ bool GatherOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ return true; } +bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const { + const auto& input = *node.InputDefs()[0]; + const auto& op_type = node.OpType(); + int32_t input_type; + if (!GetType(input, input_type, logger)) + return false; + + std::unordered_set supported_data_types = webnn_supported_data_types; + // WebNN CPU backend doesn't support uint32, uint64 input data types for gather. + if (device_type == WebnnDeviceType::CPU) { + supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); + supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); + } + + if (!IsSupportedDataType(input_type, supported_data_types)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input type: [" << input_type + << "] is not supported for now"; + return false; + } + + return true; +} + void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index ed32013216..248463f473 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -25,6 +25,8 @@ class GemmOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const override; }; // Add operator related. @@ -219,6 +221,55 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } +bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input0_type; // A data type + int32_t input1_type; // B data type + int32_t input2_type; // C or a_zero_point data type + int32_t input3_type; // b_zero_point data type + bool has_input2 = input_defs.size() > 2 && input_defs[2]->Exists(); + bool has_input3 = input_defs.size() > 3 && input_defs[3]->Exists(); + + if (!GetType(*input_defs[0], input0_type, logger) || + !GetType(*input_defs[1], input1_type, logger) || + (has_input2 && !GetType(*input_defs[2], input2_type, logger)) || + (has_input3 && !GetType(*input_defs[3], input3_type, logger))) { + return false; + } + + std::unordered_set supported_data_types; + if (op_type == "Gemm" || op_type == "MatMul") { + // WebNN gemm and matmul only support float32 and float16 input data types. + supported_data_types = { + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + }; + } else if (op_type == "MatMulInteger") { + supported_data_types = { + ONNX_NAMESPACE::TensorProto_DataType_INT8, + ONNX_NAMESPACE::TensorProto_DataType_UINT8, + }; + } + if (!IsSupportedDataType(input0_type, supported_data_types)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input type: [" << input0_type + << "] is not supported for now"; + return false; + } + + if (input0_type != input1_type || + (has_input2 && input0_type != input2_type) || + (has_input3 && input0_type != input3_type)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input data types should be the same."; + return false; + } + + return true; +} + void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index c8f58fa986..e56e8f6a3e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -21,6 +21,8 @@ class LogicalOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const override; }; // Add operator related. @@ -50,6 +52,48 @@ Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons return Status::OK(); } +bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, + const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& name = node.Name(); + const auto& op_type = node.OpType(); + const auto& input_defs = node.InputDefs(); + if (input_defs.size() < 2) { + LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 2 inputs, actual: " + << input_defs.size(); + return false; + } + return true; +} + +bool LogicalOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input0_type; + int32_t input1_type; + + if (!GetType(*input_defs[0], input0_type, logger) || + !GetType(*input_defs[1], input1_type, logger)) + return false; + + if (!IsSupportedDataType(input0_type, webnn_supported_data_types)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input type: [" << input0_type + << "] is not supported for now"; + return false; + } + + if (input0_type != input1_type) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input data types should be the same."; + return false; + } + + return true; +} + void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; @@ -69,20 +113,5 @@ void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& } } -bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, - const Node& node, - const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& name = node.Name(); - const auto& op_type = node.OpType(); - const auto& input_defs = node.InputDefs(); - if (input_defs.size() < 2) { - LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 2 inputs, actual: " - << input_defs.size(); - return false; - } - return true; -} - } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc index cefd4236d8..0168f59273 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc @@ -22,6 +22,8 @@ class MaxMinOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const override; }; // Add operator related. @@ -84,6 +86,33 @@ bool MaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ return true; } +bool MaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input0_type; + int32_t input1_type; + + if (!GetType(*input_defs[0], input0_type, logger) || + !GetType(*input_defs[1], input1_type, logger)) + return false; + + if (!IsSupportedDataType(input0_type, webnn_supported_data_types)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input type: [" << input0_type + << "] is not supported for now"; + return false; + } + + if (input0_type != input1_type) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input data types should be the same."; + return false; + } + + return true; +} + void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index b6948b48c3..90ad9b48d5 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -25,6 +25,8 @@ class NormalizationOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const override; }; Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -173,6 +175,53 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initi return true; } +bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input0_type; // input data type + int32_t input1_type; // scale data type + int32_t input2_type; // B data type + int32_t input3_type; // mean data type + int32_t input4_type; // var data type + bool has_input2 = input_defs.size() > 2 && input_defs[2]->Exists(); + bool has_input3 = input_defs.size() > 3 && input_defs[3]->Exists(); + bool has_input4 = input_defs.size() > 3 && input_defs[4]->Exists(); + + if (!GetType(*input_defs[0], input0_type, logger) || + !GetType(*input_defs[1], input1_type, logger) || + (has_input2 && !GetType(*input_defs[2], input2_type, logger)) || + (has_input3 && !GetType(*input_defs[3], input3_type, logger)) || + (has_input4 && !GetType(*input_defs[4], input4_type, logger))) { + return false; + } + + // WebNN batchNormalization, instanceNormalization, layerNormalization + // only support float32 and float16 input data types. + std::unordered_set supported_data_types = { + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + }; + + if (!IsSupportedDataType(input0_type, supported_data_types)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input type: [" << input0_type + << "] is not supported for now"; + return false; + } + + if (input0_type != input1_type || + (has_input2 && input0_type != input2_type) || + (has_input3 && input0_type != input3_type) || + (has_input4 && input0_type != input4_type)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input data types should be the same."; + return false; + } + + return true; +} + void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc index 9852db0abc..bc90821ba4 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc @@ -28,6 +28,8 @@ class PadOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const override; }; // Add operator related. @@ -190,6 +192,31 @@ bool PadOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } // namespace webnn +bool PadOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const { + const auto& input = *node.InputDefs()[0]; + const auto& op_type = node.OpType(); + int32_t input_type; + if (!GetType(input, input_type, logger)) + return false; + + std::unordered_set supported_data_types = webnn_supported_data_types; + // WebNN CPU backend doesn't support uint32, uint64 input data types for pad. + if (device_type == WebnnDeviceType::CPU) { + supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); + supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); + } + + if (!IsSupportedDataType(input_type, supported_data_types)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input type: [" << input_type + << "] is not supported for now"; + return false; + } + + return true; +} + void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc index c0954f7cf6..de65f015de 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc @@ -31,6 +31,8 @@ class ReductionOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const override; }; // Add operator related. @@ -144,6 +146,51 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ return true; } +bool ReductionOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const { + const auto& input = *node.InputDefs()[0]; + const auto& op_type = node.OpType(); + int32_t input_type; + if (!GetType(input, input_type, logger)) + return false; + + std::unordered_set supported_data_types; + if (op_type == "ReduceL1" || op_type == "ReduceProd" || + op_type == "ReduceSum" || op_type == "ReduceSumSquare") { + supported_data_types = { + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + ONNX_NAMESPACE::TensorProto_DataType_INT32, + ONNX_NAMESPACE::TensorProto_DataType_UINT32, + }; + // WebNN CPU backend doesn't support uint32 for reduceProd and reduceSum. + if (device_type == WebnnDeviceType::CPU && (op_type == "ReduceProd" || op_type == "ReduceSum")) { + supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); + } + } else if (op_type == "ReduceL2" || op_type == "ReduceLogSum" || + op_type == "ReduceLogSumExp" || op_type == "ReduceMean") { + supported_data_types = { + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + }; + } else { // ReduceMax and ReduceMin + supported_data_types = webnn_supported_data_types; + // WebNN CPU backend doesn't support uint32, uint64 for reduceMax and reduceMin. + if (device_type == WebnnDeviceType::CPU) { + supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); + supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); + } + } + if (!IsSupportedDataType(input_type, supported_data_types)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input type: [" << input_type + << "] is not supported for now"; + return false; + } + + return true; +} + void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc index 9018f8c96f..ea54b70a66 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -35,6 +35,8 @@ class ResizeOpBuilder : public BaseOpBuilder { // Resize opset 10- is very different than Resize opset 11+, with many key attributes missing. // We only support Resize opset 11+ here. int GetMinSupportedOpSet(const Node& /* node */) const override { return 11; } + bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const override; }; // Helper functions @@ -280,6 +282,30 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return true; } +bool ResizeOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input = *node.InputDefs()[0]; + const auto& op_type = node.OpType(); + int32_t input_type; + if (!GetType(input, input_type, logger)) + return false; + + // WebNN resample2d op only supports float32 and float16 input data types. + std::unordered_set supported_data_types = { + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + }; + + if (!IsSupportedDataType(input_type, supported_data_types)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input type: [" << input_type + << "] is not supported for now"; + return false; + } + + return true; +} + void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc index 12c2cf6dd0..1552023d3f 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc @@ -66,7 +66,7 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, bool ShapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, - const WebnnDeviceType device_type, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); std::vector input_shape; @@ -74,7 +74,7 @@ bool ShapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initialize return false; int32_t output_type = ONNX_NAMESPACE::TensorProto_DataType_INT64; - if (!IsSupportedDataType(output_type, device_type)) { + if (!IsSupportedDataType(output_type, webnn_supported_data_types)) { LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Output type: [" << output_type << "] is not supported for now"; diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index 4e0628581a..fb452aec1c 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -29,6 +29,8 @@ class SliceOpBuilder : public BaseOpBuilder { const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; // TODO: Support Slice opset < 10, which uses attributes for starts and ends. int GetMinSupportedOpSet(const Node& /* node */) const override { return 10; } + bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const override; }; // Add operator related. @@ -161,6 +163,30 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } +bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const { + const auto& input = *node.InputDefs()[0]; + const auto& op_type = node.OpType(); + int32_t input_type; + if (!GetType(input, input_type, logger)) + return false; + + std::unordered_set supported_data_types = webnn_supported_data_types; + // WebNN CPU backend doesn't support uint64 input data type for slice. + if (device_type == WebnnDeviceType::CPU) { + supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); + } + + if (!IsSupportedDataType(input_type, supported_data_types)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input type: [" << input_type + << "] is not supported for now"; + return false; + } + + return true; +} + void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index beee8b1d77..283badb98a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -24,6 +24,8 @@ class SoftmaxOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const override; }; Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -131,6 +133,30 @@ bool SoftmaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali return true; } +bool SoftmaxOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input = *node.InputDefs()[0]; + const auto& op_type = node.OpType(); + int32_t input_type; + if (!GetType(input, input_type, logger)) + return false; + + // WebNN softmax only supports float32 and float16 input data types. + std::unordered_set supported_data_types = { + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + }; + + if (!IsSupportedDataType(input_type, supported_data_types)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input type: [" << input_type + << "] is not supported for now"; + return false; + } + + return true; +} + void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc index 9c23554a44..841e2d1824 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc @@ -18,6 +18,8 @@ class TernaryOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const override; }; // Add operator related. @@ -42,6 +44,42 @@ Status TernaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons return Status::OK(); } +bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input0_type; // condition data type + int32_t input1_type; // X data type + int32_t input2_type; // Y data type + + if (!GetType(*input_defs[0], input0_type, logger) || + !GetType(*input_defs[1], input1_type, logger) || + !GetType(*input_defs[2], input2_type, logger)) + return false; + + std::unordered_set supported_data_types = webnn_supported_data_types; + // WebNN CPU backend doesn't support uint64 X, Y data type for where. + if (device_type == WebnnDeviceType::CPU && op_type == "Where") { + supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); + } + // ONNX's condition data type is bool which is same as WebNN. + // Only need to check X, Y data types. + if (!IsSupportedDataType(input1_type, supported_data_types)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input type: [" << input1_type + << "] is not supported for now"; + return false; + } + + if (input1_type != input2_type) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input X, Y data types should be the same."; + return false; + } + + return true; +} + void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc index 79f60c51ac..3921b1da18 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc @@ -18,6 +18,8 @@ class TransposeOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const override; }; // Add operator related. @@ -47,6 +49,31 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } +bool TransposeOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const { + const auto& input = *node.InputDefs()[0]; + const auto& op_type = node.OpType(); + int32_t input_type; + if (!GetType(input, input_type, logger)) + return false; + + std::unordered_set supported_data_types = webnn_supported_data_types; + // WebNN CPU backend doesn't support uint32, uint64 input data types for transpose. + if (device_type == WebnnDeviceType::CPU) { + supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); + supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); + } + + if (!IsSupportedDataType(input_type, supported_data_types)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input type: [" << input_type + << "] is not supported for now"; + return false; + } + + return true; +} + void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc index e6c5cf2408..e0016de8e6 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc @@ -18,6 +18,8 @@ class UnaryOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const override; }; // Add operator related. @@ -66,6 +68,44 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const return Status::OK(); } +bool UnaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input = *node.InputDefs()[0]; + const auto& op_type = node.OpType(); + int32_t input_type; + if (!GetType(input, input_type, logger)) + return false; + + std::unordered_set supported_data_types; + if (op_type == "Identity") { + supported_data_types = webnn_supported_data_types; + } else if (op_type == "Abs" || op_type == "Neg") { + supported_data_types = { + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + ONNX_NAMESPACE::TensorProto_DataType_INT32, + ONNX_NAMESPACE::TensorProto_DataType_INT8, + }; + } else if (op_type == "Not") { + supported_data_types = { + ONNX_NAMESPACE::TensorProto_DataType_BOOL, + }; + } else { // Others only support float32, float16 input data types. + supported_data_types = { + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + }; + } + if (!IsSupportedDataType(input_type, supported_data_types)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input type: [" << input_type + << "] is not supported for now"; + return false; + } + + return true; +} + void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index ff3e1a7179..c46b04a3c2 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -96,7 +96,7 @@ Status ModelBuilder::RegisterInitializers() { desc.set("dimensions", emscripten::val::array(dims)); auto data_type = tensor.data_type(); emscripten::val operand = emscripten::val::object(); - if (IsSupportedDataType(data_type, wnn_device_type_)) { + if (IsSupportedDataType(data_type, webnn_supported_data_types)) { ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type"); auto num_elements = SafeInt(Product(tensor.dims())); emscripten::val view = emscripten::val::undefined(); @@ -112,12 +112,10 @@ Status ModelBuilder::RegisterInitializers() { switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: - desc.set("type", emscripten::val("uint8")); view = emscripten::val{emscripten::typed_memory_view(num_elements, reinterpret_cast(tensor_ptr))}; break; case ONNX_NAMESPACE::TensorProto_DataType_INT8: - desc.set("type", emscripten::val("int8")); view = emscripten::val{emscripten::typed_memory_view(num_elements, reinterpret_cast(tensor_ptr))}; break; diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index d72abf1a72..13ed29667d 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -318,25 +318,11 @@ common::Status WebNNExecutionProvider::Compile(const std::vector