mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
update QA models tests + run_generation
This commit is contained in:
parent
15d8b1266c
commit
e691fc0963
4 changed files with 41 additions and 27 deletions
|
|
@ -131,8 +131,10 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
|
|||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_name', type=str, default=None, required=True,
|
||||
help="GPT, GPT-2, Transformer-XL or XLNet pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--prompt", type=str, default="")
|
||||
parser.add_argument("--padding_text", type=str, default="")
|
||||
parser.add_argument("--length", type=int, default=20)
|
||||
|
|
@ -150,15 +152,10 @@ def main():
|
|||
|
||||
set_seed(args)
|
||||
|
||||
args.model_type = ""
|
||||
for key in MODEL_CLASSES:
|
||||
if key in args.model_name.lower():
|
||||
args.model_type = key # take the first match in model types
|
||||
break
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
tokenizer = tokenizer_class.from_pretrained(args.model_name)
|
||||
model = model_class.from_pretrained(args.model_name)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
||||
model = model_class.from_pretrained(args.model_name_or_path)
|
||||
model.to(args.device)
|
||||
model.eval()
|
||||
|
||||
|
|
|
|||
|
|
@ -101,7 +101,8 @@ class ExamplesTests(unittest.TestCase):
|
|||
"--prompt=Hello",
|
||||
"--length=10",
|
||||
"--seed=42"]
|
||||
model_name = "--model_name=openai-gpt"
|
||||
model_type, model_name = ("--model_type=openai-gpt",
|
||||
"--model_name_or_path=openai-gpt")
|
||||
with patch.object(sys, 'argv', testargs + [model_name]):
|
||||
result = run_generation.main()
|
||||
self.assertGreaterEqual(len(result), 10)
|
||||
|
|
|
|||
|
|
@ -191,17 +191,19 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
|
|||
cls_index=sequence_labels,
|
||||
is_impossible=is_impossible_labels)
|
||||
|
||||
total_loss, start_logits, end_logits, cls_logits = outputs
|
||||
(total_loss,) = outputs
|
||||
|
||||
outputs = model(input_ids, start_positions=sequence_labels,
|
||||
end_positions=sequence_labels)
|
||||
|
||||
total_loss, start_logits, end_logits = outputs
|
||||
(total_loss,) = outputs
|
||||
|
||||
result = {
|
||||
"loss": total_loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
"start_top_log_probs": start_top_log_probs,
|
||||
"start_top_index": start_top_index,
|
||||
"end_top_log_probs": end_top_log_probs,
|
||||
"end_top_index": end_top_index,
|
||||
"cls_logits": cls_logits,
|
||||
}
|
||||
|
||||
|
|
@ -209,11 +211,17 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
|
|||
list(result["loss"].size()),
|
||||
[])
|
||||
self.parent.assertListEqual(
|
||||
list(result["start_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
list(result["start_top_log_probs"].size()),
|
||||
[self.batch_size, model.config.start_n_top])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
list(result["start_top_index"].size()),
|
||||
[self.batch_size, model.config.start_n_top])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_top_log_probs"].size()),
|
||||
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_top_index"].size()),
|
||||
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
|
||||
self.parent.assertListEqual(
|
||||
list(result["cls_logits"].size()),
|
||||
[self.batch_size])
|
||||
|
|
|
|||
|
|
@ -210,17 +210,19 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
|||
cls_index=sequence_labels,
|
||||
is_impossible=is_impossible_labels)
|
||||
|
||||
total_loss, start_logits, end_logits, cls_logits, mems = outputs
|
||||
total_loss, mems = outputs
|
||||
|
||||
outputs = model(input_ids_1, start_positions=sequence_labels,
|
||||
end_positions=sequence_labels)
|
||||
|
||||
total_loss, start_logits, end_logits, mems = outputs
|
||||
total_loss, mems = outputs
|
||||
|
||||
result = {
|
||||
"loss": total_loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
"start_top_log_probs": start_top_log_probs,
|
||||
"start_top_index": start_top_index,
|
||||
"end_top_log_probs": end_top_log_probs,
|
||||
"end_top_index": end_top_index,
|
||||
"cls_logits": cls_logits,
|
||||
"mems": mems,
|
||||
}
|
||||
|
|
@ -229,11 +231,17 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
|||
list(result["loss"].size()),
|
||||
[])
|
||||
self.parent.assertListEqual(
|
||||
list(result["start_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
list(result["start_top_log_probs"].size()),
|
||||
[self.batch_size, model.config.start_n_top])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
list(result["start_top_index"].size()),
|
||||
[self.batch_size, model.config.start_n_top])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_top_log_probs"].size()),
|
||||
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_top_index"].size()),
|
||||
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
|
||||
self.parent.assertListEqual(
|
||||
list(result["cls_logits"].size()),
|
||||
[self.batch_size])
|
||||
|
|
|
|||
Loading…
Reference in a new issue