mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[WebNN EP] Support several activation ops (#16693)
Support Elu, HardSigmoid, HardSwish, Softplus, Softsign, Tanh.
This commit is contained in:
parent
a189e76fde
commit
ea43671eb6
4 changed files with 112 additions and 69 deletions
|
|
@ -93,55 +93,61 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
|
|||
const logging::Logger& logger);
|
||||
static const InlinedHashMap<std::string, std::string> op_map = {
|
||||
{"Abs", "abs"},
|
||||
{"Add", "add"},
|
||||
{"ArgMax", "argMax"},
|
||||
{"ArgMin", "argMin"},
|
||||
{"Add", "add"},
|
||||
{"Sub", "sub"},
|
||||
{"Mul", "mul"},
|
||||
{"Div", "div"},
|
||||
{"Pow", "pow"},
|
||||
{"AveragePool", "averagePool2d"},
|
||||
{"Cast", "cast"},
|
||||
{"Ceil", "ceil"},
|
||||
{"Clip", "clamp"},
|
||||
{"Concat", "concat"},
|
||||
{"Conv", "conv2d"},
|
||||
{"ConvTranspose", "convTranspose2d"},
|
||||
{"Cos", "cos"},
|
||||
{"Div", "div"},
|
||||
{"Elu", "elu"},
|
||||
{"Equal", "equal"},
|
||||
{"Erf", "erf"},
|
||||
{"Exp", "exp"},
|
||||
{"Neg", "neg"},
|
||||
{"Not", "logicalNot"},
|
||||
{"Floor", "floor"},
|
||||
{"Flatten", "flattenTo2d"},
|
||||
{"Identity", "identity"},
|
||||
{"Reciprocal", "reciprocal"},
|
||||
{"Sin", "sin"},
|
||||
{"Sqrt", "sqrt"},
|
||||
{"Tan", "tan"},
|
||||
{"Relu", "relu"},
|
||||
{"LeakyRelu", "leakyRelu"},
|
||||
{"Sigmoid", "sigmoid"},
|
||||
{"Slice", "slice"},
|
||||
{"Softmax", "softmax"},
|
||||
{"Cast", "cast"},
|
||||
{"Clip", "clamp"},
|
||||
{"Conv", "conv2d"},
|
||||
{"ConvTranspose", "convTranspose2d"},
|
||||
{"Concat", "concat"},
|
||||
{"Expand", "expand"},
|
||||
{"Flatten", "flattenTo2d"},
|
||||
{"Floor", "floor"},
|
||||
{"Gather", "gather"},
|
||||
{"Gemm", "gemm"},
|
||||
{"MatMul", "matmul"},
|
||||
{"GlobalAveragePool", "averagePool2d"},
|
||||
{"GlobalMaxPool", "maxPool2d"},
|
||||
{"AveragePool", "averagePool2d"},
|
||||
{"GroupNormalization", "meanVarianceNormalization"},
|
||||
{"HardSigmoid", "hardSigmoid"},
|
||||
{"HardSwish", "hardSwish"},
|
||||
{"Identity", "identity"},
|
||||
{"InstanceNormalization", "meanVarianceNormalization"},
|
||||
{"LayerNormalization", "meanVarianceNormalization"},
|
||||
{"LeakyRelu", "leakyRelu"},
|
||||
{"MatMul", "matmul"},
|
||||
{"MaxPool", "maxPool2d"},
|
||||
{"Mul", "mul"},
|
||||
{"Neg", "neg"},
|
||||
{"Not", "logicalNot"},
|
||||
{"Pow", "pow"},
|
||||
{"Reciprocal", "reciprocal"},
|
||||
{"ReduceMax", "reduceMax"},
|
||||
{"ReduceMean", "reduceMean"},
|
||||
{"Relu", "relu"},
|
||||
{"Reshape", "reshape"},
|
||||
{"Resize", "resample2d"},
|
||||
{"Shape", "slice"},
|
||||
{"Sigmoid", "sigmoid"},
|
||||
{"Softplus", "softplus"},
|
||||
{"Softsign", "softsign"},
|
||||
{"Sin", "sin"},
|
||||
{"Slice", "slice"},
|
||||
{"Softmax", "softmax"},
|
||||
{"Split", "split"},
|
||||
{"Sqrt", "sqrt"},
|
||||
{"Squeeze", "squeeze"},
|
||||
{"Sub", "sub"},
|
||||
{"Tan", "tan"},
|
||||
{"Tanh", "tanh"},
|
||||
{"Transpose", "transpose"},
|
||||
{"Unsqueeze", "unsqueeze"},
|
||||
{"Where", "elementwiseIf"},
|
||||
|
|
|
|||
|
|
@ -31,33 +31,39 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
const auto& op_type(node.OpType());
|
||||
emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name());
|
||||
emscripten::val output = emscripten::val::object();
|
||||
if (op_type == "Relu") {
|
||||
if (Contains(model_builder.GetFusedActivations(), node.InputDefs()[0]->Name())) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Relu Node [" << node.Name() << "] fused";
|
||||
output = input;
|
||||
} else {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("relu", input);
|
||||
}
|
||||
} else if (op_type == "LeakyRelu") {
|
||||
if (Contains(model_builder.GetFusedActivations(), node.InputDefs()[0]->Name())) {
|
||||
LOGS_DEFAULT(VERBOSE) << "LeakyRelu Node [" << node.Name() << "] fused";
|
||||
output = input;
|
||||
} else {
|
||||
NodeAttrHelper helper(node);
|
||||
emscripten::val options = emscripten::val::object();
|
||||
options.set("alpha", helper.Get("alpha", (float)0.0));
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("leakyRelu", input, options);
|
||||
}
|
||||
} else if (op_type == "Sigmoid") {
|
||||
if (Contains(model_builder.GetFusedActivations(), node.InputDefs()[0]->Name())) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Sigmoid Node [" << node.Name() << "] fused";
|
||||
output = input;
|
||||
} else {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("sigmoid", input);
|
||||
}
|
||||
|
||||
if (Contains(model_builder.GetFusedActivations(), node.InputDefs()[0]->Name())) {
|
||||
LOGS_DEFAULT(VERBOSE) << op_type << " Node [" << node.Name() << "] fused";
|
||||
output = input;
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
|
||||
NodeAttrHelper helper(node);
|
||||
emscripten::val options = emscripten::val::object();
|
||||
if (op_type == "Elu") {
|
||||
options.set("alpha", helper.Get("alpha", 1.0f));
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("elu", input, options);
|
||||
} else if (op_type == "HardSigmoid") {
|
||||
options.set("alpha", helper.Get("alpha", 0.2f));
|
||||
options.set("beta", helper.Get("beta", 0.5f));
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("hardSigmoid", input, options);
|
||||
} else if (op_type == "HardSwish") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("hardSwish", input);
|
||||
} else if (op_type == "LeakyRelu") {
|
||||
options.set("alpha", helper.Get("alpha", 0.0f));
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("leakyRelu", input, options);
|
||||
} else if (op_type == "Relu") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("relu", input);
|
||||
} else if (op_type == "Sigmoid") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("sigmoid", input);
|
||||
} else if (op_type == "Softplus") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("softplus", input);
|
||||
} else if (op_type == "Softsign") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("softsign", input);
|
||||
} else if (op_type == "Tanh") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("tanh", input);
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
|
||||
}
|
||||
}
|
||||
|
||||
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
|
||||
|
|
@ -67,7 +73,7 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
// Operator support related.
|
||||
|
||||
int ActivationOpBuilder::GetMinSupportedOpSet(const Node& /* node */) const {
|
||||
// All ops opset 5- uses consumed_inputs attribute which is not supported for now.
|
||||
// Any operators < opset 6 used the deprecated "consumed_inputs attribute" which is unsupported.
|
||||
return 6;
|
||||
}
|
||||
|
||||
|
|
@ -75,7 +81,18 @@ void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistration
|
|||
if (op_registrations.op_builder_map.find(op_type) != op_registrations.op_builder_map.cend())
|
||||
return;
|
||||
|
||||
static std::vector<std::string> op_types = {"Relu", "LeakyRelu", "Sigmoid"};
|
||||
static std::vector<std::string> op_types =
|
||||
{
|
||||
"Elu",
|
||||
"HardSigmoid",
|
||||
"HardSwish",
|
||||
"LeakyRelu",
|
||||
"Relu",
|
||||
"Sigmoid",
|
||||
"Softplus",
|
||||
"Softsign",
|
||||
"Tanh",
|
||||
};
|
||||
|
||||
op_registrations.builders.push_back(std::make_unique<ActivationOpBuilder>());
|
||||
for (const auto& type : op_types) {
|
||||
|
|
|
|||
|
|
@ -57,33 +57,47 @@ void ModelBuilder::PreprocessInitializers() {
|
|||
}
|
||||
}
|
||||
|
||||
emscripten::val GetClampOperator(
|
||||
const emscripten::val& builder, float min_value, float max_value) {
|
||||
emscripten::val options = emscripten::val::object();
|
||||
options.set("minValue", min_value);
|
||||
options.set("maxValue", max_value);
|
||||
return builder.call<emscripten::val>("clamp", options);
|
||||
}
|
||||
|
||||
void ModelBuilder::PreprocessActivations() {
|
||||
const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder();
|
||||
for (size_t i = 0; i < node_indices.size(); i++) {
|
||||
const auto* node(graph_viewer_.GetNode(node_indices[i]));
|
||||
const auto& op_type(node->OpType());
|
||||
|
||||
if (op_type == "Relu") {
|
||||
if (op_type == "Clip") {
|
||||
float minValue, maxValue;
|
||||
GetClipMinMax(GetInitializerTensors(), *node, minValue, maxValue, logger_);
|
||||
emscripten::val options = emscripten::val::object();
|
||||
options.set("minValue", minValue);
|
||||
options.set("maxValue", maxValue);
|
||||
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("clamp", options));
|
||||
} else if (op_type == "Elu") {
|
||||
NodeAttrHelper helper(*node);
|
||||
emscripten::val options = emscripten::val::object();
|
||||
options.set("alpha", helper.Get("alpha", 1.0f));
|
||||
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("elu", options));
|
||||
} else if (op_type == "HardSigmoid") {
|
||||
NodeAttrHelper helper(*node);
|
||||
emscripten::val options = emscripten::val::object();
|
||||
options.set("alpha", helper.Get("alpha", 0.2f));
|
||||
options.set("beta", helper.Get("beta", 0.5f));
|
||||
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("hardSigmoid", options));
|
||||
} else if (op_type == "HardSwish") {
|
||||
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("hardSwish"));
|
||||
} else if (op_type == "Relu") {
|
||||
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("relu"));
|
||||
} else if (op_type == "LeakyRelu") {
|
||||
NodeAttrHelper helper(*node);
|
||||
emscripten::val options = emscripten::val::object();
|
||||
options.set("alpha", helper.Get("alpha", (float)0.0));
|
||||
options.set("alpha", helper.Get("alpha", 0.0f));
|
||||
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("leakyRelu", options));
|
||||
} else if (op_type == "Sigmoid") {
|
||||
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("sigmoid"));
|
||||
} else if (op_type == "Clip") {
|
||||
float minValue, maxValue;
|
||||
GetClipMinMax(GetInitializerTensors(), *node, minValue, maxValue, logger_);
|
||||
activation_nodes_.emplace(node->Index(), GetClampOperator(wnn_builder_, minValue, maxValue));
|
||||
} else if (op_type == "Softplus") {
|
||||
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("softplus"));
|
||||
} else if (op_type == "Softsign") {
|
||||
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("softsign"));
|
||||
} else if (op_type == "Tanh") {
|
||||
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("tanh"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,9 +44,15 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
|
|||
}
|
||||
|
||||
{ // Activations
|
||||
CreateActivationOpBuilder("Relu", op_registrations);
|
||||
CreateActivationOpBuilder("Elu", op_registrations);
|
||||
CreateActivationOpBuilder("HardSigmoid", op_registrations);
|
||||
CreateActivationOpBuilder("HardSwish", op_registrations);
|
||||
CreateActivationOpBuilder("LeakyRelu", op_registrations);
|
||||
CreateActivationOpBuilder("Relu", op_registrations);
|
||||
CreateActivationOpBuilder("Sigmoid", op_registrations);
|
||||
CreateActivationOpBuilder("Softplus", op_registrations);
|
||||
CreateActivationOpBuilder("Softsign", op_registrations);
|
||||
CreateActivationOpBuilder("Tanh", op_registrations);
|
||||
}
|
||||
|
||||
{ // ArgMax/ArgMin
|
||||
|
|
|
|||
Loading…
Reference in a new issue