[ao][sparsity] Base class for Data Sparsifier

Base Data Sparsifier class for all Data sparsifiers.
The abstract class accepts raw torch tensors / embedding / embedding bags (refer to SUPPORTED_TYPES above)
to prepare for sparsification.
In this case, mask (and parametrizations) is owned by the class and not by the user.
Specifically, the container object inside the class maintains the mask and parametrizations of the input data

Test Plan:
```python test/test_ao_sparsity.py TestBaseDataSparsifier```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79251

Approved by: https://github.com/z-a-f, https://github.com/HDCharles
This commit is contained in:
macandro96 2022-06-15 21:30:00 +00:00 committed by PyTorch MergeBot
parent 2539efb842
commit 15828bcfd7
6 changed files with 97 additions and 0 deletions

View file

@ -1046,5 +1046,6 @@ Please take a look at `Limitations of Symbolic Tracing <https://docs-preview.pyt
.. py:module:: torch.ao.sparsity
.. py:module:: torch.ao.sparsity.experimental
.. py:module:: torch.ao.sparsity.experimental.pruner
.. py:module:: torch.ao.sparsity.experimental.data_sparsifier
.. py:module:: torch.ao.sparsity.scheduler
.. py:module:: torch.ao.sparsity.sparsifier

View file

@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
# Owner(s): ["module: unknown"]
import logging
from torch.testing._internal.common_utils import TestCase
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
class TestBaseDataSparsifier(TestCase):
def test_constructor(self):
pass # Nothing to test so far

View file

@ -24,5 +24,8 @@ from ao.sparsity.test_scheduler import TestScheduler # noqa: F401
# Composability
from ao.sparsity.test_composability import TestComposability # noqa: F401
# Data Sparsifier
from ao.sparsity.test_data_sparsifier import TestBaseDataSparsifier # noqa: F401
if __name__ == '__main__':
run_tests()

View file

@ -26,3 +26,6 @@ from .experimental.pruner.parametrization import BiasHook
# Pruner
from .experimental.pruner.base_pruner import BasePruner
# Data Sparsifier
from .experimental.data_sparsifier.base_data_sparsifier import BaseDataSparsifier

View file

@ -0,0 +1,77 @@
import abc
from typing import Optional, Tuple, List, Any, Dict
from ...sparsifier import base_sparsifier
from collections import defaultdict
from torch import nn
__all__ = ['BaseDataSparsifier']
class _Container(nn.Module):
def __init__(self):
super().__init__()
class BaseDataSparsifier(base_sparsifier.BaseSparsifier):
r"""
Base Data Sparsifier class for all Data sparsifiers.
The abstract class accepts raw torch tensors / embedding / embedding bags (refer to SUPPORTED_TYPES above)
to prepare for sparsification.
In this case, mask (and parametrizations) is owned by the class and not by the user.
Specifically, the container object inside the class maintains the mask and parametrizations of the input data
Args:
data_list (list of tuples)
list of (name, data) tuples to sparsify. Lookup SUPPORTED_TYPES
for type of data. Internally, a container module handles the data sparsification.
defaults (dict)
default configurations will be attached to the
configuration. Only the keys that don't exist in the `config` will
be updated.
"""
def __init__(self, data_list: Optional[List[Tuple[str, Any]]] = None, **defaults):
super().__init__(defaults=defaults)
self._container = _Container()
self.data_groups: Dict[str, Dict] = defaultdict(dict) # name -> {**config}
if data_list is not None:
# add data with default config here
[self.add_data(name, data, **self.defaults) for name, data in data_list]
def prepare(self):
raise NotImplementedError("this function is undefined for this class")
def add_data(self, name: str, data, **config):
r""" Configures and parametrizes the internal container model with name and data
"""
pass
def get_data(self, name: str):
r"""Returns weight tensor (or data) based on the input name.
"""
pass
def __repr__(self):
r"""String representation of an object when printed
"""
pass
def get_mask(self, name: str):
r"""Returns the mask currently associated with the named tensor.
"""
pass
def squash_mask(self, *args, **kwargs):
r"""Squashes the sparse masks into the appropriate tensors.
"""
pass
def step(self):
r"""Updates the mask for all the named data.
"""
pass
@abc.abstractmethod
def update_mask(self, name, data, **kwargs):
pass