mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Fix CoreML Flatten handling of axis attribute (#16046)
### Description <!-- Describe your changes. --> The CoreML EP implementation was not reading the axis attribute correctly causing an incorrect output shape to be produced for a Flatten node. That issue gets hidden as the Tensor to write the output to is created by the CoreML EP using the inferred output shape (which is correct) and we provide the Tensor's buffer but not the shape when executing the CoreML model. As the flatten isn't changing or moving any data nothing breaks when we test with only a Flatten node in the model. Fix the attribute name and add a test that uses a model with a Flatten followed by a Mul which requires broadcasting. Both nodes are handled by CoreML, so if the axis is not correctly processed the output from Flatten will not be broadcastable and the CoreML model execution will fail. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Bug fix.
This commit is contained in:
parent
1b498e414f
commit
55c3f4b28f
3 changed files with 35 additions and 1 deletions
|
|
@ -41,7 +41,7 @@ Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
auto* coreml_flatten = layer->mutable_flattento2d();
|
||||
|
||||
NodeAttrHelper helper(node);
|
||||
const int64_t axis = helper.Get("axis ", 1);
|
||||
const int64_t axis = helper.Get("axis", 1);
|
||||
coreml_flatten->set_axis(axis);
|
||||
|
||||
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "core/session/environment.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -64,5 +66,37 @@ TEST_F(FlattenOpTest, Flatten_neg_axis3) {
|
|||
test_.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
}
|
||||
|
||||
// Regression test primarily for CoreML.
|
||||
// The CoreML EP implementation was not reading the axis attribute correctly causing an incorrect output shape to be
|
||||
// produced for a Flatten node. That issue gets hidden as the Tensor to write the output to is created by the
|
||||
// CoreML EP using the inferred output shape (which is correct) and we provide the Tensor's buffer but not the shape
|
||||
// when executing the CoreML model. As the flatten isn't changing or moving any data nothing breaks when we test
|
||||
// with only a Flatten node in the model.
|
||||
//
|
||||
// This test uses a model with a Flatten followed by a Mul which requires broadcasting. Both nodes are handled by
|
||||
// CoreML, so if the axis is not correctly processed the output from Flatten will not be broadcastable and the CoreML
|
||||
// model execution will fail.
|
||||
TEST(FlattenOpModelTest, Flatten_broadcast) {
|
||||
auto model_uri = ORT_TSTR("testdata/flatten_broadcast.onnx");
|
||||
|
||||
std::shared_ptr<Model> model;
|
||||
auto status = Model::Load(model_uri, model, nullptr, GetEnvironment().GetLoggingManager()->DefaultLogger());
|
||||
|
||||
OpTester tester("flatten_broadcast");
|
||||
tester.SetModelCache(model);
|
||||
|
||||
tester.AddInput<float>("X", {4}, {0.f, 1.f, 2.f, 3.f});
|
||||
tester.AddInput<float>("Y", {3, 4},
|
||||
{0.f, 1.f, 2.f, 3.f,
|
||||
4.f, 5.f, 6.f, 7.f,
|
||||
8.f, 9.f, 10.f, 11.f});
|
||||
tester.AddOutput<float>("Z", {3, 4},
|
||||
{0.f, 1.f, 4.f, 9.f,
|
||||
0.f, 5.f, 12.f, 21.f,
|
||||
0.f, 9.f, 20.f, 33.f});
|
||||
|
||||
// disable TRT as it does not support axis=0 as used by the model
|
||||
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/flatten_broadcast.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/flatten_broadcast.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue