From a73c6a45b6cd25e6aa63f565c19a0ebd7ca0f24e Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 16 Dec 2021 15:00:48 -0800 Subject: [PATCH] [reland][quant][graphmode][fx] Enable fuse handler for sequence of 3 ops (#70006) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70006 reland: fixing some mypy errors that was missed before This PR enables fuse handler for sequence of three ops, and merges all fuse handlers into one TODO: we can also move this to backend_config_dict folder Test Plan: regression fusion test ``` python test/test_quantization.py TestFuseFx ``` Imported from OSS Imported from OSS Reviewed By: supriyar Differential Revision: D33144606 fbshipit-source-id: ca34f282018a0fb4d04c7e35119eaf2d64258e78 --- .../ao_migration/test_quantization_fx.py | 3 +- .../ao/quantization/fuser_method_mappings.py | 2 +- .../fx/backend_config_dict/fuse_handler.py | 4 +- torch/ao/quantization/fx/fuse.py | 11 +- torch/ao/quantization/fx/fusion_patterns.py | 133 ++++++------------ torch/ao/quantization/fx/prepare.py | 2 +- torch/quantization/fx/fusion_patterns.py | 3 +- 7 files changed, 58 insertions(+), 100 deletions(-) diff --git a/test/quantization/ao_migration/test_quantization_fx.py b/test/quantization/ao_migration/test_quantization_fx.py index 62920d9a93c..866f443fe8b 100644 --- a/test/quantization/ao_migration/test_quantization_fx.py +++ b/test/quantization/ao_migration/test_quantization_fx.py @@ -166,8 +166,7 @@ class TestAOMigrationQuantizationFx(AOMigrationTestCase): def test_function_import_fx_fusion_patterns(self): function_list = [ 'FuseHandler', - 'ConvOrLinearBNReLUFusion', - 'ModuleReLUFusion' + 'DefaultFuseHandler' ] self._test_function_import('fx.fusion_patterns', function_list) diff --git a/torch/ao/quantization/fuser_method_mappings.py b/torch/ao/quantization/fuser_method_mappings.py index 2359b418fc5..9eddbbc17b1 100644 --- a/torch/ao/quantization/fuser_method_mappings.py +++ b/torch/ao/quantization/fuser_method_mappings.py @@ -148,7 +148,7 @@ DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] = (nn.ReLU, (nn.BatchNorm1d, nn.Conv1d)): reverse3(fuse_conv_bn_relu), (nn.BatchNorm2d, nn.Conv2d): reverse2(fuse_conv_bn), (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)): reverse3(fuse_conv_bn_relu), - (nn.BatchNorm3d, nn.Conv2d): reverse2(fuse_conv_bn), + (nn.BatchNorm3d, nn.Conv3d): reverse2(fuse_conv_bn), (nn.ReLU, (nn.BatchNorm3d, nn.Conv3d)): reverse3(fuse_conv_bn_relu), (nn.ReLU, nn.Conv1d): reverse2(nni.ConvReLU1d), (nn.ReLU, nn.Conv2d): reverse2(nni.ConvReLU2d), diff --git a/torch/ao/quantization/fx/backend_config_dict/fuse_handler.py b/torch/ao/quantization/fx/backend_config_dict/fuse_handler.py index 8ed40f79a25..2e37c877526 100644 --- a/torch/ao/quantization/fx/backend_config_dict/fuse_handler.py +++ b/torch/ao/quantization/fx/backend_config_dict/fuse_handler.py @@ -1,5 +1,5 @@ -from ..fusion_patterns import ModuleReLUFusion +from ..fusion_patterns import DefaultFuseHandler # TODO: move ModuleReLUFusion here def get_fuse_handler_cls(): - return ModuleReLUFusion + return DefaultFuseHandler diff --git a/torch/ao/quantization/fx/fuse.py b/torch/ao/quantization/fx/fuse.py index ea918e9905c..5da00827132 100644 --- a/torch/ao/quantization/fx/fuse.py +++ b/torch/ao/quantization/fx/fuse.py @@ -56,15 +56,20 @@ class Fuser: def load_arg(a): return map_arg(a, lambda node: env[node.name]) + def get_root_node(node_pattern): + while not isinstance(node_pattern[-1], Node): + node_pattern = node_pattern[-1] + return node_pattern[-1] + for node in input_graph.nodes: maybe_last_node, pattern, matched_node_pattern, obj = \ fusion_pairs.get(node.name, (None, None, None, None)) if maybe_last_node is node: assert obj is not None # TODO: currently we hard code the root node, which only works for - # a tuple of two nodes, we want to make this more general to - # support more complex patterns - root_node = matched_node_pattern[-1] # type: ignore[index] + # a sequence of ops and assume the root node is the last node, + # we want to make this more general to support more complex patterns + root_node = get_root_node(matched_node_pattern) # type: ignore[index] env[node.name] = obj.fuse( self, load_arg, root_node, matched_node_pattern, # type: ignore[arg-type] fuse_custom_config_dict, fuser_method_mapping) diff --git a/torch/ao/quantization/fx/fusion_patterns.py b/torch/ao/quantization/fx/fusion_patterns.py index a37526cc4ac..d86c1cd4e59 100644 --- a/torch/ao/quantization/fx/fusion_patterns.py +++ b/torch/ao/quantization/fx/fusion_patterns.py @@ -5,10 +5,9 @@ from .pattern_utils import ( ) from .utils import _parent_name from .quantization_types import QuantizerCls, NodePattern, Pattern -from ..fuser_method_mappings import get_fuser_method from ..fuser_method_mappings import get_fuser_method_new from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional, Union, List from .match_utils import MatchAllNode # ---------------------------- @@ -32,80 +31,6 @@ class FuseHandler(ABC): fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]]) -> Node: pass -@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm1d, torch.nn.Conv1d))) -@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) -@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm3d, torch.nn.Conv3d))) -@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm1d, torch.nn.Conv1d))) -@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) -@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm3d, torch.nn.Conv3d))) -@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Conv1d)) -@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d)) -@register_fusion_pattern((torch.nn.BatchNorm3d, torch.nn.Conv3d)) -@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Linear)) -class ConvOrLinearBNReLUFusion(FuseHandler): - def __init__(self, quantizer: QuantizerCls, node: Node): - super().__init__(quantizer, node) - self.relu_node = None - self.bn_node = None - if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \ - (node.op == 'call_module' and type(quantizer.modules[node.target]) == torch.nn.ReLU): - self.relu_node = node - assert isinstance(node.args[0], Node) - node = node.args[0] - assert node.op == 'call_module' - if type(quantizer.modules[node.target]) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: - self.bn_node = node - self.bn = quantizer.modules[self.bn_node.target] - assert isinstance(node.args[0], Node) - node = node.args[0] - assert node.op == 'call_module' - self.conv_or_linear_node = node - self.conv_or_linear = quantizer.modules[self.conv_or_linear_node.target] - - def fuse(self, - quantizer: QuantizerCls, - load_arg: Callable, - root_node: Node, - matched_node_pattern: NodePattern, - fuse_custom_config_dict: Dict[str, Any], - fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]]) -> Node: - additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {}) - op_list = [] - if self.relu_node is not None: - # since relu can be used multiple times, we'll need to create a relu module for each match - if self.relu_node.op == 'call_module': - relu = torch.nn.ReLU(quantizer.modules[self.relu_node.target].inplace) - else: - # TODO: get inplace argument from functional - relu = torch.nn.ReLU() - op_list.append(relu) - relu.training = self.conv_or_linear.training - if self.bn_node is not None: - op_list.append(self.bn) - op_list.append(self.conv_or_linear) - else: - assert self.bn_node is not None - op_list.append(self.bn) - op_list.append(self.conv_or_linear) - - # the modules are added in order of relu - bn - conv_or_linear - # so we need to correct it - op_list.reverse() - op_type_list = tuple(type(m) for m in op_list) - conv_or_linear_parent_name, conv_or_linear_name = _parent_name(self.conv_or_linear_node.target) - fuser_method = get_fuser_method(op_type_list, additional_fuser_method_mapping) - if fuser_method is None: - raise NotImplementedError("Cannot fuse modules: {}".format(op_type_list)) - fused = fuser_method(*op_list) - setattr(quantizer.modules[conv_or_linear_parent_name], conv_or_linear_name, fused) - - # TODO: do we need to make sure bn is only used once? - if self.bn_node is not None: - parent_name, name = _parent_name(self.bn_node.target) - setattr(quantizer.modules[parent_name], name, torch.nn.Identity()) - # relu may be used multiple times, so we don't set relu to identity - return quantizer.fused_graph.node_copy(self.conv_or_linear_node, load_arg) - @register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv1d)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv2d)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv3d)) @@ -118,14 +43,25 @@ class ConvOrLinearBNReLUFusion(FuseHandler): @register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm2d)) @register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm3d)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm3d)) -class ModuleReLUFusion(FuseHandler): +@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Conv1d)) +@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d)) +@register_fusion_pattern((torch.nn.BatchNorm3d, torch.nn.Conv3d)) +@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Linear)) +@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm1d, torch.nn.Conv1d))) +@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) +@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm3d, torch.nn.Conv3d))) +@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm1d, torch.nn.Conv1d))) +@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) +@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm3d, torch.nn.Conv3d))) +class DefaultFuseHandler(FuseHandler): def __init__( self, quantizer: QuantizerCls, node: Node): super().__init__(quantizer, node) - def fuse(self, quantizer: QuantizerCls, + def fuse(self, + quantizer: QuantizerCls, load_arg: Callable, root_node: Node, matched_node_pattern: NodePattern, @@ -136,26 +72,45 @@ class ModuleReLUFusion(FuseHandler): root_module = quantizer.modules[root_node.target] assert len(additional_fuser_method_mapping) == 0, "Fusion implementation is " "undergoing changes, additoinal_fuser_method_mapping is not supported currently." - def get_module(n): - if n.op == "call_module": - return quantizer.modules[n.target] - elif n.op == "call_function" and n.target == torch.nn.functional.relu: - relu = torch.nn.ReLU() - relu.training = root_module.training - return relu - return MatchAllNode + def get_modules(pattern, modules): + """ Given a node pattern, extract the corresponding modules + e.g. input: (relu_node, (bn_node, conv_node)) + output: (relu_module, (bn_module, conv_module)) + """ + if isinstance(pattern, (tuple, list)): + n, *args = pattern + get_modules(n, modules) + arg_modules: List[torch.nn.Module] = [] + for a in args: + get_modules(a, arg_modules) + arg_modules = tuple(arg_modules) if len(arg_modules) > 1 else arg_modules[0] # type: ignore[assignment] + modules.append(arg_modules) + else: + n = pattern + if n.op == "call_module": + modules.append(quantizer.modules[n.target]) + elif n.op == "call_function" and n.target == torch.nn.functional.relu: + relu = torch.nn.ReLU() + relu.training = root_module.training + modules.append(relu) + else: + modules.append(MatchAllNode) + return tuple(modules) - matched_modules = tuple(map(get_module, matched_node_pattern)) # since relu can be used multiple times, we'll need to create a relu module for each match + matched_modules = get_modules(matched_node_pattern, []) - def get_type(m): + def get_matched_types(m): + if isinstance(m, tuple): + return tuple(map(get_matched_types, m)) return type(m) - matched_module_types = tuple(map(get_type, matched_modules)) + matched_module_types = get_matched_types(matched_modules) module_parent_name, module_name = _parent_name(root_node.target) fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping) # TODO: change the signature for fuser_method to take matched module patterns # as input fused_module = fuser_method(*matched_modules) + # TODO: maybe add a pass to cleanup bn modules? setattr(quantizer.modules[module_parent_name], module_name, fused_module) return quantizer.fused_graph.node_copy(root_node, load_arg) diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index 9ae76a22b1a..fd58e689eb8 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -1322,7 +1322,7 @@ def prepare( # 'linear': Linear(...), # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), # } - modules = dict(model.named_modules()) + modules = dict(model.named_modules(remove_duplicate=False)) # fill qconfig_map, a map from node name to qconfig, used in find_matches equalization_qconfig_map = generate_qconfig_map(model, modules, model.graph, equalization_qconfig_dict, node_name_to_scope) diff --git a/torch/quantization/fx/fusion_patterns.py b/torch/quantization/fx/fusion_patterns.py index ed9b3116f05..967c1be07af 100644 --- a/torch/quantization/fx/fusion_patterns.py +++ b/torch/quantization/fx/fusion_patterns.py @@ -8,6 +8,5 @@ here. """ from torch.ao.quantization.fx.fusion_patterns import ( FuseHandler, - ConvOrLinearBNReLUFusion, - ModuleReLUFusion + DefaultFuseHandler, )