[WebNN EP] Add data type constraint (#20779)

WebNN spec has added data type constraint for every op, and its CPU
backend (currently is TFLite) has additional constraint. Add
corresponding constraint to each op in WebNN EP.

Note: Temporarily disable fp16 for CPU backend as which is planned to be
ready in Chromium next month.
This commit is contained in:
Wanming Lin 2024-05-30 01:19:51 +08:00 committed by GitHub
parent e77f238dc6
commit 9ea9f9e46a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 680 additions and 76 deletions

View file

@ -130,16 +130,10 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
return supported_node_groups;
}
bool IsSupportedDataType(const int32_t data_type, const WebnnDeviceType device_type) {
// Current data type implementation status of WebNN is inconsistent along with different backends,
// The XNNPack backend supports only FP32, while the DML backend POC supports more.
if (device_type == WebnnDeviceType::CPU) {
return std::find(supported_cpu_data_types.begin(), supported_cpu_data_types.end(), data_type) !=
supported_cpu_data_types.end();
} else {
return std::find(supported_gpu_data_types.begin(), supported_gpu_data_types.end(), data_type) !=
supported_gpu_data_types.end();
}
bool IsSupportedDataType(const int32_t data_type,
const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType>& supported_data_types) {
return std::find(supported_data_types.begin(), supported_data_types.end(), data_type) !=
supported_data_types.end();
}
bool IsValidMultidirectionalBroadcast(std::vector<int64_t>& shape_a,

View file

@ -262,11 +262,7 @@ inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn
return true;
}
constexpr std::array<ONNX_NAMESPACE::TensorProto_DataType, 1> supported_cpu_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
};
constexpr std::array<ONNX_NAMESPACE::TensorProto_DataType, 9> supported_gpu_data_types = {
static const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> webnn_supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
ONNX_NAMESPACE::TensorProto_DataType_INT8,
ONNX_NAMESPACE::TensorProto_DataType_UINT8,
@ -278,7 +274,8 @@ constexpr std::array<ONNX_NAMESPACE::TensorProto_DataType, 9> supported_gpu_data
ONNX_NAMESPACE::TensorProto_DataType_UINT64,
};
bool IsSupportedDataType(const int32_t data_type, const WebnnDeviceType device_type);
bool IsSupportedDataType(const int32_t data_type,
const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType>& supported_data_types);
bool IsValidMultidirectionalBroadcast(std::vector<int64_t>& shape_a,
std::vector<int64_t>& shape_b,

View file

@ -21,6 +21,8 @@ class ActivationOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -81,6 +83,44 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initi
return true;
}
bool ActivationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
// WebNN relu op supports float32, float16, int32, int8 input data types.
if (op_type == "Relu") {
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
ONNX_NAMESPACE::TensorProto_DataType_INT32,
ONNX_NAMESPACE::TensorProto_DataType_INT8,
};
// WebNN CPU backend does not support int32 data type for relu.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32);
}
} else { // Others only support float32 and float16.
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
};
}
if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}
return true;
}
void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;

View file

@ -22,6 +22,8 @@ class ArgMaxMinOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -77,6 +79,31 @@ bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia
return true;
}
bool ArgMaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
// WebNN CPU backend doesn't support int64, uint64 input data types for argMax and argMin.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64);
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
}
if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}
return true;
}
void CreateArgMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;

View file

@ -73,11 +73,23 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType d
}
}
// WebNN CPU backend (TFLite) will enable float16 input data type soon,
// temporarily fallback float16 input data type for WebNN CPU.
if (device_type == WebnnDeviceType::CPU) {
const auto& input = *node.InputDefs()[0];
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)
return false;
}
return HasSupportedInputsImpl(node, device_type, logger);
}
bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node,
const WebnnDeviceType device_type,
const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
// We only check the type of input 0 by default, specific op builder can override this.
const auto& input = *node.InputDefs()[0];
@ -86,7 +98,7 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node,
if (!GetType(input, input_type, logger))
return false;
if (!IsSupportedDataType(input_type, device_type)) {
if (!IsSupportedDataType(input_type, webnn_supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << node.OpType()
<< "] Input type: [" << input_type
<< "] is not supported for now";

View file

@ -22,6 +22,8 @@ class BinaryOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -72,6 +74,49 @@ bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
return true;
}
bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type;
int32_t input1_type;
if (!GetType(*input_defs[0], input0_type, logger) ||
!GetType(*input_defs[1], input1_type, logger))
return false;
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
// WebNN prelu op only supports float32, float16, int32, int8 input data types.
if (op_type == "Prelu") {
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
ONNX_NAMESPACE::TensorProto_DataType_INT32,
ONNX_NAMESPACE::TensorProto_DataType_INT8,
};
// WebNN CPU backend doesn't support int32 for prelu.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32);
}
} else {
supported_data_types = webnn_supported_data_types;
}
if (!IsSupportedDataType(input0_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input0_type
<< "] is not supported for now";
return false;
}
if (input0_type != input1_type) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input data types should be the same.";
return false;
}
return true;
}
void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;

View file

@ -22,7 +22,7 @@ class CastOpBuilder : public BaseOpBuilder {
// Operator support related.
private:
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
};
// Add operator related.
@ -80,12 +80,12 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
bool CastOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */,
const Node& node,
const WebnnDeviceType device_type,
const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
NodeAttrHelper helper(node);
// Check cast output type.
const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED);
if (!IsSupportedDataType(to_type, device_type)) {
if (!IsSupportedDataType(to_type, webnn_supported_data_types)) {
LOGS(logger, VERBOSE) << "Invalid cast to type " << to_type << ".";
return false;
}

View file

@ -25,6 +25,8 @@ class ClipOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -71,6 +73,33 @@ bool ClipOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
return GetClipMinMax(initializers, node, min, max, logger);
}
bool ClipOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
// WebNN CPU backend doesn't support int32, uint32, int64, uint64 input data types for clamp.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32);
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64);
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
}
if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}
return true;
}
void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<ClipOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());

View file

@ -29,6 +29,8 @@ class ConvOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const override;
};
void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
@ -148,11 +150,6 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder,
bool is_conv1d) {
const auto& tensor = *model_builder.GetInitializerTensors().at(name);
auto data_type = tensor.data_type();
if (!IsSupportedDataType(data_type, model_builder.GetWebnnDeviceType())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"The initializer of graph has unsupported type, name: ",
tensor.name(), " type: ", data_type);
}
const auto& shape = tensor.dims();
std::vector<uint32_t> dims = GetVecUint32FromVecInt64(std::vector<int64_t>(std::begin(shape), std::end(shape)));
@ -177,7 +174,6 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder,
size_t element_size{0};
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
element_size = sizeof(uint8_t);
break;
@ -190,17 +186,6 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder,
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
element_size = sizeof(float);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
element_size = sizeof(int32_t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
element_size = sizeof(int64_t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
element_size = sizeof(uint32_t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
element_size = sizeof(uint64_t);
break;
default:
break;
@ -396,6 +381,55 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
return true;
}
bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type; // input data type
int32_t input1_type; // weight data type
int32_t input2_type; // bias or x_zero_point data type
int32_t input3_type; // w_zero_point data type
bool has_input2 = input_defs.size() > 2 && input_defs[2]->Exists();
bool has_input3 = input_defs.size() > 3 && input_defs[3]->Exists();
if (!GetType(*input_defs[0], input0_type, logger) ||
!GetType(*input_defs[1], input1_type, logger) ||
(has_input2 && !GetType(*input_defs[2], input2_type, logger)) ||
(has_input3 && !GetType(*input_defs[3], input3_type, logger))) {
return false;
}
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
if (op_type == "Conv" || op_type == "ConvTranspose") {
// WebNN conv2d and convTranspose2d only support float32 and float16 input data types.
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
};
} else if (op_type == "ConvInteger") {
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_INT8,
ONNX_NAMESPACE::TensorProto_DataType_UINT8,
};
}
if (!IsSupportedDataType(input0_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input0_type
<< "] is not supported for now";
return false;
}
if (input0_type != input1_type ||
(has_input2 && input0_type != input2_type) ||
(has_input3 && input0_type != input3_type)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input data types should be the same.";
return false;
}
return true;
}
void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;

View file

@ -22,6 +22,8 @@ class GatherOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -66,6 +68,31 @@ bool GatherOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ
return true;
}
bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
// WebNN CPU backend doesn't support uint32, uint64 input data types for gather.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
}
if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}
return true;
}
void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<GatherOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());

View file

@ -25,6 +25,8 @@ class GemmOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -219,6 +221,55 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
return true;
}
bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type; // A data type
int32_t input1_type; // B data type
int32_t input2_type; // C or a_zero_point data type
int32_t input3_type; // b_zero_point data type
bool has_input2 = input_defs.size() > 2 && input_defs[2]->Exists();
bool has_input3 = input_defs.size() > 3 && input_defs[3]->Exists();
if (!GetType(*input_defs[0], input0_type, logger) ||
!GetType(*input_defs[1], input1_type, logger) ||
(has_input2 && !GetType(*input_defs[2], input2_type, logger)) ||
(has_input3 && !GetType(*input_defs[3], input3_type, logger))) {
return false;
}
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
if (op_type == "Gemm" || op_type == "MatMul") {
// WebNN gemm and matmul only support float32 and float16 input data types.
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
};
} else if (op_type == "MatMulInteger") {
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_INT8,
ONNX_NAMESPACE::TensorProto_DataType_UINT8,
};
}
if (!IsSupportedDataType(input0_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input0_type
<< "] is not supported for now";
return false;
}
if (input0_type != input1_type ||
(has_input2 && input0_type != input2_type) ||
(has_input3 && input0_type != input3_type)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input data types should be the same.";
return false;
}
return true;
}
void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;

View file

@ -21,6 +21,8 @@ class LogicalOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -50,6 +52,48 @@ Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons
return Status::OK();
}
bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */,
const Node& node,
const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& name = node.Name();
const auto& op_type = node.OpType();
const auto& input_defs = node.InputDefs();
if (input_defs.size() < 2) {
LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 2 inputs, actual: "
<< input_defs.size();
return false;
}
return true;
}
bool LogicalOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type;
int32_t input1_type;
if (!GetType(*input_defs[0], input0_type, logger) ||
!GetType(*input_defs[1], input1_type, logger))
return false;
if (!IsSupportedDataType(input0_type, webnn_supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input0_type
<< "] is not supported for now";
return false;
}
if (input0_type != input1_type) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input data types should be the same.";
return false;
}
return true;
}
void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;
@ -69,20 +113,5 @@ void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations&
}
}
bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */,
const Node& node,
const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& name = node.Name();
const auto& op_type = node.OpType();
const auto& input_defs = node.InputDefs();
if (input_defs.size() < 2) {
LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 2 inputs, actual: "
<< input_defs.size();
return false;
}
return true;
}
} // namespace webnn
} // namespace onnxruntime

View file

@ -22,6 +22,8 @@ class MaxMinOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -84,6 +86,33 @@ bool MaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ
return true;
}
bool MaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type;
int32_t input1_type;
if (!GetType(*input_defs[0], input0_type, logger) ||
!GetType(*input_defs[1], input1_type, logger))
return false;
if (!IsSupportedDataType(input0_type, webnn_supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input0_type
<< "] is not supported for now";
return false;
}
if (input0_type != input1_type) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input data types should be the same.";
return false;
}
return true;
}
void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;

View file

@ -25,6 +25,8 @@ class NormalizationOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const override;
};
Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
@ -173,6 +175,53 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initi
return true;
}
bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type; // input data type
int32_t input1_type; // scale data type
int32_t input2_type; // B data type
int32_t input3_type; // mean data type
int32_t input4_type; // var data type
bool has_input2 = input_defs.size() > 2 && input_defs[2]->Exists();
bool has_input3 = input_defs.size() > 3 && input_defs[3]->Exists();
bool has_input4 = input_defs.size() > 3 && input_defs[4]->Exists();
if (!GetType(*input_defs[0], input0_type, logger) ||
!GetType(*input_defs[1], input1_type, logger) ||
(has_input2 && !GetType(*input_defs[2], input2_type, logger)) ||
(has_input3 && !GetType(*input_defs[3], input3_type, logger)) ||
(has_input4 && !GetType(*input_defs[4], input4_type, logger))) {
return false;
}
// WebNN batchNormalization, instanceNormalization, layerNormalization
// only support float32 and float16 input data types.
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
};
if (!IsSupportedDataType(input0_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input0_type
<< "] is not supported for now";
return false;
}
if (input0_type != input1_type ||
(has_input2 && input0_type != input2_type) ||
(has_input3 && input0_type != input3_type) ||
(has_input4 && input0_type != input4_type)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input data types should be the same.";
return false;
}
return true;
}
void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;

View file

@ -28,6 +28,8 @@ class PadOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -190,6 +192,31 @@ bool PadOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
return true;
} // namespace webnn
bool PadOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
// WebNN CPU backend doesn't support uint32, uint64 input data types for pad.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
}
if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}
return true;
}
void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<PadOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());

View file

@ -31,6 +31,8 @@ class ReductionOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -144,6 +146,51 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ
return true;
}
bool ReductionOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
if (op_type == "ReduceL1" || op_type == "ReduceProd" ||
op_type == "ReduceSum" || op_type == "ReduceSumSquare") {
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
ONNX_NAMESPACE::TensorProto_DataType_INT32,
ONNX_NAMESPACE::TensorProto_DataType_UINT32,
};
// WebNN CPU backend doesn't support uint32 for reduceProd and reduceSum.
if (device_type == WebnnDeviceType::CPU && (op_type == "ReduceProd" || op_type == "ReduceSum")) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
}
} else if (op_type == "ReduceL2" || op_type == "ReduceLogSum" ||
op_type == "ReduceLogSumExp" || op_type == "ReduceMean") {
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
};
} else { // ReduceMax and ReduceMin
supported_data_types = webnn_supported_data_types;
// WebNN CPU backend doesn't support uint32, uint64 for reduceMax and reduceMin.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
}
}
if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}
return true;
}
void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;

View file

@ -35,6 +35,8 @@ class ResizeOpBuilder : public BaseOpBuilder {
// Resize opset 10- is very different than Resize opset 11+, with many key attributes missing.
// We only support Resize opset 11+ here.
int GetMinSupportedOpSet(const Node& /* node */) const override { return 11; }
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const override;
};
// Helper functions
@ -280,6 +282,30 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
return true;
}
bool ResizeOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
// WebNN resample2d op only supports float32 and float16 input data types.
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
};
if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}
return true;
}
void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<ResizeOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());

View file

@ -66,7 +66,7 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
bool ShapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */,
const Node& node,
const WebnnDeviceType device_type,
const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
std::vector<int64_t> input_shape;
@ -74,7 +74,7 @@ bool ShapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initialize
return false;
int32_t output_type = ONNX_NAMESPACE::TensorProto_DataType_INT64;
if (!IsSupportedDataType(output_type, device_type)) {
if (!IsSupportedDataType(output_type, webnn_supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << node.OpType()
<< "] Output type: [" << output_type
<< "] is not supported for now";

View file

@ -29,6 +29,8 @@ class SliceOpBuilder : public BaseOpBuilder {
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
// TODO: Support Slice opset < 10, which uses attributes for starts and ends.
int GetMinSupportedOpSet(const Node& /* node */) const override { return 10; }
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -161,6 +163,30 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
return true;
}
bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
// WebNN CPU backend doesn't support uint64 input data type for slice.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
}
if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}
return true;
}
void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<SliceOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());

View file

@ -24,6 +24,8 @@ class SoftmaxOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const override;
};
Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
@ -131,6 +133,30 @@ bool SoftmaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali
return true;
}
bool SoftmaxOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
// WebNN softmax only supports float32 and float16 input data types.
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
};
if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}
return true;
}
void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<SoftmaxOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());

View file

@ -18,6 +18,8 @@ class TernaryOpBuilder : public BaseOpBuilder {
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -42,6 +44,42 @@ Status TernaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons
return Status::OK();
}
bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type; // condition data type
int32_t input1_type; // X data type
int32_t input2_type; // Y data type
if (!GetType(*input_defs[0], input0_type, logger) ||
!GetType(*input_defs[1], input1_type, logger) ||
!GetType(*input_defs[2], input2_type, logger))
return false;
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
// WebNN CPU backend doesn't support uint64 X, Y data type for where.
if (device_type == WebnnDeviceType::CPU && op_type == "Where") {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
}
// ONNX's condition data type is bool which is same as WebNN.
// Only need to check X, Y data types.
if (!IsSupportedDataType(input1_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input1_type
<< "] is not supported for now";
return false;
}
if (input1_type != input2_type) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input X, Y data types should be the same.";
return false;
}
return true;
}
void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;

View file

@ -18,6 +18,8 @@ class TransposeOpBuilder : public BaseOpBuilder {
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -47,6 +49,31 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
return Status::OK();
}
bool TransposeOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
// WebNN CPU backend doesn't support uint32, uint64 input data types for transpose.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
}
if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}
return true;
}
void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<TransposeOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());

View file

@ -18,6 +18,8 @@ class UnaryOpBuilder : public BaseOpBuilder {
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const override;
};
// Add operator related.
@ -66,6 +68,44 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
return Status::OK();
}
bool UnaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
if (op_type == "Identity") {
supported_data_types = webnn_supported_data_types;
} else if (op_type == "Abs" || op_type == "Neg") {
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
ONNX_NAMESPACE::TensorProto_DataType_INT32,
ONNX_NAMESPACE::TensorProto_DataType_INT8,
};
} else if (op_type == "Not") {
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
};
} else { // Others only support float32, float16 input data types.
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
};
}
if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}
return true;
}
void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;

View file

@ -96,7 +96,7 @@ Status ModelBuilder::RegisterInitializers() {
desc.set("dimensions", emscripten::val::array(dims));
auto data_type = tensor.data_type();
emscripten::val operand = emscripten::val::object();
if (IsSupportedDataType(data_type, wnn_device_type_)) {
if (IsSupportedDataType(data_type, webnn_supported_data_types)) {
ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type");
auto num_elements = SafeInt<size_t>(Product(tensor.dims()));
emscripten::val view = emscripten::val::undefined();
@ -112,12 +112,10 @@ Status ModelBuilder::RegisterInitializers() {
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
desc.set("type", emscripten::val("uint8"));
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint8_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
desc.set("type", emscripten::val("int8"));
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<int8_t*>(tensor_ptr))};
break;

View file

@ -318,25 +318,11 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
auto output_tensor =
ctx.GetOutput(i, output_shape.data(), output_shape.size());
void* output_buffer;
switch (output_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
output_buffer = output_tensor.GetTensorMutableRawData();
break;
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"Unsupported type: ", output_type, " for output: ", output_name);
break;
if (!webnn::IsSupportedDataType(output_type, webnn::webnn_supported_data_types)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"Unsupported type: ", output_type, " for output: ", output_name);
}
void* output_buffer = output_tensor.GetTensorMutableRawData();
outputs.emplace(output_name,
webnn::OnnxTensorData{
webnn::OnnxTensorInfo{output_type, output_shape},