mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Fix add-new-model-like when old model checkpoint is not found (#15805)
* Fix add-new-model-like command when old checkpoint can't be recovered * Style
This commit is contained in:
parent
bb7949b35a
commit
7f921bcf47
1 changed files with 28 additions and 2 deletions
|
|
@ -1115,6 +1115,7 @@ def create_new_model_like(
|
|||
new_model_patterns: ModelPatterns,
|
||||
add_copied_from: bool = True,
|
||||
frameworks: Optional[List[str]] = None,
|
||||
old_checkpoint: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Creates a new model module like a given model of the Transformers library.
|
||||
|
|
@ -1126,11 +1127,22 @@ def create_new_model_like(
|
|||
Whether or not to add "Copied from" statements to all classes in the new model modeling files.
|
||||
frameworks (`List[str]`, *optional*):
|
||||
If passed, will limit the duplicate to the frameworks specified.
|
||||
old_checkpoint (`str`, *optional*):
|
||||
The name of the base checkpoint for the old model. Should be passed along when it can't be automatically
|
||||
recovered from the `model_type`.
|
||||
"""
|
||||
# Retrieve all the old model info.
|
||||
model_info = retrieve_info_for_model(model_type, frameworks=frameworks)
|
||||
model_files = model_info["model_files"]
|
||||
old_model_patterns = model_info["model_patterns"]
|
||||
if old_checkpoint is not None:
|
||||
old_model_patterns.checkpoint = old_checkpoint
|
||||
if len(old_model_patterns.checkpoint) == 0:
|
||||
raise ValueError(
|
||||
"The old model checkpoint could not be recovered from the model type. Please pass it to the "
|
||||
"`old_checkpoint` argument."
|
||||
)
|
||||
|
||||
keep_old_processing = True
|
||||
for processing_attr in ["feature_extractor_class", "processor_class", "tokenizer_class"]:
|
||||
if getattr(old_model_patterns, processing_attr) != getattr(new_model_patterns, processing_attr):
|
||||
|
|
@ -1291,8 +1303,15 @@ class AddNewModelLikeCommand(BaseTransformersCLICommand):
|
|||
self.model_patterns = ModelPatterns(**config["new_model_patterns"])
|
||||
self.add_copied_from = config.get("add_copied_from", True)
|
||||
self.frameworks = config.get("frameworks", ["pt", "tf", "flax"])
|
||||
self.old_checkpoint = config.get("old_checkpoint", None)
|
||||
else:
|
||||
self.old_model_type, self.model_patterns, self.add_copied_from, self.frameworks = get_user_input()
|
||||
(
|
||||
self.old_model_type,
|
||||
self.model_patterns,
|
||||
self.add_copied_from,
|
||||
self.frameworks,
|
||||
self.old_checkpoint,
|
||||
) = get_user_input()
|
||||
|
||||
self.path_to_repo = path_to_repo
|
||||
|
||||
|
|
@ -1310,6 +1329,7 @@ class AddNewModelLikeCommand(BaseTransformersCLICommand):
|
|||
new_model_patterns=self.model_patterns,
|
||||
add_copied_from=self.add_copied_from,
|
||||
frameworks=self.frameworks,
|
||||
old_checkpoint=self.old_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1402,6 +1422,12 @@ def get_user_input():
|
|||
old_processor_class = old_model_info["model_patterns"].processor_class
|
||||
old_frameworks = old_model_info["frameworks"]
|
||||
|
||||
old_checkpoint = None
|
||||
if len(old_model_info["model_patterns"].checkpoint) == 0:
|
||||
old_checkpoint = get_user_field(
|
||||
"We couldn't find the name of the base checkpoint for that model, please enter it here."
|
||||
)
|
||||
|
||||
model_name = get_user_field("What is the name for your new model?")
|
||||
default_patterns = ModelPatterns(model_name, model_name)
|
||||
|
||||
|
|
@ -1497,4 +1523,4 @@ def get_user_input():
|
|||
)
|
||||
frameworks = list(set(frameworks.split(" ")))
|
||||
|
||||
return (old_model_type, model_patterns, add_copied_from, frameworks)
|
||||
return (old_model_type, model_patterns, add_copied_from, frameworks, old_checkpoint)
|
||||
|
|
|
|||
Loading…
Reference in a new issue