diff --git a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc index c8769cf8e8..cd3dfddccc 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc @@ -14,14 +14,13 @@ namespace coreml { class ArgMaxOpBuilder : public BaseOpBuilder { // Add operator related -#ifdef __APPLE__ private: +#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; #endif // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; diff --git a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc index 894039936a..ed6dee8e16 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc @@ -16,13 +16,12 @@ namespace coreml { class BinaryOpBuilder : public BaseOpBuilder { // Add operator related -#ifdef __APPLE__ private: +#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; #endif // Operator support related - private: int GetMinSupportedOpSet(const Node& node) const override; }; diff --git a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc index 7475db5dfd..35f3b22ab5 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc @@ -15,13 +15,12 @@ namespace coreml { class CastOpBuilder : public BaseOpBuilder { // Add operator related -#ifdef __APPLE__ private: +#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; #endif // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; @@ -55,7 +54,7 @@ bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara const auto& prec_node = node.InputEdgesBegin()->GetNode(); - /*Cast node is only aimed for supporting argmax and we are only handling the case where an argmax + /*Cast node is only aimed for supporting argmax and we are only handling the case where an argmax followed by a cast node. We need to check if the preceding node is an argmax and also if it's a supported argmax op type.*/ if (prec_node.OpType() != "ArgMax") { diff --git a/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc index 3992a25a95..c5fb4a0dcf 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc @@ -16,14 +16,13 @@ namespace coreml { class ConcatOpBuilder : public BaseOpBuilder { // Add operator related -#ifdef __APPLE__ private: +#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; #endif // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; diff --git a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc new file mode 100644 index 0000000000..96b2daa310 --- /dev/null +++ b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/providers/shared/utils/utils.h" +#include "core/providers/coreml/builders/helper.h" +#ifdef __APPLE__ +#include "core/providers/coreml/builders/model_builder.h" +#endif +#include "core/providers/coreml/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace coreml { + +class DepthToSpaceOpBuilder : public BaseOpBuilder { + // Add operator related + private: +#ifdef __APPLE__ + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; +#endif + + // Operator support related + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; +}; + +// Add operator related + +#ifdef __APPLE__ +Status DepthToSpaceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& /* logger */) const { + std::unique_ptr layer = CreateNNLayer(model_builder, node); + + const auto& input_defs = node.InputDefs(); + const auto& output_defs = node.OutputDefs(); + const auto& input_name = input_defs[0]->Name(); + const auto& output_name = output_defs[0]->Name(); + + uint64_t blocksize = SafeInt(node.GetAttributes().at("blocksize").i()); + + auto* coreml_depthtospace = layer->mutable_reorganizedata(); + coreml_depthtospace->set_blocksize(blocksize); + coreml_depthtospace->set_mode(CoreML::Specification::ReorganizeDataLayerParams_ReorganizationType:: + ReorganizeDataLayerParams_ReorganizationType_DEPTH_TO_SPACE); + + *layer->mutable_input()->Add() = input_name; + *layer->mutable_output()->Add() = output_name; + + model_builder.AddLayer(std::move(layer)); + return Status::OK(); +} +#endif + +// Operator support related + +bool DepthToSpaceOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + return false; + } + + const auto input_size = input_shape.size(); + if (input_size != 4) { + LOGS(logger, VERBOSE) << "DepthToSpace only supports 4d shape, input is " << input_size << "d shape."; + } + + // CoreML spec ReorganizeDataLayer DEPTH_TO_SPACE mode only accepts input with one batch ([C, H, W]). + if (input_shape[0] != 1) { + LOGS(logger, VERBOSE) << "The batch size of DepthToSpace [" << input_shape[0] << "] is not supported."; + return false; + } + + NodeAttrHelper helper(node); + if (node.SinceVersion() >= 11) { + // For now, only DCR mode DepthToSpace is supported + const auto mode = helper.Get("mode", "DCR"); + if (mode != "DCR") { + LOGS(logger, VERBOSE) << "The mode: " << mode << "of DepthToSpace is not supported in CoreML EP for now."; + return false; + } + } + + return true; +} + +void CreateDepthToSpaceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace coreml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc index cc8b0f1083..17dc90b56e 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc @@ -17,14 +17,13 @@ namespace coreml { class PoolOpBuilder : public BaseOpBuilder { // Add operator related -#ifdef __APPLE__ private: +#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; #endif // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc b/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc index 8592433494..2728fac279 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc @@ -41,6 +41,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateReshapeOpBuilder("Reshape", op_registrations); } + { // DepthToSpace + CreateDepthToSpaceOpBuilder("DepthToSpace", op_registrations); + } + { // Pool CreatePoolOpBuilder("GlobalAveragePool", op_registrations); CreatePoolOpBuilder("GlobalMaxPool", op_registrations); @@ -86,4 +90,4 @@ const std::unordered_map& GetOpBuilders() { } } // namespace coreml -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h index 78ebab64b0..72c5ce7b27 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h @@ -26,6 +26,7 @@ void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateDepthToSpaceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreatePoolOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); @@ -34,4 +35,4 @@ void CreateSqueezeOpBuilder(const std::string& op_type, OpBuilderRegistrations& void CreateArgMaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); } // namespace coreml -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime