diff --git a/docs/source/distributed.tensor.parallel.rst b/docs/source/distributed.tensor.parallel.rst index f9fd13530a2..ee74f157637 100644 --- a/docs/source/distributed.tensor.parallel.rst +++ b/docs/source/distributed.tensor.parallel.rst @@ -53,13 +53,11 @@ used for input/output preparation: .. autofunction:: make_output_shard_1d .. autofunction:: make_output_tensor -Currently, there are some constraints which makes it hard for the `nn.MultiheadAttention` -module to work out of box for Tensor Parallelism, so we built this multihead_attention -module for Tensor Parallelism users. Also, in ``parallelize_module``, we automatically -swap ``nn.MultiheadAttention`` to this custom module when specifying ``PairwiseParallel``. +Currently, there are some constraints which makes it hard for the ``MultiheadAttention`` +module to work out of box for Tensor Parallelism, so we recommend users to try ``ColwiseParallel`` +and ``RowwiseParallel`` for each parameter. There might be some code changes needed now +since we are parallelizing on the head dim of the ``MultiheadAttention`` module. -.. autoclass:: torch.distributed.tensor.parallel.multihead_attention_tp.TensorParallelMultiheadAttention - :members: We also enabled 2D parallelism to integrate with ``FullyShardedDataParallel``. Users just need to call the following API explicitly: diff --git a/test/distributed/tensor/parallel/test_tp_examples.py b/test/distributed/tensor/parallel/test_tp_examples.py index f5eb18d212c..d2a5e7af2f0 100644 --- a/test/distributed/tensor/parallel/test_tp_examples.py +++ b/test/distributed/tensor/parallel/test_tp_examples.py @@ -3,7 +3,6 @@ import torch import torch.distributed as dist -import torch.nn as nn from torch.distributed._tensor import DeviceMesh, DTensor, Replicate from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper, @@ -13,7 +12,6 @@ from torch.distributed.tensor.parallel import ( PairwiseParallel, parallelize_module, SequenceParallel, - TensorParallelMultiheadAttention, ) from torch.distributed.tensor.parallel.input_reshard import input_reshard from torch.testing._internal.common_utils import ( @@ -25,22 +23,10 @@ from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, MLPModule, NUM_DEVICES, - skip_unless_torch_gpu, with_comms, ) -class MultiheadAttnWrap(nn.Module): - def __init__(self, embed_dim, num_heads, add_bias_kv=False, device=None): - super().__init__() - self.attn = nn.MultiheadAttention( - embed_dim, num_heads, add_bias_kv=add_bias_kv, device=device - ) - - def forward(self, query, key, value): - return self.attn(query, key, value) - - class DistTensorParallelExampleTest(DTensorTestBase): def _check_module(self, m1, m2, check_grad=False, rank0_only_params=None): rank0_only_params = [] if rank0_only_params is None else rank0_only_params @@ -127,289 +113,6 @@ class DistTensorParallelExampleTest(DTensorTestBase): def test_mlp_megatron_e2e(self, is_seq_parallel, recompute_activation): self._test_mlp_magatron_e2e(is_seq_parallel=is_seq_parallel, recompute_activation=recompute_activation) - # TensorParallelMultiheadAttention == dist_module(TensorParallelMultiheadAttention) - # baddbmm introduces nan occasionally on CPU: https://github.com/pytorch/pytorch/issues/80588 - @with_comms - @skip_unless_torch_gpu - def test_self_attn_megatron_e2e(self): - inp_size = [8, 12, 16] - # Ensure all tp ranks have same input. - torch.manual_seed(0) - inp = torch.rand(*inp_size, device=self.device_type) - - # Initialize model using same seed. - torch.manual_seed(5) - model = TensorParallelMultiheadAttention( - 16, - 8, - tp_size=NUM_DEVICES, - add_bias_kv=True, - device=self.device_type, - ) - torch.manual_seed(5) - model_tp = TensorParallelMultiheadAttention( - 16, - 8, - tp_size=NUM_DEVICES, - add_bias_kv=True, - device=self.device_type, - ) - - # Ensure model are initialized the same way. - self.assertEqual(model.qkv.weight, model_tp.qkv.weight) - self.assertEqual(model.qkv.bias, model_tp.qkv.bias) - self.assertEqual(model.proj.weight, model_tp.proj.weight) - self.assertEqual(model.proj.bias, model_tp.proj.bias) - - # Shard module and initialize optimizer. - device_mesh = DeviceMesh(self.device_type, list(range(NUM_DEVICES))) - parallelize_module(model_tp, device_mesh, PairwiseParallel()) - - device_mesh = model_tp.qkv.weight.device_mesh - replicate = [Replicate()] * device_mesh.ndim - # Ensure model are initialized the same way. - self.assertEqual( - model.qkv.weight, - model_tp.qkv.weight.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - self.assertEqual( - model.qkv.bias, - model_tp.qkv.bias.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - self.assertEqual( - model.proj.weight, - model_tp.proj.weight.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - self.assertEqual( - model.proj.bias, - model_tp.proj.bias.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - - LR = 0.25 - optim = torch.optim.SGD(model.parameters(), lr=LR) - optim_tp = torch.optim.SGD(model_tp.parameters(), lr=LR) - - output = model(inp, inp, inp) - output_tp = model_tp(inp, inp, inp) - self.assertEqual(output, output_tp) - - output.sum().backward() - output_tp.sum().backward() - - device_mesh = model_tp.qkv.weight.device_mesh - # Ensure gradients are same. - self.assertEqual( - model.qkv.weight.grad, - model_tp.qkv.weight.grad.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - self.assertEqual( - model.qkv.bias.grad, - model_tp.qkv.bias.grad.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - self.assertEqual( - model.proj.weight.grad, - model_tp.proj.weight.grad.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - self.assertEqual( - model.proj.bias.grad, - model_tp.proj.bias.grad.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - - optim.step() - optim_tp.step() - - # Ensure model weights are still same after update. - self.assertEqual( - model.qkv.weight, - model_tp.qkv.weight.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - self.assertEqual( - model.qkv.bias, - model_tp.qkv.bias.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - self.assertEqual( - model.proj.weight, - model_tp.proj.weight.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - self.assertEqual( - model.proj.bias, - model_tp.proj.bias.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - - inp = torch.rand(*inp_size, device=self.device_type) - output = model(inp, inp, inp) - output_tp = model_tp(inp, inp, inp) - self.assertEqual(output, output_tp) - - # TensorParallelMultiheadAttention == dist_module(torch.nn.MultiheadAttention) - # baddbmm introduces nan occasionally on CPU: https://github.com/pytorch/pytorch/issues/80588 - @with_comms - @skip_unless_torch_gpu - def test_self_attn_replacement_megatron_e2e(self): - inp_size = [8, 12, 16] - # Ensure all tp ranks have same input. - torch.manual_seed(0) - inp = torch.rand(*inp_size, device=self.device_type) - - # TODO: our sharding function cannot shard the root node - torch.manual_seed(5) - model = TensorParallelMultiheadAttention( - 16, - 8, - tp_size=NUM_DEVICES, - add_bias_kv=True, - device=self.device_type, - ) - model_tp = MultiheadAttnWrap(16, 8, add_bias_kv=True, device=self.device_type) - - # TODO: somehow using torch.nn.MultiheadAttention's initial params does not work - # Use TensorParallelMultiheadAttention parameters instead - x = model.qkv.weight.clone().detach().requires_grad_() - model_tp.attn.register_parameter("in_proj_weight", torch.nn.Parameter(x)) - - x = model.qkv.bias.clone().detach().requires_grad_() - model_tp.attn.register_parameter("in_proj_bias", torch.nn.Parameter(x)) - - x = model.proj.weight.clone().detach().requires_grad_() - model_tp.attn.out_proj.register_parameter("weight", torch.nn.Parameter(x)) - - x = model.proj.bias.clone().detach().requires_grad_() - model_tp.attn.out_proj.register_parameter("bias", torch.nn.Parameter(x)) - - # check if parameters are same - self.assertEqual(model.qkv.weight, model_tp.attn.in_proj_weight) - self.assertEqual(model.qkv.bias, model_tp.attn.in_proj_bias) - self.assertEqual(model.proj.weight, model_tp.attn.out_proj.weight) - self.assertEqual(model.proj.bias, model_tp.attn.out_proj.bias) - - # Shard module and initialize optimizer. - device_mesh = DeviceMesh(self.device_type, list(range(NUM_DEVICES))) - parallelize_module(model_tp, device_mesh, PairwiseParallel()) - - device_mesh = model_tp.attn.qkv.weight.device_mesh - replicate = [Replicate()] * device_mesh.ndim - # Ensure model are initialized the same way. - self.assertEqual( - model.qkv.weight, - model_tp.attn.qkv.weight.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - self.assertEqual( - model.qkv.bias, - model_tp.attn.qkv.bias.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - self.assertEqual( - model.proj.weight, - model_tp.attn.proj.weight.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - self.assertEqual( - model.proj.bias, - model_tp.attn.proj.bias.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - - LR = 0.25 - optim = torch.optim.SGD(model.parameters(), lr=LR) - optim_tp = torch.optim.SGD(model_tp.parameters(), lr=LR) - - output = model(inp, inp, inp) - output_tp = model_tp(inp, inp, inp) - self.assertEqual(output, output_tp) - - output.sum().backward() - output_tp.sum().backward() - - device_mesh = model_tp.attn.qkv.weight.device_mesh - # Ensure gradients are same. - self.assertEqual( - model.qkv.weight.grad, - model_tp.attn.qkv.weight.grad.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - self.assertEqual( - model.qkv.bias.grad, - model_tp.attn.qkv.bias.grad.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - self.assertEqual( - model.proj.weight.grad, - model_tp.attn.proj.weight.grad.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - self.assertEqual( - model.proj.bias.grad, - model_tp.attn.proj.bias.grad.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - - optim.step() - optim_tp.step() - - # Ensure model weights are still same after update. - self.assertEqual( - model.qkv.weight, - model_tp.attn.qkv.weight.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - self.assertEqual( - model.qkv.bias, - model_tp.attn.qkv.bias.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - self.assertEqual( - model.proj.weight, - model_tp.attn.proj.weight.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - self.assertEqual( - model.proj.bias, - model_tp.attn.proj.bias.redistribute( - device_mesh=device_mesh, placements=replicate - ).to_local(), - ) - - inp = torch.rand(*inp_size, device=self.device_type) - output = model(inp, inp, inp) - output_tp = model_tp(inp, inp, inp) - self.assertEqual(output, output_tp) - instantiate_parametrized_tests(DistTensorParallelExampleTest) diff --git a/torch/distributed/tensor/parallel/__init__.py b/torch/distributed/tensor/parallel/__init__.py index caf2853c019..81167ddc49f 100644 --- a/torch/distributed/tensor/parallel/__init__.py +++ b/torch/distributed/tensor/parallel/__init__.py @@ -1,8 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates from torch.distributed.tensor.parallel.api import parallelize_module -from torch.distributed.tensor.parallel.multihead_attention_tp import ( - TensorParallelMultiheadAttention, -) from torch.distributed.tensor.parallel.style import ( ColwiseParallel, @@ -27,7 +24,6 @@ __all__ = [ "ParallelStyle", "RowwiseParallel", "SequenceParallel", - "TensorParallelMultiheadAttention", "make_input_replicate_1d", "make_input_reshard_replicate", "make_input_shard_1d", diff --git a/torch/distributed/tensor/parallel/api.py b/torch/distributed/tensor/parallel/api.py index 4cb2867727e..06e936cc7f7 100644 --- a/torch/distributed/tensor/parallel/api.py +++ b/torch/distributed/tensor/parallel/api.py @@ -18,9 +18,6 @@ from torch.distributed._tensor.random import ( ) from torch.distributed._tensor.sharding_prop import _CachingPropagator from torch.distributed.tensor.parallel._utils import _create_1d_device_mesh -from torch.distributed.tensor.parallel.multihead_attention_tp import ( - TensorParallelMultiheadAttention, -) from torch.distributed.tensor.parallel.style import ( ColwiseParallel, PairwiseParallel, @@ -108,9 +105,7 @@ def parallelize_module( # type: ignore[return] if isinstance(parallelize_plan, (ColwiseParallel, RowwiseParallel)): return _parallelize_linear(module, device_mesh, parallelize_plan) # PairwiseParallel - if _is_mha_for_pairwise_parallel(module): - return _parallelize_multihead_attn(module, device_mesh) - elif _is_mlp_for_pairwise_parallel(module): + if _is_mlp_for_pairwise_parallel(module): return _parallelize_mlp(module, device_mesh, parallelize_plan) else: for n, m in module.named_children(): @@ -140,20 +135,6 @@ def parallelize_module( # type: ignore[return] ) -def _is_mha_for_pairwise_parallel(module: nn.Module) -> bool: - """ - Check whether the mha module is the one can be handled for Pairwise parallel. - - Args: - module (:class:`nn.Module`): - Module to be checked. - - Return: - A boolean object which specifies whether the module is MHA supported by Pairwise parallel or not. - """ - return isinstance(module, (TensorParallelMultiheadAttention, nn.MultiheadAttention)) - - def _is_mlp_for_pairwise_parallel(module: nn.Module) -> bool: """ Traverse through all the immediate children of the given module and count the @@ -304,80 +285,6 @@ def _parallelize_linear( return module -def _parallelize_multihead_attn( - module: nn.Module, - device_mesh: DeviceMesh, - parallel_style: ParallelStyle = PairwiseParallel(), - tp_mesh_dim: int = 0, -) -> nn.Module: - """ - This function assumes the input module is a sequence of nn.Linear - and we parallelize the module based on the given parallel style. - We don't change the FQN of each sub-module and replace each parameter - in place. - - Args: - module (:class:`nn.Module`): - Module to be parallelized. - device_mesh (:class:`DeviceMesh`): - Object which describes the mesh topology of devices. - parallel_style (:class:`ParallelStyle`): - Object which contains how we prepare input/output - for Tensor Parallelism. - tp_mesh_dim (int): - The dimension of `device_mesh` where we perform - Tensor Parallelism on. - - Return: - A :class:`nn.Module` object parallelized. - - .. warning:: - We only support ``PairwiseParallel`` right now. - """ - - if not isinstance(parallel_style, PairwiseParallel): - raise NotImplementedError( - "Only support PairwiseParallel for Multihead Attention parallelization." - ) - - if device_mesh.ndim > 1: - device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim) - - if isinstance(module, nn.MultiheadAttention): - tp_multi_head_attention = TensorParallelMultiheadAttention( - module.embed_dim, - module.num_heads, - device=torch.device(device_mesh.device_type), - tp_size=device_mesh.size(tp_mesh_dim), - add_bias_kv=module.bias_k is not None, - ) - tp_multi_head_attention.copy(module) - module = tp_multi_head_attention - - assert isinstance(module, TensorParallelMultiheadAttention), ( - f"Expects TensorParallelMultiheadAttention but got {type(module)}" - ) - # shard TPMA - for n, m in module.named_children(): - if n == "qkv": - # Col-wise Parallelize the qkv layer. - distribute_module( - m, - device_mesh, - _colwise_parallelize_linear_fn, - input_fn=parallel_style._prepare_input, # type: ignore[arg-type, misc] # pyre-ignore[6] - ) - elif n == "proj": - # Row-wise Parallelize the proj layer - distribute_module( - m, - device_mesh, - _rowwise_parallelize_linear_fn, - output_fn=parallel_style._prepare_output, # type: ignore[arg-type, misc] # pyre-ignore[6] - ) - return module - - def _parallelize_mlp( module: nn.Module, device_mesh: DeviceMesh, diff --git a/torch/distributed/tensor/parallel/multihead_attention_tp.py b/torch/distributed/tensor/parallel/multihead_attention_tp.py deleted file mode 100644 index 6f64f892300..00000000000 --- a/torch/distributed/tensor/parallel/multihead_attention_tp.py +++ /dev/null @@ -1,273 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# pyre-ignore-all-errors[6] - -import math - -from typing import Optional, Union - -import torch -from torch.distributed._tensor import DTensor as DT -from torch.distributed._tensor.placement_types import Shard -from torch.distributed.tensor.parallel._view_with_dim_change import ( - _view_with_sharding_dim_change, -) - -__all__ = ["TensorParallelMultiheadAttention"] - - -# TODO: Add a test to test equivalence between our Multihead Attention -# with other mainstream ones (Megatron-LM or PyTorch). -def _stride_same_as_shard( - tensor: torch.Tensor, tp_size: int, chunk_dim: int, cat_dim: int -) -> torch.Tensor: - """ - Adjust local tensor's stride same as the sharded situation. - So that view result will keeps the same. - """ - if isinstance(tensor, DT): - return tensor - view_size = list(tensor.size()) - view_size[chunk_dim] //= tp_size - return torch.cat( - [t.view(*view_size) for t in tensor.chunk(tp_size, dim=chunk_dim)], - dim=cat_dim, - ).contiguous() - - -class TensorParallelMultiheadAttention(torch.nn.Module): - """ - Multi-head Attention block from Transformer models. - Since we need some customizations for the attention layer, - we are writing a customized but mathematically equivalent - attention module as defined in torch.nn. - - Note that: - We now only support the case when it's self attention with - limited input args and we also assume that the input tensor - has a dimension of three. Although we do implement the logic - for multihead attention, it was not fully tested. - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - bias: bool = True, - add_bias_kv: bool = False, - add_zero_attn: bool = False, - kdim: Optional[int] = None, - vdim: Optional[int] = None, - batch_first: bool = False, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - tp_size: int = 1, - self_attention: bool = True, - ) -> None: - super().__init__() - self.device: torch.device = ( - torch.device("cuda" if torch.cuda.is_available() else "cpu") - if device is None - else device - ) - self.num_heads = num_heads - self.hidden_size = embed_dim - self.hidden_size_per_attention_head: int = self.hidden_size // num_heads - self.scale: float = self.hidden_size_per_attention_head**-0.5 - if self_attention: - self.qkv: torch.nn.Module = torch.nn.Linear( - embed_dim, embed_dim * 3, bias=add_bias_kv, device=self.device - ) - torch.nn.init.xavier_uniform_(self.qkv.weight) - if add_bias_kv: - torch.nn.init.zeros_(self.qkv.bias) - else: - self.query: torch.nn.Module = torch.nn.Linear( - embed_dim, embed_dim, bias=add_bias_kv, device=self.device - ) - self.key: torch.nn.Module = torch.nn.Linear( - embed_dim, embed_dim, bias=add_bias_kv, device=self.device - ) - self.value: torch.nn.Module = torch.nn.Linear( - embed_dim, embed_dim, bias=add_bias_kv, device=self.device - ) - torch.nn.init.xavier_uniform_(self.query.weight) - torch.nn.init.xavier_uniform_(self.key.weight) - torch.nn.init.xavier_uniform_(self.value.weight) - if add_bias_kv: - torch.nn.init.zeros_(self.query.bias) - torch.nn.init.zeros_(self.key.bias) - torch.nn.init.zeros_(self.value.bias) - self.proj: torch.nn.Module = torch.nn.Linear( - embed_dim, embed_dim, bias=bias, device=self.device - ) - torch.nn.init.kaiming_uniform_(self.proj.weight, a=math.sqrt(5)) - if bias: - torch.nn.init.zeros_(self.proj.bias) - self.tp_size = tp_size - self.hidden_size = embed_dim - self.norm_factor: float = math.sqrt(self.hidden_size_per_attention_head) - self.self_attention = self_attention - - def forward( - self, - query: Union[torch.Tensor, DT], - key: Union[torch.Tensor, DT], - value: Union[torch.Tensor, DT], - key_padding_mask: Optional[torch.Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[torch.Tensor] = None, - average_attn_weights: bool = True, - ) -> Union[torch.Tensor, DT]: - b, sq, h = query.shape - sk = key.size(1) - nh = self.num_heads - hn = self.hidden_size_per_attention_head - - # x: [b, sq/sk/sv, h] - # =================== - # Permute. [sq/sk/sv, b, h] - # =================== - if not self.self_attention: - # ===================== - # Query, Key, and Value - # ===================== - query = query.permute(1, 0, 2).contiguous() - key = key.permute(1, 0, 2).contiguous() - value = value.permute(1, 0, 2).contiguous() - - # Attention heads [sq/sk/sv, b, h] --> [sq/sk/sv * b, (nh * hn)] - query = query.view(-1, h) - key = key.view(-1, h) - value = value.view(-1, h) - - query_layer = _view_with_sharding_dim_change( - self.query(query), 1, (sq, b * nh, hn) - ) - key_layer = _view_with_sharding_dim_change( - self.key(key), 1, (sk, b * nh, hn) - ) - value_layer = _view_with_sharding_dim_change( - self.value(value), 1, (sk, b * nh, hn) - ) - else: - assert torch.equal(query, key) and torch.equal( - query, value - ), "inputs are different for self-attention." - # ===================== - # Query - # ===================== - query = query.permute(1, 0, 2).contiguous() - - # Attention heads [sq, b, h] --> [sq * b, (nh * 3 * hn)] - query = query.view(-1, h) - mixed_x_layer = self.qkv(query) - - # [sq * b, 3 * h] --> [sq, b, nh, 3 * hn] - mixed_x_layer = _view_with_sharding_dim_change( - mixed_x_layer, 2, (sq, b, nh, 3 * hn) - ) - - # [sq, b, nh, 3 * hn] --> 3 [sq, b, nh, hn] - last_dim = mixed_x_layer.dim() - 1 - last_dim_size = mixed_x_layer.size(last_dim) // 3 - (query_layer, key_layer, value_layer) = mixed_x_layer.split( - last_dim_size, dim=last_dim - ) - - query_layer = _stride_same_as_shard(query_layer, self.tp_size, 2, 1) - key_layer = _stride_same_as_shard(key_layer, self.tp_size, 2, 1) - value_layer = _stride_same_as_shard(value_layer, self.tp_size, 2, 1) - # [sq, b, nh, hn] -> [sq, b * nh, hn] - query_layer = _view_with_sharding_dim_change( - query_layer, 1, (sq, b * nh, -1) - ) - key_layer = _view_with_sharding_dim_change(key_layer, 1, (sq, b * nh, -1)) - value_layer = _view_with_sharding_dim_change( - value_layer, 1, (sq, b * nh, -1) - ) - - # =================================== - # Raw attention scores. [b, nh, s, s] - # =================================== - - factor = self.tp_size if isinstance(query_layer, DT) else 1 - # preallocating result tensor: [b * nh, sq, sk] - matmul_result = torch.empty( - b * nh // factor, - sq, - sk, - dtype=query_layer.dtype, - device=self.device, - ) - if isinstance(query_layer, DT): - matmul_result = DT.from_local( - matmul_result, - query_layer.device_mesh, - [Shard(0)], - run_check=False, - ) - - # Raw attention scores. [b * nh, sq, sk] - attn = torch.baddbmm( - matmul_result, - query_layer.transpose(0, 1), # [b * nh, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * nh, hn, sk] - beta=0.0, - alpha=(1.0 / self.norm_factor), - ) - - # =============== - # Attention probs - # =============== - attn = attn.softmax(dim=-1) - - # ========================= - # Context layer. [sq * b, hidden] - # ========================= - - # bmm: [b * nh, sq, hn] - context_layer = torch.bmm(attn, value_layer.transpose(0, 1)) - - # change view [nh, b, sq, hn] - context_layer = context_layer.view(nh, b, sq, hn) - - # [nh, b, sq, hn] --> [sq, b, nh, hn] - context_layer = context_layer.permute(2, 1, 0, 3).contiguous() - - # [sq, b, nh, hn] --> [sq * b, hidden] - context_layer = _view_with_sharding_dim_change( - context_layer.contiguous(), 1, (-1, self.hidden_size) - ) - - # ================= - # Projection. [sq, b, h] - # ================= - output = self.proj(context_layer).view(sq, b, h) - - # =================== - # Permute. [b, sq, h] - # =================== - output = output.permute(1, 0, 2) - - return output - - def copy(self, that: torch.nn.MultiheadAttention) -> None: - # TODO: current implementation assume `self` is a self attention module - assert ( - self.hidden_size == that.embed_dim - ), "embed_dim must be equal in TensorParallelMultiheadAttention.copy()!" - - if that.in_proj_weight is not None: - self.qkv.register_parameter("weight", that.in_proj_weight) - if that.in_proj_bias is not None: - self.qkv.register_parameter("bias", that.in_proj_bias) - if that.out_proj.weight is not None: - # TODO: The use of Parameter is to avoid `mypy` issue caused - # by the `tensor` type annotation on Linear.weight to which - # a Parameter object is actually assigned - self.proj.register_parameter( - "weight", torch.nn.Parameter(that.out_proj.weight) - ) - if that.out_proj.bias is not None: - self.proj.register_parameter("bias", that.out_proj.bias)