From c579ebfbc3f07aa5636093ef033ff4187074a7d1 Mon Sep 17 00:00:00 2001 From: mindest <30493312+mindest@users.noreply.github.com> Date: Tue, 9 Nov 2021 08:33:50 +0800 Subject: [PATCH] change a for iteration (#9678) Co-authored-by: Min Lin --- .../python/training/ortmodule/_gradient_accumulation_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_gradient_accumulation_manager.py b/orttraining/orttraining/python/training/ortmodule/_gradient_accumulation_manager.py index 4358379010..14ba4bdbb0 100644 --- a/orttraining/orttraining/python/training/ortmodule/_gradient_accumulation_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_gradient_accumulation_manager.py @@ -55,7 +55,7 @@ class GradientAccumulationManager(object): forward_outputs (OrtValueVector): List of outputs returned by forward function """ if not self.enabled: - return tuple(_utils._ortvalue_to_torch_tensor(forward_output, device) for forward_output in forward_outputs) + return tuple(_utils._ortvalue_to_torch_tensor(forward_outputs[i], device) for i in range(len(forward_outputs))) if self._update_cache: for i in range(self._cache_start, len(forward_outputs)): self.cache.insert(