mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
This PR add support for rowwise sharded embedding by adding a MaskPartial placement that inherits from the default partial placement, and override the Partial constracts to construct the mask and release the mask after the reduction The MaskPartial placement have the potential to support other ops sharding computation that requires a mask for semantic correctness. currently make it live in the embedding ops but we can move it to a common place if needed Pull Request resolved: https://github.com/pytorch/pytorch/pull/118080 Approved by: https://github.com/tianyu-l ghstack dependencies: #118079
188 lines
6.1 KiB
Python
188 lines
6.1 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import sys
|
|
|
|
import torch
|
|
from torch.distributed._tensor import (
|
|
distribute_module,
|
|
distribute_tensor,
|
|
DTensor,
|
|
Replicate,
|
|
Shard,
|
|
)
|
|
from torch.distributed._tensor.debug import CommDebugMode
|
|
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
DTensorTestBase,
|
|
with_comms,
|
|
)
|
|
|
|
if TEST_WITH_DEV_DBG_ASAN:
|
|
print(
|
|
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
|
file=sys.stderr,
|
|
)
|
|
sys.exit(0)
|
|
|
|
|
|
funcol = torch.ops.c10d_functional
|
|
|
|
|
|
class TestEmbeddingOp(DTensorTestBase):
|
|
def _apply_sharding(self, embedding_mod, shard_dim, device_mesh):
|
|
def shard_embedding_fn(name, module, device_mesh):
|
|
for name, param in module.named_parameters():
|
|
dist_param = torch.nn.Parameter(
|
|
distribute_tensor(param, device_mesh, [Shard(shard_dim)])
|
|
)
|
|
module.register_parameter(name, dist_param)
|
|
|
|
sharded_embedding = distribute_module(
|
|
embedding_mod, device_mesh, shard_embedding_fn
|
|
)
|
|
return sharded_embedding
|
|
|
|
def _run_embedding_op_test(
|
|
self,
|
|
device_mesh,
|
|
shard_dim,
|
|
input_size,
|
|
num_embeddings,
|
|
embedding_dim,
|
|
**kwargs,
|
|
):
|
|
# Use same seed.
|
|
torch.manual_seed(0)
|
|
local_embedding = torch.nn.Embedding(
|
|
num_embeddings,
|
|
embedding_dim,
|
|
device=self.device_type,
|
|
**kwargs,
|
|
)
|
|
sharded_embedding = torch.nn.Embedding(
|
|
num_embeddings,
|
|
embedding_dim,
|
|
device=self.device_type,
|
|
**kwargs,
|
|
)
|
|
|
|
# Shard the parameter of local embedding and set it to sharded embedding.
|
|
sharded_embedding.weight = torch.nn.Parameter(
|
|
local_embedding.weight.clone().detach()
|
|
)
|
|
|
|
sharded_embedding = self._apply_sharding(
|
|
sharded_embedding, shard_dim, device_mesh
|
|
)
|
|
|
|
# Run sharded computation
|
|
torch.manual_seed(10)
|
|
inp = torch.randint(
|
|
0, num_embeddings, tuple(input_size), device=self.device_type
|
|
)
|
|
target = torch.empty(
|
|
*inp.size(), embedding_dim, dtype=torch.float, device=self.device_type
|
|
).random_(0, 1)
|
|
dist_inp = distribute_tensor(inp, device_mesh, [Replicate()])
|
|
|
|
# fwd computation, ensure no comm happened
|
|
with CommDebugMode() as fwd_mode:
|
|
dist_output = sharded_embedding(dist_inp)
|
|
self.assertEqual(fwd_mode.get_total_counts(), 0)
|
|
|
|
output = dist_output.full_tensor()
|
|
# Run local computation
|
|
local_output = local_embedding(inp)
|
|
|
|
# Verify
|
|
self.assertEqual(local_output, output)
|
|
|
|
# Use a sample cross entry loss to verify backward and grad computation.
|
|
loss = torch.nn.CrossEntropyLoss()
|
|
emb_loss = loss(
|
|
output,
|
|
target,
|
|
)
|
|
emb_dup_loss = loss(
|
|
local_output,
|
|
target,
|
|
)
|
|
|
|
# local embedding backward
|
|
emb_dup_loss.backward()
|
|
|
|
# sharded embedding bwd computation, ensure no comm happened
|
|
with CommDebugMode() as bwd_mode:
|
|
emb_loss.backward()
|
|
self.assertEqual(bwd_mode.get_total_counts(), 0)
|
|
|
|
gradient = sharded_embedding.weight.grad.full_tensor()
|
|
|
|
local_grad = local_embedding.weight.grad
|
|
|
|
# Verify gradient.
|
|
self.assertEqual(gradient, local_grad)
|
|
|
|
# Validate for torch.nn.functional.embedding version.
|
|
local_output = torch.nn.functional.embedding(
|
|
inp,
|
|
local_embedding.weight,
|
|
**kwargs,
|
|
)
|
|
sharded_output = torch.nn.functional.embedding(
|
|
DTensor.from_local(inp, device_mesh, [Replicate()], run_check=False),
|
|
sharded_embedding.weight,
|
|
**kwargs,
|
|
)
|
|
self.assertEqual(local_output, sharded_output.full_tensor())
|
|
|
|
@with_comms
|
|
def test_sharded_embedding_colwise(self):
|
|
mesh = self.build_device_mesh()
|
|
self._run_embedding_op_test(mesh, 1, [5, 4], 17, 12)
|
|
self._run_embedding_op_test(mesh, 1, [6, 7, 6], 21, 11)
|
|
self._run_embedding_op_test(mesh, 1, [8, 6, 5, 4], 23, 13)
|
|
self._run_embedding_op_test(mesh, 1, [8, 6, 5, 4, 7], 23, 16)
|
|
self._run_embedding_op_test(mesh, 1, [4], 15, 14)
|
|
self._run_embedding_op_test(mesh, 1, [34], 15, 14, padding_idx=10)
|
|
self._run_embedding_op_test(mesh, 1, [8, 6, 5, 4], 23, 13, padding_idx=12)
|
|
|
|
@with_comms
|
|
def test_sharded_embedding_colwise_max_norm_errors(self):
|
|
mesh = self.build_device_mesh()
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError,
|
|
"aten.embedding_renorm_.default does not have a sharding strategy registered.",
|
|
):
|
|
self._run_embedding_op_test(
|
|
mesh, 1, [8, 6, 5, 4], 23, 13, padding_idx=12, max_norm=2.0
|
|
)
|
|
|
|
@with_comms
|
|
def test_sharded_embedding_rowwise(self):
|
|
mesh = self.build_device_mesh()
|
|
# test correctness
|
|
self._run_embedding_op_test(mesh, 0, [5, 12], 16, 22)
|
|
self._run_embedding_op_test(mesh, 0, [6, 7, 6], 13, 22)
|
|
self._run_embedding_op_test(mesh, 0, [34], 15, 14, padding_idx=10)
|
|
|
|
from torch.distributed._tensor.ops.embedding_ops import _MaskPartial
|
|
|
|
# test collectives
|
|
embedding_mod = torch.nn.Embedding(10, 20, device=self.device_type)
|
|
sharded_embedding = self._apply_sharding(embedding_mod, 0, mesh)
|
|
inp = torch.randint(0, 10, (8, 8), device=self.device_type)
|
|
replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
|
|
output = sharded_embedding(replicated_inp)
|
|
self.assertIsInstance(output.placements[0], _MaskPartial)
|
|
|
|
comm_mode = CommDebugMode()
|
|
|
|
with comm_mode:
|
|
output.full_tensor()
|
|
self.assertEqual(comm_mode.get_total_counts(), 1)
|
|
self.assertEqual(comm_mode.get_comm_counts()[funcol.all_reduce], 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|