diff --git a/test/quantization/test_numeric_suite_fx.py b/test/quantization/test_numeric_suite_fx.py index e0bec6bac7b..a9e012eaa52 100644 --- a/test/quantization/test_numeric_suite_fx.py +++ b/test/quantization/test_numeric_suite_fx.py @@ -32,20 +32,11 @@ from torch.testing._internal.common_quantization import ( skip_if_no_torchvision, test_only_eval_fn, ) -from torch.testing._internal.common_quantization import NodeSpec as ns from torch.testing._internal.common_quantized import override_qengines from torch.quantization.ns.graph_matcher import ( get_matching_node_pairs, GraphMatchingException, ) -from torch.quantization.ns.numeric_suite_core_apis_fx import ( - compare_weights, - prepare_model_outputs, - OutputLogger, - prepare_model_with_stubs, - get_matching_activations, - get_matching_activations_a_shadows_b, -) class TestGraphModeNumericSuite(QuantizationTestCase): @@ -577,197 +568,3 @@ class TestFXGraphMatcherModels(QuantizationTestCase): mq = convert_fx(mp_copy) # assume success if no exceptions results = get_matching_node_pairs(mp, mq) - -class TestFXNumericSuiteCoreAPIs(QuantizationTestCase): - - @override_qengines - def test_compare_weights_mod(self): - m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)).eval() - mp = prepare_fx(m, {'': torch.quantization.default_qconfig}) - # TODO(future PR): prevent the need for copying here, we can copy the - # modules but should reuse the underlying tensors - mp_copy = copy.deepcopy(mp) - mq = convert_fx(mp_copy) - results = compare_weights('fp32_prepared', mp, 'int8', mq) - self.assertTrue(len(results) == 2) - self.assert_ns_weight_compare_dict_valid(results) - - @override_qengines - def test_compare_weights_fun(self): - class M(nn.Module): - def __init__(self): - super().__init__() - self.w = nn.Parameter(torch.Tensor(4, 1)) - self.b = nn.Parameter(torch.Tensor(4)) - torch.nn.init.kaiming_uniform_(self.w, a=math.sqrt(5)) - - def forward(self, x): - return F.linear(x, self.w, self.b) - - m = M().eval() - mp = prepare_fx(m, {'': torch.quantization.default_qconfig}) - mp(torch.randn(1, 1)) - # TODO(future PR): prevent the need for copying here, we can copy the - # modules but should reuse the underlying tensors - mp_copy = copy.deepcopy(mp) - mq = convert_fx(mp_copy) - results = compare_weights('fp32_prepared', mp, 'int8', mq) - self.assertTrue(len(results) == 1) - self.assert_ns_weight_compare_dict_valid(results) - - @override_qengines - def test_match_activations_mod(self): - m = nn.Sequential( - torch.quantization.QuantStub(), - nn.Conv2d(1, 1, 1), - nn.Conv2d(1, 1, 1), - ).eval() - mp = prepare_fx(m, {'': torch.quantization.default_qconfig}) - mp(torch.randn(2, 1, 2, 2)) - # TODO(future PR): prevent the need for copying here, we can copy the - # modules but should reuse the underlying tensors - mp_copy = copy.deepcopy(mp) - mq = convert_fx(mp_copy) - - mp_ns, mq_ns = prepare_model_outputs( - 'fp32_prepared', mp, 'int8', mq, OutputLogger) - - expected_occurrence = { - ns.call_module(OutputLogger): 2, - } - self.checkGraphModuleNodes( - mp_ns, expected_node_occurrence=expected_occurrence) - self.checkGraphModuleNodes( - mq_ns, expected_node_occurrence=expected_occurrence) - - # TODO(before land): test both scripted and non-scripted - mp_ns = torch.jit.script(mp_ns) - mq_ns = torch.jit.script(mq_ns) - - # calibrate - input_fp32 = torch.randn(2, 1, 2, 2) - mp_ns(input_fp32) - mq_ns(input_fp32) - - # check activation result correctness - act_compare_dict = get_matching_activations(mp_ns, mq_ns, OutputLogger) - self.assertTrue(len(act_compare_dict) == 2) - self.assert_ns_logger_act_compare_dict_valid(act_compare_dict) - - @override_qengines - def test_match_activations_fun(self): - class M(nn.Module): - def __init__(self): - super().__init__() - self.w1 = nn.Parameter(torch.Tensor(4, 4)) - self.b1 = nn.Parameter(torch.Tensor(4)) - self.w2 = nn.Parameter(torch.Tensor(4, 4)) - self.b2 = nn.Parameter(torch.Tensor(4)) - torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) - torch.nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5)) - - def forward(self, x): - x = F.linear(x, self.w1, self.b1) - x = F.linear(x, self.w2, self.b2) - return x - - m = M().eval() - mp = prepare_fx(m, {'': torch.quantization.default_qconfig}) - mp(torch.randn(4, 4)) - # TODO(future PR): prevent the need for copying here, we can copy the - # modules but should reuse the underlying tensors - mp_copy = copy.deepcopy(mp) - mq = convert_fx(mp_copy) - - mp_ns, mq_ns = prepare_model_outputs( - 'fp32_prepared', mp, 'int8', mq, OutputLogger) - - expected_occurrence = { - ns.call_module(OutputLogger): 2, - } - self.checkGraphModuleNodes( - mp_ns, expected_node_occurrence=expected_occurrence) - self.checkGraphModuleNodes( - mq_ns, expected_node_occurrence=expected_occurrence) - - # TODO(before land): test both scripted and non-scripted - mp_ns = torch.jit.script(mp_ns) - mq_ns = torch.jit.script(mq_ns) - - # calibrate - input_fp32 = torch.randn(4, 4) - mp_ns(input_fp32) - mq_ns(input_fp32) - - # check activation result correctness - act_compare_dict = get_matching_activations(mp_ns, mq_ns, OutputLogger) - self.assertTrue(len(act_compare_dict) == 2) - self.assert_ns_logger_act_compare_dict_valid(act_compare_dict) - - @override_qengines - def test_prepare_model_with_stubs_mod(self): - m = nn.Sequential( - nn.Conv2d(1, 1, 1), - nn.Conv2d(1, 1, 1), - ).eval() - mp = prepare_fx(m, {'': torch.quantization.default_qconfig}) - mp(torch.randn(1, 1, 4, 4)) - # TODO(future PR): prevent the need for copying here, we can copy the - # modules but should reuse the underlying tensors - mp_copy = copy.deepcopy(mp) - mq = convert_fx(mp_copy) - - mp_shadows_mq = prepare_model_with_stubs('fp32_prepared', mp, 'int8', mq, OutputLogger) - - # TODO(before land): test both scripted and non-scripted - mp_shadows_mq = torch.jit.script(mp_shadows_mq) - - # calibrate - input_fp32 = torch.randn(1, 1, 4, 4) - mp_shadows_mq(input_fp32) - - # check activation result correctness - act_compare_dict = get_matching_activations_a_shadows_b( - mp_shadows_mq, OutputLogger) - self.assertTrue(len(act_compare_dict) == 2) - self.assert_ns_logger_act_compare_dict_valid(act_compare_dict) - - @override_qengines - def test_prepare_model_with_stubs_fun(self): - class M(nn.Module): - def __init__(self): - super().__init__() - self.w1 = nn.Parameter(torch.Tensor(4, 4)) - self.b1 = nn.Parameter(torch.Tensor(4)) - self.w2 = nn.Parameter(torch.Tensor(4, 4)) - self.b2 = nn.Parameter(torch.Tensor(4)) - torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) - torch.nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5)) - - def forward(self, x): - x = F.linear(x, self.w1, self.b1) - x = F.linear(x, self.w2, self.b2) - return x - - m = M().eval() - mp = prepare_fx(m, {'': torch.quantization.default_qconfig}) - mp(torch.randn(4, 4)) - # TODO(future PR): prevent the need for copying here, we can copy the - # modules but should reuse the underlying tensors - mp_copy = copy.deepcopy(mp) - mq = convert_fx(mp_copy) - - mp_shadows_mq = prepare_model_with_stubs('fp32_prepared', mp, 'int8', mq, OutputLogger) - - # TODO(before land): test both scripted and non-scripted - mp_shadows_mq = torch.jit.script(mp_shadows_mq) - - # calibrate - input_fp32 = torch.randn(4, 4) - mp_shadows_mq(input_fp32) - - # check activation result correctness - act_compare_dict = get_matching_activations_a_shadows_b( - mp_shadows_mq, OutputLogger) - self.assertTrue(len(act_compare_dict) == 2) - self.assert_ns_logger_act_compare_dict_valid(act_compare_dict) diff --git a/test/test_quantization.py b/test/test_quantization.py index a1543766667..7214f236443 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -80,7 +80,6 @@ try: from quantization.test_numeric_suite_fx import TestGraphModeNumericSuite # noqa: F401 from quantization.test_numeric_suite_fx import TestFXGraphMatcher # noqa: F401 from quantization.test_numeric_suite_fx import TestFXGraphMatcherModels # noqa: F401 - from quantization.test_numeric_suite_fx import TestFXNumericSuiteCoreAPIs # noqa: F401 except ImportError: pass diff --git a/torch/quantization/ns/graph_matcher.py b/torch/quantization/ns/graph_matcher.py index 0a691458142..0fd03c77c55 100644 --- a/torch/quantization/ns/graph_matcher.py +++ b/torch/quantization/ns/graph_matcher.py @@ -11,9 +11,16 @@ toq = torch.ops.quantized from torch.fx import GraphModule from torch.fx.graph import Graph, Node -from .utils import getattr_from_fqn +from typing import Dict, Tuple, List, Optional, Set, Callable, Any -from typing import Dict, Tuple, List, Optional, Set, Callable +# TODO(before land): delete this +def _print_node(node: Optional[Node]) -> None: + if node is None: + print(None) + else: + print( + node, ', target:', node.target, ', op:', node.op, + ', args:', node.args, ', kwargs:', node.kwargs) def _get_output_nodes(g: Graph) -> List[Node]: return [n for n in g.nodes if n.op == 'output'] @@ -82,6 +89,16 @@ def get_non_matchable_modules() -> Set[Callable]: torch.quantization.FakeQuantizeBase, ]) +def _getattr_from_fqn(gm: GraphModule, fqn: str) -> Any: + """ + Given a gm and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz. + """ + fqn_parts = fqn.split(".") + cur_val = gm + for part in fqn_parts: + cur_val = getattr(cur_val, part) + return cur_val + class _NSGraphMatchableNodesIterator: """ Iterates through the graph of gm, starting with the output nodes @@ -136,7 +153,7 @@ class _NSGraphMatchableNodesIterator: elif node.op == 'call_module': assert isinstance(node.target, str) # target_mod = getattr(self.gm, node.target) - target_mod = getattr_from_fqn(self.gm, node.target) + target_mod = _getattr_from_fqn(self.gm, node.target) return not \ any(isinstance(target_mod, t) # type: ignore for t in self.non_matchable_modules) @@ -170,9 +187,9 @@ def _node_a_related_to_b( elif node_a.op == 'call_module': # for call_module, we need to look up the modules to do the type check assert isinstance(node_a.target, str) - mod_a = getattr_from_fqn(gm_a, node_a.target) + mod_a = _getattr_from_fqn(gm_a, node_a.target) assert isinstance(node_b.target, str) - mod_b = getattr_from_fqn(gm_b, node_b.target) + mod_b = _getattr_from_fqn(gm_b, node_b.target) # modules with equivalent types always match (i.e. nn.Conv2d and nn.Conv2d) if type(mod_a) == type(mod_b): return True @@ -196,7 +213,7 @@ def _get_node_target_type(node: Node, gm: GraphModule) -> Optional[Callable]: return node.target # type: ignore elif node.op == 'call_module': assert isinstance(node.target, str) - mod = getattr_from_fqn(gm, node.target) + mod = _getattr_from_fqn(gm, node.target) return type(mod) return None @@ -263,6 +280,13 @@ def get_matching_node_pairs( except StopIteration: pass + # TODO(before land): remove + if False: + print('a') + _print_node(cur_node_a) + print('b') + _print_node(cur_node_b) + # look up types of a and b for useful error messages type_a, type_b = None, None if cur_node_a is not None: diff --git a/torch/quantization/ns/graph_passes.py b/torch/quantization/ns/graph_passes.py deleted file mode 100644 index 385f07e970c..00000000000 --- a/torch/quantization/ns/graph_passes.py +++ /dev/null @@ -1,346 +0,0 @@ -import torch -from torch.fx import GraphModule, map_arg -from torch.fx.graph import Graph, Node -from torch.quantization.fx.quantize import is_activation_post_process -from torch.quantization.fx.utils import get_new_attr_name_with_prefix - -from .utils import ( - get_node_io_type, - getattr_from_fqn, - print_node, - NodeIOType, - return_first_non_observer_node, -) - -from typing import Dict, Tuple, Callable, List, Any, Optional - -def _insert_logger_after_node( - node: Node, - gm: GraphModule, - logger_cls: Callable, - logger_node_name_suffix: str, - model_name: str, - other_node_name: Optional[str] = None, -) -> Node: - """ - Given a starting graph of - - prev_node -> node -> next_node - - This function creates a new logger_cls obj and adds it - after node, resulting in - - prev_node -> node -> logger_obj -> next_node - """ - # create new name - logger_node_name = \ - get_new_attr_name_with_prefix(node.name + logger_node_name_suffix)(gm) - # create the logger object - logger_obj = logger_cls(node.name, model_name, other_node_name) - # attach the logger object to the parent module - setattr(gm, logger_node_name, logger_obj) - logger_node = node.graph.create_node( - 'call_module', logger_node_name, (node,), {}) - return logger_node - -def remove_observers_add_loggers( - gm: GraphModule, - nodes_to_instrument: List[Node], - logger_cls: Callable, - model_name: str, -) -> GraphModule: - """ - Takes the graph of gm, removes all observers, adds loggers to the output - of each node in nodes_to_instrument. Returns a GraphModule with the new - graph. - """ - - new_graph = Graph() - env: Dict[str, Any] = {} - modules = dict(gm.named_modules()) - - def load_arg(a): - return map_arg(a, lambda node: env[node.name]) - - for node in gm.graph.nodes: - if node.op == 'output': - new_graph.output(map_arg(node.args[0], load_arg)) - continue - - if node.op == 'call_module' and is_activation_post_process(modules[node.target]): - # remove activation post process node - env[node.name] = env[node.args[0].name] - - elif node in nodes_to_instrument: - # ensure env is populated with base node - env[node.name] = new_graph.node_copy(node, load_arg) - # add the logger after the base node - env[node.name] = _insert_logger_after_node( - env[node.name], gm, logger_cls, '_ns_logger_', model_name) - - else: - env[node.name] = new_graph.node_copy(node, load_arg) - - new_gm = GraphModule(gm, new_graph) - return new_gm - -def _insert_dtype_cast_after_node( - node_a: Node, - node_c: Node, - prev_node_c: Node, - gm_a: GraphModule, - gm_b: GraphModule, - node_name_prefix: str, -) -> Node: - """ - Given a starting graph C (derived from graph B) of - - ... -> prev_node_c -> node_c -> ... - - And a corresponding related node_a, inserts the correct dtype - cast node after prev_node_c to cast into the dtype expected - by node_a, resulting in: - - dtype_cast - / - ... -> prev_node_c -> node_c -> ... - - For example, if node_c is an int8 op and node_a is an fp32 op, this function - will insert a dequant. - """ - dtype_cast_op = None - node_io_type_a = get_node_io_type(node_a, gm_a) - node_io_type_c = get_node_io_type(node_c, gm_b) - - if node_io_type_a == NodeIOType.FP32 and node_io_type_c == NodeIOType.INT8: - dtype_cast_op = torch.dequantize - else: - raise AssertionError( - f"dtype cast from {node_io_type_c} to {node_io_type_a} needs to be implemented") - - new_dtype_cast_name = \ - get_new_attr_name_with_prefix(node_name_prefix)(gm_b) - return prev_node_c.graph.create_node( - 'call_function', dtype_cast_op, (prev_node_c,), {}, - new_dtype_cast_name) - -def _insert_copy_of_node_a_after_input_node_c( - input_node_c: Node, - node_a: Node, - gm_a: GraphModule, - gm_b: GraphModule, - node_name_prefix: str, -) -> Node: - """ - Assume that node_a from graph_a has - args (input, arg1, ...), and - kwargs {kw0: kwarg0, ...} - - Copies the underlying values of arg1..argn and kwarg0..kwargn into gm_b, - and creates the corresponding nodes in graph_c. Note: observers are ignored, - so if an arg is an observer we navigate up until we find a non-observer parent. - - If node_a is a call_module, points the module pointed to by node_a to gm_b. - - Creates the copy of node_a in graph_c, with input as the first arg, - and all other args and kwargs pointing to the copies of the objects - in gm_b created above. - - An example in pictures: - - graph A: - ======== - - input -------------> node_a - / / - weight -> weight_obs / - / - bias ---------------- - - graph C (derived from B): - ========================= - - input_node_c --> node_a_copy - / / - weight_copy ----/ / - / - bias_copy ------/ - """ - graph_c = input_node_c.graph - - # generically handle all args and kwargs except for the input - # Note: this hasn't been tested with many ops, logic may change. - new_args = [] - # assumes that the first arg is the input - for node_a_arg in node_a.args[1:]: - if isinstance(node_a_arg, Node): - arg_a = return_first_non_observer_node(node_a_arg, gm_a) - arg_a_copy_name = \ - get_new_attr_name_with_prefix(arg_a.name + '_shadow_copy_')(gm_b) # type: ignore - arg_a_obj = getattr_from_fqn(gm_a, arg_a.target) # type: ignore - setattr(gm_b, arg_a_copy_name, arg_a_obj.detach()) - node_a_arg_copy = graph_c.create_node( - 'get_attr', arg_a_copy_name, (), {}, arg_a_copy_name) - new_args.append(node_a_arg_copy) - else: - raise AssertionError( - f"handling for arg of type {type(node_a_arg)} is not implemented") - - new_kwargs = {} - for node_a_k, node_a_kwarg in node_a.kwargs.items(): - kwarg_a_copy_name = \ - get_new_attr_name_with_prefix(node_a_kwarg.name + '_shadow_copy_')(gm_b) # type: ignore - kwarg_a_obj = getattr_from_fqn(gm_a, node_a_kwarg.target) # type: ignore - setattr(gm_b, kwarg_a_copy_name, kwarg_a_obj.detach()) - node_a_kwarg_copy = graph_c.create_node( - 'get_attr', kwarg_a_copy_name, (), {}, kwarg_a_copy_name) - new_kwargs[node_a_k] = node_a_kwarg_copy - - node_a_shadows_c_name = \ - get_new_attr_name_with_prefix(node_name_prefix)(gm_b) - - if node_a.op == 'call_module': - # if target is a module, we point to the module from gm_b - new_mod_copy_name = \ - get_new_attr_name_with_prefix(node_name_prefix)(gm_b) - # fetch the corresponding module from gm_a - assert isinstance(node_a.target, str) - mod_a = getattr_from_fqn(gm_a, node_a.target) - setattr(gm_b, new_mod_copy_name, mod_a) - node_a_shadows_c = graph_c.create_node( - node_a.op, new_mod_copy_name, (input_node_c, *new_args), - new_kwargs, node_a_shadows_c_name) # type: ignore - return node_a_shadows_c - else: - assert node_a.op == 'call_function' - node_a_shadows_c = graph_c.create_node( - node_a.op, node_a.target, (input_node_c, *new_args), - new_kwargs, node_a_shadows_c_name) # type: ignore - return node_a_shadows_c - -def create_a_shadows_b( - name_a: str, - gm_a: GraphModule, - name_b: str, - gm_b: GraphModule, - matched_node_pairs: Dict[str, Tuple[Node, Node]], - logger_cls: Callable, -) -> GraphModule: - """ - Creates a new GraphModule consisting of the graph of C, with the meaningful - nodes of A shadowing the corresponding nodes of B. For example, - - Graph A: - a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2 - - Graph B: - b0 -> op0_int8 -> b1 -> op1_int8 -> b2 - - matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)} - - Graph C (A shadows B): - - / dequant0 -> op0_fp32 -> logger_a_0 / dequant_1 -> op1_fp32 -> logger_a_1 - / / - b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1 - - In a nutshell, this function does the following for each node pair: - * copies the necessary attributes and modules from gm_a to gm_b, - keeping names unique - * adds a dtype cast op (dequant, quant, etc) - * adds a copy of node_a in gm_b's graph - * adds loggers to the outputs of node_a and node_b - """ - - # graph_c is the graph created from copying the nodes of graph_b and inserting - # the shadows with the nodes copied from graph_a - graph_c = Graph() - env_c: Dict[str, Any] = {} - modules = dict(gm_b.named_modules()) - - def load_arg(a): - return map_arg(a, lambda node: env_c[node.name]) - - nodes_to_instrument_b_to_a = {} - for match_name, (node_a, node_b) in matched_node_pairs.items(): - nodes_to_instrument_b_to_a[node_b] = node_a - - for node_b in gm_b.graph.nodes: - if node_b.op == 'output': - graph_c.output(map_arg(node_b.args[0], load_arg)) - continue - - if node_b.op == 'call_module' and is_activation_post_process(modules[node_b.target]): - # remove activation post process node - env_c[node_b.name] = env_c[node_b.args[0].name] # type: ignore - - elif node_b in nodes_to_instrument_b_to_a: - node_a = nodes_to_instrument_b_to_a[node_b] - if False: - print('b') - print_node(node_b) - print('a') - print_node(node_a) - - # ensure env_c is populated with base node - env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) - node_c = env_c[node_b.name] - - # after this point, - # - # node_a is the original node from graph_a, with parent module gm_a - # node_b is the original node from graph_b, with parent module gm_b - # node_c is the copy of node_b in graph_c - # - # subgraph so far: - # - # node_c - - # cast dtype from the dtype of node_c's input to the dtype of - # node_a's input (dequant, etc) - dtype_cast_node = _insert_dtype_cast_after_node( - node_a, node_c, node_c.args[0], gm_a, gm_b, node_b.name + '_dtype_cast_') - env_c[dtype_cast_node.name] = dtype_cast_node - # subgraph so far: - # - # dtype_cast_node - # / - # node_c - - # hook up the new mod_a copy to be in the graph, receiving the - # same inputs as mod_b does, with dtype cast to match a - node_a_shadows_c = _insert_copy_of_node_a_after_input_node_c( - env_c[dtype_cast_node.name], - node_a, gm_a, gm_b, node_c.name + '_shadow_copy_') - env_c[node_a_shadows_c.name] = node_a_shadows_c - # subgraph so far: - # - # dtype_cast_node --> node_a_copy(args/kwargs not shown) - # / - # node_c - - # hook up a logger to the mod_b copy - env_c[node_b.name] = _insert_logger_after_node( - env_c[node_b.name], gm_b, logger_cls, '_ns_logger_b_', name_b) - # subgraph so far: - # - # dtype_cast_node --> node_a_copy - # / - # node_c --> logger_c - - # hook up a logger to the mod_a copy - # Note: we pass node_b.name to this logger, for easy matching later - env_c[node_a_shadows_c.name] = _insert_logger_after_node( - env_c[node_a_shadows_c.name], gm_b, logger_cls, '_ns_logger_a_', name_a, - node_b.name) - # subgraph so far: - # - # dtype_cast_node --> node_a_copy --> logger_a - # / - # node_c --> logger_c - - else: - env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) - - gm_c = GraphModule(gm_b, graph_c) - return gm_c diff --git a/torch/quantization/ns/numeric_suite_core_apis_fx.py b/torch/quantization/ns/numeric_suite_core_apis_fx.py deleted file mode 100644 index dd8bea3285c..00000000000 --- a/torch/quantization/ns/numeric_suite_core_apis_fx.py +++ /dev/null @@ -1,241 +0,0 @@ -import collections - -import torch -import torch.nn as nn -import torch.nn.functional as F -toq = torch.ops.quantized -from torch.fx import GraphModule -from torch.quantization.ns.graph_matcher import ( - get_matching_node_pairs, - get_type_a_related_to_b, -) - -from .utils import ( - getattr_from_fqn, -) - -from .weight_utils import ( - get_conv_mod_weight, - get_linear_fun_weight, -) - -from .graph_passes import ( - remove_observers_add_loggers, - create_a_shadows_b, -) - -from typing import Dict, Tuple, Callable, List, Optional - - -# Note: this is not a user facing API -# TODO(future PR): wrap this in a user facing API which does not -# expose FX types. -def compare_weights( - name_a: str, - gm_a: GraphModule, - name_b: str, - gm_b: GraphModule, -) -> Dict[str, Dict[str, torch.Tensor]]: - type_a_related_to_b = get_type_a_related_to_b() - matched_node_pairs = get_matching_node_pairs(gm_a, gm_b) - - results = {} - - for match_name, match in matched_node_pairs.items(): - - node_a, node_b = match - assert node_a.op == node_b.op and \ - node_a.op in ('call_function', 'call_module') - - if node_a.op == 'call_function': - - # linear - # TODO(future PR): other function types - a_related_to_linear = node_a.target in (F.linear,) or \ - (node_a.target, F.linear) in type_a_related_to_b - - if a_related_to_linear: - weight_a = get_linear_fun_weight(node_a, gm_a) - weight_b = get_linear_fun_weight(node_b, gm_b) - - results[match_name] = { - name_a: weight_a, - name_b: weight_b, - } - - else: # call_module - # for call_module, we need to look up the modules to do the type check - assert isinstance(node_a.target, str) - mod_a = getattr_from_fqn(gm_a, node_a.target) - assert isinstance(node_b.target, str) - mod_b = getattr_from_fqn(gm_b, node_b.target) - - # check that A is one the modules we need - # assume B is related (this is done by graph matcher) - a_related_to_conv2d_mod = isinstance(mod_a, nn.Conv2d) or \ - (type(mod_a), nn.Conv2d) in type_a_related_to_b - - # TODO(future PR): other module types - if a_related_to_conv2d_mod: - weight_a = get_conv_mod_weight(mod_a) - weight_b = get_conv_mod_weight(mod_b) - results[match_name] = { - name_a: weight_a, - name_b: weight_b, - } - - return results - - -class OutputLogger(nn.Module): - stats: List[torch.Tensor] - - def __init__( - self, - node_name: str, - model_name: str, - other_node_name: Optional[str] = None, - ): - super().__init__() - self.stats: List[torch.Tensor] = [] - # name of the node whose output this Logger is capturing - self.node_name = node_name - # name of the model from which the node originated from - self.model_name = model_name - # name of the other node with a matching Logger - # used to link node_a_copy -> logger_a to node_c -> logger_c - # in a_shadows_b - self.other_node_name = other_node_name - - def forward(self, x: torch.Tensor): - self.stats.append(x.detach()) - return x - - def __repr__(self): - return f"OutputLogger(node_name={self.node_name}, model_name={self.model_name}, other_node_name={self.other_node_name})" - -# Note: this is not a user facing API -# TODO(future PR): wrap this in a user facing API which does not -# expose FX types. -def prepare_model_outputs( - name_a: str, - gm_a: GraphModule, - name_b: str, - gm_b: GraphModule, - logger_cls: Callable, -) -> Tuple[GraphModule, GraphModule]: - - matched_node_pairs = get_matching_node_pairs(gm_a, gm_b) - - nodes_to_instrument_a = [] - nodes_to_instrument_b = [] - for match_name, (node_a, node_b,) in matched_node_pairs.items(): - # TODO(future PR): do not observe pairs of nodes we do not care - # about (both fp32, denylist, etc) - nodes_to_instrument_a.append(node_a) - nodes_to_instrument_b.append(node_b) - - gm_a = remove_observers_add_loggers(gm_a, nodes_to_instrument_a, logger_cls, name_a) - gm_b = remove_observers_add_loggers(gm_b, nodes_to_instrument_b, logger_cls, name_b) - return (gm_a, gm_b) - -# Note: this is not a user facing API -# TODO(future PR): wrap this in a user facing API which does not -# expose FX types. -# TODO(future PR): align on naming -# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs` -def get_matching_activations( - gm_a: GraphModule, - gm_b: GraphModule, - logger_cls: Callable, -) -> Dict[str, Dict[str, List[torch.Tensor]]]: - """ - Same thing as ns.get_matching_activations, but for FX models prepared with - this module. - - TODO(future PR): real docblock - - Output format: - - { - 'layer1.stats': { - 'name_a': [torch.Tensor(...), ...], - 'name_b': [torch.Tensor(...), ...], - }, - ... - } - - Note, there are three differences from the output format of Eager NS: - 1. `name_a` and `name_b` are used instead of hardcoding names - to `float` and `quantized`. - 2. Lists of Tensors are returned instead of individual Tensors, to unify - the return type for calibrating with 1 input vs N inputs. - 3. `logger_cls` is included in the API for easy result extraction - """ - results: Dict[str, Dict[str, List[torch.Tensor]]] = \ - collections.defaultdict(dict) - for gm in (gm_a, gm_b): - for gm_name, mod in gm.named_modules(): - # TODO(future PR): better check when scripted - is_logger = ( - isinstance(mod, logger_cls) # type: ignore - or ( - isinstance(mod, torch.jit.RecursiveScriptModule) - and mod.original_name == 'OutputLogger' - ) - ) - if is_logger: - results[mod.node_name + '.stats'][mod.model_name] = mod.stats - return dict(results) - -# Note: this is not a user facing API -# TODO(future PR): wrap this in a user facing API which does not -# expose FX types. -def prepare_model_with_stubs( - name_a: str, - gm_a: GraphModule, - name_b: str, - gm_b: GraphModule, - logger_cls: Callable, -) -> GraphModule: - """ - Same thing as prepare_model_outputs, but for an `a_shadows_b` model. - TODO(future PR): real docblock - """ - matched_node_pairs = get_matching_node_pairs(gm_a, gm_b) - gm_a_shadows_b = create_a_shadows_b( - name_a, gm_a, name_b, gm_b, matched_node_pairs, logger_cls) - return gm_a_shadows_b - -# Note: this is not a user facing API -# TODO(future PR): wrap this in a user facing API which does not -# expose FX types. -# TODO(future PR): align on naming -# this is equivalent of just the comparison extraction part of `ns.compare_model_stub` -def get_matching_activations_a_shadows_b( - gm_a_shadows_b: GraphModule, - logger_cls: Callable, -) -> Dict[str, Dict[str, List[torch.Tensor]]]: - """ - Same thing as get_matching_activations, but for an `a_shadows_b` model. - TODO(future PR): real docblock - """ - results: Dict[str, Dict[str, List[torch.Tensor]]] = \ - collections.defaultdict(dict) - for name, mod in gm_a_shadows_b.named_modules(): - # TODO(future PR): better check when scripted - is_logger = ( - isinstance(mod, logger_cls) # type: ignore - or ( - isinstance(mod, torch.jit.RecursiveScriptModule) - and mod.original_name == 'OutputLogger' - ) - ) - if is_logger: - # If logger_obj.other_node_name is populated, then this logger - # is from model A, and other_node_name is the name from model B. - if mod.other_node_name is None: - results[mod.node_name + '.stats'][mod.model_name] = mod.stats - else: - results[mod.other_node_name + '.stats'][mod.model_name] = mod.stats - return dict(results) diff --git a/torch/quantization/ns/utils.py b/torch/quantization/ns/utils.py deleted file mode 100644 index 3bc4b9ee1d4..00000000000 --- a/torch/quantization/ns/utils.py +++ /dev/null @@ -1,84 +0,0 @@ -import enum - -from torch.fx import GraphModule -from torch.fx.graph import Node -from torch.quantization.fx.quantize import is_activation_post_process - -from typing import Optional, Any - -# TODO(future PR): delete this after FX has a util for it -def print_node(node: Optional[Node]) -> None: - if node is None: - print(None) - else: - print( - node, ', target:', node.target, ', op:', node.op, - ', args:', node.args, ', kwargs:', node.kwargs) - -def getattr_from_fqn(gm: GraphModule, fqn: str) -> Any: - """ - Given a gm and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz. - """ - fqn_parts = fqn.split(".") - cur_val = gm - for part in fqn_parts: - cur_val = getattr(cur_val, part) - return cur_val - -class NodeIOType(enum.Enum): - FP32 = enum.auto() # all inputs and outputs fp32 - INT8 = enum.auto() # all inputs and outputs int8 - # TODO(future PRs): dynamic quant, fake quant, etc - - -def get_node_io_type(node: Node, gm: GraphModule) -> NodeIOType: - if node.op == 'call_function': - fp32_fun_target_names = ('torch.nn.functional', 'torch.nn') - int8_fun_target_names = ('torch._ops.quantized',) - # For now, hacky check to see which op is in which namespace - # TODO(future PR): use a real mapping - if node.target.__module__ in fp32_fun_target_names: - return NodeIOType.FP32 - else: - assert node.target.__module__ in int8_fun_target_names, \ - 'unknown node target %s' % node.target - return NodeIOType.INT8 - else: - assert node.op == 'call_module' - assert isinstance(node.target, str) - mod = getattr_from_fqn(gm, node.target) - # For now, hacky check to see which mod is in which namespace - # TODO(future PR): use a real mapping - if mod.__module__.startswith('torch.nn.modules'): - return NodeIOType.FP32 - else: - assert mod.__module__.startswith('torch.nn.q'), \ - 'unknown node target %s' % mod - return NodeIOType.INT8 - -def return_first_non_observer_node( - node: Node, - gm: GraphModule, -) -> Node: - """ - If node is not an observer, returns it. If node is an observer, - navigates up the graph and returns the first parent which is not an - observer. For example, - - graph: (node_non_obs), node = node_non_obs : returns node_non_obs - graph: (node_non_obs -> obs0), node = obs0 : returns node_non_obs - graph: (node_non_obs -> obs0 -> fq0), node = fq0 : returns node_non_obs - """ - node_obj = getattr_from_fqn(gm, node.target) # type: ignore - if is_activation_post_process(node_obj): - assert len(node.args) == 1 - assert isinstance(node.args[0], Node) - node = node.args[0] - # code duplication intended, not worth refactoring - assert isinstance(node.target, str) - node_obj = getattr_from_fqn(gm, node.target) - if is_activation_post_process(node_obj): - assert len(node.args) == 1 - assert isinstance(node.args[0], Node) - node = node.args[0] - return node diff --git a/torch/quantization/ns/weight_utils.py b/torch/quantization/ns/weight_utils.py deleted file mode 100644 index 1beed277ff3..00000000000 --- a/torch/quantization/ns/weight_utils.py +++ /dev/null @@ -1,45 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -toq = torch.ops.quantized -from torch.fx import GraphModule -from torch.fx.graph import Node - -from .utils import getattr_from_fqn - -def get_conv_mod_weight(mod: nn.Module) -> torch.Tensor: - # TODO(future PR): make more generic, handle everything - if isinstance(mod, nn.Conv2d): - return mod.weight.detach() - else: - return mod._weight_bias()[0] # type: ignore - -def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor: - # TODO(future PR): better docblock, with example FX IR - if node.target in (F.linear,): - # traverse backwards from the weight arg, accounting for - # any observers - weight_arg_node = node.args[1] - # print_node(weight_arg_node) - assert isinstance(weight_arg_node, Node) - weight_node = weight_arg_node.args[0] - # print_node(weight_node) - # TODO(future PR): currently assumes 1 observer, handle arbitrary - # levels of observation, from 0 to N - assert isinstance(weight_node, Node) - assert weight_node.op == 'get_attr' - weight = getattr_from_fqn(gm, weight_node.target) # type: ignore - return weight.detach() - - else: - assert node.target in (toq.linear,) - # packed weight is arg 1 - packed_weight_node = node.args[1] - assert isinstance(packed_weight_node, Node) - assert packed_weight_node.op == 'get_attr' - packed_weight = getattr_from_fqn(gm, packed_weight_node.target) # type: ignore - # TODO(future PR): why does packed_weight.unpack() not work? - # TODO(future PR): discuss if we even need to unpack, or if the - # caller can handle the unpacking - (weight, _bias), _name = packed_weight.__getstate__() - return weight diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index f328a35e5e8..c0fd4ffe16c 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -42,7 +42,7 @@ import os import unittest import numpy as np from torch.testing import FileCheck -from typing import Callable, Tuple, Dict, List +from typing import Callable, Tuple, Dict class NodeSpec: ''' Used for checking GraphModule Node @@ -558,7 +558,6 @@ class QuantizationTestCase(TestCase): nodes_in_graph[n] = 1 if expected_node is not None: - print('expected_node', expected_node) self.assertTrue(expected_node in nodes_in_graph, 'node:' + str(expected_node) + ' not found in the graph module') @@ -655,49 +654,6 @@ class QuantizationTestCase(TestCase): (k, (expected_type_a, expected_type_b), (actual_type_a, actual_type_b)) ) - def assert_ns_weight_compare_dict_valid( - self, - weight_compare_dict: Dict[str, Dict[str, torch.Tensor]], - ) -> None: - """ - Verifieds that the weight_compare dict (output of Numeric Suite - weight matching APIs) is valid: - 1. for each layer, results are recorded for two models - 2. shapes of each pair of weights match - """ - for layer_name, layer_data in weight_compare_dict.items(): - self.assertTrue( - len(layer_data) == 2, - f"Layer {layer_name} does not have exactly two model results.") - k0, k1 = layer_data.keys() - self.assertTrue( - layer_data[k0].shape == layer_data[k1].shape, - f"Layer {layer_name}, {k0} and {k1} have a shape mismatch.") - - def assert_ns_logger_act_compare_dict_valid( - self, - act_compare_dict: Dict[str, Dict[str, List[torch.Tensor]]], - ) -> None: - """ - Verifies that the act_compare_dict (output of Numeric Suite - activation matching APIs) is valid: - 1. for each layer, results are recorded for two models - 2. number of seen tensors match - 3. shapes of each pair of seen tensors match - """ - for layer_name, layer_data in act_compare_dict.items(): - self.assertTrue( - len(layer_data) == 2, - f"Layer {layer_name} does not have exactly two model results.") - k0, k1 = layer_data.keys() - self.assertTrue( - len(layer_data[k0]) == len(layer_data[k1]), - f"Layer {layer_name}, {k0} and {k1} do not have the same number of seen Tensors.") - for idx in range(len(layer_data[k0])): - self.assertTrue( - layer_data[k0][idx].shape == layer_data[k1][idx].shape, - f"Layer {layer_name}, {k0} and {k1} have a shape mismatch at idx {idx}.") - def checkGraphModeFxOp(self, model, inputs, quant_type, expected_node=None, expected_node_occurrence=None,