From 15828bcfd7c8f013cd9dff3ef7b71655925e16fa Mon Sep 17 00:00:00 2001 From: macandro96 Date: Wed, 15 Jun 2022 21:30:00 +0000 Subject: [PATCH] [ao][sparsity] Base class for Data Sparsifier Base Data Sparsifier class for all Data sparsifiers. The abstract class accepts raw torch tensors / embedding / embedding bags (refer to SUPPORTED_TYPES above) to prepare for sparsification. In this case, mask (and parametrizations) is owned by the class and not by the user. Specifically, the container object inside the class maintains the mask and parametrizations of the input data Test Plan: ```python test/test_ao_sparsity.py TestBaseDataSparsifier``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/79251 Approved by: https://github.com/z-a-f, https://github.com/HDCharles --- docs/source/quantization.rst | 1 + test/ao/sparsity/test_data_sparsifier.py | 13 ++++ test/test_ao_sparsity.py | 3 + torch/ao/sparsity/__init__.py | 3 + .../experimental/data_sparsifier/__init__.py | 0 .../data_sparsifier/base_data_sparsifier.py | 77 +++++++++++++++++++ 6 files changed, 97 insertions(+) create mode 100644 test/ao/sparsity/test_data_sparsifier.py create mode 100644 torch/ao/sparsity/experimental/data_sparsifier/__init__.py create mode 100644 torch/ao/sparsity/experimental/data_sparsifier/base_data_sparsifier.py diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst index 5bdd6948761..3be4d5a390e 100644 --- a/docs/source/quantization.rst +++ b/docs/source/quantization.rst @@ -1046,5 +1046,6 @@ Please take a look at `Limitations of Symbolic Tracing {**config} + if data_list is not None: + # add data with default config here + [self.add_data(name, data, **self.defaults) for name, data in data_list] + + def prepare(self): + raise NotImplementedError("this function is undefined for this class") + + def add_data(self, name: str, data, **config): + r""" Configures and parametrizes the internal container model with name and data + """ + pass + + def get_data(self, name: str): + r"""Returns weight tensor (or data) based on the input name. + """ + pass + + def __repr__(self): + r"""String representation of an object when printed + """ + pass + + def get_mask(self, name: str): + r"""Returns the mask currently associated with the named tensor. + """ + pass + + def squash_mask(self, *args, **kwargs): + r"""Squashes the sparse masks into the appropriate tensors. + """ + pass + + def step(self): + r"""Updates the mask for all the named data. + """ + pass + + @abc.abstractmethod + def update_mask(self, name, data, **kwargs): + pass