mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Correct NATTEN function signatures and force new version (#22298)
This commit is contained in:
parent
d35f729649
commit
5990743fdd
4 changed files with 4 additions and 4 deletions
2
setup.py
2
setup.py
|
|
@ -129,7 +129,7 @@ _deps = [
|
|||
"keras-nlp>=0.3.1",
|
||||
"librosa",
|
||||
"nltk",
|
||||
"natten>=0.14.5",
|
||||
"natten>=0.14.6",
|
||||
"numpy>=1.17",
|
||||
"onnxconverter-common",
|
||||
"onnxruntime-tools>=1.4.2",
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ deps = {
|
|||
"keras-nlp": "keras-nlp>=0.3.1",
|
||||
"librosa": "librosa",
|
||||
"nltk": "nltk",
|
||||
"natten": "natten>=0.14.5",
|
||||
"natten": "natten>=0.14.6",
|
||||
"numpy": "numpy>=1.17",
|
||||
"onnxconverter-common": "onnxconverter-common",
|
||||
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
|
||||
|
|
|
|||
|
|
@ -356,7 +356,7 @@ class NeighborhoodAttention(nn.Module):
|
|||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
context_layer = natten2dav(attention_probs, value_layer, self.dilation)
|
||||
context_layer = natten2dav(attention_probs, value_layer, self.kernel_size, self.dilation)
|
||||
context_layer = context_layer.permute(0, 2, 3, 1, 4).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
|
|
|||
|
|
@ -348,7 +348,7 @@ class NeighborhoodAttention(nn.Module):
|
|||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
context_layer = natten2dav(attention_probs, value_layer, 1)
|
||||
context_layer = natten2dav(attention_probs, value_layer, self.kernel_size, 1)
|
||||
context_layer = context_layer.permute(0, 2, 3, 1, 4).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
|
|
|||
Loading…
Reference in a new issue