mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
[CoreML EP] Fix ArgMaxOpBuilder::AddToModelBuilderImpl() nullptr Node access. (#21797)
This commit is contained in:
parent
4c4ae1e490
commit
5726318ec0
7 changed files with 87 additions and 24 deletions
|
|
@ -117,13 +117,14 @@ class ValidNodes {
|
|||
return (current_ != other.current_);
|
||||
}
|
||||
|
||||
void operator++() {
|
||||
NodeIterator<TIterator>& operator++() {
|
||||
if (current_ < end_) {
|
||||
while (++current_ != end_) {
|
||||
if (*current_ != nullptr && (!apply_filter_ || (*filter_func_)((*current_)->Index()) == false))
|
||||
break;
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeIterator<TIterator> operator++(int) {
|
||||
|
|
|
|||
|
|
@ -38,13 +38,14 @@ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
// 2. Otherwise, we add Argmax layer normally
|
||||
if (node.GetOutputEdgesCount() == 1) {
|
||||
auto it = node.OutputEdgesBegin();
|
||||
const auto* succ_node(graph_viewer.GetNode(it->GetNode().Index()));
|
||||
const auto* next_node_in_partition = graph_viewer.GetNode(it->GetNode().Index());
|
||||
// If Argmax's successive node is a Cast from int64 to int32 output
|
||||
// The 'cast to' type is checked in operator supported related, omit the check here
|
||||
if (succ_node->OpType() == "Cast") {
|
||||
// The 'cast to' type is checked when determining operator support (see CastOpBuilder::IsOpSupportedImpl())
|
||||
// so we omit the check here
|
||||
if (next_node_in_partition != nullptr && next_node_in_partition->OpType() == "Cast") {
|
||||
// Skip the cast's input/argmax's output
|
||||
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
|
||||
*layer->mutable_output()->Add() = succ_node->OutputDefs()[0]->Name();
|
||||
*layer->mutable_output()->Add() = next_node_in_partition->OutputDefs()[0]->Name();
|
||||
model_builder.AddLayer(std::move(layer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,11 +36,6 @@ bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara
|
|||
return false;
|
||||
}
|
||||
|
||||
if (node.GetInputEdgesCount() > 1) {
|
||||
LOGS(logger, VERBOSE) << "Multiple nodes producing Cast's input.";
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& prec_node = node.InputEdgesBegin()->GetNode();
|
||||
|
||||
/*Cast node is only aimed for supporting argmax and we are only handling the case where an argmax
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/graph/graph.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/providers/coreml/coreml_execution_provider.h"
|
||||
#include "core/providers/coreml/coreml_provider_factory.h"
|
||||
#include "core/session/inference_session.h"
|
||||
|
|
@ -92,7 +94,7 @@ TEST(CoreMLExecutionProviderTest, FunctionTest) {
|
|||
feeds.insert(std::make_pair("Y", ml_value_y));
|
||||
feeds.insert(std::make_pair("Z", ml_value_z));
|
||||
|
||||
RunAndVerifyOutputsWithEP(model_file_name, "CoreMLExecutionProviderTest.FunctionTest",
|
||||
RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(),
|
||||
MakeCoreMLExecutionProvider(),
|
||||
feeds);
|
||||
#else
|
||||
|
|
@ -118,9 +120,50 @@ TEST(CoreMLExecutionProviderTest, ArgMaxCastTest) {
|
|||
NameMLValMap feeds;
|
||||
feeds.insert(std::make_pair("X", ml_value_x));
|
||||
|
||||
RunAndVerifyOutputsWithEP(model_file_name, "CoreMLExecutionProviderTest.ArgMaxCastTest",
|
||||
EPVerificationParams verification_params{};
|
||||
verification_params.ep_node_assignment = ExpectedEPNodeAssignment::All;
|
||||
|
||||
RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(),
|
||||
MakeCoreMLExecutionProvider(),
|
||||
feeds);
|
||||
feeds,
|
||||
verification_params);
|
||||
#else
|
||||
TestModelLoad(model_file_name, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(CoreMLExecutionProviderTest, ArgMaxUnsupportedCastTest) {
|
||||
const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/coreml_argmax_unsupported_cast_test.onnx");
|
||||
|
||||
#if defined(__APPLE__)
|
||||
std::vector<int64_t> dims_mul_x = {3, 2, 2};
|
||||
std::vector<float> values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
|
||||
OrtValue ml_value_x;
|
||||
AllocatorPtr allocator = std::make_shared<CPUAllocator>();
|
||||
CreateMLValue<float>(allocator, dims_mul_x, values_mul_x, &ml_value_x);
|
||||
|
||||
NameMLValMap feeds;
|
||||
feeds.insert(std::make_pair("X", ml_value_x));
|
||||
|
||||
const std::function<void(const Graph&)> graph_verifier = [](const Graph& graph) {
|
||||
GraphViewer graph_viewer{graph};
|
||||
const auto& node_indices_in_order = graph_viewer.GetNodesInTopologicalOrder();
|
||||
ASSERT_EQ(node_indices_in_order.size(), size_t{2});
|
||||
// second node should be an unsupported Cast
|
||||
const auto* cast_node = graph.GetNode(node_indices_in_order[1]);
|
||||
ASSERT_NE(cast_node, nullptr);
|
||||
ASSERT_EQ(cast_node->OpType(), "Cast");
|
||||
ASSERT_EQ(cast_node->GetExecutionProviderType(), kCpuExecutionProvider);
|
||||
};
|
||||
|
||||
EPVerificationParams verification_params{};
|
||||
verification_params.ep_node_assignment = ExpectedEPNodeAssignment::Some;
|
||||
verification_params.graph_verifier = &graph_verifier;
|
||||
|
||||
RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(),
|
||||
MakeCoreMLExecutionProvider(),
|
||||
feeds,
|
||||
verification_params);
|
||||
#else
|
||||
TestModelLoad(model_file_name, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::Some);
|
||||
#endif
|
||||
|
|
@ -184,7 +227,7 @@ TEST(CoreMLExecutionProviderTest, TestOrtFormatModel) {
|
|||
NameMLValMap feeds;
|
||||
feeds.insert(std::make_pair("Input3", ml_value));
|
||||
|
||||
RunAndVerifyOutputsWithEP(model_file_name, "CoreMLExecutionProviderTest.TestOrtFormatModel",
|
||||
RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(),
|
||||
MakeCoreMLExecutionProvider(),
|
||||
feeds);
|
||||
#else
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
:ト
|
||||
|
||||
:ト
|
||||
F
|
||||
Xargmax_output_int64argmax"ArgMax*
|
||||
axis *
|
||||
|
|
@ -15,4 +16,4 @@ F
|
|||
|
||||
|
||||
|
||||
B
|
||||
B
|
||||
|
|
@ -1,16 +1,18 @@
|
|||
import onnx
|
||||
from onnx import TensorProto, helper
|
||||
|
||||
# CoreML EP currently handles a special case for supporting ArgMax op
|
||||
# Please see in <repo_root>/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc and
|
||||
# <repo_root>/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc
|
||||
# We have this separated test script to generate graph for the case: An ArgMax followed by a Cast to int32 type
|
||||
# CoreML EP currently handles a special case for supporting ArgMax followed by a Cast to int32.
|
||||
# Please see <repo_root>/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc and
|
||||
# <repo_root>/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc.
|
||||
# This script generates graphs for these cases:
|
||||
# - An ArgMax followed by a supported Cast to int32 type
|
||||
# - An ArgMax followed by an unsupported Cast to a type other than int32
|
||||
|
||||
|
||||
def GenerateModel(model_name): # noqa: N802
|
||||
def GenerateModel(model_name, cast_to_dtype): # noqa: N802
|
||||
nodes = [
|
||||
helper.make_node("ArgMax", ["X"], ["argmax_output_int64"], "argmax", axis=1, keepdims=1),
|
||||
helper.make_node("Cast", ["argmax_output_int64"], ["Y"], "cast", to=6), # cast to int32 type
|
||||
helper.make_node("Cast", ["argmax_output_int64"], ["Y"], "cast", to=cast_to_dtype),
|
||||
]
|
||||
|
||||
graph = helper.make_graph(
|
||||
|
|
@ -20,7 +22,7 @@ def GenerateModel(model_name): # noqa: N802
|
|||
helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 2, 2]),
|
||||
],
|
||||
[ # output
|
||||
helper.make_tensor_value_info("Y", TensorProto.INT32, [3, 1, 2]),
|
||||
helper.make_tensor_value_info("Y", cast_to_dtype, [3, 1, 2]),
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -29,4 +31,5 @@ def GenerateModel(model_name): # noqa: N802
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GenerateModel("coreml_argmax_cast_test.onnx")
|
||||
GenerateModel("coreml_argmax_cast_test.onnx", TensorProto.INT32)
|
||||
GenerateModel("coreml_argmax_unsupported_cast_test.onnx", TensorProto.UINT32)
|
||||
|
|
|
|||
19
onnxruntime/test/testdata/coreml_argmax_unsupported_cast_test.onnx
vendored
Normal file
19
onnxruntime/test/testdata/coreml_argmax_unsupported_cast_test.onnx
vendored
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
|
||||
:Ä
|
||||
F
|
||||
Xargmax_output_int64argmax"ArgMax*
|
||||
axis *
|
||||
keepdims
|
||||
/
|
||||
argmax_output_int64Ycast"Cast*
|
||||
to CoreML_ArgMax_Cast_TestZ
|
||||
X
|
||||
|
||||
|
||||
|
||||
b
|
||||
Y
|
||||
|
||||
|
||||
|
||||
B
|
||||
Loading…
Reference in a new issue