Handle case where the Loop 'M' and 'cond' inputs can be considered scalars but the rank doesn't match the subgraph. Use the subgraph rank when creating the MLValue instance for the subgraph input. (#1285)

This commit is contained in:
Scott McKay 2019-06-26 09:33:11 +10:00 committed by GitHub
parent c0cf2213bc
commit a462328d9d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 15 deletions

View file

@ -203,22 +203,31 @@ Status LoopImpl::Initialize() {
}
auto* max_trip_count_tensor = context_.Input<Tensor>(0);
auto iter_num_rank = max_trip_count_tensor ? max_trip_count_tensor->Shape().NumDimensions() : 0;
auto* cond_tensor = context_.Input<Tensor>(1);
auto condition_rank = cond_tensor ? cond_tensor->Shape().NumDimensions() : 0;
if (condition_rank >= 2 || iter_num_rank >= 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"'Loop' input 'M' and 'cond' should be a scalar tensor, but have ranks of ",
condition_rank, " and ", iter_num_rank);
if (max_trip_count_tensor) {
if (max_trip_count_tensor->Shape().Size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'Loop' input 'M' should be a scalar tensor. Got shape of ",
max_trip_count_tensor->Shape());
}
}
if (cond_tensor) {
if (cond_tensor->Shape().Size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'Loop' input 'cond' should be a scalar tensor. Got shape of ",
cond_tensor->Shape());
}
}
AllocatorPtr allocator;
status = context_.GetTempSpaceAllocator(&allocator);
ORT_RETURN_IF_ERROR(status);
condition_mlvalue_ = MakeScalarMLValue<bool>(allocator, condition_, condition_rank);
auto iter_num_rank = subgraph_inputs[0]->Shape()->dim_size();
auto condition_rank = subgraph_inputs[1]->Shape()->dim_size();
iter_num_mlvalue_ = MakeScalarMLValue<int64_t>(allocator, 0, iter_num_rank);
condition_mlvalue_ = MakeScalarMLValue<bool>(allocator, condition_, condition_rank);
subgraph_input_names_.reserve(num_subgraph_inputs_);
for (int i = 0; i < num_subgraph_inputs_; ++i) {

View file

@ -25,6 +25,8 @@ struct RunOptions {
bool mixed_execution_providers = false;
bool init_cond_1d_tensor = true;
bool init_iter_num_1d_tensor = true;
bool subgraph_cond_1d_tensor = true;
bool subgraph_iter_num_1d_tensor = true;
};
} // namespace
@ -91,8 +93,8 @@ static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options
bool use_null_typeproto = !include_types && !include_dim_value && !graph_output_shape_required;
bool is_init_cond_1d = options.init_cond_1d_tensor;
bool is_init_iter_num_1d = options.init_iter_num_1d_tensor;
bool is_cond_1d = options.subgraph_cond_1d_tensor;
bool is_iter_num_1d = options.subgraph_iter_num_1d_tensor;
Model model("Loop subgraph");
auto& graph = model.MainGraph();
@ -156,16 +158,16 @@ static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options
// graph inputs
auto& iter_num_in = graph.GetOrCreateNodeArg("iter_num_in",
is_init_iter_num_1d ? &int64_tensor_single_dim : &int64_scalar);
is_iter_num_1d ? &int64_tensor_single_dim : &int64_scalar);
auto& cond_in = graph.GetOrCreateNodeArg("cond_in",
is_init_cond_1d ? &bool_tensor_single_dim : &bool_scalar);
is_cond_1d ? &bool_tensor_single_dim : &bool_scalar);
auto& loop_var_0_in = graph.GetOrCreateNodeArg("loop_var_0_in", &float_tensor_single_dim);
auto& loop_var_1_in = graph.GetOrCreateNodeArg("loop_var_1_in", &float_tensor_single_dim);
auto& iter_num_float = graph.GetOrCreateNodeArg("iter_num_float",
is_init_iter_num_1d ? &float_tensor_single_dim : &float_scalar);
auto& iter_num_float_tensor = is_init_iter_num_1d ? iter_num_float
: graph.GetOrCreateNodeArg("iter_num_float_tensor", &float_tensor_single_dim);
is_iter_num_1d ? &float_tensor_single_dim : &float_scalar);
auto& iter_num_float_tensor = is_iter_num_1d ? iter_num_float
: graph.GetOrCreateNodeArg("iter_num_float_tensor", &float_tensor_single_dim);
// outer scope values. need type but not shape.
auto& outer_scope_0 = graph.GetOrCreateNodeArg("outer_scope_0", &float_tensor);
@ -206,7 +208,7 @@ static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options
}
// Unsqueeze iter_num_float, if initial iter_num is scalar.
if (!is_init_iter_num_1d) {
if (!is_iter_num_1d) {
auto& unsqueeze = graph.AddNode("iter_num_unsqueeze", "Unsqueeze",
"Unsqueeze iter_num_float to tensor of single dim",
{&iter_num_float}, {&iter_num_float_tensor});
@ -393,6 +395,8 @@ void ExitDueToCond(const RunOptions& options) {
\
options.init_iter_num_1d_tensor = iter_num_1d; \
options.init_cond_1d_tensor = cond_1d; \
options.subgraph_iter_num_1d_tensor = iter_num_1d; \
options.subgraph_cond_1d_tensor = cond_1d; \
\
ExitDueToCond(options); \
}
@ -406,6 +410,28 @@ TEST_EXIT_DUE_TO_COND(ExitDueToCond_DimsInSubGraph_ScalarIter, false, false, tru
TEST_EXIT_DUE_TO_COND(ExitDueToCond_DimsInSubGraph_ScalarCond, false, true, false);
TEST_EXIT_DUE_TO_COND(ExitDueToCond_DimsInSubGraph_ScalarBoth, false, false, false);
// check that a rank mismatch between the Loop 'M' and 'cond' inputs and the subgraph is handled gracefully
// if both equate to a scalar (rank 0 or rank 1 with shape of {1})
TEST(Loop, LoopSubgraphRankMismatch) {
RunOptions options{};
options.include_dim_values_in_main_graph = false;
options.include_dim_values_in_subgraph = false;
options.init_iter_num_1d_tensor = true;
options.init_cond_1d_tensor = true;
options.subgraph_cond_1d_tensor = false;
options.subgraph_cond_1d_tensor = false;
ExitDueToCond(options);
options.init_iter_num_1d_tensor = false;
options.init_cond_1d_tensor = false;
options.subgraph_cond_1d_tensor = true;
options.subgraph_cond_1d_tensor = true;
ExitDueToCond(options);
}
TEST(Loop, ExitDueToMaxIterations) {
int64_t max_iterations = 2;
const int64_t expected_num_iterations = 2;