diff --git a/test/quantization/pt2e/test_numeric_debugger.py b/test/quantization/pt2e/test_numeric_debugger.py index 3a8a75d3773..7808eb89257 100644 --- a/test/quantization/pt2e/test_numeric_debugger.py +++ b/test/quantization/pt2e/test_numeric_debugger.py @@ -25,8 +25,8 @@ from torch.testing._internal.common_quantization import TestHelperModules from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef, TestCase -def _extract_debug_handles(model) -> Dict[str, int]: - debug_handle_map: Dict[str, int] = {} +def _extract_debug_handles(model) -> Dict[torch.fx.Node, int]: + debug_handle_map: Dict[torch.fx.Node, int] = {} for node in model.graph.nodes: if ( @@ -187,53 +187,3 @@ class TestNumericDebugger(TestCase): for node_summary in comparison_results.values(): if len(node_summary.results) > 0: self.assertGreaterEqual(node_summary.results[0].sqnr, 35) - - def test_added_node_gets_unique_id(self) -> None: - m = TestHelperModules.Conv2dThenConv1d() - example_inputs = m.example_inputs() - m = capture_pre_autograd_graph(m, example_inputs) - assert isinstance(m, torch.fx.GraphModule) - generate_numeric_debug_handle(m) - ref_handles = _extract_debug_handles(m) - ref_counter = Counter(ref_handles.values()) - for k, v in ref_counter.items(): - self.assertEqual( - v, - 1, - msg=f"For handle {k}, there were {v} nodes with that handle, but expected only 1", - ) - - # Now that we have unique ids, add a new node into the graph and re-generate - # to make sure that the new node gets a unique id. - last_node = next(iter(reversed(m.graph.nodes))) - with m.graph.inserting_before(last_node): - arg = last_node.args[0] - self.assertIsInstance(arg, tuple) - arg = arg[0] - # Add a function that only requires a single tensor input. - n = m.graph.call_function(torch.ops.aten.relu.default, args=(arg,)) - arg.replace_all_uses_with(n, lambda x: x != n) - m.recompile() - - # Regenerate handles, make sure only the new relu node has a new id, and - # it doesn't clash with any of the existing ids. - generate_numeric_debug_handle(m) - handles_after_modification = _extract_debug_handles(m) - handles_counter = Counter(handles_after_modification.values()) - for name, handle in ref_handles.items(): - self.assertIn(name, handles_after_modification) - # Check that handle was unchanged. - self.assertEqual(handles_after_modification[name], handle) - # Check that total count was unchanged. - ref_count = ref_counter[handle] - after_count = handles_counter[handle] - self.assertEqual( - after_count, - ref_count, - msg=f"For handle {handle}, there were {after_count} nodes with that handle, but expected only {ref_count}", - ) - - # Check for relu specifically. Avoid hardcoding the handle id since it - # may change with future node ordering changes. - self.assertNotEqual(handles_after_modification["relu_default"], 0) - self.assertEqual(handles_counter[handles_after_modification["relu_default"]], 1) diff --git a/torch/ao/quantization/pt2e/_numeric_debugger.py b/torch/ao/quantization/pt2e/_numeric_debugger.py index 3ae57acc8cb..fedcf470a18 100644 --- a/torch/ao/quantization/pt2e/_numeric_debugger.py +++ b/torch/ao/quantization/pt2e/_numeric_debugger.py @@ -1,7 +1,7 @@ import copy import logging from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Sequence, Tuple +from typing import Dict, List, Optional, Sequence, Tuple import torch from torch.ao.ns.fx.utils import compute_sqnr @@ -19,16 +19,7 @@ def generate_numeric_debug_handle(graph_module: GraphModule) -> None: """Attach numeric_debug_handle_id for all nodes in the model except for placeholder node The graph nodes of input model is modified inplace. """ - unique_id = -1 - # Find the max ID that exists in the graph first, in case part of the graph - # has already been annotated. This way we guarantee there are no duplicate - # handle IDs. - for node in graph_module.graph.nodes: - unique_id = max( - unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, -1) - ) - unique_id += 1 - + unique_id = 0 for node in graph_module.graph.nodes: if node.op in ["output", "placeholder"]: continue @@ -143,17 +134,6 @@ class QuantizationComparisonResult: self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32) ) - def loss( - self, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] - ) -> torch.Tensor: - if self.actual.shape != self.ref.shape: - raise ValueError( - f"Cannot compare tensors with different shapes: {self.actual.shape} vs {self.ref.shape}" - ) - return loss_function( - self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32) - ) - def __repr__(self) -> str: # Don't include the tensors themselves as they are quite large to print # out. @@ -169,10 +149,6 @@ class QuantizationComparisonResult: if not isinstance(self.ref, torch.Tensor): raise ValueError(f"`self.ref` value must be a Tensor, got: {self.ref}") - if self.actual.shape != self.ref.shape: - raise ValueError( - f"Cannot compare tensors with different shapes: ref={self.ref.shape} vs actual={self.actual.shape}" - ) @dataclass(frozen=True) @@ -221,8 +197,8 @@ def extract_results_from_loggers( def compare_results( - ref_results: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]], - actual_results: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]], + ref_results: Dict[int, Tuple[str, object, List[torch.Tensor]]], + actual_results: Dict[int, Tuple[str, object, List[torch.Tensor]]], ) -> Dict[int, NodeAccuracySummary]: """Given two dict mapping from `debug_handle_id` (int) to list of tensors return a map from `debug_handle_id` to `NodeAccuracySummary` that contains @@ -244,25 +220,16 @@ def compare_results( ) continue actual_name, actual_stack, actual_stats = actual_results[debug_handle] - try: - results = [ - QuantizationComparisonResult(actual=a, ref=b) - for a, b in zip(actual_stats, ref_stats) - ] - except Exception as e: - # Add extra information for an exception from QuantizationComparisonResult - # if the shapes didn't match, to include the handle and the node names. - raise ValueError( - f"For numeric_debug_handle={debug_handle} from ref node {ref_name} and actual node {actual_name}" - ) from e - comparisons[debug_handle] = NodeAccuracySummary( handle=debug_handle, - actual_node_name=actual_name or "", + actual_node_name=actual_name, actual_module_stack=_module_stack_to_str(actual_stack), - ref_node_name=ref_name or "", + ref_node_name=ref_name, ref_module_stack=_module_stack_to_str(ref_stack), - results=results, + results=[ + QuantizationComparisonResult(actual=a, ref=b) + for a, b in zip(actual_stats, ref_stats) + ], ) return comparisons