diff --git a/tests/tp/test_tp.py b/tests/tp/test_tp.py index bb431cef9..586de4f5b 100644 --- a/tests/tp/test_tp.py +++ b/tests/tp/test_tp.py @@ -84,8 +84,8 @@ class TestTensorParallel(TestCasePlus): expected_model_memory = 16 overhead_factor = 1.2 - # # Check that we do not use more than the expected sharded size during initialization - if not torch.cuda.max_memory_allocated(device) / 1024**3 < (expected_model_memory / world_size) * overhead_factor: + # Check that we do not use more than the expected sharded size during initialization + if torch.cuda.max_memory_allocated(device) / 1024**3 > (expected_model_memory / world_size) * overhead_factor: raise ValueError("Loading the model used more than the expected fraction of model size per device") torch.distributed.barrier()