Update test_tp.py

This commit is contained in:
Cyril Vallez 2025-01-31 22:41:02 +01:00
parent 8c419c6fa9
commit a3a55d06df
No known key found for this signature in database

View file

@ -84,8 +84,8 @@ class TestTensorParallel(TestCasePlus):
expected_model_memory = 16
overhead_factor = 1.2
# # Check that we do not use more than the expected sharded size during initialization
if not torch.cuda.max_memory_allocated(device) / 1024**3 < (expected_model_memory / world_size) * overhead_factor:
# Check that we do not use more than the expected sharded size during initialization
if torch.cuda.max_memory_allocated(device) / 1024**3 > (expected_model_memory / world_size) * overhead_factor:
raise ValueError("Loading the model used more than the expected fraction of model size per device")
torch.distributed.barrier()