diff --git a/aten/src/ATen/native/Fill.cpp b/aten/src/ATen/native/Fill.cpp index d02dc781ed1..73f7dcd6192 100644 --- a/aten/src/ATen/native/Fill.cpp +++ b/aten/src/ATen/native/Fill.cpp @@ -19,6 +19,13 @@ namespace { } // namspace Tensor& fill_out(Tensor& self, Scalar value) { + if (self.is_quantized()) { + at::Tensor out = at::ones(self.sizes()).to(kFloat) * value; + out = out.to(self.device()); + // Trust the `copy_` to handle the quantization and the boundary chacks. + self.copy_(out); + return self; + } // When filling a number to 1-element CPU tensor, we want to skip // everything but manipulate data ptr directly. // Ideally this fast pass should be implemented in TensorIterator, diff --git a/test/quantization/test_quantized_tensor.py b/test/quantization/test_quantized_tensor.py index aa95373eca1..d0714bb72de 100644 --- a/test/quantization/test_quantized_tensor.py +++ b/test/quantization/test_quantized_tensor.py @@ -484,6 +484,35 @@ class TestQuantizedTensor(TestCase): # Check to make sure the scale and zero_point has been copied. self.assertEqual(q, q2) + def test_qtensor_fill(self): + numel = 10 + scale = 0.5 + zero_point = 10 + + ones = torch.ones(numel).to(torch.float) + + types = [torch.qint8, torch.quint8, torch.qint32] + fills = [-1, 1, 2**32] # positive, negative, overflow + + # `fill_` uses `copy_(float)`, which doesn't support CUDA + device = 'cpu' + ones = ones.to(device) + for qtype, fill_with in itertools.product(types, fills): + q_filled = torch._empty_affine_quantized( + [numel], scale=scale, zero_point=zero_point, device=device, + dtype=qtype) + q_filled.fill_(fill_with) + int_repr = torch.quantize_per_tensor(ones * fill_with, scale, + zero_point, qtype) + fill_with = int_repr.dequantize() + int_repr = int_repr.int_repr() + + self.assertEqual(q_filled.int_repr(), int_repr) + self.assertEqual(q_filled.dequantize(), fill_with) + # Make sure the scale and zero_point don't change + self.assertEqual(q_filled.q_scale(), scale) + self.assertEqual(q_filled.q_zero_point(), zero_point) + def test_qtensor_view(self): scale, zero_point, dtype = 1.0, 2, torch.uint8 for device in get_supported_device_types():