mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
torch/ao/quantization/utils.py: Moving eps to targeted device to avoid device mismatch issue (#135204)
MOTIVATION We recently verified some quantization tests on devices other than cpu (eg. CUDA and Intel Gaudi devices identified as 'hpu'). We noticed a device mismatch error as eps is a tensor created on cpu but other tensors (min_val_neg, max_val_pos, scale, zero_point) are moved to the targeted _device_. CHANGES Move eps to _device_ of other tensors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135204 Approved by: https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
parent
cef6c3dcb0
commit
3361908fc5
1 changed files with 1 additions and 0 deletions
|
|
@ -649,6 +649,7 @@ def determine_qparams(
|
|||
device = min_val_neg.device
|
||||
scale = torch.ones(min_val_neg.size(), dtype=torch.double, device=device)
|
||||
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
|
||||
eps = eps.to(device)
|
||||
|
||||
if qscheme == torch.per_tensor_symmetric or qscheme == torch.per_channel_symmetric:
|
||||
max_val_pos = torch.max(-min_val_neg, max_val_pos)
|
||||
|
|
|
|||
Loading…
Reference in a new issue