be permissive

This commit is contained in:
Arthur Zucker 2024-12-12 11:33:37 +01:00
parent 584b443096
commit 95cb944ee6

View file

@ -1493,7 +1493,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
message += (
', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
)
raise ValueError(message + ".")
if config._attn_implementation in ALL_ATTENTION_FUNCTIONS:
pass
else:
raise ValueError(message + ".")
# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
requested_attn_implementation = config._attn_implementation_internal