[quant] Input-weight equalization - branch support (#62366)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62366

In the case of models with branches, we are unable to equalize the branching part in the graph.

For example, given this graph:
```
     conv2
    /     \
x -> conv1 -> add
```

After prepare, we will ignore the branched layers (conv1 and conv2) and will not insert the equalization observers. A warning message will also be printed with the layers that are unable to be equalized.
```
                        conv2 -> out_quant_obs2
                       /                       \
x -> input_quant_obs -> conv1 -> out_quant_obs1 -> add
```

Test Plan:
`python test/test_quantization.py TestEqualizeFx.test_input_weight_equalization_prepare`

Imported from OSS

Reviewed By: malfet, supriyar

Differential Revision: D29982585

fbshipit-source-id: 706297e7f1861975998dfa83e7ca59af09d80618
This commit is contained in:
Angela Yi 2021-08-03 12:43:44 -07:00 committed by Facebook GitHub Bot
parent 62a90c227f
commit 91ef19309e
2 changed files with 108 additions and 11 deletions

View file

@ -49,7 +49,9 @@ from hypothesis import given
from hypothesis import strategies as st
qconfig_dict = {
default_qconfig_dict = {"": default_qconfig}
specific_qconfig_dict = {
"": None,
"object_type": [(nn.Linear, default_qconfig),
(F.linear, default_qconfig),
@ -268,9 +270,61 @@ class TestEqualizeFx(QuantizationTestCase):
for (M, node_occurrence) in tests:
m = M().eval()
prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence)
def test_input_weight_equalization_branching(self):
""" Tests that graphs containing branches are prepared correctly.
Specifically, equalization observers should not be inserted in front of
branches in which both initial layers in the branches plan to be
quantized.
"""
# Tests that we do not add an equalization observer due to both initial
# nodes in the branch containing layers that need to be equalized.
# Note that this should print out 2 warning messages for not being able
# to equalize layers linear1 and linear1 because it is part of a branch
class TestBranchingWithoutEqualizationModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear1 = nn.Linear(5, 5)
self.linear2 = nn.Linear(5, 5)
def forward(self, x):
y = self.linear1(x)
z = self.linear2(x)
return torch.add(y, z)
no_eq_branching_node_occurrence = {
ns.call_module(_InputEqualizationObserver): 0,
ns.call_module(MinMaxObserver): 3,
}
m = TestBranchingWithoutEqualizationModel().eval()
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
self.checkGraphModuleNodes(prepared, expected_node_occurrence=no_eq_branching_node_occurrence)
# Tests that we will add an equalization observer because there is only
# one initial node in the branch that needs to be equalized
class TestBranchingWithEqualizationModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear1 = nn.Linear(5, 5)
def forward(self, x):
y = self.linear1(x)
z = torch.add(x, 5)
return torch.add(y, z)
eq_branching_node_occurrence = {
ns.call_module(_InputEqualizationObserver): 1,
ns.call_module(MinMaxObserver): 2,
}
m = TestBranchingWithEqualizationModel().eval()
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
self.checkGraphModuleNodes(prepared, expected_node_occurrence=eq_branching_node_occurrence)
@skipIfNoFBGEMM
def test_input_weight_equalization_convert(self):
""" Tests that the modified model for equalization (before quantization)
@ -295,13 +349,17 @@ class TestEqualizeFx(QuantizationTestCase):
elif ndim == 4:
x = torch.rand((16, 3, 224, 224))
prepared = prepare_fx(copy.deepcopy(m), qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
prepared = prepare_fx(
copy.deepcopy(m),
specific_qconfig_dict,
equalization_qconfig_dict=default_equalization_qconfig_dict
)
output = prepared(x)
convert_ref = _convert_equalization_ref(prepared)
convert_ref_output = convert_ref(x)
prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
prepared(x)
convert_fx(prepared) # Check if compile
self.assertEqual(output, convert_ref_output)
@ -349,7 +407,7 @@ class TestEqualizeFx(QuantizationTestCase):
m = M().eval()
exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy())
prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
prepared(x)
convert_ref = _convert_equalization_ref(prepared)
convert_ref(x)
@ -398,7 +456,7 @@ class TestEqualizeFx(QuantizationTestCase):
exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy())
exp_weights, exp_bias = self.get_expected_weights_bias(m, x.detach().numpy(), exp_eq_scales)
prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
prepared(x)
convert_ref = _convert_equalization_ref(prepared)
convert_ref(x)
@ -454,7 +512,7 @@ class TestEqualizeFx(QuantizationTestCase):
exp_inp_act_vals = self.get_expected_inp_act_vals(m, x, exp_eq_scales, exp_weights, exp_bias)
exp_weight_act_vals = self.get_expected_weight_act_vals(exp_weights)
prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
prepared(x)
convert_ref = _convert_equalization_ref(prepared)
convert_ref(x)
@ -695,7 +753,7 @@ class TestEqualizeFx(QuantizationTestCase):
elif ndim == 4:
x = torch.rand((16, 3, 224, 224))
prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
prepared(x)
equalized_quantized_model = convert_fx(prepared)
@ -716,13 +774,17 @@ class TestEqualizeFx(QuantizationTestCase):
m = M().eval()
# No equalization
prepared = prepare_fx(copy.deepcopy(m), qconfig_dict, equalization_qconfig_dict={})
prepared = prepare_fx(copy.deepcopy(m), specific_qconfig_dict, equalization_qconfig_dict={})
prepared(x)
quantized = convert_fx(prepared) # Check if compile
quantized_output = quantized(x)
# With equalization
prepared = prepare_fx(copy.deepcopy(m), qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
prepared = prepare_fx(
copy.deepcopy(m),
specific_qconfig_dict,
equalization_qconfig_dict=default_equalization_qconfig_dict
)
prepared(x)
equalized_and_quantized = convert_fx(prepared) # Check if compile
equalized_and_quantized_output = equalized_and_quantized(x)

View file

@ -1,11 +1,13 @@
import torch
import operator
import warnings
from torch.fx import (
GraphModule,
)
from torch.quantization import (
propagate_qconfig_,
ObserverBase,
)
from torch.fx.graph import (
Graph,
@ -485,6 +487,7 @@ def maybe_insert_input_equalization_observers_for_node(
modules: Dict[str, torch.nn.Module],
graph: Graph,
node_name_to_target_dtype: Dict[str, Any],
is_branch: bool,
) -> None:
"""
If `node` needs to be equalized, find the input/weight observers it needs in
@ -495,6 +498,12 @@ def maybe_insert_input_equalization_observers_for_node(
if equalization_qconfig is None or not node_supports_equalization(node, modules):
return
if is_branch:
warnings.warn(
f"Cannot equalize {node} because it is part of a branch."
)
return
new_args = []
for arg in node.args:
if not isinstance(arg, Node) or node_arg_is_bias(node, arg):
@ -940,6 +949,32 @@ def insert_observers_for_model(
if not skip_inserting_observers:
modules = dict(model.named_modules(remove_duplicate=False))
if node.op != 'output':
# This is currently only used for equalization.
# Checks if the current node is in a branch in which the two
# first layers are both being quantized.
#
# ex. conv2
# /
# x -> conv1
#
# If this is the case, we will not apply equalization to the
# initial two layers.
is_quantized_branch = False
if (
len(node.args) > 0 and
isinstance(node.args[0], Node) and
len(node.args[0].users) > 1
):
for user in node.args[0].users:
# Checks if there exists another user being quantized
is_user_quantized = (
qconfig_map.get(user.name, None) is not None or
(user.op == 'call_module' and isinstance(modules[str(user.target)], ObserverBase))
)
if user != node and is_user_quantized:
is_quantized_branch = True
# this modifies node inplace
maybe_insert_input_observers_for_node(
node, qconfig, model, modules, graph,
@ -949,7 +984,7 @@ def insert_observers_for_model(
# Insert equalization input observers if needed
maybe_insert_input_equalization_observers_for_node(
node, equalization_qconfig, model, modules, graph,
node_name_to_target_dtype)
node_name_to_target_dtype, is_quantized_branch)
is_last_node_of_pattern = root_node is node
is_general_tensor_value_op = \