diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index f1d544d92bc..8947e577010 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1398,7 +1398,7 @@ utils_device.CURRENT_DEVICE == None""".split( cfg2.val = 2.0 v = opt_fn(v, cfg2) # 7 self.assertEqual(v[0], 7) - self.assertEqual(cnts.op_count, 8) + self.assertEqual(cnts.op_count, 9) def test_config_getattr_default(self): class Cfg: @@ -3747,8 +3747,18 @@ utils_device.CURRENT_DEVICE == None""".split( result1, result2, _ = opt_fn() self.assertAlmostEqual(orig1 + 1 * i, result1) self.assertTrue(torch.allclose(orig2 + 10 * i, result2)) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 3) + if i == 1: + # No automatic dynamic + self.assertEqual(cnts.frame_count, 1) + self.assertEqual(cnts.op_count, 3) + elif i == 2: + # Automatic dynamic float arguments kicked in + self.assertEqual(cnts.frame_count, 1) + self.assertEqual(cnts.op_count, 6) + else: + # No more recompiles + self.assertEqual(cnts.frame_count, 0) + self.assertEqual(cnts.op_count, 0) cnts.clear() def test_closure_with_mutation_and_graph_break(self): diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index 565759a1fdb..4aaf5bd7dff 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -641,18 +641,19 @@ class UnspecTests(torch._dynamo.test_case.TestCase): cf = torch.compile(backend=cnts, fullgraph=True)(f) x = torch.randn(3) - self.assertEqual(f(x, 3.0), cf(x, 3.0)) + self.assertEqual(f(x, 2.0), cf(x, 2.0)) + self.assertEqual(f(x, 3.0), cf(x, 3.0)) # automatic dynamic kicks in here self.assertEqual(f(x, 4.0), cf(x, 4.0)) - self.assertExpectedInline(cnts.frame_count, """1""") # no recompile + self.assertExpectedInline(cnts.frame_count, """2""") # no recompile self.assertEqual(f(x, 5.0), cf(x, 5.0)) - self.assertExpectedInline(cnts.frame_count, """2""") # guard worked + self.assertExpectedInline(cnts.frame_count, """3""") # guard worked self.assertEqual(f(x, math.nan), cf(x, math.nan)) - self.assertExpectedInline(cnts.frame_count, """3""") # nan always recompiles + self.assertExpectedInline(cnts.frame_count, """4""") # nan always recompiles @torch._dynamo.config.patch(specialize_float=False, capture_scalar_outputs=True) def test_unspecialized_float_multiply_precision(self): dtypes = [torch.bfloat16, torch.float16, torch.float32, torch.float64] - for dtype in dtypes: + for i, dtype in enumerate(dtypes): def fn(x, y): return x * y @@ -662,10 +663,19 @@ class UnspecTests(torch._dynamo.test_case.TestCase): x = torch.randn(5, dtype=dtype, requires_grad=True) y1 = 1.00048828125 y2 = 1.00048828126 + y3 = 1.00048828127 self.assertEqual(fn_opt(x, y1), fn(x, y1)) self.assertEqual(fn_opt(x, y2), fn(x, y2)) - self.assertEqual(cnt.frame_count, 1) + self.assertEqual(fn_opt(x, y3), fn(x, y3)) + if i == 0: + # This is kind of quirky part of automatic dynamic, + # since it just uses source name + tx.f_code as the key + # subsequent recompilations will actually reuse the automatic + # dynamic choices. + self.assertEqual(cnt.frame_count, 2) + else: + self.assertEqual(cnt.frame_count, 1) @torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=False) def test_unspec_float_input_f64(self): diff --git a/test/dynamo/test_utils.py b/test/dynamo/test_utils.py index e4e95bb9c69..578d276ec2c 100644 --- a/test/dynamo/test_utils.py +++ b/test/dynamo/test_utils.py @@ -270,7 +270,7 @@ class TestDynamoTimed(TestCase): 'runtime_cudagraphify_time_us': None, 'runtime_triton_autotune_time_us': None, 'shape_env_guard_count': 0, - 'specialize_float': True, + 'specialize_float': False, 'start_time': 0.0001, 'start_time_us': 100, 'structured_logging_overhead_s': 0.0, diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index 4fe92c73427..41a009a2483 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -607,6 +607,8 @@ if HAS_CUDA and not TEST_WITH_ASAN: @torch._functorch.config.patch("enable_autograd_cache", True) @torch._inductor.config.patch("fx_graph_cache", True) @torch._inductor.config.patch("fx_graph_remote_cache", False) + # Currently fx graph cache is turned off for specialize_float=False + @torch._dynamo.config.patch("specialize_float", True) def test_cache_hit_forward_miss_backward(self): # Test that we don't cache cudagraphs, skipping cudagraphs on backward on a cache miss @@ -661,6 +663,8 @@ if HAS_CUDA and not TEST_WITH_ASAN: @torch._functorch.config.patch("enable_autograd_cache", True) @torch._inductor.config.patch("fx_graph_cache", True) @torch._inductor.config.patch("fx_graph_remote_cache", False) + # Currently fx graph cache is turned off for specialize_float=False + @torch._dynamo.config.patch("specialize_float", True) def test_backward_gets_cached_cudagraphs(self): # We pass cpu tensors to foo and save that into the cache # On a subsequent run in a new process, cudagraphs should be @@ -705,6 +709,8 @@ if HAS_CUDA and not TEST_WITH_ASAN: @torch._functorch.config.patch("enable_autograd_cache", True) @torch._inductor.config.patch("fx_graph_cache", True) @torch._inductor.config.patch("fx_graph_remote_cache", False) + # Currently fx graph cache is turned off for specialize_float=False + @torch._dynamo.config.patch("specialize_float", True) def test_cached_forward_backward(self): counters.clear() AOTAutogradCache.clear() diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 9f78a8b6e9f..df461980258 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -969,7 +969,7 @@ class TestInductorDynamic(TestCase): "divide": operator.truediv, } - for name, op in operations.items(): + for i, (name, op) in enumerate(operations.items()): with self.subTest(operation=name): def fn(x, y): @@ -981,7 +981,14 @@ class TestInductorDynamic(TestCase): x = torch.arange(3) self.assertEqual(fn(x, 2.0), fn_opt(x, 2.0)) self.assertEqual(fn(x, 3.0), fn_opt(x, 3.0)) - self.assertEqual(cnt.frame_count, 1) + self.assertEqual(fn(x, 4.0), fn_opt(x, 4.0)) + if i == 0: + # Automatic dynamic state persists across + # compiles so only the first compile + # goes through the automatic dynamic step. + self.assertEqual(cnt.frame_count, 2) + else: + self.assertEqual(cnt.frame_count, 1) @torch._dynamo.config.patch(specialize_float=False) def test_unspecialized_float_fallback_specialization(self): @@ -1005,8 +1012,25 @@ class TestInductorDynamic(TestCase): self.assertEqual(fn(x, 2.0, z), fn_opt(x, 2.0, z)) self.assertEqual(fn(x, 3.0, z), fn_opt(x, 3.0, z)) self.assertEqual(fn(x, 4.0, z), fn_opt(x, 4.0, z)) - # We expect frame count to be 2 since we will have - # one sledgehammer restart. + # Automatic dynamic float arguments + self.assertEqual(cnt.frame_count, 2) + + def test_unspecialized_float_softshrink(self): + # This test is particularly interesting since it exercises + # both standard operator replacements ie. torch.ops.aten.mul.Tensor + # as well as comparison replacements ie. torch.ops.aten.ge.Scalar + def fn(x, y): + return torch._C._nn.softshrink(x, lambd=y) + + cnt = CompileCounterWithBackend("inductor") + fn_opt = torch._dynamo.optimize(cnt)(fn) + x = torch.randn(5, 5) + + print(fn(x, 2.0), fn_opt(x, 2.0)) + + self.assertEqual(fn(x, 2.0), fn_opt(x, 2.0)) + self.assertEqual(fn(x, 3.0), fn_opt(x, 3.0)) + self.assertEqual(fn(x, 4.0), fn_opt(x, 4.0)) self.assertEqual(cnt.frame_count, 2) @torch._dynamo.config.patch(specialize_float=False) @@ -1021,8 +1045,7 @@ class TestInductorDynamic(TestCase): self.assertEqual(fn(2.0, y), fn_opt(2.0, y)) self.assertEqual(fn(3.0, y), fn_opt(3.0, y)) self.assertEqual(fn(4.0, y), fn_opt(4.0, y)) - # We expect frame count to be N + 1 since we will have - # one sledgehammer restart for the first compile. + # N + 1 for automatic dynamic float arguments self.assertEqual(cnt.frame_count, 4) def test_sort_dynamic_shape_with_check(self, device): diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 5c36654ae5d..139c3ec0b6d 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -65,7 +65,7 @@ specialize_int = False # Whether or not to specialize on float inputs. Dynamo will always promote # float inputs into Tensor inputs, but at the moment, backends inconsistently # support codegen on float (this is to be fixed). -specialize_float = True +specialize_float = True if is_fbcode() else False # legacy config, does nothing now! dynamic_shapes = True diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index f9c0e56ecb8..8243033ba38 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1904,6 +1904,13 @@ class VariableBuilder: if self.name in self.tx.output.unspec_variable_map: return self.tx.output.unspec_variable_map[self.name] + frame_state_entry = process_automatic_dynamic( + self.tx, + self.source.name(), + FrameStateSizeEntry.make_scalar(value), + is_unspecialized_nn_module=self.source.guard_source().is_unspecialized_nn_module(), + ) + # NB: we specialize on nan input, because our guard modeling in # ShapeEnv cannot deal with nan if ( @@ -1918,6 +1925,7 @@ class VariableBuilder: # python test/inductor/test_compiled_optimizers.py CompiledOptimizerTests.test_rmsprop_weight_decay_maximize_capturable_cuda # noqa: B950 or torch._inductor.config.triton.cudagraphs or justknobs_check("pytorch/compiler:unspecialize_float_killswitch", False) + or frame_state_entry.scalar is not auto_dynamic ): self.install_guards(GuardBuilder.CONSTANT_MATCH) return ConstantVariable.create(value=value, source=self.source) diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index ad2833d8dbc..94ad6335b58 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -205,11 +205,6 @@ def check_cacheable(gm: torch.fx.GraphModule): "Cannot cache a graph with compiled autograd enabled" ) - if not torch._dynamo.config.specialize_float: - raise BypassAOTAutogradCache( - "Cannot cache a graph with specialize_float disabled" - ) - if not ( torch._inductor.config.fx_graph_cache or should_use_remote_fx_graph_cache() ): diff --git a/torch/fx/passes/_tensorify_python_scalars.py b/torch/fx/passes/_tensorify_python_scalars.py index 1786cfcecd7..fb7902a7d02 100644 --- a/torch/fx/passes/_tensorify_python_scalars.py +++ b/torch/fx/passes/_tensorify_python_scalars.py @@ -73,10 +73,16 @@ graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code") SUPPORTED_OPS = { - torch.ops.aten.mul.Tensor, - torch.ops.aten.add.Tensor, - torch.ops.aten.sub.Tensor, - torch.ops.aten.div.Tensor, + torch.ops.aten.mul.Tensor: torch.ops.aten.mul.Tensor, + torch.ops.aten.add.Tensor: torch.ops.aten.add.Tensor, + torch.ops.aten.sub.Tensor: torch.ops.aten.sub.Tensor, + torch.ops.aten.div.Tensor: torch.ops.aten.div.Tensor, + torch.ops.aten.gt.Scalar: torch.ops.aten.gt.Tensor, + torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor, + torch.ops.aten.ge.Scalar: torch.ops.aten.ge.Tensor, + torch.ops.aten.le.Scalar: torch.ops.aten.le.Tensor, + torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor, + torch.ops.aten.ne.Scalar: torch.ops.aten.ne.Tensor, } @@ -232,7 +238,9 @@ def tensorify_python_scalars( should_restart = True # Look for functions to convert - if node.op == "call_function" and node.target in SUPPORTED_OPS: + if node.op == "call_function" and ( + replacement_op := SUPPORTED_OPS.get(node.target) + ): args: List[Any] = [] transform = False compute_dtype = get_computation_dtype(node.meta["val"].dtype) @@ -253,7 +261,13 @@ def tensorify_python_scalars( # We use _expr instead of expr b/c we want the symbol not the replacement tensorified_symbols.add(a.meta["val"].node._expr) - if proxy.node.meta["val"].dtype != compute_dtype: + # The upcasting is irrelevant when the compute dtype is bool. This happens + # in cases where we are tensorifying a comparison operator such as + # torch.ops.aten.gt.Tensor + if ( + compute_dtype != torch.bool + and proxy.node.meta["val"].dtype != compute_dtype + ): proxy = torch.ops.prims.convert_element_type.default( proxy, compute_dtype ) @@ -265,7 +279,7 @@ def tensorify_python_scalars( args.append(a) if transform: - replacement_proxy = node.target(*args) + replacement_proxy = replacement_op(*args) if compute_dtype != node.meta["val"].dtype: replacement_proxy = (