transformers/examples/seq2seq/run_distributed_eval.py

249 lines
9.4 KiB
Python
Raw Normal View History

#!/usr/bin/env python
import argparse
import shutil
import time
from json import JSONDecodeError
from logging import getLogger
from pathlib import Path
2020-09-16 19:38:37 +00:00
from typing import Dict, List
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from utils import (
Seq2SeqDataset,
calculate_bleu,
calculate_rouge,
chunks,
lmap,
load_json,
parse_numeric_n_bool_cl_kwargs,
save_json,
use_task_specific_params,
write_txt_file,
)
logger = getLogger(__name__)
def eval_data_dir(
data_dir,
save_dir: str,
model_name: str,
bs: int = 8,
max_source_length: int = 1024,
type_path="val",
n_obs=None,
fp16=False,
task="summarization",
local_rank=None,
num_return_sequences=1,
dataset_kwargs: Dict = None,
prefix="",
**generate_kwargs,
) -> Dict:
"""Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json"""
model_name = str(model_name)
assert local_rank is not None
torch.distributed.init_process_group(backend="nccl", rank=local_rank)
save_dir = Path(save_dir)
save_path = save_dir.joinpath(f"rank_{local_rank}_output.json")
torch.cuda.set_device(local_rank)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).cuda()
if fp16:
model = model.half()
# determine if we need to increase num_beams
use_task_specific_params(model, task) # update config with task specific params
num_beams = generate_kwargs.pop("num_beams", model.config.num_beams) # AttributeError risk?
if num_return_sequences > num_beams:
num_beams = num_return_sequences
tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type.
2020-09-14 03:40:38 +00:00
if max_source_length is None:
max_source_length = tokenizer.model_max_length
if prefix is None:
prefix = prefix or getattr(model.config, "prefix", "") or ""
ds = Seq2SeqDataset(
tokenizer,
data_dir,
max_source_length,
max_target_length=1024,
type_path=type_path,
n_obs=n_obs,
prefix=prefix,
**dataset_kwargs,
)
2020-09-16 19:38:37 +00:00
# I set shuffle=True for a more accurate progress bar.
# If all the longest samples are first, the prog bar estimate is too high at the beginning.
sampler = ds.make_sortish_sampler(bs, distributed=True, add_extra_examples=False, shuffle=True)
data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn)
results = []
for batch in tqdm(data_loader):
summaries = model.generate(
input_ids=batch["input_ids"].to(model.device),
attention_mask=batch["attention_mask"].to(model.device),
num_return_sequences=num_return_sequences,
num_beams=num_beams,
**generate_kwargs,
)
2020-09-16 19:38:37 +00:00
preds = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
ids = batch["ids"]
if num_return_sequences > 1:
preds = chunks(preds, num_return_sequences) # batch size chunks, each of size num_return_seq
2020-09-16 19:38:37 +00:00
for i, pred in enumerate(preds):
results.append(dict(pred=pred, id=ids[i].item()))
save_json(results, save_path)
return results, sampler.num_replicas
def run_generate():
parser = argparse.ArgumentParser(
epilog="Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate"
)
parser.add_argument("--data_dir", type=str, help="like cnn_dm/test.source")
parser.add_argument(
"--model_name",
type=str,
help="like facebook/bart-large-cnn,t5-base, etc.",
default="sshleifer/distilbart-xsum-12-3",
)
parser.add_argument("--save_dir", type=str, help="where to save", default="tmp_gen")
2020-09-14 03:40:38 +00:00
parser.add_argument("--max_source_length", type=int, default=None)
parser.add_argument(
"--type_path", type=str, default="test", help="which subset to evaluate typically train/val/test"
)
parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
parser.add_argument(
"--local_rank", type=int, default=-1, required=False, help="should be passed by distributed.launch"
)
parser.add_argument(
"--n_obs", type=int, default=None, required=False, help="How many observations. Defaults to all."
)
parser.add_argument(
"--num_return_sequences", type=int, default=1, required=False, help="How many sequences to return"
)
parser.add_argument(
"--sync_timeout",
type=int,
default=600,
required=False,
help="How long should master process wait for other processes to finish.",
)
parser.add_argument("--src_lang", type=str, default=None, required=False)
parser.add_argument("--tgt_lang", type=str, default=None, required=False)
parser.add_argument(
"--prefix", type=str, required=False, default=None, help="will be added to the begininng of src examples"
)
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--debug", action="store_true")
start_time = time.time()
args, rest = parser.parse_known_args()
2020-09-16 19:38:37 +00:00
generate_kwargs = parse_numeric_n_bool_cl_kwargs(rest)
if generate_kwargs and args.local_rank <= 0:
2020-09-14 03:40:38 +00:00
print(f"parsed the following generate kwargs: {generate_kwargs}")
json_save_dir = Path(args.save_dir + "_tmp")
Path(json_save_dir).mkdir(exist_ok=True) # this handles locking.
intermediate_files = list(json_save_dir.glob("rank_*.json"))
if intermediate_files:
raise ValueError(f"Found files at {json_save_dir} please move or remove them.")
# In theory, a node could finish and save before another node hits this. If this happens, we can address later.
dataset_kwargs = {}
if args.src_lang is not None:
dataset_kwargs["src_lang"] = args.src_lang
if args.tgt_lang is not None:
dataset_kwargs["tgt_lang"] = args.tgt_lang
Path(args.save_dir).mkdir(exist_ok=True)
results, num_replicas = eval_data_dir(
args.data_dir,
json_save_dir,
args.model_name,
2020-09-14 03:40:38 +00:00
type_path=args.type_path,
2020-09-17 01:58:57 +00:00
bs=args.bs,
fp16=args.fp16,
task=args.task,
local_rank=args.local_rank,
n_obs=args.n_obs,
2020-09-14 03:40:38 +00:00
max_source_length=args.max_source_length,
num_return_sequences=args.num_return_sequences,
prefix=args.prefix,
2020-09-30 21:00:06 +00:00
dataset_kwargs=dataset_kwargs,
**generate_kwargs,
)
if args.local_rank <= 0:
save_dir = Path(args.save_dir)
save_dir.mkdir(exist_ok=True)
partial_results = gather_results_from_each_node(num_replicas, json_save_dir, args.sync_timeout)
2020-09-16 19:38:37 +00:00
preds = combine_partial_results(partial_results)
if args.num_return_sequences > 1:
save_path = save_dir.joinpath("pseudolabel_results.json")
print(f"Saving aggregated results at {save_path}, intermediate in {json_save_dir}/")
save_json(preds, save_path)
return
2020-09-16 19:38:37 +00:00
tgt_file = Path(args.data_dir).joinpath(args.type_path + ".target")
labels = [x.rstrip() for x in open(tgt_file).readlines()][: len(preds)]
# Calculate metrics, save metrics, and save _generations.txt
calc_bleu = "translation" in args.task
score_fn = calculate_bleu if calc_bleu else calculate_rouge
metric_name = "bleu" if calc_bleu else "rouge"
metrics: Dict = score_fn(preds, labels)
metrics["n_obs"] = len(preds)
runtime = time.time() - start_time
metrics["seconds_per_sample"] = round(runtime / metrics["n_obs"], 4)
metrics["n_gpus"] = num_replicas
# TODO(@stas00): add whatever metadata to metrics
metrics_save_path = save_dir.joinpath(f"{args.type_path}_{metric_name}.json")
2020-09-16 19:38:37 +00:00
save_json(metrics, metrics_save_path, indent=None)
print(metrics)
write_txt_file(preds, save_dir.joinpath(f"{args.type_path}_generations.txt"))
if args.debug:
write_txt_file(labels, save_dir.joinpath(f"{args.type_path}.target"))
else:
shutil.rmtree(json_save_dir)
2020-09-16 19:38:37 +00:00
def combine_partial_results(partial_results) -> List:
"""Concatenate partial results into one file, then sort it by id."""
records = []
for partial_result in partial_results:
records.extend(partial_result)
records = list(sorted(records, key=lambda x: x["id"]))
preds = [x["pred"] for x in records]
2020-09-16 19:38:37 +00:00
return preds
def gather_results_from_each_node(num_replicas, save_dir, timeout) -> List[Dict[str, List]]:
# WAIT FOR lots of .json files
start_wait = time.time()
logger.info("waiting for all nodes to finish")
json_data = None
while (time.time() - start_wait) < timeout:
json_files = list(save_dir.glob("rank_*.json"))
if len(json_files) < num_replicas:
continue
try:
# make sure all json files are fully saved
json_data = lmap(load_json, json_files)
return json_data
except JSONDecodeError:
continue
else:
raise TimeoutError("Rank 0 gave up on waiting for other processes")
# Unreachable
if __name__ == "__main__":
# Usage for MT:
run_generate()