mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
parent
700e0cd65f
commit
4bb0764750
1 changed files with 2 additions and 17 deletions
|
|
@ -215,23 +215,8 @@ class MixedInt8TestMultiGpu(BaseMixedInt8Test):
|
|||
self.model_name, load_in_8bit=True, max_memory=memory_mapping, device_map="auto"
|
||||
)
|
||||
|
||||
def get_list_devices(model):
|
||||
list_devices = []
|
||||
for _, module in model.named_children():
|
||||
if len(list(module.children())) > 0:
|
||||
list_devices.extend(get_list_devices(module))
|
||||
else:
|
||||
# Do a try except since we can encounter Dropout modules that does not
|
||||
# have any device set
|
||||
try:
|
||||
list_devices.append(next(module.parameters()).device.index)
|
||||
except BaseException:
|
||||
continue
|
||||
return list_devices
|
||||
|
||||
list_devices = get_list_devices(model_parallel)
|
||||
# Check that we have dispatched the model into 2 separate devices
|
||||
self.assertTrue((1 in list_devices) and (0 in list_devices))
|
||||
# Check correct device map
|
||||
self.assertEqual(set(model_parallel.hf_device_map.values()), {0, 1})
|
||||
|
||||
# Check that inference pass works on the model
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
|
|
|
|||
Loading…
Reference in a new issue