From 498f94668da2443df191e11bc39442208a778016 Mon Sep 17 00:00:00 2001 From: Suffian Khan Date: Thu, 8 Oct 2020 15:47:18 -0700 Subject: [PATCH] Keep all_finite tensor on CPU when using PyTorch Frontend (#5371) --- orttraining/orttraining/python/training/orttrainer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index bbb24e1c9d..1361a73a90 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -801,7 +801,12 @@ class ORTTrainer(object): outputs_desc_resolved = self._resolve_symbolic_dimensions(inputs, inputs_desc, outputs_desc) result = {} for output_desc in outputs_desc_resolved: - torch_tensor = torch.zeros(output_desc.shape, device=self.options.device.id, + target_device = self.options.device.id + if self.options.mixed_precision.enabled and output_desc.name == self.model_desc.all_finite.name: + # Keep all finite flag on CPU to match backend implementation + # This prevents CPU -> GPU -> CPU copies between frontend and backend + target_device = 'cpu' + torch_tensor = torch.zeros(output_desc.shape, device=target_device, dtype=output_desc.dtype_amp if output_desc.dtype_amp else output_desc.dtype) iobinding.bind_output(output_desc.name, torch_tensor.device.type, _utils.get_device_index(self.options.device.id), _utils.dtype_torch_to_numpy(torch_tensor.dtype),