From 9d44b3d110ab3b4e6f6ad376858cdfd2ba4faa1b Mon Sep 17 00:00:00 2001 From: Zafar Date: Wed, 18 May 2022 15:43:04 -0700 Subject: [PATCH] [quant][refactor] Remove the base class from __all__ In general, if we are expecting the users to use the base class, such as `_ConvNd`, we should rename it to something like `BaseConv`. However, because this base class is only used inside of the AO packages, there is no need to expose it to the users. Test Plan: ``` python test/test_quantization.py python test/test_module_init.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/77344 Approved by: https://github.com/jerryzh168 --- test/test_module_init.py | 1 - torch/nn/quantized/modules/__init__.py | 3 +-- torch/testing/_internal/common_modules.py | 1 - 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/test/test_module_init.py b/test/test_module_init.py index b568f210e55..8dcbc694dd9 100644 --- a/test/test_module_init.py +++ b/test/test_module_init.py @@ -365,7 +365,6 @@ def generate_tests(test_cls, constructor_arg_db): torch.nn.Module, torch.nn.Container, # deprecated torch.nn.NLLLoss2d, # deprecated - torch.nn.quantized._ConvNd, # base class in __all__ for some reason # TODO: Remove these 2 from this list once the ASan issue is fixed. # See https://github.com/pytorch/pytorch/issues/55396 torch.nn.quantized.Embedding, diff --git a/torch/nn/quantized/modules/__init__.py b/torch/nn/quantized/modules/__init__.py index 4a899ef2607..8004b52cc65 100644 --- a/torch/nn/quantized/modules/__init__.py +++ b/torch/nn/quantized/modules/__init__.py @@ -6,7 +6,7 @@ from .dropout import Dropout from .batchnorm import BatchNorm2d, BatchNorm3d from .normalization import LayerNorm, GroupNorm, InstanceNorm1d, \ InstanceNorm2d, InstanceNorm3d -from .conv import _ConvNd, Conv1d, Conv2d, Conv3d +from .conv import Conv1d, Conv2d, Conv3d from .conv import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d from .linear import Linear from .embedding_ops import Embedding, EmbeddingBag @@ -91,7 +91,6 @@ class DeQuantize(torch.nn.Module): __all__ = [ 'BatchNorm2d', 'BatchNorm3d', - '_ConvNd', 'Conv1d', 'Conv2d', 'Conv3d', diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index 9917988817a..63144060efc 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -31,7 +31,6 @@ MODULES_TO_SKIP: Set[Type] = { torch.nn.Module, # abstract base class torch.nn.Container, # deprecated torch.nn.NLLLoss2d, # deprecated - torch.nn.quantized.modules._ConvNd, # abstract base class torch.nn.quantized.MaxPool2d, # aliases to nn.MaxPool2d }