mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Clean up unsed MHA code to avoid confusion (#105956)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105956 Approved by: https://github.com/wz337, https://github.com/ezyang, https://github.com/wanchaol
This commit is contained in:
parent
1a59be2c9e
commit
487ebcac3b
5 changed files with 5 additions and 674 deletions
|
|
@ -53,13 +53,11 @@ used for input/output preparation:
|
|||
.. autofunction:: make_output_shard_1d
|
||||
.. autofunction:: make_output_tensor
|
||||
|
||||
Currently, there are some constraints which makes it hard for the `nn.MultiheadAttention`
|
||||
module to work out of box for Tensor Parallelism, so we built this multihead_attention
|
||||
module for Tensor Parallelism users. Also, in ``parallelize_module``, we automatically
|
||||
swap ``nn.MultiheadAttention`` to this custom module when specifying ``PairwiseParallel``.
|
||||
Currently, there are some constraints which makes it hard for the ``MultiheadAttention``
|
||||
module to work out of box for Tensor Parallelism, so we recommend users to try ``ColwiseParallel``
|
||||
and ``RowwiseParallel`` for each parameter. There might be some code changes needed now
|
||||
since we are parallelizing on the head dim of the ``MultiheadAttention`` module.
|
||||
|
||||
.. autoclass:: torch.distributed.tensor.parallel.multihead_attention_tp.TensorParallelMultiheadAttention
|
||||
:members:
|
||||
|
||||
We also enabled 2D parallelism to integrate with ``FullyShardedDataParallel``.
|
||||
Users just need to call the following API explicitly:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
checkpoint_wrapper,
|
||||
|
|
@ -13,7 +12,6 @@ from torch.distributed.tensor.parallel import (
|
|||
PairwiseParallel,
|
||||
parallelize_module,
|
||||
SequenceParallel,
|
||||
TensorParallelMultiheadAttention,
|
||||
)
|
||||
from torch.distributed.tensor.parallel.input_reshard import input_reshard
|
||||
from torch.testing._internal.common_utils import (
|
||||
|
|
@ -25,22 +23,10 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|||
DTensorTestBase,
|
||||
MLPModule,
|
||||
NUM_DEVICES,
|
||||
skip_unless_torch_gpu,
|
||||
with_comms,
|
||||
)
|
||||
|
||||
|
||||
class MultiheadAttnWrap(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, add_bias_kv=False, device=None):
|
||||
super().__init__()
|
||||
self.attn = nn.MultiheadAttention(
|
||||
embed_dim, num_heads, add_bias_kv=add_bias_kv, device=device
|
||||
)
|
||||
|
||||
def forward(self, query, key, value):
|
||||
return self.attn(query, key, value)
|
||||
|
||||
|
||||
class DistTensorParallelExampleTest(DTensorTestBase):
|
||||
def _check_module(self, m1, m2, check_grad=False, rank0_only_params=None):
|
||||
rank0_only_params = [] if rank0_only_params is None else rank0_only_params
|
||||
|
|
@ -127,289 +113,6 @@ class DistTensorParallelExampleTest(DTensorTestBase):
|
|||
def test_mlp_megatron_e2e(self, is_seq_parallel, recompute_activation):
|
||||
self._test_mlp_magatron_e2e(is_seq_parallel=is_seq_parallel, recompute_activation=recompute_activation)
|
||||
|
||||
# TensorParallelMultiheadAttention == dist_module(TensorParallelMultiheadAttention)
|
||||
# baddbmm introduces nan occasionally on CPU: https://github.com/pytorch/pytorch/issues/80588
|
||||
@with_comms
|
||||
@skip_unless_torch_gpu
|
||||
def test_self_attn_megatron_e2e(self):
|
||||
inp_size = [8, 12, 16]
|
||||
# Ensure all tp ranks have same input.
|
||||
torch.manual_seed(0)
|
||||
inp = torch.rand(*inp_size, device=self.device_type)
|
||||
|
||||
# Initialize model using same seed.
|
||||
torch.manual_seed(5)
|
||||
model = TensorParallelMultiheadAttention(
|
||||
16,
|
||||
8,
|
||||
tp_size=NUM_DEVICES,
|
||||
add_bias_kv=True,
|
||||
device=self.device_type,
|
||||
)
|
||||
torch.manual_seed(5)
|
||||
model_tp = TensorParallelMultiheadAttention(
|
||||
16,
|
||||
8,
|
||||
tp_size=NUM_DEVICES,
|
||||
add_bias_kv=True,
|
||||
device=self.device_type,
|
||||
)
|
||||
|
||||
# Ensure model are initialized the same way.
|
||||
self.assertEqual(model.qkv.weight, model_tp.qkv.weight)
|
||||
self.assertEqual(model.qkv.bias, model_tp.qkv.bias)
|
||||
self.assertEqual(model.proj.weight, model_tp.proj.weight)
|
||||
self.assertEqual(model.proj.bias, model_tp.proj.bias)
|
||||
|
||||
# Shard module and initialize optimizer.
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(NUM_DEVICES)))
|
||||
parallelize_module(model_tp, device_mesh, PairwiseParallel())
|
||||
|
||||
device_mesh = model_tp.qkv.weight.device_mesh
|
||||
replicate = [Replicate()] * device_mesh.ndim
|
||||
# Ensure model are initialized the same way.
|
||||
self.assertEqual(
|
||||
model.qkv.weight,
|
||||
model_tp.qkv.weight.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
self.assertEqual(
|
||||
model.qkv.bias,
|
||||
model_tp.qkv.bias.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
self.assertEqual(
|
||||
model.proj.weight,
|
||||
model_tp.proj.weight.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
self.assertEqual(
|
||||
model.proj.bias,
|
||||
model_tp.proj.bias.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
|
||||
LR = 0.25
|
||||
optim = torch.optim.SGD(model.parameters(), lr=LR)
|
||||
optim_tp = torch.optim.SGD(model_tp.parameters(), lr=LR)
|
||||
|
||||
output = model(inp, inp, inp)
|
||||
output_tp = model_tp(inp, inp, inp)
|
||||
self.assertEqual(output, output_tp)
|
||||
|
||||
output.sum().backward()
|
||||
output_tp.sum().backward()
|
||||
|
||||
device_mesh = model_tp.qkv.weight.device_mesh
|
||||
# Ensure gradients are same.
|
||||
self.assertEqual(
|
||||
model.qkv.weight.grad,
|
||||
model_tp.qkv.weight.grad.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
self.assertEqual(
|
||||
model.qkv.bias.grad,
|
||||
model_tp.qkv.bias.grad.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
self.assertEqual(
|
||||
model.proj.weight.grad,
|
||||
model_tp.proj.weight.grad.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
self.assertEqual(
|
||||
model.proj.bias.grad,
|
||||
model_tp.proj.bias.grad.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
|
||||
optim.step()
|
||||
optim_tp.step()
|
||||
|
||||
# Ensure model weights are still same after update.
|
||||
self.assertEqual(
|
||||
model.qkv.weight,
|
||||
model_tp.qkv.weight.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
self.assertEqual(
|
||||
model.qkv.bias,
|
||||
model_tp.qkv.bias.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
self.assertEqual(
|
||||
model.proj.weight,
|
||||
model_tp.proj.weight.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
self.assertEqual(
|
||||
model.proj.bias,
|
||||
model_tp.proj.bias.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
|
||||
inp = torch.rand(*inp_size, device=self.device_type)
|
||||
output = model(inp, inp, inp)
|
||||
output_tp = model_tp(inp, inp, inp)
|
||||
self.assertEqual(output, output_tp)
|
||||
|
||||
# TensorParallelMultiheadAttention == dist_module(torch.nn.MultiheadAttention)
|
||||
# baddbmm introduces nan occasionally on CPU: https://github.com/pytorch/pytorch/issues/80588
|
||||
@with_comms
|
||||
@skip_unless_torch_gpu
|
||||
def test_self_attn_replacement_megatron_e2e(self):
|
||||
inp_size = [8, 12, 16]
|
||||
# Ensure all tp ranks have same input.
|
||||
torch.manual_seed(0)
|
||||
inp = torch.rand(*inp_size, device=self.device_type)
|
||||
|
||||
# TODO: our sharding function cannot shard the root node
|
||||
torch.manual_seed(5)
|
||||
model = TensorParallelMultiheadAttention(
|
||||
16,
|
||||
8,
|
||||
tp_size=NUM_DEVICES,
|
||||
add_bias_kv=True,
|
||||
device=self.device_type,
|
||||
)
|
||||
model_tp = MultiheadAttnWrap(16, 8, add_bias_kv=True, device=self.device_type)
|
||||
|
||||
# TODO: somehow using torch.nn.MultiheadAttention's initial params does not work
|
||||
# Use TensorParallelMultiheadAttention parameters instead
|
||||
x = model.qkv.weight.clone().detach().requires_grad_()
|
||||
model_tp.attn.register_parameter("in_proj_weight", torch.nn.Parameter(x))
|
||||
|
||||
x = model.qkv.bias.clone().detach().requires_grad_()
|
||||
model_tp.attn.register_parameter("in_proj_bias", torch.nn.Parameter(x))
|
||||
|
||||
x = model.proj.weight.clone().detach().requires_grad_()
|
||||
model_tp.attn.out_proj.register_parameter("weight", torch.nn.Parameter(x))
|
||||
|
||||
x = model.proj.bias.clone().detach().requires_grad_()
|
||||
model_tp.attn.out_proj.register_parameter("bias", torch.nn.Parameter(x))
|
||||
|
||||
# check if parameters are same
|
||||
self.assertEqual(model.qkv.weight, model_tp.attn.in_proj_weight)
|
||||
self.assertEqual(model.qkv.bias, model_tp.attn.in_proj_bias)
|
||||
self.assertEqual(model.proj.weight, model_tp.attn.out_proj.weight)
|
||||
self.assertEqual(model.proj.bias, model_tp.attn.out_proj.bias)
|
||||
|
||||
# Shard module and initialize optimizer.
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(NUM_DEVICES)))
|
||||
parallelize_module(model_tp, device_mesh, PairwiseParallel())
|
||||
|
||||
device_mesh = model_tp.attn.qkv.weight.device_mesh
|
||||
replicate = [Replicate()] * device_mesh.ndim
|
||||
# Ensure model are initialized the same way.
|
||||
self.assertEqual(
|
||||
model.qkv.weight,
|
||||
model_tp.attn.qkv.weight.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
self.assertEqual(
|
||||
model.qkv.bias,
|
||||
model_tp.attn.qkv.bias.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
self.assertEqual(
|
||||
model.proj.weight,
|
||||
model_tp.attn.proj.weight.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
self.assertEqual(
|
||||
model.proj.bias,
|
||||
model_tp.attn.proj.bias.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
|
||||
LR = 0.25
|
||||
optim = torch.optim.SGD(model.parameters(), lr=LR)
|
||||
optim_tp = torch.optim.SGD(model_tp.parameters(), lr=LR)
|
||||
|
||||
output = model(inp, inp, inp)
|
||||
output_tp = model_tp(inp, inp, inp)
|
||||
self.assertEqual(output, output_tp)
|
||||
|
||||
output.sum().backward()
|
||||
output_tp.sum().backward()
|
||||
|
||||
device_mesh = model_tp.attn.qkv.weight.device_mesh
|
||||
# Ensure gradients are same.
|
||||
self.assertEqual(
|
||||
model.qkv.weight.grad,
|
||||
model_tp.attn.qkv.weight.grad.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
self.assertEqual(
|
||||
model.qkv.bias.grad,
|
||||
model_tp.attn.qkv.bias.grad.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
self.assertEqual(
|
||||
model.proj.weight.grad,
|
||||
model_tp.attn.proj.weight.grad.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
self.assertEqual(
|
||||
model.proj.bias.grad,
|
||||
model_tp.attn.proj.bias.grad.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
|
||||
optim.step()
|
||||
optim_tp.step()
|
||||
|
||||
# Ensure model weights are still same after update.
|
||||
self.assertEqual(
|
||||
model.qkv.weight,
|
||||
model_tp.attn.qkv.weight.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
self.assertEqual(
|
||||
model.qkv.bias,
|
||||
model_tp.attn.qkv.bias.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
self.assertEqual(
|
||||
model.proj.weight,
|
||||
model_tp.attn.proj.weight.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
self.assertEqual(
|
||||
model.proj.bias,
|
||||
model_tp.attn.proj.bias.redistribute(
|
||||
device_mesh=device_mesh, placements=replicate
|
||||
).to_local(),
|
||||
)
|
||||
|
||||
inp = torch.rand(*inp_size, device=self.device_type)
|
||||
output = model(inp, inp, inp)
|
||||
output_tp = model_tp(inp, inp, inp)
|
||||
self.assertEqual(output, output_tp)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(DistTensorParallelExampleTest)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from torch.distributed.tensor.parallel.api import parallelize_module
|
||||
from torch.distributed.tensor.parallel.multihead_attention_tp import (
|
||||
TensorParallelMultiheadAttention,
|
||||
)
|
||||
|
||||
from torch.distributed.tensor.parallel.style import (
|
||||
ColwiseParallel,
|
||||
|
|
@ -27,7 +24,6 @@ __all__ = [
|
|||
"ParallelStyle",
|
||||
"RowwiseParallel",
|
||||
"SequenceParallel",
|
||||
"TensorParallelMultiheadAttention",
|
||||
"make_input_replicate_1d",
|
||||
"make_input_reshard_replicate",
|
||||
"make_input_shard_1d",
|
||||
|
|
|
|||
|
|
@ -18,9 +18,6 @@ from torch.distributed._tensor.random import (
|
|||
)
|
||||
from torch.distributed._tensor.sharding_prop import _CachingPropagator
|
||||
from torch.distributed.tensor.parallel._utils import _create_1d_device_mesh
|
||||
from torch.distributed.tensor.parallel.multihead_attention_tp import (
|
||||
TensorParallelMultiheadAttention,
|
||||
)
|
||||
from torch.distributed.tensor.parallel.style import (
|
||||
ColwiseParallel,
|
||||
PairwiseParallel,
|
||||
|
|
@ -108,9 +105,7 @@ def parallelize_module( # type: ignore[return]
|
|||
if isinstance(parallelize_plan, (ColwiseParallel, RowwiseParallel)):
|
||||
return _parallelize_linear(module, device_mesh, parallelize_plan)
|
||||
# PairwiseParallel
|
||||
if _is_mha_for_pairwise_parallel(module):
|
||||
return _parallelize_multihead_attn(module, device_mesh)
|
||||
elif _is_mlp_for_pairwise_parallel(module):
|
||||
if _is_mlp_for_pairwise_parallel(module):
|
||||
return _parallelize_mlp(module, device_mesh, parallelize_plan)
|
||||
else:
|
||||
for n, m in module.named_children():
|
||||
|
|
@ -140,20 +135,6 @@ def parallelize_module( # type: ignore[return]
|
|||
)
|
||||
|
||||
|
||||
def _is_mha_for_pairwise_parallel(module: nn.Module) -> bool:
|
||||
"""
|
||||
Check whether the mha module is the one can be handled for Pairwise parallel.
|
||||
|
||||
Args:
|
||||
module (:class:`nn.Module`):
|
||||
Module to be checked.
|
||||
|
||||
Return:
|
||||
A boolean object which specifies whether the module is MHA supported by Pairwise parallel or not.
|
||||
"""
|
||||
return isinstance(module, (TensorParallelMultiheadAttention, nn.MultiheadAttention))
|
||||
|
||||
|
||||
def _is_mlp_for_pairwise_parallel(module: nn.Module) -> bool:
|
||||
"""
|
||||
Traverse through all the immediate children of the given module and count the
|
||||
|
|
@ -304,80 +285,6 @@ def _parallelize_linear(
|
|||
return module
|
||||
|
||||
|
||||
def _parallelize_multihead_attn(
|
||||
module: nn.Module,
|
||||
device_mesh: DeviceMesh,
|
||||
parallel_style: ParallelStyle = PairwiseParallel(),
|
||||
tp_mesh_dim: int = 0,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
This function assumes the input module is a sequence of nn.Linear
|
||||
and we parallelize the module based on the given parallel style.
|
||||
We don't change the FQN of each sub-module and replace each parameter
|
||||
in place.
|
||||
|
||||
Args:
|
||||
module (:class:`nn.Module`):
|
||||
Module to be parallelized.
|
||||
device_mesh (:class:`DeviceMesh`):
|
||||
Object which describes the mesh topology of devices.
|
||||
parallel_style (:class:`ParallelStyle`):
|
||||
Object which contains how we prepare input/output
|
||||
for Tensor Parallelism.
|
||||
tp_mesh_dim (int):
|
||||
The dimension of `device_mesh` where we perform
|
||||
Tensor Parallelism on.
|
||||
|
||||
Return:
|
||||
A :class:`nn.Module` object parallelized.
|
||||
|
||||
.. warning::
|
||||
We only support ``PairwiseParallel`` right now.
|
||||
"""
|
||||
|
||||
if not isinstance(parallel_style, PairwiseParallel):
|
||||
raise NotImplementedError(
|
||||
"Only support PairwiseParallel for Multihead Attention parallelization."
|
||||
)
|
||||
|
||||
if device_mesh.ndim > 1:
|
||||
device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
|
||||
|
||||
if isinstance(module, nn.MultiheadAttention):
|
||||
tp_multi_head_attention = TensorParallelMultiheadAttention(
|
||||
module.embed_dim,
|
||||
module.num_heads,
|
||||
device=torch.device(device_mesh.device_type),
|
||||
tp_size=device_mesh.size(tp_mesh_dim),
|
||||
add_bias_kv=module.bias_k is not None,
|
||||
)
|
||||
tp_multi_head_attention.copy(module)
|
||||
module = tp_multi_head_attention
|
||||
|
||||
assert isinstance(module, TensorParallelMultiheadAttention), (
|
||||
f"Expects TensorParallelMultiheadAttention but got {type(module)}"
|
||||
)
|
||||
# shard TPMA
|
||||
for n, m in module.named_children():
|
||||
if n == "qkv":
|
||||
# Col-wise Parallelize the qkv layer.
|
||||
distribute_module(
|
||||
m,
|
||||
device_mesh,
|
||||
_colwise_parallelize_linear_fn,
|
||||
input_fn=parallel_style._prepare_input, # type: ignore[arg-type, misc] # pyre-ignore[6]
|
||||
)
|
||||
elif n == "proj":
|
||||
# Row-wise Parallelize the proj layer
|
||||
distribute_module(
|
||||
m,
|
||||
device_mesh,
|
||||
_rowwise_parallelize_linear_fn,
|
||||
output_fn=parallel_style._prepare_output, # type: ignore[arg-type, misc] # pyre-ignore[6]
|
||||
)
|
||||
return module
|
||||
|
||||
|
||||
def _parallelize_mlp(
|
||||
module: nn.Module,
|
||||
device_mesh: DeviceMesh,
|
||||
|
|
|
|||
|
|
@ -1,273 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# pyre-ignore-all-errors[6]
|
||||
|
||||
import math
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed._tensor import DTensor as DT
|
||||
from torch.distributed._tensor.placement_types import Shard
|
||||
from torch.distributed.tensor.parallel._view_with_dim_change import (
|
||||
_view_with_sharding_dim_change,
|
||||
)
|
||||
|
||||
__all__ = ["TensorParallelMultiheadAttention"]
|
||||
|
||||
|
||||
# TODO: Add a test to test equivalence between our Multihead Attention
|
||||
# with other mainstream ones (Megatron-LM or PyTorch).
|
||||
def _stride_same_as_shard(
|
||||
tensor: torch.Tensor, tp_size: int, chunk_dim: int, cat_dim: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Adjust local tensor's stride same as the sharded situation.
|
||||
So that view result will keeps the same.
|
||||
"""
|
||||
if isinstance(tensor, DT):
|
||||
return tensor
|
||||
view_size = list(tensor.size())
|
||||
view_size[chunk_dim] //= tp_size
|
||||
return torch.cat(
|
||||
[t.view(*view_size) for t in tensor.chunk(tp_size, dim=chunk_dim)],
|
||||
dim=cat_dim,
|
||||
).contiguous()
|
||||
|
||||
|
||||
class TensorParallelMultiheadAttention(torch.nn.Module):
|
||||
"""
|
||||
Multi-head Attention block from Transformer models.
|
||||
Since we need some customizations for the attention layer,
|
||||
we are writing a customized but mathematically equivalent
|
||||
attention module as defined in torch.nn.
|
||||
|
||||
Note that:
|
||||
We now only support the case when it's self attention with
|
||||
limited input args and we also assume that the input tensor
|
||||
has a dimension of three. Although we do implement the logic
|
||||
for multihead attention, it was not fully tested.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = True,
|
||||
add_bias_kv: bool = False,
|
||||
add_zero_attn: bool = False,
|
||||
kdim: Optional[int] = None,
|
||||
vdim: Optional[int] = None,
|
||||
batch_first: bool = False,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
tp_size: int = 1,
|
||||
self_attention: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device: torch.device = (
|
||||
torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if device is None
|
||||
else device
|
||||
)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = embed_dim
|
||||
self.hidden_size_per_attention_head: int = self.hidden_size // num_heads
|
||||
self.scale: float = self.hidden_size_per_attention_head**-0.5
|
||||
if self_attention:
|
||||
self.qkv: torch.nn.Module = torch.nn.Linear(
|
||||
embed_dim, embed_dim * 3, bias=add_bias_kv, device=self.device
|
||||
)
|
||||
torch.nn.init.xavier_uniform_(self.qkv.weight)
|
||||
if add_bias_kv:
|
||||
torch.nn.init.zeros_(self.qkv.bias)
|
||||
else:
|
||||
self.query: torch.nn.Module = torch.nn.Linear(
|
||||
embed_dim, embed_dim, bias=add_bias_kv, device=self.device
|
||||
)
|
||||
self.key: torch.nn.Module = torch.nn.Linear(
|
||||
embed_dim, embed_dim, bias=add_bias_kv, device=self.device
|
||||
)
|
||||
self.value: torch.nn.Module = torch.nn.Linear(
|
||||
embed_dim, embed_dim, bias=add_bias_kv, device=self.device
|
||||
)
|
||||
torch.nn.init.xavier_uniform_(self.query.weight)
|
||||
torch.nn.init.xavier_uniform_(self.key.weight)
|
||||
torch.nn.init.xavier_uniform_(self.value.weight)
|
||||
if add_bias_kv:
|
||||
torch.nn.init.zeros_(self.query.bias)
|
||||
torch.nn.init.zeros_(self.key.bias)
|
||||
torch.nn.init.zeros_(self.value.bias)
|
||||
self.proj: torch.nn.Module = torch.nn.Linear(
|
||||
embed_dim, embed_dim, bias=bias, device=self.device
|
||||
)
|
||||
torch.nn.init.kaiming_uniform_(self.proj.weight, a=math.sqrt(5))
|
||||
if bias:
|
||||
torch.nn.init.zeros_(self.proj.bias)
|
||||
self.tp_size = tp_size
|
||||
self.hidden_size = embed_dim
|
||||
self.norm_factor: float = math.sqrt(self.hidden_size_per_attention_head)
|
||||
self.self_attention = self_attention
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: Union[torch.Tensor, DT],
|
||||
key: Union[torch.Tensor, DT],
|
||||
value: Union[torch.Tensor, DT],
|
||||
key_padding_mask: Optional[torch.Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
average_attn_weights: bool = True,
|
||||
) -> Union[torch.Tensor, DT]:
|
||||
b, sq, h = query.shape
|
||||
sk = key.size(1)
|
||||
nh = self.num_heads
|
||||
hn = self.hidden_size_per_attention_head
|
||||
|
||||
# x: [b, sq/sk/sv, h]
|
||||
# ===================
|
||||
# Permute. [sq/sk/sv, b, h]
|
||||
# ===================
|
||||
if not self.self_attention:
|
||||
# =====================
|
||||
# Query, Key, and Value
|
||||
# =====================
|
||||
query = query.permute(1, 0, 2).contiguous()
|
||||
key = key.permute(1, 0, 2).contiguous()
|
||||
value = value.permute(1, 0, 2).contiguous()
|
||||
|
||||
# Attention heads [sq/sk/sv, b, h] --> [sq/sk/sv * b, (nh * hn)]
|
||||
query = query.view(-1, h)
|
||||
key = key.view(-1, h)
|
||||
value = value.view(-1, h)
|
||||
|
||||
query_layer = _view_with_sharding_dim_change(
|
||||
self.query(query), 1, (sq, b * nh, hn)
|
||||
)
|
||||
key_layer = _view_with_sharding_dim_change(
|
||||
self.key(key), 1, (sk, b * nh, hn)
|
||||
)
|
||||
value_layer = _view_with_sharding_dim_change(
|
||||
self.value(value), 1, (sk, b * nh, hn)
|
||||
)
|
||||
else:
|
||||
assert torch.equal(query, key) and torch.equal(
|
||||
query, value
|
||||
), "inputs are different for self-attention."
|
||||
# =====================
|
||||
# Query
|
||||
# =====================
|
||||
query = query.permute(1, 0, 2).contiguous()
|
||||
|
||||
# Attention heads [sq, b, h] --> [sq * b, (nh * 3 * hn)]
|
||||
query = query.view(-1, h)
|
||||
mixed_x_layer = self.qkv(query)
|
||||
|
||||
# [sq * b, 3 * h] --> [sq, b, nh, 3 * hn]
|
||||
mixed_x_layer = _view_with_sharding_dim_change(
|
||||
mixed_x_layer, 2, (sq, b, nh, 3 * hn)
|
||||
)
|
||||
|
||||
# [sq, b, nh, 3 * hn] --> 3 [sq, b, nh, hn]
|
||||
last_dim = mixed_x_layer.dim() - 1
|
||||
last_dim_size = mixed_x_layer.size(last_dim) // 3
|
||||
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
|
||||
last_dim_size, dim=last_dim
|
||||
)
|
||||
|
||||
query_layer = _stride_same_as_shard(query_layer, self.tp_size, 2, 1)
|
||||
key_layer = _stride_same_as_shard(key_layer, self.tp_size, 2, 1)
|
||||
value_layer = _stride_same_as_shard(value_layer, self.tp_size, 2, 1)
|
||||
# [sq, b, nh, hn] -> [sq, b * nh, hn]
|
||||
query_layer = _view_with_sharding_dim_change(
|
||||
query_layer, 1, (sq, b * nh, -1)
|
||||
)
|
||||
key_layer = _view_with_sharding_dim_change(key_layer, 1, (sq, b * nh, -1))
|
||||
value_layer = _view_with_sharding_dim_change(
|
||||
value_layer, 1, (sq, b * nh, -1)
|
||||
)
|
||||
|
||||
# ===================================
|
||||
# Raw attention scores. [b, nh, s, s]
|
||||
# ===================================
|
||||
|
||||
factor = self.tp_size if isinstance(query_layer, DT) else 1
|
||||
# preallocating result tensor: [b * nh, sq, sk]
|
||||
matmul_result = torch.empty(
|
||||
b * nh // factor,
|
||||
sq,
|
||||
sk,
|
||||
dtype=query_layer.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
if isinstance(query_layer, DT):
|
||||
matmul_result = DT.from_local(
|
||||
matmul_result,
|
||||
query_layer.device_mesh,
|
||||
[Shard(0)],
|
||||
run_check=False,
|
||||
)
|
||||
|
||||
# Raw attention scores. [b * nh, sq, sk]
|
||||
attn = torch.baddbmm(
|
||||
matmul_result,
|
||||
query_layer.transpose(0, 1), # [b * nh, sq, hn]
|
||||
key_layer.transpose(0, 1).transpose(1, 2), # [b * nh, hn, sk]
|
||||
beta=0.0,
|
||||
alpha=(1.0 / self.norm_factor),
|
||||
)
|
||||
|
||||
# ===============
|
||||
# Attention probs
|
||||
# ===============
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
# =========================
|
||||
# Context layer. [sq * b, hidden]
|
||||
# =========================
|
||||
|
||||
# bmm: [b * nh, sq, hn]
|
||||
context_layer = torch.bmm(attn, value_layer.transpose(0, 1))
|
||||
|
||||
# change view [nh, b, sq, hn]
|
||||
context_layer = context_layer.view(nh, b, sq, hn)
|
||||
|
||||
# [nh, b, sq, hn] --> [sq, b, nh, hn]
|
||||
context_layer = context_layer.permute(2, 1, 0, 3).contiguous()
|
||||
|
||||
# [sq, b, nh, hn] --> [sq * b, hidden]
|
||||
context_layer = _view_with_sharding_dim_change(
|
||||
context_layer.contiguous(), 1, (-1, self.hidden_size)
|
||||
)
|
||||
|
||||
# =================
|
||||
# Projection. [sq, b, h]
|
||||
# =================
|
||||
output = self.proj(context_layer).view(sq, b, h)
|
||||
|
||||
# ===================
|
||||
# Permute. [b, sq, h]
|
||||
# ===================
|
||||
output = output.permute(1, 0, 2)
|
||||
|
||||
return output
|
||||
|
||||
def copy(self, that: torch.nn.MultiheadAttention) -> None:
|
||||
# TODO: current implementation assume `self` is a self attention module
|
||||
assert (
|
||||
self.hidden_size == that.embed_dim
|
||||
), "embed_dim must be equal in TensorParallelMultiheadAttention.copy()!"
|
||||
|
||||
if that.in_proj_weight is not None:
|
||||
self.qkv.register_parameter("weight", that.in_proj_weight)
|
||||
if that.in_proj_bias is not None:
|
||||
self.qkv.register_parameter("bias", that.in_proj_bias)
|
||||
if that.out_proj.weight is not None:
|
||||
# TODO: The use of Parameter is to avoid `mypy` issue caused
|
||||
# by the `tensor` type annotation on Linear.weight to which
|
||||
# a Parameter object is actually assigned
|
||||
self.proj.register_parameter(
|
||||
"weight", torch.nn.Parameter(that.out_proj.weight)
|
||||
)
|
||||
if that.out_proj.bias is not None:
|
||||
self.proj.register_parameter("bias", that.out_proj.bias)
|
||||
Loading…
Reference in a new issue