mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "[torch][ao] Add customizable loss function to NodeAccuracySummary (#136282)"
This reverts commit f3c54ccf8f.
Reverted https://github.com/pytorch/pytorch/pull/136282 on behalf of https://github.com/huydhn due to This breaks OSS, let revert it and land the revert internally then ([comment](https://github.com/pytorch/pytorch/pull/136282#issuecomment-2364219252))
This commit is contained in:
parent
15dba021bb
commit
df1eef9779
2 changed files with 12 additions and 95 deletions
|
|
@ -25,8 +25,8 @@ from torch.testing._internal.common_quantization import TestHelperModules
|
|||
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef, TestCase
|
||||
|
||||
|
||||
def _extract_debug_handles(model) -> Dict[str, int]:
|
||||
debug_handle_map: Dict[str, int] = {}
|
||||
def _extract_debug_handles(model) -> Dict[torch.fx.Node, int]:
|
||||
debug_handle_map: Dict[torch.fx.Node, int] = {}
|
||||
|
||||
for node in model.graph.nodes:
|
||||
if (
|
||||
|
|
@ -187,53 +187,3 @@ class TestNumericDebugger(TestCase):
|
|||
for node_summary in comparison_results.values():
|
||||
if len(node_summary.results) > 0:
|
||||
self.assertGreaterEqual(node_summary.results[0].sqnr, 35)
|
||||
|
||||
def test_added_node_gets_unique_id(self) -> None:
|
||||
m = TestHelperModules.Conv2dThenConv1d()
|
||||
example_inputs = m.example_inputs()
|
||||
m = capture_pre_autograd_graph(m, example_inputs)
|
||||
assert isinstance(m, torch.fx.GraphModule)
|
||||
generate_numeric_debug_handle(m)
|
||||
ref_handles = _extract_debug_handles(m)
|
||||
ref_counter = Counter(ref_handles.values())
|
||||
for k, v in ref_counter.items():
|
||||
self.assertEqual(
|
||||
v,
|
||||
1,
|
||||
msg=f"For handle {k}, there were {v} nodes with that handle, but expected only 1",
|
||||
)
|
||||
|
||||
# Now that we have unique ids, add a new node into the graph and re-generate
|
||||
# to make sure that the new node gets a unique id.
|
||||
last_node = next(iter(reversed(m.graph.nodes)))
|
||||
with m.graph.inserting_before(last_node):
|
||||
arg = last_node.args[0]
|
||||
self.assertIsInstance(arg, tuple)
|
||||
arg = arg[0]
|
||||
# Add a function that only requires a single tensor input.
|
||||
n = m.graph.call_function(torch.ops.aten.relu.default, args=(arg,))
|
||||
arg.replace_all_uses_with(n, lambda x: x != n)
|
||||
m.recompile()
|
||||
|
||||
# Regenerate handles, make sure only the new relu node has a new id, and
|
||||
# it doesn't clash with any of the existing ids.
|
||||
generate_numeric_debug_handle(m)
|
||||
handles_after_modification = _extract_debug_handles(m)
|
||||
handles_counter = Counter(handles_after_modification.values())
|
||||
for name, handle in ref_handles.items():
|
||||
self.assertIn(name, handles_after_modification)
|
||||
# Check that handle was unchanged.
|
||||
self.assertEqual(handles_after_modification[name], handle)
|
||||
# Check that total count was unchanged.
|
||||
ref_count = ref_counter[handle]
|
||||
after_count = handles_counter[handle]
|
||||
self.assertEqual(
|
||||
after_count,
|
||||
ref_count,
|
||||
msg=f"For handle {handle}, there were {after_count} nodes with that handle, but expected only {ref_count}",
|
||||
)
|
||||
|
||||
# Check for relu specifically. Avoid hardcoding the handle id since it
|
||||
# may change with future node ordering changes.
|
||||
self.assertNotEqual(handles_after_modification["relu_default"], 0)
|
||||
self.assertEqual(handles_counter[handles_after_modification["relu_default"]], 1)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import copy
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Tuple
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from torch.ao.ns.fx.utils import compute_sqnr
|
||||
|
|
@ -19,16 +19,7 @@ def generate_numeric_debug_handle(graph_module: GraphModule) -> None:
|
|||
"""Attach numeric_debug_handle_id for all nodes in the model except for placeholder node
|
||||
The graph nodes of input model is modified inplace.
|
||||
"""
|
||||
unique_id = -1
|
||||
# Find the max ID that exists in the graph first, in case part of the graph
|
||||
# has already been annotated. This way we guarantee there are no duplicate
|
||||
# handle IDs.
|
||||
for node in graph_module.graph.nodes:
|
||||
unique_id = max(
|
||||
unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, -1)
|
||||
)
|
||||
unique_id += 1
|
||||
|
||||
unique_id = 0
|
||||
for node in graph_module.graph.nodes:
|
||||
if node.op in ["output", "placeholder"]:
|
||||
continue
|
||||
|
|
@ -143,17 +134,6 @@ class QuantizationComparisonResult:
|
|||
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
|
||||
)
|
||||
|
||||
def loss(
|
||||
self, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
if self.actual.shape != self.ref.shape:
|
||||
raise ValueError(
|
||||
f"Cannot compare tensors with different shapes: {self.actual.shape} vs {self.ref.shape}"
|
||||
)
|
||||
return loss_function(
|
||||
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# Don't include the tensors themselves as they are quite large to print
|
||||
# out.
|
||||
|
|
@ -169,10 +149,6 @@ class QuantizationComparisonResult:
|
|||
|
||||
if not isinstance(self.ref, torch.Tensor):
|
||||
raise ValueError(f"`self.ref` value must be a Tensor, got: {self.ref}")
|
||||
if self.actual.shape != self.ref.shape:
|
||||
raise ValueError(
|
||||
f"Cannot compare tensors with different shapes: ref={self.ref.shape} vs actual={self.actual.shape}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -221,8 +197,8 @@ def extract_results_from_loggers(
|
|||
|
||||
|
||||
def compare_results(
|
||||
ref_results: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]],
|
||||
actual_results: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]],
|
||||
ref_results: Dict[int, Tuple[str, object, List[torch.Tensor]]],
|
||||
actual_results: Dict[int, Tuple[str, object, List[torch.Tensor]]],
|
||||
) -> Dict[int, NodeAccuracySummary]:
|
||||
"""Given two dict mapping from `debug_handle_id` (int) to list of tensors
|
||||
return a map from `debug_handle_id` to `NodeAccuracySummary` that contains
|
||||
|
|
@ -244,25 +220,16 @@ def compare_results(
|
|||
)
|
||||
continue
|
||||
actual_name, actual_stack, actual_stats = actual_results[debug_handle]
|
||||
try:
|
||||
results = [
|
||||
QuantizationComparisonResult(actual=a, ref=b)
|
||||
for a, b in zip(actual_stats, ref_stats)
|
||||
]
|
||||
except Exception as e:
|
||||
# Add extra information for an exception from QuantizationComparisonResult
|
||||
# if the shapes didn't match, to include the handle and the node names.
|
||||
raise ValueError(
|
||||
f"For numeric_debug_handle={debug_handle} from ref node {ref_name} and actual node {actual_name}"
|
||||
) from e
|
||||
|
||||
comparisons[debug_handle] = NodeAccuracySummary(
|
||||
handle=debug_handle,
|
||||
actual_node_name=actual_name or "",
|
||||
actual_node_name=actual_name,
|
||||
actual_module_stack=_module_stack_to_str(actual_stack),
|
||||
ref_node_name=ref_name or "",
|
||||
ref_node_name=ref_name,
|
||||
ref_module_stack=_module_stack_to_str(ref_stack),
|
||||
results=results,
|
||||
results=[
|
||||
QuantizationComparisonResult(actual=a, ref=b)
|
||||
for a, b in zip(actual_stats, ref_stats)
|
||||
],
|
||||
)
|
||||
|
||||
return comparisons
|
||||
|
|
|
|||
Loading…
Reference in a new issue