From efb35a4107478f7d2ebcf56572c0967e68536e15 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 11 Jan 2022 11:59:38 +0100 Subject: [PATCH] [Wav2Vec2ProcessorWithLM] improve decoder downlaod (#15040) --- .../processing_wav2vec2_with_lm.py | 9 ++++++++- tests/test_processor_wav2vec2_with_lm.py | 14 ++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py index c90c4b9c9..e5fddc80f 100644 --- a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +++ b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py @@ -166,7 +166,14 @@ class Wav2Vec2ProcessorWithLM: # BeamSearchDecoderCTC has no auto class kwargs.pop("_from_auto", None) - decoder = BeamSearchDecoderCTC.load_from_hf_hub(pretrained_model_name_or_path, **kwargs) + # make sure that only relevant filenames are downloaded + language_model_filenames = os.path.join(BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*") + alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME + allow_regex = [language_model_filenames, alphabet_filename] + + decoder = BeamSearchDecoderCTC.load_from_hf_hub( + pretrained_model_name_or_path, allow_regex=allow_regex, **kwargs + ) # set language model attributes for attribute in ["alpha", "beta", "unk_score_offset", "score_boundary"]: diff --git a/tests/test_processor_wav2vec2_with_lm.py b/tests/test_processor_wav2vec2_with_lm.py index 14e76d38f..119e8eb7b 100644 --- a/tests/test_processor_wav2vec2_with_lm.py +++ b/tests/test_processor_wav2vec2_with_lm.py @@ -18,6 +18,7 @@ import shutil import tempfile import unittest from multiprocessing import Pool +from pathlib import Path import numpy as np @@ -234,3 +235,16 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): self.assertListEqual(decoded_decoder, decoded_processor) self.assertListEqual([" ", " "], decoded_processor) + + def test_decoder_download_ignores_files(self): + processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm") + + language_model = processor.decoder.model_container[processor.decoder._model_key] + path_to_cached_dir = Path(language_model._kenlm_model.path.decode("utf-8")).parent.parent.absolute() + + downloaded_decoder_files = os.listdir(path_to_cached_dir) + + # test that only decoder relevant files from + # https://huggingface.co/hf-internal-testing/processor_with_lm/tree/main + # are downloaded and none of the rest (e.g. README.md, ...) + self.assertListEqual(downloaded_decoder_files, ["alphabet.json", "language_model"])