[WebNN EP] Remove activation fusion (#20635)

WebNN spec has removed activation option for conv and
batchNormalization. We don't need additional activation fusion in WebNN
EP anymore.

[edit by fdwr] Note this is handled in the browser now, which knows more
about the backend platform version and can more safely make decisions
about which fusions are possible (e.g. for the DirectML backend, whether
softmax and gelu can fuse successfully with their base operator).
This commit is contained in:
Wanming Lin 2024-05-16 07:49:07 +08:00 committed by GitHub
parent d1e66f0446
commit f5bfbd6d81
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 30 additions and 175 deletions

View file

@ -32,40 +32,35 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name());
emscripten::val output = emscripten::val::object();
if (Contains(model_builder.GetFusedActivations(), node.InputDefs()[0]->Name())) {
LOGS_DEFAULT(VERBOSE) << op_type << " Node [" << node.Name() << "] fused";
output = input;
NodeAttrHelper helper(node);
emscripten::val options = emscripten::val::object();
if (op_type == "Elu") {
options.set("alpha", helper.Get("alpha", 1.0f));
output = model_builder.GetBuilder().call<emscripten::val>("elu", input, options);
} else if (op_type == "Gelu") {
output = model_builder.GetBuilder().call<emscripten::val>("gelu", input, options);
} else if (op_type == "HardSigmoid") {
options.set("alpha", helper.Get("alpha", 0.2f));
options.set("beta", helper.Get("beta", 0.5f));
output = model_builder.GetBuilder().call<emscripten::val>("hardSigmoid", input, options);
} else if (op_type == "HardSwish") {
output = model_builder.GetBuilder().call<emscripten::val>("hardSwish", input);
} else if (op_type == "LeakyRelu") {
options.set("alpha", helper.Get("alpha", 0.0f));
output = model_builder.GetBuilder().call<emscripten::val>("leakyRelu", input, options);
} else if (op_type == "Relu") {
output = model_builder.GetBuilder().call<emscripten::val>("relu", input);
} else if (op_type == "Sigmoid") {
output = model_builder.GetBuilder().call<emscripten::val>("sigmoid", input);
} else if (op_type == "Softplus") {
output = model_builder.GetBuilder().call<emscripten::val>("softplus", input);
} else if (op_type == "Softsign") {
output = model_builder.GetBuilder().call<emscripten::val>("softsign", input);
} else if (op_type == "Tanh") {
output = model_builder.GetBuilder().call<emscripten::val>("tanh", input);
} else {
NodeAttrHelper helper(node);
emscripten::val options = emscripten::val::object();
if (op_type == "Elu") {
options.set("alpha", helper.Get("alpha", 1.0f));
output = model_builder.GetBuilder().call<emscripten::val>("elu", input, options);
} else if (op_type == "Gelu") {
output = model_builder.GetBuilder().call<emscripten::val>("gelu", input, options);
} else if (op_type == "HardSigmoid") {
options.set("alpha", helper.Get("alpha", 0.2f));
options.set("beta", helper.Get("beta", 0.5f));
output = model_builder.GetBuilder().call<emscripten::val>("hardSigmoid", input, options);
} else if (op_type == "HardSwish") {
output = model_builder.GetBuilder().call<emscripten::val>("hardSwish", input);
} else if (op_type == "LeakyRelu") {
options.set("alpha", helper.Get("alpha", 0.0f));
output = model_builder.GetBuilder().call<emscripten::val>("leakyRelu", input, options);
} else if (op_type == "Relu") {
output = model_builder.GetBuilder().call<emscripten::val>("relu", input);
} else if (op_type == "Sigmoid") {
output = model_builder.GetBuilder().call<emscripten::val>("sigmoid", input);
} else if (op_type == "Softplus") {
output = model_builder.GetBuilder().call<emscripten::val>("softplus", input);
} else if (op_type == "Softsign") {
output = model_builder.GetBuilder().call<emscripten::val>("softsign", input);
} else if (op_type == "Tanh") {
output = model_builder.GetBuilder().call<emscripten::val>("tanh", input);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
}
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
}
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));

View file

@ -52,13 +52,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
options.set("minValue", minValue);
options.set("maxValue", maxValue);
emscripten::val input = model_builder.GetOperand(input_name);
emscripten::val output = emscripten::val::object();
if (Contains(model_builder.GetFusedActivations(), input_name)) {
LOGS_DEFAULT(VERBOSE) << "Clip Node [" << node.Name() << "] fused";
output = input;
} else {
output = model_builder.GetBuilder().call<emscripten::val>("clamp", input, options);
}
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("clamp", input, options);
model_builder.AddOperand(output_name, std::move(output));
return Status::OK();

View file

@ -138,11 +138,6 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
options.set("bias", model_builder.GetOperand(input_defs[2]->Name()));
}
emscripten::val activation = model_builder.FindActivation(node, *node.OutputDefs()[0]);
if (emscripten::val::null() != activation) {
options.set("activation", activation);
}
return Status::OK();
}

View file

@ -79,10 +79,7 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
options.set("axis", rank - 1);
}
emscripten::val activation = model_builder.FindActivation(node, *node.OutputDefs()[0]);
if (emscripten::val::null() != activation) {
options.set("activation", activation);
}
output = model_builder.GetBuilder().call<emscripten::val>("batchNormalization", input, mean, variance, options);
} else if (op_type == "LayerNormalization") {
int64_t axis = helper.Get("axis", -1);

View file

@ -31,7 +31,6 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge
Status ModelBuilder::Initialize() {
PreprocessInitializers();
PreprocessActivations();
ORT_RETURN_IF_ERROR(RegisterInitializers());
ORT_RETURN_IF_ERROR(RegisterModelInputs());
ORT_RETURN_IF_ERROR(AddOperations());
@ -78,79 +77,6 @@ void ModelBuilder::PreprocessInitializers() {
}
}
void ModelBuilder::PreprocessActivations() {
const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder();
if (wnn_device_type_ == WebnnDeviceType::CPU) {
// WebNN CPU currently only supports "Relu" and "Clip" fusion.
supported_activation_nodes_ = {"Clip", "Relu"};
} else {
supported_activation_nodes_ = {
// Temporarily disable clamp fusion for WebNN GPU as which is not supported yet.
// "Clip",
"Elu",
"Gelu",
"HardSigmoid",
"HardSwish",
"Relu",
"LeakyRelu",
"Sigmoid",
"Softplus",
"Softsign",
"Tanh",
};
}
for (size_t i = 0; i < node_indices.size(); i++) {
const auto* node(graph_viewer_.GetNode(node_indices[i]));
const auto& op_type(node->OpType());
// Ignore unsupported activation nodes.
if (!Contains(supported_activation_nodes_, op_type)) {
continue;
}
if (op_type == "Clip") {
float minValue, maxValue;
GetClipMinMax(GetInitializerTensors(), *node, minValue, maxValue, logger_);
emscripten::val options = emscripten::val::object();
options.set("minValue", minValue);
options.set("maxValue", maxValue);
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("clamp", options));
} else if (op_type == "Elu") {
NodeAttrHelper helper(*node);
emscripten::val options = emscripten::val::object();
options.set("alpha", helper.Get("alpha", 1.0f));
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("elu", options));
} else if (op_type == "Gelu") {
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("gelu"));
} else if (op_type == "HardSigmoid") {
NodeAttrHelper helper(*node);
emscripten::val options = emscripten::val::object();
options.set("alpha", helper.Get("alpha", 0.2f));
options.set("beta", helper.Get("beta", 0.5f));
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("hardSigmoid", options));
} else if (op_type == "HardSwish") {
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("hardSwish"));
} else if (op_type == "Relu") {
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("relu"));
} else if (op_type == "LeakyRelu") {
NodeAttrHelper helper(*node);
emscripten::val options = emscripten::val::object();
options.set("alpha", helper.Get("alpha", 0.0f));
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("leakyRelu", options));
} else if (op_type == "Sigmoid") {
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("sigmoid"));
} else if (op_type == "Softplus") {
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("softplus"));
} else if (op_type == "Softsign") {
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("softsign"));
} else if (op_type == "Tanh") {
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("tanh"));
}
}
}
Status ModelBuilder::RegisterInitializers() {
for (const auto& pair : GetInitializerTensors()) {
const auto& tensor = *pair.second;
@ -421,44 +347,6 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
return Status::OK();
}
emscripten::val ModelBuilder::FindActivation(const Node& node, const NodeArg& output) {
emscripten::val fused_op = emscripten::val::null();
for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) {
const auto& dst_node = it->GetNode();
const auto* dst_input = dst_node.InputDefs()[it->GetDstArgIndex()];
if (!Contains(supported_activation_nodes_, dst_node.OpType())) {
return emscripten::val::null();
}
if (Contains(activation_nodes_, dst_node.Index())) {
if (&output == dst_input) {
fused_op = activation_nodes_.at(dst_node.Index());
}
} else {
// If there is any other non-relu node using the output
// will add relu separately.
if (&output == dst_input) {
return emscripten::val::null();
}
}
}
// If output is a graph output, will add relu separately.
if (fused_op != emscripten::val::null()) {
for (const auto* graph_output : graph_viewer_.GetOutputs()) {
if (&output == graph_output) {
return emscripten::val::null();
}
}
LOGS_DEFAULT(VERBOSE) << "Node [" << node.Name() << "] type [" << node.OpType()
<< "], fused the output [" << output.Name() << "]";
fused_activations_.insert(output.Name());
}
return fused_op;
}
void ModelBuilder::AddScalarOutput(const std::string& output_name) {
scalar_outputs_.insert(output_name);
}

View file

@ -44,11 +44,6 @@ class ModelBuilder {
Status AddOperandFromPersistMemoryBuffer(
const std::string& name, const void* buffer,
const size_t size, const std::vector<uint32_t> shape, const int32_t data_type);
// Find if an output has a fuseable activation (e.g., Relu).
emscripten::val FindActivation(const Node& node, const NodeArg& output);
const InlinedHashSet<std::string>&
GetFusedActivations() const { return fused_activations_; }
DataLayout GetPreferredLayout() const { return preferred_layout_; }
@ -82,22 +77,13 @@ class ModelBuilder {
InlinedHashSet<std::string> skipped_initializers_;
InlinedHashSet<std::string> skipped_inputs_;
InlinedHashSet<std::string> fused_activations_;
InlinedHashSet<std::string> supported_activation_nodes_;
uint32_t name_token_{0};
InlinedHashSet<std::string> unique_names_;
// All activation nodes (e.g., Relu) as a map <NodeIndex, FusionOperator>.
InlinedHashMap<NodeIndex, emscripten::val> activation_nodes_;
// Convert the onnx model to WebNN operands
Status Initialize() ORT_MUST_USE_RESULT;
void PreprocessInitializers();
// Preprocess all the activation nodes (e.g., Relu) for easy query later.
void PreprocessActivations();
// Copy and process all the initializers to WebNN constants.
Status RegisterInitializers() ORT_MUST_USE_RESULT;