[WebNN EP] Support several activation ops (#16693)

Support Elu, HardSigmoid, HardSwish, Softplus, Softsign, Tanh.
This commit is contained in:
Wanming Lin 2023-07-15 05:36:15 +08:00 committed by GitHub
parent a189e76fde
commit ea43671eb6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 112 additions and 69 deletions

View file

@ -93,55 +93,61 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
const logging::Logger& logger);
static const InlinedHashMap<std::string, std::string> op_map = {
{"Abs", "abs"},
{"Add", "add"},
{"ArgMax", "argMax"},
{"ArgMin", "argMin"},
{"Add", "add"},
{"Sub", "sub"},
{"Mul", "mul"},
{"Div", "div"},
{"Pow", "pow"},
{"AveragePool", "averagePool2d"},
{"Cast", "cast"},
{"Ceil", "ceil"},
{"Clip", "clamp"},
{"Concat", "concat"},
{"Conv", "conv2d"},
{"ConvTranspose", "convTranspose2d"},
{"Cos", "cos"},
{"Div", "div"},
{"Elu", "elu"},
{"Equal", "equal"},
{"Erf", "erf"},
{"Exp", "exp"},
{"Neg", "neg"},
{"Not", "logicalNot"},
{"Floor", "floor"},
{"Flatten", "flattenTo2d"},
{"Identity", "identity"},
{"Reciprocal", "reciprocal"},
{"Sin", "sin"},
{"Sqrt", "sqrt"},
{"Tan", "tan"},
{"Relu", "relu"},
{"LeakyRelu", "leakyRelu"},
{"Sigmoid", "sigmoid"},
{"Slice", "slice"},
{"Softmax", "softmax"},
{"Cast", "cast"},
{"Clip", "clamp"},
{"Conv", "conv2d"},
{"ConvTranspose", "convTranspose2d"},
{"Concat", "concat"},
{"Expand", "expand"},
{"Flatten", "flattenTo2d"},
{"Floor", "floor"},
{"Gather", "gather"},
{"Gemm", "gemm"},
{"MatMul", "matmul"},
{"GlobalAveragePool", "averagePool2d"},
{"GlobalMaxPool", "maxPool2d"},
{"AveragePool", "averagePool2d"},
{"GroupNormalization", "meanVarianceNormalization"},
{"HardSigmoid", "hardSigmoid"},
{"HardSwish", "hardSwish"},
{"Identity", "identity"},
{"InstanceNormalization", "meanVarianceNormalization"},
{"LayerNormalization", "meanVarianceNormalization"},
{"LeakyRelu", "leakyRelu"},
{"MatMul", "matmul"},
{"MaxPool", "maxPool2d"},
{"Mul", "mul"},
{"Neg", "neg"},
{"Not", "logicalNot"},
{"Pow", "pow"},
{"Reciprocal", "reciprocal"},
{"ReduceMax", "reduceMax"},
{"ReduceMean", "reduceMean"},
{"Relu", "relu"},
{"Reshape", "reshape"},
{"Resize", "resample2d"},
{"Shape", "slice"},
{"Sigmoid", "sigmoid"},
{"Softplus", "softplus"},
{"Softsign", "softsign"},
{"Sin", "sin"},
{"Slice", "slice"},
{"Softmax", "softmax"},
{"Split", "split"},
{"Sqrt", "sqrt"},
{"Squeeze", "squeeze"},
{"Sub", "sub"},
{"Tan", "tan"},
{"Tanh", "tanh"},
{"Transpose", "transpose"},
{"Unsqueeze", "unsqueeze"},
{"Where", "elementwiseIf"},

View file

@ -31,34 +31,40 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const auto& op_type(node.OpType());
emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name());
emscripten::val output = emscripten::val::object();
if (op_type == "Relu") {
if (Contains(model_builder.GetFusedActivations(), node.InputDefs()[0]->Name())) {
LOGS_DEFAULT(VERBOSE) << "Relu Node [" << node.Name() << "] fused";
output = input;
} else {
output = model_builder.GetBuilder().call<emscripten::val>("relu", input);
}
} else if (op_type == "LeakyRelu") {
if (Contains(model_builder.GetFusedActivations(), node.InputDefs()[0]->Name())) {
LOGS_DEFAULT(VERBOSE) << "LeakyRelu Node [" << node.Name() << "] fused";
LOGS_DEFAULT(VERBOSE) << op_type << " Node [" << node.Name() << "] fused";
output = input;
} else {
NodeAttrHelper helper(node);
emscripten::val options = emscripten::val::object();
options.set("alpha", helper.Get("alpha", (float)0.0));
if (op_type == "Elu") {
options.set("alpha", helper.Get("alpha", 1.0f));
output = model_builder.GetBuilder().call<emscripten::val>("elu", input, options);
} else if (op_type == "HardSigmoid") {
options.set("alpha", helper.Get("alpha", 0.2f));
options.set("beta", helper.Get("beta", 0.5f));
output = model_builder.GetBuilder().call<emscripten::val>("hardSigmoid", input, options);
} else if (op_type == "HardSwish") {
output = model_builder.GetBuilder().call<emscripten::val>("hardSwish", input);
} else if (op_type == "LeakyRelu") {
options.set("alpha", helper.Get("alpha", 0.0f));
output = model_builder.GetBuilder().call<emscripten::val>("leakyRelu", input, options);
}
} else if (op_type == "Relu") {
output = model_builder.GetBuilder().call<emscripten::val>("relu", input);
} else if (op_type == "Sigmoid") {
if (Contains(model_builder.GetFusedActivations(), node.InputDefs()[0]->Name())) {
LOGS_DEFAULT(VERBOSE) << "Sigmoid Node [" << node.Name() << "] fused";
output = input;
} else {
output = model_builder.GetBuilder().call<emscripten::val>("sigmoid", input);
}
} else if (op_type == "Softplus") {
output = model_builder.GetBuilder().call<emscripten::val>("softplus", input);
} else if (op_type == "Softsign") {
output = model_builder.GetBuilder().call<emscripten::val>("softsign", input);
} else if (op_type == "Tanh") {
output = model_builder.GetBuilder().call<emscripten::val>("tanh", input);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
}
}
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
return Status::OK();
@ -67,7 +73,7 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
// Operator support related.
int ActivationOpBuilder::GetMinSupportedOpSet(const Node& /* node */) const {
// All ops opset 5- uses consumed_inputs attribute which is not supported for now.
// Any operators < opset 6 used the deprecated "consumed_inputs attribute" which is unsupported.
return 6;
}
@ -75,7 +81,18 @@ void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistration
if (op_registrations.op_builder_map.find(op_type) != op_registrations.op_builder_map.cend())
return;
static std::vector<std::string> op_types = {"Relu", "LeakyRelu", "Sigmoid"};
static std::vector<std::string> op_types =
{
"Elu",
"HardSigmoid",
"HardSwish",
"LeakyRelu",
"Relu",
"Sigmoid",
"Softplus",
"Softsign",
"Tanh",
};
op_registrations.builders.push_back(std::make_unique<ActivationOpBuilder>());
for (const auto& type : op_types) {

View file

@ -57,33 +57,47 @@ void ModelBuilder::PreprocessInitializers() {
}
}
emscripten::val GetClampOperator(
const emscripten::val& builder, float min_value, float max_value) {
emscripten::val options = emscripten::val::object();
options.set("minValue", min_value);
options.set("maxValue", max_value);
return builder.call<emscripten::val>("clamp", options);
}
void ModelBuilder::PreprocessActivations() {
const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder();
for (size_t i = 0; i < node_indices.size(); i++) {
const auto* node(graph_viewer_.GetNode(node_indices[i]));
const auto& op_type(node->OpType());
if (op_type == "Relu") {
if (op_type == "Clip") {
float minValue, maxValue;
GetClipMinMax(GetInitializerTensors(), *node, minValue, maxValue, logger_);
emscripten::val options = emscripten::val::object();
options.set("minValue", minValue);
options.set("maxValue", maxValue);
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("clamp", options));
} else if (op_type == "Elu") {
NodeAttrHelper helper(*node);
emscripten::val options = emscripten::val::object();
options.set("alpha", helper.Get("alpha", 1.0f));
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("elu", options));
} else if (op_type == "HardSigmoid") {
NodeAttrHelper helper(*node);
emscripten::val options = emscripten::val::object();
options.set("alpha", helper.Get("alpha", 0.2f));
options.set("beta", helper.Get("beta", 0.5f));
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("hardSigmoid", options));
} else if (op_type == "HardSwish") {
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("hardSwish"));
} else if (op_type == "Relu") {
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("relu"));
} else if (op_type == "LeakyRelu") {
NodeAttrHelper helper(*node);
emscripten::val options = emscripten::val::object();
options.set("alpha", helper.Get("alpha", (float)0.0));
options.set("alpha", helper.Get("alpha", 0.0f));
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("leakyRelu", options));
} else if (op_type == "Sigmoid") {
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("sigmoid"));
} else if (op_type == "Clip") {
float minValue, maxValue;
GetClipMinMax(GetInitializerTensors(), *node, minValue, maxValue, logger_);
activation_nodes_.emplace(node->Index(), GetClampOperator(wnn_builder_, minValue, maxValue));
} else if (op_type == "Softplus") {
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("softplus"));
} else if (op_type == "Softsign") {
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("softsign"));
} else if (op_type == "Tanh") {
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("tanh"));
}
}
}

View file

@ -44,9 +44,15 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
}
{ // Activations
CreateActivationOpBuilder("Relu", op_registrations);
CreateActivationOpBuilder("Elu", op_registrations);
CreateActivationOpBuilder("HardSigmoid", op_registrations);
CreateActivationOpBuilder("HardSwish", op_registrations);
CreateActivationOpBuilder("LeakyRelu", op_registrations);
CreateActivationOpBuilder("Relu", op_registrations);
CreateActivationOpBuilder("Sigmoid", op_registrations);
CreateActivationOpBuilder("Softplus", op_registrations);
CreateActivationOpBuilder("Softsign", op_registrations);
CreateActivationOpBuilder("Tanh", op_registrations);
}
{ // ArgMax/ArgMin