fx quant: refactor qconfig setting out of find_matches

Summary:

Refactors `find_matches` function to only find subgraph
matches and not assign qconfigs to them. Moves the qconfig assignment
outside of the function. No logic change.

This will useful for prototyping future tools for quantizing
parts of the model. These tools will need to know the matches
and will reuse the `find_matches` function,
but they will assign their own qconfigs to them using a different
strategy.

Test plan:

```
python test/test_quantization.py -k Fx
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79713

Approved by: https://github.com/jerryzh168
This commit is contained in:
Vasiliy Kuznetsov 2022-06-17 08:21:01 -04:00 committed by PyTorch MergeBot
parent 1b25aa6786
commit 7b4e92acef
5 changed files with 24 additions and 34 deletions

View file

@ -56,7 +56,6 @@ class TestAOMigrationQuantizationFx(AOMigrationTestCase):
def test_function_import_fx_pattern_utils(self):
function_list = [
'QuantizeHandler',
'MatchResult',
'register_fusion_pattern',
'get_default_fusion_patterns',
'register_quant_pattern',

View file

@ -28,8 +28,10 @@ __all__ = [
]
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
QConfigAny]
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler]
_MatchResultWithQConfig = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
QConfigAny]
# Note: The order of patterns is important! match function will take whatever is matched first, so we'll
# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
@ -83,7 +85,6 @@ def find_matches(
modules: Dict[str, torch.nn.Module],
patterns: Dict[Pattern, QuantizeHandler],
root_node_getter_mapping: Dict[Pattern, Callable],
qconfig_map: Dict[str, QConfigAny],
standalone_module_names: List[str] = None,
standalone_module_classes: List[Type] = None,
custom_module_classes: List[Any] = None) -> Dict[str, MatchResult]:
@ -127,14 +128,13 @@ def find_matches(
node_pattern,
matched_node_pattern,
pattern,
match_value,
qconfig):
match_value):
if isinstance(node_pattern, Node):
match_map[node_pattern.name] = (
last_node, matched_node_pattern, pattern, match_value, qconfig)
last_node, matched_node_pattern, pattern, match_value)
else:
for n in node_pattern:
_recursive_record_node_in_match_map(last_node, match_map, n, matched_node_pattern, pattern, match_value, qconfig)
_recursive_record_node_in_match_map(last_node, match_map, n, matched_node_pattern, pattern, match_value)
# TODO: 1. merge with fuse matcher 2. document the code
def record_match(
@ -193,8 +193,7 @@ def find_matches(
# this is a part of the value corresponding to the node
matched_node_pattern,
pattern,
quantize_handler,
qconfig_map[node.name])
quantize_handler)
break
# add custom module instances to the match result
@ -202,10 +201,8 @@ def find_matches(
for node in graph.nodes:
if node.op == 'call_module' and \
type(modules[node.target]) in custom_module_classes:
custom_module_qconfig = qconfig_map[node.name]
match_map[node.name] = (
node, node, None, QuantizeHandler(node, modules, is_custom_module=True),
custom_module_qconfig)
node, node, None, QuantizeHandler(node, modules, is_custom_module=True))
def is_standalone_module(node_target: str, modules: Dict[str, torch.nn.Module]):
assert modules is not None
@ -220,10 +217,8 @@ def find_matches(
(is_standalone_module(node.target, modules) or
is_observed_standalone_module(modules[node.target])):
# add node to matched nodes
standalone_module_qconfig = qconfig_map[node.name]
match_map[node.name] = (
node, node, None,
QuantizeHandler(node, modules, is_standalone_module=True),
standalone_module_qconfig)
QuantizeHandler(node, modules, is_standalone_module=True))
return match_map

View file

@ -1,10 +1,6 @@
from collections import OrderedDict
from typing import Dict, Any, Tuple, List, Optional
from torch.fx.graph import (
Node,
)
from typing import Dict, Any
from torch.ao.quantization.quantization_types import Pattern
from ..qconfig import QConfigAny
from ..fake_quantize import FixedQParamsFakeQuantize
# from .quantization_patterns import BinaryOpQuantizeHandler
from ..observer import ObserverBase
@ -13,9 +9,6 @@ import copy
# TODO(future PR): fix the typing on QuantizeHandler (currently a circular dependency)
QuantizeHandler = Any
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
QConfigAny]
# pattern for conv bn fusion
DEFAULT_FUSION_PATTERNS = OrderedDict()
def register_fusion_pattern(pattern):

View file

@ -48,11 +48,11 @@ from .graph_module import (
)
from .pattern_utils import (
MatchResult,
sorted_patterns_dict,
)
from .match_utils import (
_MatchResultWithQConfig,
find_matches,
)
@ -732,7 +732,7 @@ def maybe_insert_output_observer_for_node(
model: torch.nn.Module,
modules: Dict[str, torch.nn.Module],
graph: Graph,
matches: Dict[str, MatchResult],
matches: Dict[str, _MatchResultWithQConfig],
node_name_to_target_dtype: Dict[str, Dict[str, Optional[Union[torch.dtype, type]]]],
matched_pattern: Any,
qhandler: Optional[QuantizeHandler],
@ -885,7 +885,7 @@ def maybe_propagate_dtype_for_node(
node: Node,
target_dtype: Union[torch.dtype, type],
node_name_to_target_dtype: Dict[str, Dict[str, Optional[Union[torch.dtype, type]]]],
matches: Dict[str, MatchResult],
matches: Dict[str, _MatchResultWithQConfig],
) -> None:
"""
Assigns `target_dtype` to `node`. If `node` is a general tensor shape op
@ -907,7 +907,7 @@ def maybe_propagate_dtype_for_node(
def propagate_dtypes_for_known_nodes(
graph: Graph,
node_name_to_target_dtype: Dict[str, Dict[str, Optional[Union[torch.dtype, type]]]],
matches: Dict[str, MatchResult],
matches: Dict[str, _MatchResultWithQConfig],
) -> None:
"""
Currently we assume that inputs to the graph are either `torch.float` or
@ -1062,7 +1062,7 @@ def swap_custom_module_to_observed(
def insert_observers_for_model(
model: GraphModule,
modules: Dict[str, torch.nn.Module],
matches: Dict[str, MatchResult],
matches: Dict[str, _MatchResultWithQConfig],
qconfig_map: Dict[str, QConfigAny],
graph: Graph,
prepare_custom_config: PrepareCustomConfig,
@ -1506,10 +1506,16 @@ def prepare(
standalone_module_classes = list(prepare_custom_config.standalone_module_classes.keys())
custom_module_classes = get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping)
matches = find_matches(
model.graph, modules, patterns, root_node_getter_mapping, qconfig_map,
matches_without_qconfig = find_matches(
model.graph, modules, patterns, root_node_getter_mapping,
standalone_module_names, standalone_module_classes, custom_module_classes)
# map qconfig instances to matches
matches = {}
for node_name, match_without_qconfig in matches_without_qconfig.items():
match_with_qconfig = (*match_without_qconfig, qconfig_map[node_name])
matches[node_name] = match_with_qconfig
input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes
output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes

View file

@ -8,7 +8,6 @@ here.
"""
from torch.ao.quantization.fx.pattern_utils import (
QuantizeHandler,
MatchResult,
register_fusion_pattern,
get_default_fusion_patterns,
register_quant_pattern,
@ -17,7 +16,6 @@ from torch.ao.quantization.fx.pattern_utils import (
)
# QuantizeHandler.__module__ = _NAMESPACE
MatchResult.__module__ = "torch.quantization.fx.pattern_utils"
register_fusion_pattern.__module__ = "torch.quantization.fx.pattern_utils"
get_default_fusion_patterns.__module__ = "torch.quantization.fx.pattern_utils"
register_quant_pattern.__module__ = "torch.quantization.fx.pattern_utils"
@ -26,7 +24,6 @@ get_default_output_activation_post_process_map.__module__ = "torch.quantization.
# __all__ = [
# "QuantizeHandler",
# "MatchResult",
# "register_fusion_pattern",
# "get_default_fusion_patterns",
# "register_quant_pattern",