Fix batch size zero for QNNPACK linear_dynamic (#40588)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40588

Two bugs were preventing this from working.  One was a divide by zero
when multithreading was enabled, fixed similarly to the fix for static
quantized linear in the previous commit.  The other was computation of
min and max to determine qparams.  FBGEMM uses [0,0] for [min,max] of
empty input, do the same.

Test Plan: Added a unit test.

Differential Revision: D22264415

Pulled By: dreiss

fbshipit-source-id: 6ca9cf48107dd998ef4834e5540279a8826bc754
This commit is contained in:
David Reiss 2020-06-29 16:26:57 -07:00 committed by Facebook GitHub Bot
parent 14145f9775
commit 21de450fcb
3 changed files with 22 additions and 2 deletions

View file

@ -241,8 +241,17 @@ at::Tensor PackedLinearWeightsQnnp::apply_dynamic_impl(at::Tensor input) {
// Calculate statistics for quantization of input Tensor
// TODO: optimized kernel
float x_min = input_contig.min().item<float>();
float x_max = input_contig.max().item<float>();
float x_min;
float x_max;
if (input.numel() > 0) {
x_min = input_contig.min().item<float>();
x_max = input_contig.max().item<float>();
} else {
// On empty input, no output data will be generated,
// so use arbitrary qparams.
x_min = 0;
x_max = 0;
}
auto q_params = quant_utils::ChooseQuantizationParams(
/*min=*/x_min,

View file

@ -100,6 +100,12 @@ enum pytorch_qnnp_status qnnpackLinearDynamic(
.ukernel = pytorch_qnnp_params.q8conv.gemm_dq,
};
if (output_size == 0) {
// pthreadpool can tolerate a range of 0, but not a tile of 0.
// We use output_size as a tile size, so bail here if it's 0.
return pytorch_qnnp_status_success;
}
pthreadpool_compute_4d_tiled(
threadpool,
(pthreadpool_function_4d_tiled_t)compute_q8gemm_dq,

View file

@ -2114,6 +2114,11 @@ class TestQuantizedOps(TestCase):
result = torch.ops.quantized.linear(qX, w_packed, 1.0, 0)
self.assertEqual(result.shape, (0, 2))
# dynamic linear
result = torch.ops.quantized.linear_dynamic(X, w_packed)
self.assertEqual(result.shape, (0, 2))
class TestDynamicQuantizedLinear(TestCase):
"""Tests the correctness of the dynamic quantized linear and linear_relu op."""