automatic dynamic unspecialize float (#141647)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141647
Approved by: https://github.com/ezyang
This commit is contained in:
Bob Ren 2024-11-29 10:53:37 -08:00 committed by PyTorch MergeBot
parent e29dabbd71
commit 2f72635a5c
9 changed files with 95 additions and 29 deletions

View file

@ -1398,7 +1398,7 @@ utils_device.CURRENT_DEVICE == None""".split(
cfg2.val = 2.0
v = opt_fn(v, cfg2) # 7
self.assertEqual(v[0], 7)
self.assertEqual(cnts.op_count, 8)
self.assertEqual(cnts.op_count, 9)
def test_config_getattr_default(self):
class Cfg:
@ -3747,8 +3747,18 @@ utils_device.CURRENT_DEVICE == None""".split(
result1, result2, _ = opt_fn()
self.assertAlmostEqual(orig1 + 1 * i, result1)
self.assertTrue(torch.allclose(orig2 + 10 * i, result2))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 3)
if i == 1:
# No automatic dynamic
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 3)
elif i == 2:
# Automatic dynamic float arguments kicked in
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 6)
else:
# No more recompiles
self.assertEqual(cnts.frame_count, 0)
self.assertEqual(cnts.op_count, 0)
cnts.clear()
def test_closure_with_mutation_and_graph_break(self):

View file

@ -641,18 +641,19 @@ class UnspecTests(torch._dynamo.test_case.TestCase):
cf = torch.compile(backend=cnts, fullgraph=True)(f)
x = torch.randn(3)
self.assertEqual(f(x, 3.0), cf(x, 3.0))
self.assertEqual(f(x, 2.0), cf(x, 2.0))
self.assertEqual(f(x, 3.0), cf(x, 3.0)) # automatic dynamic kicks in here
self.assertEqual(f(x, 4.0), cf(x, 4.0))
self.assertExpectedInline(cnts.frame_count, """1""") # no recompile
self.assertExpectedInline(cnts.frame_count, """2""") # no recompile
self.assertEqual(f(x, 5.0), cf(x, 5.0))
self.assertExpectedInline(cnts.frame_count, """2""") # guard worked
self.assertExpectedInline(cnts.frame_count, """3""") # guard worked
self.assertEqual(f(x, math.nan), cf(x, math.nan))
self.assertExpectedInline(cnts.frame_count, """3""") # nan always recompiles
self.assertExpectedInline(cnts.frame_count, """4""") # nan always recompiles
@torch._dynamo.config.patch(specialize_float=False, capture_scalar_outputs=True)
def test_unspecialized_float_multiply_precision(self):
dtypes = [torch.bfloat16, torch.float16, torch.float32, torch.float64]
for dtype in dtypes:
for i, dtype in enumerate(dtypes):
def fn(x, y):
return x * y
@ -662,10 +663,19 @@ class UnspecTests(torch._dynamo.test_case.TestCase):
x = torch.randn(5, dtype=dtype, requires_grad=True)
y1 = 1.00048828125
y2 = 1.00048828126
y3 = 1.00048828127
self.assertEqual(fn_opt(x, y1), fn(x, y1))
self.assertEqual(fn_opt(x, y2), fn(x, y2))
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(fn_opt(x, y3), fn(x, y3))
if i == 0:
# This is kind of quirky part of automatic dynamic,
# since it just uses source name + tx.f_code as the key
# subsequent recompilations will actually reuse the automatic
# dynamic choices.
self.assertEqual(cnt.frame_count, 2)
else:
self.assertEqual(cnt.frame_count, 1)
@torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=False)
def test_unspec_float_input_f64(self):

View file

@ -270,7 +270,7 @@ class TestDynamoTimed(TestCase):
'runtime_cudagraphify_time_us': None,
'runtime_triton_autotune_time_us': None,
'shape_env_guard_count': 0,
'specialize_float': True,
'specialize_float': False,
'start_time': 0.0001,
'start_time_us': 100,
'structured_logging_overhead_s': 0.0,

View file

@ -607,6 +607,8 @@ if HAS_CUDA and not TEST_WITH_ASAN:
@torch._functorch.config.patch("enable_autograd_cache", True)
@torch._inductor.config.patch("fx_graph_cache", True)
@torch._inductor.config.patch("fx_graph_remote_cache", False)
# Currently fx graph cache is turned off for specialize_float=False
@torch._dynamo.config.patch("specialize_float", True)
def test_cache_hit_forward_miss_backward(self):
# Test that we don't cache cudagraphs, skipping cudagraphs on backward on a cache miss
@ -661,6 +663,8 @@ if HAS_CUDA and not TEST_WITH_ASAN:
@torch._functorch.config.patch("enable_autograd_cache", True)
@torch._inductor.config.patch("fx_graph_cache", True)
@torch._inductor.config.patch("fx_graph_remote_cache", False)
# Currently fx graph cache is turned off for specialize_float=False
@torch._dynamo.config.patch("specialize_float", True)
def test_backward_gets_cached_cudagraphs(self):
# We pass cpu tensors to foo and save that into the cache
# On a subsequent run in a new process, cudagraphs should be
@ -705,6 +709,8 @@ if HAS_CUDA and not TEST_WITH_ASAN:
@torch._functorch.config.patch("enable_autograd_cache", True)
@torch._inductor.config.patch("fx_graph_cache", True)
@torch._inductor.config.patch("fx_graph_remote_cache", False)
# Currently fx graph cache is turned off for specialize_float=False
@torch._dynamo.config.patch("specialize_float", True)
def test_cached_forward_backward(self):
counters.clear()
AOTAutogradCache.clear()

View file

@ -969,7 +969,7 @@ class TestInductorDynamic(TestCase):
"divide": operator.truediv,
}
for name, op in operations.items():
for i, (name, op) in enumerate(operations.items()):
with self.subTest(operation=name):
def fn(x, y):
@ -981,7 +981,14 @@ class TestInductorDynamic(TestCase):
x = torch.arange(3)
self.assertEqual(fn(x, 2.0), fn_opt(x, 2.0))
self.assertEqual(fn(x, 3.0), fn_opt(x, 3.0))
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(fn(x, 4.0), fn_opt(x, 4.0))
if i == 0:
# Automatic dynamic state persists across
# compiles so only the first compile
# goes through the automatic dynamic step.
self.assertEqual(cnt.frame_count, 2)
else:
self.assertEqual(cnt.frame_count, 1)
@torch._dynamo.config.patch(specialize_float=False)
def test_unspecialized_float_fallback_specialization(self):
@ -1005,8 +1012,25 @@ class TestInductorDynamic(TestCase):
self.assertEqual(fn(x, 2.0, z), fn_opt(x, 2.0, z))
self.assertEqual(fn(x, 3.0, z), fn_opt(x, 3.0, z))
self.assertEqual(fn(x, 4.0, z), fn_opt(x, 4.0, z))
# We expect frame count to be 2 since we will have
# one sledgehammer restart.
# Automatic dynamic float arguments
self.assertEqual(cnt.frame_count, 2)
def test_unspecialized_float_softshrink(self):
# This test is particularly interesting since it exercises
# both standard operator replacements ie. torch.ops.aten.mul.Tensor
# as well as comparison replacements ie. torch.ops.aten.ge.Scalar
def fn(x, y):
return torch._C._nn.softshrink(x, lambd=y)
cnt = CompileCounterWithBackend("inductor")
fn_opt = torch._dynamo.optimize(cnt)(fn)
x = torch.randn(5, 5)
print(fn(x, 2.0), fn_opt(x, 2.0))
self.assertEqual(fn(x, 2.0), fn_opt(x, 2.0))
self.assertEqual(fn(x, 3.0), fn_opt(x, 3.0))
self.assertEqual(fn(x, 4.0), fn_opt(x, 4.0))
self.assertEqual(cnt.frame_count, 2)
@torch._dynamo.config.patch(specialize_float=False)
@ -1021,8 +1045,7 @@ class TestInductorDynamic(TestCase):
self.assertEqual(fn(2.0, y), fn_opt(2.0, y))
self.assertEqual(fn(3.0, y), fn_opt(3.0, y))
self.assertEqual(fn(4.0, y), fn_opt(4.0, y))
# We expect frame count to be N + 1 since we will have
# one sledgehammer restart for the first compile.
# N + 1 for automatic dynamic float arguments
self.assertEqual(cnt.frame_count, 4)
def test_sort_dynamic_shape_with_check(self, device):

View file

@ -65,7 +65,7 @@ specialize_int = False
# Whether or not to specialize on float inputs. Dynamo will always promote
# float inputs into Tensor inputs, but at the moment, backends inconsistently
# support codegen on float (this is to be fixed).
specialize_float = True
specialize_float = True if is_fbcode() else False
# legacy config, does nothing now!
dynamic_shapes = True

View file

@ -1904,6 +1904,13 @@ class VariableBuilder:
if self.name in self.tx.output.unspec_variable_map:
return self.tx.output.unspec_variable_map[self.name]
frame_state_entry = process_automatic_dynamic(
self.tx,
self.source.name(),
FrameStateSizeEntry.make_scalar(value),
is_unspecialized_nn_module=self.source.guard_source().is_unspecialized_nn_module(),
)
# NB: we specialize on nan input, because our guard modeling in
# ShapeEnv cannot deal with nan
if (
@ -1918,6 +1925,7 @@ class VariableBuilder:
# python test/inductor/test_compiled_optimizers.py CompiledOptimizerTests.test_rmsprop_weight_decay_maximize_capturable_cuda # noqa: B950
or torch._inductor.config.triton.cudagraphs
or justknobs_check("pytorch/compiler:unspecialize_float_killswitch", False)
or frame_state_entry.scalar is not auto_dynamic
):
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value, source=self.source)

View file

@ -205,11 +205,6 @@ def check_cacheable(gm: torch.fx.GraphModule):
"Cannot cache a graph with compiled autograd enabled"
)
if not torch._dynamo.config.specialize_float:
raise BypassAOTAutogradCache(
"Cannot cache a graph with specialize_float disabled"
)
if not (
torch._inductor.config.fx_graph_cache or should_use_remote_fx_graph_cache()
):

View file

@ -73,10 +73,16 @@ graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
SUPPORTED_OPS = {
torch.ops.aten.mul.Tensor,
torch.ops.aten.add.Tensor,
torch.ops.aten.sub.Tensor,
torch.ops.aten.div.Tensor,
torch.ops.aten.mul.Tensor: torch.ops.aten.mul.Tensor,
torch.ops.aten.add.Tensor: torch.ops.aten.add.Tensor,
torch.ops.aten.sub.Tensor: torch.ops.aten.sub.Tensor,
torch.ops.aten.div.Tensor: torch.ops.aten.div.Tensor,
torch.ops.aten.gt.Scalar: torch.ops.aten.gt.Tensor,
torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor,
torch.ops.aten.ge.Scalar: torch.ops.aten.ge.Tensor,
torch.ops.aten.le.Scalar: torch.ops.aten.le.Tensor,
torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor,
torch.ops.aten.ne.Scalar: torch.ops.aten.ne.Tensor,
}
@ -232,7 +238,9 @@ def tensorify_python_scalars(
should_restart = True
# Look for functions to convert
if node.op == "call_function" and node.target in SUPPORTED_OPS:
if node.op == "call_function" and (
replacement_op := SUPPORTED_OPS.get(node.target)
):
args: List[Any] = []
transform = False
compute_dtype = get_computation_dtype(node.meta["val"].dtype)
@ -253,7 +261,13 @@ def tensorify_python_scalars(
# We use _expr instead of expr b/c we want the symbol not the replacement
tensorified_symbols.add(a.meta["val"].node._expr)
if proxy.node.meta["val"].dtype != compute_dtype:
# The upcasting is irrelevant when the compute dtype is bool. This happens
# in cases where we are tensorifying a comparison operator such as
# torch.ops.aten.gt.Tensor
if (
compute_dtype != torch.bool
and proxy.node.meta["val"].dtype != compute_dtype
):
proxy = torch.ops.prims.convert_element_type.default(
proxy, compute_dtype
)
@ -265,7 +279,7 @@ def tensorify_python_scalars(
args.append(a)
if transform:
replacement_proxy = node.target(*args)
replacement_proxy = replacement_op(*args)
if compute_dtype != node.meta["val"].dtype:
replacement_proxy = (