Update comments

This commit is contained in:
Scott McKay 2018-11-26 08:22:28 +10:00
parent eef36d2fbf
commit 720aca581a
2 changed files with 16 additions and 12 deletions

View file

@ -772,7 +772,8 @@ Status ScanImpl::IterateSequence(std::vector<LoopStateVariable>& loop_state_vari
fetches.clear();
bool copy_fetch_to_iter = false;
// one or more outputs have symbolic dimensions and need the first fetch to be copied to the OutputIterator
bool have_symbolic_dim_in_output = false;
for (int output = 0, end = num_variadic_outputs_; output < end; ++output) {
if (output < num_loop_state_variables_) {
@ -784,12 +785,11 @@ Status ScanImpl::IterateSequence(std::vector<LoopStateVariable>& loop_state_vari
auto& mlvalue = *iterator;
fetches.push_back(mlvalue);
// If there is a dynamic shape in an output we need to copy it back to the OutputIterator
// so it can setup the overall output and avoid copies for all other output values.
// The mlvalue in the iterator will point to data once we have the overall output initialized.
// Check current value as we don't want to unset copy_fetch_to_iter if it is true.
if (!copy_fetch_to_iter)
copy_fetch_to_iter = (seq_no == 0) && (mlvalue.IsAllocated() == false);
// mlvalue.IsAllocated will be false when the OutputIterator is using a temporary MLValue
// and not the overall output buffer.
have_symbolic_dim_in_output = seq_no == 0 &&
(mlvalue.IsAllocated() == false ||
have_symbolic_dim_in_output); // don't unset
}
}
@ -810,8 +810,9 @@ Status ScanImpl::IterateSequence(std::vector<LoopStateVariable>& loop_state_vari
for (int output = num_loop_state_variables_; output < num_variadic_outputs_; ++output) {
auto& iterator = *output_iterators_[output];
// copy the data from fetches to the iterator so it can setup the overall output
if (copy_fetch_to_iter && (*iterator).IsAllocated() == false) {
// copy data from the fetch to the iterator so it can setup the overall output when the iterator is incremented.
// if the iterator is already using the overall output buffer IsAllocated() will be true and no copy is required.
if (have_symbolic_dim_in_output && (*iterator).IsAllocated() == false) {
*iterator = fetches[output];
}

View file

@ -665,6 +665,8 @@ TEST(Scan, MixedTypeInputs) {
test.Run();
}
// create a subgraph that will have unknown dimensions in both the loop state variable and output
// after shape inferencing.
TEST(Scan, UnknownDimInSubgraphOutput) {
Model model("ScanBody");
auto& graph = model.MainGraph();
@ -702,12 +704,13 @@ TEST(Scan, UnknownDimInSubgraphOutput) {
test.AddAttribute("body", scan_body);
test.AddAttribute<int64_t>("num_scan_inputs", 1);
test.AddMissingOptionalInput<int64_t>();
// we add a symbolic dimension to bot the initial state and the scan input so we test the path that handles loop
// state variables (prior to execution) and the path that handles subgraph outputs (post first execution).
// we add a symbolic dimension to both the initial state and the scan input so we test
// the path that handles loop state variables (OutputIterator::Initialize) and
// the path that handles subgraph outputs (OutputIterator::MakeConcrete).
// Note that we cross the values over in the subgraph, so the symbolic dimension in
// initial_state_1 affects scan_out_1, and the symbolic dimension in scan_input_1 affects state_out_1.
test.AddMissingOptionalInput<int64_t>();
test.AddShapeToTensorData(true, 1); // add shape and symbolic dim in dim 1 for initial_state_1
test.AddInput<float>("initial_state_1", state_shape, {0.0});
test.AddShapeToTensorData(true, 2); // add shape and symbolic dim in dim 2 for scan_input_1