Revert "[dynamo] save/restore system random state more carefully (#145750)"

This reverts commit e3d3f2b22e.

Reverted https://github.com/pytorch/pytorch/pull/145750 on behalf of https://github.com/eellison due to bisected perf regression ([comment](https://github.com/pytorch/pytorch/pull/145750#issuecomment-2620028414))
This commit is contained in:
PyTorch MergeBot 2025-01-28 20:51:07 +00:00
parent 28982ceb3b
commit 3481c2aec4
9 changed files with 22 additions and 203 deletions

View file

@ -1,5 +1,4 @@
# Owner(s): ["module: dynamo"]
import contextlib
import math
import random
import unittest
@ -256,74 +255,6 @@ class UnspecTests(torch._dynamo.test_case.TestCase):
y2 = opt_fn(inp, *get_rng())
self.assertEqual(y1, y2)
def test_random_in_dynamo(self):
# test that system random calls still work even
# if Dynamo calls random methods.
exit_stack = contextlib.ExitStack()
def patch_fn_with_rng_burn(name):
orig_fn = eval(name)
def bad(*args, **kwargs):
# burn random call within dynamo
random.random()
return orig_fn(*args, **kwargs)
exit_stack.enter_context(unittest.mock.patch(name, bad))
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
patch_fn_with_rng_burn("torch._dynamo.eval_frame._maybe_set_eval_frame")
patch_fn_with_rng_burn("torch._dynamo.convert_frame._compile")
patch_fn_with_rng_burn(
"torch._dynamo.symbolic_convert.InstructionTranslator.run"
)
def f1(x):
# simple test
r1 = random.randint(1, 9)
y = x + random.uniform(10, 20)
r2 = random.randrange(0, 10)
return y + r1, r2
random.seed(1)
ref1 = f1(x)
opt_f1 = torch.compile(f1, backend="eager", fullgraph=True)
random.seed(1)
res1 = opt_f1(x)
self.assertEqual(ref1, res1)
def f2(x):
# test with graph breaks
r1 = random.randint(1, 9)
x = x + r1
torch._dynamo.graph_break()
r2 = random.randint(10, 19)
x = x + r2
return x, r1, r2
random.seed(2)
ref2 = f2(x)
opt_f2 = torch.compile(f2, backend="eager")
random.seed(2)
res2 = opt_f2(x)
self.assertEqual(ref2, res2)
def f3(x):
# test consecutive calls
return x + random.randint(1, 10)
random.seed(3)
ref3 = f3(x)
ref3_ = f3(x)
opt_f3 = torch.compile(f3, backend="eager", fullgraph=True)
random.seed(3)
res3 = opt_f3(x)
res3_ = opt_f3(x)
self.assertEqual(ref3, res3)
self.assertEqual(ref3_, res3_)
def test_builtin_getitem(self):
# builtin getitem args[0] is python list and args[1] is unspec
def fn(x, idx):

View file

@ -12,6 +12,7 @@ import json
import logging
import os
import pstats
import random
import subprocess
import sys
import threading
@ -192,14 +193,13 @@ def fx_forward_from_src_skip_result(
return result
# TODO it is possible to move more global state preservation to eval_frame.py/c.
# See how we preserve Python random state.
def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
"""
Context manager to:
1) Save/restore torch.is_grad_enabled() state
2) Save/restore torch random state
3) Monkey patch torch.fx.graph_module._forward_from_src
2) Save/restore python random state
3) Save/restore torch random state
4) Monkey patch torch.fx.graph_module._forward_from_src
"""
@functools.wraps(fn)
@ -217,6 +217,7 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
prior_mobile_allocator_state = (
torch._C._is_default_mobile_cpu_allocator_set()
)
py_rng_state = random.getstate()
prior_dtype = torch.get_default_dtype()
torch_rng_state = torch.random.get_rng_state()
cuda_rng_state = None
@ -244,6 +245,7 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
torch.use_deterministic_algorithms(
prior_deterministic, warn_only=prior_warn_only
)
random.setstate(py_rng_state)
torch.random.set_rng_state(torch_rng_state)
torch.set_default_dtype(prior_dtype)
curr_mobile_allocator_state = (

View file

@ -16,7 +16,6 @@ import functools
import inspect
import logging
import os
import random
import sys
import sysconfig
import textwrap
@ -535,7 +534,6 @@ class _TorchDynamoContext:
@functools.wraps(fn)
def _fn(*args, **kwargs):
prior = set_eval_frame(None)
prev_rng_state = random.getstate()
try:
if is_fx_tracing():
if config.error_on_nested_fx_trace:
@ -569,8 +567,6 @@ class _TorchDynamoContext:
_maybe_set_eval_frame(_callback_from_stance(callback))
try:
# ignore random state updates since beginning of _fn
random.setstate(prev_rng_state)
return fn(*args, **kwargs)
except ShortenTraceback as e:
# Failures in the backend likely don't have useful
@ -579,8 +575,6 @@ class _TorchDynamoContext:
finally:
# Restore the dynamic layer stack depth if necessary.
set_eval_frame(None)
# NB: assumes no random calls made between fn() and here
prev_rng_state = random.getstate()
torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth(
saved_dynamic_layer_stack_depth
)
@ -590,10 +584,6 @@ class _TorchDynamoContext:
cleanup()
finally:
_maybe_set_eval_frame(prior)
# ignore random state updates:
# - since beginning of _fn if fn was not called
# - since end of fn if fn was called (even if exn occurs)
random.setstate(prev_rng_state)
# hooks to properly handle inlining
_fn._torchdynamo_inline = fn # type: ignore[attr-defined]
@ -750,22 +740,18 @@ class DisableContext(_TorchDynamoContext):
@functools.wraps(fn)
def _fn(*args, **kwargs):
prior = set_eval_frame(None)
prev_rng_state = random.getstate()
try:
prior_skip_guard_eval_unsafe = set_skip_guard_eval_unsafe(
_is_skip_guard_eval_unsafe_stance()
)
_maybe_set_eval_frame(_callback_from_stance(self.callback))
try:
random.setstate(prev_rng_state)
return fn(*args, **kwargs)
finally:
set_eval_frame(None)
set_skip_guard_eval_unsafe(prior_skip_guard_eval_unsafe)
prev_rng_state = random.getstate()
finally:
_maybe_set_eval_frame(prior)
random.setstate(prev_rng_state)
_fn._torchdynamo_disable = True # type: ignore[attr-defined]

View file

@ -628,11 +628,9 @@ class OutputGraph:
"""
global_state = cast(
dict[str, tuple[Callable[..., Any], bool]],
(
out
if out is not None
else self.tracing_context.global_context.global_state
),
out
if out is not None
else self.tracing_context.global_context.global_state,
)
# TODO - Consider having a torch level API for torch_function_state. As

View file

@ -1713,7 +1713,6 @@ class RandomVariable(VariableTracker):
tx.output.side_effects.mutation(self)
state = self.random.getstate()
# Generate new random object with the same state and call the method
def call_random_meth(*args, **kwargs):
r = random.Random()
r.setstate(state)

View file

@ -678,8 +678,6 @@ def call_random_fn(tx, fn, args, kwargs):
args = [x.as_python_constant() for x in args]
kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
random_call_index = len(tx.output.random_calls)
# NB: it is probably not important for the example_value to be exactly correct,
# we just need the right type
example_value = fn(*args, **kwargs)
source = RandomValueSource(random_call_index)
tx.output.random_calls.append((fn, args, kwargs))
@ -876,12 +874,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
and all(k.is_python_constant() for k in args)
and all(v.is_python_constant() for v in kwargs.values())
):
return call_random_fn(
tx,
self.value,
args,
kwargs,
)
return call_random_fn(tx, self.value, args, kwargs)
elif istype(self.value, types.MethodType):
func = self.value.__func__
obj = self.value.__self__

View file

@ -9,14 +9,13 @@
#include <torch/csrc/dynamo/debug_macros.h>
#include <torch/csrc/dynamo/extra_state.h>
#include <torch/csrc/dynamo/framelocals_mapping.h>
#include <torch/csrc/dynamo/utils.h>
#include <torch/csrc/utils/python_compat.h>
PyObject* guard_error_hook = NULL;
const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup";
typedef struct {
int active_dynamo_threads;
int active_dynamo_threads;
} ModuleState;
// static int active_dynamo_threads = 0;
@ -296,7 +295,6 @@ static PyObject* dynamo_call_callback(
PyObject* cache_entry_pyobj = CacheEntry_to_obj(cache_entry);
PyObject* res = PyObject_CallFunction(
callable, "OOO", frame, cache_entry_pyobj, frame_state);
Py_DECREF(frame);
Py_DECREF(cache_entry_pyobj);
return res;
@ -539,17 +537,6 @@ static PyObject* dynamo_eval_custom_code(
return result;
}
static PyObject* random_state = NULL;
static void save_random_state() {
Py_XSETREF(random_state, random_getstate(random_module()));
}
static void restore_random_state() {
DEBUG_NULL_CHECK(random_state);
random_setstate(random_module(), random_state);
}
static PyObject* dynamo__custom_eval_frame_shim(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
@ -658,11 +645,6 @@ static PyObject* dynamo__custom_eval_frame(
// in the shim.
eval_frame_callback_set(Py_None);
// Preserve random state - restore before calling
// dynamo_eval_[frame_default]/[custom_code]!
// lookup or call_callback may burn RNG.
save_random_state();
// A callback of Py_False indicates "run only" mode, the cache is checked, but
// we never compile.
// Also, if extra is marked as "cache_limit_hit", run in "run only" mode
@ -703,7 +685,6 @@ static PyObject* dynamo__custom_eval_frame(
DEBUG_TRACE("skip recursive %s", get_frame_name(frame));
eval_frame_callback_set(Py_None);
}
restore_random_state();
PyObject* ret = dynamo_eval_frame_default(tstate, frame, throw_flag);
if (extra_state_cache_limit_hit(extra)) {
eval_frame_callback_set(callback);
@ -716,7 +697,6 @@ static PyObject* dynamo__custom_eval_frame(
// Re-enable custom behavior
eval_frame_callback_set(callback);
*should_clear_frame = 1;
restore_random_state();
return dynamo_eval_custom_code(
tstate, frame, cached_code, trace_annotation, throw_flag);
}
@ -749,7 +729,6 @@ static PyObject* dynamo__custom_eval_frame(
eval_frame_callback_set(callback);
*should_clear_frame = 1;
framelocals_mapping_free(locals);
restore_random_state();
return dynamo_eval_custom_code(
tstate, frame, cached_code, trace_annotation, throw_flag);
}
@ -783,7 +762,6 @@ static PyObject* dynamo__custom_eval_frame(
// code.
DEBUG_TRACE("create skip recursive %s", get_frame_name(frame));
set_extra_state(F_CODE(frame), SKIP_CODE_RECURSIVE);
restore_random_state();
PyObject* r = dynamo_eval_frame_default(tstate, frame, throw_flag);
// Re-enable custom behavior
eval_frame_callback_set(callback);
@ -792,7 +770,6 @@ static PyObject* dynamo__custom_eval_frame(
// Dynamo returned cache_limit_hit_flag, so we should recursively skip code.
DEBUG_TRACE("create cache limit hit %s", get_frame_name(frame));
set_extra_state_cache_limit_hit(extra, true);
restore_random_state();
PyObject* r = dynamo_eval_frame_default(tstate, frame, throw_flag);
// Re-enable custom behavior
eval_frame_callback_set(callback);
@ -814,7 +791,6 @@ static PyObject* dynamo__custom_eval_frame(
// Re-enable custom behavior
eval_frame_callback_set(callback);
*should_clear_frame = 1;
restore_random_state();
return dynamo_eval_custom_code(
tstate,
frame,
@ -827,7 +803,6 @@ static PyObject* dynamo__custom_eval_frame(
set_extra_state(F_CODE(frame), SKIP_CODE);
// Re-enable custom behavior
eval_frame_callback_set(callback);
restore_random_state();
return dynamo_eval_frame_default(tstate, frame, throw_flag);
}
}
@ -856,9 +831,7 @@ static PyTypeObject THPPyInterpreterFrameType = {
#endif // !(IS_PYTHON_3_14_PLUS)
static PyObject* increment_working_threads(
PyThreadState* tstate,
PyObject* module) {
static PyObject* increment_working_threads(PyThreadState* tstate, PyObject* module) {
ModuleState* state = PyModule_GetState(module);
if (state != NULL) {
@ -871,13 +844,11 @@ static PyObject* increment_working_threads(
Py_RETURN_NONE;
}
static PyObject* decrement_working_threads(
PyThreadState* tstate,
PyObject* module) {
static PyObject* decrement_working_threads(PyThreadState* tstate, PyObject* module) {
ModuleState* state = PyModule_GetState(module);
if (state != NULL) {
if (state->active_dynamo_threads > 0) {
if (state->active_dynamo_threads > 0) {
state->active_dynamo_threads = state->active_dynamo_threads - 1;
if (state->active_dynamo_threads == 0) {
enable_eval_frame_default(tstate);
@ -888,10 +859,7 @@ static PyObject* decrement_working_threads(
Py_RETURN_NONE;
}
static PyObject* set_eval_frame(
PyObject* new_callback,
PyThreadState* tstate,
PyObject* module) {
static PyObject* set_eval_frame(PyObject* new_callback, PyThreadState* tstate, PyObject* module) {
// Change the eval frame callback and return the old one
// - None: disables TorchDynamo
// - False: run-only mode (reuse existing compiles)
@ -1014,13 +982,13 @@ static PyObject* raise_sigtrap(PyObject* dummy, PyObject* obj) {
Py_RETURN_NONE;
}
static int clear_state(PyObject* module) {
ModuleState* state = PyModule_GetState(module);
if (state) {
state->active_dynamo_threads = 0;
return 0;
}
return -1;
static int clear_state(PyObject *module) {
ModuleState* state = PyModule_GetState(module);
if (state) {
state->active_dynamo_threads = 0;
return 0;
}
return -1;
}
static PyMethodDef _methods[] = {

View file

@ -2,29 +2,6 @@
namespace torch::dynamo {
// random utilities for C dynamo
// random module reference
py::object _random{py::none()};
PyObject* random_module() {
if (_random.is_none()) {
_random = py::module_::import("random");
}
return _random.ptr();
}
PyObject* random_getstate(PyObject* rng) {
py::handle rng_h(rng);
py::object state = rng_h.attr("getstate")();
return state.release().ptr();
}
void random_setstate(PyObject* rng, PyObject* state) {
py::handle rng_h(rng), state_h(state);
rng_h.attr("setstate")(state_h);
}
static std::array<PyMethodDef, 1> _methods = {{
{nullptr,
nullptr,

View file

@ -1,18 +1,10 @@
#pragma once
#ifdef __cplusplus
#include <torch/csrc/python_headers.h>
// C2039 MSVC
#include <pybind11/complex.h>
#include <torch/csrc/utils/pybind.h>
#endif // __cplusplus
#include <Python.h>
#ifdef __cplusplus
// The visibility attribute is to avoid a warning about storing a field in the
// struct that has a different visibility (from pybind) than the struct.
#ifdef _WIN32
@ -22,32 +14,5 @@
#endif
namespace torch::dynamo {
extern "C" {
#endif // __cplusplus
// reference to random module
// returns borrowed reference
PyObject* random_module();
// rng.getstate()
// rng can be random module or random.Random object
// rng: borrowed reference
// returns new reference
PyObject* random_getstate(PyObject* rng);
// rng.setstate(state)
// rng can be random module or random.Random object
// rng, state: borrowed references
// no return value
void random_setstate(PyObject* rng, PyObject* state);
#ifdef __cplusplus
} // extern "C"
PyObject* torch_c_dynamo_utils_init();
} // namespace torch::dynamo
#endif // __cplusplus