mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Update comments
This commit is contained in:
parent
eef36d2fbf
commit
720aca581a
2 changed files with 16 additions and 12 deletions
|
|
@ -772,7 +772,8 @@ Status ScanImpl::IterateSequence(std::vector<LoopStateVariable>& loop_state_vari
|
|||
|
||||
fetches.clear();
|
||||
|
||||
bool copy_fetch_to_iter = false;
|
||||
// one or more outputs have symbolic dimensions and need the first fetch to be copied to the OutputIterator
|
||||
bool have_symbolic_dim_in_output = false;
|
||||
|
||||
for (int output = 0, end = num_variadic_outputs_; output < end; ++output) {
|
||||
if (output < num_loop_state_variables_) {
|
||||
|
|
@ -784,12 +785,11 @@ Status ScanImpl::IterateSequence(std::vector<LoopStateVariable>& loop_state_vari
|
|||
auto& mlvalue = *iterator;
|
||||
fetches.push_back(mlvalue);
|
||||
|
||||
// If there is a dynamic shape in an output we need to copy it back to the OutputIterator
|
||||
// so it can setup the overall output and avoid copies for all other output values.
|
||||
// The mlvalue in the iterator will point to data once we have the overall output initialized.
|
||||
// Check current value as we don't want to unset copy_fetch_to_iter if it is true.
|
||||
if (!copy_fetch_to_iter)
|
||||
copy_fetch_to_iter = (seq_no == 0) && (mlvalue.IsAllocated() == false);
|
||||
// mlvalue.IsAllocated will be false when the OutputIterator is using a temporary MLValue
|
||||
// and not the overall output buffer.
|
||||
have_symbolic_dim_in_output = seq_no == 0 &&
|
||||
(mlvalue.IsAllocated() == false ||
|
||||
have_symbolic_dim_in_output); // don't unset
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -810,8 +810,9 @@ Status ScanImpl::IterateSequence(std::vector<LoopStateVariable>& loop_state_vari
|
|||
for (int output = num_loop_state_variables_; output < num_variadic_outputs_; ++output) {
|
||||
auto& iterator = *output_iterators_[output];
|
||||
|
||||
// copy the data from fetches to the iterator so it can setup the overall output
|
||||
if (copy_fetch_to_iter && (*iterator).IsAllocated() == false) {
|
||||
// copy data from the fetch to the iterator so it can setup the overall output when the iterator is incremented.
|
||||
// if the iterator is already using the overall output buffer IsAllocated() will be true and no copy is required.
|
||||
if (have_symbolic_dim_in_output && (*iterator).IsAllocated() == false) {
|
||||
*iterator = fetches[output];
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -665,6 +665,8 @@ TEST(Scan, MixedTypeInputs) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
// create a subgraph that will have unknown dimensions in both the loop state variable and output
|
||||
// after shape inferencing.
|
||||
TEST(Scan, UnknownDimInSubgraphOutput) {
|
||||
Model model("ScanBody");
|
||||
auto& graph = model.MainGraph();
|
||||
|
|
@ -702,12 +704,13 @@ TEST(Scan, UnknownDimInSubgraphOutput) {
|
|||
|
||||
test.AddAttribute("body", scan_body);
|
||||
test.AddAttribute<int64_t>("num_scan_inputs", 1);
|
||||
test.AddMissingOptionalInput<int64_t>();
|
||||
|
||||
// we add a symbolic dimension to bot the initial state and the scan input so we test the path that handles loop
|
||||
// state variables (prior to execution) and the path that handles subgraph outputs (post first execution).
|
||||
// we add a symbolic dimension to both the initial state and the scan input so we test
|
||||
// the path that handles loop state variables (OutputIterator::Initialize) and
|
||||
// the path that handles subgraph outputs (OutputIterator::MakeConcrete).
|
||||
// Note that we cross the values over in the subgraph, so the symbolic dimension in
|
||||
// initial_state_1 affects scan_out_1, and the symbolic dimension in scan_input_1 affects state_out_1.
|
||||
test.AddMissingOptionalInput<int64_t>();
|
||||
test.AddShapeToTensorData(true, 1); // add shape and symbolic dim in dim 1 for initial_state_1
|
||||
test.AddInput<float>("initial_state_1", state_shape, {0.0});
|
||||
test.AddShapeToTensorData(true, 2); // add shape and symbolic dim in dim 2 for scan_input_1
|
||||
|
|
|
|||
Loading…
Reference in a new issue