Return better output shape for Loop with zero iterations (#1233)

* Attempt to provide the correct rank for an output from a Loop node when there are no iterations.

For a loop output (vs. loop carried dependency) the first dimension is the iteration count so will have a value of 0 and the output size will be zero. Use the rank of the matching subgraph output if available.

If the subgraph output rank is not available output a warning and use a rank 1 shape of {0}.
This commit is contained in:
Scott McKay 2019-06-19 07:31:13 +10:00 committed by GitHub
parent a4148c85a5
commit 6477d4e756
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 21 deletions

View file

@ -399,12 +399,33 @@ Status LoopImpl::Execute(FeedsFetchesManager* ffm, const FeedsFetchesManager* ca
copy_tensor_from_mlvalue_to_output(feeds[i + 2], i); // skip iter# and cond
}
// create empty outputs for loop outputs
TensorShape empty;
// create empty outputs for loop outputs using the subgraph output shapes for the rank
auto& graph_outputs = subgraph_.GetOutputs();
for (int i = num_loop_carried_vars_; i < num_outputs_; ++i) {
ORT_IGNORE_RETURN_VALUE(context_.Output(i, empty));
std::vector<int64_t> output_dims;
output_dims.push_back(0); // num iterations is first dim
// get shape from subgraph output if possible to attempt to have the correct rank
auto* graph_output = graph_outputs.at(i + 1); // + 1 as first subgraph output is condition value
auto* graph_output_shape = graph_output->Shape();
if (graph_output_shape) {
output_dims.reserve(graph_output_shape->dim_size() + 1);
auto dims = onnxruntime::utils::GetTensorShapeFromTensorShapeProto(*graph_output_shape);
std::copy(dims.cbegin(), dims.cend(), std::back_inserter(output_dims));
} else {
// TODO: We could try and call ExecuteGraph to get the output shape from fetches so the rank is correct,
// however that could still fail as we would potentially be passing in invalid data.
// Until we know this is required just output a warning and return the rank 1 empty output.
LOGS(context_.Logger(), WARNING) << "Loop had zero iterations and the shape of subgraph output " << i + 1
<< " was not found. Defaulting to a rank 1 shape of {0}.";
}
ORT_IGNORE_RETURN_VALUE(context_.Output(i, TensorShape(output_dims)));
}
}
return status;
}
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -26,7 +26,8 @@ struct RunOptions {
bool init_cond_1d_tensor = true;
bool init_iter_num_1d_tensor = true;
};
}
} // namespace
static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options);
static const float kOuterNodeAddValue = 3.f;
@ -103,7 +104,6 @@ static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options
Concats the iter_num to loop_var_1_in (test loop var that changes shape) so each iteration appends the iter_num
to loop_var_1
Loop output is the iter_num and sum for that iteration, so each iteration adds a pair to the overall output
Inputs require Identity nodes to fix their order.
Inputs: iter_num, cond_in, loop_var_in
@ -342,7 +342,7 @@ void RunTest(int64_t max_iterations,
}
test.AddInput<float>("loop_var_0_orig", {1}, {0.f});
test.AddInput<float>("loop_var_0_orig", {1}, {0.f});
test.AddInput<float>("loop_var_1_orig", {1}, {0.f});
test.AddOutput<float>("loop_var_0_final", {1}, {loop_var_0_final});
test.AddOutput<float>("loop_var_1_final", loop_var_1_final_shape, loop_var_1_final);
@ -357,7 +357,7 @@ void RunTest(int64_t max_iterations,
test.Run(expect_result, failure_message, {kTensorrtExecutionProvider}, nullptr, &execution_providers);
} else {
test.Run(expect_result, failure_message, {kTensorrtExecutionProvider});// Disable TensorRT because of unsupported data type INT64
test.Run(expect_result, failure_message, {kTensorrtExecutionProvider}); // Disable TensorRT because of unsupported data type INT64
}
}
@ -384,17 +384,17 @@ void ExitDueToCond(const RunOptions& options) {
options);
}
#define TEST_EXIT_DUE_TO_COND(name, dim_in_main_graph, iter_num_1d, cond_1d) \
TEST(Loop, name) { \
RunOptions options{}; \
options.include_dim_values_in_main_graph = dim_in_main_graph; \
options.include_dim_values_in_subgraph = !dim_in_main_graph; \
options.include_types_in_subgraph = false; \
\
options.init_iter_num_1d_tensor = iter_num_1d; \
options.init_cond_1d_tensor = cond_1d; \
\
ExitDueToCond(options); \
#define TEST_EXIT_DUE_TO_COND(name, dim_in_main_graph, iter_num_1d, cond_1d) \
TEST(Loop, name) { \
RunOptions options{}; \
options.include_dim_values_in_main_graph = dim_in_main_graph; \
options.include_dim_values_in_subgraph = !dim_in_main_graph; \
options.include_types_in_subgraph = false; \
\
options.init_iter_num_1d_tensor = iter_num_1d; \
options.init_cond_1d_tensor = cond_1d; \
\
ExitDueToCond(options); \
}
TEST_EXIT_DUE_TO_COND(ExitDueToCond_DimsInMainGraph, true, true, true);
@ -426,6 +426,25 @@ TEST(Loop, ExitDueToMaxIterations) {
{});
}
TEST(Loop, ZeroIterations) {
int64_t max_iterations = 0;
float loop_var_0_final = 0.f;
std::vector<int64_t> loop_var_1_final_shape{1};
std::vector<float> loop_var_1_final{0.f};
// zero iterations so first dim value is 0. also checking rank is correct.
std::vector<int64_t> loop_out_0_final_shape{0, 0};
std::vector<float> loop_out_0_final{};
RunTest(max_iterations,
loop_var_0_final,
loop_var_1_final_shape, loop_var_1_final,
loop_out_0_final_shape, loop_out_0_final,
{});
}
TEST(Loop, InfiniteLoopTermination) {
auto create_subgraph = [](const RunOptions&) {
Model model("Infinite Loop subgraph");
@ -518,8 +537,8 @@ TEST(Loop, InfiniteLoopTermination) {
std::future<void> terminator_result = task.get_future();
std::thread terminator_thread{std::move(task)};
test.Run(OpTester::ExpectResult::kExpectFailure, "Exiting due to terminate flag being set to true", {kTensorrtExecutionProvider},
&session_run_options);// Disable TensorRT on unsupported data type BOOL
test.Run(OpTester::ExpectResult::kExpectFailure, "Exiting due to terminate flag being set to true",
{kTensorrtExecutionProvider}, &session_run_options); // Disable TensorRT on unsupported data type BOOL
// call get to propagate any exception
terminator_result.get();