From e691fc0963b64e8c19b4a71ddaefe418096776b6 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 15 Jul 2019 17:45:24 +0200 Subject: [PATCH] update QA models tests + run_generation --- examples/run_generation.py | 17 ++++++------- examples/test_examples.py | 3 ++- .../tests/modeling_xlm_test.py | 24 ++++++++++++------- .../tests/modeling_xlnet_test.py | 24 ++++++++++++------- 4 files changed, 41 insertions(+), 27 deletions(-) diff --git a/examples/run_generation.py b/examples/run_generation.py index 4108b2894..a2a8f2910 100644 --- a/examples/run_generation.py +++ b/examples/run_generation.py @@ -131,8 +131,10 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k= def main(): parser = argparse.ArgumentParser() - parser.add_argument('--model_name', type=str, default=None, required=True, - help="GPT, GPT-2, Transformer-XL or XLNet pre-trained model selected in the list: " + ", ".join(ALL_MODELS)) + parser.add_argument("--model_type", default=None, type=str, required=True, + help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) + parser.add_argument("--model_name_or_path", default=None, type=str, required=True, + help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) parser.add_argument("--prompt", type=str, default="") parser.add_argument("--padding_text", type=str, default="") parser.add_argument("--length", type=int, default=20) @@ -150,15 +152,10 @@ def main(): set_seed(args) - args.model_type = "" - for key in MODEL_CLASSES: - if key in args.model_name.lower(): - args.model_type = key # take the first match in model types - break - + args.model_type = args.model_type.lower() model_class, tokenizer_class = MODEL_CLASSES[args.model_type] - tokenizer = tokenizer_class.from_pretrained(args.model_name) - model = model_class.from_pretrained(args.model_name) + tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) + model = model_class.from_pretrained(args.model_name_or_path) model.to(args.device) model.eval() diff --git a/examples/test_examples.py b/examples/test_examples.py index 00370e936..2f88d129f 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -101,7 +101,8 @@ class ExamplesTests(unittest.TestCase): "--prompt=Hello", "--length=10", "--seed=42"] - model_name = "--model_name=openai-gpt" + model_type, model_name = ("--model_type=openai-gpt", + "--model_name_or_path=openai-gpt") with patch.object(sys, 'argv', testargs + [model_name]): result = run_generation.main() self.assertGreaterEqual(len(result), 10) diff --git a/pytorch_transformers/tests/modeling_xlm_test.py b/pytorch_transformers/tests/modeling_xlm_test.py index 85189859a..4308c18d4 100644 --- a/pytorch_transformers/tests/modeling_xlm_test.py +++ b/pytorch_transformers/tests/modeling_xlm_test.py @@ -191,17 +191,19 @@ class XLMModelTest(CommonTestCases.CommonModelTester): cls_index=sequence_labels, is_impossible=is_impossible_labels) - total_loss, start_logits, end_logits, cls_logits = outputs + (total_loss,) = outputs outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels) - total_loss, start_logits, end_logits = outputs + (total_loss,) = outputs result = { "loss": total_loss, - "start_logits": start_logits, - "end_logits": end_logits, + "start_top_log_probs": start_top_log_probs, + "start_top_index": start_top_index, + "end_top_log_probs": end_top_log_probs, + "end_top_index": end_top_index, "cls_logits": cls_logits, } @@ -209,11 +211,17 @@ class XLMModelTest(CommonTestCases.CommonModelTester): list(result["loss"].size()), []) self.parent.assertListEqual( - list(result["start_logits"].size()), - [self.batch_size, self.seq_length]) + list(result["start_top_log_probs"].size()), + [self.batch_size, model.config.start_n_top]) self.parent.assertListEqual( - list(result["end_logits"].size()), - [self.batch_size, self.seq_length]) + list(result["start_top_index"].size()), + [self.batch_size, model.config.start_n_top]) + self.parent.assertListEqual( + list(result["end_top_log_probs"].size()), + [self.batch_size, model.config.start_n_top * model.config.end_n_top]) + self.parent.assertListEqual( + list(result["end_top_index"].size()), + [self.batch_size, model.config.start_n_top * model.config.end_n_top]) self.parent.assertListEqual( list(result["cls_logits"].size()), [self.batch_size]) diff --git a/pytorch_transformers/tests/modeling_xlnet_test.py b/pytorch_transformers/tests/modeling_xlnet_test.py index 8360a08d6..290c5766e 100644 --- a/pytorch_transformers/tests/modeling_xlnet_test.py +++ b/pytorch_transformers/tests/modeling_xlnet_test.py @@ -210,17 +210,19 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): cls_index=sequence_labels, is_impossible=is_impossible_labels) - total_loss, start_logits, end_logits, cls_logits, mems = outputs + total_loss, mems = outputs outputs = model(input_ids_1, start_positions=sequence_labels, end_positions=sequence_labels) - total_loss, start_logits, end_logits, mems = outputs + total_loss, mems = outputs result = { "loss": total_loss, - "start_logits": start_logits, - "end_logits": end_logits, + "start_top_log_probs": start_top_log_probs, + "start_top_index": start_top_index, + "end_top_log_probs": end_top_log_probs, + "end_top_index": end_top_index, "cls_logits": cls_logits, "mems": mems, } @@ -229,11 +231,17 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): list(result["loss"].size()), []) self.parent.assertListEqual( - list(result["start_logits"].size()), - [self.batch_size, self.seq_length]) + list(result["start_top_log_probs"].size()), + [self.batch_size, model.config.start_n_top]) self.parent.assertListEqual( - list(result["end_logits"].size()), - [self.batch_size, self.seq_length]) + list(result["start_top_index"].size()), + [self.batch_size, model.config.start_n_top]) + self.parent.assertListEqual( + list(result["end_top_log_probs"].size()), + [self.batch_size, model.config.start_n_top * model.config.end_n_top]) + self.parent.assertListEqual( + list(result["end_top_index"].size()), + [self.batch_size, model.config.start_n_top * model.config.end_n_top]) self.parent.assertListEqual( list(result["cls_logits"].size()), [self.batch_size])