diff --git a/tests/test_pipelines_conversational.py b/tests/test_pipelines_conversational.py index 4ea4d808a..4860fce72 100644 --- a/tests/test_pipelines_conversational.py +++ b/tests/test_pipelines_conversational.py @@ -53,9 +53,9 @@ class SimpleConversationPipelineTests(unittest.TestCase): model = GPT2LMHeadModel(config) # Force model output to be L V, D = model.lm_head.weight.shape - bias = torch.zeros(V, requires_grad=True) - weight = torch.zeros((V, D), requires_grad=True) + bias = torch.zeros(V) bias[76] = 1 + weight = torch.zeros((V, D), requires_grad=True) model.lm_head.bias = torch.nn.Parameter(bias) model.lm_head.weight = torch.nn.Parameter(weight)