diff --git a/src/transformers/integrations/fbgemm_fp8.py b/src/transformers/integrations/fbgemm_fp8.py index a0f5b2b76..71c2b570c 100644 --- a/src/transformers/integrations/fbgemm_fp8.py +++ b/src/transformers/integrations/fbgemm_fp8.py @@ -45,6 +45,8 @@ class FbgemmFp8Linear(torch.nn.Module): def forward(self, x): num_tokens = None + # quantize_fp8_per_row will squash the leading dimensions, so save the desired shape here + output_shape = (*x.shape[:-1], -1) # x_quantized and x_scale are not necessarily on the same device as x, this is an issue. # https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45 x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( @@ -60,6 +62,7 @@ class FbgemmFp8Linear(torch.nn.Module): output = output + self.bias if self.bias is not None else output # Hacky for now, we have the output to the device of x output = output.to(x.device) + output = output.reshape(output_shape) del x_quantized, x_scale return output diff --git a/tests/quantization/fbgemm_fp8/test_fbgemm_fp8.py b/tests/quantization/fbgemm_fp8/test_fbgemm_fp8.py index 61a1eecba..a9ff650c0 100644 --- a/tests/quantization/fbgemm_fp8/test_fbgemm_fp8.py +++ b/tests/quantization/fbgemm_fp8/test_fbgemm_fp8.py @@ -268,3 +268,34 @@ class FbgemmFp8Test(unittest.TestCase): output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + +@require_torch_gpu +@require_accelerate +@require_fbgemm_gpu +class FbgemmFp8LinearTest(unittest.TestCase): + def test_linear_preserves_shape(self): + """ + Test that FbgemmFp8Linear preserves shape when in_features == out_features. + """ + from transformers.integrations import FbgemmFp8Linear + + with init_empty_weights(include_buffers=True): + linear = FbgemmFp8Linear(1024, 1024, True) + x = torch.rand((17, 23, 1024)) + + x_ = linear(x) + self.assertEqual(x_.shape, x.shape) + + def test_linear_with_diff_feature_size_preserves_shape(self): + """ + Test that FbgemmFp8Linear generates the correct shape when in_features != out_features. + """ + from transformers.integrations import FbgemmFp8Linear + + with init_empty_weights(include_buffers=True): + linear = FbgemmFp8Linear(1024, 2048, True) + x = torch.rand((17, 23, 1024)) + + x_ = linear(x) + self.assertEqual(x_.shape, (17, 23, 2048))