From bfb9ea1a43c0fc5e650f22bd21d132e3bbb60a0d Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Mon, 12 Feb 2024 00:22:30 -0800 Subject: [PATCH] fix compile DTensor.from_local in trace_rule_look up (#119659) There's a bug when converting from TorchVariable to trace rule look ups, in some corner cases the DTensor.from_local calls not matching the trace name rule look up, resulting in a None look up, and falling back to the UserFunctionVariable, which makes the tracing silent wrong by tracing into the DTensor.from_local function. Not exactly sure yet why the look up failed This PR fixes the DTensor.from_local tracing to make sure in everycase we should hit the InGraphFunctionVariable Pull Request resolved: https://github.com/pytorch/pytorch/pull/119659 Approved by: https://github.com/yifuwang --- .../_tensor/test_dtensor_compile.py | 34 +++++++++++++------ torch/_dynamo/trace_rules.py | 2 +- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/test/distributed/_tensor/test_dtensor_compile.py b/test/distributed/_tensor/test_dtensor_compile.py index 44ad14efad4..2d5798b457f 100644 --- a/test/distributed/_tensor/test_dtensor_compile.py +++ b/test/distributed/_tensor/test_dtensor_compile.py @@ -8,6 +8,7 @@ from unittest.mock import patch import torch import torch._dynamo +import torch._dynamo.testing import torch.distributed as dist import torch.distributed._functional_collectives as funcol import torch.nn as nn @@ -193,11 +194,16 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase): # _dt_lib_impl = torch.library.Library("dtensor", "IMPL") # _dt_lib_impl.impl("from_local", from_local_tensor, "Autograd") - x = torch.ones(1) + x = torch.ones(1, requires_grad=True) ref = fn(x) - opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True) + cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") + opt_fn = torch.compile(fn, backend=cnt, fullgraph=True) res = opt_fn(x) + # backward should work as well + res.sum().backward() + self.assertEqual(res, ref) + self.assertEqual(cnt.frame_count, 1) # test if user calls from_local with mesh/placements as kwargs and that should still work def from_local_kwargs_fn(x): @@ -207,11 +213,10 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase): return dt.to_local() + 2 ref = from_local_kwargs_fn(x) - opt_kwargs_fn = torch.compile( - from_local_kwargs_fn, backend="aot_eager", fullgraph=True - ) + opt_kwargs_fn = torch.compile(from_local_kwargs_fn, backend=cnt, fullgraph=True) res = opt_kwargs_fn(x) self.assertEqual(res, ref) + self.assertEqual(cnt.frame_count, 2) def test_dynamo_dtensor_from_local_redistribute(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -224,7 +229,8 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase): x = torch.ones(1) ref = fn(x) - opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True) + cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") + opt_fn = torch.compile(fn, backend=cnt, fullgraph=True) res = opt_fn(x) self.assertEqual(res, ref) @@ -238,7 +244,7 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase): x = torch.ones(1) ref = redistribute_kwargs_fn(x) opt_kwargs_fn = torch.compile( - redistribute_kwargs_fn, backend="aot_eager", fullgraph=True + redistribute_kwargs_fn, backend=cnt, fullgraph=True ) res = opt_kwargs_fn(x) self.assertEqual(res, ref) @@ -302,9 +308,12 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase): parallelize_plan=parallel_plan, ) - compiled_model = torch.compile(model) + cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor") + compiled_model = torch.compile(model, backend=cnt, fullgraph=True) inp = torch.rand(20, 16).to(self.device_type) out = compiled_model(inp) + out.sum().backward() + self.assertEqual(cnt.frame_count, 1) code = run_and_get_triton_code(compiled_model, inp) # Check that `buf2` is correctly waited on before first use. @@ -379,9 +388,12 @@ class TestDTensorCompileE2E(DTensorTestBase): torch.manual_seed(rng_seed) inp = torch.rand(20, 10, device=self.device_type) out = model(inp) - compiled_mod = torch.compile(model, backend="aot_eager", fullgraph=True) + cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") + compiled_mod = torch.compile(model, backend=cnt, fullgraph=True) compiled_out = compiled_mod(inp) + compiled_out.sum().backward() self.assertEqual(compiled_out, out) + self.assertEqual(cnt.frame_count, 1) @with_comms @skip_if_lt_x_gpu(4) @@ -431,10 +443,12 @@ class TestDTensorCompileE2E(DTensorTestBase): ) # TODO: once aot autograd support is ready we can just use default backend - compiled_2d = torch.compile(fsdp_2d, backend="aot_eager") + cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") + compiled_2d = torch.compile(fsdp_2d, backend=cnt) compiled_output = compiled_2d(inp) self.assertEqual(out, compiled_output) + self.assertEqual(cnt.frame_count, 1) @with_comms @skip_if_lt_x_gpu(4) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 667a9b93464..98656b7b9ac 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -92,7 +92,7 @@ manual_torch_name_rule_map = { "torch.distributed.is_initialized": TorchInGraphFunctionVariable, "torch.distributed.get_rank": TorchInGraphFunctionVariable, "torch.distributed.get_world_size": TorchInGraphFunctionVariable, - "torch.distributed._tensor.DTensor#from_local": TorchInGraphFunctionVariable, + "torch.distributed._tensor.api.DTensor#from_local": TorchInGraphFunctionVariable, "torch.distributed.distributed_c10d._get_group_size_by_name": TorchInGraphFunctionVariable, "torch.distributed.distributed_c10d._get_group_tag": TorchInGraphFunctionVariable, "torch.distributed.distributed_c10d.get_process_group_ranks": TorchInGraphFunctionVariable,