[CoreML EP] Add Concat support (#6834)

* [CoreML EP] Add concat support

* Update comments
This commit is contained in:
Guoyu Wang 2021-03-01 13:35:44 -08:00 committed by GitHub
parent 2d6e10ba00
commit 5cf6606964
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 189 additions and 0 deletions

View file

@ -0,0 +1,88 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/coreml/builders/helper.h"
#include "core/providers/coreml/builders/model_builder.h"
#include "core/providers/coreml/builders/op_builder_factory.h"
#include "base_op_builder.h"
namespace onnxruntime {
namespace coreml {
class ConcatOpBuilder : public BaseOpBuilder {
// Add operator related
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
// Operator support related
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const logging::Logger& logger) const override;
};
// Add operator related
Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const Node& node,
const logging::Logger& logger) const {
std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(node);
layer->mutable_concat()->set_sequenceconcat(false);
for (const auto* input : node.InputDefs()) {
LOGS(logger, VERBOSE) << "input name " << input->Name();
*layer->mutable_input()->Add() = input->Name();
}
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();
model_builder.AddLayer(std::move(layer));
return Status::OK();
}
// Operator support related
bool ConcatOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
if (input_defs.size() < 2) {
LOGS(logger, VERBOSE) << "Concat only support 2+ inputs, actual number of input" << input_defs.size();
return false;
}
std::vector<int64_t> input_shape;
if (!GetShape(*input_defs[0], input_shape, logger))
return false;
auto rank = input_shape.size();
if (rank != 4) {
// For some reason, the concat in CoreML running on 3d tensor will concat on wrong axis
// Instead of concat on axis 0, it will concat on axis 1
// Disable Concat support for 3d tensor for now
// TODO, add ExpandDims and Squeeze, 3d -ExpandDims-> 4d -> Concat -Squeeze-> 3d
LOGS(logger, VERBOSE) << "Concat only support 4d shape for now, input is "
<< rank << "d shape";
return false;
}
NodeAttrHelper helper(node);
auto axis = static_cast<size_t>(HandleNegativeAxis(helper.Get("axis", 1), rank));
if (rank != axis + 3) {
LOGS(logger, VERBOSE) << "Concat only support axis to be -3, actual axis: " << axis
<< ", actual rank: " << rank;
return false;
}
return true;
}
void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(onnxruntime::make_unique<ConcatOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
}
} // namespace coreml
} // namespace onnxruntime

View file

@ -47,6 +47,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreatePoolOpBuilder("GlobalMaxPool", op_registrations);
}
{ // Concat
CreateConcatOpBuilder("Concat", op_registrations);
}
return op_registrations;
}

View file

@ -23,6 +23,7 @@ void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations
void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateBatchNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreatePoolOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

View file

@ -257,5 +257,101 @@ TEST(ConcatOpTest, Concat3D_3) {
test.Run();
}
TEST(ConcatOpTest, Concat4D_1) {
OpTester test("Concat");
test.AddAttribute("axis", int64_t{1});
std::vector<int64_t> dims{1, 1, 3, 3};
test.AddInput<float>("input1", dims,
{111.0f, 112.0f, 113.0f,
121.0f, 122.0f, 123.0f,
131.0f, 132.0f, 133.0f});
test.AddInput<float>("input2", dims,
{211.0f, 212.0f, 213.0f,
221.0f, 222.0f, 223.0f,
231.0f, 232.0f, 233.0f});
test.AddInput<float>("input3", dims,
{311.0f, 312.0f, 313.0f,
321.0f, 322.0f, 323.0f,
331.0f, 332.0f, 333.0f});
test.AddOutput<float>("concat_result", {1, 3, 3, 3},
{111.0f, 112.0f, 113.0f,
121.0f, 122.0f, 123.0f,
131.0f, 132.0f, 133.0f,
211.0f, 212.0f, 213.0f,
221.0f, 222.0f, 223.0f,
231.0f, 232.0f, 233.0f,
311.0f, 312.0f, 313.0f,
321.0f, 322.0f, 323.0f,
331.0f, 332.0f, 333.0f});
test.Run();
}
TEST(ConcatOpTest, Concat4D_1_negative_axis) {
OpTester test("Concat");
test.AddAttribute("axis", int64_t{-3});
std::vector<int64_t> dims{1, 1, 3, 3};
test.AddInput<float>("input1", dims,
{111.0f, 112.0f, 113.0f,
121.0f, 122.0f, 123.0f,
131.0f, 132.0f, 133.0f});
test.AddInput<float>("input2", dims,
{211.0f, 212.0f, 213.0f,
221.0f, 222.0f, 223.0f,
231.0f, 232.0f, 233.0f});
test.AddInput<float>("input3", dims,
{311.0f, 312.0f, 313.0f,
321.0f, 322.0f, 323.0f,
331.0f, 332.0f, 333.0f});
test.AddOutput<float>("concat_result", {1, 3, 3, 3},
{111.0f, 112.0f, 113.0f,
121.0f, 122.0f, 123.0f,
131.0f, 132.0f, 133.0f,
211.0f, 212.0f, 213.0f,
221.0f, 222.0f, 223.0f,
231.0f, 232.0f, 233.0f,
311.0f, 312.0f, 313.0f,
321.0f, 322.0f, 323.0f,
331.0f, 332.0f, 333.0f});
test.Run();
}
TEST(ConcatOpTest, Concat4D_2) {
OpTester test("Concat");
test.AddAttribute("axis", int64_t{2});
std::vector<int64_t> dims{1, 3, 1, 3};
test.AddInput<float>("input1", dims,
{111.0f, 112.0f, 113.0f,
211.0f, 212.0f, 213.0f,
311.0f, 312.0f, 313.0f});
test.AddInput<float>("input2", dims,
{121.0f, 122.0f, 123.0f,
221.0f, 222.0f, 223.0f,
321.0f, 322.0f, 323.0f});
test.AddInput<float>("input3", dims,
{131.0f, 132.0f, 133.0f,
231.0f, 232.0f, 233.0f,
331.0f, 332.0f, 333.0f});
test.AddOutput<float>("concat_result", {1, 3, 3, 3},
{111.0f, 112.0f, 113.0f,
121.0f, 122.0f, 123.0f,
131.0f, 132.0f, 133.0f,
211.0f, 212.0f, 213.0f,
221.0f, 222.0f, 223.0f,
231.0f, 232.0f, 233.0f,
311.0f, 312.0f, 313.0f,
321.0f, 322.0f, 323.0f,
331.0f, 332.0f, 333.0f});
test.Run();
}
} // namespace test
} // namespace onnxruntime