mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Handle case where the Loop 'M' and 'cond' inputs can be considered scalars but the rank doesn't match the subgraph. Use the subgraph rank when creating the MLValue instance for the subgraph input. (#1285)
This commit is contained in:
parent
c0cf2213bc
commit
a462328d9d
2 changed files with 50 additions and 15 deletions
|
|
@ -203,22 +203,31 @@ Status LoopImpl::Initialize() {
|
|||
}
|
||||
|
||||
auto* max_trip_count_tensor = context_.Input<Tensor>(0);
|
||||
auto iter_num_rank = max_trip_count_tensor ? max_trip_count_tensor->Shape().NumDimensions() : 0;
|
||||
auto* cond_tensor = context_.Input<Tensor>(1);
|
||||
auto condition_rank = cond_tensor ? cond_tensor->Shape().NumDimensions() : 0;
|
||||
|
||||
if (condition_rank >= 2 || iter_num_rank >= 2) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
||||
"'Loop' input 'M' and 'cond' should be a scalar tensor, but have ranks of ",
|
||||
condition_rank, " and ", iter_num_rank);
|
||||
if (max_trip_count_tensor) {
|
||||
if (max_trip_count_tensor->Shape().Size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'Loop' input 'M' should be a scalar tensor. Got shape of ",
|
||||
max_trip_count_tensor->Shape());
|
||||
}
|
||||
}
|
||||
|
||||
if (cond_tensor) {
|
||||
if (cond_tensor->Shape().Size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'Loop' input 'cond' should be a scalar tensor. Got shape of ",
|
||||
cond_tensor->Shape());
|
||||
}
|
||||
}
|
||||
|
||||
AllocatorPtr allocator;
|
||||
status = context_.GetTempSpaceAllocator(&allocator);
|
||||
ORT_RETURN_IF_ERROR(status);
|
||||
|
||||
condition_mlvalue_ = MakeScalarMLValue<bool>(allocator, condition_, condition_rank);
|
||||
auto iter_num_rank = subgraph_inputs[0]->Shape()->dim_size();
|
||||
auto condition_rank = subgraph_inputs[1]->Shape()->dim_size();
|
||||
|
||||
iter_num_mlvalue_ = MakeScalarMLValue<int64_t>(allocator, 0, iter_num_rank);
|
||||
condition_mlvalue_ = MakeScalarMLValue<bool>(allocator, condition_, condition_rank);
|
||||
|
||||
subgraph_input_names_.reserve(num_subgraph_inputs_);
|
||||
for (int i = 0; i < num_subgraph_inputs_; ++i) {
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ struct RunOptions {
|
|||
bool mixed_execution_providers = false;
|
||||
bool init_cond_1d_tensor = true;
|
||||
bool init_iter_num_1d_tensor = true;
|
||||
bool subgraph_cond_1d_tensor = true;
|
||||
bool subgraph_iter_num_1d_tensor = true;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
|
@ -91,8 +93,8 @@ static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options
|
|||
|
||||
bool use_null_typeproto = !include_types && !include_dim_value && !graph_output_shape_required;
|
||||
|
||||
bool is_init_cond_1d = options.init_cond_1d_tensor;
|
||||
bool is_init_iter_num_1d = options.init_iter_num_1d_tensor;
|
||||
bool is_cond_1d = options.subgraph_cond_1d_tensor;
|
||||
bool is_iter_num_1d = options.subgraph_iter_num_1d_tensor;
|
||||
|
||||
Model model("Loop subgraph");
|
||||
auto& graph = model.MainGraph();
|
||||
|
|
@ -156,16 +158,16 @@ static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options
|
|||
|
||||
// graph inputs
|
||||
auto& iter_num_in = graph.GetOrCreateNodeArg("iter_num_in",
|
||||
is_init_iter_num_1d ? &int64_tensor_single_dim : &int64_scalar);
|
||||
is_iter_num_1d ? &int64_tensor_single_dim : &int64_scalar);
|
||||
auto& cond_in = graph.GetOrCreateNodeArg("cond_in",
|
||||
is_init_cond_1d ? &bool_tensor_single_dim : &bool_scalar);
|
||||
is_cond_1d ? &bool_tensor_single_dim : &bool_scalar);
|
||||
auto& loop_var_0_in = graph.GetOrCreateNodeArg("loop_var_0_in", &float_tensor_single_dim);
|
||||
auto& loop_var_1_in = graph.GetOrCreateNodeArg("loop_var_1_in", &float_tensor_single_dim);
|
||||
|
||||
auto& iter_num_float = graph.GetOrCreateNodeArg("iter_num_float",
|
||||
is_init_iter_num_1d ? &float_tensor_single_dim : &float_scalar);
|
||||
auto& iter_num_float_tensor = is_init_iter_num_1d ? iter_num_float
|
||||
: graph.GetOrCreateNodeArg("iter_num_float_tensor", &float_tensor_single_dim);
|
||||
is_iter_num_1d ? &float_tensor_single_dim : &float_scalar);
|
||||
auto& iter_num_float_tensor = is_iter_num_1d ? iter_num_float
|
||||
: graph.GetOrCreateNodeArg("iter_num_float_tensor", &float_tensor_single_dim);
|
||||
|
||||
// outer scope values. need type but not shape.
|
||||
auto& outer_scope_0 = graph.GetOrCreateNodeArg("outer_scope_0", &float_tensor);
|
||||
|
|
@ -206,7 +208,7 @@ static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options
|
|||
}
|
||||
|
||||
// Unsqueeze iter_num_float, if initial iter_num is scalar.
|
||||
if (!is_init_iter_num_1d) {
|
||||
if (!is_iter_num_1d) {
|
||||
auto& unsqueeze = graph.AddNode("iter_num_unsqueeze", "Unsqueeze",
|
||||
"Unsqueeze iter_num_float to tensor of single dim",
|
||||
{&iter_num_float}, {&iter_num_float_tensor});
|
||||
|
|
@ -393,6 +395,8 @@ void ExitDueToCond(const RunOptions& options) {
|
|||
\
|
||||
options.init_iter_num_1d_tensor = iter_num_1d; \
|
||||
options.init_cond_1d_tensor = cond_1d; \
|
||||
options.subgraph_iter_num_1d_tensor = iter_num_1d; \
|
||||
options.subgraph_cond_1d_tensor = cond_1d; \
|
||||
\
|
||||
ExitDueToCond(options); \
|
||||
}
|
||||
|
|
@ -406,6 +410,28 @@ TEST_EXIT_DUE_TO_COND(ExitDueToCond_DimsInSubGraph_ScalarIter, false, false, tru
|
|||
TEST_EXIT_DUE_TO_COND(ExitDueToCond_DimsInSubGraph_ScalarCond, false, true, false);
|
||||
TEST_EXIT_DUE_TO_COND(ExitDueToCond_DimsInSubGraph_ScalarBoth, false, false, false);
|
||||
|
||||
// check that a rank mismatch between the Loop 'M' and 'cond' inputs and the subgraph is handled gracefully
|
||||
// if both equate to a scalar (rank 0 or rank 1 with shape of {1})
|
||||
TEST(Loop, LoopSubgraphRankMismatch) {
|
||||
RunOptions options{};
|
||||
options.include_dim_values_in_main_graph = false;
|
||||
options.include_dim_values_in_subgraph = false;
|
||||
|
||||
options.init_iter_num_1d_tensor = true;
|
||||
options.init_cond_1d_tensor = true;
|
||||
options.subgraph_cond_1d_tensor = false;
|
||||
options.subgraph_cond_1d_tensor = false;
|
||||
|
||||
ExitDueToCond(options);
|
||||
|
||||
options.init_iter_num_1d_tensor = false;
|
||||
options.init_cond_1d_tensor = false;
|
||||
options.subgraph_cond_1d_tensor = true;
|
||||
options.subgraph_cond_1d_tensor = true;
|
||||
|
||||
ExitDueToCond(options);
|
||||
}
|
||||
|
||||
TEST(Loop, ExitDueToMaxIterations) {
|
||||
int64_t max_iterations = 2;
|
||||
const int64_t expected_num_iterations = 2;
|
||||
|
|
|
|||
Loading…
Reference in a new issue