mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Return better output shape for Loop with zero iterations (#1233)
* Attempt to provide the correct rank for an output from a Loop node when there are no iterations.
For a loop output (vs. loop carried dependency) the first dimension is the iteration count so will have a value of 0 and the output size will be zero. Use the rank of the matching subgraph output if available.
If the subgraph output rank is not available output a warning and use a rank 1 shape of {0}.
This commit is contained in:
parent
a4148c85a5
commit
6477d4e756
2 changed files with 61 additions and 21 deletions
|
|
@ -399,12 +399,33 @@ Status LoopImpl::Execute(FeedsFetchesManager* ffm, const FeedsFetchesManager* ca
|
|||
copy_tensor_from_mlvalue_to_output(feeds[i + 2], i); // skip iter# and cond
|
||||
}
|
||||
|
||||
// create empty outputs for loop outputs
|
||||
TensorShape empty;
|
||||
// create empty outputs for loop outputs using the subgraph output shapes for the rank
|
||||
auto& graph_outputs = subgraph_.GetOutputs();
|
||||
|
||||
for (int i = num_loop_carried_vars_; i < num_outputs_; ++i) {
|
||||
ORT_IGNORE_RETURN_VALUE(context_.Output(i, empty));
|
||||
std::vector<int64_t> output_dims;
|
||||
output_dims.push_back(0); // num iterations is first dim
|
||||
|
||||
// get shape from subgraph output if possible to attempt to have the correct rank
|
||||
auto* graph_output = graph_outputs.at(i + 1); // + 1 as first subgraph output is condition value
|
||||
auto* graph_output_shape = graph_output->Shape();
|
||||
|
||||
if (graph_output_shape) {
|
||||
output_dims.reserve(graph_output_shape->dim_size() + 1);
|
||||
|
||||
auto dims = onnxruntime::utils::GetTensorShapeFromTensorShapeProto(*graph_output_shape);
|
||||
std::copy(dims.cbegin(), dims.cend(), std::back_inserter(output_dims));
|
||||
} else {
|
||||
// TODO: We could try and call ExecuteGraph to get the output shape from fetches so the rank is correct,
|
||||
// however that could still fail as we would potentially be passing in invalid data.
|
||||
// Until we know this is required just output a warning and return the rank 1 empty output.
|
||||
LOGS(context_.Logger(), WARNING) << "Loop had zero iterations and the shape of subgraph output " << i + 1
|
||||
<< " was not found. Defaulting to a rank 1 shape of {0}.";
|
||||
}
|
||||
|
||||
ORT_IGNORE_RETURN_VALUE(context_.Output(i, TensorShape(output_dims)));
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -26,7 +26,8 @@ struct RunOptions {
|
|||
bool init_cond_1d_tensor = true;
|
||||
bool init_iter_num_1d_tensor = true;
|
||||
};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options);
|
||||
|
||||
static const float kOuterNodeAddValue = 3.f;
|
||||
|
|
@ -103,7 +104,6 @@ static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options
|
|||
Concats the iter_num to loop_var_1_in (test loop var that changes shape) so each iteration appends the iter_num
|
||||
to loop_var_1
|
||||
Loop output is the iter_num and sum for that iteration, so each iteration adds a pair to the overall output
|
||||
Inputs require Identity nodes to fix their order.
|
||||
|
||||
Inputs: iter_num, cond_in, loop_var_in
|
||||
|
||||
|
|
@ -342,7 +342,7 @@ void RunTest(int64_t max_iterations,
|
|||
}
|
||||
|
||||
test.AddInput<float>("loop_var_0_orig", {1}, {0.f});
|
||||
test.AddInput<float>("loop_var_0_orig", {1}, {0.f});
|
||||
test.AddInput<float>("loop_var_1_orig", {1}, {0.f});
|
||||
|
||||
test.AddOutput<float>("loop_var_0_final", {1}, {loop_var_0_final});
|
||||
test.AddOutput<float>("loop_var_1_final", loop_var_1_final_shape, loop_var_1_final);
|
||||
|
|
@ -357,7 +357,7 @@ void RunTest(int64_t max_iterations,
|
|||
|
||||
test.Run(expect_result, failure_message, {kTensorrtExecutionProvider}, nullptr, &execution_providers);
|
||||
} else {
|
||||
test.Run(expect_result, failure_message, {kTensorrtExecutionProvider});// Disable TensorRT because of unsupported data type INT64
|
||||
test.Run(expect_result, failure_message, {kTensorrtExecutionProvider}); // Disable TensorRT because of unsupported data type INT64
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -384,17 +384,17 @@ void ExitDueToCond(const RunOptions& options) {
|
|||
options);
|
||||
}
|
||||
|
||||
#define TEST_EXIT_DUE_TO_COND(name, dim_in_main_graph, iter_num_1d, cond_1d) \
|
||||
TEST(Loop, name) { \
|
||||
RunOptions options{}; \
|
||||
options.include_dim_values_in_main_graph = dim_in_main_graph; \
|
||||
options.include_dim_values_in_subgraph = !dim_in_main_graph; \
|
||||
options.include_types_in_subgraph = false; \
|
||||
\
|
||||
options.init_iter_num_1d_tensor = iter_num_1d; \
|
||||
options.init_cond_1d_tensor = cond_1d; \
|
||||
\
|
||||
ExitDueToCond(options); \
|
||||
#define TEST_EXIT_DUE_TO_COND(name, dim_in_main_graph, iter_num_1d, cond_1d) \
|
||||
TEST(Loop, name) { \
|
||||
RunOptions options{}; \
|
||||
options.include_dim_values_in_main_graph = dim_in_main_graph; \
|
||||
options.include_dim_values_in_subgraph = !dim_in_main_graph; \
|
||||
options.include_types_in_subgraph = false; \
|
||||
\
|
||||
options.init_iter_num_1d_tensor = iter_num_1d; \
|
||||
options.init_cond_1d_tensor = cond_1d; \
|
||||
\
|
||||
ExitDueToCond(options); \
|
||||
}
|
||||
|
||||
TEST_EXIT_DUE_TO_COND(ExitDueToCond_DimsInMainGraph, true, true, true);
|
||||
|
|
@ -426,6 +426,25 @@ TEST(Loop, ExitDueToMaxIterations) {
|
|||
{});
|
||||
}
|
||||
|
||||
TEST(Loop, ZeroIterations) {
|
||||
int64_t max_iterations = 0;
|
||||
|
||||
float loop_var_0_final = 0.f;
|
||||
|
||||
std::vector<int64_t> loop_var_1_final_shape{1};
|
||||
std::vector<float> loop_var_1_final{0.f};
|
||||
|
||||
// zero iterations so first dim value is 0. also checking rank is correct.
|
||||
std::vector<int64_t> loop_out_0_final_shape{0, 0};
|
||||
std::vector<float> loop_out_0_final{};
|
||||
|
||||
RunTest(max_iterations,
|
||||
loop_var_0_final,
|
||||
loop_var_1_final_shape, loop_var_1_final,
|
||||
loop_out_0_final_shape, loop_out_0_final,
|
||||
{});
|
||||
}
|
||||
|
||||
TEST(Loop, InfiniteLoopTermination) {
|
||||
auto create_subgraph = [](const RunOptions&) {
|
||||
Model model("Infinite Loop subgraph");
|
||||
|
|
@ -518,8 +537,8 @@ TEST(Loop, InfiniteLoopTermination) {
|
|||
std::future<void> terminator_result = task.get_future();
|
||||
std::thread terminator_thread{std::move(task)};
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectFailure, "Exiting due to terminate flag being set to true", {kTensorrtExecutionProvider},
|
||||
&session_run_options);// Disable TensorRT on unsupported data type BOOL
|
||||
test.Run(OpTester::ExpectResult::kExpectFailure, "Exiting due to terminate flag being set to true",
|
||||
{kTensorrtExecutionProvider}, &session_run_options); // Disable TensorRT on unsupported data type BOOL
|
||||
|
||||
// call get to propagate any exception
|
||||
terminator_result.get();
|
||||
|
|
|
|||
Loading…
Reference in a new issue