mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-15 21:01:19 +00:00
Fix megatron_gpt2 attention block's causal mask (#12007)
* Fix megatron_gpt2 attention block's causal mask. * compatibility with checkpoints created with recent versions of Megatron-LM * added integration test for the released Megatron-GPT2 model * code style changes * added option to megatron conversion script to read from config file Co-authored-by: Guido Novati <gnovati@nvidia.com>
This commit is contained in:
parent
783b0dd589
commit
ecd6efe7cb
2 changed files with 186 additions and 44 deletions
|
|
@ -24,6 +24,8 @@ import zipfile
|
|||
|
||||
import torch
|
||||
|
||||
from transformers import GPT2Config
|
||||
|
||||
|
||||
####################################################################################################
|
||||
|
||||
|
|
@ -48,17 +50,45 @@ def recursive_print(name, val, spaces=0):
|
|||
print(msg, ":", val)
|
||||
|
||||
|
||||
def fix_query_key_value_ordering(param, checkpoint_version, num_splits, num_heads, hidden_size):
|
||||
# Permutes layout of param tensor to [num_splits * num_heads * hidden_size, :]
|
||||
# for compatibility with later versions of NVIDIA Megatron-LM.
|
||||
# The inverse operation is performed inside Megatron-LM to read checkpoints:
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/v2.4/megatron/checkpointing.py#L209
|
||||
# If param is the weight tensor of the self-attention block, the returned tensor
|
||||
# will have to be transposed one more time to be read by HuggingFace GPT2.
|
||||
input_shape = param.size()
|
||||
if checkpoint_version == 1.0:
|
||||
# version 1.0 stores [num_heads * hidden_size * num_splits, :]
|
||||
saved_shape = (num_heads, hidden_size, num_splits) + input_shape[1:]
|
||||
param = param.view(*saved_shape)
|
||||
param = param.transpose(0, 2)
|
||||
param = param.transpose(1, 2).contiguous()
|
||||
elif checkpoint_version >= 2.0:
|
||||
# other versions store [num_heads * num_splits * hidden_size, :]
|
||||
saved_shape = (num_heads, num_splits, hidden_size) + input_shape[1:]
|
||||
param = param.view(*saved_shape)
|
||||
param = param.transpose(0, 1).contiguous()
|
||||
param = param.view(*input_shape)
|
||||
return param
|
||||
|
||||
|
||||
####################################################################################################
|
||||
|
||||
|
||||
def convert_megatron_checkpoint(args, input_state_dict):
|
||||
def convert_megatron_checkpoint(args, input_state_dict, config):
|
||||
# The converted output model.
|
||||
output_state_dict = {}
|
||||
|
||||
# The number of heads.
|
||||
heads = 16
|
||||
heads = config.n_head
|
||||
# The hidden_size per head.
|
||||
hidden_size_per_head = 64
|
||||
hidden_size_per_head = config.n_embd // config.n_head
|
||||
# Megatron-LM checkpoint version
|
||||
if "checkpoint_version" in input_state_dict.keys():
|
||||
checkpoint_version = input_state_dict["checkpoint_version"]
|
||||
else:
|
||||
checkpoint_version = 0.0
|
||||
|
||||
# The model.
|
||||
model = input_state_dict["model"]
|
||||
|
|
@ -69,22 +99,21 @@ def convert_megatron_checkpoint(args, input_state_dict):
|
|||
|
||||
# The word embeddings.
|
||||
word_embeddings = embeddings["word_embeddings"]["weight"]
|
||||
# Truncate the embedding table to 50257 rows.
|
||||
word_embeddings = word_embeddings[:50257, :]
|
||||
# Truncate the embedding table to 50257 rows.
|
||||
# Truncate the embedding table to vocab_size rows.
|
||||
word_embeddings = word_embeddings[: config.vocab_size, :]
|
||||
output_state_dict["transformer.wte.weight"] = word_embeddings
|
||||
|
||||
# The position embeddings.
|
||||
pos_embeddings = embeddings["position_embeddings"]["weight"]
|
||||
# Read the hidden dimension.
|
||||
hidden_size = pos_embeddings.size(0)
|
||||
n_embed = pos_embeddings.size(0)
|
||||
# DEBUG.
|
||||
assert hidden_size == heads * hidden_size_per_head
|
||||
assert n_embed == heads * hidden_size_per_head
|
||||
# Store the position embeddings.
|
||||
output_state_dict["transformer.wpe.weight"] = pos_embeddings
|
||||
|
||||
# The transformer.
|
||||
transformer = lm["transformer"]
|
||||
transformer = lm["transformer"] if "transformer" in lm.keys() else lm["encoder"]
|
||||
|
||||
# The regex to extract layer names.
|
||||
layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")
|
||||
|
|
@ -92,6 +121,7 @@ def convert_megatron_checkpoint(args, input_state_dict):
|
|||
# The simple map of names for "automated" rules.
|
||||
megatron_to_transformers = {
|
||||
"attention.dense": ".attn.c_proj.",
|
||||
"self_attention.dense": ".attn.c_proj.",
|
||||
"mlp.dense_h_to_4h": ".mlp.c_fc.",
|
||||
"mlp.dense_4h_to_h": ".mlp.c_proj.",
|
||||
}
|
||||
|
|
@ -122,26 +152,32 @@ def convert_megatron_checkpoint(args, input_state_dict):
|
|||
output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = val
|
||||
|
||||
# Transpose the QKV matrix.
|
||||
elif op_name == "attention.query_key_value" and weight_or_bias == "weight":
|
||||
elif (
|
||||
op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value"
|
||||
) and weight_or_bias == "weight":
|
||||
|
||||
# Insert a tensor of 1x1xDxD bias.
|
||||
zeros = torch.ones(1, 1, hidden_size, hidden_size)
|
||||
output_state_dict[layer_name + ".attn.bias"] = zeros
|
||||
causal_mask = torch.tril(torch.ones((n_embed, n_embed), dtype=torch.uint8)).view(1, 1, n_embed, n_embed)
|
||||
output_state_dict[layer_name + ".attn.bias"] = causal_mask
|
||||
|
||||
# Insert a "dummy" tensor for masked_bias.
|
||||
masked_bias = torch.tensor(-1e4)
|
||||
output_state_dict[layer_name + ".attn.masked_bias"] = masked_bias
|
||||
|
||||
out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head)
|
||||
# Megatron stores (3*D) x D but transformers-GPT2 expects D x 3*D.
|
||||
out_val = val.transpose(0, 1)
|
||||
out_val = out_val.transpose(0, 1).contiguous()
|
||||
# Store.
|
||||
output_state_dict[layer_name + ".attn.c_attn.weight"] = out_val
|
||||
|
||||
# Transpose the bias.
|
||||
elif op_name == "attention.query_key_value" and weight_or_bias == "bias":
|
||||
elif (
|
||||
op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value"
|
||||
) and weight_or_bias == "bias":
|
||||
|
||||
out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head)
|
||||
# Store. No change of shape.
|
||||
output_state_dict[layer_name + ".attn.c_attn.bias"] = val
|
||||
output_state_dict[layer_name + ".attn.c_attn.bias"] = out_val
|
||||
|
||||
# Transpose the weights.
|
||||
elif weight_or_bias == "weight":
|
||||
|
|
@ -155,6 +191,9 @@ def convert_megatron_checkpoint(args, input_state_dict):
|
|||
out_name = megatron_to_transformers[op_name]
|
||||
output_state_dict[layer_name + out_name + "bias"] = val
|
||||
|
||||
# DEBUG.
|
||||
assert config.n_layer == layer_idx + 1
|
||||
|
||||
# The final layernorm.
|
||||
output_state_dict["transformer.ln_f.weight"] = transformer["final_layernorm.weight"]
|
||||
output_state_dict["transformer.ln_f.bias"] = transformer["final_layernorm.bias"]
|
||||
|
|
@ -162,33 +201,8 @@ def convert_megatron_checkpoint(args, input_state_dict):
|
|||
# For LM head, transformers' wants the matrix to weight embeddings.
|
||||
output_state_dict["lm_head.weight"] = word_embeddings
|
||||
|
||||
# The config.
|
||||
output_config = {
|
||||
"activation_function": "gelu_new",
|
||||
"architectures": ["GPT2LMHeadModel"],
|
||||
"attn_pdrop": 0.1,
|
||||
"bos_token_id": 50256,
|
||||
"embd_pdrop": 0.1,
|
||||
"eos_token_id": 50256,
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_epsilon": 1e-05,
|
||||
"model_type": "gpt2",
|
||||
"n_ctx": 1024,
|
||||
"n_embd": 1024,
|
||||
"n_head": 16,
|
||||
"n_layer": 24,
|
||||
"n_positions": 1024,
|
||||
"resid_pdrop": 0.1,
|
||||
"summary_activation": None,
|
||||
"summary_first_dropout": 0.1,
|
||||
"summary_proj_to_labels": True,
|
||||
"summary_type": "cls_index",
|
||||
"summary_use_proj": True,
|
||||
"vocab_size": 50257,
|
||||
}
|
||||
|
||||
# It should be done!
|
||||
return output_state_dict, output_config
|
||||
return output_state_dict
|
||||
|
||||
|
||||
####################################################################################################
|
||||
|
|
@ -198,21 +212,62 @@ def main():
|
|||
# Create the argument parser.
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--print-checkpoint-structure", action="store_true")
|
||||
parser.add_argument("path_to_checkpoint", type=str, help="Path to the ZIP file containing the checkpoint")
|
||||
parser.add_argument(
|
||||
"path_to_checkpoint",
|
||||
type=str,
|
||||
help="Path to the ZIP file containing the checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
default="",
|
||||
type=str,
|
||||
help="An optional config json file describing the pre-trained model.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Extract the basename.
|
||||
basename = os.path.dirname(args.path_to_checkpoint)
|
||||
|
||||
# Load the model.
|
||||
print('Extracting PyTorch state dictionary from "{}"'.format(args.path_to_checkpoint))
|
||||
print(f"Extracting PyTorch state dictionary from {args.path_to_checkpoint}")
|
||||
with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint:
|
||||
with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict:
|
||||
input_state_dict = torch.load(pytorch_dict, map_location="cpu")
|
||||
|
||||
# Read the config, or default to the model released by NVIDIA.
|
||||
if args.config_file == "":
|
||||
# Spell out all parameters in case the defaults change.
|
||||
config = GPT2Config(
|
||||
vocab_size=50257,
|
||||
n_positions=1024,
|
||||
n_ctx=1024,
|
||||
n_embd=1024,
|
||||
n_layer=24,
|
||||
n_head=16,
|
||||
n_inner=4096,
|
||||
activation_function="gelu_new",
|
||||
resid_pdrop=0.1,
|
||||
embd_pdrop=0.1,
|
||||
attn_pdrop=0.1,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
summary_type="cls_index",
|
||||
summary_use_proj=True,
|
||||
summary_activation=None,
|
||||
summary_proj_to_labels=True,
|
||||
summary_first_dropout=0.1,
|
||||
scale_attn_weights=True,
|
||||
gradient_checkpointing=False,
|
||||
use_cache=True,
|
||||
bos_token_id=50256,
|
||||
eos_token_id=50256,
|
||||
)
|
||||
else:
|
||||
config = GPT2Config.from_json_file(args.config_file)
|
||||
|
||||
# Convert.
|
||||
print("Converting")
|
||||
output_state_dict, output_config = convert_megatron_checkpoint(args, input_state_dict)
|
||||
output_state_dict = convert_megatron_checkpoint(args, input_state_dict, config)
|
||||
|
||||
# Print the structure of converted state dict.
|
||||
if args.print_checkpoint_structure:
|
||||
|
|
@ -220,6 +275,9 @@ def main():
|
|||
|
||||
# Store the config to file.
|
||||
output_config_file = os.path.join(basename, "config.json")
|
||||
output_config = config.to_dict()
|
||||
output_config["architectures"] = ["GPT2LMHeadModel"]
|
||||
output_config["model_type"] = "gpt2"
|
||||
print(f'Saving config to "{output_config_file}"')
|
||||
with open(output_config_file, "w") as f:
|
||||
json.dump(output_config, f)
|
||||
|
|
|
|||
84
tests/test_modeling_megatron_gpt2.py
Normal file
84
tests/test_modeling_megatron_gpt2.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import GPT2LMHeadModel
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class MegatronGPT2IntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_no_head(self):
|
||||
directory = "nvidia/megatron-gpt2-345m/"
|
||||
if "MYDIR" in os.environ:
|
||||
directory = os.path.join(os.environ["MYDIR"], directory)
|
||||
model = GPT2LMHeadModel.from_pretrained(directory)
|
||||
model.to(torch_device)
|
||||
model.half()
|
||||
|
||||
input_ids = torch.tensor(
|
||||
[[101, 7110, 1005, 1056, 2023, 11333, 17413, 1029, 102]],
|
||||
device=torch_device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_ids).logits
|
||||
|
||||
expected_shape = torch.Size((1, 9, 50257))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
expected_diag = torch.tensor(
|
||||
[
|
||||
4.9414,
|
||||
-0.2920,
|
||||
-1.2148,
|
||||
-4.0273,
|
||||
-0.5161,
|
||||
-5.2109,
|
||||
-1.2412,
|
||||
-1.8301,
|
||||
-1.7734,
|
||||
-4.7148,
|
||||
-0.2317,
|
||||
-1.0811,
|
||||
-2.1777,
|
||||
0.4141,
|
||||
-3.7969,
|
||||
-4.0586,
|
||||
-2.5332,
|
||||
-3.3809,
|
||||
4.3867,
|
||||
],
|
||||
device=torch_device,
|
||||
dtype=torch.half,
|
||||
)
|
||||
|
||||
for i in range(19):
|
||||
r, c = 8 * i // 17, 2792 * i # along the diagonal
|
||||
computed, expected = output[0, r, c], expected_diag[i]
|
||||
msg = f"row={r} col={c} computed={computed} expected={expected}"
|
||||
self.assertAlmostEqual(computed, expected, delta=1e-4, msg=msg)
|
||||
Loading…
Reference in a new issue