[quant] fill_ path for quantized tensors (#43303)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43303

Test Plan: Imported from OSS

Reviewed By: raghuramank100

Differential Revision: D23231947

Pulled By: z-a-f

fbshipit-source-id: fd5110ff15a073f326ef590436f8c6e5a2608324
This commit is contained in:
Zafar 2020-09-08 21:32:33 -07:00 committed by Facebook GitHub Bot
parent 4aacfab221
commit 1d01fcdc24
2 changed files with 36 additions and 0 deletions

View file

@ -19,6 +19,13 @@ namespace {
} // namspace
Tensor& fill_out(Tensor& self, Scalar value) {
if (self.is_quantized()) {
at::Tensor out = at::ones(self.sizes()).to(kFloat) * value;
out = out.to(self.device());
// Trust the `copy_` to handle the quantization and the boundary chacks.
self.copy_(out);
return self;
}
// When filling a number to 1-element CPU tensor, we want to skip
// everything but manipulate data ptr directly.
// Ideally this fast pass should be implemented in TensorIterator,

View file

@ -484,6 +484,35 @@ class TestQuantizedTensor(TestCase):
# Check to make sure the scale and zero_point has been copied.
self.assertEqual(q, q2)
def test_qtensor_fill(self):
numel = 10
scale = 0.5
zero_point = 10
ones = torch.ones(numel).to(torch.float)
types = [torch.qint8, torch.quint8, torch.qint32]
fills = [-1, 1, 2**32] # positive, negative, overflow
# `fill_` uses `copy_(float)`, which doesn't support CUDA
device = 'cpu'
ones = ones.to(device)
for qtype, fill_with in itertools.product(types, fills):
q_filled = torch._empty_affine_quantized(
[numel], scale=scale, zero_point=zero_point, device=device,
dtype=qtype)
q_filled.fill_(fill_with)
int_repr = torch.quantize_per_tensor(ones * fill_with, scale,
zero_point, qtype)
fill_with = int_repr.dequantize()
int_repr = int_repr.int_repr()
self.assertEqual(q_filled.int_repr(), int_repr)
self.assertEqual(q_filled.dequantize(), fill_with)
# Make sure the scale and zero_point don't change
self.assertEqual(q_filled.q_scale(), scale)
self.assertEqual(q_filled.q_zero_point(), zero_point)
def test_qtensor_view(self):
scale, zero_point, dtype = 1.0, 2, torch.uint8
for device in get_supported_device_types():