mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
[CoreML EP] Add Flatten Op and LRN Op support (#14857)
### Description <!-- Describe your changes. --> As title. CoreML Spec for reference: https://apple.github.io/coremltools/mlmodel/Format/NeuralNetwork.html#flattento2dlayerparams https://apple.github.io/coremltools/mlmodel/Format/NeuralNetwork.html#lrnlayerparams ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Fill CoreML Clipchamp usage gaps. --------- Co-authored-by: rachguo <rachguo@rachguos-Mini.attlocal.net>
This commit is contained in:
parent
bf35ad2aa3
commit
7cd4b334a9
4 changed files with 185 additions and 0 deletions
|
|
@ -0,0 +1,94 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
#include "core/providers/coreml/builders/helper.h"
|
||||
#ifdef __APPLE__
|
||||
#include "core/providers/coreml/builders/model_builder.h"
|
||||
#endif
|
||||
#include "core/providers/coreml/builders/op_builder_factory.h"
|
||||
|
||||
#include "base_op_builder.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace coreml {
|
||||
|
||||
class LRNOpBuilder : public BaseOpBuilder {
|
||||
// Add operator related
|
||||
#ifdef __APPLE__
|
||||
private:
|
||||
[[nodiscard]] Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override;
|
||||
#endif
|
||||
|
||||
// Operator support related
|
||||
private:
|
||||
bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related
|
||||
|
||||
#ifdef __APPLE__
|
||||
|
||||
Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
||||
const Node& node,
|
||||
const logging::Logger& logger) const {
|
||||
std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
|
||||
|
||||
auto* coreml_lrn = layer->mutable_lrn();
|
||||
|
||||
NodeAttrHelper helper(node);
|
||||
const auto alpha = helper.Get("alpha", 0.0001f);
|
||||
const auto beta = helper.Get("beta", 0.75f);
|
||||
const auto bias = helper.Get("bias", 1.0f); // k
|
||||
const auto size = helper.Get("size", 1); // localSize
|
||||
|
||||
coreml_lrn->set_alpha(alpha);
|
||||
coreml_lrn->set_beta(beta);
|
||||
coreml_lrn->set_localsize(size);
|
||||
coreml_lrn->set_k(bias);
|
||||
|
||||
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
|
||||
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();
|
||||
|
||||
model_builder.AddLayer(std::move(layer));
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
// Operator support related
|
||||
|
||||
bool LRNOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
|
||||
std::vector<int64_t> input_shape;
|
||||
if (!GetShape(*input_defs[0], input_shape, logger))
|
||||
return false;
|
||||
|
||||
if (input_shape.empty()) {
|
||||
LOGS(logger, VERBOSE) << "LRN does not support empty input shape";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Note: For higher ranks ( > 3), CoreML LRN treats all leading dimensions as the batch,
|
||||
// which differs from ONNX LRN. Only support the case - input rank equals 3 or 4 here.
|
||||
// CoreML Spec:https://apple.github.io/coremltools/mlmodel/Format/NeuralNetwork.html#lrnlayerparams
|
||||
const auto input_rank = input_shape.size();
|
||||
if (input_rank != 3 && input_rank != 4) {
|
||||
LOGS(logger, VERBOSE) << "LRN only supports input rank equals to 3 or 4, input rank is "
|
||||
<< input_rank;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<LRNOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
}
|
||||
|
||||
} // namespace coreml
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
#include "core/providers/coreml/builders/helper.h"
|
||||
|
||||
#ifdef __APPLE__
|
||||
#include "core/providers/coreml/builders/model_builder.h"
|
||||
#endif
|
||||
#include "core/providers/coreml/builders/op_builder_factory.h"
|
||||
|
||||
#include "base_op_builder.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace coreml {
|
||||
|
||||
class FlattenOpBuilder : public BaseOpBuilder {
|
||||
// Add operator related
|
||||
#ifdef __APPLE__
|
||||
private:
|
||||
[[nodiscard]] Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override;
|
||||
#endif
|
||||
|
||||
// Operator support related
|
||||
private:
|
||||
bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related
|
||||
|
||||
#ifdef __APPLE__
|
||||
|
||||
Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
||||
const Node& node,
|
||||
const logging::Logger& logger) const {
|
||||
std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
|
||||
|
||||
// Note: ONNX Flatten corresponds to CoreML FlattenTo2DLayerParams
|
||||
auto* coreml_flatten = layer->mutable_flattento2d();
|
||||
|
||||
NodeAttrHelper helper(node);
|
||||
const int64_t axis = helper.Get("axis ", 1);
|
||||
coreml_flatten->set_axis(axis);
|
||||
|
||||
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
|
||||
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();
|
||||
|
||||
model_builder.AddLayer(std::move(layer));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
// Operator support related
|
||||
|
||||
bool FlattenOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
|
||||
std::vector<int64_t> input_shape;
|
||||
if (!GetShape(*input_defs[0], input_shape, logger))
|
||||
return false;
|
||||
|
||||
if (input_shape.empty()) {
|
||||
LOGS(logger, VERBOSE) << "Flatten does not support empty input shape";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<FlattenOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
}
|
||||
|
||||
} // namespace coreml
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -82,6 +82,14 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
|
|||
CreateCastOpBuilder("Cast", op_registrations);
|
||||
}
|
||||
|
||||
{ // Flatten
|
||||
CreateFlattenOpBuilder("Flatten", op_registrations);
|
||||
}
|
||||
|
||||
{ // LRN
|
||||
CreateLRNOpBuilder("LRN", op_registrations);
|
||||
}
|
||||
|
||||
return op_registrations;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -32,5 +32,8 @@ void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_
|
|||
void CreateSqueezeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateArgMaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
|
||||
} // namespace coreml
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue