diff --git a/include/onnxruntime/core/graph/graph_nodes.h b/include/onnxruntime/core/graph/graph_nodes.h index 4fa2848a1d..aab5f2699d 100644 --- a/include/onnxruntime/core/graph/graph_nodes.h +++ b/include/onnxruntime/core/graph/graph_nodes.h @@ -117,13 +117,14 @@ class ValidNodes { return (current_ != other.current_); } - void operator++() { + NodeIterator& operator++() { if (current_ < end_) { while (++current_ != end_) { if (*current_ != nullptr && (!apply_filter_ || (*filter_func_)((*current_)->Index()) == false)) break; } } + return *this; } NodeIterator operator++(int) { diff --git a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc index e9a8176c83..bc8b2d1a35 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc @@ -38,13 +38,14 @@ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // 2. Otherwise, we add Argmax layer normally if (node.GetOutputEdgesCount() == 1) { auto it = node.OutputEdgesBegin(); - const auto* succ_node(graph_viewer.GetNode(it->GetNode().Index())); + const auto* next_node_in_partition = graph_viewer.GetNode(it->GetNode().Index()); // If Argmax's successive node is a Cast from int64 to int32 output - // The 'cast to' type is checked in operator supported related, omit the check here - if (succ_node->OpType() == "Cast") { + // The 'cast to' type is checked when determining operator support (see CastOpBuilder::IsOpSupportedImpl()) + // so we omit the check here + if (next_node_in_partition != nullptr && next_node_in_partition->OpType() == "Cast") { // Skip the cast's input/argmax's output *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); - *layer->mutable_output()->Add() = succ_node->OutputDefs()[0]->Name(); + *layer->mutable_output()->Add() = next_node_in_partition->OutputDefs()[0]->Name(); model_builder.AddLayer(std::move(layer)); return Status::OK(); } diff --git a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc index 70053c2c60..fc8879abbe 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc @@ -36,11 +36,6 @@ bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara return false; } - if (node.GetInputEdgesCount() > 1) { - LOGS(logger, VERBOSE) << "Multiple nodes producing Cast's input."; - return false; - } - const auto& prec_node = node.InputEdgesBegin()->GetNode(); /*Cast node is only aimed for supporting argmax and we are only handling the case where an argmax diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc index 0f068ba48d..daa24db134 100644 --- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc +++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc @@ -2,6 +2,8 @@ // Licensed under the MIT License. #include "core/common/logging/logging.h" +#include "core/graph/graph.h" +#include "core/graph/graph_viewer.h" #include "core/providers/coreml/coreml_execution_provider.h" #include "core/providers/coreml/coreml_provider_factory.h" #include "core/session/inference_session.h" @@ -92,7 +94,7 @@ TEST(CoreMLExecutionProviderTest, FunctionTest) { feeds.insert(std::make_pair("Y", ml_value_y)); feeds.insert(std::make_pair("Z", ml_value_z)); - RunAndVerifyOutputsWithEP(model_file_name, "CoreMLExecutionProviderTest.FunctionTest", + RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(), MakeCoreMLExecutionProvider(), feeds); #else @@ -118,9 +120,50 @@ TEST(CoreMLExecutionProviderTest, ArgMaxCastTest) { NameMLValMap feeds; feeds.insert(std::make_pair("X", ml_value_x)); - RunAndVerifyOutputsWithEP(model_file_name, "CoreMLExecutionProviderTest.ArgMaxCastTest", + EPVerificationParams verification_params{}; + verification_params.ep_node_assignment = ExpectedEPNodeAssignment::All; + + RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(), MakeCoreMLExecutionProvider(), - feeds); + feeds, + verification_params); +#else + TestModelLoad(model_file_name, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, ArgMaxUnsupportedCastTest) { + const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/coreml_argmax_unsupported_cast_test.onnx"); + +#if defined(__APPLE__) + std::vector dims_mul_x = {3, 2, 2}; + std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; + OrtValue ml_value_x; + AllocatorPtr allocator = std::make_shared(); + CreateMLValue(allocator, dims_mul_x, values_mul_x, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + const std::function graph_verifier = [](const Graph& graph) { + GraphViewer graph_viewer{graph}; + const auto& node_indices_in_order = graph_viewer.GetNodesInTopologicalOrder(); + ASSERT_EQ(node_indices_in_order.size(), size_t{2}); + // second node should be an unsupported Cast + const auto* cast_node = graph.GetNode(node_indices_in_order[1]); + ASSERT_NE(cast_node, nullptr); + ASSERT_EQ(cast_node->OpType(), "Cast"); + ASSERT_EQ(cast_node->GetExecutionProviderType(), kCpuExecutionProvider); + }; + + EPVerificationParams verification_params{}; + verification_params.ep_node_assignment = ExpectedEPNodeAssignment::Some; + verification_params.graph_verifier = &graph_verifier; + + RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(), + MakeCoreMLExecutionProvider(), + feeds, + verification_params); #else TestModelLoad(model_file_name, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::Some); #endif @@ -184,7 +227,7 @@ TEST(CoreMLExecutionProviderTest, TestOrtFormatModel) { NameMLValMap feeds; feeds.insert(std::make_pair("Input3", ml_value)); - RunAndVerifyOutputsWithEP(model_file_name, "CoreMLExecutionProviderTest.TestOrtFormatModel", + RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(), MakeCoreMLExecutionProvider(), feeds); #else diff --git a/onnxruntime/test/testdata/coreml_argmax_cast_test.onnx b/onnxruntime/test/testdata/coreml_argmax_cast_test.onnx index db806f296a..931bd30dbe 100644 --- a/onnxruntime/test/testdata/coreml_argmax_cast_test.onnx +++ b/onnxruntime/test/testdata/coreml_argmax_cast_test.onnx @@ -1,4 +1,5 @@ -:Ä + +:Ä F Xargmax_output_int64argmax"ArgMax* axis * @@ -15,4 +16,4 @@ F    -B \ No newline at end of file +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/coreml_argmax_cast_test.py b/onnxruntime/test/testdata/coreml_argmax_cast_test.py index acf24ac379..6cc2531113 100644 --- a/onnxruntime/test/testdata/coreml_argmax_cast_test.py +++ b/onnxruntime/test/testdata/coreml_argmax_cast_test.py @@ -1,16 +1,18 @@ import onnx from onnx import TensorProto, helper -# CoreML EP currently handles a special case for supporting ArgMax op -# Please see in /onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc and -# /onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc -# We have this separated test script to generate graph for the case: An ArgMax followed by a Cast to int32 type +# CoreML EP currently handles a special case for supporting ArgMax followed by a Cast to int32. +# Please see /onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc and +# /onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc. +# This script generates graphs for these cases: +# - An ArgMax followed by a supported Cast to int32 type +# - An ArgMax followed by an unsupported Cast to a type other than int32 -def GenerateModel(model_name): # noqa: N802 +def GenerateModel(model_name, cast_to_dtype): # noqa: N802 nodes = [ helper.make_node("ArgMax", ["X"], ["argmax_output_int64"], "argmax", axis=1, keepdims=1), - helper.make_node("Cast", ["argmax_output_int64"], ["Y"], "cast", to=6), # cast to int32 type + helper.make_node("Cast", ["argmax_output_int64"], ["Y"], "cast", to=cast_to_dtype), ] graph = helper.make_graph( @@ -20,7 +22,7 @@ def GenerateModel(model_name): # noqa: N802 helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 2, 2]), ], [ # output - helper.make_tensor_value_info("Y", TensorProto.INT32, [3, 1, 2]), + helper.make_tensor_value_info("Y", cast_to_dtype, [3, 1, 2]), ], ) @@ -29,4 +31,5 @@ def GenerateModel(model_name): # noqa: N802 if __name__ == "__main__": - GenerateModel("coreml_argmax_cast_test.onnx") + GenerateModel("coreml_argmax_cast_test.onnx", TensorProto.INT32) + GenerateModel("coreml_argmax_unsupported_cast_test.onnx", TensorProto.UINT32) diff --git a/onnxruntime/test/testdata/coreml_argmax_unsupported_cast_test.onnx b/onnxruntime/test/testdata/coreml_argmax_unsupported_cast_test.onnx new file mode 100644 index 0000000000..d5aea9110c --- /dev/null +++ b/onnxruntime/test/testdata/coreml_argmax_unsupported_cast_test.onnx @@ -0,0 +1,19 @@ + +:Ä +F +Xargmax_output_int64argmax"ArgMax* +axis * +keepdims  +/ +argmax_output_int64Ycast"Cast* +to  CoreML_ArgMax_Cast_TestZ +X + + + +b +Y +  + + +B \ No newline at end of file