From 7cd4b334a9102ff661c7d33e67a4856bc7e34daf Mon Sep 17 00:00:00 2001 From: Rachel Guo <35738743+YUNQIUGUO@users.noreply.github.com> Date: Thu, 2 Mar 2023 09:43:15 -0800 Subject: [PATCH] [CoreML EP] Add Flatten Op and LRN Op support (#14857) ### Description As title. CoreML Spec for reference: https://apple.github.io/coremltools/mlmodel/Format/NeuralNetwork.html#flattento2dlayerparams https://apple.github.io/coremltools/mlmodel/Format/NeuralNetwork.html#lrnlayerparams ### Motivation and Context Fill CoreML Clipchamp usage gaps. --------- Co-authored-by: rachguo --- .../coreml/builders/impl/LRN_op_builder.cc | 94 +++++++++++++++++++ .../builders/impl/flatten_op_builder.cc | 80 ++++++++++++++++ .../coreml/builders/op_builder_factory.cc | 8 ++ .../coreml/builders/op_builder_factory.h | 3 + 4 files changed, 185 insertions(+) create mode 100644 onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc create mode 100644 onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc diff --git a/onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc new file mode 100644 index 0000000000..3ace8e1fc3 --- /dev/null +++ b/onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/shared/utils/utils.h" +#include "core/providers/coreml/builders/helper.h" +#ifdef __APPLE__ +#include "core/providers/coreml/builders/model_builder.h" +#endif +#include "core/providers/coreml/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace coreml { + +class LRNOpBuilder : public BaseOpBuilder { + // Add operator related +#ifdef __APPLE__ + private: + [[nodiscard]] Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override; +#endif + + // Operator support related + private: + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; +}; + +// Add operator related + +#ifdef __APPLE__ + +Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + std::unique_ptr layer = CreateNNLayer(model_builder, node); + + auto* coreml_lrn = layer->mutable_lrn(); + + NodeAttrHelper helper(node); + const auto alpha = helper.Get("alpha", 0.0001f); + const auto beta = helper.Get("beta", 0.75f); + const auto bias = helper.Get("bias", 1.0f); // k + const auto size = helper.Get("size", 1); // localSize + + coreml_lrn->set_alpha(alpha); + coreml_lrn->set_beta(beta); + coreml_lrn->set_localsize(size); + coreml_lrn->set_k(bias); + + *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + + model_builder.AddLayer(std::move(layer)); + return Status::OK(); +} +#endif + +// Operator support related + +bool LRNOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + if (input_shape.empty()) { + LOGS(logger, VERBOSE) << "LRN does not support empty input shape"; + return false; + } + + // Note: For higher ranks ( > 3), CoreML LRN treats all leading dimensions as the batch, + // which differs from ONNX LRN. Only support the case - input rank equals 3 or 4 here. + // CoreML Spec:https://apple.github.io/coremltools/mlmodel/Format/NeuralNetwork.html#lrnlayerparams + const auto input_rank = input_shape.size(); + if (input_rank != 3 && input_rank != 4) { + LOGS(logger, VERBOSE) << "LRN only supports input rank equals to 3 or 4, input rank is " + << input_rank; + return false; + } + + return true; +} + +void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace coreml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc new file mode 100644 index 0000000000..d15db37fdd --- /dev/null +++ b/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/shared/utils/utils.h" +#include "core/providers/coreml/builders/helper.h" + +#ifdef __APPLE__ +#include "core/providers/coreml/builders/model_builder.h" +#endif +#include "core/providers/coreml/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace coreml { + +class FlattenOpBuilder : public BaseOpBuilder { + // Add operator related +#ifdef __APPLE__ + private: + [[nodiscard]] Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override; +#endif + + // Operator support related + private: + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; +}; + +// Add operator related + +#ifdef __APPLE__ + +Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + std::unique_ptr layer = CreateNNLayer(model_builder, node); + + // Note: ONNX Flatten corresponds to CoreML FlattenTo2DLayerParams + auto* coreml_flatten = layer->mutable_flattento2d(); + + NodeAttrHelper helper(node); + const int64_t axis = helper.Get("axis ", 1); + coreml_flatten->set_axis(axis); + + *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + + model_builder.AddLayer(std::move(layer)); + + return Status::OK(); +} +#endif + +// Operator support related + +bool FlattenOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + if (input_shape.empty()) { + LOGS(logger, VERBOSE) << "Flatten does not support empty input shape"; + return false; + } + + return true; +} + +void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace coreml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc b/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc index d6a488b6f9..d78564fb9b 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc @@ -82,6 +82,14 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateCastOpBuilder("Cast", op_registrations); } + { // Flatten + CreateFlattenOpBuilder("Flatten", op_registrations); + } + + { // LRN + CreateLRNOpBuilder("LRN", op_registrations); + } + return op_registrations; } diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h index 96536ff104..7bc7ccc377 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h @@ -32,5 +32,8 @@ void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_ void CreateSqueezeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateArgMaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); + } // namespace coreml } // namespace onnxruntime