mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
fx quant: refactor qconfig setting out of find_matches
Summary: Refactors `find_matches` function to only find subgraph matches and not assign qconfigs to them. Moves the qconfig assignment outside of the function. No logic change. This will useful for prototyping future tools for quantizing parts of the model. These tools will need to know the matches and will reuse the `find_matches` function, but they will assign their own qconfigs to them using a different strategy. Test plan: ``` python test/test_quantization.py -k Fx ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/79713 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
1b25aa6786
commit
7b4e92acef
5 changed files with 24 additions and 34 deletions
|
|
@ -56,7 +56,6 @@ class TestAOMigrationQuantizationFx(AOMigrationTestCase):
|
|||
def test_function_import_fx_pattern_utils(self):
|
||||
function_list = [
|
||||
'QuantizeHandler',
|
||||
'MatchResult',
|
||||
'register_fusion_pattern',
|
||||
'get_default_fusion_patterns',
|
||||
'register_quant_pattern',
|
||||
|
|
|
|||
|
|
@ -28,8 +28,10 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
|
||||
QConfigAny]
|
||||
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler]
|
||||
|
||||
_MatchResultWithQConfig = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
|
||||
QConfigAny]
|
||||
|
||||
# Note: The order of patterns is important! match function will take whatever is matched first, so we'll
|
||||
# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
|
||||
|
|
@ -83,7 +85,6 @@ def find_matches(
|
|||
modules: Dict[str, torch.nn.Module],
|
||||
patterns: Dict[Pattern, QuantizeHandler],
|
||||
root_node_getter_mapping: Dict[Pattern, Callable],
|
||||
qconfig_map: Dict[str, QConfigAny],
|
||||
standalone_module_names: List[str] = None,
|
||||
standalone_module_classes: List[Type] = None,
|
||||
custom_module_classes: List[Any] = None) -> Dict[str, MatchResult]:
|
||||
|
|
@ -127,14 +128,13 @@ def find_matches(
|
|||
node_pattern,
|
||||
matched_node_pattern,
|
||||
pattern,
|
||||
match_value,
|
||||
qconfig):
|
||||
match_value):
|
||||
if isinstance(node_pattern, Node):
|
||||
match_map[node_pattern.name] = (
|
||||
last_node, matched_node_pattern, pattern, match_value, qconfig)
|
||||
last_node, matched_node_pattern, pattern, match_value)
|
||||
else:
|
||||
for n in node_pattern:
|
||||
_recursive_record_node_in_match_map(last_node, match_map, n, matched_node_pattern, pattern, match_value, qconfig)
|
||||
_recursive_record_node_in_match_map(last_node, match_map, n, matched_node_pattern, pattern, match_value)
|
||||
|
||||
# TODO: 1. merge with fuse matcher 2. document the code
|
||||
def record_match(
|
||||
|
|
@ -193,8 +193,7 @@ def find_matches(
|
|||
# this is a part of the value corresponding to the node
|
||||
matched_node_pattern,
|
||||
pattern,
|
||||
quantize_handler,
|
||||
qconfig_map[node.name])
|
||||
quantize_handler)
|
||||
break
|
||||
|
||||
# add custom module instances to the match result
|
||||
|
|
@ -202,10 +201,8 @@ def find_matches(
|
|||
for node in graph.nodes:
|
||||
if node.op == 'call_module' and \
|
||||
type(modules[node.target]) in custom_module_classes:
|
||||
custom_module_qconfig = qconfig_map[node.name]
|
||||
match_map[node.name] = (
|
||||
node, node, None, QuantizeHandler(node, modules, is_custom_module=True),
|
||||
custom_module_qconfig)
|
||||
node, node, None, QuantizeHandler(node, modules, is_custom_module=True))
|
||||
|
||||
def is_standalone_module(node_target: str, modules: Dict[str, torch.nn.Module]):
|
||||
assert modules is not None
|
||||
|
|
@ -220,10 +217,8 @@ def find_matches(
|
|||
(is_standalone_module(node.target, modules) or
|
||||
is_observed_standalone_module(modules[node.target])):
|
||||
# add node to matched nodes
|
||||
standalone_module_qconfig = qconfig_map[node.name]
|
||||
match_map[node.name] = (
|
||||
node, node, None,
|
||||
QuantizeHandler(node, modules, is_standalone_module=True),
|
||||
standalone_module_qconfig)
|
||||
QuantizeHandler(node, modules, is_standalone_module=True))
|
||||
|
||||
return match_map
|
||||
|
|
|
|||
|
|
@ -1,10 +1,6 @@
|
|||
from collections import OrderedDict
|
||||
from typing import Dict, Any, Tuple, List, Optional
|
||||
from torch.fx.graph import (
|
||||
Node,
|
||||
)
|
||||
from typing import Dict, Any
|
||||
from torch.ao.quantization.quantization_types import Pattern
|
||||
from ..qconfig import QConfigAny
|
||||
from ..fake_quantize import FixedQParamsFakeQuantize
|
||||
# from .quantization_patterns import BinaryOpQuantizeHandler
|
||||
from ..observer import ObserverBase
|
||||
|
|
@ -13,9 +9,6 @@ import copy
|
|||
# TODO(future PR): fix the typing on QuantizeHandler (currently a circular dependency)
|
||||
QuantizeHandler = Any
|
||||
|
||||
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
|
||||
QConfigAny]
|
||||
|
||||
# pattern for conv bn fusion
|
||||
DEFAULT_FUSION_PATTERNS = OrderedDict()
|
||||
def register_fusion_pattern(pattern):
|
||||
|
|
|
|||
|
|
@ -48,11 +48,11 @@ from .graph_module import (
|
|||
)
|
||||
|
||||
from .pattern_utils import (
|
||||
MatchResult,
|
||||
sorted_patterns_dict,
|
||||
)
|
||||
|
||||
from .match_utils import (
|
||||
_MatchResultWithQConfig,
|
||||
find_matches,
|
||||
)
|
||||
|
||||
|
|
@ -732,7 +732,7 @@ def maybe_insert_output_observer_for_node(
|
|||
model: torch.nn.Module,
|
||||
modules: Dict[str, torch.nn.Module],
|
||||
graph: Graph,
|
||||
matches: Dict[str, MatchResult],
|
||||
matches: Dict[str, _MatchResultWithQConfig],
|
||||
node_name_to_target_dtype: Dict[str, Dict[str, Optional[Union[torch.dtype, type]]]],
|
||||
matched_pattern: Any,
|
||||
qhandler: Optional[QuantizeHandler],
|
||||
|
|
@ -885,7 +885,7 @@ def maybe_propagate_dtype_for_node(
|
|||
node: Node,
|
||||
target_dtype: Union[torch.dtype, type],
|
||||
node_name_to_target_dtype: Dict[str, Dict[str, Optional[Union[torch.dtype, type]]]],
|
||||
matches: Dict[str, MatchResult],
|
||||
matches: Dict[str, _MatchResultWithQConfig],
|
||||
) -> None:
|
||||
"""
|
||||
Assigns `target_dtype` to `node`. If `node` is a general tensor shape op
|
||||
|
|
@ -907,7 +907,7 @@ def maybe_propagate_dtype_for_node(
|
|||
def propagate_dtypes_for_known_nodes(
|
||||
graph: Graph,
|
||||
node_name_to_target_dtype: Dict[str, Dict[str, Optional[Union[torch.dtype, type]]]],
|
||||
matches: Dict[str, MatchResult],
|
||||
matches: Dict[str, _MatchResultWithQConfig],
|
||||
) -> None:
|
||||
"""
|
||||
Currently we assume that inputs to the graph are either `torch.float` or
|
||||
|
|
@ -1062,7 +1062,7 @@ def swap_custom_module_to_observed(
|
|||
def insert_observers_for_model(
|
||||
model: GraphModule,
|
||||
modules: Dict[str, torch.nn.Module],
|
||||
matches: Dict[str, MatchResult],
|
||||
matches: Dict[str, _MatchResultWithQConfig],
|
||||
qconfig_map: Dict[str, QConfigAny],
|
||||
graph: Graph,
|
||||
prepare_custom_config: PrepareCustomConfig,
|
||||
|
|
@ -1506,10 +1506,16 @@ def prepare(
|
|||
standalone_module_classes = list(prepare_custom_config.standalone_module_classes.keys())
|
||||
|
||||
custom_module_classes = get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping)
|
||||
matches = find_matches(
|
||||
model.graph, modules, patterns, root_node_getter_mapping, qconfig_map,
|
||||
matches_without_qconfig = find_matches(
|
||||
model.graph, modules, patterns, root_node_getter_mapping,
|
||||
standalone_module_names, standalone_module_classes, custom_module_classes)
|
||||
|
||||
# map qconfig instances to matches
|
||||
matches = {}
|
||||
for node_name, match_without_qconfig in matches_without_qconfig.items():
|
||||
match_with_qconfig = (*match_without_qconfig, qconfig_map[node_name])
|
||||
matches[node_name] = match_with_qconfig
|
||||
|
||||
input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes
|
||||
output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ here.
|
|||
"""
|
||||
from torch.ao.quantization.fx.pattern_utils import (
|
||||
QuantizeHandler,
|
||||
MatchResult,
|
||||
register_fusion_pattern,
|
||||
get_default_fusion_patterns,
|
||||
register_quant_pattern,
|
||||
|
|
@ -17,7 +16,6 @@ from torch.ao.quantization.fx.pattern_utils import (
|
|||
)
|
||||
|
||||
# QuantizeHandler.__module__ = _NAMESPACE
|
||||
MatchResult.__module__ = "torch.quantization.fx.pattern_utils"
|
||||
register_fusion_pattern.__module__ = "torch.quantization.fx.pattern_utils"
|
||||
get_default_fusion_patterns.__module__ = "torch.quantization.fx.pattern_utils"
|
||||
register_quant_pattern.__module__ = "torch.quantization.fx.pattern_utils"
|
||||
|
|
@ -26,7 +24,6 @@ get_default_output_activation_post_process_map.__module__ = "torch.quantization.
|
|||
|
||||
# __all__ = [
|
||||
# "QuantizeHandler",
|
||||
# "MatchResult",
|
||||
# "register_fusion_pattern",
|
||||
# "get_default_fusion_patterns",
|
||||
# "register_quant_pattern",
|
||||
|
|
|
|||
Loading…
Reference in a new issue