mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
speculation_log: Raise a unique error for divergence issues (#144785)
This is primarily sent for discussion and to see what tests fail due to this. The idea is that rather than capturing this as a regex on the fail_reason, just give it a unique failure type Pull Request resolved: https://github.com/pytorch/pytorch/pull/144785 Approved by: https://github.com/ezyang
This commit is contained in:
parent
b90231a189
commit
926f9056a9
2 changed files with 18 additions and 2 deletions
|
|
@ -6,6 +6,8 @@ import torch._dynamo.test_case
|
|||
import torch._functorch.config
|
||||
import torch.nn
|
||||
import torch.utils.checkpoint
|
||||
from torch._dynamo.bytecode_transformation import Instruction
|
||||
from torch._dynamo.symbolic_convert import SpeculationLog, SpeculationLogDivergence
|
||||
|
||||
|
||||
class ExceptionTests(torch._dynamo.test_case.TestCase):
|
||||
|
|
@ -402,6 +404,13 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(ref[0], res[0])
|
||||
self.assertEqual(ref[1], res[1])
|
||||
|
||||
def test_speculation_exception(self):
|
||||
log = SpeculationLog()
|
||||
log.next("fake", 555, "fake", Instruction(1, "fake", 1, 1))
|
||||
log.restart()
|
||||
with self.assertRaises(SpeculationLogDivergence):
|
||||
log.next("bad", 58, "bad", Instruction(2, "different", 2, 2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -213,11 +213,13 @@ class SpeculationLog:
|
|||
f"Previous instruction: {prev_entry.filename}:{prev_entry.lineno}"
|
||||
f"({prev_entry.inst.opname} @ {prev_entry.instruction_pointer})\n"
|
||||
)
|
||||
assert (
|
||||
if not (
|
||||
entry.instruction_pointer == instruction_pointer
|
||||
and entry.filename == filename
|
||||
and entry.lineno == lineno
|
||||
), f"""
|
||||
):
|
||||
raise SpeculationLogDivergence(
|
||||
f"""
|
||||
SpeculationLog diverged at index {self.index} (log had {len(self.entries)} entries):
|
||||
- Expected: {entry.filename}:{entry.lineno} ({entry.inst.opname} at ip={entry.instruction_pointer})
|
||||
- Actual: {filename}:{lineno} ({inst.opname} at ip={instruction_pointer})
|
||||
|
|
@ -235,6 +237,7 @@ do this for graph breaks, you will infinite loop).
|
|||
|
||||
Otherwise, please submit a bug report, ideally including the contents of TORCH_LOGS=+dynamo
|
||||
"""
|
||||
)
|
||||
self.index += 1
|
||||
return entry
|
||||
|
||||
|
|
@ -319,6 +322,10 @@ class BlockStackEntry:
|
|||
return self.with_context.exit(tx)
|
||||
|
||||
|
||||
class SpeculationLogDivergence(AssertionError):
|
||||
pass
|
||||
|
||||
|
||||
class ReturnValueOp(Exception):
|
||||
pass
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue