Keep all_finite tensor on CPU when using PyTorch Frontend (#5371)

This commit is contained in:
Suffian Khan 2020-10-08 15:47:18 -07:00 committed by GitHub
parent c2c78399ee
commit 498f94668d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -801,7 +801,12 @@ class ORTTrainer(object):
outputs_desc_resolved = self._resolve_symbolic_dimensions(inputs, inputs_desc, outputs_desc)
result = {}
for output_desc in outputs_desc_resolved:
torch_tensor = torch.zeros(output_desc.shape, device=self.options.device.id,
target_device = self.options.device.id
if self.options.mixed_precision.enabled and output_desc.name == self.model_desc.all_finite.name:
# Keep all finite flag on CPU to match backend implementation
# This prevents CPU -> GPU -> CPU copies between frontend and backend
target_device = 'cpu'
torch_tensor = torch.zeros(output_desc.shape, device=target_device,
dtype=output_desc.dtype_amp if output_desc.dtype_amp else output_desc.dtype)
iobinding.bind_output(output_desc.name, torch_tensor.device.type, _utils.get_device_index(self.options.device.id),
_utils.dtype_torch_to_numpy(torch_tensor.dtype),