mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-15 21:01:19 +00:00
[T5] Fix speed degradation bug t5 (#10496)
* fix speed degradation bug t5 * fix for all models * fix code quality
This commit is contained in:
parent
5dc303e281
commit
2d2ed2cc18
9 changed files with 30 additions and 11 deletions
|
|
@ -319,7 +319,9 @@ class BartEncoderLayer(nn.Module):
|
|||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
|
||||
if hidden_states.dtype == torch.float16 and (
|
||||
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
||||
):
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||
|
||||
|
|
|
|||
|
|
@ -322,7 +322,9 @@ class BlenderbotEncoderLayer(nn.Module):
|
|||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
|
||||
if hidden_states.dtype == torch.float16 and (
|
||||
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
||||
):
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||
|
||||
|
|
|
|||
|
|
@ -320,7 +320,9 @@ class BlenderbotSmallEncoderLayer(nn.Module):
|
|||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
|
||||
if hidden_states.dtype == torch.float16 and (
|
||||
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
||||
):
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||
|
||||
|
|
|
|||
|
|
@ -925,7 +925,9 @@ class LEDEncoderLayer(nn.Module):
|
|||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
|
||||
if hidden_states.dtype == torch.float16 and (
|
||||
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
||||
):
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||
return (hidden_states,) + attn_outputs[1:]
|
||||
|
|
|
|||
|
|
@ -337,7 +337,9 @@ class MarianEncoderLayer(nn.Module):
|
|||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
|
||||
if hidden_states.dtype == torch.float16 and (
|
||||
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
||||
):
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||
|
||||
|
|
|
|||
|
|
@ -326,7 +326,9 @@ class MBartEncoderLayer(nn.Module):
|
|||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
|
||||
if hidden_states.dtype == torch.float16 and (
|
||||
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
||||
):
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||
|
||||
|
|
|
|||
|
|
@ -337,7 +337,9 @@ class PegasusEncoderLayer(nn.Module):
|
|||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
|
||||
if hidden_states.dtype == torch.float16 and (
|
||||
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
||||
):
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||
|
||||
|
|
|
|||
|
|
@ -643,7 +643,7 @@ class T5Block(nn.Module):
|
|||
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
|
||||
|
||||
# clamp inf values to enable fp16 training
|
||||
if torch.isinf(hidden_states).any():
|
||||
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||
|
||||
|
|
@ -668,7 +668,9 @@ class T5Block(nn.Module):
|
|||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = cross_attention_outputs[0]
|
||||
if torch.isinf(hidden_states).any():
|
||||
|
||||
# clamp inf values to enable fp16 training
|
||||
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||
|
||||
|
|
@ -681,9 +683,12 @@ class T5Block(nn.Module):
|
|||
|
||||
# Apply Feed Forward layer
|
||||
hidden_states = self.layer[-1](hidden_states)
|
||||
if torch.isinf(hidden_states).any():
|
||||
|
||||
# clamp inf values to enable fp16 training
|
||||
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
outputs = outputs + (present_key_value_state,) + attention_outputs
|
||||
|
|
|
|||
|
|
@ -1824,7 +1824,7 @@ class {{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
|
|||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
|
||||
if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()):
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue