mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-24 02:47:54 +00:00
[CoreML EP] Add DepthToSpace op support (#11468)
* initial impl of depthtospace coreml support * fix build * address pr comments * minor update * minor pr comments Co-authored-by: rachguo <rachguo@rachguos-Mini.attlocal.net> Co-authored-by: rachguo <rachguo@rachguos-Mac-mini.local>
This commit is contained in:
parent
a3f05da338
commit
4aef7e3aab
8 changed files with 113 additions and 13 deletions
|
|
@ -14,14 +14,13 @@ namespace coreml {
|
|||
|
||||
class ArgMaxOpBuilder : public BaseOpBuilder {
|
||||
// Add operator related
|
||||
#ifdef __APPLE__
|
||||
private:
|
||||
#ifdef __APPLE__
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
#endif
|
||||
|
||||
// Operator support related
|
||||
private:
|
||||
bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -16,13 +16,12 @@ namespace coreml {
|
|||
|
||||
class BinaryOpBuilder : public BaseOpBuilder {
|
||||
// Add operator related
|
||||
#ifdef __APPLE__
|
||||
private:
|
||||
#ifdef __APPLE__
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
#endif
|
||||
// Operator support related
|
||||
private:
|
||||
int GetMinSupportedOpSet(const Node& node) const override;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -15,13 +15,12 @@ namespace coreml {
|
|||
|
||||
class CastOpBuilder : public BaseOpBuilder {
|
||||
// Add operator related
|
||||
#ifdef __APPLE__
|
||||
private:
|
||||
#ifdef __APPLE__
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
#endif
|
||||
// Operator support related
|
||||
private:
|
||||
bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
|
@ -55,7 +54,7 @@ bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara
|
|||
|
||||
const auto& prec_node = node.InputEdgesBegin()->GetNode();
|
||||
|
||||
/*Cast node is only aimed for supporting argmax and we are only handling the case where an argmax
|
||||
/*Cast node is only aimed for supporting argmax and we are only handling the case where an argmax
|
||||
followed by a cast node. We need to check if the preceding node is an argmax and also if it's a
|
||||
supported argmax op type.*/
|
||||
if (prec_node.OpType() != "ArgMax") {
|
||||
|
|
|
|||
|
|
@ -16,14 +16,13 @@ namespace coreml {
|
|||
|
||||
class ConcatOpBuilder : public BaseOpBuilder {
|
||||
// Add operator related
|
||||
#ifdef __APPLE__
|
||||
private:
|
||||
#ifdef __APPLE__
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
#endif
|
||||
|
||||
// Operator support related
|
||||
private:
|
||||
bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -0,0 +1,100 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <core/common/safeint.h>
|
||||
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
#include "core/providers/coreml/builders/helper.h"
|
||||
#ifdef __APPLE__
|
||||
#include "core/providers/coreml/builders/model_builder.h"
|
||||
#endif
|
||||
#include "core/providers/coreml/builders/op_builder_factory.h"
|
||||
|
||||
#include "base_op_builder.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace coreml {
|
||||
|
||||
class DepthToSpaceOpBuilder : public BaseOpBuilder {
|
||||
// Add operator related
|
||||
private:
|
||||
#ifdef __APPLE__
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
#endif
|
||||
|
||||
// Operator support related
|
||||
bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related
|
||||
|
||||
#ifdef __APPLE__
|
||||
Status DepthToSpaceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
||||
const Node& node,
|
||||
const logging::Logger& /* logger */) const {
|
||||
std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
|
||||
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& output_defs = node.OutputDefs();
|
||||
const auto& input_name = input_defs[0]->Name();
|
||||
const auto& output_name = output_defs[0]->Name();
|
||||
|
||||
uint64_t blocksize = SafeInt<uint64_t>(node.GetAttributes().at("blocksize").i());
|
||||
|
||||
auto* coreml_depthtospace = layer->mutable_reorganizedata();
|
||||
coreml_depthtospace->set_blocksize(blocksize);
|
||||
coreml_depthtospace->set_mode(CoreML::Specification::ReorganizeDataLayerParams_ReorganizationType::
|
||||
ReorganizeDataLayerParams_ReorganizationType_DEPTH_TO_SPACE);
|
||||
|
||||
*layer->mutable_input()->Add() = input_name;
|
||||
*layer->mutable_output()->Add() = output_name;
|
||||
|
||||
model_builder.AddLayer(std::move(layer));
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
// Operator support related
|
||||
|
||||
bool DepthToSpaceOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
|
||||
std::vector<int64_t> input_shape;
|
||||
if (!GetShape(*input_defs[0], input_shape, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto input_size = input_shape.size();
|
||||
if (input_size != 4) {
|
||||
LOGS(logger, VERBOSE) << "DepthToSpace only supports 4d shape, input is " << input_size << "d shape.";
|
||||
}
|
||||
|
||||
// CoreML spec ReorganizeDataLayer DEPTH_TO_SPACE mode only accepts input with one batch ([C, H, W]).
|
||||
if (input_shape[0] != 1) {
|
||||
LOGS(logger, VERBOSE) << "The batch size of DepthToSpace [" << input_shape[0] << "] is not supported.";
|
||||
return false;
|
||||
}
|
||||
|
||||
NodeAttrHelper helper(node);
|
||||
if (node.SinceVersion() >= 11) {
|
||||
// For now, only DCR mode DepthToSpace is supported
|
||||
const auto mode = helper.Get("mode", "DCR");
|
||||
if (mode != "DCR") {
|
||||
LOGS(logger, VERBOSE) << "The mode: " << mode << "of DepthToSpace is not supported in CoreML EP for now.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateDepthToSpaceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<DepthToSpaceOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
}
|
||||
|
||||
} // namespace coreml
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -17,14 +17,13 @@ namespace coreml {
|
|||
|
||||
class PoolOpBuilder : public BaseOpBuilder {
|
||||
// Add operator related
|
||||
#ifdef __APPLE__
|
||||
private:
|
||||
#ifdef __APPLE__
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
#endif
|
||||
|
||||
// Operator support related
|
||||
private:
|
||||
bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -41,6 +41,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
|
|||
CreateReshapeOpBuilder("Reshape", op_registrations);
|
||||
}
|
||||
|
||||
{ // DepthToSpace
|
||||
CreateDepthToSpaceOpBuilder("DepthToSpace", op_registrations);
|
||||
}
|
||||
|
||||
{ // Pool
|
||||
CreatePoolOpBuilder("GlobalAveragePool", op_registrations);
|
||||
CreatePoolOpBuilder("GlobalMaxPool", op_registrations);
|
||||
|
|
@ -86,4 +90,4 @@ const std::unordered_map<std::string, const IOpBuilder*>& GetOpBuilders() {
|
|||
}
|
||||
|
||||
} // namespace coreml
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations&
|
|||
void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateDepthToSpaceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
|
||||
void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreatePoolOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
|
|
@ -34,4 +35,4 @@ void CreateSqueezeOpBuilder(const std::string& op_type, OpBuilderRegistrations&
|
|||
void CreateArgMaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
} // namespace coreml
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue