mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Fixing conversation test for torch 1.8 (#10545)
This commit is contained in:
parent
dc9aaa3848
commit
54e55b52d4
1 changed files with 2 additions and 2 deletions
|
|
@ -53,9 +53,9 @@ class SimpleConversationPipelineTests(unittest.TestCase):
|
|||
model = GPT2LMHeadModel(config)
|
||||
# Force model output to be L
|
||||
V, D = model.lm_head.weight.shape
|
||||
bias = torch.zeros(V, requires_grad=True)
|
||||
weight = torch.zeros((V, D), requires_grad=True)
|
||||
bias = torch.zeros(V)
|
||||
bias[76] = 1
|
||||
weight = torch.zeros((V, D), requires_grad=True)
|
||||
|
||||
model.lm_head.bias = torch.nn.Parameter(bias)
|
||||
model.lm_head.weight = torch.nn.Parameter(weight)
|
||||
|
|
|
|||
Loading…
Reference in a new issue