mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
fix compile DTensor.from_local in trace_rule_look up (#119659)
There's a bug when converting from TorchVariable to trace rule look ups, in some corner cases the DTensor.from_local calls not matching the trace name rule look up, resulting in a None look up, and falling back to the UserFunctionVariable, which makes the tracing silent wrong by tracing into the DTensor.from_local function. Not exactly sure yet why the look up failed This PR fixes the DTensor.from_local tracing to make sure in everycase we should hit the InGraphFunctionVariable Pull Request resolved: https://github.com/pytorch/pytorch/pull/119659 Approved by: https://github.com/yifuwang
This commit is contained in:
parent
379183a0dd
commit
bfb9ea1a43
2 changed files with 25 additions and 11 deletions
|
|
@ -8,6 +8,7 @@ from unittest.mock import patch
|
|||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._dynamo.testing
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
import torch.nn as nn
|
||||
|
|
@ -193,11 +194,16 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
|
|||
# _dt_lib_impl = torch.library.Library("dtensor", "IMPL")
|
||||
# _dt_lib_impl.impl("from_local", from_local_tensor, "Autograd")
|
||||
|
||||
x = torch.ones(1)
|
||||
x = torch.ones(1, requires_grad=True)
|
||||
ref = fn(x)
|
||||
opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
|
||||
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||
opt_fn = torch.compile(fn, backend=cnt, fullgraph=True)
|
||||
res = opt_fn(x)
|
||||
# backward should work as well
|
||||
res.sum().backward()
|
||||
|
||||
self.assertEqual(res, ref)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
# test if user calls from_local with mesh/placements as kwargs and that should still work
|
||||
def from_local_kwargs_fn(x):
|
||||
|
|
@ -207,11 +213,10 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
|
|||
return dt.to_local() + 2
|
||||
|
||||
ref = from_local_kwargs_fn(x)
|
||||
opt_kwargs_fn = torch.compile(
|
||||
from_local_kwargs_fn, backend="aot_eager", fullgraph=True
|
||||
)
|
||||
opt_kwargs_fn = torch.compile(from_local_kwargs_fn, backend=cnt, fullgraph=True)
|
||||
res = opt_kwargs_fn(x)
|
||||
self.assertEqual(res, ref)
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
def test_dynamo_dtensor_from_local_redistribute(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
|
@ -224,7 +229,8 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
|
|||
|
||||
x = torch.ones(1)
|
||||
ref = fn(x)
|
||||
opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
|
||||
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||
opt_fn = torch.compile(fn, backend=cnt, fullgraph=True)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
|
|
@ -238,7 +244,7 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
|
|||
x = torch.ones(1)
|
||||
ref = redistribute_kwargs_fn(x)
|
||||
opt_kwargs_fn = torch.compile(
|
||||
redistribute_kwargs_fn, backend="aot_eager", fullgraph=True
|
||||
redistribute_kwargs_fn, backend=cnt, fullgraph=True
|
||||
)
|
||||
res = opt_kwargs_fn(x)
|
||||
self.assertEqual(res, ref)
|
||||
|
|
@ -302,9 +308,12 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
|
|||
parallelize_plan=parallel_plan,
|
||||
)
|
||||
|
||||
compiled_model = torch.compile(model)
|
||||
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
|
||||
compiled_model = torch.compile(model, backend=cnt, fullgraph=True)
|
||||
inp = torch.rand(20, 16).to(self.device_type)
|
||||
out = compiled_model(inp)
|
||||
out.sum().backward()
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
code = run_and_get_triton_code(compiled_model, inp)
|
||||
# Check that `buf2` is correctly waited on before first use.
|
||||
|
|
@ -379,9 +388,12 @@ class TestDTensorCompileE2E(DTensorTestBase):
|
|||
torch.manual_seed(rng_seed)
|
||||
inp = torch.rand(20, 10, device=self.device_type)
|
||||
out = model(inp)
|
||||
compiled_mod = torch.compile(model, backend="aot_eager", fullgraph=True)
|
||||
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||
compiled_mod = torch.compile(model, backend=cnt, fullgraph=True)
|
||||
compiled_out = compiled_mod(inp)
|
||||
compiled_out.sum().backward()
|
||||
self.assertEqual(compiled_out, out)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
|
|
@ -431,10 +443,12 @@ class TestDTensorCompileE2E(DTensorTestBase):
|
|||
)
|
||||
|
||||
# TODO: once aot autograd support is ready we can just use default backend
|
||||
compiled_2d = torch.compile(fsdp_2d, backend="aot_eager")
|
||||
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||
compiled_2d = torch.compile(fsdp_2d, backend=cnt)
|
||||
compiled_output = compiled_2d(inp)
|
||||
|
||||
self.assertEqual(out, compiled_output)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ manual_torch_name_rule_map = {
|
|||
"torch.distributed.is_initialized": TorchInGraphFunctionVariable,
|
||||
"torch.distributed.get_rank": TorchInGraphFunctionVariable,
|
||||
"torch.distributed.get_world_size": TorchInGraphFunctionVariable,
|
||||
"torch.distributed._tensor.DTensor#from_local": TorchInGraphFunctionVariable,
|
||||
"torch.distributed._tensor.api.DTensor#from_local": TorchInGraphFunctionVariable,
|
||||
"torch.distributed.distributed_c10d._get_group_size_by_name": TorchInGraphFunctionVariable,
|
||||
"torch.distributed.distributed_c10d._get_group_tag": TorchInGraphFunctionVariable,
|
||||
"torch.distributed.distributed_c10d.get_process_group_ranks": TorchInGraphFunctionVariable,
|
||||
|
|
|
|||
Loading…
Reference in a new issue