From 03b980990a2dba03611f2d89cdac07ea57254d48 Mon Sep 17 00:00:00 2001 From: yuanwu2017 Date: Fri, 5 Jan 2024 19:21:29 +0800 Subject: [PATCH] Don't check the device when device_map=auto (#28351) When running the case on multi-cards server with devcie_map-auto, It will not always be allocated to device 0, Because other processes may be using these cards. It will select the devices that can accommodate this model. Signed-off-by: yuanwu --- tests/pipelines/test_pipelines_text_generation.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index dc77204f3..b80944e80 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -276,7 +276,6 @@ class TextGenerationPipelineTests(unittest.TestCase): model="hf-internal-testing/tiny-random-bloom", model_kwargs={"device_map": "auto", "torch_dtype": torch.bfloat16}, ) - self.assertEqual(pipe.model.device, torch.device(0)) self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16) out = pipe("This is a test") self.assertEqual( @@ -293,7 +292,6 @@ class TextGenerationPipelineTests(unittest.TestCase): # Upgraded those two to real pipeline arguments (they just get sent for the model as they're unlikely to mean anything else.) pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.bfloat16) - self.assertEqual(pipe.model.device, torch.device(0)) self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16) out = pipe("This is a test") self.assertEqual( @@ -310,7 +308,6 @@ class TextGenerationPipelineTests(unittest.TestCase): # torch_dtype will be automatically set to float32 if not provided - check: https://github.com/huggingface/transformers/pull/20602 pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto") - self.assertEqual(pipe.model.device, torch.device(0)) self.assertEqual(pipe.model.lm_head.weight.dtype, torch.float32) out = pipe("This is a test") self.assertEqual(