mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "[dynamo] save/restore system random state more carefully (#145750)"
This reverts commit e3d3f2b22e.
Reverted https://github.com/pytorch/pytorch/pull/145750 on behalf of https://github.com/eellison due to bisected perf regression ([comment](https://github.com/pytorch/pytorch/pull/145750#issuecomment-2620028414))
This commit is contained in:
parent
28982ceb3b
commit
3481c2aec4
9 changed files with 22 additions and 203 deletions
|
|
@ -1,5 +1,4 @@
|
||||||
# Owner(s): ["module: dynamo"]
|
# Owner(s): ["module: dynamo"]
|
||||||
import contextlib
|
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
import unittest
|
import unittest
|
||||||
|
|
@ -256,74 +255,6 @@ class UnspecTests(torch._dynamo.test_case.TestCase):
|
||||||
y2 = opt_fn(inp, *get_rng())
|
y2 = opt_fn(inp, *get_rng())
|
||||||
self.assertEqual(y1, y2)
|
self.assertEqual(y1, y2)
|
||||||
|
|
||||||
def test_random_in_dynamo(self):
|
|
||||||
# test that system random calls still work even
|
|
||||||
# if Dynamo calls random methods.
|
|
||||||
|
|
||||||
exit_stack = contextlib.ExitStack()
|
|
||||||
|
|
||||||
def patch_fn_with_rng_burn(name):
|
|
||||||
orig_fn = eval(name)
|
|
||||||
|
|
||||||
def bad(*args, **kwargs):
|
|
||||||
# burn random call within dynamo
|
|
||||||
random.random()
|
|
||||||
return orig_fn(*args, **kwargs)
|
|
||||||
|
|
||||||
exit_stack.enter_context(unittest.mock.patch(name, bad))
|
|
||||||
|
|
||||||
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
|
|
||||||
|
|
||||||
patch_fn_with_rng_burn("torch._dynamo.eval_frame._maybe_set_eval_frame")
|
|
||||||
patch_fn_with_rng_burn("torch._dynamo.convert_frame._compile")
|
|
||||||
patch_fn_with_rng_burn(
|
|
||||||
"torch._dynamo.symbolic_convert.InstructionTranslator.run"
|
|
||||||
)
|
|
||||||
|
|
||||||
def f1(x):
|
|
||||||
# simple test
|
|
||||||
r1 = random.randint(1, 9)
|
|
||||||
y = x + random.uniform(10, 20)
|
|
||||||
r2 = random.randrange(0, 10)
|
|
||||||
return y + r1, r2
|
|
||||||
|
|
||||||
random.seed(1)
|
|
||||||
ref1 = f1(x)
|
|
||||||
opt_f1 = torch.compile(f1, backend="eager", fullgraph=True)
|
|
||||||
random.seed(1)
|
|
||||||
res1 = opt_f1(x)
|
|
||||||
self.assertEqual(ref1, res1)
|
|
||||||
|
|
||||||
def f2(x):
|
|
||||||
# test with graph breaks
|
|
||||||
r1 = random.randint(1, 9)
|
|
||||||
x = x + r1
|
|
||||||
torch._dynamo.graph_break()
|
|
||||||
r2 = random.randint(10, 19)
|
|
||||||
x = x + r2
|
|
||||||
return x, r1, r2
|
|
||||||
|
|
||||||
random.seed(2)
|
|
||||||
ref2 = f2(x)
|
|
||||||
opt_f2 = torch.compile(f2, backend="eager")
|
|
||||||
random.seed(2)
|
|
||||||
res2 = opt_f2(x)
|
|
||||||
self.assertEqual(ref2, res2)
|
|
||||||
|
|
||||||
def f3(x):
|
|
||||||
# test consecutive calls
|
|
||||||
return x + random.randint(1, 10)
|
|
||||||
|
|
||||||
random.seed(3)
|
|
||||||
ref3 = f3(x)
|
|
||||||
ref3_ = f3(x)
|
|
||||||
opt_f3 = torch.compile(f3, backend="eager", fullgraph=True)
|
|
||||||
random.seed(3)
|
|
||||||
res3 = opt_f3(x)
|
|
||||||
res3_ = opt_f3(x)
|
|
||||||
self.assertEqual(ref3, res3)
|
|
||||||
self.assertEqual(ref3_, res3_)
|
|
||||||
|
|
||||||
def test_builtin_getitem(self):
|
def test_builtin_getitem(self):
|
||||||
# builtin getitem args[0] is python list and args[1] is unspec
|
# builtin getitem args[0] is python list and args[1] is unspec
|
||||||
def fn(x, idx):
|
def fn(x, idx):
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pstats
|
import pstats
|
||||||
|
import random
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
|
|
@ -192,14 +193,13 @@ def fx_forward_from_src_skip_result(
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
# TODO it is possible to move more global state preservation to eval_frame.py/c.
|
|
||||||
# See how we preserve Python random state.
|
|
||||||
def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||||
"""
|
"""
|
||||||
Context manager to:
|
Context manager to:
|
||||||
1) Save/restore torch.is_grad_enabled() state
|
1) Save/restore torch.is_grad_enabled() state
|
||||||
2) Save/restore torch random state
|
2) Save/restore python random state
|
||||||
3) Monkey patch torch.fx.graph_module._forward_from_src
|
3) Save/restore torch random state
|
||||||
|
4) Monkey patch torch.fx.graph_module._forward_from_src
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@functools.wraps(fn)
|
@functools.wraps(fn)
|
||||||
|
|
@ -217,6 +217,7 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||||
prior_mobile_allocator_state = (
|
prior_mobile_allocator_state = (
|
||||||
torch._C._is_default_mobile_cpu_allocator_set()
|
torch._C._is_default_mobile_cpu_allocator_set()
|
||||||
)
|
)
|
||||||
|
py_rng_state = random.getstate()
|
||||||
prior_dtype = torch.get_default_dtype()
|
prior_dtype = torch.get_default_dtype()
|
||||||
torch_rng_state = torch.random.get_rng_state()
|
torch_rng_state = torch.random.get_rng_state()
|
||||||
cuda_rng_state = None
|
cuda_rng_state = None
|
||||||
|
|
@ -244,6 +245,7 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||||
torch.use_deterministic_algorithms(
|
torch.use_deterministic_algorithms(
|
||||||
prior_deterministic, warn_only=prior_warn_only
|
prior_deterministic, warn_only=prior_warn_only
|
||||||
)
|
)
|
||||||
|
random.setstate(py_rng_state)
|
||||||
torch.random.set_rng_state(torch_rng_state)
|
torch.random.set_rng_state(torch_rng_state)
|
||||||
torch.set_default_dtype(prior_dtype)
|
torch.set_default_dtype(prior_dtype)
|
||||||
curr_mobile_allocator_state = (
|
curr_mobile_allocator_state = (
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@ import functools
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import sys
|
import sys
|
||||||
import sysconfig
|
import sysconfig
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
@ -535,7 +534,6 @@ class _TorchDynamoContext:
|
||||||
@functools.wraps(fn)
|
@functools.wraps(fn)
|
||||||
def _fn(*args, **kwargs):
|
def _fn(*args, **kwargs):
|
||||||
prior = set_eval_frame(None)
|
prior = set_eval_frame(None)
|
||||||
prev_rng_state = random.getstate()
|
|
||||||
try:
|
try:
|
||||||
if is_fx_tracing():
|
if is_fx_tracing():
|
||||||
if config.error_on_nested_fx_trace:
|
if config.error_on_nested_fx_trace:
|
||||||
|
|
@ -569,8 +567,6 @@ class _TorchDynamoContext:
|
||||||
_maybe_set_eval_frame(_callback_from_stance(callback))
|
_maybe_set_eval_frame(_callback_from_stance(callback))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# ignore random state updates since beginning of _fn
|
|
||||||
random.setstate(prev_rng_state)
|
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
except ShortenTraceback as e:
|
except ShortenTraceback as e:
|
||||||
# Failures in the backend likely don't have useful
|
# Failures in the backend likely don't have useful
|
||||||
|
|
@ -579,8 +575,6 @@ class _TorchDynamoContext:
|
||||||
finally:
|
finally:
|
||||||
# Restore the dynamic layer stack depth if necessary.
|
# Restore the dynamic layer stack depth if necessary.
|
||||||
set_eval_frame(None)
|
set_eval_frame(None)
|
||||||
# NB: assumes no random calls made between fn() and here
|
|
||||||
prev_rng_state = random.getstate()
|
|
||||||
torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth(
|
torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth(
|
||||||
saved_dynamic_layer_stack_depth
|
saved_dynamic_layer_stack_depth
|
||||||
)
|
)
|
||||||
|
|
@ -590,10 +584,6 @@ class _TorchDynamoContext:
|
||||||
cleanup()
|
cleanup()
|
||||||
finally:
|
finally:
|
||||||
_maybe_set_eval_frame(prior)
|
_maybe_set_eval_frame(prior)
|
||||||
# ignore random state updates:
|
|
||||||
# - since beginning of _fn if fn was not called
|
|
||||||
# - since end of fn if fn was called (even if exn occurs)
|
|
||||||
random.setstate(prev_rng_state)
|
|
||||||
|
|
||||||
# hooks to properly handle inlining
|
# hooks to properly handle inlining
|
||||||
_fn._torchdynamo_inline = fn # type: ignore[attr-defined]
|
_fn._torchdynamo_inline = fn # type: ignore[attr-defined]
|
||||||
|
|
@ -750,22 +740,18 @@ class DisableContext(_TorchDynamoContext):
|
||||||
@functools.wraps(fn)
|
@functools.wraps(fn)
|
||||||
def _fn(*args, **kwargs):
|
def _fn(*args, **kwargs):
|
||||||
prior = set_eval_frame(None)
|
prior = set_eval_frame(None)
|
||||||
prev_rng_state = random.getstate()
|
|
||||||
try:
|
try:
|
||||||
prior_skip_guard_eval_unsafe = set_skip_guard_eval_unsafe(
|
prior_skip_guard_eval_unsafe = set_skip_guard_eval_unsafe(
|
||||||
_is_skip_guard_eval_unsafe_stance()
|
_is_skip_guard_eval_unsafe_stance()
|
||||||
)
|
)
|
||||||
_maybe_set_eval_frame(_callback_from_stance(self.callback))
|
_maybe_set_eval_frame(_callback_from_stance(self.callback))
|
||||||
try:
|
try:
|
||||||
random.setstate(prev_rng_state)
|
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
set_eval_frame(None)
|
set_eval_frame(None)
|
||||||
set_skip_guard_eval_unsafe(prior_skip_guard_eval_unsafe)
|
set_skip_guard_eval_unsafe(prior_skip_guard_eval_unsafe)
|
||||||
prev_rng_state = random.getstate()
|
|
||||||
finally:
|
finally:
|
||||||
_maybe_set_eval_frame(prior)
|
_maybe_set_eval_frame(prior)
|
||||||
random.setstate(prev_rng_state)
|
|
||||||
|
|
||||||
_fn._torchdynamo_disable = True # type: ignore[attr-defined]
|
_fn._torchdynamo_disable = True # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -628,11 +628,9 @@ class OutputGraph:
|
||||||
"""
|
"""
|
||||||
global_state = cast(
|
global_state = cast(
|
||||||
dict[str, tuple[Callable[..., Any], bool]],
|
dict[str, tuple[Callable[..., Any], bool]],
|
||||||
(
|
out
|
||||||
out
|
if out is not None
|
||||||
if out is not None
|
else self.tracing_context.global_context.global_state,
|
||||||
else self.tracing_context.global_context.global_state
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO - Consider having a torch level API for torch_function_state. As
|
# TODO - Consider having a torch level API for torch_function_state. As
|
||||||
|
|
|
||||||
|
|
@ -1713,7 +1713,6 @@ class RandomVariable(VariableTracker):
|
||||||
tx.output.side_effects.mutation(self)
|
tx.output.side_effects.mutation(self)
|
||||||
state = self.random.getstate()
|
state = self.random.getstate()
|
||||||
|
|
||||||
# Generate new random object with the same state and call the method
|
|
||||||
def call_random_meth(*args, **kwargs):
|
def call_random_meth(*args, **kwargs):
|
||||||
r = random.Random()
|
r = random.Random()
|
||||||
r.setstate(state)
|
r.setstate(state)
|
||||||
|
|
|
||||||
|
|
@ -678,8 +678,6 @@ def call_random_fn(tx, fn, args, kwargs):
|
||||||
args = [x.as_python_constant() for x in args]
|
args = [x.as_python_constant() for x in args]
|
||||||
kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
|
kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
|
||||||
random_call_index = len(tx.output.random_calls)
|
random_call_index = len(tx.output.random_calls)
|
||||||
# NB: it is probably not important for the example_value to be exactly correct,
|
|
||||||
# we just need the right type
|
|
||||||
example_value = fn(*args, **kwargs)
|
example_value = fn(*args, **kwargs)
|
||||||
source = RandomValueSource(random_call_index)
|
source = RandomValueSource(random_call_index)
|
||||||
tx.output.random_calls.append((fn, args, kwargs))
|
tx.output.random_calls.append((fn, args, kwargs))
|
||||||
|
|
@ -876,12 +874,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||||
and all(k.is_python_constant() for k in args)
|
and all(k.is_python_constant() for k in args)
|
||||||
and all(v.is_python_constant() for v in kwargs.values())
|
and all(v.is_python_constant() for v in kwargs.values())
|
||||||
):
|
):
|
||||||
return call_random_fn(
|
return call_random_fn(tx, self.value, args, kwargs)
|
||||||
tx,
|
|
||||||
self.value,
|
|
||||||
args,
|
|
||||||
kwargs,
|
|
||||||
)
|
|
||||||
elif istype(self.value, types.MethodType):
|
elif istype(self.value, types.MethodType):
|
||||||
func = self.value.__func__
|
func = self.value.__func__
|
||||||
obj = self.value.__self__
|
obj = self.value.__self__
|
||||||
|
|
|
||||||
|
|
@ -9,14 +9,13 @@
|
||||||
#include <torch/csrc/dynamo/debug_macros.h>
|
#include <torch/csrc/dynamo/debug_macros.h>
|
||||||
#include <torch/csrc/dynamo/extra_state.h>
|
#include <torch/csrc/dynamo/extra_state.h>
|
||||||
#include <torch/csrc/dynamo/framelocals_mapping.h>
|
#include <torch/csrc/dynamo/framelocals_mapping.h>
|
||||||
#include <torch/csrc/dynamo/utils.h>
|
|
||||||
#include <torch/csrc/utils/python_compat.h>
|
#include <torch/csrc/utils/python_compat.h>
|
||||||
|
|
||||||
PyObject* guard_error_hook = NULL;
|
PyObject* guard_error_hook = NULL;
|
||||||
const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup";
|
const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup";
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int active_dynamo_threads;
|
int active_dynamo_threads;
|
||||||
} ModuleState;
|
} ModuleState;
|
||||||
|
|
||||||
// static int active_dynamo_threads = 0;
|
// static int active_dynamo_threads = 0;
|
||||||
|
|
@ -296,7 +295,6 @@ static PyObject* dynamo_call_callback(
|
||||||
PyObject* cache_entry_pyobj = CacheEntry_to_obj(cache_entry);
|
PyObject* cache_entry_pyobj = CacheEntry_to_obj(cache_entry);
|
||||||
PyObject* res = PyObject_CallFunction(
|
PyObject* res = PyObject_CallFunction(
|
||||||
callable, "OOO", frame, cache_entry_pyobj, frame_state);
|
callable, "OOO", frame, cache_entry_pyobj, frame_state);
|
||||||
|
|
||||||
Py_DECREF(frame);
|
Py_DECREF(frame);
|
||||||
Py_DECREF(cache_entry_pyobj);
|
Py_DECREF(cache_entry_pyobj);
|
||||||
return res;
|
return res;
|
||||||
|
|
@ -539,17 +537,6 @@ static PyObject* dynamo_eval_custom_code(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyObject* random_state = NULL;
|
|
||||||
|
|
||||||
static void save_random_state() {
|
|
||||||
Py_XSETREF(random_state, random_getstate(random_module()));
|
|
||||||
}
|
|
||||||
|
|
||||||
static void restore_random_state() {
|
|
||||||
DEBUG_NULL_CHECK(random_state);
|
|
||||||
random_setstate(random_module(), random_state);
|
|
||||||
}
|
|
||||||
|
|
||||||
static PyObject* dynamo__custom_eval_frame_shim(
|
static PyObject* dynamo__custom_eval_frame_shim(
|
||||||
PyThreadState* tstate,
|
PyThreadState* tstate,
|
||||||
THP_EVAL_API_FRAME_OBJECT* frame,
|
THP_EVAL_API_FRAME_OBJECT* frame,
|
||||||
|
|
@ -658,11 +645,6 @@ static PyObject* dynamo__custom_eval_frame(
|
||||||
// in the shim.
|
// in the shim.
|
||||||
eval_frame_callback_set(Py_None);
|
eval_frame_callback_set(Py_None);
|
||||||
|
|
||||||
// Preserve random state - restore before calling
|
|
||||||
// dynamo_eval_[frame_default]/[custom_code]!
|
|
||||||
// lookup or call_callback may burn RNG.
|
|
||||||
save_random_state();
|
|
||||||
|
|
||||||
// A callback of Py_False indicates "run only" mode, the cache is checked, but
|
// A callback of Py_False indicates "run only" mode, the cache is checked, but
|
||||||
// we never compile.
|
// we never compile.
|
||||||
// Also, if extra is marked as "cache_limit_hit", run in "run only" mode
|
// Also, if extra is marked as "cache_limit_hit", run in "run only" mode
|
||||||
|
|
@ -703,7 +685,6 @@ static PyObject* dynamo__custom_eval_frame(
|
||||||
DEBUG_TRACE("skip recursive %s", get_frame_name(frame));
|
DEBUG_TRACE("skip recursive %s", get_frame_name(frame));
|
||||||
eval_frame_callback_set(Py_None);
|
eval_frame_callback_set(Py_None);
|
||||||
}
|
}
|
||||||
restore_random_state();
|
|
||||||
PyObject* ret = dynamo_eval_frame_default(tstate, frame, throw_flag);
|
PyObject* ret = dynamo_eval_frame_default(tstate, frame, throw_flag);
|
||||||
if (extra_state_cache_limit_hit(extra)) {
|
if (extra_state_cache_limit_hit(extra)) {
|
||||||
eval_frame_callback_set(callback);
|
eval_frame_callback_set(callback);
|
||||||
|
|
@ -716,7 +697,6 @@ static PyObject* dynamo__custom_eval_frame(
|
||||||
// Re-enable custom behavior
|
// Re-enable custom behavior
|
||||||
eval_frame_callback_set(callback);
|
eval_frame_callback_set(callback);
|
||||||
*should_clear_frame = 1;
|
*should_clear_frame = 1;
|
||||||
restore_random_state();
|
|
||||||
return dynamo_eval_custom_code(
|
return dynamo_eval_custom_code(
|
||||||
tstate, frame, cached_code, trace_annotation, throw_flag);
|
tstate, frame, cached_code, trace_annotation, throw_flag);
|
||||||
}
|
}
|
||||||
|
|
@ -749,7 +729,6 @@ static PyObject* dynamo__custom_eval_frame(
|
||||||
eval_frame_callback_set(callback);
|
eval_frame_callback_set(callback);
|
||||||
*should_clear_frame = 1;
|
*should_clear_frame = 1;
|
||||||
framelocals_mapping_free(locals);
|
framelocals_mapping_free(locals);
|
||||||
restore_random_state();
|
|
||||||
return dynamo_eval_custom_code(
|
return dynamo_eval_custom_code(
|
||||||
tstate, frame, cached_code, trace_annotation, throw_flag);
|
tstate, frame, cached_code, trace_annotation, throw_flag);
|
||||||
}
|
}
|
||||||
|
|
@ -783,7 +762,6 @@ static PyObject* dynamo__custom_eval_frame(
|
||||||
// code.
|
// code.
|
||||||
DEBUG_TRACE("create skip recursive %s", get_frame_name(frame));
|
DEBUG_TRACE("create skip recursive %s", get_frame_name(frame));
|
||||||
set_extra_state(F_CODE(frame), SKIP_CODE_RECURSIVE);
|
set_extra_state(F_CODE(frame), SKIP_CODE_RECURSIVE);
|
||||||
restore_random_state();
|
|
||||||
PyObject* r = dynamo_eval_frame_default(tstate, frame, throw_flag);
|
PyObject* r = dynamo_eval_frame_default(tstate, frame, throw_flag);
|
||||||
// Re-enable custom behavior
|
// Re-enable custom behavior
|
||||||
eval_frame_callback_set(callback);
|
eval_frame_callback_set(callback);
|
||||||
|
|
@ -792,7 +770,6 @@ static PyObject* dynamo__custom_eval_frame(
|
||||||
// Dynamo returned cache_limit_hit_flag, so we should recursively skip code.
|
// Dynamo returned cache_limit_hit_flag, so we should recursively skip code.
|
||||||
DEBUG_TRACE("create cache limit hit %s", get_frame_name(frame));
|
DEBUG_TRACE("create cache limit hit %s", get_frame_name(frame));
|
||||||
set_extra_state_cache_limit_hit(extra, true);
|
set_extra_state_cache_limit_hit(extra, true);
|
||||||
restore_random_state();
|
|
||||||
PyObject* r = dynamo_eval_frame_default(tstate, frame, throw_flag);
|
PyObject* r = dynamo_eval_frame_default(tstate, frame, throw_flag);
|
||||||
// Re-enable custom behavior
|
// Re-enable custom behavior
|
||||||
eval_frame_callback_set(callback);
|
eval_frame_callback_set(callback);
|
||||||
|
|
@ -814,7 +791,6 @@ static PyObject* dynamo__custom_eval_frame(
|
||||||
// Re-enable custom behavior
|
// Re-enable custom behavior
|
||||||
eval_frame_callback_set(callback);
|
eval_frame_callback_set(callback);
|
||||||
*should_clear_frame = 1;
|
*should_clear_frame = 1;
|
||||||
restore_random_state();
|
|
||||||
return dynamo_eval_custom_code(
|
return dynamo_eval_custom_code(
|
||||||
tstate,
|
tstate,
|
||||||
frame,
|
frame,
|
||||||
|
|
@ -827,7 +803,6 @@ static PyObject* dynamo__custom_eval_frame(
|
||||||
set_extra_state(F_CODE(frame), SKIP_CODE);
|
set_extra_state(F_CODE(frame), SKIP_CODE);
|
||||||
// Re-enable custom behavior
|
// Re-enable custom behavior
|
||||||
eval_frame_callback_set(callback);
|
eval_frame_callback_set(callback);
|
||||||
restore_random_state();
|
|
||||||
return dynamo_eval_frame_default(tstate, frame, throw_flag);
|
return dynamo_eval_frame_default(tstate, frame, throw_flag);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -856,9 +831,7 @@ static PyTypeObject THPPyInterpreterFrameType = {
|
||||||
|
|
||||||
#endif // !(IS_PYTHON_3_14_PLUS)
|
#endif // !(IS_PYTHON_3_14_PLUS)
|
||||||
|
|
||||||
static PyObject* increment_working_threads(
|
static PyObject* increment_working_threads(PyThreadState* tstate, PyObject* module) {
|
||||||
PyThreadState* tstate,
|
|
||||||
PyObject* module) {
|
|
||||||
ModuleState* state = PyModule_GetState(module);
|
ModuleState* state = PyModule_GetState(module);
|
||||||
|
|
||||||
if (state != NULL) {
|
if (state != NULL) {
|
||||||
|
|
@ -871,13 +844,11 @@ static PyObject* increment_working_threads(
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyObject* decrement_working_threads(
|
static PyObject* decrement_working_threads(PyThreadState* tstate, PyObject* module) {
|
||||||
PyThreadState* tstate,
|
|
||||||
PyObject* module) {
|
|
||||||
ModuleState* state = PyModule_GetState(module);
|
ModuleState* state = PyModule_GetState(module);
|
||||||
|
|
||||||
if (state != NULL) {
|
if (state != NULL) {
|
||||||
if (state->active_dynamo_threads > 0) {
|
if (state->active_dynamo_threads > 0) {
|
||||||
state->active_dynamo_threads = state->active_dynamo_threads - 1;
|
state->active_dynamo_threads = state->active_dynamo_threads - 1;
|
||||||
if (state->active_dynamo_threads == 0) {
|
if (state->active_dynamo_threads == 0) {
|
||||||
enable_eval_frame_default(tstate);
|
enable_eval_frame_default(tstate);
|
||||||
|
|
@ -888,10 +859,7 @@ static PyObject* decrement_working_threads(
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyObject* set_eval_frame(
|
static PyObject* set_eval_frame(PyObject* new_callback, PyThreadState* tstate, PyObject* module) {
|
||||||
PyObject* new_callback,
|
|
||||||
PyThreadState* tstate,
|
|
||||||
PyObject* module) {
|
|
||||||
// Change the eval frame callback and return the old one
|
// Change the eval frame callback and return the old one
|
||||||
// - None: disables TorchDynamo
|
// - None: disables TorchDynamo
|
||||||
// - False: run-only mode (reuse existing compiles)
|
// - False: run-only mode (reuse existing compiles)
|
||||||
|
|
@ -1014,13 +982,13 @@ static PyObject* raise_sigtrap(PyObject* dummy, PyObject* obj) {
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int clear_state(PyObject* module) {
|
static int clear_state(PyObject *module) {
|
||||||
ModuleState* state = PyModule_GetState(module);
|
ModuleState* state = PyModule_GetState(module);
|
||||||
if (state) {
|
if (state) {
|
||||||
state->active_dynamo_threads = 0;
|
state->active_dynamo_threads = 0;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyMethodDef _methods[] = {
|
static PyMethodDef _methods[] = {
|
||||||
|
|
|
||||||
|
|
@ -2,29 +2,6 @@
|
||||||
|
|
||||||
namespace torch::dynamo {
|
namespace torch::dynamo {
|
||||||
|
|
||||||
// random utilities for C dynamo
|
|
||||||
|
|
||||||
// random module reference
|
|
||||||
py::object _random{py::none()};
|
|
||||||
|
|
||||||
PyObject* random_module() {
|
|
||||||
if (_random.is_none()) {
|
|
||||||
_random = py::module_::import("random");
|
|
||||||
}
|
|
||||||
return _random.ptr();
|
|
||||||
}
|
|
||||||
|
|
||||||
PyObject* random_getstate(PyObject* rng) {
|
|
||||||
py::handle rng_h(rng);
|
|
||||||
py::object state = rng_h.attr("getstate")();
|
|
||||||
return state.release().ptr();
|
|
||||||
}
|
|
||||||
|
|
||||||
void random_setstate(PyObject* rng, PyObject* state) {
|
|
||||||
py::handle rng_h(rng), state_h(state);
|
|
||||||
rng_h.attr("setstate")(state_h);
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::array<PyMethodDef, 1> _methods = {{
|
static std::array<PyMethodDef, 1> _methods = {{
|
||||||
{nullptr,
|
{nullptr,
|
||||||
nullptr,
|
nullptr,
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,10 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
|
|
||||||
#include <torch/csrc/python_headers.h>
|
#include <torch/csrc/python_headers.h>
|
||||||
// C2039 MSVC
|
// C2039 MSVC
|
||||||
#include <pybind11/complex.h>
|
#include <pybind11/complex.h>
|
||||||
#include <torch/csrc/utils/pybind.h>
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
|
||||||
#endif // __cplusplus
|
|
||||||
|
|
||||||
#include <Python.h>
|
#include <Python.h>
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
|
|
||||||
// The visibility attribute is to avoid a warning about storing a field in the
|
// The visibility attribute is to avoid a warning about storing a field in the
|
||||||
// struct that has a different visibility (from pybind) than the struct.
|
// struct that has a different visibility (from pybind) than the struct.
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
|
|
@ -22,32 +14,5 @@
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace torch::dynamo {
|
namespace torch::dynamo {
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
|
|
||||||
#endif // __cplusplus
|
|
||||||
|
|
||||||
// reference to random module
|
|
||||||
// returns borrowed reference
|
|
||||||
PyObject* random_module();
|
|
||||||
|
|
||||||
// rng.getstate()
|
|
||||||
// rng can be random module or random.Random object
|
|
||||||
// rng: borrowed reference
|
|
||||||
// returns new reference
|
|
||||||
PyObject* random_getstate(PyObject* rng);
|
|
||||||
|
|
||||||
// rng.setstate(state)
|
|
||||||
// rng can be random module or random.Random object
|
|
||||||
// rng, state: borrowed references
|
|
||||||
// no return value
|
|
||||||
void random_setstate(PyObject* rng, PyObject* state);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
|
|
||||||
} // extern "C"
|
|
||||||
|
|
||||||
PyObject* torch_c_dynamo_utils_init();
|
PyObject* torch_c_dynamo_utils_init();
|
||||||
} // namespace torch::dynamo
|
} // namespace torch::dynamo
|
||||||
|
|
||||||
#endif // __cplusplus
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue