diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index ddfa28fbf..e32d6e66d 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -467,7 +467,7 @@ def main(): trainer.save_metrics("eval", metrics) if training_args.push_to_hub: - kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "text-generation"} + kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"} if data_args.dataset_name is not None: kwargs["dataset_tags"] = data_args.dataset_name if data_args.dataset_config_name is not None: diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py index da687aea1..425647875 100755 --- a/examples/pytorch/language-modeling/run_mlm.py +++ b/examples/pytorch/language-modeling/run_mlm.py @@ -497,7 +497,7 @@ def main(): trainer.save_metrics("eval", metrics) if training_args.push_to_hub: - kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "fill-mask"} + kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "fill-mask"} if data_args.dataset_name is not None: kwargs["dataset_tags"] = data_args.dataset_name if data_args.dataset_config_name is not None: diff --git a/examples/pytorch/language-modeling/run_plm.py b/examples/pytorch/language-modeling/run_plm.py index b4cf5f532..c19d7dfde 100755 --- a/examples/pytorch/language-modeling/run_plm.py +++ b/examples/pytorch/language-modeling/run_plm.py @@ -471,7 +471,7 @@ def main(): trainer.save_metrics("eval", metrics) if training_args.push_to_hub: - kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "language-modeling"} + kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "language-modeling"} if data_args.dataset_name is not None: kwargs["dataset_tags"] = data_args.dataset_name if data_args.dataset_config_name is not None: diff --git a/examples/pytorch/multiple-choice/run_swag.py b/examples/pytorch/multiple-choice/run_swag.py index 0dd11d286..b21406bc0 100755 --- a/examples/pytorch/multiple-choice/run_swag.py +++ b/examples/pytorch/multiple-choice/run_swag.py @@ -430,7 +430,7 @@ def main(): if training_args.push_to_hub: trainer.push_to_hub( finetuned_from=model_args.model_name_or_path, - tags="multiple-choice", + tasks="multiple-choice", dataset_tags="swag", dataset_args="regular", dataset="SWAG", diff --git a/examples/pytorch/question-answering/run_qa.py b/examples/pytorch/question-answering/run_qa.py index c3e1520bc..b6ba8c7a8 100755 --- a/examples/pytorch/question-answering/run_qa.py +++ b/examples/pytorch/question-answering/run_qa.py @@ -601,7 +601,7 @@ def main(): trainer.save_metrics("predict", metrics) if training_args.push_to_hub: - kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "question-answering"} + kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "question-answering"} if data_args.dataset_name is not None: kwargs["dataset_tags"] = data_args.dataset_name if data_args.dataset_config_name is not None: diff --git a/examples/pytorch/question-answering/run_qa_beam_search.py b/examples/pytorch/question-answering/run_qa_beam_search.py index ef5396f72..70c2d1f62 100755 --- a/examples/pytorch/question-answering/run_qa_beam_search.py +++ b/examples/pytorch/question-answering/run_qa_beam_search.py @@ -640,7 +640,7 @@ def main(): trainer.save_metrics("predict", metrics) if training_args.push_to_hub: - kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "question-answering"} + kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "question-answering"} if data_args.dataset_name is not None: kwargs["dataset_tags"] = data_args.dataset_name if data_args.dataset_config_name is not None: diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py index 98dbcef74..277c19324 100755 --- a/examples/pytorch/summarization/run_summarization.py +++ b/examples/pytorch/summarization/run_summarization.py @@ -583,7 +583,7 @@ def main(): writer.write("\n".join(predictions)) if training_args.push_to_hub: - kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "summarization"} + kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "summarization"} if data_args.dataset_name is not None: kwargs["dataset_tags"] = data_args.dataset_name if data_args.dataset_config_name is not None: diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index b7fe21424..0c1d60a69 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -538,7 +538,7 @@ def main(): writer.write(f"{index}\t{item}\n") if training_args.push_to_hub: - kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "text-classification"} + kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"} if data_args.task_name is not None: kwargs["language"] = "en" kwargs["dataset_tags"] = "glue" diff --git a/examples/pytorch/token-classification/run_ner.py b/examples/pytorch/token-classification/run_ner.py index ab1372ba4..ffa4f7773 100755 --- a/examples/pytorch/token-classification/run_ner.py +++ b/examples/pytorch/token-classification/run_ner.py @@ -522,7 +522,7 @@ def main(): writer.write(" ".join(prediction) + "\n") if training_args.push_to_hub: - kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "token-classification"} + kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "token-classification"} if data_args.dataset_name is not None: kwargs["dataset_tags"] = data_args.dataset_name if data_args.dataset_config_name is not None: diff --git a/examples/pytorch/translation/run_translation.py b/examples/pytorch/translation/run_translation.py index a89ea80b4..3f4a45875 100755 --- a/examples/pytorch/translation/run_translation.py +++ b/examples/pytorch/translation/run_translation.py @@ -575,7 +575,7 @@ def main(): writer.write("\n".join(predictions)) if training_args.push_to_hub: - kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "translation"} + kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "translation"} if data_args.dataset_name is not None: kwargs["dataset_tags"] = data_args.dataset_name if data_args.dataset_config_name is not None: diff --git a/src/transformers/modelcard.py b/src/transformers/modelcard.py index 49f250265..eb71f6821 100644 --- a/src/transformers/modelcard.py +++ b/src/transformers/modelcard.py @@ -42,8 +42,31 @@ from .file_utils import ( ) from .training_args import ParallelMode from .utils import logging +from .utils.modeling_auto_mapping import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, + MODEL_FOR_MASKED_LM_MAPPING_NAMES, + MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, + MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, + MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, + MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES, + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, +) +TASK_MAPPING = { + "text-generation": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, + "fill-mask": MODEL_FOR_MASKED_LM_MAPPING_NAMES, + "object-detection": MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, + "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, + "text2text-generation": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, + "text-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, + "table-question-answering": MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES, + "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, +} + logger = logging.get_logger(__name__) @@ -246,9 +269,12 @@ should probably proofread and complete it, then remove this comment. --> TASK_TAG_TO_NAME_MAPPING = { "fill-mask": "Masked Language Modeling", + "image-classification": "Image Classification", "multiple-choice": "Multiple Choice", + "object-detection": "Object Detection", "question-answering": "Question Answering", "summarization": "Summarization", + "table-question-answering": "Table Question Answering", "text-classification": "Text Classification", "text-generation": "Causal Language Modeling", "text2text-generation": "Sequence-to-sequence Language Modeling", @@ -304,6 +330,25 @@ def infer_metric_tags_from_eval_results(eval_results): return result +def is_hf_dataset(dataset): + if not is_datasets_available(): + return False + + from datasets import Dataset + + return isinstance(dataset, Dataset) + + +def _get_mapping_values(mapping): + result = [] + for v in mapping.values(): + if isinstance(v, (tuple, list)): + result += list(v) + else: + result.append(v) + return result + + @dataclass class TrainingSummary: model_name: str @@ -311,6 +356,7 @@ class TrainingSummary: license: Optional[str] = None tags: Optional[Union[str, List[str]]] = None finetuned_from: Optional[str] = None + tasks: Optional[Union[str, List[str]]] = None dataset: Optional[Union[str, List[str]]] = None dataset_tags: Optional[Union[str, List[str]]] = None dataset_args: Optional[Union[str, List[str]]] = None @@ -320,7 +366,12 @@ class TrainingSummary: def __post_init__(self): # Infer default license from the checkpoint used, if possible. - if self.license is None and not is_offline_mode() and self.finetuned_from is not None: + if ( + self.license is None + and not is_offline_mode() + and self.finetuned_from is not None + and len(self.finetuned_from) > 0 + ): try: model_info = HfApi().model_info(self.finetuned_from) for tag in model_info.tags: @@ -342,7 +393,7 @@ class TrainingSummary: dataset_arg_mapping = {tag: arg for tag, arg in zip(dataset_tags, dataset_args)} task_mapping = { - tag: TASK_TAG_TO_NAME_MAPPING[tag] for tag in _listify(self.tags) if tag in TASK_TAG_TO_NAME_MAPPING + task: TASK_TAG_TO_NAME_MAPPING[task] for task in _listify(self.tasks) if task in TASK_TAG_TO_NAME_MAPPING } if len(task_mapping) == 0 and len(dataset_mapping) == 0: @@ -405,6 +456,8 @@ class TrainingSummary: else: if isinstance(self.dataset, str): model_card += f"the {self.dataset} dataset." + elif isinstance(self.dataset, (tuple, list)) and len(self.dataset) == 1: + model_card += f"the {self.dataset[0]} dataset." else: model_card += ( ", ".join([f"the {ds}" for ds in self.dataset[:-1]]) + f" and the {self.dataset[-1]} datasets." @@ -459,11 +512,40 @@ class TrainingSummary: tags=None, model_name=None, finetuned_from=None, + tasks=None, dataset_tags=None, dataset=None, dataset_args=None, ): - # TODO (Sylvain) Add a default for `pipeline-tag` inferred from the model. + # Infer default from dataset + one_dataset = trainer.train_dataset if trainer.train_dataset is not None else trainer.eval_dataset + if is_hf_dataset(one_dataset) and (dataset_tags is None or dataset_args is None): + default_tag = one_dataset.builder_name + # Those are not real datasets from the Hub so we exclude them. + if default_tag not in ["csv", "json", "pandas", "parquet", "text"]: + if dataset_tags is None: + dataset_tags = [default_tag] + if dataset_args is None: + dataset_args = [one_dataset.config_name] + + if dataset is None and dataset_tags is not None: + dataset = dataset_tags + + # Infer default finetuned_from + if ( + finetuned_from is None + and hasattr(trainer.model.config, "_name_or_path") + and not os.path.isdir(trainer.model.config._name_or_path) + ): + finetuned_from = trainer.model.config._name_or_path + + # Infer default task tag: + if tasks is None: + model_class_name = trainer.model.__class__.__name__ + for task, mapping in TASK_MAPPING.items(): + if model_class_name in _get_mapping_values(mapping): + tasks = task + if model_name is None: model_name = Path(trainer.args.output_dir).name @@ -476,6 +558,7 @@ class TrainingSummary: tags=tags, model_name=model_name, finetuned_from=finetuned_from, + tasks=tasks, dataset_tags=dataset_tags, dataset=dataset, dataset_args=dataset_args, diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 9f882e56a..70aeec25c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2433,6 +2433,7 @@ class Trainer: tags: Optional[str] = None, model_name: Optional[str] = None, finetuned_from: Optional[str] = None, + tasks: Optional[str] = None, dataset_tags: Optional[Union[str, List[str]]] = None, dataset: Optional[Union[str, List[str]]] = None, dataset_args: Optional[Union[str, List[str]]] = None, @@ -2444,6 +2445,7 @@ class Trainer: tags=tags, model_name=model_name, finetuned_from=finetuned_from, + tasks=tasks, dataset_tags=dataset_tags, dataset=dataset, dataset_args=dataset_args, diff --git a/src/transformers/utils/modeling_auto_mapping.py b/src/transformers/utils/modeling_auto_mapping.py index f6abd0bcf..10e7aabba 100644 --- a/src/transformers/utils/modeling_auto_mapping.py +++ b/src/transformers/utils/modeling_auto_mapping.py @@ -36,3 +36,323 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ("IBertConfig", "IBertForQuestionAnswering"), ] ) + + +MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( + [ + ("RoFormerConfig", "RoFormerForCausalLM"), + ("BigBirdPegasusConfig", "BigBirdPegasusForCausalLM"), + ("GPTNeoConfig", "GPTNeoForCausalLM"), + ("BigBirdConfig", "BigBirdForCausalLM"), + ("CamembertConfig", "CamembertForCausalLM"), + ("XLMRobertaConfig", "XLMRobertaForCausalLM"), + ("RobertaConfig", "RobertaForCausalLM"), + ("BertConfig", "BertLMHeadModel"), + ("OpenAIGPTConfig", "OpenAIGPTLMHeadModel"), + ("GPT2Config", "GPT2LMHeadModel"), + ("TransfoXLConfig", "TransfoXLLMHeadModel"), + ("XLNetConfig", "XLNetLMHeadModel"), + ("XLMConfig", "XLMWithLMHeadModel"), + ("CTRLConfig", "CTRLLMHeadModel"), + ("ReformerConfig", "ReformerModelWithLMHead"), + ("BertGenerationConfig", "BertGenerationDecoder"), + ("XLMProphetNetConfig", "XLMProphetNetForCausalLM"), + ("ProphetNetConfig", "ProphetNetForCausalLM"), + ("BartConfig", "BartForCausalLM"), + ("MBartConfig", "MBartForCausalLM"), + ("PegasusConfig", "PegasusForCausalLM"), + ("MarianConfig", "MarianForCausalLM"), + ("BlenderbotConfig", "BlenderbotForCausalLM"), + ("BlenderbotSmallConfig", "BlenderbotSmallForCausalLM"), + ("MegatronBertConfig", "MegatronBertForCausalLM"), + ] +) + + +MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + ("ViTConfig", "ViTForImageClassification"), + ("DeiTConfig", "('DeiTForImageClassification', 'DeiTForImageClassificationWithTeacher')"), + ] +) + + +MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( + [ + ("RoFormerConfig", "RoFormerForMaskedLM"), + ("BigBirdConfig", "BigBirdForMaskedLM"), + ("Wav2Vec2Config", "Wav2Vec2ForMaskedLM"), + ("ConvBertConfig", "ConvBertForMaskedLM"), + ("LayoutLMConfig", "LayoutLMForMaskedLM"), + ("DistilBertConfig", "DistilBertForMaskedLM"), + ("AlbertConfig", "AlbertForMaskedLM"), + ("BartConfig", "BartForConditionalGeneration"), + ("MBartConfig", "MBartForConditionalGeneration"), + ("CamembertConfig", "CamembertForMaskedLM"), + ("XLMRobertaConfig", "XLMRobertaForMaskedLM"), + ("LongformerConfig", "LongformerForMaskedLM"), + ("RobertaConfig", "RobertaForMaskedLM"), + ("SqueezeBertConfig", "SqueezeBertForMaskedLM"), + ("BertConfig", "BertForMaskedLM"), + ("MegatronBertConfig", "MegatronBertForMaskedLM"), + ("MobileBertConfig", "MobileBertForMaskedLM"), + ("FlaubertConfig", "FlaubertWithLMHeadModel"), + ("XLMConfig", "XLMWithLMHeadModel"), + ("ElectraConfig", "ElectraForMaskedLM"), + ("ReformerConfig", "ReformerForMaskedLM"), + ("FunnelConfig", "FunnelForMaskedLM"), + ("MPNetConfig", "MPNetForMaskedLM"), + ("TapasConfig", "TapasForMaskedLM"), + ("DebertaConfig", "DebertaForMaskedLM"), + ("DebertaV2Config", "DebertaV2ForMaskedLM"), + ("IBertConfig", "IBertForMaskedLM"), + ] +) + + +MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( + [ + ("RoFormerConfig", "RoFormerForMultipleChoice"), + ("BigBirdConfig", "BigBirdForMultipleChoice"), + ("ConvBertConfig", "ConvBertForMultipleChoice"), + ("CamembertConfig", "CamembertForMultipleChoice"), + ("ElectraConfig", "ElectraForMultipleChoice"), + ("XLMRobertaConfig", "XLMRobertaForMultipleChoice"), + ("LongformerConfig", "LongformerForMultipleChoice"), + ("RobertaConfig", "RobertaForMultipleChoice"), + ("SqueezeBertConfig", "SqueezeBertForMultipleChoice"), + ("BertConfig", "BertForMultipleChoice"), + ("DistilBertConfig", "DistilBertForMultipleChoice"), + ("MegatronBertConfig", "MegatronBertForMultipleChoice"), + ("MobileBertConfig", "MobileBertForMultipleChoice"), + ("XLNetConfig", "XLNetForMultipleChoice"), + ("AlbertConfig", "AlbertForMultipleChoice"), + ("XLMConfig", "XLMForMultipleChoice"), + ("FlaubertConfig", "FlaubertForMultipleChoice"), + ("FunnelConfig", "FunnelForMultipleChoice"), + ("MPNetConfig", "MPNetForMultipleChoice"), + ("IBertConfig", "IBertForMultipleChoice"), + ] +) + + +MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( + [ + ("BertConfig", "BertForNextSentencePrediction"), + ("MegatronBertConfig", "MegatronBertForNextSentencePrediction"), + ("MobileBertConfig", "MobileBertForNextSentencePrediction"), + ] +) + + +MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict( + [ + ("DetrConfig", "DetrForObjectDetection"), + ] +) + + +MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( + [ + ("BigBirdPegasusConfig", "BigBirdPegasusForConditionalGeneration"), + ("M2M100Config", "M2M100ForConditionalGeneration"), + ("LEDConfig", "LEDForConditionalGeneration"), + ("BlenderbotSmallConfig", "BlenderbotSmallForConditionalGeneration"), + ("MT5Config", "MT5ForConditionalGeneration"), + ("T5Config", "T5ForConditionalGeneration"), + ("PegasusConfig", "PegasusForConditionalGeneration"), + ("MarianConfig", "MarianMTModel"), + ("MBartConfig", "MBartForConditionalGeneration"), + ("BlenderbotConfig", "BlenderbotForConditionalGeneration"), + ("BartConfig", "BartForConditionalGeneration"), + ("FSMTConfig", "FSMTForConditionalGeneration"), + ("EncoderDecoderConfig", "EncoderDecoderModel"), + ("XLMProphetNetConfig", "XLMProphetNetForConditionalGeneration"), + ("ProphetNetConfig", "ProphetNetForConditionalGeneration"), + ] +) + + +MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + ("RoFormerConfig", "RoFormerForSequenceClassification"), + ("BigBirdPegasusConfig", "BigBirdPegasusForSequenceClassification"), + ("BigBirdConfig", "BigBirdForSequenceClassification"), + ("ConvBertConfig", "ConvBertForSequenceClassification"), + ("LEDConfig", "LEDForSequenceClassification"), + ("DistilBertConfig", "DistilBertForSequenceClassification"), + ("AlbertConfig", "AlbertForSequenceClassification"), + ("CamembertConfig", "CamembertForSequenceClassification"), + ("XLMRobertaConfig", "XLMRobertaForSequenceClassification"), + ("MBartConfig", "MBartForSequenceClassification"), + ("BartConfig", "BartForSequenceClassification"), + ("LongformerConfig", "LongformerForSequenceClassification"), + ("RobertaConfig", "RobertaForSequenceClassification"), + ("SqueezeBertConfig", "SqueezeBertForSequenceClassification"), + ("LayoutLMConfig", "LayoutLMForSequenceClassification"), + ("BertConfig", "BertForSequenceClassification"), + ("XLNetConfig", "XLNetForSequenceClassification"), + ("MegatronBertConfig", "MegatronBertForSequenceClassification"), + ("MobileBertConfig", "MobileBertForSequenceClassification"), + ("FlaubertConfig", "FlaubertForSequenceClassification"), + ("XLMConfig", "XLMForSequenceClassification"), + ("ElectraConfig", "ElectraForSequenceClassification"), + ("FunnelConfig", "FunnelForSequenceClassification"), + ("DebertaConfig", "DebertaForSequenceClassification"), + ("DebertaV2Config", "DebertaV2ForSequenceClassification"), + ("GPT2Config", "GPT2ForSequenceClassification"), + ("GPTNeoConfig", "GPTNeoForSequenceClassification"), + ("OpenAIGPTConfig", "OpenAIGPTForSequenceClassification"), + ("ReformerConfig", "ReformerForSequenceClassification"), + ("CTRLConfig", "CTRLForSequenceClassification"), + ("TransfoXLConfig", "TransfoXLForSequenceClassification"), + ("MPNetConfig", "MPNetForSequenceClassification"), + ("TapasConfig", "TapasForSequenceClassification"), + ("IBertConfig", "IBertForSequenceClassification"), + ] +) + + +MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( + [ + ("TapasConfig", "TapasForQuestionAnswering"), + ] +) + + +MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + ("RoFormerConfig", "RoFormerForTokenClassification"), + ("BigBirdConfig", "BigBirdForTokenClassification"), + ("ConvBertConfig", "ConvBertForTokenClassification"), + ("LayoutLMConfig", "LayoutLMForTokenClassification"), + ("DistilBertConfig", "DistilBertForTokenClassification"), + ("CamembertConfig", "CamembertForTokenClassification"), + ("FlaubertConfig", "FlaubertForTokenClassification"), + ("XLMConfig", "XLMForTokenClassification"), + ("XLMRobertaConfig", "XLMRobertaForTokenClassification"), + ("LongformerConfig", "LongformerForTokenClassification"), + ("RobertaConfig", "RobertaForTokenClassification"), + ("SqueezeBertConfig", "SqueezeBertForTokenClassification"), + ("BertConfig", "BertForTokenClassification"), + ("MegatronBertConfig", "MegatronBertForTokenClassification"), + ("MobileBertConfig", "MobileBertForTokenClassification"), + ("XLNetConfig", "XLNetForTokenClassification"), + ("AlbertConfig", "AlbertForTokenClassification"), + ("ElectraConfig", "ElectraForTokenClassification"), + ("FunnelConfig", "FunnelForTokenClassification"), + ("MPNetConfig", "MPNetForTokenClassification"), + ("DebertaConfig", "DebertaForTokenClassification"), + ("DebertaV2Config", "DebertaV2ForTokenClassification"), + ("IBertConfig", "IBertForTokenClassification"), + ] +) + + +MODEL_MAPPING_NAMES = OrderedDict( + [ + ("VisualBertConfig", "VisualBertModel"), + ("RoFormerConfig", "RoFormerModel"), + ("CLIPConfig", "CLIPModel"), + ("BigBirdPegasusConfig", "BigBirdPegasusModel"), + ("DeiTConfig", "DeiTModel"), + ("LukeConfig", "LukeModel"), + ("DetrConfig", "DetrModel"), + ("GPTNeoConfig", "GPTNeoModel"), + ("BigBirdConfig", "BigBirdModel"), + ("Speech2TextConfig", "Speech2TextModel"), + ("ViTConfig", "ViTModel"), + ("Wav2Vec2Config", "Wav2Vec2Model"), + ("M2M100Config", "M2M100Model"), + ("ConvBertConfig", "ConvBertModel"), + ("LEDConfig", "LEDModel"), + ("BlenderbotSmallConfig", "BlenderbotSmallModel"), + ("RetriBertConfig", "RetriBertModel"), + ("MT5Config", "MT5Model"), + ("T5Config", "T5Model"), + ("PegasusConfig", "PegasusModel"), + ("MarianConfig", "MarianModel"), + ("MBartConfig", "MBartModel"), + ("BlenderbotConfig", "BlenderbotModel"), + ("DistilBertConfig", "DistilBertModel"), + ("AlbertConfig", "AlbertModel"), + ("CamembertConfig", "CamembertModel"), + ("XLMRobertaConfig", "XLMRobertaModel"), + ("BartConfig", "BartModel"), + ("LongformerConfig", "LongformerModel"), + ("RobertaConfig", "RobertaModel"), + ("LayoutLMConfig", "LayoutLMModel"), + ("SqueezeBertConfig", "SqueezeBertModel"), + ("BertConfig", "BertModel"), + ("OpenAIGPTConfig", "OpenAIGPTModel"), + ("GPT2Config", "GPT2Model"), + ("MegatronBertConfig", "MegatronBertModel"), + ("MobileBertConfig", "MobileBertModel"), + ("TransfoXLConfig", "TransfoXLModel"), + ("XLNetConfig", "XLNetModel"), + ("FlaubertConfig", "FlaubertModel"), + ("FSMTConfig", "FSMTModel"), + ("XLMConfig", "XLMModel"), + ("CTRLConfig", "CTRLModel"), + ("ElectraConfig", "ElectraModel"), + ("ReformerConfig", "ReformerModel"), + ("FunnelConfig", "('FunnelModel', 'FunnelBaseModel')"), + ("LxmertConfig", "LxmertModel"), + ("BertGenerationConfig", "BertGenerationEncoder"), + ("DebertaConfig", "DebertaModel"), + ("DebertaV2Config", "DebertaV2Model"), + ("DPRConfig", "DPRQuestionEncoder"), + ("XLMProphetNetConfig", "XLMProphetNetModel"), + ("ProphetNetConfig", "ProphetNetModel"), + ("MPNetConfig", "MPNetModel"), + ("TapasConfig", "TapasModel"), + ("IBertConfig", "IBertModel"), + ] +) + + +MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( + [ + ("RoFormerConfig", "RoFormerForMaskedLM"), + ("BigBirdPegasusConfig", "BigBirdPegasusForConditionalGeneration"), + ("GPTNeoConfig", "GPTNeoForCausalLM"), + ("BigBirdConfig", "BigBirdForMaskedLM"), + ("Speech2TextConfig", "Speech2TextForConditionalGeneration"), + ("Wav2Vec2Config", "Wav2Vec2ForMaskedLM"), + ("M2M100Config", "M2M100ForConditionalGeneration"), + ("ConvBertConfig", "ConvBertForMaskedLM"), + ("LEDConfig", "LEDForConditionalGeneration"), + ("BlenderbotSmallConfig", "BlenderbotSmallForConditionalGeneration"), + ("LayoutLMConfig", "LayoutLMForMaskedLM"), + ("T5Config", "T5ForConditionalGeneration"), + ("DistilBertConfig", "DistilBertForMaskedLM"), + ("AlbertConfig", "AlbertForMaskedLM"), + ("CamembertConfig", "CamembertForMaskedLM"), + ("XLMRobertaConfig", "XLMRobertaForMaskedLM"), + ("MarianConfig", "MarianMTModel"), + ("FSMTConfig", "FSMTForConditionalGeneration"), + ("BartConfig", "BartForConditionalGeneration"), + ("LongformerConfig", "LongformerForMaskedLM"), + ("RobertaConfig", "RobertaForMaskedLM"), + ("SqueezeBertConfig", "SqueezeBertForMaskedLM"), + ("BertConfig", "BertForMaskedLM"), + ("OpenAIGPTConfig", "OpenAIGPTLMHeadModel"), + ("GPT2Config", "GPT2LMHeadModel"), + ("MegatronBertConfig", "MegatronBertForCausalLM"), + ("MobileBertConfig", "MobileBertForMaskedLM"), + ("TransfoXLConfig", "TransfoXLLMHeadModel"), + ("XLNetConfig", "XLNetLMHeadModel"), + ("FlaubertConfig", "FlaubertWithLMHeadModel"), + ("XLMConfig", "XLMWithLMHeadModel"), + ("CTRLConfig", "CTRLLMHeadModel"), + ("ElectraConfig", "ElectraForMaskedLM"), + ("EncoderDecoderConfig", "EncoderDecoderModel"), + ("ReformerConfig", "ReformerModelWithLMHead"), + ("FunnelConfig", "FunnelForMaskedLM"), + ("MPNetConfig", "MPNetForMaskedLM"), + ("TapasConfig", "TapasForMaskedLM"), + ("DebertaConfig", "DebertaForMaskedLM"), + ("DebertaV2Config", "DebertaV2ForMaskedLM"), + ("IBertConfig", "IBertForMaskedLM"), + ] +) diff --git a/utils/class_mapping_update.py b/utils/class_mapping_update.py index 126600acd..71f02dcef 100644 --- a/utils/class_mapping_update.py +++ b/utils/class_mapping_update.py @@ -30,31 +30,77 @@ sys.path.insert(1, git_repo_path) src = "src/transformers/models/auto/modeling_auto.py" dst = "src/transformers/utils/modeling_auto_mapping.py" + if os.path.exists(dst) and os.path.getmtime(src) < os.path.getmtime(dst): # speed things up by only running this script if the src is newer than dst sys.exit(0) # only load if needed -from transformers.models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING # noqa - - -entries = "\n".join( - [f' ("{k.__name__}", "{v.__name__}"),' for k, v in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items()] +from transformers.models.auto.modeling_auto import ( # noqa + MODEL_FOR_CAUSAL_LM_MAPPING, + MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, + MODEL_FOR_MASKED_LM_MAPPING, + MODEL_FOR_MULTIPLE_CHOICE_MAPPING, + MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, + MODEL_FOR_OBJECT_DETECTION_MAPPING, + MODEL_FOR_PRETRAINING_MAPPING, + MODEL_FOR_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, + MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, + MODEL_MAPPING, + MODEL_WITH_LM_HEAD_MAPPING, ) + + +# Those constants don't have a name attribute, so we need to define it manually +mappings = { + "MODEL_FOR_QUESTION_ANSWERING_MAPPING": MODEL_FOR_QUESTION_ANSWERING_MAPPING, + "MODEL_FOR_CAUSAL_LM_MAPPING": MODEL_FOR_CAUSAL_LM_MAPPING, + "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, + "MODEL_FOR_MASKED_LM_MAPPING": MODEL_FOR_MASKED_LM_MAPPING, + "MODEL_FOR_MULTIPLE_CHOICE_MAPPING": MODEL_FOR_MULTIPLE_CHOICE_MAPPING, + "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, + "MODEL_FOR_OBJECT_DETECTION_MAPPING": MODEL_FOR_OBJECT_DETECTION_MAPPING, + "MODEL_FOR_OBJECT_DETECTION_MAPPING": MODEL_FOR_OBJECT_DETECTION_MAPPING, + "MODEL_FOR_QUESTION_ANSWERING_MAPPING": MODEL_FOR_QUESTION_ANSWERING_MAPPING, + "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, + "MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING": MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, + "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, + "MODEL_MAPPING": MODEL_MAPPING, + "MODEL_WITH_LM_HEAD_MAPPING": MODEL_WITH_LM_HEAD_MAPPING, +} + + +def get_name(value): + if isinstance(value, tuple): + return tuple(get_name(o) for o in value) + return value.__name__ + + content = [ "# THIS FILE HAS BEEN AUTOGENERATED. To update:", "# 1. modify: models/auto/modeling_auto.py", "# 2. run: python utils/class_mapping_update.py", "from collections import OrderedDict", "", - "", - "MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(", - " [", - entries, - " ]", - ")", - "", ] -print(f"updating {dst}") + +for name, mapping in mappings.items(): + entries = "\n".join([f' ("{k.__name__}", "{get_name(v)}"),' for k, v in mapping.items()]) + + content += [ + "", + f"{name}_NAMES = OrderedDict(", + " [", + entries, + " ]", + ")", + "", + ] + +print(f"Updating {dst}") with open(dst, "w", encoding="utf-8", newline="\n") as f: f.write("\n".join(content))