fix style

This commit is contained in:
MekkCyber 2025-02-08 11:23:12 +00:00
parent 33f73712dc
commit 3ead98b2f6
2 changed files with 2 additions and 2 deletions

View file

@ -28,7 +28,6 @@ if is_accelerate_available():
from accelerate import init_empty_weights from accelerate import init_empty_weights
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)

View file

@ -33,6 +33,7 @@ if is_torch_available():
if is_accelerate_available(): if is_accelerate_available():
from accelerate import init_empty_weights from accelerate import init_empty_weights
@require_torch_gpu @require_torch_gpu
class FP8ConfigTest(unittest.TestCase): class FP8ConfigTest(unittest.TestCase):
def test_to_dict(self): def test_to_dict(self):
@ -200,6 +201,7 @@ class FP8QuantizerTest(unittest.TestCase):
@require_torch_gpu @require_torch_gpu
class FP8LinearTest(unittest.TestCase): class FP8LinearTest(unittest.TestCase):
device = "cuda" device = "cuda"
def test_linear_preserves_shape(self): def test_linear_preserves_shape(self):
""" """
Test that FP8Linear preserves shape when in_features == out_features. Test that FP8Linear preserves shape when in_features == out_features.
@ -218,7 +220,6 @@ class FP8LinearTest(unittest.TestCase):
""" """
from transformers.integrations import FP8Linear from transformers.integrations import FP8Linear
linear = FP8Linear(128, 256, device=self.device) linear = FP8Linear(128, 256, device=self.device)
x = torch.rand((1, 5, 128)).to(self.device) x = torch.rand((1, 5, 128)).to(self.device)