diff --git a/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc index d15db37fdd..4b6da84ff5 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc @@ -41,7 +41,7 @@ Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, auto* coreml_flatten = layer->mutable_flattento2d(); NodeAttrHelper helper(node); - const int64_t axis = helper.Get("axis ", 1); + const int64_t axis = helper.Get("axis", 1); coreml_flatten->set_axis(axis); *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); diff --git a/onnxruntime/test/providers/cpu/nn/flatten_op_test.cc b/onnxruntime/test/providers/cpu/nn/flatten_op_test.cc index 7846b27713..a8dd7f24e7 100644 --- a/onnxruntime/test/providers/cpu/nn/flatten_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/flatten_op_test.cc @@ -2,6 +2,8 @@ // Licensed under the MIT License. #include "gtest/gtest.h" + +#include "core/session/environment.h" #include "test/providers/provider_test_utils.h" namespace onnxruntime { @@ -64,5 +66,37 @@ TEST_F(FlattenOpTest, Flatten_neg_axis3) { test_.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +// Regression test primarily for CoreML. +// The CoreML EP implementation was not reading the axis attribute correctly causing an incorrect output shape to be +// produced for a Flatten node. That issue gets hidden as the Tensor to write the output to is created by the +// CoreML EP using the inferred output shape (which is correct) and we provide the Tensor's buffer but not the shape +// when executing the CoreML model. As the flatten isn't changing or moving any data nothing breaks when we test +// with only a Flatten node in the model. +// +// This test uses a model with a Flatten followed by a Mul which requires broadcasting. Both nodes are handled by +// CoreML, so if the axis is not correctly processed the output from Flatten will not be broadcastable and the CoreML +// model execution will fail. +TEST(FlattenOpModelTest, Flatten_broadcast) { + auto model_uri = ORT_TSTR("testdata/flatten_broadcast.onnx"); + + std::shared_ptr model; + auto status = Model::Load(model_uri, model, nullptr, GetEnvironment().GetLoggingManager()->DefaultLogger()); + + OpTester tester("flatten_broadcast"); + tester.SetModelCache(model); + + tester.AddInput("X", {4}, {0.f, 1.f, 2.f, 3.f}); + tester.AddInput("Y", {3, 4}, + {0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, + 8.f, 9.f, 10.f, 11.f}); + tester.AddOutput("Z", {3, 4}, + {0.f, 1.f, 4.f, 9.f, + 0.f, 5.f, 12.f, 21.f, + 0.f, 9.f, 20.f, 33.f}); + + // disable TRT as it does not support axis=0 as used by the model + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/flatten_broadcast.onnx b/onnxruntime/test/testdata/flatten_broadcast.onnx new file mode 100644 index 0000000000..098c0821e6 Binary files /dev/null and b/onnxruntime/test/testdata/flatten_broadcast.onnx differ