diff --git a/onnxruntime/core/providers/cpu/controlflow/loop.cc b/onnxruntime/core/providers/cpu/controlflow/loop.cc index ff29daec06..84d911fb0e 100644 --- a/onnxruntime/core/providers/cpu/controlflow/loop.cc +++ b/onnxruntime/core/providers/cpu/controlflow/loop.cc @@ -203,22 +203,31 @@ Status LoopImpl::Initialize() { } auto* max_trip_count_tensor = context_.Input(0); - auto iter_num_rank = max_trip_count_tensor ? max_trip_count_tensor->Shape().NumDimensions() : 0; auto* cond_tensor = context_.Input(1); - auto condition_rank = cond_tensor ? cond_tensor->Shape().NumDimensions() : 0; - if (condition_rank >= 2 || iter_num_rank >= 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "'Loop' input 'M' and 'cond' should be a scalar tensor, but have ranks of ", - condition_rank, " and ", iter_num_rank); + if (max_trip_count_tensor) { + if (max_trip_count_tensor->Shape().Size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'Loop' input 'M' should be a scalar tensor. Got shape of ", + max_trip_count_tensor->Shape()); + } + } + + if (cond_tensor) { + if (cond_tensor->Shape().Size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'Loop' input 'cond' should be a scalar tensor. Got shape of ", + cond_tensor->Shape()); + } } AllocatorPtr allocator; status = context_.GetTempSpaceAllocator(&allocator); ORT_RETURN_IF_ERROR(status); - condition_mlvalue_ = MakeScalarMLValue(allocator, condition_, condition_rank); + auto iter_num_rank = subgraph_inputs[0]->Shape()->dim_size(); + auto condition_rank = subgraph_inputs[1]->Shape()->dim_size(); + iter_num_mlvalue_ = MakeScalarMLValue(allocator, 0, iter_num_rank); + condition_mlvalue_ = MakeScalarMLValue(allocator, condition_, condition_rank); subgraph_input_names_.reserve(num_subgraph_inputs_); for (int i = 0; i < num_subgraph_inputs_; ++i) { diff --git a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc index 6f7c93ac2b..d4293450f1 100644 --- a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc @@ -25,6 +25,8 @@ struct RunOptions { bool mixed_execution_providers = false; bool init_cond_1d_tensor = true; bool init_iter_num_1d_tensor = true; + bool subgraph_cond_1d_tensor = true; + bool subgraph_iter_num_1d_tensor = true; }; } // namespace @@ -91,8 +93,8 @@ static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options bool use_null_typeproto = !include_types && !include_dim_value && !graph_output_shape_required; - bool is_init_cond_1d = options.init_cond_1d_tensor; - bool is_init_iter_num_1d = options.init_iter_num_1d_tensor; + bool is_cond_1d = options.subgraph_cond_1d_tensor; + bool is_iter_num_1d = options.subgraph_iter_num_1d_tensor; Model model("Loop subgraph"); auto& graph = model.MainGraph(); @@ -156,16 +158,16 @@ static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options // graph inputs auto& iter_num_in = graph.GetOrCreateNodeArg("iter_num_in", - is_init_iter_num_1d ? &int64_tensor_single_dim : &int64_scalar); + is_iter_num_1d ? &int64_tensor_single_dim : &int64_scalar); auto& cond_in = graph.GetOrCreateNodeArg("cond_in", - is_init_cond_1d ? &bool_tensor_single_dim : &bool_scalar); + is_cond_1d ? &bool_tensor_single_dim : &bool_scalar); auto& loop_var_0_in = graph.GetOrCreateNodeArg("loop_var_0_in", &float_tensor_single_dim); auto& loop_var_1_in = graph.GetOrCreateNodeArg("loop_var_1_in", &float_tensor_single_dim); auto& iter_num_float = graph.GetOrCreateNodeArg("iter_num_float", - is_init_iter_num_1d ? &float_tensor_single_dim : &float_scalar); - auto& iter_num_float_tensor = is_init_iter_num_1d ? iter_num_float - : graph.GetOrCreateNodeArg("iter_num_float_tensor", &float_tensor_single_dim); + is_iter_num_1d ? &float_tensor_single_dim : &float_scalar); + auto& iter_num_float_tensor = is_iter_num_1d ? iter_num_float + : graph.GetOrCreateNodeArg("iter_num_float_tensor", &float_tensor_single_dim); // outer scope values. need type but not shape. auto& outer_scope_0 = graph.GetOrCreateNodeArg("outer_scope_0", &float_tensor); @@ -206,7 +208,7 @@ static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options } // Unsqueeze iter_num_float, if initial iter_num is scalar. - if (!is_init_iter_num_1d) { + if (!is_iter_num_1d) { auto& unsqueeze = graph.AddNode("iter_num_unsqueeze", "Unsqueeze", "Unsqueeze iter_num_float to tensor of single dim", {&iter_num_float}, {&iter_num_float_tensor}); @@ -393,6 +395,8 @@ void ExitDueToCond(const RunOptions& options) { \ options.init_iter_num_1d_tensor = iter_num_1d; \ options.init_cond_1d_tensor = cond_1d; \ + options.subgraph_iter_num_1d_tensor = iter_num_1d; \ + options.subgraph_cond_1d_tensor = cond_1d; \ \ ExitDueToCond(options); \ } @@ -406,6 +410,28 @@ TEST_EXIT_DUE_TO_COND(ExitDueToCond_DimsInSubGraph_ScalarIter, false, false, tru TEST_EXIT_DUE_TO_COND(ExitDueToCond_DimsInSubGraph_ScalarCond, false, true, false); TEST_EXIT_DUE_TO_COND(ExitDueToCond_DimsInSubGraph_ScalarBoth, false, false, false); +// check that a rank mismatch between the Loop 'M' and 'cond' inputs and the subgraph is handled gracefully +// if both equate to a scalar (rank 0 or rank 1 with shape of {1}) +TEST(Loop, LoopSubgraphRankMismatch) { + RunOptions options{}; + options.include_dim_values_in_main_graph = false; + options.include_dim_values_in_subgraph = false; + + options.init_iter_num_1d_tensor = true; + options.init_cond_1d_tensor = true; + options.subgraph_cond_1d_tensor = false; + options.subgraph_cond_1d_tensor = false; + + ExitDueToCond(options); + + options.init_iter_num_1d_tensor = false; + options.init_cond_1d_tensor = false; + options.subgraph_cond_1d_tensor = true; + options.subgraph_cond_1d_tensor = true; + + ExitDueToCond(options); +} + TEST(Loop, ExitDueToMaxIterations) { int64_t max_iterations = 2; const int64_t expected_num_iterations = 2;