From 55c3f4b28f2a5e4118df7ef44cc8e7ebe7ae7cf4 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 24 May 2023 08:27:32 +1000 Subject: [PATCH] Fix CoreML Flatten handling of axis attribute (#16046) ### Description The CoreML EP implementation was not reading the axis attribute correctly causing an incorrect output shape to be produced for a Flatten node. That issue gets hidden as the Tensor to write the output to is created by the CoreML EP using the inferred output shape (which is correct) and we provide the Tensor's buffer but not the shape when executing the CoreML model. As the flatten isn't changing or moving any data nothing breaks when we test with only a Flatten node in the model. Fix the attribute name and add a test that uses a model with a Flatten followed by a Mul which requires broadcasting. Both nodes are handled by CoreML, so if the axis is not correctly processed the output from Flatten will not be broadcastable and the CoreML model execution will fail. ### Motivation and Context Bug fix. --- .../builders/impl/flatten_op_builder.cc | 2 +- .../test/providers/cpu/nn/flatten_op_test.cc | 34 ++++++++++++++++++ .../test/testdata/flatten_broadcast.onnx | Bin 0 -> 156 bytes 3 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/test/testdata/flatten_broadcast.onnx diff --git a/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc index d15db37fdd..4b6da84ff5 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc @@ -41,7 +41,7 @@ Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, auto* coreml_flatten = layer->mutable_flattento2d(); NodeAttrHelper helper(node); - const int64_t axis = helper.Get("axis ", 1); + const int64_t axis = helper.Get("axis", 1); coreml_flatten->set_axis(axis); *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); diff --git a/onnxruntime/test/providers/cpu/nn/flatten_op_test.cc b/onnxruntime/test/providers/cpu/nn/flatten_op_test.cc index 7846b27713..a8dd7f24e7 100644 --- a/onnxruntime/test/providers/cpu/nn/flatten_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/flatten_op_test.cc @@ -2,6 +2,8 @@ // Licensed under the MIT License. #include "gtest/gtest.h" + +#include "core/session/environment.h" #include "test/providers/provider_test_utils.h" namespace onnxruntime { @@ -64,5 +66,37 @@ TEST_F(FlattenOpTest, Flatten_neg_axis3) { test_.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +// Regression test primarily for CoreML. +// The CoreML EP implementation was not reading the axis attribute correctly causing an incorrect output shape to be +// produced for a Flatten node. That issue gets hidden as the Tensor to write the output to is created by the +// CoreML EP using the inferred output shape (which is correct) and we provide the Tensor's buffer but not the shape +// when executing the CoreML model. As the flatten isn't changing or moving any data nothing breaks when we test +// with only a Flatten node in the model. +// +// This test uses a model with a Flatten followed by a Mul which requires broadcasting. Both nodes are handled by +// CoreML, so if the axis is not correctly processed the output from Flatten will not be broadcastable and the CoreML +// model execution will fail. +TEST(FlattenOpModelTest, Flatten_broadcast) { + auto model_uri = ORT_TSTR("testdata/flatten_broadcast.onnx"); + + std::shared_ptr model; + auto status = Model::Load(model_uri, model, nullptr, GetEnvironment().GetLoggingManager()->DefaultLogger()); + + OpTester tester("flatten_broadcast"); + tester.SetModelCache(model); + + tester.AddInput("X", {4}, {0.f, 1.f, 2.f, 3.f}); + tester.AddInput("Y", {3, 4}, + {0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, + 8.f, 9.f, 10.f, 11.f}); + tester.AddOutput("Z", {3, 4}, + {0.f, 1.f, 4.f, 9.f, + 0.f, 5.f, 12.f, 21.f, + 0.f, 9.f, 20.f, 33.f}); + + // disable TRT as it does not support axis=0 as used by the model + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/flatten_broadcast.onnx b/onnxruntime/test/testdata/flatten_broadcast.onnx new file mode 100644 index 0000000000000000000000000000000000000000..098c0821e6fbd664303fcd458a510b5ac8676b48 GIT binary patch literal 156 zcmd=jVvG=CiHJ8!QDS$?Nh~Qz&C}xMVo9vXES6wcz{q69z$MHD zmgi!O6k?20V)iY~v0@Ms0~rvX2GJ3pRFt2XlAKsv62%WPkBf_ggHecui;06JN*JV% akBbK=zyTCs1_>qs1*6afomjXS1Ox$F8yoxp literal 0 HcmV?d00001