mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Model card defaults (#12122)
* [WIP] Model card defaults * finetuned_from default value * Add all mappings to the mapping file * Be more defensive on finetuned_from arg * Add default task tag * Separate tags from tasks * Edge case for dataset * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
6e7cc5cc51
commit
7d7ceca396
14 changed files with 477 additions and 26 deletions
|
|
@ -467,7 +467,7 @@ def main():
|
|||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
if training_args.push_to_hub:
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "text-generation"}
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"}
|
||||
if data_args.dataset_name is not None:
|
||||
kwargs["dataset_tags"] = data_args.dataset_name
|
||||
if data_args.dataset_config_name is not None:
|
||||
|
|
|
|||
|
|
@ -497,7 +497,7 @@ def main():
|
|||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
if training_args.push_to_hub:
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "fill-mask"}
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "fill-mask"}
|
||||
if data_args.dataset_name is not None:
|
||||
kwargs["dataset_tags"] = data_args.dataset_name
|
||||
if data_args.dataset_config_name is not None:
|
||||
|
|
|
|||
|
|
@ -471,7 +471,7 @@ def main():
|
|||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
if training_args.push_to_hub:
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "language-modeling"}
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "language-modeling"}
|
||||
if data_args.dataset_name is not None:
|
||||
kwargs["dataset_tags"] = data_args.dataset_name
|
||||
if data_args.dataset_config_name is not None:
|
||||
|
|
|
|||
|
|
@ -430,7 +430,7 @@ def main():
|
|||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(
|
||||
finetuned_from=model_args.model_name_or_path,
|
||||
tags="multiple-choice",
|
||||
tasks="multiple-choice",
|
||||
dataset_tags="swag",
|
||||
dataset_args="regular",
|
||||
dataset="SWAG",
|
||||
|
|
|
|||
|
|
@ -601,7 +601,7 @@ def main():
|
|||
trainer.save_metrics("predict", metrics)
|
||||
|
||||
if training_args.push_to_hub:
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "question-answering"}
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "question-answering"}
|
||||
if data_args.dataset_name is not None:
|
||||
kwargs["dataset_tags"] = data_args.dataset_name
|
||||
if data_args.dataset_config_name is not None:
|
||||
|
|
|
|||
|
|
@ -640,7 +640,7 @@ def main():
|
|||
trainer.save_metrics("predict", metrics)
|
||||
|
||||
if training_args.push_to_hub:
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "question-answering"}
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "question-answering"}
|
||||
if data_args.dataset_name is not None:
|
||||
kwargs["dataset_tags"] = data_args.dataset_name
|
||||
if data_args.dataset_config_name is not None:
|
||||
|
|
|
|||
|
|
@ -583,7 +583,7 @@ def main():
|
|||
writer.write("\n".join(predictions))
|
||||
|
||||
if training_args.push_to_hub:
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "summarization"}
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "summarization"}
|
||||
if data_args.dataset_name is not None:
|
||||
kwargs["dataset_tags"] = data_args.dataset_name
|
||||
if data_args.dataset_config_name is not None:
|
||||
|
|
|
|||
|
|
@ -538,7 +538,7 @@ def main():
|
|||
writer.write(f"{index}\t{item}\n")
|
||||
|
||||
if training_args.push_to_hub:
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "text-classification"}
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"}
|
||||
if data_args.task_name is not None:
|
||||
kwargs["language"] = "en"
|
||||
kwargs["dataset_tags"] = "glue"
|
||||
|
|
|
|||
|
|
@ -522,7 +522,7 @@ def main():
|
|||
writer.write(" ".join(prediction) + "\n")
|
||||
|
||||
if training_args.push_to_hub:
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "token-classification"}
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "token-classification"}
|
||||
if data_args.dataset_name is not None:
|
||||
kwargs["dataset_tags"] = data_args.dataset_name
|
||||
if data_args.dataset_config_name is not None:
|
||||
|
|
|
|||
|
|
@ -575,7 +575,7 @@ def main():
|
|||
writer.write("\n".join(predictions))
|
||||
|
||||
if training_args.push_to_hub:
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "translation"}
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "translation"}
|
||||
if data_args.dataset_name is not None:
|
||||
kwargs["dataset_tags"] = data_args.dataset_name
|
||||
if data_args.dataset_config_name is not None:
|
||||
|
|
|
|||
|
|
@ -42,8 +42,31 @@ from .file_utils import (
|
|||
)
|
||||
from .training_args import ParallelMode
|
||||
from .utils import logging
|
||||
from .utils.modeling_auto_mapping import (
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
||||
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
||||
)
|
||||
|
||||
|
||||
TASK_MAPPING = {
|
||||
"text-generation": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||
"fill-mask": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
||||
"object-detection": MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
|
||||
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
|
||||
"text2text-generation": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
|
||||
"text-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
||||
"table-question-answering": MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
|
||||
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
||||
}
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
|
|
@ -246,9 +269,12 @@ should probably proofread and complete it, then remove this comment. -->
|
|||
|
||||
TASK_TAG_TO_NAME_MAPPING = {
|
||||
"fill-mask": "Masked Language Modeling",
|
||||
"image-classification": "Image Classification",
|
||||
"multiple-choice": "Multiple Choice",
|
||||
"object-detection": "Object Detection",
|
||||
"question-answering": "Question Answering",
|
||||
"summarization": "Summarization",
|
||||
"table-question-answering": "Table Question Answering",
|
||||
"text-classification": "Text Classification",
|
||||
"text-generation": "Causal Language Modeling",
|
||||
"text2text-generation": "Sequence-to-sequence Language Modeling",
|
||||
|
|
@ -304,6 +330,25 @@ def infer_metric_tags_from_eval_results(eval_results):
|
|||
return result
|
||||
|
||||
|
||||
def is_hf_dataset(dataset):
|
||||
if not is_datasets_available():
|
||||
return False
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
return isinstance(dataset, Dataset)
|
||||
|
||||
|
||||
def _get_mapping_values(mapping):
|
||||
result = []
|
||||
for v in mapping.values():
|
||||
if isinstance(v, (tuple, list)):
|
||||
result += list(v)
|
||||
else:
|
||||
result.append(v)
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingSummary:
|
||||
model_name: str
|
||||
|
|
@ -311,6 +356,7 @@ class TrainingSummary:
|
|||
license: Optional[str] = None
|
||||
tags: Optional[Union[str, List[str]]] = None
|
||||
finetuned_from: Optional[str] = None
|
||||
tasks: Optional[Union[str, List[str]]] = None
|
||||
dataset: Optional[Union[str, List[str]]] = None
|
||||
dataset_tags: Optional[Union[str, List[str]]] = None
|
||||
dataset_args: Optional[Union[str, List[str]]] = None
|
||||
|
|
@ -320,7 +366,12 @@ class TrainingSummary:
|
|||
|
||||
def __post_init__(self):
|
||||
# Infer default license from the checkpoint used, if possible.
|
||||
if self.license is None and not is_offline_mode() and self.finetuned_from is not None:
|
||||
if (
|
||||
self.license is None
|
||||
and not is_offline_mode()
|
||||
and self.finetuned_from is not None
|
||||
and len(self.finetuned_from) > 0
|
||||
):
|
||||
try:
|
||||
model_info = HfApi().model_info(self.finetuned_from)
|
||||
for tag in model_info.tags:
|
||||
|
|
@ -342,7 +393,7 @@ class TrainingSummary:
|
|||
dataset_arg_mapping = {tag: arg for tag, arg in zip(dataset_tags, dataset_args)}
|
||||
|
||||
task_mapping = {
|
||||
tag: TASK_TAG_TO_NAME_MAPPING[tag] for tag in _listify(self.tags) if tag in TASK_TAG_TO_NAME_MAPPING
|
||||
task: TASK_TAG_TO_NAME_MAPPING[task] for task in _listify(self.tasks) if task in TASK_TAG_TO_NAME_MAPPING
|
||||
}
|
||||
|
||||
if len(task_mapping) == 0 and len(dataset_mapping) == 0:
|
||||
|
|
@ -405,6 +456,8 @@ class TrainingSummary:
|
|||
else:
|
||||
if isinstance(self.dataset, str):
|
||||
model_card += f"the {self.dataset} dataset."
|
||||
elif isinstance(self.dataset, (tuple, list)) and len(self.dataset) == 1:
|
||||
model_card += f"the {self.dataset[0]} dataset."
|
||||
else:
|
||||
model_card += (
|
||||
", ".join([f"the {ds}" for ds in self.dataset[:-1]]) + f" and the {self.dataset[-1]} datasets."
|
||||
|
|
@ -459,11 +512,40 @@ class TrainingSummary:
|
|||
tags=None,
|
||||
model_name=None,
|
||||
finetuned_from=None,
|
||||
tasks=None,
|
||||
dataset_tags=None,
|
||||
dataset=None,
|
||||
dataset_args=None,
|
||||
):
|
||||
# TODO (Sylvain) Add a default for `pipeline-tag` inferred from the model.
|
||||
# Infer default from dataset
|
||||
one_dataset = trainer.train_dataset if trainer.train_dataset is not None else trainer.eval_dataset
|
||||
if is_hf_dataset(one_dataset) and (dataset_tags is None or dataset_args is None):
|
||||
default_tag = one_dataset.builder_name
|
||||
# Those are not real datasets from the Hub so we exclude them.
|
||||
if default_tag not in ["csv", "json", "pandas", "parquet", "text"]:
|
||||
if dataset_tags is None:
|
||||
dataset_tags = [default_tag]
|
||||
if dataset_args is None:
|
||||
dataset_args = [one_dataset.config_name]
|
||||
|
||||
if dataset is None and dataset_tags is not None:
|
||||
dataset = dataset_tags
|
||||
|
||||
# Infer default finetuned_from
|
||||
if (
|
||||
finetuned_from is None
|
||||
and hasattr(trainer.model.config, "_name_or_path")
|
||||
and not os.path.isdir(trainer.model.config._name_or_path)
|
||||
):
|
||||
finetuned_from = trainer.model.config._name_or_path
|
||||
|
||||
# Infer default task tag:
|
||||
if tasks is None:
|
||||
model_class_name = trainer.model.__class__.__name__
|
||||
for task, mapping in TASK_MAPPING.items():
|
||||
if model_class_name in _get_mapping_values(mapping):
|
||||
tasks = task
|
||||
|
||||
if model_name is None:
|
||||
model_name = Path(trainer.args.output_dir).name
|
||||
|
||||
|
|
@ -476,6 +558,7 @@ class TrainingSummary:
|
|||
tags=tags,
|
||||
model_name=model_name,
|
||||
finetuned_from=finetuned_from,
|
||||
tasks=tasks,
|
||||
dataset_tags=dataset_tags,
|
||||
dataset=dataset,
|
||||
dataset_args=dataset_args,
|
||||
|
|
|
|||
|
|
@ -2433,6 +2433,7 @@ class Trainer:
|
|||
tags: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
finetuned_from: Optional[str] = None,
|
||||
tasks: Optional[str] = None,
|
||||
dataset_tags: Optional[Union[str, List[str]]] = None,
|
||||
dataset: Optional[Union[str, List[str]]] = None,
|
||||
dataset_args: Optional[Union[str, List[str]]] = None,
|
||||
|
|
@ -2444,6 +2445,7 @@ class Trainer:
|
|||
tags=tags,
|
||||
model_name=model_name,
|
||||
finetuned_from=finetuned_from,
|
||||
tasks=tasks,
|
||||
dataset_tags=dataset_tags,
|
||||
dataset=dataset,
|
||||
dataset_args=dataset_args,
|
||||
|
|
|
|||
|
|
@ -36,3 +36,323 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
|||
("IBertConfig", "IBertForQuestionAnswering"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("RoFormerConfig", "RoFormerForCausalLM"),
|
||||
("BigBirdPegasusConfig", "BigBirdPegasusForCausalLM"),
|
||||
("GPTNeoConfig", "GPTNeoForCausalLM"),
|
||||
("BigBirdConfig", "BigBirdForCausalLM"),
|
||||
("CamembertConfig", "CamembertForCausalLM"),
|
||||
("XLMRobertaConfig", "XLMRobertaForCausalLM"),
|
||||
("RobertaConfig", "RobertaForCausalLM"),
|
||||
("BertConfig", "BertLMHeadModel"),
|
||||
("OpenAIGPTConfig", "OpenAIGPTLMHeadModel"),
|
||||
("GPT2Config", "GPT2LMHeadModel"),
|
||||
("TransfoXLConfig", "TransfoXLLMHeadModel"),
|
||||
("XLNetConfig", "XLNetLMHeadModel"),
|
||||
("XLMConfig", "XLMWithLMHeadModel"),
|
||||
("CTRLConfig", "CTRLLMHeadModel"),
|
||||
("ReformerConfig", "ReformerModelWithLMHead"),
|
||||
("BertGenerationConfig", "BertGenerationDecoder"),
|
||||
("XLMProphetNetConfig", "XLMProphetNetForCausalLM"),
|
||||
("ProphetNetConfig", "ProphetNetForCausalLM"),
|
||||
("BartConfig", "BartForCausalLM"),
|
||||
("MBartConfig", "MBartForCausalLM"),
|
||||
("PegasusConfig", "PegasusForCausalLM"),
|
||||
("MarianConfig", "MarianForCausalLM"),
|
||||
("BlenderbotConfig", "BlenderbotForCausalLM"),
|
||||
("BlenderbotSmallConfig", "BlenderbotSmallForCausalLM"),
|
||||
("MegatronBertConfig", "MegatronBertForCausalLM"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("ViTConfig", "ViTForImageClassification"),
|
||||
("DeiTConfig", "('DeiTForImageClassification', 'DeiTForImageClassificationWithTeacher')"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("RoFormerConfig", "RoFormerForMaskedLM"),
|
||||
("BigBirdConfig", "BigBirdForMaskedLM"),
|
||||
("Wav2Vec2Config", "Wav2Vec2ForMaskedLM"),
|
||||
("ConvBertConfig", "ConvBertForMaskedLM"),
|
||||
("LayoutLMConfig", "LayoutLMForMaskedLM"),
|
||||
("DistilBertConfig", "DistilBertForMaskedLM"),
|
||||
("AlbertConfig", "AlbertForMaskedLM"),
|
||||
("BartConfig", "BartForConditionalGeneration"),
|
||||
("MBartConfig", "MBartForConditionalGeneration"),
|
||||
("CamembertConfig", "CamembertForMaskedLM"),
|
||||
("XLMRobertaConfig", "XLMRobertaForMaskedLM"),
|
||||
("LongformerConfig", "LongformerForMaskedLM"),
|
||||
("RobertaConfig", "RobertaForMaskedLM"),
|
||||
("SqueezeBertConfig", "SqueezeBertForMaskedLM"),
|
||||
("BertConfig", "BertForMaskedLM"),
|
||||
("MegatronBertConfig", "MegatronBertForMaskedLM"),
|
||||
("MobileBertConfig", "MobileBertForMaskedLM"),
|
||||
("FlaubertConfig", "FlaubertWithLMHeadModel"),
|
||||
("XLMConfig", "XLMWithLMHeadModel"),
|
||||
("ElectraConfig", "ElectraForMaskedLM"),
|
||||
("ReformerConfig", "ReformerForMaskedLM"),
|
||||
("FunnelConfig", "FunnelForMaskedLM"),
|
||||
("MPNetConfig", "MPNetForMaskedLM"),
|
||||
("TapasConfig", "TapasForMaskedLM"),
|
||||
("DebertaConfig", "DebertaForMaskedLM"),
|
||||
("DebertaV2Config", "DebertaV2ForMaskedLM"),
|
||||
("IBertConfig", "IBertForMaskedLM"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("RoFormerConfig", "RoFormerForMultipleChoice"),
|
||||
("BigBirdConfig", "BigBirdForMultipleChoice"),
|
||||
("ConvBertConfig", "ConvBertForMultipleChoice"),
|
||||
("CamembertConfig", "CamembertForMultipleChoice"),
|
||||
("ElectraConfig", "ElectraForMultipleChoice"),
|
||||
("XLMRobertaConfig", "XLMRobertaForMultipleChoice"),
|
||||
("LongformerConfig", "LongformerForMultipleChoice"),
|
||||
("RobertaConfig", "RobertaForMultipleChoice"),
|
||||
("SqueezeBertConfig", "SqueezeBertForMultipleChoice"),
|
||||
("BertConfig", "BertForMultipleChoice"),
|
||||
("DistilBertConfig", "DistilBertForMultipleChoice"),
|
||||
("MegatronBertConfig", "MegatronBertForMultipleChoice"),
|
||||
("MobileBertConfig", "MobileBertForMultipleChoice"),
|
||||
("XLNetConfig", "XLNetForMultipleChoice"),
|
||||
("AlbertConfig", "AlbertForMultipleChoice"),
|
||||
("XLMConfig", "XLMForMultipleChoice"),
|
||||
("FlaubertConfig", "FlaubertForMultipleChoice"),
|
||||
("FunnelConfig", "FunnelForMultipleChoice"),
|
||||
("MPNetConfig", "MPNetForMultipleChoice"),
|
||||
("IBertConfig", "IBertForMultipleChoice"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("BertConfig", "BertForNextSentencePrediction"),
|
||||
("MegatronBertConfig", "MegatronBertForNextSentencePrediction"),
|
||||
("MobileBertConfig", "MobileBertForNextSentencePrediction"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("DetrConfig", "DetrForObjectDetection"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("BigBirdPegasusConfig", "BigBirdPegasusForConditionalGeneration"),
|
||||
("M2M100Config", "M2M100ForConditionalGeneration"),
|
||||
("LEDConfig", "LEDForConditionalGeneration"),
|
||||
("BlenderbotSmallConfig", "BlenderbotSmallForConditionalGeneration"),
|
||||
("MT5Config", "MT5ForConditionalGeneration"),
|
||||
("T5Config", "T5ForConditionalGeneration"),
|
||||
("PegasusConfig", "PegasusForConditionalGeneration"),
|
||||
("MarianConfig", "MarianMTModel"),
|
||||
("MBartConfig", "MBartForConditionalGeneration"),
|
||||
("BlenderbotConfig", "BlenderbotForConditionalGeneration"),
|
||||
("BartConfig", "BartForConditionalGeneration"),
|
||||
("FSMTConfig", "FSMTForConditionalGeneration"),
|
||||
("EncoderDecoderConfig", "EncoderDecoderModel"),
|
||||
("XLMProphetNetConfig", "XLMProphetNetForConditionalGeneration"),
|
||||
("ProphetNetConfig", "ProphetNetForConditionalGeneration"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("RoFormerConfig", "RoFormerForSequenceClassification"),
|
||||
("BigBirdPegasusConfig", "BigBirdPegasusForSequenceClassification"),
|
||||
("BigBirdConfig", "BigBirdForSequenceClassification"),
|
||||
("ConvBertConfig", "ConvBertForSequenceClassification"),
|
||||
("LEDConfig", "LEDForSequenceClassification"),
|
||||
("DistilBertConfig", "DistilBertForSequenceClassification"),
|
||||
("AlbertConfig", "AlbertForSequenceClassification"),
|
||||
("CamembertConfig", "CamembertForSequenceClassification"),
|
||||
("XLMRobertaConfig", "XLMRobertaForSequenceClassification"),
|
||||
("MBartConfig", "MBartForSequenceClassification"),
|
||||
("BartConfig", "BartForSequenceClassification"),
|
||||
("LongformerConfig", "LongformerForSequenceClassification"),
|
||||
("RobertaConfig", "RobertaForSequenceClassification"),
|
||||
("SqueezeBertConfig", "SqueezeBertForSequenceClassification"),
|
||||
("LayoutLMConfig", "LayoutLMForSequenceClassification"),
|
||||
("BertConfig", "BertForSequenceClassification"),
|
||||
("XLNetConfig", "XLNetForSequenceClassification"),
|
||||
("MegatronBertConfig", "MegatronBertForSequenceClassification"),
|
||||
("MobileBertConfig", "MobileBertForSequenceClassification"),
|
||||
("FlaubertConfig", "FlaubertForSequenceClassification"),
|
||||
("XLMConfig", "XLMForSequenceClassification"),
|
||||
("ElectraConfig", "ElectraForSequenceClassification"),
|
||||
("FunnelConfig", "FunnelForSequenceClassification"),
|
||||
("DebertaConfig", "DebertaForSequenceClassification"),
|
||||
("DebertaV2Config", "DebertaV2ForSequenceClassification"),
|
||||
("GPT2Config", "GPT2ForSequenceClassification"),
|
||||
("GPTNeoConfig", "GPTNeoForSequenceClassification"),
|
||||
("OpenAIGPTConfig", "OpenAIGPTForSequenceClassification"),
|
||||
("ReformerConfig", "ReformerForSequenceClassification"),
|
||||
("CTRLConfig", "CTRLForSequenceClassification"),
|
||||
("TransfoXLConfig", "TransfoXLForSequenceClassification"),
|
||||
("MPNetConfig", "MPNetForSequenceClassification"),
|
||||
("TapasConfig", "TapasForSequenceClassification"),
|
||||
("IBertConfig", "IBertForSequenceClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("TapasConfig", "TapasForQuestionAnswering"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("RoFormerConfig", "RoFormerForTokenClassification"),
|
||||
("BigBirdConfig", "BigBirdForTokenClassification"),
|
||||
("ConvBertConfig", "ConvBertForTokenClassification"),
|
||||
("LayoutLMConfig", "LayoutLMForTokenClassification"),
|
||||
("DistilBertConfig", "DistilBertForTokenClassification"),
|
||||
("CamembertConfig", "CamembertForTokenClassification"),
|
||||
("FlaubertConfig", "FlaubertForTokenClassification"),
|
||||
("XLMConfig", "XLMForTokenClassification"),
|
||||
("XLMRobertaConfig", "XLMRobertaForTokenClassification"),
|
||||
("LongformerConfig", "LongformerForTokenClassification"),
|
||||
("RobertaConfig", "RobertaForTokenClassification"),
|
||||
("SqueezeBertConfig", "SqueezeBertForTokenClassification"),
|
||||
("BertConfig", "BertForTokenClassification"),
|
||||
("MegatronBertConfig", "MegatronBertForTokenClassification"),
|
||||
("MobileBertConfig", "MobileBertForTokenClassification"),
|
||||
("XLNetConfig", "XLNetForTokenClassification"),
|
||||
("AlbertConfig", "AlbertForTokenClassification"),
|
||||
("ElectraConfig", "ElectraForTokenClassification"),
|
||||
("FunnelConfig", "FunnelForTokenClassification"),
|
||||
("MPNetConfig", "MPNetForTokenClassification"),
|
||||
("DebertaConfig", "DebertaForTokenClassification"),
|
||||
("DebertaV2Config", "DebertaV2ForTokenClassification"),
|
||||
("IBertConfig", "IBertForTokenClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("VisualBertConfig", "VisualBertModel"),
|
||||
("RoFormerConfig", "RoFormerModel"),
|
||||
("CLIPConfig", "CLIPModel"),
|
||||
("BigBirdPegasusConfig", "BigBirdPegasusModel"),
|
||||
("DeiTConfig", "DeiTModel"),
|
||||
("LukeConfig", "LukeModel"),
|
||||
("DetrConfig", "DetrModel"),
|
||||
("GPTNeoConfig", "GPTNeoModel"),
|
||||
("BigBirdConfig", "BigBirdModel"),
|
||||
("Speech2TextConfig", "Speech2TextModel"),
|
||||
("ViTConfig", "ViTModel"),
|
||||
("Wav2Vec2Config", "Wav2Vec2Model"),
|
||||
("M2M100Config", "M2M100Model"),
|
||||
("ConvBertConfig", "ConvBertModel"),
|
||||
("LEDConfig", "LEDModel"),
|
||||
("BlenderbotSmallConfig", "BlenderbotSmallModel"),
|
||||
("RetriBertConfig", "RetriBertModel"),
|
||||
("MT5Config", "MT5Model"),
|
||||
("T5Config", "T5Model"),
|
||||
("PegasusConfig", "PegasusModel"),
|
||||
("MarianConfig", "MarianModel"),
|
||||
("MBartConfig", "MBartModel"),
|
||||
("BlenderbotConfig", "BlenderbotModel"),
|
||||
("DistilBertConfig", "DistilBertModel"),
|
||||
("AlbertConfig", "AlbertModel"),
|
||||
("CamembertConfig", "CamembertModel"),
|
||||
("XLMRobertaConfig", "XLMRobertaModel"),
|
||||
("BartConfig", "BartModel"),
|
||||
("LongformerConfig", "LongformerModel"),
|
||||
("RobertaConfig", "RobertaModel"),
|
||||
("LayoutLMConfig", "LayoutLMModel"),
|
||||
("SqueezeBertConfig", "SqueezeBertModel"),
|
||||
("BertConfig", "BertModel"),
|
||||
("OpenAIGPTConfig", "OpenAIGPTModel"),
|
||||
("GPT2Config", "GPT2Model"),
|
||||
("MegatronBertConfig", "MegatronBertModel"),
|
||||
("MobileBertConfig", "MobileBertModel"),
|
||||
("TransfoXLConfig", "TransfoXLModel"),
|
||||
("XLNetConfig", "XLNetModel"),
|
||||
("FlaubertConfig", "FlaubertModel"),
|
||||
("FSMTConfig", "FSMTModel"),
|
||||
("XLMConfig", "XLMModel"),
|
||||
("CTRLConfig", "CTRLModel"),
|
||||
("ElectraConfig", "ElectraModel"),
|
||||
("ReformerConfig", "ReformerModel"),
|
||||
("FunnelConfig", "('FunnelModel', 'FunnelBaseModel')"),
|
||||
("LxmertConfig", "LxmertModel"),
|
||||
("BertGenerationConfig", "BertGenerationEncoder"),
|
||||
("DebertaConfig", "DebertaModel"),
|
||||
("DebertaV2Config", "DebertaV2Model"),
|
||||
("DPRConfig", "DPRQuestionEncoder"),
|
||||
("XLMProphetNetConfig", "XLMProphetNetModel"),
|
||||
("ProphetNetConfig", "ProphetNetModel"),
|
||||
("MPNetConfig", "MPNetModel"),
|
||||
("TapasConfig", "TapasModel"),
|
||||
("IBertConfig", "IBertModel"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("RoFormerConfig", "RoFormerForMaskedLM"),
|
||||
("BigBirdPegasusConfig", "BigBirdPegasusForConditionalGeneration"),
|
||||
("GPTNeoConfig", "GPTNeoForCausalLM"),
|
||||
("BigBirdConfig", "BigBirdForMaskedLM"),
|
||||
("Speech2TextConfig", "Speech2TextForConditionalGeneration"),
|
||||
("Wav2Vec2Config", "Wav2Vec2ForMaskedLM"),
|
||||
("M2M100Config", "M2M100ForConditionalGeneration"),
|
||||
("ConvBertConfig", "ConvBertForMaskedLM"),
|
||||
("LEDConfig", "LEDForConditionalGeneration"),
|
||||
("BlenderbotSmallConfig", "BlenderbotSmallForConditionalGeneration"),
|
||||
("LayoutLMConfig", "LayoutLMForMaskedLM"),
|
||||
("T5Config", "T5ForConditionalGeneration"),
|
||||
("DistilBertConfig", "DistilBertForMaskedLM"),
|
||||
("AlbertConfig", "AlbertForMaskedLM"),
|
||||
("CamembertConfig", "CamembertForMaskedLM"),
|
||||
("XLMRobertaConfig", "XLMRobertaForMaskedLM"),
|
||||
("MarianConfig", "MarianMTModel"),
|
||||
("FSMTConfig", "FSMTForConditionalGeneration"),
|
||||
("BartConfig", "BartForConditionalGeneration"),
|
||||
("LongformerConfig", "LongformerForMaskedLM"),
|
||||
("RobertaConfig", "RobertaForMaskedLM"),
|
||||
("SqueezeBertConfig", "SqueezeBertForMaskedLM"),
|
||||
("BertConfig", "BertForMaskedLM"),
|
||||
("OpenAIGPTConfig", "OpenAIGPTLMHeadModel"),
|
||||
("GPT2Config", "GPT2LMHeadModel"),
|
||||
("MegatronBertConfig", "MegatronBertForCausalLM"),
|
||||
("MobileBertConfig", "MobileBertForMaskedLM"),
|
||||
("TransfoXLConfig", "TransfoXLLMHeadModel"),
|
||||
("XLNetConfig", "XLNetLMHeadModel"),
|
||||
("FlaubertConfig", "FlaubertWithLMHeadModel"),
|
||||
("XLMConfig", "XLMWithLMHeadModel"),
|
||||
("CTRLConfig", "CTRLLMHeadModel"),
|
||||
("ElectraConfig", "ElectraForMaskedLM"),
|
||||
("EncoderDecoderConfig", "EncoderDecoderModel"),
|
||||
("ReformerConfig", "ReformerModelWithLMHead"),
|
||||
("FunnelConfig", "FunnelForMaskedLM"),
|
||||
("MPNetConfig", "MPNetForMaskedLM"),
|
||||
("TapasConfig", "TapasForMaskedLM"),
|
||||
("DebertaConfig", "DebertaForMaskedLM"),
|
||||
("DebertaV2Config", "DebertaV2ForMaskedLM"),
|
||||
("IBertConfig", "IBertForMaskedLM"),
|
||||
]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -30,31 +30,77 @@ sys.path.insert(1, git_repo_path)
|
|||
src = "src/transformers/models/auto/modeling_auto.py"
|
||||
dst = "src/transformers/utils/modeling_auto_mapping.py"
|
||||
|
||||
|
||||
if os.path.exists(dst) and os.path.getmtime(src) < os.path.getmtime(dst):
|
||||
# speed things up by only running this script if the src is newer than dst
|
||||
sys.exit(0)
|
||||
|
||||
# only load if needed
|
||||
from transformers.models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING # noqa
|
||||
|
||||
|
||||
entries = "\n".join(
|
||||
[f' ("{k.__name__}", "{v.__name__}"),' for k, v in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items()]
|
||||
from transformers.models.auto.modeling_auto import ( # noqa
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
MODEL_FOR_OBJECT_DETECTION_MAPPING,
|
||||
MODEL_FOR_PRETRAINING_MAPPING,
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
)
|
||||
|
||||
|
||||
# Those constants don't have a name attribute, so we need to define it manually
|
||||
mappings = {
|
||||
"MODEL_FOR_QUESTION_ANSWERING_MAPPING": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
"MODEL_FOR_CAUSAL_LM_MAPPING": MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
"MODEL_FOR_MASKED_LM_MAPPING": MODEL_FOR_MASKED_LM_MAPPING,
|
||||
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING": MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
"MODEL_FOR_OBJECT_DETECTION_MAPPING": MODEL_FOR_OBJECT_DETECTION_MAPPING,
|
||||
"MODEL_FOR_OBJECT_DETECTION_MAPPING": MODEL_FOR_OBJECT_DETECTION_MAPPING,
|
||||
"MODEL_FOR_QUESTION_ANSWERING_MAPPING": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
"MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
"MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING": MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
"MODEL_MAPPING": MODEL_MAPPING,
|
||||
"MODEL_WITH_LM_HEAD_MAPPING": MODEL_WITH_LM_HEAD_MAPPING,
|
||||
}
|
||||
|
||||
|
||||
def get_name(value):
|
||||
if isinstance(value, tuple):
|
||||
return tuple(get_name(o) for o in value)
|
||||
return value.__name__
|
||||
|
||||
|
||||
content = [
|
||||
"# THIS FILE HAS BEEN AUTOGENERATED. To update:",
|
||||
"# 1. modify: models/auto/modeling_auto.py",
|
||||
"# 2. run: python utils/class_mapping_update.py",
|
||||
"from collections import OrderedDict",
|
||||
"",
|
||||
"",
|
||||
"MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(",
|
||||
" [",
|
||||
entries,
|
||||
" ]",
|
||||
")",
|
||||
"",
|
||||
]
|
||||
print(f"updating {dst}")
|
||||
|
||||
for name, mapping in mappings.items():
|
||||
entries = "\n".join([f' ("{k.__name__}", "{get_name(v)}"),' for k, v in mapping.items()])
|
||||
|
||||
content += [
|
||||
"",
|
||||
f"{name}_NAMES = OrderedDict(",
|
||||
" [",
|
||||
entries,
|
||||
" ]",
|
||||
")",
|
||||
"",
|
||||
]
|
||||
|
||||
print(f"Updating {dst}")
|
||||
with open(dst, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write("\n".join(content))
|
||||
|
|
|
|||
Loading…
Reference in a new issue