mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
reland of https://github.com/pytorch/pytorch/pull/133113 I have to create a new PR because the previous reverted PR could not either be rebased, or imported successfully :( ---- Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes: * many path renames and path import fixes * a dedicated doc page without too much content yet (adding in the next PRs) * To preserve the BC for users still using the torch.distributed._tensor, I added a shim script to redirect old path calls to the new module The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes Pull Request resolved: https://github.com/pytorch/pytorch/pull/134203 Approved by: https://github.com/tianyu-l
417 lines
17 KiB
Python
417 lines
17 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
|
|
import torch
|
|
from torch.distributed._tensor import DeviceMesh
|
|
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
|
|
from torch.distributed.tensor._op_schema import OpSchema
|
|
from torch.distributed.tensor._ops._common_rules import einop_rule, pointwise_rule
|
|
from torch.testing._internal.common_utils import run_tests
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
DTensorTestBase,
|
|
with_comms,
|
|
)
|
|
|
|
|
|
aten = torch.ops.aten
|
|
|
|
|
|
class CommonRulesTest(DTensorTestBase):
|
|
@property
|
|
def world_size(self) -> int:
|
|
# hard code world size to 4 as we need to test
|
|
# at least with 2d mesh
|
|
return 4
|
|
|
|
def _gen_tensor_meta(self, shape):
|
|
empty_tensor = torch.empty(shape)
|
|
return TensorMeta(
|
|
empty_tensor.shape,
|
|
empty_tensor.stride(),
|
|
empty_tensor.dtype,
|
|
)
|
|
|
|
@with_comms
|
|
def test_einop_basic_propagation(self):
|
|
# plain einsum, mm
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
mm_call = aten.mm.default
|
|
# propagate col-wise sharding
|
|
mat1, mat2 = [-1, -1], [-1, 0]
|
|
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([4, 8]))
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
output_sharding = einop_rule(
|
|
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
|
|
)
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNotNone(output_spec)
|
|
self.assertEqual(output_spec.dim_map, [-1, 0])
|
|
|
|
# propagate row-wise sharding
|
|
mat1, mat2 = [0, -1], [-1, -1]
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
output_sharding = einop_rule(
|
|
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
|
|
)
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNotNone(output_spec)
|
|
self.assertEqual(output_spec.dim_map, [0, -1])
|
|
|
|
# generate partial
|
|
mat1, mat2 = [-1, 0], [0, -1]
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
output_sharding = einop_rule(
|
|
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
|
|
)
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNotNone(output_spec)
|
|
self.assertTrue(output_spec.placements[0].is_partial())
|
|
|
|
@with_comms
|
|
def test_einop_pointwise_propagation(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
add_call = aten.add.Tensor
|
|
# addition
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 8]))
|
|
mat1 = [0, -1]
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
output_sharding = einop_rule(
|
|
"ij,ij->ij", OpSchema(add_call, (mat1_spec, mat1_spec), {})
|
|
)
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNotNone(output_spec)
|
|
self.assertEqual(output_spec.dim_map, [0, -1])
|
|
|
|
# broadcast addition
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 8]))
|
|
mat1 = [-1, 0, -1]
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([2]))
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, [-1], [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
output_sharding = einop_rule(
|
|
"ijk,k->ijk", OpSchema(add_call, (mat1_spec, mat2_spec), {})
|
|
)
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNotNone(output_spec)
|
|
self.assertEqual(output_spec.dim_map, [-1, 0, -1])
|
|
|
|
# broadcast to a common shape
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 8, 8]))
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([1, 8]))
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, [0, -1, -1], [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, [-1, -1], [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
output_sharding = einop_rule(
|
|
"ijk,1k->ijk", OpSchema(add_call, (mat1_spec, mat2_spec), {})
|
|
)
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNotNone(output_spec)
|
|
self.assertEqual(output_spec.dim_map, [0, -1, -1])
|
|
|
|
@with_comms
|
|
def test_einop_merge_sharding(self):
|
|
# 2d mesh einop merge sharding
|
|
mesh_shape = torch.arange(self.world_size).reshape(
|
|
self.world_size // 2, self.world_size // 2
|
|
)
|
|
mesh = DeviceMesh(self.device_type, mesh_shape)
|
|
|
|
mm_call = aten.mm.default
|
|
|
|
mat1, mat2 = [0, -1], [-1, 1]
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([4, 8]))
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
output_sharding = einop_rule(
|
|
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
|
|
)
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNotNone(output_spec)
|
|
self.assertEqual(output_spec.dim_map, [0, 1])
|
|
|
|
@with_comms
|
|
def test_einop_linearity(self):
|
|
mesh_shape = torch.arange(self.world_size).reshape(
|
|
self.world_size // 2, self.world_size // 2
|
|
)
|
|
mesh = DeviceMesh(self.device_type, mesh_shape)
|
|
|
|
mm_call = aten.mm.default
|
|
|
|
mat1, mat2 = [0, -1], [-1, -1]
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([4, 8]))
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [1], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
# if not turn on linearity, partial sum is not eligible to propagate, we return
|
|
# suggestion to reshard inputs with no partial sum (i.e. all_reduce one input)
|
|
output_sharding = einop_rule(
|
|
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
|
|
)
|
|
self.assertIsNone(output_sharding.output_spec)
|
|
suggestions = output_sharding.redistribute_schema
|
|
self.assertIsNotNone(suggestions)
|
|
suggested_spec = suggestions.args_schema[0]
|
|
self.assertFalse(suggested_spec.placements[1].is_partial())
|
|
|
|
# einop prop with linearity on mm, should give back suggestion
|
|
# on converting placements to partial
|
|
output_sharding = einop_rule(
|
|
"mk,kn->mn",
|
|
OpSchema(mm_call, (mat1_spec, mat2_spec), {}),
|
|
linearity=True,
|
|
)
|
|
self.assertIsNone(output_sharding.output_spec)
|
|
suggestions = output_sharding.redistribute_schema
|
|
self.assertIsNotNone(suggestions)
|
|
mat2_spec = suggestions.args_schema[1]
|
|
# mat2 mesh dim 1 should become partial now!
|
|
self.assertTrue(mat2_spec.placements[1].is_partial())
|
|
|
|
# einop prop with linearity on point-wise, should give back suggestion
|
|
# on converting placements to partial
|
|
add_call = aten.add.Tensor
|
|
mat1, mat2 = [0, -1], [0, -1]
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 6]))
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([8, 6]))
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [1], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
|
|
output_sharding = einop_rule(
|
|
"ij,ij->ij",
|
|
OpSchema(add_call, (mat1_spec, mat2_spec), {}),
|
|
linearity=True,
|
|
)
|
|
self.assertIsNone(output_sharding.output_spec)
|
|
suggestions = output_sharding.redistribute_schema
|
|
self.assertIsNotNone(suggestions)
|
|
mat2_spec = suggestions.args_schema[1]
|
|
# mat2 mesh dim 1 should become partial now!
|
|
self.assertTrue(mat2_spec.placements[1].is_partial())
|
|
|
|
@with_comms
|
|
def test_einop_multi_sharding_on_mesh_dim(self):
|
|
# einop prop with multi sharding on same mesh dim
|
|
mesh_shape = torch.arange(self.world_size)
|
|
mesh = DeviceMesh(self.device_type, mesh_shape)
|
|
|
|
mm_call = aten.mm.default
|
|
mat1, mat2 = [0, -1], [0, -1]
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 12]))
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([12, 4]))
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
output_sharding = einop_rule(
|
|
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
|
|
)
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNone(output_spec)
|
|
self.assertIsNotNone(output_sharding.redistribute_schema)
|
|
|
|
# ensure that the suggestion is to reshard the second
|
|
# arg by all_gather its tensor dim sharding
|
|
schema_suggestion = output_sharding.redistribute_schema
|
|
self.assertEqual(schema_suggestion.args_schema[0].dim_map, [0, -1])
|
|
self.assertEqual(schema_suggestion.args_schema[1].dim_map, [-1, -1])
|
|
|
|
@with_comms
|
|
def test_einop_errors(self):
|
|
mesh_shape = torch.arange(self.world_size).reshape(
|
|
self.world_size // 2, self.world_size // 2
|
|
)
|
|
mesh = DeviceMesh(self.device_type, mesh_shape)
|
|
|
|
add_call = aten.add.Tensor
|
|
mat1, mat2 = [0, -1], [1, -1]
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "sharded two different ways:"):
|
|
einop_rule("ij,ij->ij", OpSchema(add_call, (mat1_spec, mat2_spec), {}))
|
|
|
|
@with_comms
|
|
def test_pointwise_rules_broadcasting(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
where_call = aten.where.self
|
|
inp1, inp2, inp3 = [0], [], [-1, -1]
|
|
inp1_tensor_meta = self._gen_tensor_meta(torch.Size([8]))
|
|
inp2_tensor_meta = self._gen_tensor_meta(torch.Size([]))
|
|
inp3_tensor_meta = self._gen_tensor_meta(torch.Size([1, 1]))
|
|
condition = DTensorSpec.from_dim_map(
|
|
mesh, inp1, [], tensor_meta=inp1_tensor_meta
|
|
)
|
|
self_tensor = DTensorSpec.from_dim_map(
|
|
mesh, inp2, [], tensor_meta=inp2_tensor_meta
|
|
)
|
|
other_tensor = DTensorSpec.from_dim_map(
|
|
mesh, inp3, [], tensor_meta=inp3_tensor_meta
|
|
)
|
|
# propagate point-wise sharding with broadcasting
|
|
output_sharding = pointwise_rule(
|
|
OpSchema(where_call, (condition, self_tensor, other_tensor), {})
|
|
)
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNotNone(output_spec)
|
|
self.assertEqual(output_spec.dim_map, [-1, 0])
|
|
|
|
@with_comms
|
|
def test_pointwise_rules_suggestion(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
lerp_call = aten.lerp.Scalar
|
|
# propagate point-wise sharding
|
|
inp1, inp2 = [-1, -1], [-1, 0]
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, inp1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, inp2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
# adding a positional argument -1 to arg schema
|
|
output_sharding = pointwise_rule(
|
|
OpSchema(lerp_call, (mat1_spec, mat2_spec, -1), {})
|
|
)
|
|
self.assertIsNone(output_sharding.output_spec)
|
|
self.assertIsNotNone(output_sharding.redistribute_schema)
|
|
|
|
# ensure that the suggestion from pointwise rules still have
|
|
# the positional args that are not DTensorSpec
|
|
schema_suggestion = output_sharding.redistribute_schema
|
|
self.assertEqual(len(schema_suggestion.args_schema), 3)
|
|
self.assertEqual(schema_suggestion.args_schema[2], -1)
|
|
|
|
@with_comms
|
|
def test_pointwise_multi_sharding_on_mesh_dim(self):
|
|
# 2d mesh pointwise sharding
|
|
mesh_shape = torch.arange(self.world_size).reshape(
|
|
self.world_size // 2, self.world_size // 2
|
|
)
|
|
mesh = DeviceMesh(self.device_type, mesh_shape)
|
|
|
|
add_call = aten.add.Tensor
|
|
|
|
# basic case to test implicit broadcasting shape alignment
|
|
mat1, mat2 = [-1, 0], [0]
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([20, 6]))
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([6]))
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
output_sharding = pointwise_rule(OpSchema(add_call, (mat1_spec, mat2_spec), {}))
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNotNone(output_spec)
|
|
self.assertEqual(output_spec.dim_map, [-1, 0])
|
|
|
|
# more advanced case that needs reshard one input to align sharding
|
|
mat1, mat2 = [0, -1, -1, 1], [0, -1, 1]
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([12, 1, 1, 8]))
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([12, 4, 8]))
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
output_sharding = pointwise_rule(OpSchema(add_call, (mat1_spec, mat2_spec), {}))
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNone(output_spec)
|
|
self.assertIsNotNone(output_sharding.redistribute_schema)
|
|
|
|
# ensure that the suggestion is to reshard the first
|
|
# arg by all_gather first tensor dim sharding
|
|
schema_suggestion = output_sharding.redistribute_schema
|
|
self.assertEqual(schema_suggestion.args_schema[0].dim_map, [-1, -1, -1, 1])
|
|
self.assertEqual(schema_suggestion.args_schema[1].dim_map, mat2)
|
|
|
|
@with_comms
|
|
def test_pointwise_enforce_sharding_multi_sharding_on_mesh_dim(self):
|
|
# 2d mesh pointwise sharding
|
|
mesh_shape = torch.arange(self.world_size).reshape(
|
|
self.world_size // 2, self.world_size // 2
|
|
)
|
|
mesh = DeviceMesh(self.device_type, mesh_shape)
|
|
|
|
add_call = aten.add_.Tensor
|
|
|
|
# more advanced case that needs reshard one input to align sharding
|
|
mat1, mat2 = [0, -1, 1], [-1, -1, 0]
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([12, 4, 8]))
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([12, 1, 8]))
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
output_sharding = pointwise_rule(OpSchema(add_call, (mat1_spec, mat2_spec), {}))
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNone(output_spec)
|
|
self.assertIsNotNone(output_sharding.redistribute_schema)
|
|
|
|
# ensure that the suggestion is to reshard the second
|
|
# arg as we should enforce the sharding of the first arg
|
|
schema_suggestion = output_sharding.redistribute_schema
|
|
self.assertEqual(schema_suggestion.args_schema[0].dim_map, mat1)
|
|
self.assertEqual(schema_suggestion.args_schema[1].dim_map, mat1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|