Fixing conversation test for torch 1.8 (#10545)

This commit is contained in:
Nicolas Patry 2021-03-05 15:24:14 +01:00 committed by GitHub
parent dc9aaa3848
commit 54e55b52d4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -53,9 +53,9 @@ class SimpleConversationPipelineTests(unittest.TestCase):
model = GPT2LMHeadModel(config)
# Force model output to be L
V, D = model.lm_head.weight.shape
bias = torch.zeros(V, requires_grad=True)
weight = torch.zeros((V, D), requires_grad=True)
bias = torch.zeros(V)
bias[76] = 1
weight = torch.zeros((V, D), requires_grad=True)
model.lm_head.bias = torch.nn.Parameter(bias)
model.lm_head.weight = torch.nn.Parameter(weight)