mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
quantized tensor: add preliminary support for advanced indexing, try 2 (#49346)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49346 This is less ambitious redo of https://github.com/pytorch/pytorch/pull/49129/. We make the ``` xq_slice = xq[:, [0], :, :] ``` indexing syntax work if `xq` is a quantized Tensor. For now, we are making the code not crash, with an in efficient `dq -> index -> q` implementation. A future PR can optimize performance by removing the unnecessary memory copies (which will require some non-trivial changes to TensorIterator). Test Plan: ``` python test/test_quantization.py TestQuantizedOps.test_advanced_indexing ``` Imported from OSS Reviewed By: jerryzh168 Differential Revision: D25539365 fbshipit-source-id: 98485875aaaf5743e1a940e170258057691be4fa
This commit is contained in:
parent
8954eb3f72
commit
a9137aeb06
4 changed files with 63 additions and 0 deletions
|
|
@ -290,6 +290,27 @@ Tensor index(const Tensor & self, TensorList indices) {
|
|||
return iter.output();
|
||||
}
|
||||
|
||||
Tensor quantized_index(const Tensor & self, TensorList indices) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
self.qscheme() == c10::kPerTensorAffine ||
|
||||
self.qscheme() == c10::kPerTensorSymmetric,
|
||||
"Indexing is only supported for per-Tensor quantized Tensors.");
|
||||
|
||||
// For now, this is a naive implementation which does dq -> index -> q.
|
||||
// TODO(future PR): improve performance by removing the copies.
|
||||
const auto& self_dq = self.dequantize();
|
||||
|
||||
TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
|
||||
|
||||
auto info = make_info(self_dq, indices);
|
||||
auto iter = make_index_iterator(info);
|
||||
index_stub(iter.device_type(), iter, info.indexed_sizes, info.indexed_strides);
|
||||
at::Tensor res = iter.output();
|
||||
|
||||
return at::quantize_per_tensor(
|
||||
res, self.q_scale(), self.q_zero_point(), self.scalar_type());
|
||||
}
|
||||
|
||||
Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices) {
|
||||
TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
|
||||
at::assert_no_internal_overlap(result);
|
||||
|
|
|
|||
|
|
@ -2200,6 +2200,7 @@
|
|||
variants: function, method
|
||||
dispatch:
|
||||
CPU, CUDA: index
|
||||
QuantizedCPU: quantized_index
|
||||
# NB: This function is special-cased in tools/autograd/gen_variable_type.py
|
||||
# NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp:
|
||||
# - Tensor Tensor::index(ArrayRef<TensorIndex> indices)
|
||||
|
|
|
|||
|
|
@ -691,6 +691,8 @@ inline DeviceType computeDeviceType(DispatchKey tid) {
|
|||
return DeviceType::Vulkan;
|
||||
} else if (tid == DispatchKey::Metal) {
|
||||
return DeviceType::Metal;
|
||||
} else if (tid == DispatchKey::QuantizedCPU) {
|
||||
return DeviceType::CPU;
|
||||
} else {
|
||||
AT_ASSERTM(false, "Unknown DispatchKey: ", tid);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2274,6 +2274,45 @@ class TestQuantizedOps(TestCase):
|
|||
result = torch.ops.quantized.linear_dynamic(X, w_packed)
|
||||
self.assertEqual(result.shape, (0, 2))
|
||||
|
||||
def test_advanced_indexing(self):
|
||||
"""
|
||||
Verifies that the x[:, [0], :, :] syntax works for quantized tensors.
|
||||
"""
|
||||
for dtype in (torch.qint8, torch.quint8, torch.qint32):
|
||||
scale = 0.1
|
||||
zp = 0
|
||||
x_q = torch.quantize_per_tensor(
|
||||
torch.randn(1, 4, 4, 4), scale, zp, dtype)
|
||||
# reference
|
||||
x_fp32 = x_q.dequantize()
|
||||
|
||||
# single dim, single index
|
||||
x_q_s1 = x_q[:, [0], :, :]
|
||||
x_fp32_s1 = x_fp32[:, [0], :, :]
|
||||
x_fp32_s1_ref = \
|
||||
torch.quantize_per_tensor(x_fp32_s1, scale, zp, dtype)
|
||||
self.assertEqual(x_q_s1, x_fp32_s1_ref)
|
||||
|
||||
# multiple dim, single index
|
||||
x_q_s2 = x_q[:, [0], [2], :]
|
||||
x_fp32_s2 = x_fp32[:, [0], [2], :]
|
||||
x_fp32_s2_ref = \
|
||||
torch.quantize_per_tensor(x_fp32_s2, scale, zp, dtype)
|
||||
self.assertEqual(x_q_s2, x_fp32_s2_ref)
|
||||
|
||||
# single dim, multiple indices
|
||||
x_q_s3 = x_q[:, [2, 0, 1], :, :]
|
||||
x_fp32_s3 = x_fp32[:, [2, 0, 1], :, :]
|
||||
x_fp32_s3_ref = \
|
||||
torch.quantize_per_tensor(x_fp32_s3, scale, zp, dtype)
|
||||
self.assertEqual(x_q_s3, x_fp32_s3_ref)
|
||||
|
||||
# multiple dim, multiple indices
|
||||
x_q_s4 = x_q[:, [2, 0, 1], :, [1]]
|
||||
x_fp32_s4 = x_fp32[:, [2, 0, 1], :, [1]]
|
||||
x_fp32_s4_ref = \
|
||||
torch.quantize_per_tensor(x_fp32_s4, scale, zp, dtype)
|
||||
self.assertEqual(x_q_s4, x_fp32_s4_ref)
|
||||
|
||||
|
||||
class TestDynamicQuantizedLinear(TestCase):
|
||||
|
|
|
|||
Loading…
Reference in a new issue