[CoreML EP] Add DepthToSpace op support (#11468)

* initial impl of depthtospace coreml support

* fix build

* address pr comments

* minor update

* minor pr comments

Co-authored-by: rachguo <rachguo@rachguos-Mini.attlocal.net>
Co-authored-by: rachguo <rachguo@rachguos-Mac-mini.local>
This commit is contained in:
Rachel Guo 2022-05-12 13:48:51 -07:00 committed by GitHub
parent a3f05da338
commit 4aef7e3aab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 113 additions and 13 deletions

View file

@ -14,14 +14,13 @@ namespace coreml {
class ArgMaxOpBuilder : public BaseOpBuilder {
// Add operator related
#ifdef __APPLE__
private:
#ifdef __APPLE__
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
#endif
// Operator support related
private:
bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const override;
};

View file

@ -16,13 +16,12 @@ namespace coreml {
class BinaryOpBuilder : public BaseOpBuilder {
// Add operator related
#ifdef __APPLE__
private:
#ifdef __APPLE__
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
#endif
// Operator support related
private:
int GetMinSupportedOpSet(const Node& node) const override;
};

View file

@ -15,13 +15,12 @@ namespace coreml {
class CastOpBuilder : public BaseOpBuilder {
// Add operator related
#ifdef __APPLE__
private:
#ifdef __APPLE__
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
#endif
// Operator support related
private:
bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const override;
};
@ -55,7 +54,7 @@ bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara
const auto& prec_node = node.InputEdgesBegin()->GetNode();
/*Cast node is only aimed for supporting argmax and we are only handling the case where an argmax
/*Cast node is only aimed for supporting argmax and we are only handling the case where an argmax
followed by a cast node. We need to check if the preceding node is an argmax and also if it's a
supported argmax op type.*/
if (prec_node.OpType() != "ArgMax") {

View file

@ -16,14 +16,13 @@ namespace coreml {
class ConcatOpBuilder : public BaseOpBuilder {
// Add operator related
#ifdef __APPLE__
private:
#ifdef __APPLE__
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
#endif
// Operator support related
private:
bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const override;
};

View file

@ -0,0 +1,100 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <core/common/safeint.h>
#include "core/providers/shared/utils/utils.h"
#include "core/providers/coreml/builders/helper.h"
#ifdef __APPLE__
#include "core/providers/coreml/builders/model_builder.h"
#endif
#include "core/providers/coreml/builders/op_builder_factory.h"
#include "base_op_builder.h"
namespace onnxruntime {
namespace coreml {
class DepthToSpaceOpBuilder : public BaseOpBuilder {
// Add operator related
private:
#ifdef __APPLE__
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
#endif
// Operator support related
bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const override;
};
// Add operator related
#ifdef __APPLE__
Status DepthToSpaceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const Node& node,
const logging::Logger& /* logger */) const {
std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
const auto& input_defs = node.InputDefs();
const auto& output_defs = node.OutputDefs();
const auto& input_name = input_defs[0]->Name();
const auto& output_name = output_defs[0]->Name();
uint64_t blocksize = SafeInt<uint64_t>(node.GetAttributes().at("blocksize").i());
auto* coreml_depthtospace = layer->mutable_reorganizedata();
coreml_depthtospace->set_blocksize(blocksize);
coreml_depthtospace->set_mode(CoreML::Specification::ReorganizeDataLayerParams_ReorganizationType::
ReorganizeDataLayerParams_ReorganizationType_DEPTH_TO_SPACE);
*layer->mutable_input()->Add() = input_name;
*layer->mutable_output()->Add() = output_name;
model_builder.AddLayer(std::move(layer));
return Status::OK();
}
#endif
// Operator support related
bool DepthToSpaceOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
std::vector<int64_t> input_shape;
if (!GetShape(*input_defs[0], input_shape, logger)) {
return false;
}
const auto input_size = input_shape.size();
if (input_size != 4) {
LOGS(logger, VERBOSE) << "DepthToSpace only supports 4d shape, input is " << input_size << "d shape.";
}
// CoreML spec ReorganizeDataLayer DEPTH_TO_SPACE mode only accepts input with one batch ([C, H, W]).
if (input_shape[0] != 1) {
LOGS(logger, VERBOSE) << "The batch size of DepthToSpace [" << input_shape[0] << "] is not supported.";
return false;
}
NodeAttrHelper helper(node);
if (node.SinceVersion() >= 11) {
// For now, only DCR mode DepthToSpace is supported
const auto mode = helper.Get("mode", "DCR");
if (mode != "DCR") {
LOGS(logger, VERBOSE) << "The mode: " << mode << "of DepthToSpace is not supported in CoreML EP for now.";
return false;
}
}
return true;
}
void CreateDepthToSpaceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<DepthToSpaceOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
}
} // namespace coreml
} // namespace onnxruntime

View file

@ -17,14 +17,13 @@ namespace coreml {
class PoolOpBuilder : public BaseOpBuilder {
// Add operator related
#ifdef __APPLE__
private:
#ifdef __APPLE__
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
#endif
// Operator support related
private:
bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const override;
};

View file

@ -41,6 +41,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateReshapeOpBuilder("Reshape", op_registrations);
}
{ // DepthToSpace
CreateDepthToSpaceOpBuilder("DepthToSpace", op_registrations);
}
{ // Pool
CreatePoolOpBuilder("GlobalAveragePool", op_registrations);
CreatePoolOpBuilder("GlobalMaxPool", op_registrations);
@ -86,4 +90,4 @@ const std::unordered_map<std::string, const IOpBuilder*>& GetOpBuilders() {
}
} // namespace coreml
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -26,6 +26,7 @@ void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations&
void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateDepthToSpaceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreatePoolOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
@ -34,4 +35,4 @@ void CreateSqueezeOpBuilder(const std::string& op_type, OpBuilderRegistrations&
void CreateArgMaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
} // namespace coreml
} // namespace onnxruntime
} // namespace onnxruntime