Clean up unsed MHA code to avoid confusion (#105956)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105956
Approved by: https://github.com/wz337, https://github.com/ezyang, https://github.com/wanchaol
This commit is contained in:
fduwjj 2023-07-26 17:39:50 -07:00 committed by PyTorch MergeBot
parent 1a59be2c9e
commit 487ebcac3b
5 changed files with 5 additions and 674 deletions

View file

@ -53,13 +53,11 @@ used for input/output preparation:
.. autofunction:: make_output_shard_1d
.. autofunction:: make_output_tensor
Currently, there are some constraints which makes it hard for the `nn.MultiheadAttention`
module to work out of box for Tensor Parallelism, so we built this multihead_attention
module for Tensor Parallelism users. Also, in ``parallelize_module``, we automatically
swap ``nn.MultiheadAttention`` to this custom module when specifying ``PairwiseParallel``.
Currently, there are some constraints which makes it hard for the ``MultiheadAttention``
module to work out of box for Tensor Parallelism, so we recommend users to try ``ColwiseParallel``
and ``RowwiseParallel`` for each parameter. There might be some code changes needed now
since we are parallelizing on the head dim of the ``MultiheadAttention`` module.
.. autoclass:: torch.distributed.tensor.parallel.multihead_attention_tp.TensorParallelMultiheadAttention
:members:
We also enabled 2D parallelism to integrate with ``FullyShardedDataParallel``.
Users just need to call the following API explicitly:

View file

@ -3,7 +3,6 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
@ -13,7 +12,6 @@ from torch.distributed.tensor.parallel import (
PairwiseParallel,
parallelize_module,
SequenceParallel,
TensorParallelMultiheadAttention,
)
from torch.distributed.tensor.parallel.input_reshard import input_reshard
from torch.testing._internal.common_utils import (
@ -25,22 +23,10 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
MLPModule,
NUM_DEVICES,
skip_unless_torch_gpu,
with_comms,
)
class MultiheadAttnWrap(nn.Module):
def __init__(self, embed_dim, num_heads, add_bias_kv=False, device=None):
super().__init__()
self.attn = nn.MultiheadAttention(
embed_dim, num_heads, add_bias_kv=add_bias_kv, device=device
)
def forward(self, query, key, value):
return self.attn(query, key, value)
class DistTensorParallelExampleTest(DTensorTestBase):
def _check_module(self, m1, m2, check_grad=False, rank0_only_params=None):
rank0_only_params = [] if rank0_only_params is None else rank0_only_params
@ -127,289 +113,6 @@ class DistTensorParallelExampleTest(DTensorTestBase):
def test_mlp_megatron_e2e(self, is_seq_parallel, recompute_activation):
self._test_mlp_magatron_e2e(is_seq_parallel=is_seq_parallel, recompute_activation=recompute_activation)
# TensorParallelMultiheadAttention == dist_module(TensorParallelMultiheadAttention)
# baddbmm introduces nan occasionally on CPU: https://github.com/pytorch/pytorch/issues/80588
@with_comms
@skip_unless_torch_gpu
def test_self_attn_megatron_e2e(self):
inp_size = [8, 12, 16]
# Ensure all tp ranks have same input.
torch.manual_seed(0)
inp = torch.rand(*inp_size, device=self.device_type)
# Initialize model using same seed.
torch.manual_seed(5)
model = TensorParallelMultiheadAttention(
16,
8,
tp_size=NUM_DEVICES,
add_bias_kv=True,
device=self.device_type,
)
torch.manual_seed(5)
model_tp = TensorParallelMultiheadAttention(
16,
8,
tp_size=NUM_DEVICES,
add_bias_kv=True,
device=self.device_type,
)
# Ensure model are initialized the same way.
self.assertEqual(model.qkv.weight, model_tp.qkv.weight)
self.assertEqual(model.qkv.bias, model_tp.qkv.bias)
self.assertEqual(model.proj.weight, model_tp.proj.weight)
self.assertEqual(model.proj.bias, model_tp.proj.bias)
# Shard module and initialize optimizer.
device_mesh = DeviceMesh(self.device_type, list(range(NUM_DEVICES)))
parallelize_module(model_tp, device_mesh, PairwiseParallel())
device_mesh = model_tp.qkv.weight.device_mesh
replicate = [Replicate()] * device_mesh.ndim
# Ensure model are initialized the same way.
self.assertEqual(
model.qkv.weight,
model_tp.qkv.weight.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
self.assertEqual(
model.qkv.bias,
model_tp.qkv.bias.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
self.assertEqual(
model.proj.weight,
model_tp.proj.weight.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
self.assertEqual(
model.proj.bias,
model_tp.proj.bias.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
LR = 0.25
optim = torch.optim.SGD(model.parameters(), lr=LR)
optim_tp = torch.optim.SGD(model_tp.parameters(), lr=LR)
output = model(inp, inp, inp)
output_tp = model_tp(inp, inp, inp)
self.assertEqual(output, output_tp)
output.sum().backward()
output_tp.sum().backward()
device_mesh = model_tp.qkv.weight.device_mesh
# Ensure gradients are same.
self.assertEqual(
model.qkv.weight.grad,
model_tp.qkv.weight.grad.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
self.assertEqual(
model.qkv.bias.grad,
model_tp.qkv.bias.grad.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
self.assertEqual(
model.proj.weight.grad,
model_tp.proj.weight.grad.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
self.assertEqual(
model.proj.bias.grad,
model_tp.proj.bias.grad.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
optim.step()
optim_tp.step()
# Ensure model weights are still same after update.
self.assertEqual(
model.qkv.weight,
model_tp.qkv.weight.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
self.assertEqual(
model.qkv.bias,
model_tp.qkv.bias.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
self.assertEqual(
model.proj.weight,
model_tp.proj.weight.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
self.assertEqual(
model.proj.bias,
model_tp.proj.bias.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
inp = torch.rand(*inp_size, device=self.device_type)
output = model(inp, inp, inp)
output_tp = model_tp(inp, inp, inp)
self.assertEqual(output, output_tp)
# TensorParallelMultiheadAttention == dist_module(torch.nn.MultiheadAttention)
# baddbmm introduces nan occasionally on CPU: https://github.com/pytorch/pytorch/issues/80588
@with_comms
@skip_unless_torch_gpu
def test_self_attn_replacement_megatron_e2e(self):
inp_size = [8, 12, 16]
# Ensure all tp ranks have same input.
torch.manual_seed(0)
inp = torch.rand(*inp_size, device=self.device_type)
# TODO: our sharding function cannot shard the root node
torch.manual_seed(5)
model = TensorParallelMultiheadAttention(
16,
8,
tp_size=NUM_DEVICES,
add_bias_kv=True,
device=self.device_type,
)
model_tp = MultiheadAttnWrap(16, 8, add_bias_kv=True, device=self.device_type)
# TODO: somehow using torch.nn.MultiheadAttention's initial params does not work
# Use TensorParallelMultiheadAttention parameters instead
x = model.qkv.weight.clone().detach().requires_grad_()
model_tp.attn.register_parameter("in_proj_weight", torch.nn.Parameter(x))
x = model.qkv.bias.clone().detach().requires_grad_()
model_tp.attn.register_parameter("in_proj_bias", torch.nn.Parameter(x))
x = model.proj.weight.clone().detach().requires_grad_()
model_tp.attn.out_proj.register_parameter("weight", torch.nn.Parameter(x))
x = model.proj.bias.clone().detach().requires_grad_()
model_tp.attn.out_proj.register_parameter("bias", torch.nn.Parameter(x))
# check if parameters are same
self.assertEqual(model.qkv.weight, model_tp.attn.in_proj_weight)
self.assertEqual(model.qkv.bias, model_tp.attn.in_proj_bias)
self.assertEqual(model.proj.weight, model_tp.attn.out_proj.weight)
self.assertEqual(model.proj.bias, model_tp.attn.out_proj.bias)
# Shard module and initialize optimizer.
device_mesh = DeviceMesh(self.device_type, list(range(NUM_DEVICES)))
parallelize_module(model_tp, device_mesh, PairwiseParallel())
device_mesh = model_tp.attn.qkv.weight.device_mesh
replicate = [Replicate()] * device_mesh.ndim
# Ensure model are initialized the same way.
self.assertEqual(
model.qkv.weight,
model_tp.attn.qkv.weight.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
self.assertEqual(
model.qkv.bias,
model_tp.attn.qkv.bias.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
self.assertEqual(
model.proj.weight,
model_tp.attn.proj.weight.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
self.assertEqual(
model.proj.bias,
model_tp.attn.proj.bias.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
LR = 0.25
optim = torch.optim.SGD(model.parameters(), lr=LR)
optim_tp = torch.optim.SGD(model_tp.parameters(), lr=LR)
output = model(inp, inp, inp)
output_tp = model_tp(inp, inp, inp)
self.assertEqual(output, output_tp)
output.sum().backward()
output_tp.sum().backward()
device_mesh = model_tp.attn.qkv.weight.device_mesh
# Ensure gradients are same.
self.assertEqual(
model.qkv.weight.grad,
model_tp.attn.qkv.weight.grad.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
self.assertEqual(
model.qkv.bias.grad,
model_tp.attn.qkv.bias.grad.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
self.assertEqual(
model.proj.weight.grad,
model_tp.attn.proj.weight.grad.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
self.assertEqual(
model.proj.bias.grad,
model_tp.attn.proj.bias.grad.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
optim.step()
optim_tp.step()
# Ensure model weights are still same after update.
self.assertEqual(
model.qkv.weight,
model_tp.attn.qkv.weight.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
self.assertEqual(
model.qkv.bias,
model_tp.attn.qkv.bias.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
self.assertEqual(
model.proj.weight,
model_tp.attn.proj.weight.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
self.assertEqual(
model.proj.bias,
model_tp.attn.proj.bias.redistribute(
device_mesh=device_mesh, placements=replicate
).to_local(),
)
inp = torch.rand(*inp_size, device=self.device_type)
output = model(inp, inp, inp)
output_tp = model_tp(inp, inp, inp)
self.assertEqual(output, output_tp)
instantiate_parametrized_tests(DistTensorParallelExampleTest)

View file

@ -1,8 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from torch.distributed.tensor.parallel.api import parallelize_module
from torch.distributed.tensor.parallel.multihead_attention_tp import (
TensorParallelMultiheadAttention,
)
from torch.distributed.tensor.parallel.style import (
ColwiseParallel,
@ -27,7 +24,6 @@ __all__ = [
"ParallelStyle",
"RowwiseParallel",
"SequenceParallel",
"TensorParallelMultiheadAttention",
"make_input_replicate_1d",
"make_input_reshard_replicate",
"make_input_shard_1d",

View file

@ -18,9 +18,6 @@ from torch.distributed._tensor.random import (
)
from torch.distributed._tensor.sharding_prop import _CachingPropagator
from torch.distributed.tensor.parallel._utils import _create_1d_device_mesh
from torch.distributed.tensor.parallel.multihead_attention_tp import (
TensorParallelMultiheadAttention,
)
from torch.distributed.tensor.parallel.style import (
ColwiseParallel,
PairwiseParallel,
@ -108,9 +105,7 @@ def parallelize_module( # type: ignore[return]
if isinstance(parallelize_plan, (ColwiseParallel, RowwiseParallel)):
return _parallelize_linear(module, device_mesh, parallelize_plan)
# PairwiseParallel
if _is_mha_for_pairwise_parallel(module):
return _parallelize_multihead_attn(module, device_mesh)
elif _is_mlp_for_pairwise_parallel(module):
if _is_mlp_for_pairwise_parallel(module):
return _parallelize_mlp(module, device_mesh, parallelize_plan)
else:
for n, m in module.named_children():
@ -140,20 +135,6 @@ def parallelize_module( # type: ignore[return]
)
def _is_mha_for_pairwise_parallel(module: nn.Module) -> bool:
"""
Check whether the mha module is the one can be handled for Pairwise parallel.
Args:
module (:class:`nn.Module`):
Module to be checked.
Return:
A boolean object which specifies whether the module is MHA supported by Pairwise parallel or not.
"""
return isinstance(module, (TensorParallelMultiheadAttention, nn.MultiheadAttention))
def _is_mlp_for_pairwise_parallel(module: nn.Module) -> bool:
"""
Traverse through all the immediate children of the given module and count the
@ -304,80 +285,6 @@ def _parallelize_linear(
return module
def _parallelize_multihead_attn(
module: nn.Module,
device_mesh: DeviceMesh,
parallel_style: ParallelStyle = PairwiseParallel(),
tp_mesh_dim: int = 0,
) -> nn.Module:
"""
This function assumes the input module is a sequence of nn.Linear
and we parallelize the module based on the given parallel style.
We don't change the FQN of each sub-module and replace each parameter
in place.
Args:
module (:class:`nn.Module`):
Module to be parallelized.
device_mesh (:class:`DeviceMesh`):
Object which describes the mesh topology of devices.
parallel_style (:class:`ParallelStyle`):
Object which contains how we prepare input/output
for Tensor Parallelism.
tp_mesh_dim (int):
The dimension of `device_mesh` where we perform
Tensor Parallelism on.
Return:
A :class:`nn.Module` object parallelized.
.. warning::
We only support ``PairwiseParallel`` right now.
"""
if not isinstance(parallel_style, PairwiseParallel):
raise NotImplementedError(
"Only support PairwiseParallel for Multihead Attention parallelization."
)
if device_mesh.ndim > 1:
device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
if isinstance(module, nn.MultiheadAttention):
tp_multi_head_attention = TensorParallelMultiheadAttention(
module.embed_dim,
module.num_heads,
device=torch.device(device_mesh.device_type),
tp_size=device_mesh.size(tp_mesh_dim),
add_bias_kv=module.bias_k is not None,
)
tp_multi_head_attention.copy(module)
module = tp_multi_head_attention
assert isinstance(module, TensorParallelMultiheadAttention), (
f"Expects TensorParallelMultiheadAttention but got {type(module)}"
)
# shard TPMA
for n, m in module.named_children():
if n == "qkv":
# Col-wise Parallelize the qkv layer.
distribute_module(
m,
device_mesh,
_colwise_parallelize_linear_fn,
input_fn=parallel_style._prepare_input, # type: ignore[arg-type, misc] # pyre-ignore[6]
)
elif n == "proj":
# Row-wise Parallelize the proj layer
distribute_module(
m,
device_mesh,
_rowwise_parallelize_linear_fn,
output_fn=parallel_style._prepare_output, # type: ignore[arg-type, misc] # pyre-ignore[6]
)
return module
def _parallelize_mlp(
module: nn.Module,
device_mesh: DeviceMesh,

View file

@ -1,273 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# pyre-ignore-all-errors[6]
import math
from typing import Optional, Union
import torch
from torch.distributed._tensor import DTensor as DT
from torch.distributed._tensor.placement_types import Shard
from torch.distributed.tensor.parallel._view_with_dim_change import (
_view_with_sharding_dim_change,
)
__all__ = ["TensorParallelMultiheadAttention"]
# TODO: Add a test to test equivalence between our Multihead Attention
# with other mainstream ones (Megatron-LM or PyTorch).
def _stride_same_as_shard(
tensor: torch.Tensor, tp_size: int, chunk_dim: int, cat_dim: int
) -> torch.Tensor:
"""
Adjust local tensor's stride same as the sharded situation.
So that view result will keeps the same.
"""
if isinstance(tensor, DT):
return tensor
view_size = list(tensor.size())
view_size[chunk_dim] //= tp_size
return torch.cat(
[t.view(*view_size) for t in tensor.chunk(tp_size, dim=chunk_dim)],
dim=cat_dim,
).contiguous()
class TensorParallelMultiheadAttention(torch.nn.Module):
"""
Multi-head Attention block from Transformer models.
Since we need some customizations for the attention layer,
we are writing a customized but mathematically equivalent
attention module as defined in torch.nn.
Note that:
We now only support the case when it's self attention with
limited input args and we also assume that the input tensor
has a dimension of three. Although we do implement the logic
for multihead attention, it was not fully tested.
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
bias: bool = True,
add_bias_kv: bool = False,
add_zero_attn: bool = False,
kdim: Optional[int] = None,
vdim: Optional[int] = None,
batch_first: bool = False,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
tp_size: int = 1,
self_attention: bool = True,
) -> None:
super().__init__()
self.device: torch.device = (
torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device is None
else device
)
self.num_heads = num_heads
self.hidden_size = embed_dim
self.hidden_size_per_attention_head: int = self.hidden_size // num_heads
self.scale: float = self.hidden_size_per_attention_head**-0.5
if self_attention:
self.qkv: torch.nn.Module = torch.nn.Linear(
embed_dim, embed_dim * 3, bias=add_bias_kv, device=self.device
)
torch.nn.init.xavier_uniform_(self.qkv.weight)
if add_bias_kv:
torch.nn.init.zeros_(self.qkv.bias)
else:
self.query: torch.nn.Module = torch.nn.Linear(
embed_dim, embed_dim, bias=add_bias_kv, device=self.device
)
self.key: torch.nn.Module = torch.nn.Linear(
embed_dim, embed_dim, bias=add_bias_kv, device=self.device
)
self.value: torch.nn.Module = torch.nn.Linear(
embed_dim, embed_dim, bias=add_bias_kv, device=self.device
)
torch.nn.init.xavier_uniform_(self.query.weight)
torch.nn.init.xavier_uniform_(self.key.weight)
torch.nn.init.xavier_uniform_(self.value.weight)
if add_bias_kv:
torch.nn.init.zeros_(self.query.bias)
torch.nn.init.zeros_(self.key.bias)
torch.nn.init.zeros_(self.value.bias)
self.proj: torch.nn.Module = torch.nn.Linear(
embed_dim, embed_dim, bias=bias, device=self.device
)
torch.nn.init.kaiming_uniform_(self.proj.weight, a=math.sqrt(5))
if bias:
torch.nn.init.zeros_(self.proj.bias)
self.tp_size = tp_size
self.hidden_size = embed_dim
self.norm_factor: float = math.sqrt(self.hidden_size_per_attention_head)
self.self_attention = self_attention
def forward(
self,
query: Union[torch.Tensor, DT],
key: Union[torch.Tensor, DT],
value: Union[torch.Tensor, DT],
key_padding_mask: Optional[torch.Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[torch.Tensor] = None,
average_attn_weights: bool = True,
) -> Union[torch.Tensor, DT]:
b, sq, h = query.shape
sk = key.size(1)
nh = self.num_heads
hn = self.hidden_size_per_attention_head
# x: [b, sq/sk/sv, h]
# ===================
# Permute. [sq/sk/sv, b, h]
# ===================
if not self.self_attention:
# =====================
# Query, Key, and Value
# =====================
query = query.permute(1, 0, 2).contiguous()
key = key.permute(1, 0, 2).contiguous()
value = value.permute(1, 0, 2).contiguous()
# Attention heads [sq/sk/sv, b, h] --> [sq/sk/sv * b, (nh * hn)]
query = query.view(-1, h)
key = key.view(-1, h)
value = value.view(-1, h)
query_layer = _view_with_sharding_dim_change(
self.query(query), 1, (sq, b * nh, hn)
)
key_layer = _view_with_sharding_dim_change(
self.key(key), 1, (sk, b * nh, hn)
)
value_layer = _view_with_sharding_dim_change(
self.value(value), 1, (sk, b * nh, hn)
)
else:
assert torch.equal(query, key) and torch.equal(
query, value
), "inputs are different for self-attention."
# =====================
# Query
# =====================
query = query.permute(1, 0, 2).contiguous()
# Attention heads [sq, b, h] --> [sq * b, (nh * 3 * hn)]
query = query.view(-1, h)
mixed_x_layer = self.qkv(query)
# [sq * b, 3 * h] --> [sq, b, nh, 3 * hn]
mixed_x_layer = _view_with_sharding_dim_change(
mixed_x_layer, 2, (sq, b, nh, 3 * hn)
)
# [sq, b, nh, 3 * hn] --> 3 [sq, b, nh, hn]
last_dim = mixed_x_layer.dim() - 1
last_dim_size = mixed_x_layer.size(last_dim) // 3
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
last_dim_size, dim=last_dim
)
query_layer = _stride_same_as_shard(query_layer, self.tp_size, 2, 1)
key_layer = _stride_same_as_shard(key_layer, self.tp_size, 2, 1)
value_layer = _stride_same_as_shard(value_layer, self.tp_size, 2, 1)
# [sq, b, nh, hn] -> [sq, b * nh, hn]
query_layer = _view_with_sharding_dim_change(
query_layer, 1, (sq, b * nh, -1)
)
key_layer = _view_with_sharding_dim_change(key_layer, 1, (sq, b * nh, -1))
value_layer = _view_with_sharding_dim_change(
value_layer, 1, (sq, b * nh, -1)
)
# ===================================
# Raw attention scores. [b, nh, s, s]
# ===================================
factor = self.tp_size if isinstance(query_layer, DT) else 1
# preallocating result tensor: [b * nh, sq, sk]
matmul_result = torch.empty(
b * nh // factor,
sq,
sk,
dtype=query_layer.dtype,
device=self.device,
)
if isinstance(query_layer, DT):
matmul_result = DT.from_local(
matmul_result,
query_layer.device_mesh,
[Shard(0)],
run_check=False,
)
# Raw attention scores. [b * nh, sq, sk]
attn = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * nh, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * nh, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
)
# ===============
# Attention probs
# ===============
attn = attn.softmax(dim=-1)
# =========================
# Context layer. [sq * b, hidden]
# =========================
# bmm: [b * nh, sq, hn]
context_layer = torch.bmm(attn, value_layer.transpose(0, 1))
# change view [nh, b, sq, hn]
context_layer = context_layer.view(nh, b, sq, hn)
# [nh, b, sq, hn] --> [sq, b, nh, hn]
context_layer = context_layer.permute(2, 1, 0, 3).contiguous()
# [sq, b, nh, hn] --> [sq * b, hidden]
context_layer = _view_with_sharding_dim_change(
context_layer.contiguous(), 1, (-1, self.hidden_size)
)
# =================
# Projection. [sq, b, h]
# =================
output = self.proj(context_layer).view(sq, b, h)
# ===================
# Permute. [b, sq, h]
# ===================
output = output.permute(1, 0, 2)
return output
def copy(self, that: torch.nn.MultiheadAttention) -> None:
# TODO: current implementation assume `self` is a self attention module
assert (
self.hidden_size == that.embed_dim
), "embed_dim must be equal in TensorParallelMultiheadAttention.copy()!"
if that.in_proj_weight is not None:
self.qkv.register_parameter("weight", that.in_proj_weight)
if that.in_proj_bias is not None:
self.qkv.register_parameter("bias", that.in_proj_bias)
if that.out_proj.weight is not None:
# TODO: The use of Parameter is to avoid `mypy` issue caused
# by the `tensor` type annotation on Linear.weight to which
# a Parameter object is actually assigned
self.proj.register_parameter(
"weight", torch.nn.Parameter(that.out_proj.weight)
)
if that.out_proj.bias is not None:
self.proj.register_parameter("bias", that.out_proj.bias)