mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[ao][sparsity] Base class for Data Sparsifier
Base Data Sparsifier class for all Data sparsifiers. The abstract class accepts raw torch tensors / embedding / embedding bags (refer to SUPPORTED_TYPES above) to prepare for sparsification. In this case, mask (and parametrizations) is owned by the class and not by the user. Specifically, the container object inside the class maintains the mask and parametrizations of the input data Test Plan: ```python test/test_ao_sparsity.py TestBaseDataSparsifier``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/79251 Approved by: https://github.com/z-a-f, https://github.com/HDCharles
This commit is contained in:
parent
2539efb842
commit
15828bcfd7
6 changed files with 97 additions and 0 deletions
|
|
@ -1046,5 +1046,6 @@ Please take a look at `Limitations of Symbolic Tracing <https://docs-preview.pyt
|
|||
.. py:module:: torch.ao.sparsity
|
||||
.. py:module:: torch.ao.sparsity.experimental
|
||||
.. py:module:: torch.ao.sparsity.experimental.pruner
|
||||
.. py:module:: torch.ao.sparsity.experimental.data_sparsifier
|
||||
.. py:module:: torch.ao.sparsity.scheduler
|
||||
.. py:module:: torch.ao.sparsity.sparsifier
|
||||
|
|
|
|||
13
test/ao/sparsity/test_data_sparsifier.py
Normal file
13
test/ao/sparsity/test_data_sparsifier.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
import logging
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
|
||||
|
||||
|
||||
class TestBaseDataSparsifier(TestCase):
|
||||
def test_constructor(self):
|
||||
pass # Nothing to test so far
|
||||
|
|
@ -24,5 +24,8 @@ from ao.sparsity.test_scheduler import TestScheduler # noqa: F401
|
|||
# Composability
|
||||
from ao.sparsity.test_composability import TestComposability # noqa: F401
|
||||
|
||||
# Data Sparsifier
|
||||
from ao.sparsity.test_data_sparsifier import TestBaseDataSparsifier # noqa: F401
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -26,3 +26,6 @@ from .experimental.pruner.parametrization import BiasHook
|
|||
|
||||
# Pruner
|
||||
from .experimental.pruner.base_pruner import BasePruner
|
||||
|
||||
# Data Sparsifier
|
||||
from .experimental.data_sparsifier.base_data_sparsifier import BaseDataSparsifier
|
||||
|
|
|
|||
|
|
@ -0,0 +1,77 @@
|
|||
import abc
|
||||
from typing import Optional, Tuple, List, Any, Dict
|
||||
from ...sparsifier import base_sparsifier
|
||||
from collections import defaultdict
|
||||
from torch import nn
|
||||
__all__ = ['BaseDataSparsifier']
|
||||
|
||||
|
||||
class _Container(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
class BaseDataSparsifier(base_sparsifier.BaseSparsifier):
|
||||
r"""
|
||||
Base Data Sparsifier class for all Data sparsifiers.
|
||||
The abstract class accepts raw torch tensors / embedding / embedding bags (refer to SUPPORTED_TYPES above)
|
||||
to prepare for sparsification.
|
||||
In this case, mask (and parametrizations) is owned by the class and not by the user.
|
||||
Specifically, the container object inside the class maintains the mask and parametrizations of the input data
|
||||
|
||||
Args:
|
||||
data_list (list of tuples)
|
||||
list of (name, data) tuples to sparsify. Lookup SUPPORTED_TYPES
|
||||
for type of data. Internally, a container module handles the data sparsification.
|
||||
|
||||
defaults (dict)
|
||||
default configurations will be attached to the
|
||||
configuration. Only the keys that don't exist in the `config` will
|
||||
be updated.
|
||||
"""
|
||||
def __init__(self, data_list: Optional[List[Tuple[str, Any]]] = None, **defaults):
|
||||
super().__init__(defaults=defaults)
|
||||
|
||||
self._container = _Container()
|
||||
|
||||
self.data_groups: Dict[str, Dict] = defaultdict(dict) # name -> {**config}
|
||||
if data_list is not None:
|
||||
# add data with default config here
|
||||
[self.add_data(name, data, **self.defaults) for name, data in data_list]
|
||||
|
||||
def prepare(self):
|
||||
raise NotImplementedError("this function is undefined for this class")
|
||||
|
||||
def add_data(self, name: str, data, **config):
|
||||
r""" Configures and parametrizes the internal container model with name and data
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_data(self, name: str):
|
||||
r"""Returns weight tensor (or data) based on the input name.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
r"""String representation of an object when printed
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_mask(self, name: str):
|
||||
r"""Returns the mask currently associated with the named tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def squash_mask(self, *args, **kwargs):
|
||||
r"""Squashes the sparse masks into the appropriate tensors.
|
||||
"""
|
||||
pass
|
||||
|
||||
def step(self):
|
||||
r"""Updates the mask for all the named data.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_mask(self, name, data, **kwargs):
|
||||
pass
|
||||
Loading…
Reference in a new issue