From a462328d9db70afac3eafffa4038d871b4aa1847 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 26 Jun 2019 09:33:11 +1000 Subject: [PATCH] Handle case where the Loop 'M' and 'cond' inputs can be considered scalars but the rank doesn't match the subgraph. Use the subgraph rank when creating the MLValue instance for the subgraph input. (#1285) --- .../core/providers/cpu/controlflow/loop.cc | 23 ++++++---- .../providers/cpu/controlflow/loop_test.cc | 42 +++++++++++++++---- 2 files changed, 50 insertions(+), 15 deletions(-) diff --git a/onnxruntime/core/providers/cpu/controlflow/loop.cc b/onnxruntime/core/providers/cpu/controlflow/loop.cc index ff29daec06..84d911fb0e 100644 --- a/onnxruntime/core/providers/cpu/controlflow/loop.cc +++ b/onnxruntime/core/providers/cpu/controlflow/loop.cc @@ -203,22 +203,31 @@ Status LoopImpl::Initialize() { } auto* max_trip_count_tensor = context_.Input(0); - auto iter_num_rank = max_trip_count_tensor ? max_trip_count_tensor->Shape().NumDimensions() : 0; auto* cond_tensor = context_.Input(1); - auto condition_rank = cond_tensor ? cond_tensor->Shape().NumDimensions() : 0; - if (condition_rank >= 2 || iter_num_rank >= 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "'Loop' input 'M' and 'cond' should be a scalar tensor, but have ranks of ", - condition_rank, " and ", iter_num_rank); + if (max_trip_count_tensor) { + if (max_trip_count_tensor->Shape().Size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'Loop' input 'M' should be a scalar tensor. Got shape of ", + max_trip_count_tensor->Shape()); + } + } + + if (cond_tensor) { + if (cond_tensor->Shape().Size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'Loop' input 'cond' should be a scalar tensor. Got shape of ", + cond_tensor->Shape()); + } } AllocatorPtr allocator; status = context_.GetTempSpaceAllocator(&allocator); ORT_RETURN_IF_ERROR(status); - condition_mlvalue_ = MakeScalarMLValue(allocator, condition_, condition_rank); + auto iter_num_rank = subgraph_inputs[0]->Shape()->dim_size(); + auto condition_rank = subgraph_inputs[1]->Shape()->dim_size(); + iter_num_mlvalue_ = MakeScalarMLValue(allocator, 0, iter_num_rank); + condition_mlvalue_ = MakeScalarMLValue(allocator, condition_, condition_rank); subgraph_input_names_.reserve(num_subgraph_inputs_); for (int i = 0; i < num_subgraph_inputs_; ++i) { diff --git a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc index 6f7c93ac2b..d4293450f1 100644 --- a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc @@ -25,6 +25,8 @@ struct RunOptions { bool mixed_execution_providers = false; bool init_cond_1d_tensor = true; bool init_iter_num_1d_tensor = true; + bool subgraph_cond_1d_tensor = true; + bool subgraph_iter_num_1d_tensor = true; }; } // namespace @@ -91,8 +93,8 @@ static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options bool use_null_typeproto = !include_types && !include_dim_value && !graph_output_shape_required; - bool is_init_cond_1d = options.init_cond_1d_tensor; - bool is_init_iter_num_1d = options.init_iter_num_1d_tensor; + bool is_cond_1d = options.subgraph_cond_1d_tensor; + bool is_iter_num_1d = options.subgraph_iter_num_1d_tensor; Model model("Loop subgraph"); auto& graph = model.MainGraph(); @@ -156,16 +158,16 @@ static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options // graph inputs auto& iter_num_in = graph.GetOrCreateNodeArg("iter_num_in", - is_init_iter_num_1d ? &int64_tensor_single_dim : &int64_scalar); + is_iter_num_1d ? &int64_tensor_single_dim : &int64_scalar); auto& cond_in = graph.GetOrCreateNodeArg("cond_in", - is_init_cond_1d ? &bool_tensor_single_dim : &bool_scalar); + is_cond_1d ? &bool_tensor_single_dim : &bool_scalar); auto& loop_var_0_in = graph.GetOrCreateNodeArg("loop_var_0_in", &float_tensor_single_dim); auto& loop_var_1_in = graph.GetOrCreateNodeArg("loop_var_1_in", &float_tensor_single_dim); auto& iter_num_float = graph.GetOrCreateNodeArg("iter_num_float", - is_init_iter_num_1d ? &float_tensor_single_dim : &float_scalar); - auto& iter_num_float_tensor = is_init_iter_num_1d ? iter_num_float - : graph.GetOrCreateNodeArg("iter_num_float_tensor", &float_tensor_single_dim); + is_iter_num_1d ? &float_tensor_single_dim : &float_scalar); + auto& iter_num_float_tensor = is_iter_num_1d ? iter_num_float + : graph.GetOrCreateNodeArg("iter_num_float_tensor", &float_tensor_single_dim); // outer scope values. need type but not shape. auto& outer_scope_0 = graph.GetOrCreateNodeArg("outer_scope_0", &float_tensor); @@ -206,7 +208,7 @@ static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options } // Unsqueeze iter_num_float, if initial iter_num is scalar. - if (!is_init_iter_num_1d) { + if (!is_iter_num_1d) { auto& unsqueeze = graph.AddNode("iter_num_unsqueeze", "Unsqueeze", "Unsqueeze iter_num_float to tensor of single dim", {&iter_num_float}, {&iter_num_float_tensor}); @@ -393,6 +395,8 @@ void ExitDueToCond(const RunOptions& options) { \ options.init_iter_num_1d_tensor = iter_num_1d; \ options.init_cond_1d_tensor = cond_1d; \ + options.subgraph_iter_num_1d_tensor = iter_num_1d; \ + options.subgraph_cond_1d_tensor = cond_1d; \ \ ExitDueToCond(options); \ } @@ -406,6 +410,28 @@ TEST_EXIT_DUE_TO_COND(ExitDueToCond_DimsInSubGraph_ScalarIter, false, false, tru TEST_EXIT_DUE_TO_COND(ExitDueToCond_DimsInSubGraph_ScalarCond, false, true, false); TEST_EXIT_DUE_TO_COND(ExitDueToCond_DimsInSubGraph_ScalarBoth, false, false, false); +// check that a rank mismatch between the Loop 'M' and 'cond' inputs and the subgraph is handled gracefully +// if both equate to a scalar (rank 0 or rank 1 with shape of {1}) +TEST(Loop, LoopSubgraphRankMismatch) { + RunOptions options{}; + options.include_dim_values_in_main_graph = false; + options.include_dim_values_in_subgraph = false; + + options.init_iter_num_1d_tensor = true; + options.init_cond_1d_tensor = true; + options.subgraph_cond_1d_tensor = false; + options.subgraph_cond_1d_tensor = false; + + ExitDueToCond(options); + + options.init_iter_num_1d_tensor = false; + options.init_cond_1d_tensor = false; + options.subgraph_cond_1d_tensor = true; + options.subgraph_cond_1d_tensor = true; + + ExitDueToCond(options); +} + TEST(Loop, ExitDueToMaxIterations) { int64_t max_iterations = 2; const int64_t expected_num_iterations = 2;