mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Fix batch size zero for QNNPACK linear_dynamic (#40588)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40588 Two bugs were preventing this from working. One was a divide by zero when multithreading was enabled, fixed similarly to the fix for static quantized linear in the previous commit. The other was computation of min and max to determine qparams. FBGEMM uses [0,0] for [min,max] of empty input, do the same. Test Plan: Added a unit test. Differential Revision: D22264415 Pulled By: dreiss fbshipit-source-id: 6ca9cf48107dd998ef4834e5540279a8826bc754
This commit is contained in:
parent
14145f9775
commit
21de450fcb
3 changed files with 22 additions and 2 deletions
|
|
@ -241,8 +241,17 @@ at::Tensor PackedLinearWeightsQnnp::apply_dynamic_impl(at::Tensor input) {
|
|||
|
||||
// Calculate statistics for quantization of input Tensor
|
||||
// TODO: optimized kernel
|
||||
float x_min = input_contig.min().item<float>();
|
||||
float x_max = input_contig.max().item<float>();
|
||||
float x_min;
|
||||
float x_max;
|
||||
if (input.numel() > 0) {
|
||||
x_min = input_contig.min().item<float>();
|
||||
x_max = input_contig.max().item<float>();
|
||||
} else {
|
||||
// On empty input, no output data will be generated,
|
||||
// so use arbitrary qparams.
|
||||
x_min = 0;
|
||||
x_max = 0;
|
||||
}
|
||||
|
||||
auto q_params = quant_utils::ChooseQuantizationParams(
|
||||
/*min=*/x_min,
|
||||
|
|
|
|||
|
|
@ -100,6 +100,12 @@ enum pytorch_qnnp_status qnnpackLinearDynamic(
|
|||
.ukernel = pytorch_qnnp_params.q8conv.gemm_dq,
|
||||
};
|
||||
|
||||
if (output_size == 0) {
|
||||
// pthreadpool can tolerate a range of 0, but not a tile of 0.
|
||||
// We use output_size as a tile size, so bail here if it's 0.
|
||||
return pytorch_qnnp_status_success;
|
||||
}
|
||||
|
||||
pthreadpool_compute_4d_tiled(
|
||||
threadpool,
|
||||
(pthreadpool_function_4d_tiled_t)compute_q8gemm_dq,
|
||||
|
|
|
|||
|
|
@ -2114,6 +2114,11 @@ class TestQuantizedOps(TestCase):
|
|||
result = torch.ops.quantized.linear(qX, w_packed, 1.0, 0)
|
||||
self.assertEqual(result.shape, (0, 2))
|
||||
|
||||
# dynamic linear
|
||||
result = torch.ops.quantized.linear_dynamic(X, w_packed)
|
||||
self.assertEqual(result.shape, (0, 2))
|
||||
|
||||
|
||||
|
||||
class TestDynamicQuantizedLinear(TestCase):
|
||||
"""Tests the correctness of the dynamic quantized linear and linear_relu op."""
|
||||
|
|
|
|||
Loading…
Reference in a new issue