[CoreML EP] Fix ArgMaxOpBuilder::AddToModelBuilderImpl() nullptr Node access. (#21797)

This commit is contained in:
Edward Chen 2024-08-23 10:19:53 -07:00 committed by GitHub
parent 4c4ae1e490
commit 5726318ec0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 87 additions and 24 deletions

View file

@ -117,13 +117,14 @@ class ValidNodes {
return (current_ != other.current_);
}
void operator++() {
NodeIterator<TIterator>& operator++() {
if (current_ < end_) {
while (++current_ != end_) {
if (*current_ != nullptr && (!apply_filter_ || (*filter_func_)((*current_)->Index()) == false))
break;
}
}
return *this;
}
NodeIterator<TIterator> operator++(int) {

View file

@ -38,13 +38,14 @@ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
// 2. Otherwise, we add Argmax layer normally
if (node.GetOutputEdgesCount() == 1) {
auto it = node.OutputEdgesBegin();
const auto* succ_node(graph_viewer.GetNode(it->GetNode().Index()));
const auto* next_node_in_partition = graph_viewer.GetNode(it->GetNode().Index());
// If Argmax's successive node is a Cast from int64 to int32 output
// The 'cast to' type is checked in operator supported related, omit the check here
if (succ_node->OpType() == "Cast") {
// The 'cast to' type is checked when determining operator support (see CastOpBuilder::IsOpSupportedImpl())
// so we omit the check here
if (next_node_in_partition != nullptr && next_node_in_partition->OpType() == "Cast") {
// Skip the cast's input/argmax's output
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
*layer->mutable_output()->Add() = succ_node->OutputDefs()[0]->Name();
*layer->mutable_output()->Add() = next_node_in_partition->OutputDefs()[0]->Name();
model_builder.AddLayer(std::move(layer));
return Status::OK();
}

View file

@ -36,11 +36,6 @@ bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara
return false;
}
if (node.GetInputEdgesCount() > 1) {
LOGS(logger, VERBOSE) << "Multiple nodes producing Cast's input.";
return false;
}
const auto& prec_node = node.InputEdgesBegin()->GetNode();
/*Cast node is only aimed for supporting argmax and we are only handling the case where an argmax

View file

@ -2,6 +2,8 @@
// Licensed under the MIT License.
#include "core/common/logging/logging.h"
#include "core/graph/graph.h"
#include "core/graph/graph_viewer.h"
#include "core/providers/coreml/coreml_execution_provider.h"
#include "core/providers/coreml/coreml_provider_factory.h"
#include "core/session/inference_session.h"
@ -92,7 +94,7 @@ TEST(CoreMLExecutionProviderTest, FunctionTest) {
feeds.insert(std::make_pair("Y", ml_value_y));
feeds.insert(std::make_pair("Z", ml_value_z));
RunAndVerifyOutputsWithEP(model_file_name, "CoreMLExecutionProviderTest.FunctionTest",
RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(),
MakeCoreMLExecutionProvider(),
feeds);
#else
@ -118,9 +120,50 @@ TEST(CoreMLExecutionProviderTest, ArgMaxCastTest) {
NameMLValMap feeds;
feeds.insert(std::make_pair("X", ml_value_x));
RunAndVerifyOutputsWithEP(model_file_name, "CoreMLExecutionProviderTest.ArgMaxCastTest",
EPVerificationParams verification_params{};
verification_params.ep_node_assignment = ExpectedEPNodeAssignment::All;
RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(),
MakeCoreMLExecutionProvider(),
feeds);
feeds,
verification_params);
#else
TestModelLoad(model_file_name, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All);
#endif
}
TEST(CoreMLExecutionProviderTest, ArgMaxUnsupportedCastTest) {
const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/coreml_argmax_unsupported_cast_test.onnx");
#if defined(__APPLE__)
std::vector<int64_t> dims_mul_x = {3, 2, 2};
std::vector<float> values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
OrtValue ml_value_x;
AllocatorPtr allocator = std::make_shared<CPUAllocator>();
CreateMLValue<float>(allocator, dims_mul_x, values_mul_x, &ml_value_x);
NameMLValMap feeds;
feeds.insert(std::make_pair("X", ml_value_x));
const std::function<void(const Graph&)> graph_verifier = [](const Graph& graph) {
GraphViewer graph_viewer{graph};
const auto& node_indices_in_order = graph_viewer.GetNodesInTopologicalOrder();
ASSERT_EQ(node_indices_in_order.size(), size_t{2});
// second node should be an unsupported Cast
const auto* cast_node = graph.GetNode(node_indices_in_order[1]);
ASSERT_NE(cast_node, nullptr);
ASSERT_EQ(cast_node->OpType(), "Cast");
ASSERT_EQ(cast_node->GetExecutionProviderType(), kCpuExecutionProvider);
};
EPVerificationParams verification_params{};
verification_params.ep_node_assignment = ExpectedEPNodeAssignment::Some;
verification_params.graph_verifier = &graph_verifier;
RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(),
MakeCoreMLExecutionProvider(),
feeds,
verification_params);
#else
TestModelLoad(model_file_name, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::Some);
#endif
@ -184,7 +227,7 @@ TEST(CoreMLExecutionProviderTest, TestOrtFormatModel) {
NameMLValMap feeds;
feeds.insert(std::make_pair("Input3", ml_value));
RunAndVerifyOutputsWithEP(model_file_name, "CoreMLExecutionProviderTest.TestOrtFormatModel",
RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(),
MakeCoreMLExecutionProvider(),
feeds);
#else

View file

@ -1,4 +1,5 @@
:ト

:ト
F
Xargmax_output_int64argmax"ArgMax*
axis *
@ -15,4 +16,4 @@ F



B
B

View file

@ -1,16 +1,18 @@
import onnx
from onnx import TensorProto, helper
# CoreML EP currently handles a special case for supporting ArgMax op
# Please see in <repo_root>/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc and
# <repo_root>/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc
# We have this separated test script to generate graph for the case: An ArgMax followed by a Cast to int32 type
# CoreML EP currently handles a special case for supporting ArgMax followed by a Cast to int32.
# Please see <repo_root>/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc and
# <repo_root>/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc.
# This script generates graphs for these cases:
# - An ArgMax followed by a supported Cast to int32 type
# - An ArgMax followed by an unsupported Cast to a type other than int32
def GenerateModel(model_name): # noqa: N802
def GenerateModel(model_name, cast_to_dtype): # noqa: N802
nodes = [
helper.make_node("ArgMax", ["X"], ["argmax_output_int64"], "argmax", axis=1, keepdims=1),
helper.make_node("Cast", ["argmax_output_int64"], ["Y"], "cast", to=6), # cast to int32 type
helper.make_node("Cast", ["argmax_output_int64"], ["Y"], "cast", to=cast_to_dtype),
]
graph = helper.make_graph(
@ -20,7 +22,7 @@ def GenerateModel(model_name): # noqa: N802
helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 2, 2]),
],
[ # output
helper.make_tensor_value_info("Y", TensorProto.INT32, [3, 1, 2]),
helper.make_tensor_value_info("Y", cast_to_dtype, [3, 1, 2]),
],
)
@ -29,4 +31,5 @@ def GenerateModel(model_name): # noqa: N802
if __name__ == "__main__":
GenerateModel("coreml_argmax_cast_test.onnx")
GenerateModel("coreml_argmax_cast_test.onnx", TensorProto.INT32)
GenerateModel("coreml_argmax_unsupported_cast_test.onnx", TensorProto.UINT32)

View file

@ -0,0 +1,19 @@


F
Xargmax_output_int64argmax"ArgMax*
axis *
keepdims 
/
argmax_output_int64Ycast"Cast*
to  CoreML_ArgMax_Cast_TestZ
X



b
Y
 


B