Setup loss_type in config at model init time (#34616)

* setup loss_type in config at model init time

ensures no additional graph break introduced when torch.compile'ed

fixes #34615

Signed-off-by: ChanderG <mail@chandergovind.org>

* lookup loss mapping at init time instead of manual setup

Signed-off-by: ChanderG <mail@chandergovind.org>

* remove redundant lookup at loss_function time

* overwride losstype at init time

---------

Signed-off-by: ChanderG <mail@chandergovind.org>
Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
Chander G 2025-01-09 18:02:21 +05:30 committed by GitHub
parent c8ab6ce6ce
commit 4adc415b6d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1319,6 +1319,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
self.config = config
# for initialization of the loss
loss_type = self.__class__.__name__
if loss_type not in LOSS_MAPPING:
loss_groups = f"({'|'.join(LOSS_MAPPING)})"
loss_type = re.findall(loss_groups, self.__class__.__name__)
if len(loss_type) > 0:
loss_type = loss_type[0]
else:
loss_type = None
self.loss_type = loss_type
self.name_or_path = config.name_or_path
self.warnings_issued = {}
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
@ -5110,18 +5121,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
@property
def loss_function(self):
if getattr(self.config, "loss_type", None) is not None:
loss_type = self.config.loss_type
else:
loss_type = self.__class__.__name__
if loss_type not in LOSS_MAPPING:
loss_groups = f"({'|'.join(LOSS_MAPPING)})"
loss_type = re.findall(loss_groups, self.__class__.__name__)
if len(loss_type) > 0:
loss_type = loss_type[0]
else:
loss_type = None
if loss_type is None or loss_type not in LOSS_MAPPING and getattr(self.config, "loss_type", None) is not None:
loss_type = getattr(self, "loss_type", None)
if loss_type is None or loss_type not in LOSS_MAPPING:
logger.warning_once(
f"`loss_type={loss_type}` was set in the config but it is unrecognised."
f"Using the default loss: `ForCausalLMLoss`."