Revert "[torch][ao] Add customizable loss function to NodeAccuracySummary (#136282)"

This reverts commit f3c54ccf8f.

Reverted https://github.com/pytorch/pytorch/pull/136282 on behalf of https://github.com/huydhn due to This breaks OSS, let revert it and land the revert internally then ([comment](https://github.com/pytorch/pytorch/pull/136282#issuecomment-2364219252))
This commit is contained in:
PyTorch MergeBot 2024-09-20 17:49:06 +00:00
parent 15dba021bb
commit df1eef9779
2 changed files with 12 additions and 95 deletions

View file

@ -25,8 +25,8 @@ from torch.testing._internal.common_quantization import TestHelperModules
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef, TestCase
def _extract_debug_handles(model) -> Dict[str, int]:
debug_handle_map: Dict[str, int] = {}
def _extract_debug_handles(model) -> Dict[torch.fx.Node, int]:
debug_handle_map: Dict[torch.fx.Node, int] = {}
for node in model.graph.nodes:
if (
@ -187,53 +187,3 @@ class TestNumericDebugger(TestCase):
for node_summary in comparison_results.values():
if len(node_summary.results) > 0:
self.assertGreaterEqual(node_summary.results[0].sqnr, 35)
def test_added_node_gets_unique_id(self) -> None:
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
m = capture_pre_autograd_graph(m, example_inputs)
assert isinstance(m, torch.fx.GraphModule)
generate_numeric_debug_handle(m)
ref_handles = _extract_debug_handles(m)
ref_counter = Counter(ref_handles.values())
for k, v in ref_counter.items():
self.assertEqual(
v,
1,
msg=f"For handle {k}, there were {v} nodes with that handle, but expected only 1",
)
# Now that we have unique ids, add a new node into the graph and re-generate
# to make sure that the new node gets a unique id.
last_node = next(iter(reversed(m.graph.nodes)))
with m.graph.inserting_before(last_node):
arg = last_node.args[0]
self.assertIsInstance(arg, tuple)
arg = arg[0]
# Add a function that only requires a single tensor input.
n = m.graph.call_function(torch.ops.aten.relu.default, args=(arg,))
arg.replace_all_uses_with(n, lambda x: x != n)
m.recompile()
# Regenerate handles, make sure only the new relu node has a new id, and
# it doesn't clash with any of the existing ids.
generate_numeric_debug_handle(m)
handles_after_modification = _extract_debug_handles(m)
handles_counter = Counter(handles_after_modification.values())
for name, handle in ref_handles.items():
self.assertIn(name, handles_after_modification)
# Check that handle was unchanged.
self.assertEqual(handles_after_modification[name], handle)
# Check that total count was unchanged.
ref_count = ref_counter[handle]
after_count = handles_counter[handle]
self.assertEqual(
after_count,
ref_count,
msg=f"For handle {handle}, there were {after_count} nodes with that handle, but expected only {ref_count}",
)
# Check for relu specifically. Avoid hardcoding the handle id since it
# may change with future node ordering changes.
self.assertNotEqual(handles_after_modification["relu_default"], 0)
self.assertEqual(handles_counter[handles_after_modification["relu_default"]], 1)

View file

@ -1,7 +1,7 @@
import copy
import logging
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Sequence, Tuple
from typing import Dict, List, Optional, Sequence, Tuple
import torch
from torch.ao.ns.fx.utils import compute_sqnr
@ -19,16 +19,7 @@ def generate_numeric_debug_handle(graph_module: GraphModule) -> None:
"""Attach numeric_debug_handle_id for all nodes in the model except for placeholder node
The graph nodes of input model is modified inplace.
"""
unique_id = -1
# Find the max ID that exists in the graph first, in case part of the graph
# has already been annotated. This way we guarantee there are no duplicate
# handle IDs.
for node in graph_module.graph.nodes:
unique_id = max(
unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, -1)
)
unique_id += 1
unique_id = 0
for node in graph_module.graph.nodes:
if node.op in ["output", "placeholder"]:
continue
@ -143,17 +134,6 @@ class QuantizationComparisonResult:
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
)
def loss(
self, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
) -> torch.Tensor:
if self.actual.shape != self.ref.shape:
raise ValueError(
f"Cannot compare tensors with different shapes: {self.actual.shape} vs {self.ref.shape}"
)
return loss_function(
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
)
def __repr__(self) -> str:
# Don't include the tensors themselves as they are quite large to print
# out.
@ -169,10 +149,6 @@ class QuantizationComparisonResult:
if not isinstance(self.ref, torch.Tensor):
raise ValueError(f"`self.ref` value must be a Tensor, got: {self.ref}")
if self.actual.shape != self.ref.shape:
raise ValueError(
f"Cannot compare tensors with different shapes: ref={self.ref.shape} vs actual={self.actual.shape}"
)
@dataclass(frozen=True)
@ -221,8 +197,8 @@ def extract_results_from_loggers(
def compare_results(
ref_results: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]],
actual_results: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]],
ref_results: Dict[int, Tuple[str, object, List[torch.Tensor]]],
actual_results: Dict[int, Tuple[str, object, List[torch.Tensor]]],
) -> Dict[int, NodeAccuracySummary]:
"""Given two dict mapping from `debug_handle_id` (int) to list of tensors
return a map from `debug_handle_id` to `NodeAccuracySummary` that contains
@ -244,25 +220,16 @@ def compare_results(
)
continue
actual_name, actual_stack, actual_stats = actual_results[debug_handle]
try:
results = [
QuantizationComparisonResult(actual=a, ref=b)
for a, b in zip(actual_stats, ref_stats)
]
except Exception as e:
# Add extra information for an exception from QuantizationComparisonResult
# if the shapes didn't match, to include the handle and the node names.
raise ValueError(
f"For numeric_debug_handle={debug_handle} from ref node {ref_name} and actual node {actual_name}"
) from e
comparisons[debug_handle] = NodeAccuracySummary(
handle=debug_handle,
actual_node_name=actual_name or "",
actual_node_name=actual_name,
actual_module_stack=_module_stack_to_str(actual_stack),
ref_node_name=ref_name or "",
ref_node_name=ref_name,
ref_module_stack=_module_stack_to_str(ref_stack),
results=results,
results=[
QuantizationComparisonResult(actual=a, ref=b)
for a, b in zip(actual_stats, ref_stats)
],
)
return comparisons