mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Keep all_finite tensor on CPU when using PyTorch Frontend (#5371)
This commit is contained in:
parent
c2c78399ee
commit
498f94668d
1 changed files with 6 additions and 1 deletions
|
|
@ -801,7 +801,12 @@ class ORTTrainer(object):
|
|||
outputs_desc_resolved = self._resolve_symbolic_dimensions(inputs, inputs_desc, outputs_desc)
|
||||
result = {}
|
||||
for output_desc in outputs_desc_resolved:
|
||||
torch_tensor = torch.zeros(output_desc.shape, device=self.options.device.id,
|
||||
target_device = self.options.device.id
|
||||
if self.options.mixed_precision.enabled and output_desc.name == self.model_desc.all_finite.name:
|
||||
# Keep all finite flag on CPU to match backend implementation
|
||||
# This prevents CPU -> GPU -> CPU copies between frontend and backend
|
||||
target_device = 'cpu'
|
||||
torch_tensor = torch.zeros(output_desc.shape, device=target_device,
|
||||
dtype=output_desc.dtype_amp if output_desc.dtype_amp else output_desc.dtype)
|
||||
iobinding.bind_output(output_desc.name, torch_tensor.device.type, _utils.get_device_index(self.options.device.id),
|
||||
_utils.dtype_torch_to_numpy(torch_tensor.dtype),
|
||||
|
|
|
|||
Loading…
Reference in a new issue