mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
[CoreML EP] Add Concat support (#6834)
* [CoreML EP] Add concat support * Update comments
This commit is contained in:
parent
2d6e10ba00
commit
5cf6606964
4 changed files with 189 additions and 0 deletions
|
|
@ -0,0 +1,88 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
#include "core/providers/coreml/builders/helper.h"
|
||||
#include "core/providers/coreml/builders/model_builder.h"
|
||||
#include "core/providers/coreml/builders/op_builder_factory.h"
|
||||
|
||||
#include "base_op_builder.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace coreml {
|
||||
|
||||
class ConcatOpBuilder : public BaseOpBuilder {
|
||||
// Add operator related
|
||||
private:
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
|
||||
// Operator support related
|
||||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related
|
||||
|
||||
Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
||||
const Node& node,
|
||||
const logging::Logger& logger) const {
|
||||
std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(node);
|
||||
|
||||
layer->mutable_concat()->set_sequenceconcat(false);
|
||||
|
||||
for (const auto* input : node.InputDefs()) {
|
||||
LOGS(logger, VERBOSE) << "input name " << input->Name();
|
||||
*layer->mutable_input()->Add() = input->Name();
|
||||
}
|
||||
|
||||
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();
|
||||
|
||||
model_builder.AddLayer(std::move(layer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Operator support related
|
||||
bool ConcatOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
if (input_defs.size() < 2) {
|
||||
LOGS(logger, VERBOSE) << "Concat only support 2+ inputs, actual number of input" << input_defs.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<int64_t> input_shape;
|
||||
if (!GetShape(*input_defs[0], input_shape, logger))
|
||||
return false;
|
||||
|
||||
auto rank = input_shape.size();
|
||||
if (rank != 4) {
|
||||
// For some reason, the concat in CoreML running on 3d tensor will concat on wrong axis
|
||||
// Instead of concat on axis 0, it will concat on axis 1
|
||||
// Disable Concat support for 3d tensor for now
|
||||
// TODO, add ExpandDims and Squeeze, 3d -ExpandDims-> 4d -> Concat -Squeeze-> 3d
|
||||
LOGS(logger, VERBOSE) << "Concat only support 4d shape for now, input is "
|
||||
<< rank << "d shape";
|
||||
return false;
|
||||
}
|
||||
|
||||
NodeAttrHelper helper(node);
|
||||
auto axis = static_cast<size_t>(HandleNegativeAxis(helper.Get("axis", 1), rank));
|
||||
if (rank != axis + 3) {
|
||||
LOGS(logger, VERBOSE) << "Concat only support axis to be -3, actual axis: " << axis
|
||||
<< ", actual rank: " << rank;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(onnxruntime::make_unique<ConcatOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
}
|
||||
|
||||
} // namespace coreml
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -47,6 +47,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
|
|||
CreatePoolOpBuilder("GlobalMaxPool", op_registrations);
|
||||
}
|
||||
|
||||
{ // Concat
|
||||
CreateConcatOpBuilder("Concat", op_registrations);
|
||||
}
|
||||
|
||||
return op_registrations;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations
|
|||
void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateBatchNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
|
||||
void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreatePoolOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
|
|
|
|||
|
|
@ -257,5 +257,101 @@ TEST(ConcatOpTest, Concat3D_3) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ConcatOpTest, Concat4D_1) {
|
||||
OpTester test("Concat");
|
||||
test.AddAttribute("axis", int64_t{1});
|
||||
|
||||
std::vector<int64_t> dims{1, 1, 3, 3};
|
||||
test.AddInput<float>("input1", dims,
|
||||
{111.0f, 112.0f, 113.0f,
|
||||
121.0f, 122.0f, 123.0f,
|
||||
131.0f, 132.0f, 133.0f});
|
||||
test.AddInput<float>("input2", dims,
|
||||
{211.0f, 212.0f, 213.0f,
|
||||
221.0f, 222.0f, 223.0f,
|
||||
231.0f, 232.0f, 233.0f});
|
||||
test.AddInput<float>("input3", dims,
|
||||
{311.0f, 312.0f, 313.0f,
|
||||
321.0f, 322.0f, 323.0f,
|
||||
331.0f, 332.0f, 333.0f});
|
||||
test.AddOutput<float>("concat_result", {1, 3, 3, 3},
|
||||
{111.0f, 112.0f, 113.0f,
|
||||
121.0f, 122.0f, 123.0f,
|
||||
131.0f, 132.0f, 133.0f,
|
||||
|
||||
211.0f, 212.0f, 213.0f,
|
||||
221.0f, 222.0f, 223.0f,
|
||||
231.0f, 232.0f, 233.0f,
|
||||
|
||||
311.0f, 312.0f, 313.0f,
|
||||
321.0f, 322.0f, 323.0f,
|
||||
331.0f, 332.0f, 333.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ConcatOpTest, Concat4D_1_negative_axis) {
|
||||
OpTester test("Concat");
|
||||
test.AddAttribute("axis", int64_t{-3});
|
||||
|
||||
std::vector<int64_t> dims{1, 1, 3, 3};
|
||||
test.AddInput<float>("input1", dims,
|
||||
{111.0f, 112.0f, 113.0f,
|
||||
121.0f, 122.0f, 123.0f,
|
||||
131.0f, 132.0f, 133.0f});
|
||||
test.AddInput<float>("input2", dims,
|
||||
{211.0f, 212.0f, 213.0f,
|
||||
221.0f, 222.0f, 223.0f,
|
||||
231.0f, 232.0f, 233.0f});
|
||||
test.AddInput<float>("input3", dims,
|
||||
{311.0f, 312.0f, 313.0f,
|
||||
321.0f, 322.0f, 323.0f,
|
||||
331.0f, 332.0f, 333.0f});
|
||||
test.AddOutput<float>("concat_result", {1, 3, 3, 3},
|
||||
{111.0f, 112.0f, 113.0f,
|
||||
121.0f, 122.0f, 123.0f,
|
||||
131.0f, 132.0f, 133.0f,
|
||||
|
||||
211.0f, 212.0f, 213.0f,
|
||||
221.0f, 222.0f, 223.0f,
|
||||
231.0f, 232.0f, 233.0f,
|
||||
|
||||
311.0f, 312.0f, 313.0f,
|
||||
321.0f, 322.0f, 323.0f,
|
||||
331.0f, 332.0f, 333.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ConcatOpTest, Concat4D_2) {
|
||||
OpTester test("Concat");
|
||||
test.AddAttribute("axis", int64_t{2});
|
||||
|
||||
std::vector<int64_t> dims{1, 3, 1, 3};
|
||||
test.AddInput<float>("input1", dims,
|
||||
{111.0f, 112.0f, 113.0f,
|
||||
211.0f, 212.0f, 213.0f,
|
||||
311.0f, 312.0f, 313.0f});
|
||||
test.AddInput<float>("input2", dims,
|
||||
{121.0f, 122.0f, 123.0f,
|
||||
221.0f, 222.0f, 223.0f,
|
||||
321.0f, 322.0f, 323.0f});
|
||||
test.AddInput<float>("input3", dims,
|
||||
{131.0f, 132.0f, 133.0f,
|
||||
231.0f, 232.0f, 233.0f,
|
||||
331.0f, 332.0f, 333.0f});
|
||||
test.AddOutput<float>("concat_result", {1, 3, 3, 3},
|
||||
{111.0f, 112.0f, 113.0f,
|
||||
121.0f, 122.0f, 123.0f,
|
||||
131.0f, 132.0f, 133.0f,
|
||||
|
||||
211.0f, 212.0f, 213.0f,
|
||||
221.0f, 222.0f, 223.0f,
|
||||
231.0f, 232.0f, 233.0f,
|
||||
|
||||
311.0f, 312.0f, 313.0f,
|
||||
321.0f, 322.0f, 323.0f,
|
||||
331.0f, 332.0f, 333.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue