diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index 547839569d6..79fdb0a37ad 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -1,5 +1,4 @@ # Owner(s): ["module: dynamo"] -import contextlib import math import random import unittest @@ -256,74 +255,6 @@ class UnspecTests(torch._dynamo.test_case.TestCase): y2 = opt_fn(inp, *get_rng()) self.assertEqual(y1, y2) - def test_random_in_dynamo(self): - # test that system random calls still work even - # if Dynamo calls random methods. - - exit_stack = contextlib.ExitStack() - - def patch_fn_with_rng_burn(name): - orig_fn = eval(name) - - def bad(*args, **kwargs): - # burn random call within dynamo - random.random() - return orig_fn(*args, **kwargs) - - exit_stack.enter_context(unittest.mock.patch(name, bad)) - - x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) - - patch_fn_with_rng_burn("torch._dynamo.eval_frame._maybe_set_eval_frame") - patch_fn_with_rng_burn("torch._dynamo.convert_frame._compile") - patch_fn_with_rng_burn( - "torch._dynamo.symbolic_convert.InstructionTranslator.run" - ) - - def f1(x): - # simple test - r1 = random.randint(1, 9) - y = x + random.uniform(10, 20) - r2 = random.randrange(0, 10) - return y + r1, r2 - - random.seed(1) - ref1 = f1(x) - opt_f1 = torch.compile(f1, backend="eager", fullgraph=True) - random.seed(1) - res1 = opt_f1(x) - self.assertEqual(ref1, res1) - - def f2(x): - # test with graph breaks - r1 = random.randint(1, 9) - x = x + r1 - torch._dynamo.graph_break() - r2 = random.randint(10, 19) - x = x + r2 - return x, r1, r2 - - random.seed(2) - ref2 = f2(x) - opt_f2 = torch.compile(f2, backend="eager") - random.seed(2) - res2 = opt_f2(x) - self.assertEqual(ref2, res2) - - def f3(x): - # test consecutive calls - return x + random.randint(1, 10) - - random.seed(3) - ref3 = f3(x) - ref3_ = f3(x) - opt_f3 = torch.compile(f3, backend="eager", fullgraph=True) - random.seed(3) - res3 = opt_f3(x) - res3_ = opt_f3(x) - self.assertEqual(ref3, res3) - self.assertEqual(ref3_, res3_) - def test_builtin_getitem(self): # builtin getitem args[0] is python list and args[1] is unspec def fn(x, idx): diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 038eafa6ef8..7d68861ec62 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -12,6 +12,7 @@ import json import logging import os import pstats +import random import subprocess import sys import threading @@ -192,14 +193,13 @@ def fx_forward_from_src_skip_result( return result -# TODO it is possible to move more global state preservation to eval_frame.py/c. -# See how we preserve Python random state. def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]: """ Context manager to: 1) Save/restore torch.is_grad_enabled() state - 2) Save/restore torch random state - 3) Monkey patch torch.fx.graph_module._forward_from_src + 2) Save/restore python random state + 3) Save/restore torch random state + 4) Monkey patch torch.fx.graph_module._forward_from_src """ @functools.wraps(fn) @@ -217,6 +217,7 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]: prior_mobile_allocator_state = ( torch._C._is_default_mobile_cpu_allocator_set() ) + py_rng_state = random.getstate() prior_dtype = torch.get_default_dtype() torch_rng_state = torch.random.get_rng_state() cuda_rng_state = None @@ -244,6 +245,7 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]: torch.use_deterministic_algorithms( prior_deterministic, warn_only=prior_warn_only ) + random.setstate(py_rng_state) torch.random.set_rng_state(torch_rng_state) torch.set_default_dtype(prior_dtype) curr_mobile_allocator_state = ( diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 90b21f3f845..c4ac8dc2a91 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -16,7 +16,6 @@ import functools import inspect import logging import os -import random import sys import sysconfig import textwrap @@ -535,7 +534,6 @@ class _TorchDynamoContext: @functools.wraps(fn) def _fn(*args, **kwargs): prior = set_eval_frame(None) - prev_rng_state = random.getstate() try: if is_fx_tracing(): if config.error_on_nested_fx_trace: @@ -569,8 +567,6 @@ class _TorchDynamoContext: _maybe_set_eval_frame(_callback_from_stance(callback)) try: - # ignore random state updates since beginning of _fn - random.setstate(prev_rng_state) return fn(*args, **kwargs) except ShortenTraceback as e: # Failures in the backend likely don't have useful @@ -579,8 +575,6 @@ class _TorchDynamoContext: finally: # Restore the dynamic layer stack depth if necessary. set_eval_frame(None) - # NB: assumes no random calls made between fn() and here - prev_rng_state = random.getstate() torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth( saved_dynamic_layer_stack_depth ) @@ -590,10 +584,6 @@ class _TorchDynamoContext: cleanup() finally: _maybe_set_eval_frame(prior) - # ignore random state updates: - # - since beginning of _fn if fn was not called - # - since end of fn if fn was called (even if exn occurs) - random.setstate(prev_rng_state) # hooks to properly handle inlining _fn._torchdynamo_inline = fn # type: ignore[attr-defined] @@ -750,22 +740,18 @@ class DisableContext(_TorchDynamoContext): @functools.wraps(fn) def _fn(*args, **kwargs): prior = set_eval_frame(None) - prev_rng_state = random.getstate() try: prior_skip_guard_eval_unsafe = set_skip_guard_eval_unsafe( _is_skip_guard_eval_unsafe_stance() ) _maybe_set_eval_frame(_callback_from_stance(self.callback)) try: - random.setstate(prev_rng_state) return fn(*args, **kwargs) finally: set_eval_frame(None) set_skip_guard_eval_unsafe(prior_skip_guard_eval_unsafe) - prev_rng_state = random.getstate() finally: _maybe_set_eval_frame(prior) - random.setstate(prev_rng_state) _fn._torchdynamo_disable = True # type: ignore[attr-defined] diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index e89d7689fe2..44806909956 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -628,11 +628,9 @@ class OutputGraph: """ global_state = cast( dict[str, tuple[Callable[..., Any], bool]], - ( - out - if out is not None - else self.tracing_context.global_context.global_state - ), + out + if out is not None + else self.tracing_context.global_context.global_state, ) # TODO - Consider having a torch level API for torch_function_state. As diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 95c06720e9c..64ff8800193 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -1713,7 +1713,6 @@ class RandomVariable(VariableTracker): tx.output.side_effects.mutation(self) state = self.random.getstate() - # Generate new random object with the same state and call the method def call_random_meth(*args, **kwargs): r = random.Random() r.setstate(state) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 76da1ac01e0..1a0d2ed31aa 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -678,8 +678,6 @@ def call_random_fn(tx, fn, args, kwargs): args = [x.as_python_constant() for x in args] kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} random_call_index = len(tx.output.random_calls) - # NB: it is probably not important for the example_value to be exactly correct, - # we just need the right type example_value = fn(*args, **kwargs) source = RandomValueSource(random_call_index) tx.output.random_calls.append((fn, args, kwargs)) @@ -876,12 +874,7 @@ class UserDefinedObjectVariable(UserDefinedVariable): and all(k.is_python_constant() for k in args) and all(v.is_python_constant() for v in kwargs.values()) ): - return call_random_fn( - tx, - self.value, - args, - kwargs, - ) + return call_random_fn(tx, self.value, args, kwargs) elif istype(self.value, types.MethodType): func = self.value.__func__ obj = self.value.__self__ diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c index 4a230125f94..d6457e7fb08 100644 --- a/torch/csrc/dynamo/eval_frame.c +++ b/torch/csrc/dynamo/eval_frame.c @@ -9,14 +9,13 @@ #include #include #include -#include #include PyObject* guard_error_hook = NULL; const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup"; typedef struct { - int active_dynamo_threads; + int active_dynamo_threads; } ModuleState; // static int active_dynamo_threads = 0; @@ -296,7 +295,6 @@ static PyObject* dynamo_call_callback( PyObject* cache_entry_pyobj = CacheEntry_to_obj(cache_entry); PyObject* res = PyObject_CallFunction( callable, "OOO", frame, cache_entry_pyobj, frame_state); - Py_DECREF(frame); Py_DECREF(cache_entry_pyobj); return res; @@ -539,17 +537,6 @@ static PyObject* dynamo_eval_custom_code( return result; } -static PyObject* random_state = NULL; - -static void save_random_state() { - Py_XSETREF(random_state, random_getstate(random_module())); -} - -static void restore_random_state() { - DEBUG_NULL_CHECK(random_state); - random_setstate(random_module(), random_state); -} - static PyObject* dynamo__custom_eval_frame_shim( PyThreadState* tstate, THP_EVAL_API_FRAME_OBJECT* frame, @@ -658,11 +645,6 @@ static PyObject* dynamo__custom_eval_frame( // in the shim. eval_frame_callback_set(Py_None); - // Preserve random state - restore before calling - // dynamo_eval_[frame_default]/[custom_code]! - // lookup or call_callback may burn RNG. - save_random_state(); - // A callback of Py_False indicates "run only" mode, the cache is checked, but // we never compile. // Also, if extra is marked as "cache_limit_hit", run in "run only" mode @@ -703,7 +685,6 @@ static PyObject* dynamo__custom_eval_frame( DEBUG_TRACE("skip recursive %s", get_frame_name(frame)); eval_frame_callback_set(Py_None); } - restore_random_state(); PyObject* ret = dynamo_eval_frame_default(tstate, frame, throw_flag); if (extra_state_cache_limit_hit(extra)) { eval_frame_callback_set(callback); @@ -716,7 +697,6 @@ static PyObject* dynamo__custom_eval_frame( // Re-enable custom behavior eval_frame_callback_set(callback); *should_clear_frame = 1; - restore_random_state(); return dynamo_eval_custom_code( tstate, frame, cached_code, trace_annotation, throw_flag); } @@ -749,7 +729,6 @@ static PyObject* dynamo__custom_eval_frame( eval_frame_callback_set(callback); *should_clear_frame = 1; framelocals_mapping_free(locals); - restore_random_state(); return dynamo_eval_custom_code( tstate, frame, cached_code, trace_annotation, throw_flag); } @@ -783,7 +762,6 @@ static PyObject* dynamo__custom_eval_frame( // code. DEBUG_TRACE("create skip recursive %s", get_frame_name(frame)); set_extra_state(F_CODE(frame), SKIP_CODE_RECURSIVE); - restore_random_state(); PyObject* r = dynamo_eval_frame_default(tstate, frame, throw_flag); // Re-enable custom behavior eval_frame_callback_set(callback); @@ -792,7 +770,6 @@ static PyObject* dynamo__custom_eval_frame( // Dynamo returned cache_limit_hit_flag, so we should recursively skip code. DEBUG_TRACE("create cache limit hit %s", get_frame_name(frame)); set_extra_state_cache_limit_hit(extra, true); - restore_random_state(); PyObject* r = dynamo_eval_frame_default(tstate, frame, throw_flag); // Re-enable custom behavior eval_frame_callback_set(callback); @@ -814,7 +791,6 @@ static PyObject* dynamo__custom_eval_frame( // Re-enable custom behavior eval_frame_callback_set(callback); *should_clear_frame = 1; - restore_random_state(); return dynamo_eval_custom_code( tstate, frame, @@ -827,7 +803,6 @@ static PyObject* dynamo__custom_eval_frame( set_extra_state(F_CODE(frame), SKIP_CODE); // Re-enable custom behavior eval_frame_callback_set(callback); - restore_random_state(); return dynamo_eval_frame_default(tstate, frame, throw_flag); } } @@ -856,9 +831,7 @@ static PyTypeObject THPPyInterpreterFrameType = { #endif // !(IS_PYTHON_3_14_PLUS) -static PyObject* increment_working_threads( - PyThreadState* tstate, - PyObject* module) { +static PyObject* increment_working_threads(PyThreadState* tstate, PyObject* module) { ModuleState* state = PyModule_GetState(module); if (state != NULL) { @@ -871,13 +844,11 @@ static PyObject* increment_working_threads( Py_RETURN_NONE; } -static PyObject* decrement_working_threads( - PyThreadState* tstate, - PyObject* module) { +static PyObject* decrement_working_threads(PyThreadState* tstate, PyObject* module) { ModuleState* state = PyModule_GetState(module); if (state != NULL) { - if (state->active_dynamo_threads > 0) { + if (state->active_dynamo_threads > 0) { state->active_dynamo_threads = state->active_dynamo_threads - 1; if (state->active_dynamo_threads == 0) { enable_eval_frame_default(tstate); @@ -888,10 +859,7 @@ static PyObject* decrement_working_threads( Py_RETURN_NONE; } -static PyObject* set_eval_frame( - PyObject* new_callback, - PyThreadState* tstate, - PyObject* module) { +static PyObject* set_eval_frame(PyObject* new_callback, PyThreadState* tstate, PyObject* module) { // Change the eval frame callback and return the old one // - None: disables TorchDynamo // - False: run-only mode (reuse existing compiles) @@ -1014,13 +982,13 @@ static PyObject* raise_sigtrap(PyObject* dummy, PyObject* obj) { Py_RETURN_NONE; } -static int clear_state(PyObject* module) { - ModuleState* state = PyModule_GetState(module); - if (state) { - state->active_dynamo_threads = 0; - return 0; - } - return -1; +static int clear_state(PyObject *module) { + ModuleState* state = PyModule_GetState(module); + if (state) { + state->active_dynamo_threads = 0; + return 0; + } + return -1; } static PyMethodDef _methods[] = { diff --git a/torch/csrc/dynamo/utils.cpp b/torch/csrc/dynamo/utils.cpp index 4a779d245ab..eea523e2747 100644 --- a/torch/csrc/dynamo/utils.cpp +++ b/torch/csrc/dynamo/utils.cpp @@ -2,29 +2,6 @@ namespace torch::dynamo { -// random utilities for C dynamo - -// random module reference -py::object _random{py::none()}; - -PyObject* random_module() { - if (_random.is_none()) { - _random = py::module_::import("random"); - } - return _random.ptr(); -} - -PyObject* random_getstate(PyObject* rng) { - py::handle rng_h(rng); - py::object state = rng_h.attr("getstate")(); - return state.release().ptr(); -} - -void random_setstate(PyObject* rng, PyObject* state) { - py::handle rng_h(rng), state_h(state); - rng_h.attr("setstate")(state_h); -} - static std::array _methods = {{ {nullptr, nullptr, diff --git a/torch/csrc/dynamo/utils.h b/torch/csrc/dynamo/utils.h index ca191f65c8b..c7bcfb5b5e6 100644 --- a/torch/csrc/dynamo/utils.h +++ b/torch/csrc/dynamo/utils.h @@ -1,18 +1,10 @@ #pragma once - -#ifdef __cplusplus - #include // C2039 MSVC #include #include -#endif // __cplusplus - #include - -#ifdef __cplusplus - // The visibility attribute is to avoid a warning about storing a field in the // struct that has a different visibility (from pybind) than the struct. #ifdef _WIN32 @@ -22,32 +14,5 @@ #endif namespace torch::dynamo { - -extern "C" { - -#endif // __cplusplus - -// reference to random module -// returns borrowed reference -PyObject* random_module(); - -// rng.getstate() -// rng can be random module or random.Random object -// rng: borrowed reference -// returns new reference -PyObject* random_getstate(PyObject* rng); - -// rng.setstate(state) -// rng can be random module or random.Random object -// rng, state: borrowed references -// no return value -void random_setstate(PyObject* rng, PyObject* state); - -#ifdef __cplusplus - -} // extern "C" - PyObject* torch_c_dynamo_utils_init(); } // namespace torch::dynamo - -#endif // __cplusplus