From 4bb07647504a277398856e828fa48ddbec97678e Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 17 Nov 2022 15:59:22 +0100 Subject: [PATCH] refactor test (#20300) - simplifies the devce checking test --- tests/mixed_int8/test_mixed_int8.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index a459ffa84..6e8e7b842 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -215,23 +215,8 @@ class MixedInt8TestMultiGpu(BaseMixedInt8Test): self.model_name, load_in_8bit=True, max_memory=memory_mapping, device_map="auto" ) - def get_list_devices(model): - list_devices = [] - for _, module in model.named_children(): - if len(list(module.children())) > 0: - list_devices.extend(get_list_devices(module)) - else: - # Do a try except since we can encounter Dropout modules that does not - # have any device set - try: - list_devices.append(next(module.parameters()).device.index) - except BaseException: - continue - return list_devices - - list_devices = get_list_devices(model_parallel) - # Check that we have dispatched the model into 2 separate devices - self.assertTrue((1 in list_devices) and (0 in list_devices)) + # Check correct device map + self.assertEqual(set(model_parallel.hf_device_map.values()), {0, 1}) # Check that inference pass works on the model encoded_input = self.tokenizer(self.input_text, return_tensors="pt")