pytorch/test/distributed/tensor/test_common_rules.py
2025-01-08 00:32:33 +00:00

417 lines
17 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import torch
from torch.distributed._tensor import DeviceMesh
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
from torch.distributed.tensor._op_schema import OpSchema
from torch.distributed.tensor._ops._common_rules import einop_rule, pointwise_rule
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
aten = torch.ops.aten
class CommonRulesTest(DTensorTestBase):
@property
def world_size(self) -> int:
# hard code world size to 4 as we need to test
# at least with 2d mesh
return 4
def _gen_tensor_meta(self, shape):
empty_tensor = torch.empty(shape)
return TensorMeta(
empty_tensor.shape,
empty_tensor.stride(),
empty_tensor.dtype,
)
@with_comms
def test_einop_basic_propagation(self):
# plain einsum, mm
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
mm_call = aten.mm.default
# propagate col-wise sharding
mat1, mat2 = [-1, -1], [-1, 0]
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([4, 8]))
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], tensor_meta=mat2_tensor_meta
)
output_sharding = einop_rule(
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
)
output_spec = output_sharding.output_spec
self.assertIsNotNone(output_spec)
self.assertEqual(output_spec.dim_map, [-1, 0])
# propagate row-wise sharding
mat1, mat2 = [0, -1], [-1, -1]
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], tensor_meta=mat2_tensor_meta
)
output_sharding = einop_rule(
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
)
output_spec = output_sharding.output_spec
self.assertIsNotNone(output_spec)
self.assertEqual(output_spec.dim_map, [0, -1])
# generate partial
mat1, mat2 = [-1, 0], [0, -1]
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], tensor_meta=mat2_tensor_meta
)
output_sharding = einop_rule(
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
)
output_spec = output_sharding.output_spec
self.assertIsNotNone(output_spec)
self.assertTrue(output_spec.placements[0].is_partial())
@with_comms
def test_einop_pointwise_propagation(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
add_call = aten.add.Tensor
# addition
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 8]))
mat1 = [0, -1]
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], tensor_meta=mat1_tensor_meta
)
output_sharding = einop_rule(
"ij,ij->ij", OpSchema(add_call, (mat1_spec, mat1_spec), {})
)
output_spec = output_sharding.output_spec
self.assertIsNotNone(output_spec)
self.assertEqual(output_spec.dim_map, [0, -1])
# broadcast addition
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 8]))
mat1 = [-1, 0, -1]
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], tensor_meta=mat1_tensor_meta
)
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([2]))
mat2_spec = DTensorSpec.from_dim_map(
mesh, [-1], [], tensor_meta=mat2_tensor_meta
)
output_sharding = einop_rule(
"ijk,k->ijk", OpSchema(add_call, (mat1_spec, mat2_spec), {})
)
output_spec = output_sharding.output_spec
self.assertIsNotNone(output_spec)
self.assertEqual(output_spec.dim_map, [-1, 0, -1])
# broadcast to a common shape
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 8, 8]))
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([1, 8]))
mat1_spec = DTensorSpec.from_dim_map(
mesh, [0, -1, -1], [], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, [-1, -1], [], tensor_meta=mat2_tensor_meta
)
output_sharding = einop_rule(
"ijk,1k->ijk", OpSchema(add_call, (mat1_spec, mat2_spec), {})
)
output_spec = output_sharding.output_spec
self.assertIsNotNone(output_spec)
self.assertEqual(output_spec.dim_map, [0, -1, -1])
@with_comms
def test_einop_merge_sharding(self):
# 2d mesh einop merge sharding
mesh_shape = torch.arange(self.world_size).reshape(
self.world_size // 2, self.world_size // 2
)
mesh = DeviceMesh(self.device_type, mesh_shape)
mm_call = aten.mm.default
mat1, mat2 = [0, -1], [-1, 1]
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([4, 8]))
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], tensor_meta=mat2_tensor_meta
)
output_sharding = einop_rule(
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
)
output_spec = output_sharding.output_spec
self.assertIsNotNone(output_spec)
self.assertEqual(output_spec.dim_map, [0, 1])
@with_comms
def test_einop_linearity(self):
mesh_shape = torch.arange(self.world_size).reshape(
self.world_size // 2, self.world_size // 2
)
mesh = DeviceMesh(self.device_type, mesh_shape)
mm_call = aten.mm.default
mat1, mat2 = [0, -1], [-1, -1]
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([4, 8]))
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [1], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], tensor_meta=mat2_tensor_meta
)
# if not turn on linearity, partial sum is not eligible to propagate, we return
# suggestion to reshard inputs with no partial sum (i.e. all_reduce one input)
output_sharding = einop_rule(
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
)
self.assertIsNone(output_sharding.output_spec)
suggestions = output_sharding.redistribute_schema
self.assertIsNotNone(suggestions)
suggested_spec = suggestions.args_schema[0]
self.assertFalse(suggested_spec.placements[1].is_partial())
# einop prop with linearity on mm, should give back suggestion
# on converting placements to partial
output_sharding = einop_rule(
"mk,kn->mn",
OpSchema(mm_call, (mat1_spec, mat2_spec), {}),
linearity=True,
)
self.assertIsNone(output_sharding.output_spec)
suggestions = output_sharding.redistribute_schema
self.assertIsNotNone(suggestions)
mat2_spec = suggestions.args_schema[1]
# mat2 mesh dim 1 should become partial now!
self.assertTrue(mat2_spec.placements[1].is_partial())
# einop prop with linearity on point-wise, should give back suggestion
# on converting placements to partial
add_call = aten.add.Tensor
mat1, mat2 = [0, -1], [0, -1]
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 6]))
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([8, 6]))
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [1], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], tensor_meta=mat2_tensor_meta
)
output_sharding = einop_rule(
"ij,ij->ij",
OpSchema(add_call, (mat1_spec, mat2_spec), {}),
linearity=True,
)
self.assertIsNone(output_sharding.output_spec)
suggestions = output_sharding.redistribute_schema
self.assertIsNotNone(suggestions)
mat2_spec = suggestions.args_schema[1]
# mat2 mesh dim 1 should become partial now!
self.assertTrue(mat2_spec.placements[1].is_partial())
@with_comms
def test_einop_multi_sharding_on_mesh_dim(self):
# einop prop with multi sharding on same mesh dim
mesh_shape = torch.arange(self.world_size)
mesh = DeviceMesh(self.device_type, mesh_shape)
mm_call = aten.mm.default
mat1, mat2 = [0, -1], [0, -1]
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 12]))
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([12, 4]))
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], tensor_meta=mat2_tensor_meta
)
output_sharding = einop_rule(
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
)
output_spec = output_sharding.output_spec
self.assertIsNone(output_spec)
self.assertIsNotNone(output_sharding.redistribute_schema)
# ensure that the suggestion is to reshard the second
# arg by all_gather its tensor dim sharding
schema_suggestion = output_sharding.redistribute_schema
self.assertEqual(schema_suggestion.args_schema[0].dim_map, [0, -1])
self.assertEqual(schema_suggestion.args_schema[1].dim_map, [-1, -1])
@with_comms
def test_einop_errors(self):
mesh_shape = torch.arange(self.world_size).reshape(
self.world_size // 2, self.world_size // 2
)
mesh = DeviceMesh(self.device_type, mesh_shape)
add_call = aten.add.Tensor
mat1, mat2 = [0, -1], [1, -1]
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], tensor_meta=mat2_tensor_meta
)
with self.assertRaisesRegex(RuntimeError, "sharded two different ways:"):
einop_rule("ij,ij->ij", OpSchema(add_call, (mat1_spec, mat2_spec), {}))
@with_comms
def test_pointwise_rules_broadcasting(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
where_call = aten.where.self
inp1, inp2, inp3 = [0], [], [-1, -1]
inp1_tensor_meta = self._gen_tensor_meta(torch.Size([8]))
inp2_tensor_meta = self._gen_tensor_meta(torch.Size([]))
inp3_tensor_meta = self._gen_tensor_meta(torch.Size([1, 1]))
condition = DTensorSpec.from_dim_map(
mesh, inp1, [], tensor_meta=inp1_tensor_meta
)
self_tensor = DTensorSpec.from_dim_map(
mesh, inp2, [], tensor_meta=inp2_tensor_meta
)
other_tensor = DTensorSpec.from_dim_map(
mesh, inp3, [], tensor_meta=inp3_tensor_meta
)
# propagate point-wise sharding with broadcasting
output_sharding = pointwise_rule(
OpSchema(where_call, (condition, self_tensor, other_tensor), {})
)
output_spec = output_sharding.output_spec
self.assertIsNotNone(output_spec)
self.assertEqual(output_spec.dim_map, [-1, 0])
@with_comms
def test_pointwise_rules_suggestion(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
lerp_call = aten.lerp.Scalar
# propagate point-wise sharding
inp1, inp2 = [-1, -1], [-1, 0]
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
mat1_spec = DTensorSpec.from_dim_map(
mesh, inp1, [], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, inp2, [], tensor_meta=mat2_tensor_meta
)
# adding a positional argument -1 to arg schema
output_sharding = pointwise_rule(
OpSchema(lerp_call, (mat1_spec, mat2_spec, -1), {})
)
self.assertIsNone(output_sharding.output_spec)
self.assertIsNotNone(output_sharding.redistribute_schema)
# ensure that the suggestion from pointwise rules still have
# the positional args that are not DTensorSpec
schema_suggestion = output_sharding.redistribute_schema
self.assertEqual(len(schema_suggestion.args_schema), 3)
self.assertEqual(schema_suggestion.args_schema[2], -1)
@with_comms
def test_pointwise_multi_sharding_on_mesh_dim(self):
# 2d mesh pointwise sharding
mesh_shape = torch.arange(self.world_size).reshape(
self.world_size // 2, self.world_size // 2
)
mesh = DeviceMesh(self.device_type, mesh_shape)
add_call = aten.add.Tensor
# basic case to test implicit broadcasting shape alignment
mat1, mat2 = [-1, 0], [0]
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([20, 6]))
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([6]))
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], tensor_meta=mat2_tensor_meta
)
output_sharding = pointwise_rule(OpSchema(add_call, (mat1_spec, mat2_spec), {}))
output_spec = output_sharding.output_spec
self.assertIsNotNone(output_spec)
self.assertEqual(output_spec.dim_map, [-1, 0])
# more advanced case that needs reshard one input to align sharding
mat1, mat2 = [0, -1, -1, 1], [0, -1, 1]
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([12, 1, 1, 8]))
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([12, 4, 8]))
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], tensor_meta=mat2_tensor_meta
)
output_sharding = pointwise_rule(OpSchema(add_call, (mat1_spec, mat2_spec), {}))
output_spec = output_sharding.output_spec
self.assertIsNone(output_spec)
self.assertIsNotNone(output_sharding.redistribute_schema)
# ensure that the suggestion is to reshard the first
# arg by all_gather first tensor dim sharding
schema_suggestion = output_sharding.redistribute_schema
self.assertEqual(schema_suggestion.args_schema[0].dim_map, [-1, -1, -1, 1])
self.assertEqual(schema_suggestion.args_schema[1].dim_map, mat2)
@with_comms
def test_pointwise_enforce_sharding_multi_sharding_on_mesh_dim(self):
# 2d mesh pointwise sharding
mesh_shape = torch.arange(self.world_size).reshape(
self.world_size // 2, self.world_size // 2
)
mesh = DeviceMesh(self.device_type, mesh_shape)
add_call = aten.add_.Tensor
# more advanced case that needs reshard one input to align sharding
mat1, mat2 = [0, -1, 1], [-1, -1, 0]
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([12, 4, 8]))
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([12, 1, 8]))
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], tensor_meta=mat2_tensor_meta
)
output_sharding = pointwise_rule(OpSchema(add_call, (mat1_spec, mat2_spec), {}))
output_spec = output_sharding.output_spec
self.assertIsNone(output_spec)
self.assertIsNotNone(output_sharding.redistribute_schema)
# ensure that the suggestion is to reshard the second
# arg as we should enforce the sharding of the first arg
schema_suggestion = output_sharding.redistribute_schema
self.assertEqual(schema_suggestion.args_schema[0].dim_map, mat1)
self.assertEqual(schema_suggestion.args_schema[1].dim_map, mat1)
if __name__ == "__main__":
run_tests()