mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Setup loss_type in config at model init time (#34616)
* setup loss_type in config at model init time ensures no additional graph break introduced when torch.compile'ed fixes #34615 Signed-off-by: ChanderG <mail@chandergovind.org> * lookup loss mapping at init time instead of manual setup Signed-off-by: ChanderG <mail@chandergovind.org> * remove redundant lookup at loss_function time * overwride losstype at init time --------- Signed-off-by: ChanderG <mail@chandergovind.org> Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
parent
c8ab6ce6ce
commit
4adc415b6d
1 changed files with 14 additions and 12 deletions
|
|
@ -1319,6 +1319,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
)
|
||||
self.config = config
|
||||
|
||||
# for initialization of the loss
|
||||
loss_type = self.__class__.__name__
|
||||
if loss_type not in LOSS_MAPPING:
|
||||
loss_groups = f"({'|'.join(LOSS_MAPPING)})"
|
||||
loss_type = re.findall(loss_groups, self.__class__.__name__)
|
||||
if len(loss_type) > 0:
|
||||
loss_type = loss_type[0]
|
||||
else:
|
||||
loss_type = None
|
||||
self.loss_type = loss_type
|
||||
|
||||
self.name_or_path = config.name_or_path
|
||||
self.warnings_issued = {}
|
||||
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
|
||||
|
|
@ -5110,18 +5121,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
|
||||
@property
|
||||
def loss_function(self):
|
||||
if getattr(self.config, "loss_type", None) is not None:
|
||||
loss_type = self.config.loss_type
|
||||
else:
|
||||
loss_type = self.__class__.__name__
|
||||
if loss_type not in LOSS_MAPPING:
|
||||
loss_groups = f"({'|'.join(LOSS_MAPPING)})"
|
||||
loss_type = re.findall(loss_groups, self.__class__.__name__)
|
||||
if len(loss_type) > 0:
|
||||
loss_type = loss_type[0]
|
||||
else:
|
||||
loss_type = None
|
||||
if loss_type is None or loss_type not in LOSS_MAPPING and getattr(self.config, "loss_type", None) is not None:
|
||||
loss_type = getattr(self, "loss_type", None)
|
||||
|
||||
if loss_type is None or loss_type not in LOSS_MAPPING:
|
||||
logger.warning_once(
|
||||
f"`loss_type={loss_type}` was set in the config but it is unrecognised."
|
||||
f"Using the default loss: `ForCausalLMLoss`."
|
||||
|
|
|
|||
Loading…
Reference in a new issue