fix compile DTensor.from_local in trace_rule_look up (#119659)

There's a bug when converting from TorchVariable to trace rule look ups,
in some corner cases the DTensor.from_local calls not matching the trace
name rule look up, resulting in a None look up, and falling back to the
UserFunctionVariable, which makes the tracing silent wrong by tracing
into the DTensor.from_local function. Not exactly sure yet why the look
up failed

This PR fixes the DTensor.from_local tracing to make sure in everycase
we should hit the InGraphFunctionVariable

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119659
Approved by: https://github.com/yifuwang
This commit is contained in:
Wanchao Liang 2024-02-12 00:22:30 -08:00 committed by PyTorch MergeBot
parent 379183a0dd
commit bfb9ea1a43
2 changed files with 25 additions and 11 deletions

View file

@ -8,6 +8,7 @@ from unittest.mock import patch
import torch
import torch._dynamo
import torch._dynamo.testing
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
import torch.nn as nn
@ -193,11 +194,16 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
# _dt_lib_impl = torch.library.Library("dtensor", "IMPL")
# _dt_lib_impl.impl("from_local", from_local_tensor, "Autograd")
x = torch.ones(1)
x = torch.ones(1, requires_grad=True)
ref = fn(x)
opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
opt_fn = torch.compile(fn, backend=cnt, fullgraph=True)
res = opt_fn(x)
# backward should work as well
res.sum().backward()
self.assertEqual(res, ref)
self.assertEqual(cnt.frame_count, 1)
# test if user calls from_local with mesh/placements as kwargs and that should still work
def from_local_kwargs_fn(x):
@ -207,11 +213,10 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
return dt.to_local() + 2
ref = from_local_kwargs_fn(x)
opt_kwargs_fn = torch.compile(
from_local_kwargs_fn, backend="aot_eager", fullgraph=True
)
opt_kwargs_fn = torch.compile(from_local_kwargs_fn, backend=cnt, fullgraph=True)
res = opt_kwargs_fn(x)
self.assertEqual(res, ref)
self.assertEqual(cnt.frame_count, 2)
def test_dynamo_dtensor_from_local_redistribute(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
@ -224,7 +229,8 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
x = torch.ones(1)
ref = fn(x)
opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
opt_fn = torch.compile(fn, backend=cnt, fullgraph=True)
res = opt_fn(x)
self.assertEqual(res, ref)
@ -238,7 +244,7 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
x = torch.ones(1)
ref = redistribute_kwargs_fn(x)
opt_kwargs_fn = torch.compile(
redistribute_kwargs_fn, backend="aot_eager", fullgraph=True
redistribute_kwargs_fn, backend=cnt, fullgraph=True
)
res = opt_kwargs_fn(x)
self.assertEqual(res, ref)
@ -302,9 +308,12 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
parallelize_plan=parallel_plan,
)
compiled_model = torch.compile(model)
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
compiled_model = torch.compile(model, backend=cnt, fullgraph=True)
inp = torch.rand(20, 16).to(self.device_type)
out = compiled_model(inp)
out.sum().backward()
self.assertEqual(cnt.frame_count, 1)
code = run_and_get_triton_code(compiled_model, inp)
# Check that `buf2` is correctly waited on before first use.
@ -379,9 +388,12 @@ class TestDTensorCompileE2E(DTensorTestBase):
torch.manual_seed(rng_seed)
inp = torch.rand(20, 10, device=self.device_type)
out = model(inp)
compiled_mod = torch.compile(model, backend="aot_eager", fullgraph=True)
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
compiled_mod = torch.compile(model, backend=cnt, fullgraph=True)
compiled_out = compiled_mod(inp)
compiled_out.sum().backward()
self.assertEqual(compiled_out, out)
self.assertEqual(cnt.frame_count, 1)
@with_comms
@skip_if_lt_x_gpu(4)
@ -431,10 +443,12 @@ class TestDTensorCompileE2E(DTensorTestBase):
)
# TODO: once aot autograd support is ready we can just use default backend
compiled_2d = torch.compile(fsdp_2d, backend="aot_eager")
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
compiled_2d = torch.compile(fsdp_2d, backend=cnt)
compiled_output = compiled_2d(inp)
self.assertEqual(out, compiled_output)
self.assertEqual(cnt.frame_count, 1)
@with_comms
@skip_if_lt_x_gpu(4)

View file

@ -92,7 +92,7 @@ manual_torch_name_rule_map = {
"torch.distributed.is_initialized": TorchInGraphFunctionVariable,
"torch.distributed.get_rank": TorchInGraphFunctionVariable,
"torch.distributed.get_world_size": TorchInGraphFunctionVariable,
"torch.distributed._tensor.DTensor#from_local": TorchInGraphFunctionVariable,
"torch.distributed._tensor.api.DTensor#from_local": TorchInGraphFunctionVariable,
"torch.distributed.distributed_c10d._get_group_size_by_name": TorchInGraphFunctionVariable,
"torch.distributed.distributed_c10d._get_group_tag": TorchInGraphFunctionVariable,
"torch.distributed.distributed_c10d.get_process_group_ranks": TorchInGraphFunctionVariable,