mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Quantized Tensor support copy (#28612)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28612 att Test Plan: python test/test_quantized_tensor.py Imported from OSS Differential Revision: D18255247 fbshipit-source-id: 814b12640fdf9d79b27482ee642ce430dbaeea68
This commit is contained in:
parent
41e42c34d6
commit
23193c155f
9 changed files with 64 additions and 4 deletions
|
|
@ -7,6 +7,7 @@
|
|||
#define Real QInt32
|
||||
#define RealUnderlying Int
|
||||
#define THQUANTIZED
|
||||
#define THQINT32
|
||||
#define TH_REAL_IS_BYTE
|
||||
#line 1 TH_GENERIC_FILE
|
||||
#include TH_GENERIC_FILE
|
||||
|
|
@ -15,6 +16,7 @@
|
|||
#undef Real
|
||||
#undef RealUnderlying
|
||||
#undef TH_REAL_IS_BYTE
|
||||
#undef THQINT32
|
||||
#undef THQUANTIZED
|
||||
|
||||
#ifndef THGenerateManyTypes
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#define Real QInt8
|
||||
#define RealUnderlying Char
|
||||
#define THQUANTIZED
|
||||
#define THQINT8
|
||||
#define TH_REAL_IS_BYTE
|
||||
#line 1 TH_GENERIC_FILE
|
||||
#include TH_GENERIC_FILE
|
||||
|
|
@ -15,6 +16,7 @@
|
|||
#undef Real
|
||||
#undef RealUnderlying
|
||||
#undef TH_REAL_IS_BYTE
|
||||
#undef THQINT8
|
||||
#undef THQUANTIZED
|
||||
|
||||
#ifndef THGenerateManyTypes
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#define Real QUInt8
|
||||
#define RealUnderlying Byte
|
||||
#define THQUANTIZED
|
||||
#define THQUINT8
|
||||
#define TH_REAL_IS_BYTE
|
||||
#line 1 TH_GENERIC_FILE
|
||||
#include TH_GENERIC_FILE
|
||||
|
|
@ -15,6 +16,7 @@
|
|||
#undef Real
|
||||
#undef RealUnderlying
|
||||
#undef TH_REAL_IS_BYTE
|
||||
#undef THQUINT8
|
||||
#undef THQUANTIZED
|
||||
|
||||
#ifndef THGenerateManyTypes
|
||||
|
|
|
|||
|
|
@ -35,6 +35,9 @@
|
|||
#define THLongStorage THStorage
|
||||
#define THBoolStorage THStorage
|
||||
#define THBFloat16Storage THStorage
|
||||
#define THQUInt8Storage THStorage
|
||||
#define THQInt8Storage THStorage
|
||||
#define THQInt32Storage THStorage
|
||||
|
||||
TH_API scalar_t* THStorage_(data)(const THStorage*);
|
||||
TH_API ptrdiff_t THStorage_(size)(const THStorage*);
|
||||
|
|
|
|||
|
|
@ -35,5 +35,14 @@ IMPLEMENT_THStorage_COPY(Double)
|
|||
IMPLEMENT_THStorage_COPY(Half)
|
||||
IMPLEMENT_THStorage_COPY(Bool)
|
||||
IMPLEMENT_THStorage_COPY(BFloat16)
|
||||
#ifdef THQUINT8
|
||||
IMPLEMENT_THStorage_COPY(QUInt8)
|
||||
#endif
|
||||
#ifdef THQINT8
|
||||
IMPLEMENT_THStorage_COPY(QInt8)
|
||||
#endif
|
||||
#ifdef THQINT32
|
||||
IMPLEMENT_THStorage_COPY(QInt32)
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -14,5 +14,14 @@ TH_API void THStorage_(copyDouble)(THStorage *storage, struct THDoubleStorage *s
|
|||
TH_API void THStorage_(copyHalf)(THStorage *storage, struct THHalfStorage *src);
|
||||
TH_API void THStorage_(copyBool)(THStorage *storage, struct THBoolStorage *src);
|
||||
TH_API void THStorage_(copyBFloat16)(THStorage *storage, struct THBFloat16Storage *src);
|
||||
#ifdef THQUINT8
|
||||
TH_API void THStorage_(copyQUInt8)(THStorage *storage, struct THQUInt8Storage *src);
|
||||
#endif
|
||||
#ifdef THQINT8
|
||||
TH_API void THStorage_(copyQInt8)(THStorage *storage, struct THQInt8Storage *src);
|
||||
#endif
|
||||
#ifdef THQINT32
|
||||
TH_API void THStorage_(copyQInt32)(THStorage *storage, struct THQInt32Storage *src);
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import numpy as np
|
|||
|
||||
import torch
|
||||
import io
|
||||
from copy import deepcopy
|
||||
|
||||
from common_utils import TestCase, run_tests
|
||||
import tempfile
|
||||
|
|
@ -252,6 +253,13 @@ class TestQuantizedTensor(TestCase):
|
|||
q.copy_(q2)
|
||||
# check scale and zero_points has been copied
|
||||
self.assertEqual(q, q2)
|
||||
# deep copy
|
||||
scale, zero_point, dtype = 1.0, 2, torch.uint8
|
||||
q_int = torch.randint(0, 100, [3, 5], dtype=dtype)
|
||||
scale, zero_point = 2.0, 3
|
||||
q = torch._make_per_tensor_quantized_tensor(q_int, scale=scale, zero_point=zero_point)
|
||||
qc = deepcopy(q)
|
||||
self.assertEqual(qc, q)
|
||||
|
||||
def test_qtensor_clone(self):
|
||||
numel = 10
|
||||
|
|
@ -322,7 +330,6 @@ class TestQuantizedTensor(TestCase):
|
|||
c = b.reshape(1, 4, 2, 3)
|
||||
|
||||
def test_qscheme_pickle(self):
|
||||
|
||||
f = Foo()
|
||||
buf = io.BytesIO()
|
||||
torch.save(f, buf)
|
||||
|
|
@ -332,5 +339,6 @@ class TestQuantizedTensor(TestCase):
|
|||
|
||||
self.assertEqual(f2.qscheme, torch.per_tensor_symmetric)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -318,6 +318,15 @@ void THPStorage_(initCopyMethods)()
|
|||
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPDoubleStorageType, h, &THWStorage_(copyDouble));
|
||||
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPBoolStorageType, h, &THWStorage_(copyBool));
|
||||
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPBFloat16StorageType, h, &THWStorage_(copyBFloat16));
|
||||
#ifdef THQUINT8
|
||||
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPQUInt8StorageType, h, &THWStorage_(copyQUInt8));
|
||||
#endif
|
||||
#ifdef THQINT8
|
||||
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPQInt8StorageType, h, &THWStorage_(copyQInt8));
|
||||
#endif
|
||||
#ifdef THQINT32
|
||||
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPQInt32StorageType, h, &THWStorage_(copyQInt32));
|
||||
#endif
|
||||
#ifdef THC_GENERIC_FILE
|
||||
// copy from GPU types
|
||||
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPByteStorageType, h, &THWStorage_(copyCudaByte));
|
||||
|
|
|
|||
|
|
@ -50,10 +50,26 @@ class Tensor(torch._C._TensorBase):
|
|||
new_tensor = self.clone()
|
||||
else:
|
||||
new_storage = self.storage().__deepcopy__(memo)
|
||||
new_tensor = self.new()
|
||||
new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride())
|
||||
if self.is_quantized:
|
||||
if self.qscheme() == torch.per_tensor_affine:
|
||||
quantizer_params = self.qscheme(), self.q_scale(), self.q_zero_point()
|
||||
elif self.qscheme() == torch.per_channel_affine:
|
||||
quantizer_params = self.qscheme(), self.q_per_channel_scales(), self.q_per_channel_zero_points(), self.q_per_channel_axis()
|
||||
else:
|
||||
raise RuntimeError("Unsupported qscheme {} in deepcopy".format(self.qscheme()))
|
||||
new_tensor = torch._utils._rebuild_qtensor(
|
||||
new_storage,
|
||||
self.storage_offset(),
|
||||
self.size(),
|
||||
self.stride(),
|
||||
quantizer_params,
|
||||
self.requires_grad,
|
||||
self._backward_hooks)
|
||||
else:
|
||||
new_tensor = self.new()
|
||||
new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride())
|
||||
new_tensor.requires_grad = self.requires_grad
|
||||
memo[id(self)] = new_tensor
|
||||
new_tensor.requires_grad = self.requires_grad
|
||||
return new_tensor
|
||||
|
||||
def __reduce_ex__(self, proto):
|
||||
|
|
|
|||
Loading…
Reference in a new issue