From 6e3fdd77ca7ae4e0c2ac32a91c230e556b6db0d3 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Tue, 23 Jun 2020 08:56:45 -0700 Subject: [PATCH] quant docs: add and clean up GroupNorm (#40343) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40343 Cleans up the quantized GroupNorm docstring and adds it to quantization docs. Test Plan: * build on Mac OS and inspect Differential Revision: D22152635 Pulled By: vkuzo fbshipit-source-id: 5553b841c7a5d77f1467f0c40657db9e5d730a12 --- docs/source/quantization.rst | 6 ++++++ torch/nn/quantized/modules/normalization.py | 7 ++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst index 21fbdbabb56..078678253e6 100644 --- a/docs/source/quantization.rst +++ b/docs/source/quantization.rst @@ -355,6 +355,7 @@ Quantized version of standard NN layers. quantized representation of 6 * :class:`~torch.nn.quantized.Hardswish` — Hardswish * :class:`~torch.nn.quantized.LayerNorm` — LayerNorm. *Note: performance on ARM is not optimized*. +* :class:`~torch.nn.quantized.GroupNorm` — GroupNorm. *Note: performance on ARM is not optimized*. ``torch.nn.quantized.dynamic`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -793,6 +794,11 @@ LayerNorm .. autoclass:: LayerNorm :members: +GroupNorm +~~~~~~~~~~~~~~~ +.. autoclass:: GroupNorm + :members: + torch.nn.quantized.dynamic ---------------------------- diff --git a/torch/nn/quantized/modules/normalization.py b/torch/nn/quantized/modules/normalization.py index dab687ce84e..43619942830 100644 --- a/torch/nn/quantized/modules/normalization.py +++ b/torch/nn/quantized/modules/normalization.py @@ -42,7 +42,12 @@ class LayerNorm(torch.nn.LayerNorm): return new_mod class GroupNorm(torch.nn.GroupNorm): - r"""This is the quantized version of `torch.nn.GroupNorm`. + r"""This is the quantized version of :class:`~torch.nn.GroupNorm`. + + Additional args: + * **scale** - quantization scale of the output, type: double. + * **zero_point** - quantization zero point of the output, type: long. + """ __constants__ = ['num_groups', 'num_channels', 'eps', 'affine']