From e719b65c31e48c07e78dea27bc28aaaebf69c16e Mon Sep 17 00:00:00 2001 From: Theia Vogel Date: Wed, 11 Sep 2024 04:26:44 -0700 Subject: [PATCH] Fix `FbgemmFp8Linear` not preserving tensor shape (#33239) * add tests for linear shape behavior * fix linear shape behavior ended up adding the reshape at the end, after f8f8bf16_rowwise, because adding it directly after quantize_fp8_per_row caused f8f8bf16_rowwise to drop the seq_len dimension. (i.e., (17, 23, 1014) -> (17, 1024)) * save shape up front + comment --- src/transformers/integrations/fbgemm_fp8.py | 3 ++ .../fbgemm_fp8/test_fbgemm_fp8.py | 31 +++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/src/transformers/integrations/fbgemm_fp8.py b/src/transformers/integrations/fbgemm_fp8.py index a0f5b2b76..71c2b570c 100644 --- a/src/transformers/integrations/fbgemm_fp8.py +++ b/src/transformers/integrations/fbgemm_fp8.py @@ -45,6 +45,8 @@ class FbgemmFp8Linear(torch.nn.Module): def forward(self, x): num_tokens = None + # quantize_fp8_per_row will squash the leading dimensions, so save the desired shape here + output_shape = (*x.shape[:-1], -1) # x_quantized and x_scale are not necessarily on the same device as x, this is an issue. # https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45 x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( @@ -60,6 +62,7 @@ class FbgemmFp8Linear(torch.nn.Module): output = output + self.bias if self.bias is not None else output # Hacky for now, we have the output to the device of x output = output.to(x.device) + output = output.reshape(output_shape) del x_quantized, x_scale return output diff --git a/tests/quantization/fbgemm_fp8/test_fbgemm_fp8.py b/tests/quantization/fbgemm_fp8/test_fbgemm_fp8.py index 61a1eecba..a9ff650c0 100644 --- a/tests/quantization/fbgemm_fp8/test_fbgemm_fp8.py +++ b/tests/quantization/fbgemm_fp8/test_fbgemm_fp8.py @@ -268,3 +268,34 @@ class FbgemmFp8Test(unittest.TestCase): output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + +@require_torch_gpu +@require_accelerate +@require_fbgemm_gpu +class FbgemmFp8LinearTest(unittest.TestCase): + def test_linear_preserves_shape(self): + """ + Test that FbgemmFp8Linear preserves shape when in_features == out_features. + """ + from transformers.integrations import FbgemmFp8Linear + + with init_empty_weights(include_buffers=True): + linear = FbgemmFp8Linear(1024, 1024, True) + x = torch.rand((17, 23, 1024)) + + x_ = linear(x) + self.assertEqual(x_.shape, x.shape) + + def test_linear_with_diff_feature_size_preserves_shape(self): + """ + Test that FbgemmFp8Linear generates the correct shape when in_features != out_features. + """ + from transformers.integrations import FbgemmFp8Linear + + with init_empty_weights(include_buffers=True): + linear = FbgemmFp8Linear(1024, 2048, True) + x = torch.rand((17, 23, 1024)) + + x_ = linear(x) + self.assertEqual(x_.shape, (17, 23, 2048))