refactor test (#20300)

- simplifies the devce checking test
This commit is contained in:
Younes Belkada 2022-11-17 15:59:22 +01:00 committed by GitHub
parent 700e0cd65f
commit 4bb0764750
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -215,23 +215,8 @@ class MixedInt8TestMultiGpu(BaseMixedInt8Test):
self.model_name, load_in_8bit=True, max_memory=memory_mapping, device_map="auto"
)
def get_list_devices(model):
list_devices = []
for _, module in model.named_children():
if len(list(module.children())) > 0:
list_devices.extend(get_list_devices(module))
else:
# Do a try except since we can encounter Dropout modules that does not
# have any device set
try:
list_devices.append(next(module.parameters()).device.index)
except BaseException:
continue
return list_devices
list_devices = get_list_devices(model_parallel)
# Check that we have dispatched the model into 2 separate devices
self.assertTrue((1 in list_devices) and (0 in list_devices))
# Check correct device map
self.assertEqual(set(model_parallel.hf_device_map.values()), {0, 1})
# Check that inference pass works on the model
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")