mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[ao] Added function to inform dynamic vs static appropriate
Summary: The _detect_dynamic_vs_static function was added to take in a prepared fx graph model that already had ModelReportObservers built into it and uses the collected information to determine whether input and output are stationary or non-stationary and provides feedback on whether to make linear modules static or dynamic based on this information. This PR will be followed up soon with another PR that will more rigoursly test the whole end to end performance of this system, which is primarily how the function in this PR will be tested for functionality, which is why this one only has 1 test. Test Plan: python test/quantization/fx/test_model_report_fx.py TestModelReportDetectDynamicStatic Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/79326 Approved by: https://github.com/HDCharles
This commit is contained in:
parent
e10b762537
commit
38952d9350
3 changed files with 300 additions and 90 deletions
|
|
@ -2,27 +2,23 @@
|
|||
# Owner(s): ["oncall: quantization"]
|
||||
|
||||
import torch
|
||||
import torch.ao.quantization.quantize_fx
|
||||
from torch.ao.quantization import QConfig, QConfigMapping
|
||||
from torch.ao.quantization.fx._model_report._detector import _detect_per_channel
|
||||
import torch.ao.quantization.quantize_fx as quantize_fx
|
||||
import torch.nn.functional as F
|
||||
from torch.ao.quantization.fx._model_report.model_report_observer import (
|
||||
ModelReportObserver,
|
||||
)
|
||||
from torch.ao.quantization.observer import (
|
||||
default_per_channel_weight_observer,
|
||||
HistogramObserver,
|
||||
)
|
||||
from torch.ao.quantization import QConfig, QConfigMapping
|
||||
from torch.ao.quantization.fx._model_report._detector import _detect_dynamic_vs_static, _detect_per_channel
|
||||
from torch.ao.quantization.fx._model_report.model_report_observer import ModelReportObserver
|
||||
from torch.ao.quantization.observer import HistogramObserver, default_per_channel_weight_observer
|
||||
from torch.nn.intrinsic.modules.fused import ConvReLU2d, LinearReLU
|
||||
from torch.testing._internal.common_quantization import (
|
||||
ConvModel,
|
||||
QuantizationTestCase,
|
||||
SingleLayerLinearModel,
|
||||
TwoLayerLinearModel,
|
||||
skipIfNoFBGEMM,
|
||||
skipIfNoQNNPACK,
|
||||
TwoLayerLinearModel,
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
Partition of input domain:
|
||||
|
||||
|
|
@ -41,7 +37,9 @@ There are possible changes / suggestions, there are no changes / suggestions
|
|||
"""
|
||||
|
||||
# Default output for string if no optimizations are possible
|
||||
DEFAULT_NO_OPTIMS_ANSWER_STRING = "Further Optimizations for backend {}: \nNo further per_channel optimizations possible."
|
||||
DEFAULT_NO_OPTIMS_ANSWER_STRING = (
|
||||
"Further Optimizations for backend {}: \nNo further per_channel optimizations possible."
|
||||
)
|
||||
|
||||
# Example Sequential Model with multiple Conv and Linear with nesting involved
|
||||
NESTED_CONV_LINEAR_EXAMPLE = torch.nn.Sequential(
|
||||
|
|
@ -71,14 +69,12 @@ FUSION_CONV_LINEAR_EXAMPLE = torch.nn.Sequential(
|
|||
)
|
||||
|
||||
|
||||
class TestModelReportFxDetector(QuantizationTestCase):
|
||||
class TestFxModelReportDetector(QuantizationTestCase):
|
||||
|
||||
"""Prepares and callibrate the model"""
|
||||
|
||||
def _prepare_model_and_run_input(self, model, q_config_mapping, input):
|
||||
model_prep = torch.ao.quantization.quantize_fx.prepare_fx(
|
||||
model, q_config_mapping, input
|
||||
) # prep model
|
||||
model_prep = torch.ao.quantization.quantize_fx.prepare_fx(model, q_config_mapping, input) # prep model
|
||||
model_prep(input).sum() # callibrate the model
|
||||
return model_prep
|
||||
|
||||
|
|
@ -95,14 +91,10 @@ class TestModelReportFxDetector(QuantizationTestCase):
|
|||
torch.backends.quantized.engine = "onednn"
|
||||
|
||||
q_config_mapping = QConfigMapping()
|
||||
q_config_mapping.set_global(
|
||||
torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)
|
||||
)
|
||||
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
|
||||
|
||||
input = torch.randn(1, 3, 10, 10)
|
||||
prepared_model = self._prepare_model_and_run_input(
|
||||
ConvModel(), q_config_mapping, input
|
||||
)
|
||||
prepared_model = self._prepare_model_and_run_input(ConvModel(), q_config_mapping, input)
|
||||
|
||||
# run the detector
|
||||
optims_str, per_channel_info = _detect_per_channel(prepared_model)
|
||||
|
|
@ -119,9 +111,7 @@ class TestModelReportFxDetector(QuantizationTestCase):
|
|||
per_channel_info["per_channel_status"]["conv"]["per_channel_supported"],
|
||||
True,
|
||||
)
|
||||
self.assertEqual(
|
||||
per_channel_info["per_channel_status"]["conv"]["per_channel_used"], True
|
||||
)
|
||||
self.assertEqual(per_channel_info["per_channel_status"]["conv"]["per_channel_used"], True)
|
||||
|
||||
"""Case includes:
|
||||
Multiple conv or linear
|
||||
|
|
@ -137,9 +127,7 @@ class TestModelReportFxDetector(QuantizationTestCase):
|
|||
torch.backends.quantized.engine = "qnnpack"
|
||||
|
||||
q_config_mapping = QConfigMapping()
|
||||
q_config_mapping.set_global(
|
||||
torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)
|
||||
)
|
||||
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
|
||||
|
||||
prepared_model = self._prepare_model_and_run_input(
|
||||
TwoLayerLinearModel(),
|
||||
|
|
@ -238,9 +226,7 @@ class TestModelReportFxDetector(QuantizationTestCase):
|
|||
elif "conv" in key:
|
||||
self.assertEqual(module_entry["per_channel_used"], True)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Should only contain conv and linear layers as key values"
|
||||
)
|
||||
raise ValueError("Should only contain conv and linear layers as key values")
|
||||
|
||||
"""Case includes:
|
||||
Multiple conv or linear
|
||||
|
|
@ -256,9 +242,7 @@ class TestModelReportFxDetector(QuantizationTestCase):
|
|||
torch.backends.quantized.engine = "qnnpack"
|
||||
|
||||
q_config_mapping = QConfigMapping()
|
||||
q_config_mapping.set_global(
|
||||
torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)
|
||||
)
|
||||
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
|
||||
|
||||
prepared_model = self._prepare_model_and_run_input(
|
||||
NESTED_CONV_LINEAR_EXAMPLE,
|
||||
|
|
@ -299,9 +283,7 @@ class TestModelReportFxDetector(QuantizationTestCase):
|
|||
torch.backends.quantized.engine = "qnnpack"
|
||||
|
||||
q_config_mapping = QConfigMapping()
|
||||
q_config_mapping.set_global(
|
||||
torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)
|
||||
)
|
||||
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
|
||||
|
||||
prepared_model = self._prepare_model_and_run_input(
|
||||
LAZY_CONV_LINEAR_EXAMPLE,
|
||||
|
|
@ -342,9 +324,7 @@ class TestModelReportFxDetector(QuantizationTestCase):
|
|||
torch.backends.quantized.engine = "fbgemm"
|
||||
|
||||
q_config_mapping = QConfigMapping()
|
||||
q_config_mapping.set_global(
|
||||
torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)
|
||||
)
|
||||
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
|
||||
|
||||
prepared_model = self._prepare_model_and_run_input(
|
||||
FUSION_CONV_LINEAR_EXAMPLE,
|
||||
|
|
@ -409,9 +389,7 @@ class TestModelReportFxDetector(QuantizationTestCase):
|
|||
|
||||
# model must be in eval mode for fusion
|
||||
model_fp32.eval()
|
||||
model_fp32_fused = torch.quantization.fuse_modules(
|
||||
model_fp32, [["conv", "bn", "relu"]]
|
||||
)
|
||||
model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [["conv", "bn", "relu"]])
|
||||
|
||||
# model must be set to train mode for QAT logic to work
|
||||
model_fp32_fused.train()
|
||||
|
|
@ -454,7 +432,7 @@ Partition on Output
|
|||
"""
|
||||
|
||||
|
||||
class TestModelReportObserver(QuantizationTestCase):
|
||||
class TestFxModelReportObserver(QuantizationTestCase):
|
||||
class NestedModifiedSingleLayerLinear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
@ -488,12 +466,8 @@ class TestModelReportObserver(QuantizationTestCase):
|
|||
getattr(model, "obs1").average_batch_activation_range,
|
||||
torch.tensor(float(0)),
|
||||
)
|
||||
self.assertEqual(
|
||||
getattr(model, "obs1").epoch_activation_min, torch.tensor(float("inf"))
|
||||
)
|
||||
self.assertEqual(
|
||||
getattr(model, "obs1").epoch_activation_max, torch.tensor(float("-inf"))
|
||||
)
|
||||
self.assertEqual(getattr(model, "obs1").epoch_activation_min, torch.tensor(float("inf")))
|
||||
self.assertEqual(getattr(model, "obs1").epoch_activation_max, torch.tensor(float("-inf")))
|
||||
|
||||
# loop through the batches and run through
|
||||
for index, batch in enumerate(split_up_data):
|
||||
|
|
@ -503,9 +477,7 @@ class TestModelReportObserver(QuantizationTestCase):
|
|||
|
||||
# get general info about the batch and the model to use later
|
||||
batch_min, batch_max = torch.aminmax(batch)
|
||||
current_average_range = getattr(
|
||||
model, "obs1"
|
||||
).average_batch_activation_range
|
||||
current_average_range = getattr(model, "obs1").average_batch_activation_range
|
||||
current_epoch_min = getattr(model, "obs1").epoch_activation_min
|
||||
current_epoch_max = getattr(model, "obs1").epoch_activation_max
|
||||
|
||||
|
|
@ -513,9 +485,9 @@ class TestModelReportObserver(QuantizationTestCase):
|
|||
model(ex_input)
|
||||
|
||||
# check that average batch activation range updated correctly
|
||||
correct_updated_value = (
|
||||
current_average_range * num_tracked_so_far + (batch_max - batch_min)
|
||||
) / (num_tracked_so_far + 1)
|
||||
correct_updated_value = (current_average_range * num_tracked_so_far + (batch_max - batch_min)) / (
|
||||
num_tracked_so_far + 1
|
||||
)
|
||||
self.assertEqual(
|
||||
getattr(model, "obs1").average_batch_activation_range,
|
||||
correct_updated_value,
|
||||
|
|
@ -659,9 +631,9 @@ class TestModelReportObserver(QuantizationTestCase):
|
|||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
class ThreeOps(torch.nn.Module):
|
||||
class ModifiedThreeOps(torch.nn.Module):
|
||||
def __init__(self, batch_norm_dim):
|
||||
super(ThreeOps, self).__init__()
|
||||
super(ModifiedThreeOps, self).__init__()
|
||||
self.obs1 = ModelReportObserver()
|
||||
self.linear = torch.nn.Linear(7, 3, 2)
|
||||
self.obs2 = ModelReportObserver()
|
||||
|
|
@ -688,9 +660,9 @@ class TestModelReportObserver(QuantizationTestCase):
|
|||
super(HighDimensionNet, self).__init__()
|
||||
self.obs1 = ModelReportObserver()
|
||||
self.fc1 = torch.nn.Linear(3, 7)
|
||||
self.block1 = ThreeOps(3)
|
||||
self.block1 = ModifiedThreeOps(3)
|
||||
self.fc2 = torch.nn.Linear(3, 7)
|
||||
self.block2 = ThreeOps(3)
|
||||
self.block2 = ModifiedThreeOps(3)
|
||||
self.fc3 = torch.nn.Linear(3, 7)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -706,7 +678,12 @@ class TestModelReportObserver(QuantizationTestCase):
|
|||
|
||||
# the purpose of this test is to give the observers a variety of data examples
|
||||
# initialize the model
|
||||
models = [self.NestedModifiedSingleLayerLinear(), LargerIncludeNestModel(), ThreeOps(2), HighDimensionNet()]
|
||||
models = [
|
||||
self.NestedModifiedSingleLayerLinear(),
|
||||
LargerIncludeNestModel(),
|
||||
ModifiedThreeOps(2),
|
||||
HighDimensionNet(),
|
||||
]
|
||||
|
||||
# get some number of epochs and batches
|
||||
num_epochs = 10
|
||||
|
|
@ -723,3 +700,100 @@ class TestModelReportObserver(QuantizationTestCase):
|
|||
# run it through the model and do general tests
|
||||
for index, model in enumerate(models):
|
||||
self.run_model_and_common_checks(model, inputs[index], num_epochs, num_batches)
|
||||
|
||||
|
||||
"""
|
||||
Partition on domain / things to test
|
||||
|
||||
There is only a single test case for now.
|
||||
|
||||
This will be more thoroughly tested with the implementation of the full end to end tool coming soon.
|
||||
"""
|
||||
|
||||
|
||||
class TestFxModelReportDetectDynamicStatic(QuantizationTestCase):
|
||||
@skipIfNoFBGEMM
|
||||
def test_nested_detection_case(self):
|
||||
class SingleLinear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(SingleLinear, self).__init__()
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class TwoBlockNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(TwoBlockNet, self).__init__()
|
||||
self.block1 = SingleLinear()
|
||||
self.block2 = SingleLinear()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.block1(x)
|
||||
y = self.block2(x)
|
||||
z = x + y
|
||||
z = F.relu(z)
|
||||
return z
|
||||
|
||||
# create model, example input, and qconfig mapping
|
||||
torch.backends.quantized.engine = "fbgemm"
|
||||
model = TwoBlockNet()
|
||||
example_input = torch.randint(-10, 0, (1, 3, 3, 3))
|
||||
example_input = example_input.to(torch.float)
|
||||
q_config_mapping = QConfigMapping()
|
||||
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig("fbgemm"))
|
||||
|
||||
# prep model and select observer
|
||||
model_prep = quantize_fx.prepare_fx(model, q_config_mapping, example_input)
|
||||
obs_ctr = ModelReportObserver
|
||||
|
||||
# find layer to attach to and store
|
||||
linear_fqn = "block2.linear" # fqn of target linear
|
||||
|
||||
target_linear = None
|
||||
for node in model_prep.graph.nodes:
|
||||
if node.target == linear_fqn:
|
||||
target_linear = node
|
||||
break
|
||||
|
||||
# insert into both module and graph pre and post
|
||||
|
||||
# set up to insert before target_linear (pre_observer)
|
||||
with model_prep.graph.inserting_before(target_linear):
|
||||
obs_to_insert = obs_ctr()
|
||||
pre_obs_fqn = linear_fqn + ".model_report_pre_observer"
|
||||
model_prep.add_submodule(pre_obs_fqn, obs_to_insert)
|
||||
model_prep.graph.create_node(op="call_module", target=pre_obs_fqn, args=target_linear.args)
|
||||
|
||||
# set up and insert after the target_linear (post_observer)
|
||||
with model_prep.graph.inserting_after(target_linear):
|
||||
obs_to_insert = obs_ctr()
|
||||
post_obs_fqn = linear_fqn + ".model_report_post_observer"
|
||||
model_prep.add_submodule(post_obs_fqn, obs_to_insert)
|
||||
model_prep.graph.create_node(op="call_module", target=post_obs_fqn, args=(target_linear,))
|
||||
|
||||
# need to recompile module after submodule added and pass input through
|
||||
model_prep.recompile()
|
||||
|
||||
num_iterations = 10
|
||||
for i in range(num_iterations):
|
||||
if i % 2 == 0:
|
||||
example_input = torch.randint(-10, 0, (1, 3, 3, 3)).to(torch.float)
|
||||
else:
|
||||
example_input = torch.randint(0, 10, (1, 3, 3, 3)).to(torch.float)
|
||||
model_prep(example_input)
|
||||
|
||||
# run it through the dynamic vs static detector
|
||||
dynam_vs_stat_str, dynam_vs_stat_dict = _detect_dynamic_vs_static(model_prep, tolerance=0.5)
|
||||
|
||||
# one of the stats should be stationary, and the other non-stationary
|
||||
# as a result, dynamic should be recommended
|
||||
data_dist_info = [
|
||||
dynam_vs_stat_dict[linear_fqn]["pre_observer_data_dist"],
|
||||
dynam_vs_stat_dict[linear_fqn]["post_observer_data_dist"],
|
||||
]
|
||||
|
||||
self.assertTrue("stationary" in data_dist_info)
|
||||
self.assertTrue("non-stationary" in data_dist_info)
|
||||
self.assertTrue(dynam_vs_stat_dict[linear_fqn]["dynamic_recommended"])
|
||||
|
|
|
|||
|
|
@ -84,8 +84,9 @@ except ImportError:
|
|||
|
||||
# Test the model report module
|
||||
try:
|
||||
from quantization.fx.test_model_report_fx import TestModelReportFxDetector # noqa: F401
|
||||
from quantization.fx.test_model_report_fx import TestModelReportObserver # noqa: F401
|
||||
from quantization.fx.test_model_report_fx import TestFxModelReportDetector # noqa: F401
|
||||
from quantization.fx.test_model_report_fx import TestFxModelReportObserver # noqa: F401
|
||||
from quantization.fx.test_model_report_fx import TestFxModelReportDetectDynamicStatic # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -3,11 +3,13 @@ from typing import Any, Dict, Set, Tuple
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.ao.quantization.fake_quantize import FakeQuantize
|
||||
from torch.ao.quantization.fx.graph_module import GraphModule
|
||||
from torch.ao.quantization.observer import ObserverBase
|
||||
from torch.ao.quantization.qconfig import QConfig
|
||||
from torch.nn.qat.modules.conv import _ConvNd as QatConvNd
|
||||
from torch.nn.qat.modules.linear import Linear as QatLinear
|
||||
|
||||
|
||||
# Default map for representing supported per channel quantization modules for different backends
|
||||
DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: Dict[str, Set[Any]] = {
|
||||
"fbgemm": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, QatLinear, QatConvNd]),
|
||||
|
|
@ -35,15 +37,9 @@ def _detect_per_channel(model: nn.Module) -> Tuple[str, Dict[str, Any]]:
|
|||
backend_chosen = torch.backends.quantized.engine
|
||||
supported_modules = set([])
|
||||
if backend_chosen in DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES:
|
||||
supported_modules = DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES[
|
||||
backend_chosen
|
||||
]
|
||||
supported_modules = DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES[backend_chosen]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Not configured to work with {}. Try a different default backend".format(
|
||||
backend_chosen
|
||||
)
|
||||
)
|
||||
raise ValueError("Not configured to work with {}. Try a different default backend".format(backend_chosen))
|
||||
|
||||
# store information on submodules and if per_channel quantization is supported and used as well as qconfig information
|
||||
per_channel_info = {"backend": backend_chosen, "per_channel_status": {}}
|
||||
|
|
@ -66,11 +62,7 @@ def _detect_per_channel(model: nn.Module) -> Tuple[str, Dict[str, Any]]:
|
|||
# asserts for MyPy
|
||||
assert isinstance(fqn, str) and isinstance(per_channel_info["per_channel_status"], dict)
|
||||
|
||||
is_in_include_list = (
|
||||
True
|
||||
if sum(list(map(lambda x: isinstance(module, x), supported_modules))) > 0
|
||||
else False
|
||||
)
|
||||
is_in_include_list = sum(list(map(lambda x: isinstance(module, x), supported_modules))) > 0
|
||||
|
||||
# check if the module per_channel is supported
|
||||
# based on backend
|
||||
|
|
@ -85,22 +77,15 @@ def _detect_per_channel(model: nn.Module) -> Tuple[str, Dict[str, Any]]:
|
|||
|
||||
# this object should either be fake quant or observer
|
||||
q_or_s_obj = module.qconfig.weight.p.func()
|
||||
assert isinstance(q_or_s_obj, FakeQuantize) or isinstance(
|
||||
q_or_s_obj, ObserverBase
|
||||
)
|
||||
assert isinstance(q_or_s_obj, FakeQuantize) or isinstance(q_or_s_obj, ObserverBase)
|
||||
|
||||
per_channel_used = False # will be true if found in qconfig
|
||||
|
||||
if hasattr(
|
||||
q_or_s_obj, "ch_axis"
|
||||
): # then we know that per_channel quantization used
|
||||
if hasattr(q_or_s_obj, "ch_axis"): # then we know that per_channel quantization used
|
||||
|
||||
# all fake quants have channel axis so need to check is_per_channel
|
||||
if isinstance(q_or_s_obj, FakeQuantize):
|
||||
if (
|
||||
hasattr(q_or_s_obj, "is_per_channel")
|
||||
and q_or_s_obj.is_per_channel
|
||||
):
|
||||
if hasattr(q_or_s_obj, "is_per_channel") and q_or_s_obj.is_per_channel:
|
||||
per_channel_used = True
|
||||
elif isinstance(q_or_s_obj, ObserverBase):
|
||||
# should be an observer otherwise
|
||||
|
|
@ -117,9 +102,7 @@ def _detect_per_channel(model: nn.Module) -> Tuple[str, Dict[str, Any]]:
|
|||
_detect_per_channel_helper(model)
|
||||
|
||||
# String to let the user know of further optimizations
|
||||
further_optims_str = "Further Optimizations for backend {}: \n".format(
|
||||
backend_chosen
|
||||
)
|
||||
further_optims_str = "Further Optimizations for backend {}: \n".format(backend_chosen)
|
||||
|
||||
# assert for MyPy check
|
||||
assert isinstance(per_channel_info["per_channel_status"], dict)
|
||||
|
|
@ -134,9 +117,161 @@ def _detect_per_channel(model: nn.Module) -> Tuple[str, Dict[str, Any]]:
|
|||
)
|
||||
|
||||
if optimizations_possible:
|
||||
further_optims_str += "To use per_channel quantization, make sure the qconfig has a per_channel weight observer."
|
||||
further_optims_str += (
|
||||
"To use per_channel quantization, make sure the qconfig has a per_channel weight observer."
|
||||
)
|
||||
else:
|
||||
further_optims_str += "No further per_channel optimizations possible."
|
||||
|
||||
# return the string and the dictionary form of same information
|
||||
return (further_optims_str, per_channel_info)
|
||||
|
||||
|
||||
# names for the pre and post observers that are inserted
|
||||
DEFAULT_PRE_OBSERVER_NAME = "model_report_pre_observer"
|
||||
DEFAULT_POST_OBSERVER_NAME = "model_report_post_observer"
|
||||
|
||||
# naming conventions for stationary vs non-stationary data
|
||||
DEFAULT_STATIONARY = "stationary"
|
||||
DEFAULT_NON_STATIONARY = "non-stationary"
|
||||
|
||||
# modules that are supported both dynamic and static for this report function
|
||||
DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED = set([nn.Linear])
|
||||
|
||||
def _detect_dynamic_vs_static(model: GraphModule, tolerance=0.5) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
Determines whether dynamic or static quantization is more appropriate for a given module.
|
||||
|
||||
Takes advantage of the ModelReportObserver that records range information.
|
||||
Stationary distribution of data are strictly above tolerance level for the comparison statistic:
|
||||
|
||||
S = average_batch_activation_range/epoch_activation_range
|
||||
|
||||
Nonstationary distributions are below or at the tolerance level for this metric.
|
||||
|
||||
If the distribution of data right after the module is non-stationary, recommend dynamic quantization
|
||||
Otherwise recommend static quantization
|
||||
|
||||
This will then generate suggestions for dynamic vs static quantization focused around Linear.
|
||||
|
||||
Args:
|
||||
model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers around layers of interest
|
||||
tolerance (float, optional): The threshold where S metric is stationary above and non-stationary otherwise. Default: 0.5
|
||||
|
||||
Returns a tuple with two elements:
|
||||
String report of of whether dynamic or static quantization is recommended for certain modules
|
||||
Dictionary mapping modules with ModelReportObservers around them to:
|
||||
whether dynamic quantization is recommended
|
||||
their S metric of input to module
|
||||
whether input to module is stationary or non-stationary
|
||||
their S metric of output of module
|
||||
whether output of module is stationary or non-stationary
|
||||
the tolerance level to decided whether input/output is stationary or non-stationary
|
||||
"""
|
||||
|
||||
# store modules dynamic vs static information
|
||||
module_dynamic_static_info = {}
|
||||
|
||||
# This for loop goes through the modules, and extracts all relavent information into module_dynamic_static_info
|
||||
# This information primary includes whether the data distributions around a supported module is stationary or not
|
||||
# Based on this, it is recorded whether dynamic or static quantization is recommended
|
||||
|
||||
# loop through all submodules included nested ones
|
||||
for fqn, module in model.named_modules():
|
||||
|
||||
# check to see if module is of a supported type
|
||||
is_supported_type = sum(list(map(lambda x: isinstance(module, x), DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED))) > 0
|
||||
|
||||
# if module is Linear has the ModelReportObserver attached to it
|
||||
if is_supported_type and hasattr(module, DEFAULT_PRE_OBSERVER_NAME) and hasattr(module, DEFAULT_POST_OBSERVER_NAME):
|
||||
# get pre and post observers for the module
|
||||
pre_obs = getattr(module, DEFAULT_PRE_OBSERVER_NAME)
|
||||
post_obs = getattr(module, DEFAULT_POST_OBSERVER_NAME)
|
||||
|
||||
# get the statistics for each module
|
||||
pre_stat = pre_obs.get_batch_to_epoch_ratio()
|
||||
post_stat = post_obs.get_batch_to_epoch_ratio()
|
||||
|
||||
# record module, pre and post stat, and whether to do dynamic or static based off it
|
||||
# true if post observer data distribution is non-stationary, false if it's stationary
|
||||
dynamic_recommended = post_stat <= tolerance
|
||||
|
||||
# specify the classifications for whether data distributions considered stationary or non-stationary
|
||||
pre_obs_dist_classification = DEFAULT_STATIONARY if pre_stat > tolerance else DEFAULT_NON_STATIONARY
|
||||
post_obs_dist_classification = DEFAULT_STATIONARY if post_stat > tolerance else DEFAULT_NON_STATIONARY
|
||||
|
||||
# store the set of important information for this module
|
||||
module_info = {
|
||||
"tolerance": tolerance,
|
||||
"dynamic_recommended": dynamic_recommended,
|
||||
"pre_observer_comp_stat": pre_stat,
|
||||
"pre_observer_data_dist": pre_obs_dist_classification,
|
||||
"post_observer_comp_stat": post_stat,
|
||||
"post_observer_data_dist": post_obs_dist_classification,
|
||||
}
|
||||
|
||||
module_dynamic_static_info[fqn] = module_info
|
||||
|
||||
dynamic_vs_static_string = "Dynamic vs. Static Quantization suggestions: \n"
|
||||
|
||||
# This for loop goes through the information collected in module_dynamic_static_info and:
|
||||
# Populates the string based report with the information from module_dynamic_static_info
|
||||
# Compiles the complete report by appending relavent formatted strings
|
||||
|
||||
for module_fqn in module_dynamic_static_info.keys():
|
||||
|
||||
module_info = module_dynamic_static_info[module_fqn]
|
||||
suggestion_string_template = "For module {} it is suggested to use {} quantization because {}.\n"
|
||||
|
||||
# decide what string formatting values will be
|
||||
quantization_type = ""
|
||||
|
||||
quantization_reasoning = "the distribution of data before {} is {} and the distribution after is {}."
|
||||
dynamic_benefit = " You will get more accurate results if you use dynamic quantization"
|
||||
static_benefit = " You can increase model efficiency if you use static quantization"
|
||||
benefit_str = ""
|
||||
|
||||
# strings for if dynamic quantized per tensor is needed
|
||||
recommend_per_tensor = " We recommend to add a {} before this module if it is static."
|
||||
rec_lay_to_add = "dynamic quantize per tensor layer"
|
||||
dynamic_per_tensor_string = recommend_per_tensor.format(rec_lay_to_add)
|
||||
dynamic_per_tensor_reasoning_string = (
|
||||
" This is because the input to this module has a non-stationary distribution."
|
||||
)
|
||||
|
||||
# start composing explanation
|
||||
if module_info["dynamic_recommended"]:
|
||||
quantization_type = "dynamic"
|
||||
benefit_str = dynamic_benefit
|
||||
else:
|
||||
quantization_type = "static"
|
||||
benefit_str = static_benefit
|
||||
|
||||
# now set the quantization explanation string
|
||||
quantization_reasoning = (
|
||||
quantization_reasoning.format(
|
||||
module_fqn, module_info["pre_observer_data_dist"], module_info["post_observer_data_dist"]
|
||||
)
|
||||
+ benefit_str
|
||||
)
|
||||
|
||||
# if we have a non-stationary input -> linear -> stationary we suggested static
|
||||
# however, we want to also recommend they add a dynamic quantize per tensor right if this change is made
|
||||
if (
|
||||
module_info["pre_observer_data_dist"] == DEFAULT_NON_STATIONARY
|
||||
and module_info["post_observer_data_dist"] == DEFAULT_STATIONARY
|
||||
):
|
||||
quantization_reasoning = (
|
||||
quantization_reasoning + dynamic_per_tensor_string + dynamic_per_tensor_reasoning_string
|
||||
)
|
||||
|
||||
# format the overall suggestion string with the specific inputs
|
||||
module_suggestion_string = suggestion_string_template.format(
|
||||
module_fqn, quantization_type, quantization_reasoning
|
||||
)
|
||||
|
||||
# append to overall suggestion
|
||||
dynamic_vs_static_string += module_suggestion_string
|
||||
|
||||
# return the string as well as the dictionary of information
|
||||
return (dynamic_vs_static_string, module_dynamic_static_info)
|
||||
|
|
|
|||
Loading…
Reference in a new issue