diff --git a/onnxruntime/core/providers/cpu/controlflow/scan.cc b/onnxruntime/core/providers/cpu/controlflow/scan.cc index 0504fe5a2f..2952d5f92e 100644 --- a/onnxruntime/core/providers/cpu/controlflow/scan.cc +++ b/onnxruntime/core/providers/cpu/controlflow/scan.cc @@ -119,8 +119,10 @@ class LoopStateVariable { }; /* -Class that co-ordinates writes to slices of the overall Scan output. -It will directly update the data returned by OpKernelContextInternal.Output(i). +Class that co-ordinates writing to slices of the overall Scan output buffer returned by OpKernelContext.Output(i). +If the subgraph has a symbolic dimension in an output it will use a temporary MLValue for the first execution +in order to discover the output shape. Once the shape is known, it will switch to using the overall output buffer +to avoid copies. */ class OutputIterator { public: @@ -136,6 +138,7 @@ class OutputIterator { MLValue& operator*(); OutputIterator& operator++(); + // set the output for the current iteration to zeros. used for short sequence lengths void ZeroOutCurrent() { auto* tensor = (**this).GetMutable(); memset(tensor->MutableDataRaw(), 0, tensor->Size()); @@ -160,7 +163,10 @@ class OutputIterator { int64_t num_iterations_; int64_t cur_iteration_; + // is the final shape concrete, or does it have symbolic dimensions bool is_concrete_shape_; + + // one or more slicers for writing to the output std::vector::Iterator> slicer_iterators_; std::vector::Iterator>::iterator cur_slicer_iterator_; @@ -302,6 +308,7 @@ void LoopStateVariable::Next() { ++iteration_num_; } +// fill in a symbolic dimension in the overall output using the output shape from an iteration of the subgraph static Status MakeShapeConcrete(const TensorShape& per_iteration_shape, TensorShape& final_shape) { auto num_dims_per_iteration = per_iteration_shape.NumDimensions(); auto final_shape_offset = final_shape.NumDimensions() - num_dims_per_iteration;