Add ML Program support for basic activation ops (#21326)

### Description

Add support for:

- Sigmoid
- Relu
- Tanh

### Motivation and Context
Enable support for Autodesk model
This commit is contained in:
vraspar 2024-07-15 22:30:20 -07:00 committed by GitHub
parent 4005d12ed4
commit 218301403d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 68 additions and 25 deletions

View file

@ -26,6 +26,8 @@ class ActivationOpBuilder : public BaseOpBuilder {
const logging::Logger& logger) const override;
int GetMinSupportedOpSet(const Node& node) const override;
bool SupportsMLProgram() const override { return true; }
};
void ActivationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
@ -74,33 +76,61 @@ Status AddPReluWeight(ModelBuilder& model_builder, const Node& node,
Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const Node& node,
const logging::Logger& logger) const {
std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
const auto& op_type(node.OpType());
if (op_type == "Sigmoid") {
layer->mutable_activation()->mutable_sigmoid();
} else if (op_type == "Tanh") {
layer->mutable_activation()->mutable_tanh();
} else if (op_type == "Relu") {
layer->mutable_activation()->mutable_relu();
} else if (op_type == "PRelu") {
auto* prelu = layer->mutable_activation()->mutable_prelu();
ORT_RETURN_IF_ERROR(AddPReluWeight(model_builder, node, logger, *prelu));
} else if (op_type == "LeakyRelu") {
NodeAttrHelper helper(node);
const auto alpha = helper.Get("alpha", 0.01f);
auto* leaky_relu = layer->mutable_activation()->mutable_leakyrelu();
leaky_relu->set_alpha(alpha);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
#if defined(COREML_ENABLE_MLPROGRAM)
if (model_builder.CreateMLProgram()) {
using namespace CoreML::Specification::MILSpec;
// https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.activation
std::string_view coreml_op_type;
if (op_type == "Sigmoid") {
coreml_op_type = "sigmoid";
} else if (op_type == "Tanh") {
coreml_op_type = "tanh";
} else if (op_type == "Relu") {
coreml_op_type = "relu";
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
}
std::unique_ptr<Operation> op = model_builder.CreateOperation(node, coreml_op_type);
AddOperationInput(*op, "x", node.InputDefs()[0]->Name());
AddOperationOutput(*op, *node.OutputDefs()[0]);
model_builder.AddOperation(std::move(op));
} else
#endif // (COREML_ENABLE_MLPROGRAM)
{
std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
if (op_type == "Sigmoid") {
layer->mutable_activation()->mutable_sigmoid();
} else if (op_type == "Tanh") {
layer->mutable_activation()->mutable_tanh();
} else if (op_type == "Relu") {
layer->mutable_activation()->mutable_relu();
} else if (op_type == "PRelu") {
auto* prelu = layer->mutable_activation()->mutable_prelu();
ORT_RETURN_IF_ERROR(AddPReluWeight(model_builder, node, logger, *prelu));
} else if (op_type == "LeakyRelu") {
NodeAttrHelper helper(node);
const auto alpha = helper.Get("alpha", 0.01f);
auto* leaky_relu = layer->mutable_activation()->mutable_leakyrelu();
leaky_relu->set_alpha(alpha);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
}
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();
model_builder.AddLayer(std::move(layer));
}
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();
model_builder.AddLayer(std::move(layer));
return Status::OK();
}
@ -165,9 +195,20 @@ bool IsPReluOpSupported(const Node& node, const OpBuilderInputParams& input_para
bool ActivationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const {
const auto& op_type = node.OpType();
if (op_type == "PRelu") {
return IsPReluOpSupported(node, input_params, logger);
#if defined(COREML_ENABLE_MLPROGRAM)
if (input_params.create_mlprogram) {
if (op_type == "PRelu" || op_type == "LeakyRelu") {
return false;
}
} else
#endif // (COREML_ENABLE_MLPROGRAM)
{
if (op_type == "PRelu") {
return IsPReluOpSupported(node, input_params, logger);
}
}
return true;
}

View file

@ -18,3 +18,5 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution
|ai.onnx:Relu||
|ai.onnx:Reshape||
|ai.onnx:Sub||
|ai.onnx:Sigmoid||
|ai:onnx:Tanh||