mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[quant] fill_ path for quantized tensors (#43303)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43303 Test Plan: Imported from OSS Reviewed By: raghuramank100 Differential Revision: D23231947 Pulled By: z-a-f fbshipit-source-id: fd5110ff15a073f326ef590436f8c6e5a2608324
This commit is contained in:
parent
4aacfab221
commit
1d01fcdc24
2 changed files with 36 additions and 0 deletions
|
|
@ -19,6 +19,13 @@ namespace {
|
|||
} // namspace
|
||||
|
||||
Tensor& fill_out(Tensor& self, Scalar value) {
|
||||
if (self.is_quantized()) {
|
||||
at::Tensor out = at::ones(self.sizes()).to(kFloat) * value;
|
||||
out = out.to(self.device());
|
||||
// Trust the `copy_` to handle the quantization and the boundary chacks.
|
||||
self.copy_(out);
|
||||
return self;
|
||||
}
|
||||
// When filling a number to 1-element CPU tensor, we want to skip
|
||||
// everything but manipulate data ptr directly.
|
||||
// Ideally this fast pass should be implemented in TensorIterator,
|
||||
|
|
|
|||
|
|
@ -484,6 +484,35 @@ class TestQuantizedTensor(TestCase):
|
|||
# Check to make sure the scale and zero_point has been copied.
|
||||
self.assertEqual(q, q2)
|
||||
|
||||
def test_qtensor_fill(self):
|
||||
numel = 10
|
||||
scale = 0.5
|
||||
zero_point = 10
|
||||
|
||||
ones = torch.ones(numel).to(torch.float)
|
||||
|
||||
types = [torch.qint8, torch.quint8, torch.qint32]
|
||||
fills = [-1, 1, 2**32] # positive, negative, overflow
|
||||
|
||||
# `fill_` uses `copy_(float)`, which doesn't support CUDA
|
||||
device = 'cpu'
|
||||
ones = ones.to(device)
|
||||
for qtype, fill_with in itertools.product(types, fills):
|
||||
q_filled = torch._empty_affine_quantized(
|
||||
[numel], scale=scale, zero_point=zero_point, device=device,
|
||||
dtype=qtype)
|
||||
q_filled.fill_(fill_with)
|
||||
int_repr = torch.quantize_per_tensor(ones * fill_with, scale,
|
||||
zero_point, qtype)
|
||||
fill_with = int_repr.dequantize()
|
||||
int_repr = int_repr.int_repr()
|
||||
|
||||
self.assertEqual(q_filled.int_repr(), int_repr)
|
||||
self.assertEqual(q_filled.dequantize(), fill_with)
|
||||
# Make sure the scale and zero_point don't change
|
||||
self.assertEqual(q_filled.q_scale(), scale)
|
||||
self.assertEqual(q_filled.q_zero_point(), zero_point)
|
||||
|
||||
def test_qtensor_view(self):
|
||||
scale, zero_point, dtype = 1.0, 2, torch.uint8
|
||||
for device in get_supported_device_types():
|
||||
|
|
|
|||
Loading…
Reference in a new issue