mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
[quant] Input-weight equalization - branch support (#62366)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62366 In the case of models with branches, we are unable to equalize the branching part in the graph. For example, given this graph: ``` conv2 / \ x -> conv1 -> add ``` After prepare, we will ignore the branched layers (conv1 and conv2) and will not insert the equalization observers. A warning message will also be printed with the layers that are unable to be equalized. ``` conv2 -> out_quant_obs2 / \ x -> input_quant_obs -> conv1 -> out_quant_obs1 -> add ``` Test Plan: `python test/test_quantization.py TestEqualizeFx.test_input_weight_equalization_prepare` Imported from OSS Reviewed By: malfet, supriyar Differential Revision: D29982585 fbshipit-source-id: 706297e7f1861975998dfa83e7ca59af09d80618
This commit is contained in:
parent
62a90c227f
commit
91ef19309e
2 changed files with 108 additions and 11 deletions
|
|
@ -49,7 +49,9 @@ from hypothesis import given
|
|||
from hypothesis import strategies as st
|
||||
|
||||
|
||||
qconfig_dict = {
|
||||
default_qconfig_dict = {"": default_qconfig}
|
||||
|
||||
specific_qconfig_dict = {
|
||||
"": None,
|
||||
"object_type": [(nn.Linear, default_qconfig),
|
||||
(F.linear, default_qconfig),
|
||||
|
|
@ -268,9 +270,61 @@ class TestEqualizeFx(QuantizationTestCase):
|
|||
|
||||
for (M, node_occurrence) in tests:
|
||||
m = M().eval()
|
||||
prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
|
||||
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
|
||||
self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence)
|
||||
|
||||
def test_input_weight_equalization_branching(self):
|
||||
""" Tests that graphs containing branches are prepared correctly.
|
||||
Specifically, equalization observers should not be inserted in front of
|
||||
branches in which both initial layers in the branches plan to be
|
||||
quantized.
|
||||
"""
|
||||
|
||||
# Tests that we do not add an equalization observer due to both initial
|
||||
# nodes in the branch containing layers that need to be equalized.
|
||||
# Note that this should print out 2 warning messages for not being able
|
||||
# to equalize layers linear1 and linear1 because it is part of a branch
|
||||
class TestBranchingWithoutEqualizationModel(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(5, 5)
|
||||
self.linear2 = nn.Linear(5, 5)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.linear1(x)
|
||||
z = self.linear2(x)
|
||||
return torch.add(y, z)
|
||||
|
||||
no_eq_branching_node_occurrence = {
|
||||
ns.call_module(_InputEqualizationObserver): 0,
|
||||
ns.call_module(MinMaxObserver): 3,
|
||||
}
|
||||
|
||||
m = TestBranchingWithoutEqualizationModel().eval()
|
||||
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
|
||||
self.checkGraphModuleNodes(prepared, expected_node_occurrence=no_eq_branching_node_occurrence)
|
||||
|
||||
# Tests that we will add an equalization observer because there is only
|
||||
# one initial node in the branch that needs to be equalized
|
||||
class TestBranchingWithEqualizationModel(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(5, 5)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.linear1(x)
|
||||
z = torch.add(x, 5)
|
||||
return torch.add(y, z)
|
||||
|
||||
eq_branching_node_occurrence = {
|
||||
ns.call_module(_InputEqualizationObserver): 1,
|
||||
ns.call_module(MinMaxObserver): 2,
|
||||
}
|
||||
|
||||
m = TestBranchingWithEqualizationModel().eval()
|
||||
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
|
||||
self.checkGraphModuleNodes(prepared, expected_node_occurrence=eq_branching_node_occurrence)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_input_weight_equalization_convert(self):
|
||||
""" Tests that the modified model for equalization (before quantization)
|
||||
|
|
@ -295,13 +349,17 @@ class TestEqualizeFx(QuantizationTestCase):
|
|||
elif ndim == 4:
|
||||
x = torch.rand((16, 3, 224, 224))
|
||||
|
||||
prepared = prepare_fx(copy.deepcopy(m), qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
|
||||
prepared = prepare_fx(
|
||||
copy.deepcopy(m),
|
||||
specific_qconfig_dict,
|
||||
equalization_qconfig_dict=default_equalization_qconfig_dict
|
||||
)
|
||||
output = prepared(x)
|
||||
|
||||
convert_ref = _convert_equalization_ref(prepared)
|
||||
convert_ref_output = convert_ref(x)
|
||||
|
||||
prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
|
||||
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
|
||||
prepared(x)
|
||||
convert_fx(prepared) # Check if compile
|
||||
self.assertEqual(output, convert_ref_output)
|
||||
|
|
@ -349,7 +407,7 @@ class TestEqualizeFx(QuantizationTestCase):
|
|||
m = M().eval()
|
||||
exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy())
|
||||
|
||||
prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
|
||||
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
|
||||
prepared(x)
|
||||
convert_ref = _convert_equalization_ref(prepared)
|
||||
convert_ref(x)
|
||||
|
|
@ -398,7 +456,7 @@ class TestEqualizeFx(QuantizationTestCase):
|
|||
exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy())
|
||||
exp_weights, exp_bias = self.get_expected_weights_bias(m, x.detach().numpy(), exp_eq_scales)
|
||||
|
||||
prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
|
||||
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
|
||||
prepared(x)
|
||||
convert_ref = _convert_equalization_ref(prepared)
|
||||
convert_ref(x)
|
||||
|
|
@ -454,7 +512,7 @@ class TestEqualizeFx(QuantizationTestCase):
|
|||
exp_inp_act_vals = self.get_expected_inp_act_vals(m, x, exp_eq_scales, exp_weights, exp_bias)
|
||||
exp_weight_act_vals = self.get_expected_weight_act_vals(exp_weights)
|
||||
|
||||
prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
|
||||
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
|
||||
prepared(x)
|
||||
convert_ref = _convert_equalization_ref(prepared)
|
||||
convert_ref(x)
|
||||
|
|
@ -695,7 +753,7 @@ class TestEqualizeFx(QuantizationTestCase):
|
|||
elif ndim == 4:
|
||||
x = torch.rand((16, 3, 224, 224))
|
||||
|
||||
prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
|
||||
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
|
||||
prepared(x)
|
||||
equalized_quantized_model = convert_fx(prepared)
|
||||
|
||||
|
|
@ -716,13 +774,17 @@ class TestEqualizeFx(QuantizationTestCase):
|
|||
m = M().eval()
|
||||
|
||||
# No equalization
|
||||
prepared = prepare_fx(copy.deepcopy(m), qconfig_dict, equalization_qconfig_dict={})
|
||||
prepared = prepare_fx(copy.deepcopy(m), specific_qconfig_dict, equalization_qconfig_dict={})
|
||||
prepared(x)
|
||||
quantized = convert_fx(prepared) # Check if compile
|
||||
quantized_output = quantized(x)
|
||||
|
||||
# With equalization
|
||||
prepared = prepare_fx(copy.deepcopy(m), qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
|
||||
prepared = prepare_fx(
|
||||
copy.deepcopy(m),
|
||||
specific_qconfig_dict,
|
||||
equalization_qconfig_dict=default_equalization_qconfig_dict
|
||||
)
|
||||
prepared(x)
|
||||
equalized_and_quantized = convert_fx(prepared) # Check if compile
|
||||
equalized_and_quantized_output = equalized_and_quantized(x)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
import torch
|
||||
import operator
|
||||
import warnings
|
||||
from torch.fx import (
|
||||
GraphModule,
|
||||
)
|
||||
|
||||
from torch.quantization import (
|
||||
propagate_qconfig_,
|
||||
ObserverBase,
|
||||
)
|
||||
from torch.fx.graph import (
|
||||
Graph,
|
||||
|
|
@ -485,6 +487,7 @@ def maybe_insert_input_equalization_observers_for_node(
|
|||
modules: Dict[str, torch.nn.Module],
|
||||
graph: Graph,
|
||||
node_name_to_target_dtype: Dict[str, Any],
|
||||
is_branch: bool,
|
||||
) -> None:
|
||||
"""
|
||||
If `node` needs to be equalized, find the input/weight observers it needs in
|
||||
|
|
@ -495,6 +498,12 @@ def maybe_insert_input_equalization_observers_for_node(
|
|||
if equalization_qconfig is None or not node_supports_equalization(node, modules):
|
||||
return
|
||||
|
||||
if is_branch:
|
||||
warnings.warn(
|
||||
f"Cannot equalize {node} because it is part of a branch."
|
||||
)
|
||||
return
|
||||
|
||||
new_args = []
|
||||
for arg in node.args:
|
||||
if not isinstance(arg, Node) or node_arg_is_bias(node, arg):
|
||||
|
|
@ -940,6 +949,32 @@ def insert_observers_for_model(
|
|||
if not skip_inserting_observers:
|
||||
modules = dict(model.named_modules(remove_duplicate=False))
|
||||
if node.op != 'output':
|
||||
|
||||
# This is currently only used for equalization.
|
||||
# Checks if the current node is in a branch in which the two
|
||||
# first layers are both being quantized.
|
||||
#
|
||||
# ex. conv2
|
||||
# /
|
||||
# x -> conv1
|
||||
#
|
||||
# If this is the case, we will not apply equalization to the
|
||||
# initial two layers.
|
||||
is_quantized_branch = False
|
||||
if (
|
||||
len(node.args) > 0 and
|
||||
isinstance(node.args[0], Node) and
|
||||
len(node.args[0].users) > 1
|
||||
):
|
||||
for user in node.args[0].users:
|
||||
# Checks if there exists another user being quantized
|
||||
is_user_quantized = (
|
||||
qconfig_map.get(user.name, None) is not None or
|
||||
(user.op == 'call_module' and isinstance(modules[str(user.target)], ObserverBase))
|
||||
)
|
||||
if user != node and is_user_quantized:
|
||||
is_quantized_branch = True
|
||||
|
||||
# this modifies node inplace
|
||||
maybe_insert_input_observers_for_node(
|
||||
node, qconfig, model, modules, graph,
|
||||
|
|
@ -949,7 +984,7 @@ def insert_observers_for_model(
|
|||
# Insert equalization input observers if needed
|
||||
maybe_insert_input_equalization_observers_for_node(
|
||||
node, equalization_qconfig, model, modules, graph,
|
||||
node_name_to_target_dtype)
|
||||
node_name_to_target_dtype, is_quantized_branch)
|
||||
|
||||
is_last_node_of_pattern = root_node is node
|
||||
is_general_tensor_value_op = \
|
||||
|
|
|
|||
Loading…
Reference in a new issue