diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 272b4af5a..b31d3b8d4 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1493,7 +1493,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix message += ( ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)' ) - raise ValueError(message + ".") + if config._attn_implementation in ALL_ATTENTION_FUNCTIONS: + pass + else: + raise ValueError(message + ".") # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available. requested_attn_implementation = config._attn_implementation_internal