mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
automatic dynamic unspecialize float (#141647)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141647 Approved by: https://github.com/ezyang
This commit is contained in:
parent
e29dabbd71
commit
2f72635a5c
9 changed files with 95 additions and 29 deletions
|
|
@ -1398,7 +1398,7 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
cfg2.val = 2.0
|
||||
v = opt_fn(v, cfg2) # 7
|
||||
self.assertEqual(v[0], 7)
|
||||
self.assertEqual(cnts.op_count, 8)
|
||||
self.assertEqual(cnts.op_count, 9)
|
||||
|
||||
def test_config_getattr_default(self):
|
||||
class Cfg:
|
||||
|
|
@ -3747,8 +3747,18 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
result1, result2, _ = opt_fn()
|
||||
self.assertAlmostEqual(orig1 + 1 * i, result1)
|
||||
self.assertTrue(torch.allclose(orig2 + 10 * i, result2))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 3)
|
||||
if i == 1:
|
||||
# No automatic dynamic
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 3)
|
||||
elif i == 2:
|
||||
# Automatic dynamic float arguments kicked in
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 6)
|
||||
else:
|
||||
# No more recompiles
|
||||
self.assertEqual(cnts.frame_count, 0)
|
||||
self.assertEqual(cnts.op_count, 0)
|
||||
cnts.clear()
|
||||
|
||||
def test_closure_with_mutation_and_graph_break(self):
|
||||
|
|
|
|||
|
|
@ -641,18 +641,19 @@ class UnspecTests(torch._dynamo.test_case.TestCase):
|
|||
cf = torch.compile(backend=cnts, fullgraph=True)(f)
|
||||
|
||||
x = torch.randn(3)
|
||||
self.assertEqual(f(x, 3.0), cf(x, 3.0))
|
||||
self.assertEqual(f(x, 2.0), cf(x, 2.0))
|
||||
self.assertEqual(f(x, 3.0), cf(x, 3.0)) # automatic dynamic kicks in here
|
||||
self.assertEqual(f(x, 4.0), cf(x, 4.0))
|
||||
self.assertExpectedInline(cnts.frame_count, """1""") # no recompile
|
||||
self.assertExpectedInline(cnts.frame_count, """2""") # no recompile
|
||||
self.assertEqual(f(x, 5.0), cf(x, 5.0))
|
||||
self.assertExpectedInline(cnts.frame_count, """2""") # guard worked
|
||||
self.assertExpectedInline(cnts.frame_count, """3""") # guard worked
|
||||
self.assertEqual(f(x, math.nan), cf(x, math.nan))
|
||||
self.assertExpectedInline(cnts.frame_count, """3""") # nan always recompiles
|
||||
self.assertExpectedInline(cnts.frame_count, """4""") # nan always recompiles
|
||||
|
||||
@torch._dynamo.config.patch(specialize_float=False, capture_scalar_outputs=True)
|
||||
def test_unspecialized_float_multiply_precision(self):
|
||||
dtypes = [torch.bfloat16, torch.float16, torch.float32, torch.float64]
|
||||
for dtype in dtypes:
|
||||
for i, dtype in enumerate(dtypes):
|
||||
|
||||
def fn(x, y):
|
||||
return x * y
|
||||
|
|
@ -662,10 +663,19 @@ class UnspecTests(torch._dynamo.test_case.TestCase):
|
|||
x = torch.randn(5, dtype=dtype, requires_grad=True)
|
||||
y1 = 1.00048828125
|
||||
y2 = 1.00048828126
|
||||
y3 = 1.00048828127
|
||||
|
||||
self.assertEqual(fn_opt(x, y1), fn(x, y1))
|
||||
self.assertEqual(fn_opt(x, y2), fn(x, y2))
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(fn_opt(x, y3), fn(x, y3))
|
||||
if i == 0:
|
||||
# This is kind of quirky part of automatic dynamic,
|
||||
# since it just uses source name + tx.f_code as the key
|
||||
# subsequent recompilations will actually reuse the automatic
|
||||
# dynamic choices.
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
else:
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
@torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=False)
|
||||
def test_unspec_float_input_f64(self):
|
||||
|
|
|
|||
|
|
@ -270,7 +270,7 @@ class TestDynamoTimed(TestCase):
|
|||
'runtime_cudagraphify_time_us': None,
|
||||
'runtime_triton_autotune_time_us': None,
|
||||
'shape_env_guard_count': 0,
|
||||
'specialize_float': True,
|
||||
'specialize_float': False,
|
||||
'start_time': 0.0001,
|
||||
'start_time_us': 100,
|
||||
'structured_logging_overhead_s': 0.0,
|
||||
|
|
|
|||
|
|
@ -607,6 +607,8 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
|||
@torch._functorch.config.patch("enable_autograd_cache", True)
|
||||
@torch._inductor.config.patch("fx_graph_cache", True)
|
||||
@torch._inductor.config.patch("fx_graph_remote_cache", False)
|
||||
# Currently fx graph cache is turned off for specialize_float=False
|
||||
@torch._dynamo.config.patch("specialize_float", True)
|
||||
def test_cache_hit_forward_miss_backward(self):
|
||||
# Test that we don't cache cudagraphs, skipping cudagraphs on backward on a cache miss
|
||||
|
||||
|
|
@ -661,6 +663,8 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
|||
@torch._functorch.config.patch("enable_autograd_cache", True)
|
||||
@torch._inductor.config.patch("fx_graph_cache", True)
|
||||
@torch._inductor.config.patch("fx_graph_remote_cache", False)
|
||||
# Currently fx graph cache is turned off for specialize_float=False
|
||||
@torch._dynamo.config.patch("specialize_float", True)
|
||||
def test_backward_gets_cached_cudagraphs(self):
|
||||
# We pass cpu tensors to foo and save that into the cache
|
||||
# On a subsequent run in a new process, cudagraphs should be
|
||||
|
|
@ -705,6 +709,8 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
|||
@torch._functorch.config.patch("enable_autograd_cache", True)
|
||||
@torch._inductor.config.patch("fx_graph_cache", True)
|
||||
@torch._inductor.config.patch("fx_graph_remote_cache", False)
|
||||
# Currently fx graph cache is turned off for specialize_float=False
|
||||
@torch._dynamo.config.patch("specialize_float", True)
|
||||
def test_cached_forward_backward(self):
|
||||
counters.clear()
|
||||
AOTAutogradCache.clear()
|
||||
|
|
|
|||
|
|
@ -969,7 +969,7 @@ class TestInductorDynamic(TestCase):
|
|||
"divide": operator.truediv,
|
||||
}
|
||||
|
||||
for name, op in operations.items():
|
||||
for i, (name, op) in enumerate(operations.items()):
|
||||
with self.subTest(operation=name):
|
||||
|
||||
def fn(x, y):
|
||||
|
|
@ -981,7 +981,14 @@ class TestInductorDynamic(TestCase):
|
|||
x = torch.arange(3)
|
||||
self.assertEqual(fn(x, 2.0), fn_opt(x, 2.0))
|
||||
self.assertEqual(fn(x, 3.0), fn_opt(x, 3.0))
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(fn(x, 4.0), fn_opt(x, 4.0))
|
||||
if i == 0:
|
||||
# Automatic dynamic state persists across
|
||||
# compiles so only the first compile
|
||||
# goes through the automatic dynamic step.
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
else:
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
@torch._dynamo.config.patch(specialize_float=False)
|
||||
def test_unspecialized_float_fallback_specialization(self):
|
||||
|
|
@ -1005,8 +1012,25 @@ class TestInductorDynamic(TestCase):
|
|||
self.assertEqual(fn(x, 2.0, z), fn_opt(x, 2.0, z))
|
||||
self.assertEqual(fn(x, 3.0, z), fn_opt(x, 3.0, z))
|
||||
self.assertEqual(fn(x, 4.0, z), fn_opt(x, 4.0, z))
|
||||
# We expect frame count to be 2 since we will have
|
||||
# one sledgehammer restart.
|
||||
# Automatic dynamic float arguments
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
def test_unspecialized_float_softshrink(self):
|
||||
# This test is particularly interesting since it exercises
|
||||
# both standard operator replacements ie. torch.ops.aten.mul.Tensor
|
||||
# as well as comparison replacements ie. torch.ops.aten.ge.Scalar
|
||||
def fn(x, y):
|
||||
return torch._C._nn.softshrink(x, lambd=y)
|
||||
|
||||
cnt = CompileCounterWithBackend("inductor")
|
||||
fn_opt = torch._dynamo.optimize(cnt)(fn)
|
||||
x = torch.randn(5, 5)
|
||||
|
||||
print(fn(x, 2.0), fn_opt(x, 2.0))
|
||||
|
||||
self.assertEqual(fn(x, 2.0), fn_opt(x, 2.0))
|
||||
self.assertEqual(fn(x, 3.0), fn_opt(x, 3.0))
|
||||
self.assertEqual(fn(x, 4.0), fn_opt(x, 4.0))
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
@torch._dynamo.config.patch(specialize_float=False)
|
||||
|
|
@ -1021,8 +1045,7 @@ class TestInductorDynamic(TestCase):
|
|||
self.assertEqual(fn(2.0, y), fn_opt(2.0, y))
|
||||
self.assertEqual(fn(3.0, y), fn_opt(3.0, y))
|
||||
self.assertEqual(fn(4.0, y), fn_opt(4.0, y))
|
||||
# We expect frame count to be N + 1 since we will have
|
||||
# one sledgehammer restart for the first compile.
|
||||
# N + 1 for automatic dynamic float arguments
|
||||
self.assertEqual(cnt.frame_count, 4)
|
||||
|
||||
def test_sort_dynamic_shape_with_check(self, device):
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ specialize_int = False
|
|||
# Whether or not to specialize on float inputs. Dynamo will always promote
|
||||
# float inputs into Tensor inputs, but at the moment, backends inconsistently
|
||||
# support codegen on float (this is to be fixed).
|
||||
specialize_float = True
|
||||
specialize_float = True if is_fbcode() else False
|
||||
|
||||
# legacy config, does nothing now!
|
||||
dynamic_shapes = True
|
||||
|
|
|
|||
|
|
@ -1904,6 +1904,13 @@ class VariableBuilder:
|
|||
if self.name in self.tx.output.unspec_variable_map:
|
||||
return self.tx.output.unspec_variable_map[self.name]
|
||||
|
||||
frame_state_entry = process_automatic_dynamic(
|
||||
self.tx,
|
||||
self.source.name(),
|
||||
FrameStateSizeEntry.make_scalar(value),
|
||||
is_unspecialized_nn_module=self.source.guard_source().is_unspecialized_nn_module(),
|
||||
)
|
||||
|
||||
# NB: we specialize on nan input, because our guard modeling in
|
||||
# ShapeEnv cannot deal with nan
|
||||
if (
|
||||
|
|
@ -1918,6 +1925,7 @@ class VariableBuilder:
|
|||
# python test/inductor/test_compiled_optimizers.py CompiledOptimizerTests.test_rmsprop_weight_decay_maximize_capturable_cuda # noqa: B950
|
||||
or torch._inductor.config.triton.cudagraphs
|
||||
or justknobs_check("pytorch/compiler:unspecialize_float_killswitch", False)
|
||||
or frame_state_entry.scalar is not auto_dynamic
|
||||
):
|
||||
self.install_guards(GuardBuilder.CONSTANT_MATCH)
|
||||
return ConstantVariable.create(value=value, source=self.source)
|
||||
|
|
|
|||
|
|
@ -205,11 +205,6 @@ def check_cacheable(gm: torch.fx.GraphModule):
|
|||
"Cannot cache a graph with compiled autograd enabled"
|
||||
)
|
||||
|
||||
if not torch._dynamo.config.specialize_float:
|
||||
raise BypassAOTAutogradCache(
|
||||
"Cannot cache a graph with specialize_float disabled"
|
||||
)
|
||||
|
||||
if not (
|
||||
torch._inductor.config.fx_graph_cache or should_use_remote_fx_graph_cache()
|
||||
):
|
||||
|
|
|
|||
|
|
@ -73,10 +73,16 @@ graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
|
|||
|
||||
|
||||
SUPPORTED_OPS = {
|
||||
torch.ops.aten.mul.Tensor,
|
||||
torch.ops.aten.add.Tensor,
|
||||
torch.ops.aten.sub.Tensor,
|
||||
torch.ops.aten.div.Tensor,
|
||||
torch.ops.aten.mul.Tensor: torch.ops.aten.mul.Tensor,
|
||||
torch.ops.aten.add.Tensor: torch.ops.aten.add.Tensor,
|
||||
torch.ops.aten.sub.Tensor: torch.ops.aten.sub.Tensor,
|
||||
torch.ops.aten.div.Tensor: torch.ops.aten.div.Tensor,
|
||||
torch.ops.aten.gt.Scalar: torch.ops.aten.gt.Tensor,
|
||||
torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor,
|
||||
torch.ops.aten.ge.Scalar: torch.ops.aten.ge.Tensor,
|
||||
torch.ops.aten.le.Scalar: torch.ops.aten.le.Tensor,
|
||||
torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor,
|
||||
torch.ops.aten.ne.Scalar: torch.ops.aten.ne.Tensor,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -232,7 +238,9 @@ def tensorify_python_scalars(
|
|||
should_restart = True
|
||||
|
||||
# Look for functions to convert
|
||||
if node.op == "call_function" and node.target in SUPPORTED_OPS:
|
||||
if node.op == "call_function" and (
|
||||
replacement_op := SUPPORTED_OPS.get(node.target)
|
||||
):
|
||||
args: List[Any] = []
|
||||
transform = False
|
||||
compute_dtype = get_computation_dtype(node.meta["val"].dtype)
|
||||
|
|
@ -253,7 +261,13 @@ def tensorify_python_scalars(
|
|||
# We use _expr instead of expr b/c we want the symbol not the replacement
|
||||
tensorified_symbols.add(a.meta["val"].node._expr)
|
||||
|
||||
if proxy.node.meta["val"].dtype != compute_dtype:
|
||||
# The upcasting is irrelevant when the compute dtype is bool. This happens
|
||||
# in cases where we are tensorifying a comparison operator such as
|
||||
# torch.ops.aten.gt.Tensor
|
||||
if (
|
||||
compute_dtype != torch.bool
|
||||
and proxy.node.meta["val"].dtype != compute_dtype
|
||||
):
|
||||
proxy = torch.ops.prims.convert_element_type.default(
|
||||
proxy, compute_dtype
|
||||
)
|
||||
|
|
@ -265,7 +279,7 @@ def tensorify_python_scalars(
|
|||
args.append(a)
|
||||
|
||||
if transform:
|
||||
replacement_proxy = node.target(*args)
|
||||
replacement_proxy = replacement_op(*args)
|
||||
|
||||
if compute_dtype != node.meta["val"].dtype:
|
||||
replacement_proxy = (
|
||||
|
|
|
|||
Loading…
Reference in a new issue