mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
[WebNN EP] Remove activation fusion (#20635)
WebNN spec has removed activation option for conv and batchNormalization. We don't need additional activation fusion in WebNN EP anymore. [edit by fdwr] Note this is handled in the browser now, which knows more about the backend platform version and can more safely make decisions about which fusions are possible (e.g. for the DirectML backend, whether softmax and gelu can fuse successfully with their base operator).
This commit is contained in:
parent
d1e66f0446
commit
f5bfbd6d81
6 changed files with 30 additions and 175 deletions
|
|
@ -32,40 +32,35 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name());
|
||||
emscripten::val output = emscripten::val::object();
|
||||
|
||||
if (Contains(model_builder.GetFusedActivations(), node.InputDefs()[0]->Name())) {
|
||||
LOGS_DEFAULT(VERBOSE) << op_type << " Node [" << node.Name() << "] fused";
|
||||
output = input;
|
||||
NodeAttrHelper helper(node);
|
||||
emscripten::val options = emscripten::val::object();
|
||||
if (op_type == "Elu") {
|
||||
options.set("alpha", helper.Get("alpha", 1.0f));
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("elu", input, options);
|
||||
} else if (op_type == "Gelu") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("gelu", input, options);
|
||||
} else if (op_type == "HardSigmoid") {
|
||||
options.set("alpha", helper.Get("alpha", 0.2f));
|
||||
options.set("beta", helper.Get("beta", 0.5f));
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("hardSigmoid", input, options);
|
||||
} else if (op_type == "HardSwish") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("hardSwish", input);
|
||||
} else if (op_type == "LeakyRelu") {
|
||||
options.set("alpha", helper.Get("alpha", 0.0f));
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("leakyRelu", input, options);
|
||||
} else if (op_type == "Relu") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("relu", input);
|
||||
} else if (op_type == "Sigmoid") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("sigmoid", input);
|
||||
} else if (op_type == "Softplus") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("softplus", input);
|
||||
} else if (op_type == "Softsign") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("softsign", input);
|
||||
} else if (op_type == "Tanh") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("tanh", input);
|
||||
} else {
|
||||
NodeAttrHelper helper(node);
|
||||
emscripten::val options = emscripten::val::object();
|
||||
if (op_type == "Elu") {
|
||||
options.set("alpha", helper.Get("alpha", 1.0f));
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("elu", input, options);
|
||||
} else if (op_type == "Gelu") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("gelu", input, options);
|
||||
} else if (op_type == "HardSigmoid") {
|
||||
options.set("alpha", helper.Get("alpha", 0.2f));
|
||||
options.set("beta", helper.Get("beta", 0.5f));
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("hardSigmoid", input, options);
|
||||
} else if (op_type == "HardSwish") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("hardSwish", input);
|
||||
} else if (op_type == "LeakyRelu") {
|
||||
options.set("alpha", helper.Get("alpha", 0.0f));
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("leakyRelu", input, options);
|
||||
} else if (op_type == "Relu") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("relu", input);
|
||||
} else if (op_type == "Sigmoid") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("sigmoid", input);
|
||||
} else if (op_type == "Softplus") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("softplus", input);
|
||||
} else if (op_type == "Softsign") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("softsign", input);
|
||||
} else if (op_type == "Tanh") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("tanh", input);
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
|
||||
}
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
|
||||
}
|
||||
|
||||
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
|
||||
|
|
|
|||
|
|
@ -52,13 +52,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
options.set("minValue", minValue);
|
||||
options.set("maxValue", maxValue);
|
||||
emscripten::val input = model_builder.GetOperand(input_name);
|
||||
emscripten::val output = emscripten::val::object();
|
||||
if (Contains(model_builder.GetFusedActivations(), input_name)) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Clip Node [" << node.Name() << "] fused";
|
||||
output = input;
|
||||
} else {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("clamp", input, options);
|
||||
}
|
||||
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("clamp", input, options);
|
||||
|
||||
model_builder.AddOperand(output_name, std::move(output));
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -138,11 +138,6 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
|
|||
options.set("bias", model_builder.GetOperand(input_defs[2]->Name()));
|
||||
}
|
||||
|
||||
emscripten::val activation = model_builder.FindActivation(node, *node.OutputDefs()[0]);
|
||||
if (emscripten::val::null() != activation) {
|
||||
options.set("activation", activation);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -79,10 +79,7 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
|
|||
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
|
||||
options.set("axis", rank - 1);
|
||||
}
|
||||
emscripten::val activation = model_builder.FindActivation(node, *node.OutputDefs()[0]);
|
||||
if (emscripten::val::null() != activation) {
|
||||
options.set("activation", activation);
|
||||
}
|
||||
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("batchNormalization", input, mean, variance, options);
|
||||
} else if (op_type == "LayerNormalization") {
|
||||
int64_t axis = helper.Get("axis", -1);
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge
|
|||
|
||||
Status ModelBuilder::Initialize() {
|
||||
PreprocessInitializers();
|
||||
PreprocessActivations();
|
||||
ORT_RETURN_IF_ERROR(RegisterInitializers());
|
||||
ORT_RETURN_IF_ERROR(RegisterModelInputs());
|
||||
ORT_RETURN_IF_ERROR(AddOperations());
|
||||
|
|
@ -78,79 +77,6 @@ void ModelBuilder::PreprocessInitializers() {
|
|||
}
|
||||
}
|
||||
|
||||
void ModelBuilder::PreprocessActivations() {
|
||||
const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder();
|
||||
|
||||
if (wnn_device_type_ == WebnnDeviceType::CPU) {
|
||||
// WebNN CPU currently only supports "Relu" and "Clip" fusion.
|
||||
supported_activation_nodes_ = {"Clip", "Relu"};
|
||||
} else {
|
||||
supported_activation_nodes_ = {
|
||||
// Temporarily disable clamp fusion for WebNN GPU as which is not supported yet.
|
||||
// "Clip",
|
||||
"Elu",
|
||||
"Gelu",
|
||||
"HardSigmoid",
|
||||
"HardSwish",
|
||||
"Relu",
|
||||
"LeakyRelu",
|
||||
"Sigmoid",
|
||||
"Softplus",
|
||||
"Softsign",
|
||||
"Tanh",
|
||||
};
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < node_indices.size(); i++) {
|
||||
const auto* node(graph_viewer_.GetNode(node_indices[i]));
|
||||
const auto& op_type(node->OpType());
|
||||
|
||||
// Ignore unsupported activation nodes.
|
||||
if (!Contains(supported_activation_nodes_, op_type)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (op_type == "Clip") {
|
||||
float minValue, maxValue;
|
||||
GetClipMinMax(GetInitializerTensors(), *node, minValue, maxValue, logger_);
|
||||
emscripten::val options = emscripten::val::object();
|
||||
options.set("minValue", minValue);
|
||||
options.set("maxValue", maxValue);
|
||||
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("clamp", options));
|
||||
} else if (op_type == "Elu") {
|
||||
NodeAttrHelper helper(*node);
|
||||
emscripten::val options = emscripten::val::object();
|
||||
options.set("alpha", helper.Get("alpha", 1.0f));
|
||||
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("elu", options));
|
||||
} else if (op_type == "Gelu") {
|
||||
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("gelu"));
|
||||
} else if (op_type == "HardSigmoid") {
|
||||
NodeAttrHelper helper(*node);
|
||||
emscripten::val options = emscripten::val::object();
|
||||
options.set("alpha", helper.Get("alpha", 0.2f));
|
||||
options.set("beta", helper.Get("beta", 0.5f));
|
||||
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("hardSigmoid", options));
|
||||
} else if (op_type == "HardSwish") {
|
||||
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("hardSwish"));
|
||||
} else if (op_type == "Relu") {
|
||||
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("relu"));
|
||||
} else if (op_type == "LeakyRelu") {
|
||||
NodeAttrHelper helper(*node);
|
||||
emscripten::val options = emscripten::val::object();
|
||||
options.set("alpha", helper.Get("alpha", 0.0f));
|
||||
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("leakyRelu", options));
|
||||
} else if (op_type == "Sigmoid") {
|
||||
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("sigmoid"));
|
||||
} else if (op_type == "Softplus") {
|
||||
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("softplus"));
|
||||
} else if (op_type == "Softsign") {
|
||||
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("softsign"));
|
||||
} else if (op_type == "Tanh") {
|
||||
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("tanh"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status ModelBuilder::RegisterInitializers() {
|
||||
for (const auto& pair : GetInitializerTensors()) {
|
||||
const auto& tensor = *pair.second;
|
||||
|
|
@ -421,44 +347,6 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
emscripten::val ModelBuilder::FindActivation(const Node& node, const NodeArg& output) {
|
||||
emscripten::val fused_op = emscripten::val::null();
|
||||
for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) {
|
||||
const auto& dst_node = it->GetNode();
|
||||
const auto* dst_input = dst_node.InputDefs()[it->GetDstArgIndex()];
|
||||
if (!Contains(supported_activation_nodes_, dst_node.OpType())) {
|
||||
return emscripten::val::null();
|
||||
}
|
||||
if (Contains(activation_nodes_, dst_node.Index())) {
|
||||
if (&output == dst_input) {
|
||||
fused_op = activation_nodes_.at(dst_node.Index());
|
||||
}
|
||||
} else {
|
||||
// If there is any other non-relu node using the output
|
||||
// will add relu separately.
|
||||
if (&output == dst_input) {
|
||||
return emscripten::val::null();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If output is a graph output, will add relu separately.
|
||||
if (fused_op != emscripten::val::null()) {
|
||||
for (const auto* graph_output : graph_viewer_.GetOutputs()) {
|
||||
if (&output == graph_output) {
|
||||
return emscripten::val::null();
|
||||
}
|
||||
}
|
||||
|
||||
LOGS_DEFAULT(VERBOSE) << "Node [" << node.Name() << "] type [" << node.OpType()
|
||||
<< "], fused the output [" << output.Name() << "]";
|
||||
|
||||
fused_activations_.insert(output.Name());
|
||||
}
|
||||
|
||||
return fused_op;
|
||||
}
|
||||
|
||||
void ModelBuilder::AddScalarOutput(const std::string& output_name) {
|
||||
scalar_outputs_.insert(output_name);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,11 +44,6 @@ class ModelBuilder {
|
|||
Status AddOperandFromPersistMemoryBuffer(
|
||||
const std::string& name, const void* buffer,
|
||||
const size_t size, const std::vector<uint32_t> shape, const int32_t data_type);
|
||||
// Find if an output has a fuseable activation (e.g., Relu).
|
||||
emscripten::val FindActivation(const Node& node, const NodeArg& output);
|
||||
|
||||
const InlinedHashSet<std::string>&
|
||||
GetFusedActivations() const { return fused_activations_; }
|
||||
|
||||
DataLayout GetPreferredLayout() const { return preferred_layout_; }
|
||||
|
||||
|
|
@ -82,22 +77,13 @@ class ModelBuilder {
|
|||
InlinedHashSet<std::string> skipped_initializers_;
|
||||
InlinedHashSet<std::string> skipped_inputs_;
|
||||
|
||||
InlinedHashSet<std::string> fused_activations_;
|
||||
|
||||
InlinedHashSet<std::string> supported_activation_nodes_;
|
||||
|
||||
uint32_t name_token_{0};
|
||||
InlinedHashSet<std::string> unique_names_;
|
||||
|
||||
// All activation nodes (e.g., Relu) as a map <NodeIndex, FusionOperator>.
|
||||
InlinedHashMap<NodeIndex, emscripten::val> activation_nodes_;
|
||||
|
||||
// Convert the onnx model to WebNN operands
|
||||
Status Initialize() ORT_MUST_USE_RESULT;
|
||||
|
||||
void PreprocessInitializers();
|
||||
// Preprocess all the activation nodes (e.g., Relu) for easy query later.
|
||||
void PreprocessActivations();
|
||||
|
||||
// Copy and process all the initializers to WebNN constants.
|
||||
Status RegisterInitializers() ORT_MUST_USE_RESULT;
|
||||
|
|
|
|||
Loading…
Reference in a new issue