mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "[dynamo] save/restore system random state more carefully (#145750)"
This reverts commit e3d3f2b22e.
Reverted https://github.com/pytorch/pytorch/pull/145750 on behalf of https://github.com/eellison due to bisected perf regression ([comment](https://github.com/pytorch/pytorch/pull/145750#issuecomment-2620028414))
This commit is contained in:
parent
28982ceb3b
commit
3481c2aec4
9 changed files with 22 additions and 203 deletions
|
|
@ -1,5 +1,4 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
import contextlib
|
||||
import math
|
||||
import random
|
||||
import unittest
|
||||
|
|
@ -256,74 +255,6 @@ class UnspecTests(torch._dynamo.test_case.TestCase):
|
|||
y2 = opt_fn(inp, *get_rng())
|
||||
self.assertEqual(y1, y2)
|
||||
|
||||
def test_random_in_dynamo(self):
|
||||
# test that system random calls still work even
|
||||
# if Dynamo calls random methods.
|
||||
|
||||
exit_stack = contextlib.ExitStack()
|
||||
|
||||
def patch_fn_with_rng_burn(name):
|
||||
orig_fn = eval(name)
|
||||
|
||||
def bad(*args, **kwargs):
|
||||
# burn random call within dynamo
|
||||
random.random()
|
||||
return orig_fn(*args, **kwargs)
|
||||
|
||||
exit_stack.enter_context(unittest.mock.patch(name, bad))
|
||||
|
||||
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
|
||||
|
||||
patch_fn_with_rng_burn("torch._dynamo.eval_frame._maybe_set_eval_frame")
|
||||
patch_fn_with_rng_burn("torch._dynamo.convert_frame._compile")
|
||||
patch_fn_with_rng_burn(
|
||||
"torch._dynamo.symbolic_convert.InstructionTranslator.run"
|
||||
)
|
||||
|
||||
def f1(x):
|
||||
# simple test
|
||||
r1 = random.randint(1, 9)
|
||||
y = x + random.uniform(10, 20)
|
||||
r2 = random.randrange(0, 10)
|
||||
return y + r1, r2
|
||||
|
||||
random.seed(1)
|
||||
ref1 = f1(x)
|
||||
opt_f1 = torch.compile(f1, backend="eager", fullgraph=True)
|
||||
random.seed(1)
|
||||
res1 = opt_f1(x)
|
||||
self.assertEqual(ref1, res1)
|
||||
|
||||
def f2(x):
|
||||
# test with graph breaks
|
||||
r1 = random.randint(1, 9)
|
||||
x = x + r1
|
||||
torch._dynamo.graph_break()
|
||||
r2 = random.randint(10, 19)
|
||||
x = x + r2
|
||||
return x, r1, r2
|
||||
|
||||
random.seed(2)
|
||||
ref2 = f2(x)
|
||||
opt_f2 = torch.compile(f2, backend="eager")
|
||||
random.seed(2)
|
||||
res2 = opt_f2(x)
|
||||
self.assertEqual(ref2, res2)
|
||||
|
||||
def f3(x):
|
||||
# test consecutive calls
|
||||
return x + random.randint(1, 10)
|
||||
|
||||
random.seed(3)
|
||||
ref3 = f3(x)
|
||||
ref3_ = f3(x)
|
||||
opt_f3 = torch.compile(f3, backend="eager", fullgraph=True)
|
||||
random.seed(3)
|
||||
res3 = opt_f3(x)
|
||||
res3_ = opt_f3(x)
|
||||
self.assertEqual(ref3, res3)
|
||||
self.assertEqual(ref3_, res3_)
|
||||
|
||||
def test_builtin_getitem(self):
|
||||
# builtin getitem args[0] is python list and args[1] is unspec
|
||||
def fn(x, idx):
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import pstats
|
||||
import random
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
|
|
@ -192,14 +193,13 @@ def fx_forward_from_src_skip_result(
|
|||
return result
|
||||
|
||||
|
||||
# TODO it is possible to move more global state preservation to eval_frame.py/c.
|
||||
# See how we preserve Python random state.
|
||||
def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
"""
|
||||
Context manager to:
|
||||
1) Save/restore torch.is_grad_enabled() state
|
||||
2) Save/restore torch random state
|
||||
3) Monkey patch torch.fx.graph_module._forward_from_src
|
||||
2) Save/restore python random state
|
||||
3) Save/restore torch random state
|
||||
4) Monkey patch torch.fx.graph_module._forward_from_src
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
|
|
@ -217,6 +217,7 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
|||
prior_mobile_allocator_state = (
|
||||
torch._C._is_default_mobile_cpu_allocator_set()
|
||||
)
|
||||
py_rng_state = random.getstate()
|
||||
prior_dtype = torch.get_default_dtype()
|
||||
torch_rng_state = torch.random.get_rng_state()
|
||||
cuda_rng_state = None
|
||||
|
|
@ -244,6 +245,7 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
|||
torch.use_deterministic_algorithms(
|
||||
prior_deterministic, warn_only=prior_warn_only
|
||||
)
|
||||
random.setstate(py_rng_state)
|
||||
torch.random.set_rng_state(torch_rng_state)
|
||||
torch.set_default_dtype(prior_dtype)
|
||||
curr_mobile_allocator_state = (
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ import functools
|
|||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import sysconfig
|
||||
import textwrap
|
||||
|
|
@ -535,7 +534,6 @@ class _TorchDynamoContext:
|
|||
@functools.wraps(fn)
|
||||
def _fn(*args, **kwargs):
|
||||
prior = set_eval_frame(None)
|
||||
prev_rng_state = random.getstate()
|
||||
try:
|
||||
if is_fx_tracing():
|
||||
if config.error_on_nested_fx_trace:
|
||||
|
|
@ -569,8 +567,6 @@ class _TorchDynamoContext:
|
|||
_maybe_set_eval_frame(_callback_from_stance(callback))
|
||||
|
||||
try:
|
||||
# ignore random state updates since beginning of _fn
|
||||
random.setstate(prev_rng_state)
|
||||
return fn(*args, **kwargs)
|
||||
except ShortenTraceback as e:
|
||||
# Failures in the backend likely don't have useful
|
||||
|
|
@ -579,8 +575,6 @@ class _TorchDynamoContext:
|
|||
finally:
|
||||
# Restore the dynamic layer stack depth if necessary.
|
||||
set_eval_frame(None)
|
||||
# NB: assumes no random calls made between fn() and here
|
||||
prev_rng_state = random.getstate()
|
||||
torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth(
|
||||
saved_dynamic_layer_stack_depth
|
||||
)
|
||||
|
|
@ -590,10 +584,6 @@ class _TorchDynamoContext:
|
|||
cleanup()
|
||||
finally:
|
||||
_maybe_set_eval_frame(prior)
|
||||
# ignore random state updates:
|
||||
# - since beginning of _fn if fn was not called
|
||||
# - since end of fn if fn was called (even if exn occurs)
|
||||
random.setstate(prev_rng_state)
|
||||
|
||||
# hooks to properly handle inlining
|
||||
_fn._torchdynamo_inline = fn # type: ignore[attr-defined]
|
||||
|
|
@ -750,22 +740,18 @@ class DisableContext(_TorchDynamoContext):
|
|||
@functools.wraps(fn)
|
||||
def _fn(*args, **kwargs):
|
||||
prior = set_eval_frame(None)
|
||||
prev_rng_state = random.getstate()
|
||||
try:
|
||||
prior_skip_guard_eval_unsafe = set_skip_guard_eval_unsafe(
|
||||
_is_skip_guard_eval_unsafe_stance()
|
||||
)
|
||||
_maybe_set_eval_frame(_callback_from_stance(self.callback))
|
||||
try:
|
||||
random.setstate(prev_rng_state)
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
set_eval_frame(None)
|
||||
set_skip_guard_eval_unsafe(prior_skip_guard_eval_unsafe)
|
||||
prev_rng_state = random.getstate()
|
||||
finally:
|
||||
_maybe_set_eval_frame(prior)
|
||||
random.setstate(prev_rng_state)
|
||||
|
||||
_fn._torchdynamo_disable = True # type: ignore[attr-defined]
|
||||
|
||||
|
|
|
|||
|
|
@ -628,11 +628,9 @@ class OutputGraph:
|
|||
"""
|
||||
global_state = cast(
|
||||
dict[str, tuple[Callable[..., Any], bool]],
|
||||
(
|
||||
out
|
||||
if out is not None
|
||||
else self.tracing_context.global_context.global_state
|
||||
),
|
||||
out
|
||||
if out is not None
|
||||
else self.tracing_context.global_context.global_state,
|
||||
)
|
||||
|
||||
# TODO - Consider having a torch level API for torch_function_state. As
|
||||
|
|
|
|||
|
|
@ -1713,7 +1713,6 @@ class RandomVariable(VariableTracker):
|
|||
tx.output.side_effects.mutation(self)
|
||||
state = self.random.getstate()
|
||||
|
||||
# Generate new random object with the same state and call the method
|
||||
def call_random_meth(*args, **kwargs):
|
||||
r = random.Random()
|
||||
r.setstate(state)
|
||||
|
|
|
|||
|
|
@ -678,8 +678,6 @@ def call_random_fn(tx, fn, args, kwargs):
|
|||
args = [x.as_python_constant() for x in args]
|
||||
kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
|
||||
random_call_index = len(tx.output.random_calls)
|
||||
# NB: it is probably not important for the example_value to be exactly correct,
|
||||
# we just need the right type
|
||||
example_value = fn(*args, **kwargs)
|
||||
source = RandomValueSource(random_call_index)
|
||||
tx.output.random_calls.append((fn, args, kwargs))
|
||||
|
|
@ -876,12 +874,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||
and all(k.is_python_constant() for k in args)
|
||||
and all(v.is_python_constant() for v in kwargs.values())
|
||||
):
|
||||
return call_random_fn(
|
||||
tx,
|
||||
self.value,
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
return call_random_fn(tx, self.value, args, kwargs)
|
||||
elif istype(self.value, types.MethodType):
|
||||
func = self.value.__func__
|
||||
obj = self.value.__self__
|
||||
|
|
|
|||
|
|
@ -9,14 +9,13 @@
|
|||
#include <torch/csrc/dynamo/debug_macros.h>
|
||||
#include <torch/csrc/dynamo/extra_state.h>
|
||||
#include <torch/csrc/dynamo/framelocals_mapping.h>
|
||||
#include <torch/csrc/dynamo/utils.h>
|
||||
#include <torch/csrc/utils/python_compat.h>
|
||||
|
||||
PyObject* guard_error_hook = NULL;
|
||||
const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup";
|
||||
|
||||
typedef struct {
|
||||
int active_dynamo_threads;
|
||||
int active_dynamo_threads;
|
||||
} ModuleState;
|
||||
|
||||
// static int active_dynamo_threads = 0;
|
||||
|
|
@ -296,7 +295,6 @@ static PyObject* dynamo_call_callback(
|
|||
PyObject* cache_entry_pyobj = CacheEntry_to_obj(cache_entry);
|
||||
PyObject* res = PyObject_CallFunction(
|
||||
callable, "OOO", frame, cache_entry_pyobj, frame_state);
|
||||
|
||||
Py_DECREF(frame);
|
||||
Py_DECREF(cache_entry_pyobj);
|
||||
return res;
|
||||
|
|
@ -539,17 +537,6 @@ static PyObject* dynamo_eval_custom_code(
|
|||
return result;
|
||||
}
|
||||
|
||||
static PyObject* random_state = NULL;
|
||||
|
||||
static void save_random_state() {
|
||||
Py_XSETREF(random_state, random_getstate(random_module()));
|
||||
}
|
||||
|
||||
static void restore_random_state() {
|
||||
DEBUG_NULL_CHECK(random_state);
|
||||
random_setstate(random_module(), random_state);
|
||||
}
|
||||
|
||||
static PyObject* dynamo__custom_eval_frame_shim(
|
||||
PyThreadState* tstate,
|
||||
THP_EVAL_API_FRAME_OBJECT* frame,
|
||||
|
|
@ -658,11 +645,6 @@ static PyObject* dynamo__custom_eval_frame(
|
|||
// in the shim.
|
||||
eval_frame_callback_set(Py_None);
|
||||
|
||||
// Preserve random state - restore before calling
|
||||
// dynamo_eval_[frame_default]/[custom_code]!
|
||||
// lookup or call_callback may burn RNG.
|
||||
save_random_state();
|
||||
|
||||
// A callback of Py_False indicates "run only" mode, the cache is checked, but
|
||||
// we never compile.
|
||||
// Also, if extra is marked as "cache_limit_hit", run in "run only" mode
|
||||
|
|
@ -703,7 +685,6 @@ static PyObject* dynamo__custom_eval_frame(
|
|||
DEBUG_TRACE("skip recursive %s", get_frame_name(frame));
|
||||
eval_frame_callback_set(Py_None);
|
||||
}
|
||||
restore_random_state();
|
||||
PyObject* ret = dynamo_eval_frame_default(tstate, frame, throw_flag);
|
||||
if (extra_state_cache_limit_hit(extra)) {
|
||||
eval_frame_callback_set(callback);
|
||||
|
|
@ -716,7 +697,6 @@ static PyObject* dynamo__custom_eval_frame(
|
|||
// Re-enable custom behavior
|
||||
eval_frame_callback_set(callback);
|
||||
*should_clear_frame = 1;
|
||||
restore_random_state();
|
||||
return dynamo_eval_custom_code(
|
||||
tstate, frame, cached_code, trace_annotation, throw_flag);
|
||||
}
|
||||
|
|
@ -749,7 +729,6 @@ static PyObject* dynamo__custom_eval_frame(
|
|||
eval_frame_callback_set(callback);
|
||||
*should_clear_frame = 1;
|
||||
framelocals_mapping_free(locals);
|
||||
restore_random_state();
|
||||
return dynamo_eval_custom_code(
|
||||
tstate, frame, cached_code, trace_annotation, throw_flag);
|
||||
}
|
||||
|
|
@ -783,7 +762,6 @@ static PyObject* dynamo__custom_eval_frame(
|
|||
// code.
|
||||
DEBUG_TRACE("create skip recursive %s", get_frame_name(frame));
|
||||
set_extra_state(F_CODE(frame), SKIP_CODE_RECURSIVE);
|
||||
restore_random_state();
|
||||
PyObject* r = dynamo_eval_frame_default(tstate, frame, throw_flag);
|
||||
// Re-enable custom behavior
|
||||
eval_frame_callback_set(callback);
|
||||
|
|
@ -792,7 +770,6 @@ static PyObject* dynamo__custom_eval_frame(
|
|||
// Dynamo returned cache_limit_hit_flag, so we should recursively skip code.
|
||||
DEBUG_TRACE("create cache limit hit %s", get_frame_name(frame));
|
||||
set_extra_state_cache_limit_hit(extra, true);
|
||||
restore_random_state();
|
||||
PyObject* r = dynamo_eval_frame_default(tstate, frame, throw_flag);
|
||||
// Re-enable custom behavior
|
||||
eval_frame_callback_set(callback);
|
||||
|
|
@ -814,7 +791,6 @@ static PyObject* dynamo__custom_eval_frame(
|
|||
// Re-enable custom behavior
|
||||
eval_frame_callback_set(callback);
|
||||
*should_clear_frame = 1;
|
||||
restore_random_state();
|
||||
return dynamo_eval_custom_code(
|
||||
tstate,
|
||||
frame,
|
||||
|
|
@ -827,7 +803,6 @@ static PyObject* dynamo__custom_eval_frame(
|
|||
set_extra_state(F_CODE(frame), SKIP_CODE);
|
||||
// Re-enable custom behavior
|
||||
eval_frame_callback_set(callback);
|
||||
restore_random_state();
|
||||
return dynamo_eval_frame_default(tstate, frame, throw_flag);
|
||||
}
|
||||
}
|
||||
|
|
@ -856,9 +831,7 @@ static PyTypeObject THPPyInterpreterFrameType = {
|
|||
|
||||
#endif // !(IS_PYTHON_3_14_PLUS)
|
||||
|
||||
static PyObject* increment_working_threads(
|
||||
PyThreadState* tstate,
|
||||
PyObject* module) {
|
||||
static PyObject* increment_working_threads(PyThreadState* tstate, PyObject* module) {
|
||||
ModuleState* state = PyModule_GetState(module);
|
||||
|
||||
if (state != NULL) {
|
||||
|
|
@ -871,13 +844,11 @@ static PyObject* increment_working_threads(
|
|||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* decrement_working_threads(
|
||||
PyThreadState* tstate,
|
||||
PyObject* module) {
|
||||
static PyObject* decrement_working_threads(PyThreadState* tstate, PyObject* module) {
|
||||
ModuleState* state = PyModule_GetState(module);
|
||||
|
||||
if (state != NULL) {
|
||||
if (state->active_dynamo_threads > 0) {
|
||||
if (state->active_dynamo_threads > 0) {
|
||||
state->active_dynamo_threads = state->active_dynamo_threads - 1;
|
||||
if (state->active_dynamo_threads == 0) {
|
||||
enable_eval_frame_default(tstate);
|
||||
|
|
@ -888,10 +859,7 @@ static PyObject* decrement_working_threads(
|
|||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* set_eval_frame(
|
||||
PyObject* new_callback,
|
||||
PyThreadState* tstate,
|
||||
PyObject* module) {
|
||||
static PyObject* set_eval_frame(PyObject* new_callback, PyThreadState* tstate, PyObject* module) {
|
||||
// Change the eval frame callback and return the old one
|
||||
// - None: disables TorchDynamo
|
||||
// - False: run-only mode (reuse existing compiles)
|
||||
|
|
@ -1014,13 +982,13 @@ static PyObject* raise_sigtrap(PyObject* dummy, PyObject* obj) {
|
|||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static int clear_state(PyObject* module) {
|
||||
ModuleState* state = PyModule_GetState(module);
|
||||
if (state) {
|
||||
state->active_dynamo_threads = 0;
|
||||
return 0;
|
||||
}
|
||||
return -1;
|
||||
static int clear_state(PyObject *module) {
|
||||
ModuleState* state = PyModule_GetState(module);
|
||||
if (state) {
|
||||
state->active_dynamo_threads = 0;
|
||||
return 0;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
static PyMethodDef _methods[] = {
|
||||
|
|
|
|||
|
|
@ -2,29 +2,6 @@
|
|||
|
||||
namespace torch::dynamo {
|
||||
|
||||
// random utilities for C dynamo
|
||||
|
||||
// random module reference
|
||||
py::object _random{py::none()};
|
||||
|
||||
PyObject* random_module() {
|
||||
if (_random.is_none()) {
|
||||
_random = py::module_::import("random");
|
||||
}
|
||||
return _random.ptr();
|
||||
}
|
||||
|
||||
PyObject* random_getstate(PyObject* rng) {
|
||||
py::handle rng_h(rng);
|
||||
py::object state = rng_h.attr("getstate")();
|
||||
return state.release().ptr();
|
||||
}
|
||||
|
||||
void random_setstate(PyObject* rng, PyObject* state) {
|
||||
py::handle rng_h(rng), state_h(state);
|
||||
rng_h.attr("setstate")(state_h);
|
||||
}
|
||||
|
||||
static std::array<PyMethodDef, 1> _methods = {{
|
||||
{nullptr,
|
||||
nullptr,
|
||||
|
|
|
|||
|
|
@ -1,18 +1,10 @@
|
|||
#pragma once
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
#include <torch/csrc/python_headers.h>
|
||||
// C2039 MSVC
|
||||
#include <pybind11/complex.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
#endif // __cplusplus
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
// The visibility attribute is to avoid a warning about storing a field in the
|
||||
// struct that has a different visibility (from pybind) than the struct.
|
||||
#ifdef _WIN32
|
||||
|
|
@ -22,32 +14,5 @@
|
|||
#endif
|
||||
|
||||
namespace torch::dynamo {
|
||||
|
||||
extern "C" {
|
||||
|
||||
#endif // __cplusplus
|
||||
|
||||
// reference to random module
|
||||
// returns borrowed reference
|
||||
PyObject* random_module();
|
||||
|
||||
// rng.getstate()
|
||||
// rng can be random module or random.Random object
|
||||
// rng: borrowed reference
|
||||
// returns new reference
|
||||
PyObject* random_getstate(PyObject* rng);
|
||||
|
||||
// rng.setstate(state)
|
||||
// rng can be random module or random.Random object
|
||||
// rng, state: borrowed references
|
||||
// no return value
|
||||
void random_setstate(PyObject* rng, PyObject* state);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
} // extern "C"
|
||||
|
||||
PyObject* torch_c_dynamo_utils_init();
|
||||
} // namespace torch::dynamo
|
||||
|
||||
#endif // __cplusplus
|
||||
|
|
|
|||
Loading…
Reference in a new issue