mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Enable memory metrics in tests that need it (#11859)
This commit is contained in:
parent
db0b2477cc
commit
6da129cb31
1 changed files with 3 additions and 3 deletions
|
|
@ -1102,7 +1102,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||
def test_mem_metrics(self):
|
||||
|
||||
# with mem metrics enabled
|
||||
trainer = get_regression_trainer()
|
||||
trainer = get_regression_trainer(skip_memory_metrics=False)
|
||||
self.check_mem_metrics(trainer, self.assertIn)
|
||||
|
||||
# with mem metrics disabled
|
||||
|
|
@ -1123,7 +1123,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||
b = torch.ones(1000, bs) - 0.001
|
||||
|
||||
# 1. with mem metrics enabled
|
||||
trainer = get_regression_trainer(a=a, b=b, eval_len=16)
|
||||
trainer = get_regression_trainer(a=a, b=b, eval_len=16, skip_memory_metrics=False)
|
||||
metrics = trainer.evaluate()
|
||||
del trainer
|
||||
gc.collect()
|
||||
|
|
@ -1144,7 +1144,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||
self.assertLess(fp32_eval, 5_000)
|
||||
|
||||
# 2. with mem metrics disabled
|
||||
trainer = get_regression_trainer(a=a, b=b, eval_len=16, fp16_full_eval=True)
|
||||
trainer = get_regression_trainer(a=a, b=b, eval_len=16, fp16_full_eval=True, skip_memory_metrics=False)
|
||||
metrics = trainer.evaluate()
|
||||
fp16_init = metrics["init_mem_gpu_alloc_delta"]
|
||||
fp16_eval = metrics["eval_mem_gpu_alloc_delta"]
|
||||
|
|
|
|||
Loading…
Reference in a new issue