mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
fix style
This commit is contained in:
parent
33f73712dc
commit
3ead98b2f6
2 changed files with 2 additions and 2 deletions
|
|
@ -28,7 +28,6 @@ if is_accelerate_available():
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@ if is_torch_available():
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
class FP8ConfigTest(unittest.TestCase):
|
class FP8ConfigTest(unittest.TestCase):
|
||||||
def test_to_dict(self):
|
def test_to_dict(self):
|
||||||
|
|
@ -200,6 +201,7 @@ class FP8QuantizerTest(unittest.TestCase):
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
class FP8LinearTest(unittest.TestCase):
|
class FP8LinearTest(unittest.TestCase):
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
|
|
||||||
def test_linear_preserves_shape(self):
|
def test_linear_preserves_shape(self):
|
||||||
"""
|
"""
|
||||||
Test that FP8Linear preserves shape when in_features == out_features.
|
Test that FP8Linear preserves shape when in_features == out_features.
|
||||||
|
|
@ -218,7 +220,6 @@ class FP8LinearTest(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
from transformers.integrations import FP8Linear
|
from transformers.integrations import FP8Linear
|
||||||
|
|
||||||
|
|
||||||
linear = FP8Linear(128, 256, device=self.device)
|
linear = FP8Linear(128, 256, device=self.device)
|
||||||
x = torch.rand((1, 5, 128)).to(self.device)
|
x = torch.rand((1, 5, 128)).to(self.device)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue