diff --git a/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc new file mode 100644 index 0000000000..1ba4c4d5a1 --- /dev/null +++ b/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/model_builder.h" +#include "core/providers/coreml/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace coreml { + +class ConcatOpBuilder : public BaseOpBuilder { + // Add operator related + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const override; +}; + +// Add operator related + +Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + std::unique_ptr layer = CreateNNLayer(node); + + layer->mutable_concat()->set_sequenceconcat(false); + + for (const auto* input : node.InputDefs()) { + LOGS(logger, VERBOSE) << "input name " << input->Name(); + *layer->mutable_input()->Add() = input->Name(); + } + + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + + model_builder.AddLayer(std::move(layer)); + return Status::OK(); +} + +// Operator support related +bool ConcatOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + if (input_defs.size() < 2) { + LOGS(logger, VERBOSE) << "Concat only support 2+ inputs, actual number of input" << input_defs.size(); + return false; + } + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + auto rank = input_shape.size(); + if (rank != 4) { + // For some reason, the concat in CoreML running on 3d tensor will concat on wrong axis + // Instead of concat on axis 0, it will concat on axis 1 + // Disable Concat support for 3d tensor for now + // TODO, add ExpandDims and Squeeze, 3d -ExpandDims-> 4d -> Concat -Squeeze-> 3d + LOGS(logger, VERBOSE) << "Concat only support 4d shape for now, input is " + << rank << "d shape"; + return false; + } + + NodeAttrHelper helper(node); + auto axis = static_cast(HandleNegativeAxis(helper.Get("axis", 1), rank)); + if (rank != axis + 3) { + LOGS(logger, VERBOSE) << "Concat only support axis to be -3, actual axis: " << axis + << ", actual rank: " << rank; + return false; + } + + return true; +} + +void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(onnxruntime::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace coreml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc b/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc index 237c8825a7..7610c8ac19 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc @@ -47,6 +47,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreatePoolOpBuilder("GlobalMaxPool", op_registrations); } + { // Concat + CreateConcatOpBuilder("Concat", op_registrations); + } + return op_registrations; } diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h index 5540c7cb94..b9485c3927 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h @@ -23,6 +23,7 @@ void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateBatchNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreatePoolOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); diff --git a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc index bcece05a01..435049d94c 100644 --- a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc @@ -257,5 +257,101 @@ TEST(ConcatOpTest, Concat3D_3) { test.Run(); } +TEST(ConcatOpTest, Concat4D_1) { + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{1}); + + std::vector dims{1, 1, 3, 3}; + test.AddInput("input1", dims, + {111.0f, 112.0f, 113.0f, + 121.0f, 122.0f, 123.0f, + 131.0f, 132.0f, 133.0f}); + test.AddInput("input2", dims, + {211.0f, 212.0f, 213.0f, + 221.0f, 222.0f, 223.0f, + 231.0f, 232.0f, 233.0f}); + test.AddInput("input3", dims, + {311.0f, 312.0f, 313.0f, + 321.0f, 322.0f, 323.0f, + 331.0f, 332.0f, 333.0f}); + test.AddOutput("concat_result", {1, 3, 3, 3}, + {111.0f, 112.0f, 113.0f, + 121.0f, 122.0f, 123.0f, + 131.0f, 132.0f, 133.0f, + + 211.0f, 212.0f, 213.0f, + 221.0f, 222.0f, 223.0f, + 231.0f, 232.0f, 233.0f, + + 311.0f, 312.0f, 313.0f, + 321.0f, 322.0f, 323.0f, + 331.0f, 332.0f, 333.0f}); + test.Run(); +} + +TEST(ConcatOpTest, Concat4D_1_negative_axis) { + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{-3}); + + std::vector dims{1, 1, 3, 3}; + test.AddInput("input1", dims, + {111.0f, 112.0f, 113.0f, + 121.0f, 122.0f, 123.0f, + 131.0f, 132.0f, 133.0f}); + test.AddInput("input2", dims, + {211.0f, 212.0f, 213.0f, + 221.0f, 222.0f, 223.0f, + 231.0f, 232.0f, 233.0f}); + test.AddInput("input3", dims, + {311.0f, 312.0f, 313.0f, + 321.0f, 322.0f, 323.0f, + 331.0f, 332.0f, 333.0f}); + test.AddOutput("concat_result", {1, 3, 3, 3}, + {111.0f, 112.0f, 113.0f, + 121.0f, 122.0f, 123.0f, + 131.0f, 132.0f, 133.0f, + + 211.0f, 212.0f, 213.0f, + 221.0f, 222.0f, 223.0f, + 231.0f, 232.0f, 233.0f, + + 311.0f, 312.0f, 313.0f, + 321.0f, 322.0f, 323.0f, + 331.0f, 332.0f, 333.0f}); + test.Run(); +} + +TEST(ConcatOpTest, Concat4D_2) { + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{2}); + + std::vector dims{1, 3, 1, 3}; + test.AddInput("input1", dims, + {111.0f, 112.0f, 113.0f, + 211.0f, 212.0f, 213.0f, + 311.0f, 312.0f, 313.0f}); + test.AddInput("input2", dims, + {121.0f, 122.0f, 123.0f, + 221.0f, 222.0f, 223.0f, + 321.0f, 322.0f, 323.0f}); + test.AddInput("input3", dims, + {131.0f, 132.0f, 133.0f, + 231.0f, 232.0f, 233.0f, + 331.0f, 332.0f, 333.0f}); + test.AddOutput("concat_result", {1, 3, 3, 3}, + {111.0f, 112.0f, 113.0f, + 121.0f, 122.0f, 123.0f, + 131.0f, 132.0f, 133.0f, + + 211.0f, 212.0f, 213.0f, + 221.0f, 222.0f, 223.0f, + 231.0f, 232.0f, 233.0f, + + 311.0f, 312.0f, 313.0f, + 321.0f, 322.0f, 323.0f, + 331.0f, 332.0f, 333.0f}); + test.Run(); +} + } // namespace test } // namespace onnxruntime