diff --git a/orttraining/orttraining/python/training/ortmodule/_gradient_accumulation_manager.py b/orttraining/orttraining/python/training/ortmodule/_gradient_accumulation_manager.py index 4358379010..14ba4bdbb0 100644 --- a/orttraining/orttraining/python/training/ortmodule/_gradient_accumulation_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_gradient_accumulation_manager.py @@ -55,7 +55,7 @@ class GradientAccumulationManager(object): forward_outputs (OrtValueVector): List of outputs returned by forward function """ if not self.enabled: - return tuple(_utils._ortvalue_to_torch_tensor(forward_output, device) for forward_output in forward_outputs) + return tuple(_utils._ortvalue_to_torch_tensor(forward_outputs[i], device) for i in range(len(forward_outputs))) if self._update_cache: for i in range(self._cache_start, len(forward_outputs)): self.cache.insert(