mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
as titled, mainly moving files Pull Request resolved: https://github.com/pytorch/pytorch/pull/144166 Approved by: https://github.com/Skylion007
417 lines
17 KiB
Python
417 lines
17 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
|
|
import torch
|
|
from torch.distributed._tensor import DeviceMesh
|
|
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
|
|
from torch.distributed.tensor._op_schema import OpSchema
|
|
from torch.distributed.tensor._ops._common_rules import einop_rule, pointwise_rule
|
|
from torch.testing._internal.common_utils import run_tests
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
DTensorTestBase,
|
|
with_comms,
|
|
)
|
|
|
|
|
|
aten = torch.ops.aten
|
|
|
|
|
|
class CommonRulesTest(DTensorTestBase):
|
|
@property
|
|
def world_size(self) -> int:
|
|
# hard code world size to 4 as we need to test
|
|
# at least with 2d mesh
|
|
return 4
|
|
|
|
def _gen_tensor_meta(self, shape):
|
|
empty_tensor = torch.empty(shape)
|
|
return TensorMeta(
|
|
empty_tensor.shape,
|
|
empty_tensor.stride(),
|
|
empty_tensor.dtype,
|
|
)
|
|
|
|
@with_comms
|
|
def test_einop_basic_propagation(self):
|
|
# plain einsum, mm
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
mm_call = aten.mm.default
|
|
# propagate col-wise sharding
|
|
mat1, mat2 = [-1, -1], [-1, 0]
|
|
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([4, 8]))
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
output_sharding = einop_rule(
|
|
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
|
|
)
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNotNone(output_spec)
|
|
self.assertEqual(output_spec.dim_map, [-1, 0])
|
|
|
|
# propagate row-wise sharding
|
|
mat1, mat2 = [0, -1], [-1, -1]
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
output_sharding = einop_rule(
|
|
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
|
|
)
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNotNone(output_spec)
|
|
self.assertEqual(output_spec.dim_map, [0, -1])
|
|
|
|
# generate partial
|
|
mat1, mat2 = [-1, 0], [0, -1]
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
output_sharding = einop_rule(
|
|
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
|
|
)
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNotNone(output_spec)
|
|
self.assertTrue(output_spec.placements[0].is_partial())
|
|
|
|
@with_comms
|
|
def test_einop_pointwise_propagation(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
add_call = aten.add.Tensor
|
|
# addition
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 8]))
|
|
mat1 = [0, -1]
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
output_sharding = einop_rule(
|
|
"ij,ij->ij", OpSchema(add_call, (mat1_spec, mat1_spec), {})
|
|
)
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNotNone(output_spec)
|
|
self.assertEqual(output_spec.dim_map, [0, -1])
|
|
|
|
# broadcast addition
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 8]))
|
|
mat1 = [-1, 0, -1]
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([2]))
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, [-1], [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
output_sharding = einop_rule(
|
|
"ijk,k->ijk", OpSchema(add_call, (mat1_spec, mat2_spec), {})
|
|
)
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNotNone(output_spec)
|
|
self.assertEqual(output_spec.dim_map, [-1, 0, -1])
|
|
|
|
# broadcast to a common shape
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 8, 8]))
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([1, 8]))
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, [0, -1, -1], [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, [-1, -1], [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
output_sharding = einop_rule(
|
|
"ijk,1k->ijk", OpSchema(add_call, (mat1_spec, mat2_spec), {})
|
|
)
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNotNone(output_spec)
|
|
self.assertEqual(output_spec.dim_map, [0, -1, -1])
|
|
|
|
@with_comms
|
|
def test_einop_merge_sharding(self):
|
|
# 2d mesh einop merge sharding
|
|
mesh_shape = torch.arange(self.world_size).reshape(
|
|
self.world_size // 2, self.world_size // 2
|
|
)
|
|
mesh = DeviceMesh(self.device_type, mesh_shape)
|
|
|
|
mm_call = aten.mm.default
|
|
|
|
mat1, mat2 = [0, -1], [-1, 1]
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([4, 8]))
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
output_sharding = einop_rule(
|
|
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
|
|
)
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNotNone(output_spec)
|
|
self.assertEqual(output_spec.dim_map, [0, 1])
|
|
|
|
@with_comms
|
|
def test_einop_linearity(self):
|
|
mesh_shape = torch.arange(self.world_size).reshape(
|
|
self.world_size // 2, self.world_size // 2
|
|
)
|
|
mesh = DeviceMesh(self.device_type, mesh_shape)
|
|
|
|
mm_call = aten.mm.default
|
|
|
|
mat1, mat2 = [0, -1], [-1, -1]
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([4, 8]))
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [1], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
# if not turn on linearity, partial sum is not eligible to propagate, we return
|
|
# suggestion to reshard inputs with no partial sum (i.e. all_reduce one input)
|
|
output_sharding = einop_rule(
|
|
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
|
|
)
|
|
self.assertIsNone(output_sharding.output_spec)
|
|
suggestions = output_sharding.redistribute_schema
|
|
self.assertIsNotNone(suggestions)
|
|
suggested_spec = suggestions.args_schema[0]
|
|
self.assertFalse(suggested_spec.placements[1].is_partial())
|
|
|
|
# einop prop with linearity on mm, should give back suggestion
|
|
# on converting placements to partial
|
|
output_sharding = einop_rule(
|
|
"mk,kn->mn",
|
|
OpSchema(mm_call, (mat1_spec, mat2_spec), {}),
|
|
linearity=True,
|
|
)
|
|
self.assertIsNone(output_sharding.output_spec)
|
|
suggestions = output_sharding.redistribute_schema
|
|
self.assertIsNotNone(suggestions)
|
|
mat2_spec = suggestions.args_schema[1]
|
|
# mat2 mesh dim 1 should become partial now!
|
|
self.assertTrue(mat2_spec.placements[1].is_partial())
|
|
|
|
# einop prop with linearity on point-wise, should give back suggestion
|
|
# on converting placements to partial
|
|
add_call = aten.add.Tensor
|
|
mat1, mat2 = [0, -1], [0, -1]
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 6]))
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([8, 6]))
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [1], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
|
|
output_sharding = einop_rule(
|
|
"ij,ij->ij",
|
|
OpSchema(add_call, (mat1_spec, mat2_spec), {}),
|
|
linearity=True,
|
|
)
|
|
self.assertIsNone(output_sharding.output_spec)
|
|
suggestions = output_sharding.redistribute_schema
|
|
self.assertIsNotNone(suggestions)
|
|
mat2_spec = suggestions.args_schema[1]
|
|
# mat2 mesh dim 1 should become partial now!
|
|
self.assertTrue(mat2_spec.placements[1].is_partial())
|
|
|
|
@with_comms
|
|
def test_einop_multi_sharding_on_mesh_dim(self):
|
|
# einop prop with multi sharding on same mesh dim
|
|
mesh_shape = torch.arange(self.world_size)
|
|
mesh = DeviceMesh(self.device_type, mesh_shape)
|
|
|
|
mm_call = aten.mm.default
|
|
mat1, mat2 = [0, -1], [0, -1]
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 12]))
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([12, 4]))
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
output_sharding = einop_rule(
|
|
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
|
|
)
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNone(output_spec)
|
|
self.assertIsNotNone(output_sharding.redistribute_schema)
|
|
|
|
# ensure that the suggestion is to reshard the second
|
|
# arg by all_gather its tensor dim sharding
|
|
schema_suggestion = output_sharding.redistribute_schema
|
|
self.assertEqual(schema_suggestion.args_schema[0].dim_map, [0, -1])
|
|
self.assertEqual(schema_suggestion.args_schema[1].dim_map, [-1, -1])
|
|
|
|
@with_comms
|
|
def test_einop_errors(self):
|
|
mesh_shape = torch.arange(self.world_size).reshape(
|
|
self.world_size // 2, self.world_size // 2
|
|
)
|
|
mesh = DeviceMesh(self.device_type, mesh_shape)
|
|
|
|
add_call = aten.add.Tensor
|
|
mat1, mat2 = [0, -1], [1, -1]
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "sharded two different ways:"):
|
|
einop_rule("ij,ij->ij", OpSchema(add_call, (mat1_spec, mat2_spec), {}))
|
|
|
|
@with_comms
|
|
def test_pointwise_rules_broadcasting(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
where_call = aten.where.self
|
|
inp1, inp2, inp3 = [0], [], [-1, -1]
|
|
inp1_tensor_meta = self._gen_tensor_meta(torch.Size([8]))
|
|
inp2_tensor_meta = self._gen_tensor_meta(torch.Size([]))
|
|
inp3_tensor_meta = self._gen_tensor_meta(torch.Size([1, 1]))
|
|
condition = DTensorSpec.from_dim_map(
|
|
mesh, inp1, [], tensor_meta=inp1_tensor_meta
|
|
)
|
|
self_tensor = DTensorSpec.from_dim_map(
|
|
mesh, inp2, [], tensor_meta=inp2_tensor_meta
|
|
)
|
|
other_tensor = DTensorSpec.from_dim_map(
|
|
mesh, inp3, [], tensor_meta=inp3_tensor_meta
|
|
)
|
|
# propagate point-wise sharding with broadcasting
|
|
output_sharding = pointwise_rule(
|
|
OpSchema(where_call, (condition, self_tensor, other_tensor), {})
|
|
)
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNotNone(output_spec)
|
|
self.assertEqual(output_spec.dim_map, [-1, 0])
|
|
|
|
@with_comms
|
|
def test_pointwise_rules_suggestion(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
lerp_call = aten.lerp.Scalar
|
|
# propagate point-wise sharding
|
|
inp1, inp2 = [-1, -1], [-1, 0]
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, inp1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, inp2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
# adding a positional argument -1 to arg schema
|
|
output_sharding = pointwise_rule(
|
|
OpSchema(lerp_call, (mat1_spec, mat2_spec, -1), {})
|
|
)
|
|
self.assertIsNone(output_sharding.output_spec)
|
|
self.assertIsNotNone(output_sharding.redistribute_schema)
|
|
|
|
# ensure that the suggestion from pointwise rules still have
|
|
# the positional args that are not DTensorSpec
|
|
schema_suggestion = output_sharding.redistribute_schema
|
|
self.assertEqual(len(schema_suggestion.args_schema), 3)
|
|
self.assertEqual(schema_suggestion.args_schema[2], -1)
|
|
|
|
@with_comms
|
|
def test_pointwise_multi_sharding_on_mesh_dim(self):
|
|
# 2d mesh pointwise sharding
|
|
mesh_shape = torch.arange(self.world_size).reshape(
|
|
self.world_size // 2, self.world_size // 2
|
|
)
|
|
mesh = DeviceMesh(self.device_type, mesh_shape)
|
|
|
|
add_call = aten.add.Tensor
|
|
|
|
# basic case to test implicit broadcasting shape alignment
|
|
mat1, mat2 = [-1, 0], [0]
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([20, 6]))
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([6]))
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
output_sharding = pointwise_rule(OpSchema(add_call, (mat1_spec, mat2_spec), {}))
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNotNone(output_spec)
|
|
self.assertEqual(output_spec.dim_map, [-1, 0])
|
|
|
|
# more advanced case that needs reshard one input to align sharding
|
|
mat1, mat2 = [0, -1, -1, 1], [0, -1, 1]
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([12, 1, 1, 8]))
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([12, 4, 8]))
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
output_sharding = pointwise_rule(OpSchema(add_call, (mat1_spec, mat2_spec), {}))
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNone(output_spec)
|
|
self.assertIsNotNone(output_sharding.redistribute_schema)
|
|
|
|
# ensure that the suggestion is to reshard the first
|
|
# arg by all_gather first tensor dim sharding
|
|
schema_suggestion = output_sharding.redistribute_schema
|
|
self.assertEqual(schema_suggestion.args_schema[0].dim_map, [-1, -1, -1, 1])
|
|
self.assertEqual(schema_suggestion.args_schema[1].dim_map, mat2)
|
|
|
|
@with_comms
|
|
def test_pointwise_enforce_sharding_multi_sharding_on_mesh_dim(self):
|
|
# 2d mesh pointwise sharding
|
|
mesh_shape = torch.arange(self.world_size).reshape(
|
|
self.world_size // 2, self.world_size // 2
|
|
)
|
|
mesh = DeviceMesh(self.device_type, mesh_shape)
|
|
|
|
add_call = aten.add_.Tensor
|
|
|
|
# more advanced case that needs reshard one input to align sharding
|
|
mat1, mat2 = [0, -1, 1], [-1, -1, 0]
|
|
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([12, 4, 8]))
|
|
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([12, 1, 8]))
|
|
mat1_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat1, [], tensor_meta=mat1_tensor_meta
|
|
)
|
|
mat2_spec = DTensorSpec.from_dim_map(
|
|
mesh, mat2, [], tensor_meta=mat2_tensor_meta
|
|
)
|
|
output_sharding = pointwise_rule(OpSchema(add_call, (mat1_spec, mat2_spec), {}))
|
|
output_spec = output_sharding.output_spec
|
|
self.assertIsNone(output_spec)
|
|
self.assertIsNotNone(output_sharding.redistribute_schema)
|
|
|
|
# ensure that the suggestion is to reshard the second
|
|
# arg as we should enforce the sharding of the first arg
|
|
schema_suggestion = output_sharding.redistribute_schema
|
|
self.assertEqual(schema_suggestion.args_schema[0].dim_map, mat1)
|
|
self.assertEqual(schema_suggestion.args_schema[1].dim_map, mat1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|