diff --git a/src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py b/src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py index 2d2d54b81..cc8899566 100644 --- a/src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py +++ b/src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py @@ -24,6 +24,8 @@ import zipfile import torch +from transformers import GPT2Config + #################################################################################################### @@ -48,17 +50,45 @@ def recursive_print(name, val, spaces=0): print(msg, ":", val) +def fix_query_key_value_ordering(param, checkpoint_version, num_splits, num_heads, hidden_size): + # Permutes layout of param tensor to [num_splits * num_heads * hidden_size, :] + # for compatibility with later versions of NVIDIA Megatron-LM. + # The inverse operation is performed inside Megatron-LM to read checkpoints: + # https://github.com/NVIDIA/Megatron-LM/blob/v2.4/megatron/checkpointing.py#L209 + # If param is the weight tensor of the self-attention block, the returned tensor + # will have to be transposed one more time to be read by HuggingFace GPT2. + input_shape = param.size() + if checkpoint_version == 1.0: + # version 1.0 stores [num_heads * hidden_size * num_splits, :] + saved_shape = (num_heads, hidden_size, num_splits) + input_shape[1:] + param = param.view(*saved_shape) + param = param.transpose(0, 2) + param = param.transpose(1, 2).contiguous() + elif checkpoint_version >= 2.0: + # other versions store [num_heads * num_splits * hidden_size, :] + saved_shape = (num_heads, num_splits, hidden_size) + input_shape[1:] + param = param.view(*saved_shape) + param = param.transpose(0, 1).contiguous() + param = param.view(*input_shape) + return param + + #################################################################################################### -def convert_megatron_checkpoint(args, input_state_dict): +def convert_megatron_checkpoint(args, input_state_dict, config): # The converted output model. output_state_dict = {} # The number of heads. - heads = 16 + heads = config.n_head # The hidden_size per head. - hidden_size_per_head = 64 + hidden_size_per_head = config.n_embd // config.n_head + # Megatron-LM checkpoint version + if "checkpoint_version" in input_state_dict.keys(): + checkpoint_version = input_state_dict["checkpoint_version"] + else: + checkpoint_version = 0.0 # The model. model = input_state_dict["model"] @@ -69,22 +99,21 @@ def convert_megatron_checkpoint(args, input_state_dict): # The word embeddings. word_embeddings = embeddings["word_embeddings"]["weight"] - # Truncate the embedding table to 50257 rows. - word_embeddings = word_embeddings[:50257, :] - # Truncate the embedding table to 50257 rows. + # Truncate the embedding table to vocab_size rows. + word_embeddings = word_embeddings[: config.vocab_size, :] output_state_dict["transformer.wte.weight"] = word_embeddings # The position embeddings. pos_embeddings = embeddings["position_embeddings"]["weight"] # Read the hidden dimension. - hidden_size = pos_embeddings.size(0) + n_embed = pos_embeddings.size(0) # DEBUG. - assert hidden_size == heads * hidden_size_per_head + assert n_embed == heads * hidden_size_per_head # Store the position embeddings. output_state_dict["transformer.wpe.weight"] = pos_embeddings # The transformer. - transformer = lm["transformer"] + transformer = lm["transformer"] if "transformer" in lm.keys() else lm["encoder"] # The regex to extract layer names. layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") @@ -92,6 +121,7 @@ def convert_megatron_checkpoint(args, input_state_dict): # The simple map of names for "automated" rules. megatron_to_transformers = { "attention.dense": ".attn.c_proj.", + "self_attention.dense": ".attn.c_proj.", "mlp.dense_h_to_4h": ".mlp.c_fc.", "mlp.dense_4h_to_h": ".mlp.c_proj.", } @@ -122,26 +152,32 @@ def convert_megatron_checkpoint(args, input_state_dict): output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = val # Transpose the QKV matrix. - elif op_name == "attention.query_key_value" and weight_or_bias == "weight": + elif ( + op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value" + ) and weight_or_bias == "weight": # Insert a tensor of 1x1xDxD bias. - zeros = torch.ones(1, 1, hidden_size, hidden_size) - output_state_dict[layer_name + ".attn.bias"] = zeros + causal_mask = torch.tril(torch.ones((n_embed, n_embed), dtype=torch.uint8)).view(1, 1, n_embed, n_embed) + output_state_dict[layer_name + ".attn.bias"] = causal_mask # Insert a "dummy" tensor for masked_bias. masked_bias = torch.tensor(-1e4) output_state_dict[layer_name + ".attn.masked_bias"] = masked_bias + out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head) # Megatron stores (3*D) x D but transformers-GPT2 expects D x 3*D. - out_val = val.transpose(0, 1) + out_val = out_val.transpose(0, 1).contiguous() # Store. output_state_dict[layer_name + ".attn.c_attn.weight"] = out_val # Transpose the bias. - elif op_name == "attention.query_key_value" and weight_or_bias == "bias": + elif ( + op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value" + ) and weight_or_bias == "bias": + out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head) # Store. No change of shape. - output_state_dict[layer_name + ".attn.c_attn.bias"] = val + output_state_dict[layer_name + ".attn.c_attn.bias"] = out_val # Transpose the weights. elif weight_or_bias == "weight": @@ -155,6 +191,9 @@ def convert_megatron_checkpoint(args, input_state_dict): out_name = megatron_to_transformers[op_name] output_state_dict[layer_name + out_name + "bias"] = val + # DEBUG. + assert config.n_layer == layer_idx + 1 + # The final layernorm. output_state_dict["transformer.ln_f.weight"] = transformer["final_layernorm.weight"] output_state_dict["transformer.ln_f.bias"] = transformer["final_layernorm.bias"] @@ -162,33 +201,8 @@ def convert_megatron_checkpoint(args, input_state_dict): # For LM head, transformers' wants the matrix to weight embeddings. output_state_dict["lm_head.weight"] = word_embeddings - # The config. - output_config = { - "activation_function": "gelu_new", - "architectures": ["GPT2LMHeadModel"], - "attn_pdrop": 0.1, - "bos_token_id": 50256, - "embd_pdrop": 0.1, - "eos_token_id": 50256, - "initializer_range": 0.02, - "layer_norm_epsilon": 1e-05, - "model_type": "gpt2", - "n_ctx": 1024, - "n_embd": 1024, - "n_head": 16, - "n_layer": 24, - "n_positions": 1024, - "resid_pdrop": 0.1, - "summary_activation": None, - "summary_first_dropout": 0.1, - "summary_proj_to_labels": True, - "summary_type": "cls_index", - "summary_use_proj": True, - "vocab_size": 50257, - } - # It should be done! - return output_state_dict, output_config + return output_state_dict #################################################################################################### @@ -198,21 +212,62 @@ def main(): # Create the argument parser. parser = argparse.ArgumentParser() parser.add_argument("--print-checkpoint-structure", action="store_true") - parser.add_argument("path_to_checkpoint", type=str, help="Path to the ZIP file containing the checkpoint") + parser.add_argument( + "path_to_checkpoint", + type=str, + help="Path to the ZIP file containing the checkpoint", + ) + parser.add_argument( + "--config_file", + default="", + type=str, + help="An optional config json file describing the pre-trained model.", + ) args = parser.parse_args() # Extract the basename. basename = os.path.dirname(args.path_to_checkpoint) # Load the model. - print('Extracting PyTorch state dictionary from "{}"'.format(args.path_to_checkpoint)) + print(f"Extracting PyTorch state dictionary from {args.path_to_checkpoint}") with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint: with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict: input_state_dict = torch.load(pytorch_dict, map_location="cpu") + # Read the config, or default to the model released by NVIDIA. + if args.config_file == "": + # Spell out all parameters in case the defaults change. + config = GPT2Config( + vocab_size=50257, + n_positions=1024, + n_ctx=1024, + n_embd=1024, + n_layer=24, + n_head=16, + n_inner=4096, + activation_function="gelu_new", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + gradient_checkpointing=False, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + ) + else: + config = GPT2Config.from_json_file(args.config_file) + # Convert. print("Converting") - output_state_dict, output_config = convert_megatron_checkpoint(args, input_state_dict) + output_state_dict = convert_megatron_checkpoint(args, input_state_dict, config) # Print the structure of converted state dict. if args.print_checkpoint_structure: @@ -220,6 +275,9 @@ def main(): # Store the config to file. output_config_file = os.path.join(basename, "config.json") + output_config = config.to_dict() + output_config["architectures"] = ["GPT2LMHeadModel"] + output_config["model_type"] = "gpt2" print(f'Saving config to "{output_config_file}"') with open(output_config_file, "w") as f: json.dump(output_config, f) diff --git a/tests/test_modeling_megatron_gpt2.py b/tests/test_modeling_megatron_gpt2.py new file mode 100644 index 000000000..a1f7c472e --- /dev/null +++ b/tests/test_modeling_megatron_gpt2.py @@ -0,0 +1,84 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +from transformers import is_torch_available +from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device + + +if is_torch_available(): + import torch + + from transformers import GPT2LMHeadModel + + +@require_torch +@require_sentencepiece +@require_tokenizers +class MegatronGPT2IntegrationTest(unittest.TestCase): + @slow + def test_inference_no_head(self): + directory = "nvidia/megatron-gpt2-345m/" + if "MYDIR" in os.environ: + directory = os.path.join(os.environ["MYDIR"], directory) + model = GPT2LMHeadModel.from_pretrained(directory) + model.to(torch_device) + model.half() + + input_ids = torch.tensor( + [[101, 7110, 1005, 1056, 2023, 11333, 17413, 1029, 102]], + device=torch_device, + dtype=torch.long, + ) + + with torch.no_grad(): + output = model(input_ids).logits + + expected_shape = torch.Size((1, 9, 50257)) + self.assertEqual(output.shape, expected_shape) + + expected_diag = torch.tensor( + [ + 4.9414, + -0.2920, + -1.2148, + -4.0273, + -0.5161, + -5.2109, + -1.2412, + -1.8301, + -1.7734, + -4.7148, + -0.2317, + -1.0811, + -2.1777, + 0.4141, + -3.7969, + -4.0586, + -2.5332, + -3.3809, + 4.3867, + ], + device=torch_device, + dtype=torch.half, + ) + + for i in range(19): + r, c = 8 * i // 17, 2792 * i # along the diagonal + computed, expected = output[0, r, c], expected_diag[i] + msg = f"row={r} col={c} computed={computed} expected={expected}" + self.assertAlmostEqual(computed, expected, delta=1e-4, msg=msg)