mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Update test_tp.py
This commit is contained in:
parent
8c419c6fa9
commit
a3a55d06df
1 changed files with 2 additions and 2 deletions
|
|
@ -84,8 +84,8 @@ class TestTensorParallel(TestCasePlus):
|
|||
expected_model_memory = 16
|
||||
overhead_factor = 1.2
|
||||
|
||||
# # Check that we do not use more than the expected sharded size during initialization
|
||||
if not torch.cuda.max_memory_allocated(device) / 1024**3 < (expected_model_memory / world_size) * overhead_factor:
|
||||
# Check that we do not use more than the expected sharded size during initialization
|
||||
if torch.cuda.max_memory_allocated(device) / 1024**3 > (expected_model_memory / world_size) * overhead_factor:
|
||||
raise ValueError("Loading the model used more than the expected fraction of model size per device")
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
|
|
|||
Loading…
Reference in a new issue