[WebNN EP] Improve activation fusion (#20320)

- Create a common util to get supported activation set
- Fuse activation to BatchNormalization if possible
This commit is contained in:
Wanming Lin 2024-04-26 08:16:55 -07:00 committed by GitHub
parent 88904b9220
commit ddd4e8c3e3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 43 additions and 17 deletions

View file

@ -137,8 +137,8 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
if (input_defs.size() > 2) {
options.set("bias", model_builder.GetOperand(input_defs[2]->Name()));
}
InlinedHashSet<std::string> supported_nodes{"Clip", "Relu"};
emscripten::val activation = model_builder.FindActivation(node, *node.OutputDefs()[0], supported_nodes);
emscripten::val activation = model_builder.FindActivation(node, *node.OutputDefs()[0]);
if (emscripten::val::null() != activation) {
options.set("activation", activation);
}

View file

@ -79,6 +79,10 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
options.set("axis", rank - 1);
}
emscripten::val activation = model_builder.FindActivation(node, *node.OutputDefs()[0]);
if (emscripten::val::null() != activation) {
options.set("activation", activation);
}
output = model_builder.GetBuilder().call<emscripten::val>("batchNormalization", input, mean, variance, options);
} else if (op_type == "LayerNormalization") {
int64_t axis = helper.Get("axis", -1);

View file

@ -80,20 +80,43 @@ void ModelBuilder::PreprocessInitializers() {
void ModelBuilder::PreprocessActivations() {
const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder();
if (wnn_device_type_ == WebnnDeviceType::CPU) {
// WebNN CPU currently only supports "Relu" and "Clip" fusion.
supported_activation_nodes_ = {"Clip", "Relu"};
} else {
supported_activation_nodes_ = {
// Temporarily disable clamp fusion for WebNN GPU as which is not supported yet.
// "Clip",
"Elu",
"Gelu",
"HardSigmoid",
"HardSwish",
"Relu",
"LeakyRelu",
"Sigmoid",
"Softplus",
"Softsign",
"Tanh",
};
}
for (size_t i = 0; i < node_indices.size(); i++) {
const auto* node(graph_viewer_.GetNode(node_indices[i]));
const auto& op_type(node->OpType());
// Ignore unsupported activation nodes.
if (!Contains(supported_activation_nodes_, op_type)) {
continue;
}
if (op_type == "Clip") {
// Temporarily disable clamp fusion for WebNN GPU as which is not supported yet.
if (wnn_device_type_ == WebnnDeviceType::CPU) {
float minValue, maxValue;
GetClipMinMax(GetInitializerTensors(), *node, minValue, maxValue, logger_);
emscripten::val options = emscripten::val::object();
options.set("minValue", minValue);
options.set("maxValue", maxValue);
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("clamp", options));
}
float minValue, maxValue;
GetClipMinMax(GetInitializerTensors(), *node, minValue, maxValue, logger_);
emscripten::val options = emscripten::val::object();
options.set("minValue", minValue);
options.set("maxValue", maxValue);
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("clamp", options));
} else if (op_type == "Elu") {
NodeAttrHelper helper(*node);
emscripten::val options = emscripten::val::object();
@ -398,14 +421,12 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
return Status::OK();
}
// supported_nodes is provided by the op to indicate whether it can be fused with the activation node.
emscripten::val ModelBuilder::FindActivation(const Node& node, const NodeArg& output,
const InlinedHashSet<std::string> supported_nodes) {
emscripten::val ModelBuilder::FindActivation(const Node& node, const NodeArg& output) {
emscripten::val fused_op = emscripten::val::null();
for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) {
const auto& dst_node = it->GetNode();
const auto* dst_input = dst_node.InputDefs()[it->GetDstArgIndex()];
if (!Contains(supported_nodes, dst_node.OpType())) {
if (!Contains(supported_activation_nodes_, dst_node.OpType())) {
return emscripten::val::null();
}
if (Contains(activation_nodes_, dst_node.Index())) {

View file

@ -45,8 +45,7 @@ class ModelBuilder {
const std::string& name, const void* buffer,
const size_t size, const std::vector<uint32_t> shape, const int32_t data_type);
// Find if an output has a fuseable activation (e.g., Relu).
emscripten::val FindActivation(const Node& node, const NodeArg& output,
const InlinedHashSet<std::string> supported_nodes = {});
emscripten::val FindActivation(const Node& node, const NodeArg& output);
const InlinedHashSet<std::string>&
GetFusedActivations() const { return fused_activations_; }
@ -85,6 +84,8 @@ class ModelBuilder {
InlinedHashSet<std::string> fused_activations_;
InlinedHashSet<std::string> supported_activation_nodes_;
uint32_t name_token_{0};
InlinedHashSet<std::string> unique_names_;