diff --git a/README.md b/README.md index 51e1bf566..620d2d579 100644 --- a/README.md +++ b/README.md @@ -311,6 +311,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu. 1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. 1. **[VisualBERT](https://huggingface.co/docs/transformers/model_doc/visual_bert)** (from UCLA NLP) released with the paper [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang. +1. **[WavLM](https://huggingface.co/docs/transformers/master/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei. 1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. 1. **[XLM](https://huggingface.co/docs/transformers/model_doc/xlm)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau. 1. **[XLM-ProphetNet](https://huggingface.co/docs/transformers/model_doc/xlmprophetnet)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. diff --git a/README_ko.md b/README_ko.md index c2bb731cd..dc888dcd8 100644 --- a/README_ko.md +++ b/README_ko.md @@ -290,6 +290,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. 1. **[VisualBERT](https://huggingface.co/docs/transformers/model_doc/visual_bert)** (from UCLA NLP) released with the paper [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang. 1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. +1. **[WavLM](https://huggingface.co/docs/transformers/master/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei. 1. **[XLM](https://huggingface.co/docs/transformers/model_doc/xlm)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau. 1. **[XLM-ProphetNet](https://huggingface.co/docs/transformers/model_doc/xlmprophetnet)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. 1. **[XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlmroberta)** (from Facebook AI), released together with the paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov. diff --git a/README_zh-hans.md b/README_zh-hans.md index 7aea040ac..9ecd8c1a6 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -314,6 +314,7 @@ conda install -c huggingface transformers 1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (来自 Google AI) 伴随论文 [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) 由 Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby 发布。 1. **[VisualBERT](https://huggingface.co/docs/transformers/model_doc/visual_bert)** (来自 UCLA NLP) 伴随论文 [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) 由 Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang 发布。 1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (来自 Facebook AI) 伴随论文 [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) 由 Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli 发布。 +1. **[WavLM](https://huggingface.co/docs/transformers/master/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei. 1. **[XLM](https://huggingface.co/docs/transformers/model_doc/xlm)** (来自 Facebook) 伴随论文 [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) 由 Guillaume Lample and Alexis Conneau 发布。 1. **[XLM-ProphetNet](https://huggingface.co/docs/transformers/model_doc/xlmprophetnet)** (来自 Microsoft Research) 伴随论文 [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) 由 Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou 发布。 1. **[XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlmroberta)** (来自 Facebook AI), 伴随论文 [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) 由 Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 52c034d3f..6ae9509cf 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -326,6 +326,7 @@ conda install -c huggingface transformers 1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. 1. **[VisualBERT](https://huggingface.co/docs/transformers/model_doc/visual_bert)** (from UCLA NLP) released with the paper [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang. 1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. +1. **[WavLM](https://huggingface.co/docs/transformers/master/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei. 1. **[XLM](https://huggingface.co/docs/transformers/model_doc/xlm)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau. 1. **[XLM-ProphetNet](https://huggingface.co/docs/transformers/model_doc/xlmprophetnet)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. 1. **[XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlmroberta)** (from Facebook AI), released together with the paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov. diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 74dce907d..2afd514ea 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -284,6 +284,8 @@ title: VisualBERT - local: model_doc/wav2vec2 title: Wav2Vec2 + - local: model_doc/wavlm + title: WavLM - local: model_doc/xlm title: XLM - local: model_doc/xlmprophetnet diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 4b6d468e8..444ba3e84 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -170,6 +170,7 @@ conversion utilities for the following models. 1. **[UniSpeechSat](model_doc/unispeech_sat)** (from Microsoft Research) released with the paper [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu. 1. **[Vision Transformer (ViT)](model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. 1. **[VisualBERT](model_doc/visual_bert)** (from UCLA NLP) released with the paper [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang. +1. **[WavLM](model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei. 1. **[Wav2Vec2](model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. 1. **[XLM](model_doc/xlm)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau. 1. **[XLM-ProphetNet](model_doc/xlmprophetnet)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. @@ -263,6 +264,7 @@ Flax), PyTorch, and/or TensorFlow. | VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ | | ViT | ❌ | ❌ | ✅ | ✅ | ✅ | | Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ | +| WavLM | ❌ | ❌ | ✅ | ❌ | ❌ | | XLM | ✅ | ❌ | ✅ | ✅ | ❌ | | XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | ❌ | | XLMProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/model_doc/wavlm.rst b/docs/source/model_doc/wavlm.rst new file mode 100644 index 000000000..e371e3977 --- /dev/null +++ b/docs/source/model_doc/wavlm.rst @@ -0,0 +1,83 @@ +.. + Copyright 2021 The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + +WavLM +----------------------------------------------------------------------------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The WavLM model was proposed in `WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing +`__ by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, +Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, +Michael Zeng, Furu Wei. + +The abstract from the paper is the following: + +*Self-supervised learning (SSL) achieves great success in speech recognition, while limited exploration has been +attempted for other speech processing tasks. As speech signal contains multi-faceted information including speaker +identity, paralinguistics, spoken content, etc., learning universal representations for all speech tasks is +challenging. In this paper, we propose a new pre-trained model, WavLM, to solve full-stack downstream speech tasks. +WavLM is built based on the HuBERT framework, with an emphasis on both spoken content modeling and speaker identity +preservation. We first equip the Transformer structure with gated relative position bias to improve its capability on +recognition tasks. For better speaker discrimination, we propose an utterance mixing training strategy, where +additional overlapped utterances are created unsupervisely and incorporated during model training. Lastly, we scale up +the training dataset from 60k hours to 94k hours. WavLM Large achieves state-of-the-art performance on the SUPERB +benchmark, and brings significant improvements for various speech processing tasks on their representative benchmarks.* + +Tips: + +- WavLM is a speech model that accepts a float array corresponding to the raw waveform of the speech signal. Please use + :class:`~transformers.Wav2Vec2Processor` for the feature extraction. +- WavLM model can be fine-tuned using connectionist temporal classification (CTC) so the model output has to be decoded + using :class:`~transformers.Wav2Vec2CTCTokenizer`. +- WavLM performs especially well on speaker verification, speaker identification, and speaker diarization tasks. + +Relevant checkpoints can be found under https://huggingface.co/models?other=wavlm. + +This model was contributed by `patrickvonplaten `__. The Authors' code can be +found `here `__. + + +WavLMConfig +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.WavLMConfig + :members: + + +WavLM specific outputs +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.models.wavlm.modeling_wavlm.WavLMBaseModelOutput + :members: + + +WavLMModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.WavLMModel + :members: forward + + +WavLMForCTC +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.WavLMForCTC + :members: forward + + +WavLMForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.WavLMForSequenceClassification + :members: forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 2adf6c6a5..8a23fac8d 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -314,6 +314,10 @@ _import_structure = { "Wav2Vec2Tokenizer", ], "models.wav2vec2_with_lm": ["Wav2Vec2ProcessorWithLM"], + "models.wavlm": [ + "WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP", + "WavLMConfig", + ], "models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"], "models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"], "models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"], @@ -1372,6 +1376,15 @@ if is_torch_available(): "Wav2Vec2PreTrainedModel", ] ) + _import_structure["models.wavlm"].extend( + [ + "WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST", + "WavLMForCTC", + "WavLMForSequenceClassification", + "WavLMModel", + "WavLMPreTrainedModel", + ] + ) _import_structure["models.xlm"].extend( [ "XLM_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -2327,6 +2340,7 @@ if TYPE_CHECKING: Wav2Vec2Tokenizer, ) from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM + from .models.wavlm import WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP, WavLMConfig from .models.xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMTokenizer from .models.xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig from .models.xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig @@ -3210,6 +3224,13 @@ if TYPE_CHECKING: Wav2Vec2Model, Wav2Vec2PreTrainedModel, ) + from .models.wavlm import ( + WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST, + WavLMForCTC, + WavLMForSequenceClassification, + WavLMModel, + WavLMPreTrainedModel, + ) from .models.xlm import ( XLM_PRETRAINED_MODEL_ARCHIVE_LIST, XLMForMultipleChoice, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index dee637278..1c1aef5aa 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -108,6 +108,7 @@ from . import ( vit, wav2vec2, wav2vec2_with_lm, + wavlm, xlm, xlm_prophetnet, xlm_roberta, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index ad2712f46..81e0749c5 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -109,6 +109,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("sew", "SEWConfig"), ("unispeech-sat", "UniSpeechSatConfig"), ("unispeech", "UniSpeechConfig"), + ("wavlm", "WavLMConfig"), ] ) @@ -277,6 +278,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("sew", "SEW"), ("unispeech-sat", "UniSpeechSat"), ("unispeech", "UniSpeech"), + ("wavlm", "WavLM"), ] ) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index cbc340c11..978551fd0 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -52,6 +52,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("vit", "ViTModel"), ("wav2vec2", "Wav2Vec2Model"), ("unispeech-sat", "UniSpeechSatModel"), + ("wavlm", "WavLMModel"), ("unispeech", "UniSpeechModel"), ("hubert", "HubertModel"), ("m2m_100", "M2M100Model"), @@ -523,6 +524,7 @@ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("hubert", "HubertForSequenceClassification"), ("sew", "SEWForSequenceClassification"), ("sew-d", "SEWDForSequenceClassification"), + ("wavlm", "WavLMForSequenceClassification"), ] ) @@ -535,6 +537,7 @@ MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict( ("hubert", "HubertForCTC"), ("sew", "SEWForCTC"), ("sew-d", "SEWDForCTC"), + ("wavlm", "WavLMForCTC"), ] ) diff --git a/src/transformers/models/wavlm/__init__.py b/src/transformers/models/wavlm/__init__.py new file mode 100644 index 000000000..7f8cc1041 --- /dev/null +++ b/src/transformers/models/wavlm/__init__.py @@ -0,0 +1,51 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...file_utils import _LazyModule, is_torch_available + + +_import_structure = { + "configuration_wavlm": ["WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "WavLMConfig"], +} + +if is_torch_available(): + _import_structure["modeling_wavlm"] = [ + "WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST", + "WavLMForCTC", + "WavLMForSequenceClassification", + "WavLMModel", + "WavLMPreTrainedModel", + ] + +if TYPE_CHECKING: + from .configuration_wavlm import WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP, WavLMConfig + + if is_torch_available(): + from .modeling_wavlm import ( + WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST, + WavLMForCTC, + WavLMForSequenceClassification, + WavLMModel, + WavLMPreTrainedModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/wavlm/configuration_wavlm.py b/src/transformers/models/wavlm/configuration_wavlm.py new file mode 100644 index 000000000..d8b99e315 --- /dev/null +++ b/src/transformers/models/wavlm/configuration_wavlm.py @@ -0,0 +1,304 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors, Microsoft Research, and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" WavLM model configuration """ + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "facebook/wavlm-base-960h": "https://huggingface.co/facebook/wavlm-base-960h/resolve/main/config.json", + # See all WavLM models at https://huggingface.co/models?filter=wavlm +} + + +class WavLMConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.WavLMModel`. It is used to + instantiate an WavLM model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the WavLM `facebook/wavlm-base-960h + `__ architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + + Args: + vocab_size (:obj:`int`, `optional`, defaults to 32): + Vocabulary size of the WavLM model. Defines the number of different tokens that can be represented by the + :obj:`inputs_ids` passed when calling :class:`~transformers.WavLMModel`. Vocabulary size of the model. + Defines the different tokens that can be represented by the `inputs_ids` passed to the forward method of + :class:`~transformers.WavLMModel`. + hidden_size (:obj:`int`, `optional`, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (:obj:`int`, `optional`, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (:obj:`int`, `optional`, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (:obj:`int`, `optional`, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, + :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported. + hidden_dropout (:obj:`float`, `optional`, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (:obj:`float`, `optional`, defaults to 0.1): + The dropout ratio for the attention probabilities. + final_dropout (:obj:`float`, `optional`, defaults to 0.1): + The dropout probability for the final projection layer of :class:`WavLMForCTC`. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): + The epsilon used by the layer normalization layers. + feat_extract_norm (:obj:`str`, `optional`, defaults to :obj:`"group"`): + The norm to be applied to 1D convolutional layers in feature extractor. One of :obj:`"group"` for group + normalization of only the first 1D convolutional layer or :obj:`"layer"` for layer normalization of all 1D + convolutional layers. + feat_proj_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout probability for output of the feature extractor. + feat_extract_activation (:obj:`str, `optional`, defaults to :obj:`"gelu"`): + The non-linear activation function (function or string) in the 1D convolutional layers of the feature + extractor. If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported. + feat_quantizer_dropout (obj:`float`, `optional`, defaults to 0.0): + The dropout probabilitiy for quantized feature extractor states. + conv_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 512, 512, 512)`): + A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the + feature extractor. The length of `conv_dim` defines the number of 1D convolutional layers. + conv_stride (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 2, 2, 2, 2, 2, 2)`): + A tuple of integers defining the stride of each 1D convolutional layer in the feature extractor. The length + of `conv_stride` defines the number of convolutional layers and has to match the the length of `conv_dim`. + conv_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(10, 3, 3, 3, 3, 3, 3)`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the feature extractor. The + length of `conv_kernel` defines the number of convolutional layers and has to match the the length of + `conv_dim`. + conv_bias (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether the 1D convolutional layers have a bias. + num_conv_pos_embeddings (:obj:`int`, `optional`, defaults to 128): + Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional + embeddings layer. + num_conv_pos_embedding_groups (:obj:`int`, `optional`, defaults to 16): + Number of groups of 1D convolutional positional embeddings layer. + do_stable_layer_norm (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to apply `stable` layer norm architecture of the Transformer encoder. ``do_stable_layer_norm is + True`` corresponds to applying layer norm before the attention layer, whereas ``do_stable_layer_norm is + False`` corresponds to applying layer norm after the attention layer. + apply_spec_augment (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to apply *SpecAugment* data augmentation to the outputs of the feature extractor. For reference see + `SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition + `__. + mask_time_prob (:obj:`float`, `optional`, defaults to 0.05): + Propability of each feature vector along the time axis to be chosen as the start of the vector span to be + masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature vectors will be + masked along the time axis. This is only relevant if ``apply_spec_augment is True``. + mask_time_length (:obj:`int`, `optional`, defaults to 10): + Length of vector span along the time axis. + mask_time_min_masks (:obj:`int`, `optional`, defaults to 2),: + The minimum number of masks of length ``mask_feature_length`` generated along the time axis, each time + step, irrespectively of ``mask_feature_prob``. Only relevant if + ''mask_time_prob*len(time_axis)/mask_time_length < mask_time_min_masks'' + mask_feature_prob (:obj:`float`, `optional`, defaults to 0.0): + Propability of each feature vector along the feature axis to be chosen as the start of the vector span to + be masked. Approximately ``mask_time_prob * hidden_size // mask_time_length`` feature vectors will be + masked along the time axis. This is only relevant if ``apply_spec_augment is True``. + mask_feature_length (:obj:`int`, `optional`, defaults to 10): + Length of vector span along the feature axis. + num_codevectors_per_group (:obj:`int`, `optional`, defaults to 320): + Number of entries in each quantization codebook (group). + num_codevector_groups (:obj:`int`, `optional`, defaults to 2): + Number of codevector groups for product codevector quantization. + contrastive_logits_temperature (:obj:`float`, `optional`, defaults to 0.1): + The temperature `kappa` in the contrastive loss. + feat_quantizer_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout probabilitiy for the output of the feature extractor that's used by the quantizer. + num_negatives (:obj:`int`, `optional`, defaults to 100): + Number of negative samples for the contrastive loss. + codevector_dim (:obj:`int`, `optional`, defaults to 256): + Dimensionality of the quantized feature vectors. + proj_codevector_dim (:obj:`int`, `optional`, defaults to 256): + Dimensionality of the final projection of both the quantized and the transformer features. + diversity_loss_weight (:obj:`int`, `optional`, defaults to 0.1): + The weight of the codebook diversity loss component. + ctc_loss_reduction (:obj:`str`, `optional`, defaults to :obj:`"mean"`): + Specifies the reduction to apply to the output of ``torch.nn.CTCLoss``. Only relevant when training an + instance of :class:`~transformers.WavLMForCTC`. + ctc_zero_infinity (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses + mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an + instance of :class:`~transformers.WavLMForCTC`. + use_weighted_layer_sum (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an + instance of :class:`~transformers.WavLMForSequenceClassification`. + classifier_proj_size (:obj:`int`, `optional`, defaults to 256): + Dimensionality of the projection before token mean-pooling for classification. + add_adapter (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for + warm-starting Wav2Vec2 for SpeechEncoderDecoder models. + adapter_kernel_size (:obj:`int`, `optional`, defaults to 3): + Kernel size of the convolutional layers in the adapter network. Only relevant if ``add_adapter is True``. + adapter_stride (:obj:`int`, `optional`, defaults to 2): + Stride of the convolutional layers in the adapter network. Only relevant if ``add_adapter is True``. + num_adapter_layers (:obj:`int`, `optional`, defaults to 3): + Number of convolutional layers that should be used in the adapter network. Only relevant if ``add_adapter + is True``. + output_hidden_size (:obj:`int`, `optional`): + Dimensionality of the encoder output layer. If not defined, this defaults to `hidden-size`. Only relevant + if ``add_adapter is True``. + + Example:: + + Example:: + + >>> from transformers import WavLMModel, WavLMConfig + + >>> # Initializing a WavLM facebook/wavlm-base-960h style configuration + >>> configuration = WavLMConfig() + + >>> # Initializing a model from the facebook/wavlm-base-960h style configuration + >>> model = WavLMModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + model_type = "wavlm" + + def __init__( + self, + vocab_size=32, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout=0.1, + activation_dropout=0.1, + attention_dropout=0.1, + feat_proj_dropout=0.0, + feat_quantizer_dropout=0.0, + final_dropout=0.1, + layerdrop=0.1, + initializer_range=0.02, + layer_norm_eps=1e-5, + feat_extract_norm="group", + feat_extract_activation="gelu", + conv_dim=(512, 512, 512, 512, 512, 512, 512), + conv_stride=(5, 2, 2, 2, 2, 2, 2), + conv_kernel=(10, 3, 3, 3, 3, 2, 2), + conv_bias=False, + num_conv_pos_embeddings=128, + num_conv_pos_embedding_groups=16, + num_buckets=320, + max_bucket_distance=800, + do_stable_layer_norm=False, + apply_spec_augment=True, + mask_time_prob=0.05, + mask_time_length=10, + mask_time_min_masks=2, + mask_feature_prob=0.0, + mask_feature_length=10, + num_codevectors_per_group=320, + num_codevector_groups=2, + contrastive_logits_temperature=0.1, + num_negatives=100, + codevector_dim=256, + proj_codevector_dim=256, + diversity_loss_weight=0.1, + ctc_loss_reduction="mean", + ctc_zero_infinity=False, + use_weighted_layer_sum=False, + classifier_proj_size=256, + num_ctc_classes=80, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + add_adapter=False, + adapter_kernel_size=3, + adapter_stride=2, + num_adapter_layers=3, + output_hidden_size=None, + **kwargs + ): + super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id) + self.hidden_size = hidden_size + self.feat_extract_norm = feat_extract_norm + self.feat_extract_activation = feat_extract_activation + self.conv_dim = list(conv_dim) + self.conv_stride = list(conv_stride) + self.conv_kernel = list(conv_kernel) + self.conv_bias = conv_bias + self.num_buckets = num_buckets + self.max_bucket_distance = max_bucket_distance + self.num_conv_pos_embeddings = num_conv_pos_embeddings + self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups + self.num_feat_extract_layers = len(self.conv_dim) + self.num_hidden_layers = num_hidden_layers + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_attention_heads = num_attention_heads + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.feat_proj_dropout = feat_proj_dropout + self.final_dropout = final_dropout + self.layerdrop = layerdrop + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + self.num_ctc_classes = num_ctc_classes + self.vocab_size = vocab_size + self.do_stable_layer_norm = do_stable_layer_norm + self.use_weighted_layer_sum = use_weighted_layer_sum + self.classifier_proj_size = classifier_proj_size + + if ( + (len(self.conv_stride) != self.num_feat_extract_layers) + or (len(self.conv_kernel) != self.num_feat_extract_layers) + or (len(self.conv_dim) != self.num_feat_extract_layers) + ): + raise ValueError( + "Configuration for convolutional layers is incorrect. " + "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, " + f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) " + f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`." + ) + + # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 + self.apply_spec_augment = apply_spec_augment + self.mask_time_prob = mask_time_prob + self.mask_time_length = mask_time_length + self.mask_time_min_masks = mask_time_min_masks + self.mask_feature_prob = mask_feature_prob + self.mask_feature_length = mask_feature_length + + # parameters for pretraining with codevector quantized representations + self.num_codevectors_per_group = num_codevectors_per_group + self.num_codevector_groups = num_codevector_groups + self.contrastive_logits_temperature = contrastive_logits_temperature + self.feat_quantizer_dropout = feat_quantizer_dropout + self.num_negatives = num_negatives + self.codevector_dim = codevector_dim + self.proj_codevector_dim = proj_codevector_dim + self.diversity_loss_weight = diversity_loss_weight + + # ctc loss + self.ctc_loss_reduction = ctc_loss_reduction + self.ctc_zero_infinity = ctc_zero_infinity + + # adapter + self.add_adapter = add_adapter + self.adapter_kernel_size = adapter_kernel_size + self.adapter_stride = adapter_stride + self.num_adapter_layers = num_adapter_layers + self.output_hidden_size = output_hidden_size or hidden_size diff --git a/src/transformers/models/wavlm/convert_wavlm_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/wavlm/convert_wavlm_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 000000000..8523fa87e --- /dev/null +++ b/src/transformers/models/wavlm/convert_wavlm_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,203 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert WavLM checkpoint.""" + + +import argparse + +import torch + +from transformers import WavLMConfig, WavLMModel, logging + +# Step 1. clone https://github.com/microsoft/unilm +# Step 2. git checkout to https://github.com/microsoft/unilm/commit/b94ec76c36f02fb2b0bf0dcb0b8554a2185173cd +# Step 3. cd unilm +# Step 4. ln -s $(realpath wavlm/modules.py) ./ # create simlink +# import classes +from unilm.wavlm.WavLM import WavLM as WavLMOrig +from unilm.wavlm.WavLM import WavLMConfig as WavLMConfigOrig + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +MAPPING = { + "post_extract_proj": "feature_projection.projection", + "encoder.pos_conv.0": "encoder.pos_conv_embed.conv", + "self_attn.k_proj": "encoder.layers.*.attention.k_proj", + "self_attn.v_proj": "encoder.layers.*.attention.v_proj", + "self_attn.q_proj": "encoder.layers.*.attention.q_proj", + "self_attn.out_proj": "encoder.layers.*.attention.out_proj", + "self_attn.grep_linear": "encoder.layers.*.attention.gru_rel_pos_linear", + "self_attn.relative_attention_bias": "encoder.layers.*.attention.rel_attn_embed", + "self_attn.grep_a": "encoder.layers.*.attention.gru_rel_pos_const", + "self_attn_layer_norm": "encoder.layers.*.layer_norm", + "fc1": "encoder.layers.*.feed_forward.intermediate_dense", + "fc2": "encoder.layers.*.feed_forward.output_dense", + "final_layer_norm": "encoder.layers.*.final_layer_norm", + "encoder.layer_norm": "encoder.layer_norm", + "w2v_model.layer_norm": "feature_projection.layer_norm", + "quantizer.weight_proj": "quantizer.weight_proj", + "quantizer.vars": "quantizer.codevectors", + "project_q": "project_q", + "final_proj": "project_hid", + "w2v_encoder.proj": "ctc_proj", + "mask_emb": "masked_spec_embed", +} +TOP_LEVEL_KEYS = [ + "ctc_proj", + "quantizer.weight_proj", + "quantizer.codevectors", + "project_q", + "project_hid", +] + + +def set_recursively(hf_pointer, key, value, full_name, weight_type): + for attribute in key.split("."): + hf_pointer = getattr(hf_pointer, attribute) + + if weight_type is not None: + hf_shape = getattr(hf_pointer, weight_type).shape + else: + hf_shape = hf_pointer.shape + + assert ( + hf_shape == value.shape + ), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}" + + if weight_type == "weight": + hf_pointer.weight.data = value + elif weight_type == "weight_g": + hf_pointer.weight_g.data = value + elif weight_type == "weight_v": + hf_pointer.weight_v.data = value + elif weight_type == "bias": + hf_pointer.bias.data = value + else: + hf_pointer.data = value + + logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.") + + +def recursively_load_weights(fairseq_model, hf_model): + unused_weights = [] + fairseq_dict = fairseq_model.state_dict() + + feature_extractor = hf_model.feature_extractor + + for name, value in fairseq_dict.items(): + is_used = False + if "conv_layers" in name: + load_conv_layer( + name, + value, + feature_extractor, + unused_weights, + hf_model.config.feat_extract_norm == "group", + ) + is_used = True + else: + for key, mapped_key in MAPPING.items(): + if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]: + is_used = True + if "*" in mapped_key: + layer_index = name.split(key)[0].split(".")[-2] + mapped_key = mapped_key.replace("*", layer_index) + if "weight_g" in name: + weight_type = "weight_g" + elif "weight_v" in name: + weight_type = "weight_v" + elif "bias" in name and "relative_attention_bias" not in name: + weight_type = "bias" + elif "weight" in name: + # TODO: don't match quantizer.weight_proj + weight_type = "weight" + else: + weight_type = None + + set_recursively(hf_model, mapped_key, value, name, weight_type) + continue + if not is_used: + unused_weights.append(name) + + logger.warning(f"Unused weights: {unused_weights}") + + +def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm): + name = full_name.split("conv_layers.")[-1] + items = name.split(".") + layer_id = int(items[0]) + type_id = int(items[1]) + + if type_id == 0: + if "bias" in name: + assert ( + value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape + ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found." + feature_extractor.conv_layers[layer_id].conv.bias.data = value + logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") + elif "weight" in name: + assert ( + value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape + ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found." + feature_extractor.conv_layers[layer_id].conv.weight.data = value + logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") + elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm): + if "bias" in name: + assert ( + value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape + ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found." + feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value + logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") + elif "weight" in name: + assert ( + value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape + ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found." + feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value + logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") + else: + unused_weights.append(full_name) + + +@torch.no_grad() +def convert_wavlm_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None): + + # load the pre-trained checkpoints + checkpoint = torch.load(checkpoint_path) + cfg = WavLMConfigOrig(checkpoint["cfg"]) + model = WavLMOrig(cfg) + model.load_state_dict(checkpoint["model"]) + model.eval() + + if config_path is not None: + config = WavLMConfig.from_pretrained(config_path) + else: + config = WavLMConfig() + + hf_wavlm = WavLMModel(config) + + recursively_load_weights(model, hf_wavlm) + + hf_wavlm.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint") + parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") + args = parser.parse_args() + convert_wavlm_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path) diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py new file mode 100755 index 000000000..b54ee6811 --- /dev/null +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -0,0 +1,1441 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors, Microsoft Research, and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch WavLM model. """ + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...deepspeed import is_deepspeed_zero3_enabled +from ...file_utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, +) +from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from .configuration_wavlm import WavLMConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "WavLMConfig" +_PROCESSOR_FOR_DOC = "Wav2Vec2Processor" +_CHECKPOINT_FOR_DOC = "microsoft/wavlm-base" + +_SEQ_CLASS_CHECKPOINT = "microsoft/wavlm-base" +_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor" + +_HIDDEN_STATES_START_POSITION = 2 + +WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "microsoft/wavlm-base", + "microsoft/wavlm-base-plus", + "microsoft/wavlm-large", + # See all WavLM models at https://huggingface.co/models?filter=wavlm +] + + +@dataclass +class WavLMBaseModelOutput(ModelOutput): + """ + Output type of :class:`~transformers.WavLMBaseModelOutput`, with potential hidden states and attentions. + + Args: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + extract_features (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, conv_dim[-1])`): + Sequence of extracted feature vectors of the last convolutional layer of the model. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + extract_features: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement `SpecAugment: A Simple Data Augmentation Method for + ASR `__. Note that this method is not optimized to run on TPU and should be run + on CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked indices <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.sum(-1).detach().tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=np.bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->WavLM +class WavLMNoLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->WavLM +class WavLMLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + hidden_states = hidden_states.transpose(-2, -1) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(-2, -1) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->WavLM +class WavLMGroupNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->WavLM +class WavLMPositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + padding=config.num_conv_pos_embeddings // 2, + groups=config.num_conv_pos_embedding_groups, + ) + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + deepspeed.zero.register_external_parameter(self, self.conv.weight_v) + deepspeed.zero.register_external_parameter(self, self.conv.weight_g) + else: + self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + + self.padding = WavLMSamePadLayer(config.num_conv_pos_embeddings) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->WavLM +class WavLMSamePadLayer(nn.Module): + def __init__(self, num_conv_pos_embeddings): + super().__init__() + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 + + def forward(self, hidden_states): + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureExtractor with Wav2Vec2->WavLM +class WavLMFeatureExtractor(nn.Module): + """Construct the features from raw audio waveform""" + + def __init__(self, config): + super().__init__() + + if config.feat_extract_norm == "group": + conv_layers = [WavLMGroupNormConvLayer(config, layer_id=0)] + [ + WavLMNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1) + ] + elif config.feat_extract_norm == "layer": + conv_layers = [WavLMLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)] + else: + raise ValueError( + f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" + ) + self.conv_layers = nn.ModuleList(conv_layers) + self.gradient_checkpointing = False + self._requires_grad = True + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def forward(self, input_values): + hidden_states = input_values[:, None] + + # make sure hidden_states require grad for gradient_checkpointing + if self._requires_grad and self.training: + hidden_states.requires_grad = True + + for conv_layer in self.conv_layers: + if self._requires_grad and self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(conv_layer), + hidden_states, + ) + else: + hidden_states = conv_layer(hidden_states) + + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->WavLM +class WavLMFeatureProjection(nn.Module): + def __init__(self, config): + super().__init__() + self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps) + self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size) + self.dropout = nn.Dropout(config.feat_proj_dropout) + + def forward(self, hidden_states): + # non-projected hidden states are needed for quantization + norm_hidden_states = self.layer_norm(hidden_states) + hidden_states = self.projection(norm_hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states, norm_hidden_states + + +class WavLMAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + num_buckets: int = 320, + max_distance: int = 800, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim ** -0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.out_proj = nn.Linear(embed_dim, embed_dim) + + self.num_buckets = num_buckets + self.max_distance = max_distance + + self.gru_rel_pos_const = nn.Parameter(torch.ones(1, self.num_heads, 1, 1)) + self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8) + self.rel_attn_embed = nn.Embedding(self.num_buckets, self.num_heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_bias: Optional[torch.Tensor] = None, + output_attentions: bool = False, + index=0, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Attention layer with relative attention""" + bsz, tgt_len, _ = hidden_states.size() + + # first pass of attention layer creates position bias + if position_bias is None: + position_bias = self.compute_bias(tgt_len, tgt_len) + position_bias = ( + position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, tgt_len) + ) + + # Compute relative position bias: + # 1) get reshape hidden_states + gated_hidden_states = hidden_states.view(hidden_states.shape[:-1] + (self.num_heads, -1)) + gated_hidden_states = gated_hidden_states.permute(0, 2, 1, 3) + + # 2) project hidden states + relative_position_proj = self.gru_rel_pos_linear(gated_hidden_states) + relative_position_proj = relative_position_proj.view(gated_hidden_states.shape[:-1] + (2, 4)).sum(-1) + + # 3) compute gate for position bias from projected hidden states + gate_a, gate_b = torch.sigmoid(relative_position_proj).chunk(2, dim=-1) + gate_output = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0 + + # 4) apply gate to position bias to compute gated position_bias + gated_position_bias = gate_output.view(bsz * self.num_heads, -1, 1) * position_bias + gated_position_bias = gated_position_bias.view((-1, tgt_len, tgt_len)) + + attn_output, attn_weights = self.torch_multi_head_self_attention( + hidden_states, attention_mask, gated_position_bias, output_attentions + ) + + return attn_output, attn_weights, position_bias + + def torch_multi_head_self_attention( + self, + hidden_states: torch.FloatTensor, + attention_mask: Union[torch.LongTensor, torch.BoolTensor], + gated_position_bias: torch.FloatTensor, + output_attentions: bool, + ) -> (torch.FloatTensor, torch.FloatTensor): + """simple wrapper around torch's multi_head_attention_forward function""" + # self-attention assumes q = k = v + query = key = value = hidden_states.transpose(0, 1) + key_padding_mask = attention_mask.ne(1) if attention_mask is not None else None + + # disable bias and add_zero_attn + bias_k = bias_v = None + add_zero_attn = False + + # PyTorch 1.3.0 has F.multi_head_attention_forward defined + # so no problem with backwards compatibility + attn_output, attn_weights = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + bias_k, + bias_v, + add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + self.training, + key_padding_mask, + output_attentions, + gated_position_bias, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + + # [Seq_Len, Batch Size, ...] -> [Batch Size, Seq_Len, ...] + attn_output = attn_output.transpose(0, 1) + + if attn_weights is not None: + # IMPORTANT: Attention weights are averaged weights + # here which should not be the case. This is an open issue + # on PyTorch: https://github.com/pytorch/pytorch/issues/32590 + attn_weights = attn_weights[:, None].broadcast_to( + attn_weights.shape[:1] + (self.num_heads,) + attn_weights.shape[1:] + ) + + return attn_output, attn_weights + + def compute_bias(self, query_length: int, key_length: int) -> torch.FloatTensor: + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_positions_bucket(relative_position) + relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device) + values = self.rel_attn_embed(relative_position_bucket) + values = values.permute([2, 0, 1]) + return values + + def _relative_positions_bucket(self, relative_positions: torch.FloatTensor) -> torch.FloatTensor: + num_buckets = self.num_buckets // 2 + + relative_buckets = (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_positions_if_large = torch.log(relative_positions.float() / max_exact) + relative_positions_if_large = relative_positions_if_large / math.log(self.max_distance / max_exact) + relative_positions_if_large = relative_positions_if_large * (num_buckets - max_exact) + relative_postion_if_large = (max_exact + relative_positions_if_large).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->WavLM +class WavLMFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_dropout = nn.Dropout(config.activation_dropout) + + self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.output_dropout = nn.Dropout(config.hidden_dropout) + + def forward(self, hidden_states): + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states) + + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states) + return hidden_states + + +class WavLMEncoderLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = WavLMAttention( + embed_dim=config.hidden_size, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + num_buckets=config.num_buckets, + max_distance=config.max_bucket_distance, + ) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.feed_forward = WavLMFeedForward(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0): + attn_residual = hidden_states + hidden_states, attn_weights, position_bias = self.attention( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + index=index, + ) + hidden_states = self.dropout(hidden_states) + hidden_states = attn_residual + hidden_states + + hidden_states = self.layer_norm(hidden_states) + + hidden_states = hidden_states + self.feed_forward(hidden_states) + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states, position_bias) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class WavLMEncoderLayerStableLayerNorm(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = WavLMAttention( + embed_dim=config.hidden_size, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + num_buckets=config.num_buckets, + max_distance=config.max_bucket_distance, + ) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.feed_forward = WavLMFeedForward(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False): + attn_residual = hidden_states + hidden_states = self.layer_norm(hidden_states) + hidden_states, attn_weights, position_bias = self.attention( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + ) + hidden_states = self.dropout(hidden_states) + hidden_states = attn_residual + hidden_states + hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) + + outputs = (hidden_states, position_bias) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class WavLMEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pos_conv_embed = WavLMPositionalConvEmbedding(config) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layers = nn.ModuleList([WavLMEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + # make sure padded tokens output 0 + hidden_states[~attention_mask] = 0.0 + + position_embeddings = self.pos_conv_embed(hidden_states) + hidden_states = hidden_states + position_embeddings + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + + position_bias = None + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = np.random.uniform(0, 1) + + skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False + if not skip_the_layer or deepspeed_zero3_is_enabled: + # under deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + # create gradient checkpointing function + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer), + hidden_states, + attention_mask, + position_bias, + ) + else: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + index=i, + ) + + hidden_states, position_bias = layer_outputs[:2] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class WavLMEncoderStableLayerNorm(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pos_conv_embed = WavLMPositionalConvEmbedding(config) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layers = nn.ModuleList( + [WavLMEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + # make sure padded tokens are not attended to + hidden_states[~attention_mask] = 0 + + position_embeddings = self.pos_conv_embed(hidden_states) + hidden_states = hidden_states + position_embeddings + hidden_states = self.dropout(hidden_states) + + deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + + position_bias = None + + for layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = np.random.uniform(0, 1) + + skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False + if not skip_the_layer or deepspeed_zero3_is_enabled: + # under deepspeed zero3 all gpus must run in sync + # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication + if self.gradient_checkpointing and self.training: + # create gradient checkpointing function + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer), + hidden_states, + attention_mask, + position_bias, + ) + else: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + position_bias=position_bias, + ) + hidden_states, position_bias = layer_outputs[:2] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[2],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions + ) + + +class WavLMGumbelVectorQuantizer(nn.Module): + """ + Vector quantization using gumbel softmax. See `CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX + `__ for more information. + """ + + def __init__(self, config): + super().__init__() + self.num_groups = config.num_codevector_groups + self.num_vars = config.num_codevectors_per_group + + if config.codevector_dim % self.num_groups != 0: + raise ValueError( + f"`config.codevector_dim {config.codevector_dim} must be divisible" + f" by `config.num_codevector_groups` {self.num_groups} " + "for concatenation." + ) + + # storage for codebook variables (codewords) + self.codevectors = nn.Parameter( + torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups) + ) + self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars) + + # can be decayed for training + self.temperature = 2 + + @staticmethod + def _compute_perplexity(probs): + marginal_probs = probs.mean(dim=0) + perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum() + return perplexity + + def forward(self, hidden_states): + batch_size, sequence_length, hidden_size = hidden_states.shape + + # project to codevector dim + hidden_states = self.weight_proj(hidden_states) + hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1) + + if self.training: + # sample code vector probs via gumbel in differentiateable way + codevector_probs = nn.functional.gumbel_softmax(hidden_states.float(), tau=self.temperature, hard=True) + codevector_probs = codevector_probs.type_as(hidden_states) + + # compute perplexity + codevector_soft_dist = torch.softmax( + hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1 + ) + perplexity = self._compute_perplexity(codevector_soft_dist) + else: + # take argmax in non-differentiable way + # comptute hard codevector distribution (one hot) + codevector_idx = hidden_states.argmax(dim=-1) + codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( + -1, codevector_idx.view(-1, 1), 1.0 + ) + codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) + + perplexity = self._compute_perplexity(codevector_probs) + + codevector_probs = codevector_probs.view(batch_size * sequence_length, -1) + # use probs to retrieve codevectors + codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors + codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1) + codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1) + + return codevectors, perplexity + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->WavLM +class WavLMAdapter(nn.Module): + def __init__(self, config): + super().__init__() + + # feature dim might need to be down-projected + if config.output_hidden_size != config.hidden_size: + self.proj = nn.Linear(config.hidden_size, config.output_hidden_size) + self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size) + else: + self.proj = self.proj_layer_norm = None + + self.layers = nn.ModuleList(WavLMAdapterLayer(config) for _ in range(config.num_adapter_layers)) + self.layerdrop = config.layerdrop + + def forward(self, hidden_states): + # down project hidden_states if necessary + if self.proj is not None and self.proj_layer_norm is not None: + hidden_states = self.proj(hidden_states) + hidden_states = self.proj_layer_norm(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + + for layer in self.layers: + layerdrop_prob = np.random.random() + if not self.training or (layerdrop_prob > self.layerdrop): + hidden_states = layer(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->WavLM +class WavLMAdapterLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.output_hidden_size, + 2 * config.output_hidden_size, + config.adapter_kernel_size, + stride=config.adapter_stride, + padding=1, + ) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = nn.functional.glu(hidden_states, dim=1) + + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PreTrainedModel with Wav2Vec2->WavLM, wav2vec2->wavlm +class WavLMPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = WavLMConfig + base_model_prefix = "wavlm" + _keys_to_ignore_on_load_missing = [r"position_ids"] + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + # gumbel softmax requires special init + if isinstance(module, WavLMGumbelVectorQuantizer): + module.weight_proj.weight.data.normal_(mean=0.0, std=1) + module.weight_proj.bias.data.zero_() + nn.init.uniform_(module.codevectors) + elif isinstance(module, WavLMPositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, WavLMFeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + + def _get_feat_extract_output_lengths( + self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None + ): + """ + Computes the output length of the convolutional layers + """ + + add_adapter = self.config.add_adapter if add_adapter is None else add_adapter + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return (input_length - kernel_size) // stride + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + if add_adapter: + for _ in range(self.config.num_adapter_layers): + input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) + + return input_lengths + + def _get_feature_vector_attention_mask( + self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None + ): + # Effectively attention_mask.sum(-1), but not inplace to be able to run + # on inference mode. + non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] + + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) + output_lengths = output_lengths.to(torch.long) + + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (WavLMEncoder, WavLMEncoderStableLayerNorm, WavLMFeatureExtractor)): + module.gradient_checkpointing = value + + +WAVLM_START_DOCSTRING = r""" + WavLM was proposed in `WavLM: Unified Speech Representation Learning with Labeled and Unlabeled Data + `__ by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, + Michael Zeng, Xuedong Huang. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving etc.). + + This model is a PyTorch `torch.nn.Module `_ sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config (:class:`~transformers.WavLMConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + + +WAVLM_INPUTS_DOCSTRING = r""" + Args: + input_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file + into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install + soundfile`). To prepare the array into `input_values`, the :class:`~transformers.WavLMProcessor` should be + used for padding and conversion into a tensor of type `torch.FloatTensor`. See + :meth:`transformers.WavLMProcessor.__call__` for details. + attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in ``[0, + 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + + .. warning:: + :obj:`attention_mask` should only be passed if the corresponding processor has + ``config.return_attention_mask == True``. For all models whose processor has + ``config.return_attention_mask == False``, :obj:`attention_mask` should **not** be passed to avoid + degraded performance when doing batched inference. For such models :obj:`input_values` should simply be + padded with 0 and passed without :obj:`attention_mask`. Be aware that these models also yield slightly + different results depending on whether :obj:`input_values` is padded or not. + + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare WavLM Model transformer outputting raw hidden-states without any specific head on top.", + WAVLM_START_DOCSTRING, +) +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM +class WavLMModel(WavLMPreTrainedModel): + def __init__(self, config: WavLMConfig): + super().__init__(config) + self.config = config + self.feature_extractor = WavLMFeatureExtractor(config) + self.feature_projection = WavLMFeatureProjection(config) + + self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_()) + + if config.do_stable_layer_norm: + self.encoder = WavLMEncoderStableLayerNorm(config) + else: + self.encoder = WavLMEncoder(config) + + self.adapter = WavLMAdapter(config) if config.add_adapter else None + + # Initialize weights and apply final processing + self.post_init() + + def _mask_hidden_states( + self, + hidden_states: torch.FloatTensor, + mask_time_indices: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + """ + Masks extracted features along time axis and/or along feature axis according to `SpecAugment + `__ . + """ + + # `config.apply_spec_augment` can set masking to False + if not getattr(self.config, "apply_spec_augment", True): + return hidden_states + + # generate indices & apply SpecAugment along time axis + batch_size, sequence_length, hidden_size = hidden_states.size() + + if mask_time_indices is not None: + # apply SpecAugment along time axis with given mask_time_indices + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + elif self.config.mask_time_prob > 0 and self.training: + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + mask_prob=self.config.mask_time_prob, + mask_length=self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=self.config.mask_time_min_masks, + ) + mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + + if self.config.mask_feature_prob > 0 and self.training: + # generate indices & apply SpecAugment along feature axis + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + mask_prob=self.config.mask_feature_prob, + mask_length=self.config.mask_feature_length, + min_masks=self.config.mask_feature_min_masks, + ) + mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) + mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) + hidden_states[mask_feature_indices] = 0 + + return hidden_states + + @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_PROCESSOR_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=WavLMBaseModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + def forward( + self, + input_values, + attention_mask=None, + mask_time_indices=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], attention_mask, add_adapter=False + ) + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if self.adapter is not None: + hidden_states = self.adapter(hidden_states) + + if not return_dict: + return (hidden_states, extract_features) + encoder_outputs[1:] + + return WavLMBaseModelOutput( + last_hidden_state=hidden_states, + extract_features=extract_features, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """WavLM Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). """, + WAVLM_START_DOCSTRING, +) +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM +class WavLMForCTC(WavLMPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.wavlm = WavLMModel(config) + self.dropout = nn.Dropout(config.final_dropout) + + if config.vocab_size is None: + raise ValueError( + f"You are trying to instantiate {self.__class__} with a configuration that " + "does not define the vocabulary size of the language model head. Please " + "instantiate the model as follows: `WavLMForCTC.from_pretrained(..., vocab_size=vocab_size)`. " + "or define `vocab_size` of your model's configuration." + ) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) + + # Initialize weights and apply final processing + self.post_init() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature extractor so that its parameter + will not be updated during training. + """ + self.wavlm.feature_extractor._freeze_parameters() + + @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_PROCESSOR_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_values, + attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + labels=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_length)`, `optional`): + Labels for connectionist temporal classification. Note that ``target_length`` has to be smaller or equal to + the sequence length of the output logits. Indices are selected in ``[-100, 0, ..., config.vocab_size - + 1]``. All labels set to ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., + config.vocab_size - 1]``. + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.wavlm( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + + if labels.max() >= self.config.vocab_size: + raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") + + # retrieve loss input_lengths from attention_mask + attention_mask = ( + attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) + ) + input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + + # assuming that padded tokens are filled with -100 + # when not being attended to + labels_mask = labels >= 0 + target_lengths = labels_mask.sum(-1) + flattened_targets = labels.masked_select(labels_mask) + + # ctc_loss doesn't support fp16 + log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) + + with torch.backends.cudnn.flags(enabled=False): + loss = nn.functional.ctc_loss( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + blank=self.config.pad_token_id, + reduction=self.config.ctc_loss_reduction, + zero_infinity=self.config.ctc_zero_infinity, + ) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@add_start_docstrings( + """ + WavLM Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like + SUPERB Keyword Spotting. + """, + WAVLM_START_DOCSTRING, +) +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM +class WavLMForSequenceClassification(WavLMPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.wavlm = WavLMModel(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) + self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature extractor so that its parameters + will not be updated during training. + """ + self.wavlm.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.wavlm.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_SEQ_CLASS_CHECKPOINT, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + def forward( + self, + input_values, + attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + labels=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.wavlm( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + if attention_mask is None: + pooled_output = hidden_states.mean(dim=1) + else: + padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) + hidden_states[~padding_mask] = 0.0 + pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index f2225f6cc..94d319960 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -5174,6 +5174,50 @@ class Wav2Vec2PreTrainedModel: requires_backends(self, ["torch"]) +WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class WavLMForCTC: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class WavLMForSequenceClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def forward(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class WavLMModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def forward(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class WavLMPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def forward(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + XLM_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/test_modeling_wavlm.py b/tests/test_modeling_wavlm.py new file mode 100644 index 000000000..c8c832b7d --- /dev/null +++ b/tests/test_modeling_wavlm.py @@ -0,0 +1,493 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch WavLM model. """ + +import math +import unittest + +import pytest +from datasets import load_dataset + +from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask +from transformers import WavLMConfig, is_torch_available +from transformers.testing_utils import require_torch, require_torchaudio, slow, torch_device + +from .test_configuration_common import ConfigTester +from .test_modeling_common import ModelTesterMixin, _config_zero_init + + +if is_torch_available(): + import torch + + from transformers import Wav2Vec2FeatureExtractor, WavLMForCTC, WavLMForSequenceClassification, WavLMModel + + +class WavLMModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=1024, # speech is longer + is_training=False, + hidden_size=16, + feat_extract_norm="group", + feat_extract_dropout=0.0, + feat_extract_activation="gelu", + conv_dim=(32, 32, 32), + conv_stride=(4, 4, 4), + conv_kernel=(8, 8, 8), + conv_bias=False, + num_conv_pos_embeddings=16, + num_conv_pos_embedding_groups=2, + num_hidden_layers=4, + num_attention_heads=2, + hidden_dropout_prob=0.1, # this is most likely not correctly set yet + intermediate_size=20, + layer_norm_eps=1e-5, + hidden_act="gelu", + initializer_range=0.02, + vocab_size=32, + do_stable_layer_norm=False, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.hidden_size = hidden_size + self.feat_extract_norm = feat_extract_norm + self.feat_extract_dropout = feat_extract_dropout + self.feat_extract_activation = feat_extract_activation + self.conv_dim = conv_dim + self.conv_stride = conv_stride + self.conv_kernel = conv_kernel + self.conv_bias = conv_bias + self.num_conv_pos_embeddings = num_conv_pos_embeddings + self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_dropout_prob = hidden_dropout_prob + self.intermediate_size = intermediate_size + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.vocab_size = vocab_size + self.do_stable_layer_norm = do_stable_layer_norm + self.scope = scope + + output_seq_length = self.seq_length + for kernel, stride in zip(self.conv_kernel, self.conv_stride): + output_seq_length = (output_seq_length - (kernel - 1)) / stride + self.output_seq_length = int(math.ceil(output_seq_length)) + self.encoder_seq_length = self.output_seq_length + + def prepare_config_and_inputs(self): + input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = random_attention_mask([self.batch_size, self.seq_length]) + + config = self.get_config() + + return config, input_values, attention_mask + + def get_config(self): + return WavLMConfig( + hidden_size=self.hidden_size, + feat_extract_norm=self.feat_extract_norm, + feat_extract_dropout=self.feat_extract_dropout, + feat_extract_activation=self.feat_extract_activation, + conv_dim=self.conv_dim, + conv_stride=self.conv_stride, + conv_kernel=self.conv_kernel, + conv_bias=self.conv_bias, + num_conv_pos_embeddings=self.num_conv_pos_embeddings, + num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + hidden_dropout_prob=self.hidden_dropout_prob, + intermediate_size=self.intermediate_size, + layer_norm_eps=self.layer_norm_eps, + hidden_act=self.hidden_act, + initializer_range=self.initializer_range, + vocab_size=self.vocab_size, + ) + + def create_and_check_model(self, config, input_values, attention_mask): + model = WavLMModel(config=config) + model.to(torch_device) + model.eval() + result = model(input_values, attention_mask=attention_mask) + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size) + ) + + def create_and_check_batch_inference(self, config, input_values, *args): + # test does not pass for models making use of `group_norm` + # check: https://github.com/pytorch/fairseq/issues/3227 + model = WavLMModel(config=config) + model.to(torch_device) + model.eval() + + input_values = input_values[:3] + attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.bool) + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + + # pad input + for i in range(len(input_lengths)): + input_values[i, input_lengths[i] :] = 0.0 + attention_mask[i, input_lengths[i] :] = 0.0 + + batch_outputs = model(input_values, attention_mask=attention_mask).last_hidden_state + + for i in range(input_values.shape[0]): + input_slice = input_values[i : i + 1, : input_lengths[i]] + output = model(input_slice).last_hidden_state + + batch_output = batch_outputs[i : i + 1, : output.shape[1]] + self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3)) + + def check_ctc_loss(self, config, input_values, *args): + model = WavLMForCTC(config=config) + model.to(torch_device) + + # make sure that dropout is disabled + model.eval() + + input_values = input_values[:3] + attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long) + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths)) + labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size) + + # pad input + for i in range(len(input_lengths)): + input_values[i, input_lengths[i] :] = 0.0 + attention_mask[i, input_lengths[i] :] = 0 + + model.config.ctc_loss_reduction = "sum" + sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item() + + model.config.ctc_loss_reduction = "mean" + mean_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item() + + self.parent.assertTrue(isinstance(sum_loss, float)) + self.parent.assertTrue(isinstance(mean_loss, float)) + + def check_seq_classifier_loss(self, config, input_values, *args): + model = WavLMForSequenceClassification(config=config) + model.to(torch_device) + + # make sure that dropout is disabled + model.eval() + + input_values = input_values[:3] + attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long) + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label)) + + # pad input + for i in range(len(input_lengths)): + input_values[i, input_lengths[i] :] = 0.0 + attention_mask[i, input_lengths[i] :] = 0 + + masked_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item() + unmasked_loss = model(input_values, labels=labels).loss.item() + + self.parent.assertTrue(isinstance(masked_loss, float)) + self.parent.assertTrue(isinstance(unmasked_loss, float)) + self.parent.assertTrue(masked_loss != unmasked_loss) + + def check_ctc_training(self, config, input_values, *args): + config.ctc_zero_infinity = True + model = WavLMForCTC(config=config) + model.to(torch_device) + model.train() + + # freeze feature encoder + model.freeze_feature_extractor() + + input_values = input_values[:3] + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths)) + labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size) + + # pad input + for i in range(len(input_lengths)): + input_values[i, input_lengths[i] :] = 0.0 + + if max_length_labels[i] < labels.shape[-1]: + # it's important that we make sure that target lenghts are at least + # one shorter than logit lenghts to prevent -inf + labels[i, max_length_labels[i] - 1 :] = -100 + + loss = model(input_values, labels=labels).loss + self.parent.assertFalse(torch.isinf(loss).item()) + + loss.backward() + + def check_seq_classifier_training(self, config, input_values, *args): + config.ctc_zero_infinity = True + model = WavLMForSequenceClassification(config=config) + model.to(torch_device) + model.train() + + # freeze everything but the classification head + model.freeze_base_model() + + input_values = input_values[:3] + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label)) + + # pad input + for i in range(len(input_lengths)): + input_values[i, input_lengths[i] :] = 0.0 + + loss = model(input_values, labels=labels).loss + self.parent.assertFalse(torch.isinf(loss).item()) + + loss.backward() + + def check_labels_out_of_vocab(self, config, input_values, *args): + model = WavLMForCTC(config) + model.to(torch_device) + model.train() + + input_values = input_values[:3] + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths)) + labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size + 100) + + with pytest.raises(ValueError): + model(input_values, labels=labels) + + def prepare_config_and_inputs_for_common(self): + config, input_values, attention_mask = self.prepare_config_and_inputs() + inputs_dict = {"input_values": input_values, "attention_mask": attention_mask} + return config, inputs_dict + + +@require_torch +class WavLMModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (WavLMForCTC, WavLMModel, WavLMForSequenceClassification) if is_torch_available() else () + test_pruning = False + test_headmasking = False + test_torchscript = False + + def setUp(self): + self.model_tester = WavLMModelTester(self) + self.config_tester = ConfigTester(self, config_class=WavLMConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_ctc_loss_inference(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_ctc_loss(*config_and_inputs) + + def test_seq_classifier_loss_inference(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_seq_classifier_loss(*config_and_inputs) + + def test_ctc_train(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_ctc_training(*config_and_inputs) + + def test_seq_classifier_train(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_seq_classifier_training(*config_and_inputs) + + def test_labels_out_of_vocab(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_labels_out_of_vocab(*config_and_inputs) + + # WavLM has no inputs_embeds + def test_inputs_embeds(self): + pass + + # `input_ids` is renamed to `input_values` + def test_forward_signature(self): + pass + + # WavLM cannot resize token embeddings + # since it has no tokens embeddings + def test_resize_tokens_embeddings(self): + pass + + # WavLM has no inputs_embeds + # and thus the `get_input_embeddings` fn + # is not implemented + def test_model_common_attributes(self): + pass + + # WavLM uses PyTorch's multi-head-attention class + # and thus can't retain gradients on attentions + def test_retain_grad_hidden_states_attentions(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = True + + # no need to test all models as different heads yield the same functionality + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + # set layer drop to 0 + model.config.layerdrop = 0.0 + + input_values = inputs_dict["input_values"] + + input_lengths = torch.tensor( + [input_values.shape[1] for _ in range(input_values.shape[0])], dtype=torch.long, device=torch_device + ) + output_lengths = model._get_feat_extract_output_lengths(input_lengths) + + labels = ids_tensor((input_values.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size) + inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"]) + inputs_dict["labels"] = labels + + outputs = model(**inputs_dict) + + output = outputs[0] + + # Encoder-/Decoder-only models + hidden_states = outputs.hidden_states[0] + hidden_states.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(hidden_states.grad) + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + uniform_init_parms = [ + "conv.weight", + "masked_spec_embed", + "codevectors", + "quantizer.weight_proj.weight", + "project_hid.weight", + "project_hid.bias", + "project_q.weight", + "project_q.bias", + "feature_projection.projection.weight", + "feature_projection.projection.bias", + "label_embeddings_concat", + "rel_attn_embed", + ] + if param.requires_grad: + if any([x in name for x in uniform_init_parms]): + self.assertTrue( + -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + else: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + # overwrite from test_modeling_common + def _mock_init_weights(self, module): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data.fill_(3) + if hasattr(module, "weight_g") and module.weight_g is not None: + module.weight_g.data.fill_(3) + if hasattr(module, "weight_v") and module.weight_v is not None: + module.weight_v.data.fill_(3) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.fill_(3) + if hasattr(module, "codevectors") and module.codevectors is not None: + module.codevectors.data.fill_(3) + if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: + module.masked_spec_embed.data.fill_(3) + + @slow + def test_model_from_pretrained(self): + model = WavLMModel.from_pretrained("microsoft/wavlm-base-plus") + self.assertIsNotNone(model) + + +@require_torch +@require_torchaudio +@slow +class WavLMModelIntegrationTest(unittest.TestCase): + def _load_datasamples(self, num_samples): + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + # automatic decoding with librispeech + speech_samples = ds.sort("id").filter( + lambda x: x["id"] in [f"1272-141231-000{i}" for i in range(num_samples)] + )[:num_samples]["audio"] + + return [x["array"] for x in speech_samples] + + def test_inference_base(self): + model = WavLMModel.from_pretrained("microsoft/wavlm-base-plus").to(torch_device) + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + "microsoft/wavlm-base-plus", return_attention_mask=True + ) + + input_speech = self._load_datasamples(2) + + inputs = feature_extractor(input_speech, return_tensors="pt", padding=True) + + input_values = inputs.input_values.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + + with torch.no_grad(): + hidden_states_slice = ( + model(input_values, attention_mask=attention_mask).last_hidden_state[:, -2:, -2:].cpu() + ) + + EXPECTED_HIDDEN_STATES_SLICE = torch.tensor( + [[[0.0554, 0.1138], [0.0555, 0.1144]], [[0.0200, 0.1240], [0.0059, 0.0607]]] + ) + self.assertTrue(torch.allclose(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, rtol=1e-2)) + + def test_inference_large(self): + model = WavLMModel.from_pretrained("microsoft/wavlm-large").to(torch_device) + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + "microsoft/wavlm-base-plus", return_attention_mask=True + ) + + input_speech = self._load_datasamples(2) + + inputs = feature_extractor(input_speech, return_tensors="pt", padding=True) + + input_values = inputs.input_values.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + + with torch.no_grad(): + hidden_states_slice = ( + model(input_values, attention_mask=attention_mask).last_hidden_state[:, -2:, -2:].cpu() + ) + + EXPECTED_HIDDEN_STATES_SLICE = torch.tensor( + [[[0.2132, 0.0486], [0.2119, 0.0571]], [[0.1386, 0.1837], [0.2455, 0.0614]]] + ) + self.assertTrue(torch.allclose(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, rtol=1e-2))