mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
[WebNN EP] Improve activation fusion (#20320)
- Create a common util to get supported activation set - Fuse activation to BatchNormalization if possible
This commit is contained in:
parent
88904b9220
commit
ddd4e8c3e3
4 changed files with 43 additions and 17 deletions
|
|
@ -137,8 +137,8 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
|
|||
if (input_defs.size() > 2) {
|
||||
options.set("bias", model_builder.GetOperand(input_defs[2]->Name()));
|
||||
}
|
||||
InlinedHashSet<std::string> supported_nodes{"Clip", "Relu"};
|
||||
emscripten::val activation = model_builder.FindActivation(node, *node.OutputDefs()[0], supported_nodes);
|
||||
|
||||
emscripten::val activation = model_builder.FindActivation(node, *node.OutputDefs()[0]);
|
||||
if (emscripten::val::null() != activation) {
|
||||
options.set("activation", activation);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -79,6 +79,10 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
|
|||
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
|
||||
options.set("axis", rank - 1);
|
||||
}
|
||||
emscripten::val activation = model_builder.FindActivation(node, *node.OutputDefs()[0]);
|
||||
if (emscripten::val::null() != activation) {
|
||||
options.set("activation", activation);
|
||||
}
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("batchNormalization", input, mean, variance, options);
|
||||
} else if (op_type == "LayerNormalization") {
|
||||
int64_t axis = helper.Get("axis", -1);
|
||||
|
|
|
|||
|
|
@ -80,20 +80,43 @@ void ModelBuilder::PreprocessInitializers() {
|
|||
|
||||
void ModelBuilder::PreprocessActivations() {
|
||||
const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder();
|
||||
|
||||
if (wnn_device_type_ == WebnnDeviceType::CPU) {
|
||||
// WebNN CPU currently only supports "Relu" and "Clip" fusion.
|
||||
supported_activation_nodes_ = {"Clip", "Relu"};
|
||||
} else {
|
||||
supported_activation_nodes_ = {
|
||||
// Temporarily disable clamp fusion for WebNN GPU as which is not supported yet.
|
||||
// "Clip",
|
||||
"Elu",
|
||||
"Gelu",
|
||||
"HardSigmoid",
|
||||
"HardSwish",
|
||||
"Relu",
|
||||
"LeakyRelu",
|
||||
"Sigmoid",
|
||||
"Softplus",
|
||||
"Softsign",
|
||||
"Tanh",
|
||||
};
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < node_indices.size(); i++) {
|
||||
const auto* node(graph_viewer_.GetNode(node_indices[i]));
|
||||
const auto& op_type(node->OpType());
|
||||
|
||||
// Ignore unsupported activation nodes.
|
||||
if (!Contains(supported_activation_nodes_, op_type)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (op_type == "Clip") {
|
||||
// Temporarily disable clamp fusion for WebNN GPU as which is not supported yet.
|
||||
if (wnn_device_type_ == WebnnDeviceType::CPU) {
|
||||
float minValue, maxValue;
|
||||
GetClipMinMax(GetInitializerTensors(), *node, minValue, maxValue, logger_);
|
||||
emscripten::val options = emscripten::val::object();
|
||||
options.set("minValue", minValue);
|
||||
options.set("maxValue", maxValue);
|
||||
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("clamp", options));
|
||||
}
|
||||
float minValue, maxValue;
|
||||
GetClipMinMax(GetInitializerTensors(), *node, minValue, maxValue, logger_);
|
||||
emscripten::val options = emscripten::val::object();
|
||||
options.set("minValue", minValue);
|
||||
options.set("maxValue", maxValue);
|
||||
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("clamp", options));
|
||||
} else if (op_type == "Elu") {
|
||||
NodeAttrHelper helper(*node);
|
||||
emscripten::val options = emscripten::val::object();
|
||||
|
|
@ -398,14 +421,12 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// supported_nodes is provided by the op to indicate whether it can be fused with the activation node.
|
||||
emscripten::val ModelBuilder::FindActivation(const Node& node, const NodeArg& output,
|
||||
const InlinedHashSet<std::string> supported_nodes) {
|
||||
emscripten::val ModelBuilder::FindActivation(const Node& node, const NodeArg& output) {
|
||||
emscripten::val fused_op = emscripten::val::null();
|
||||
for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) {
|
||||
const auto& dst_node = it->GetNode();
|
||||
const auto* dst_input = dst_node.InputDefs()[it->GetDstArgIndex()];
|
||||
if (!Contains(supported_nodes, dst_node.OpType())) {
|
||||
if (!Contains(supported_activation_nodes_, dst_node.OpType())) {
|
||||
return emscripten::val::null();
|
||||
}
|
||||
if (Contains(activation_nodes_, dst_node.Index())) {
|
||||
|
|
|
|||
|
|
@ -45,8 +45,7 @@ class ModelBuilder {
|
|||
const std::string& name, const void* buffer,
|
||||
const size_t size, const std::vector<uint32_t> shape, const int32_t data_type);
|
||||
// Find if an output has a fuseable activation (e.g., Relu).
|
||||
emscripten::val FindActivation(const Node& node, const NodeArg& output,
|
||||
const InlinedHashSet<std::string> supported_nodes = {});
|
||||
emscripten::val FindActivation(const Node& node, const NodeArg& output);
|
||||
|
||||
const InlinedHashSet<std::string>&
|
||||
GetFusedActivations() const { return fused_activations_; }
|
||||
|
|
@ -85,6 +84,8 @@ class ModelBuilder {
|
|||
|
||||
InlinedHashSet<std::string> fused_activations_;
|
||||
|
||||
InlinedHashSet<std::string> supported_activation_nodes_;
|
||||
|
||||
uint32_t name_token_{0};
|
||||
InlinedHashSet<std::string> unique_names_;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue