From 4adc415b6de5b899a48379b6cba01bd930cb587d Mon Sep 17 00:00:00 2001 From: Chander G Date: Thu, 9 Jan 2025 18:02:21 +0530 Subject: [PATCH] Setup loss_type in config at model init time (#34616) * setup loss_type in config at model init time ensures no additional graph break introduced when torch.compile'ed fixes #34615 Signed-off-by: ChanderG * lookup loss mapping at init time instead of manual setup Signed-off-by: ChanderG * remove redundant lookup at loss_function time * overwride losstype at init time --------- Signed-off-by: ChanderG Co-authored-by: Arthur Zucker --- src/transformers/modeling_utils.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 40892f0cd..8eb2d7439 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1319,6 +1319,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) self.config = config + # for initialization of the loss + loss_type = self.__class__.__name__ + if loss_type not in LOSS_MAPPING: + loss_groups = f"({'|'.join(LOSS_MAPPING)})" + loss_type = re.findall(loss_groups, self.__class__.__name__) + if len(loss_type) > 0: + loss_type = loss_type[0] + else: + loss_type = None + self.loss_type = loss_type + self.name_or_path = config.name_or_path self.warnings_issued = {} self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None @@ -5110,18 +5121,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix @property def loss_function(self): - if getattr(self.config, "loss_type", None) is not None: - loss_type = self.config.loss_type - else: - loss_type = self.__class__.__name__ - if loss_type not in LOSS_MAPPING: - loss_groups = f"({'|'.join(LOSS_MAPPING)})" - loss_type = re.findall(loss_groups, self.__class__.__name__) - if len(loss_type) > 0: - loss_type = loss_type[0] - else: - loss_type = None - if loss_type is None or loss_type not in LOSS_MAPPING and getattr(self.config, "loss_type", None) is not None: + loss_type = getattr(self, "loss_type", None) + + if loss_type is None or loss_type not in LOSS_MAPPING: logger.warning_once( f"`loss_type={loss_type}` was set in the config but it is unrecognised." f"Using the default loss: `ForCausalLMLoss`."