fix module order (#18312)

- put gelu before 4h to h
This commit is contained in:
Younes Belkada 2022-07-27 13:06:01 +02:00 committed by GitHub
parent 70e7d1d656
commit 7996ef74dd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -352,9 +352,9 @@ class BloomMLP(nn.Module):
self.pretraining_tp = config.pretraining_tp
self.slow_but_exact = config.slow_but_exact
self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
self.gelu_impl = BloomGelu()
self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
self.hidden_dropout = config.hidden_dropout
self.gelu_impl = BloomGelu()
def forward(self, hidden_states, residual):
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))