From 54e55b52d4886d4c63e592310b4253e01c606285 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 5 Mar 2021 15:24:14 +0100 Subject: [PATCH] Fixing conversation test for torch 1.8 (#10545) --- tests/test_pipelines_conversational.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_pipelines_conversational.py b/tests/test_pipelines_conversational.py index 4ea4d808a..4860fce72 100644 --- a/tests/test_pipelines_conversational.py +++ b/tests/test_pipelines_conversational.py @@ -53,9 +53,9 @@ class SimpleConversationPipelineTests(unittest.TestCase): model = GPT2LMHeadModel(config) # Force model output to be L V, D = model.lm_head.weight.shape - bias = torch.zeros(V, requires_grad=True) - weight = torch.zeros((V, D), requires_grad=True) + bias = torch.zeros(V) bias[76] = 1 + weight = torch.zeros((V, D), requires_grad=True) model.lm_head.bias = torch.nn.Parameter(bias) model.lm_head.weight = torch.nn.Parameter(weight)