diff --git a/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py b/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py index a9876fcd6..d80d470b4 100755 --- a/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py +++ b/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py @@ -28,7 +28,6 @@ from typing import Dict, List, Optional, Union import datasets import evaluate -import numpy as np import torch from datasets import DatasetDict, load_dataset @@ -712,10 +711,14 @@ def main(): logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}") return - def compute_metrics(pred): - pred_logits = pred.predictions - pred_ids = np.argmax(pred_logits, axis=-1) + # For languages like Chinese with large vocabulary size, we need to discard logits + # and only keep the argmax, otherwise we run out of memory during evaluation. + def preprocess_logits_for_metrics(logits, labels): + pred_ids = torch.argmax(logits, dim=-1) + return pred_ids, labels + def compute_metrics(pred): + pred_ids = pred.predictions[0] pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id pred_str = tokenizer.batch_decode(pred_ids) @@ -762,6 +765,7 @@ def main(): train_dataset=vectorized_datasets["train"] if training_args.do_train else None, eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None, tokenizer=processor, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) # 8. Finally, we can start training