mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
MLU devices : Checks if mlu is available via an cndev-based check which won't trigger the drivers and leave mlu (#34326)
* add Cambricon MLUs support * fix mlu device rng state * up for quality check * up mlu to support fp16 * fix mlu device dependency error * fix mlu device dependency error * enable mlu device for bf16 * fix mlu device memory tracker * Cambricon support SDPA and flash_attn * MLU devices : Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu
This commit is contained in:
parent
e3a5889ef0
commit
581524389a
1 changed files with 14 additions and 12 deletions
|
|
@ -684,25 +684,27 @@ def is_torch_npu_available(check_device=False):
|
|||
|
||||
@lru_cache()
|
||||
def is_torch_mlu_available(check_device=False):
|
||||
"Checks if `torch_mlu` is installed and potentially if a MLU is in the environment"
|
||||
"""
|
||||
Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu
|
||||
uninitialized.
|
||||
"""
|
||||
if not _torch_available or importlib.util.find_spec("torch_mlu") is None:
|
||||
return False
|
||||
|
||||
import torch
|
||||
import torch_mlu # noqa: F401
|
||||
|
||||
from ..dependency_versions_table import deps
|
||||
pytorch_cndev_based_mlu_check_previous_value = os.environ.get("PYTORCH_CNDEV_BASED_MLU_CHECK")
|
||||
try:
|
||||
os.environ["PYTORCH_CNDEV_BASED_MLU_CHECK"] = str(1)
|
||||
available = torch.mlu.is_available()
|
||||
finally:
|
||||
if pytorch_cndev_based_mlu_check_previous_value:
|
||||
os.environ["PYTORCH_CNDEV_BASED_MLU_CHECK"] = pytorch_cndev_based_mlu_check_previous_value
|
||||
else:
|
||||
os.environ.pop("PYTORCH_CNDEV_BASED_MLU_CHECK", None)
|
||||
|
||||
deps["deepspeed"] = "deepspeed-mlu>=0.10.1"
|
||||
|
||||
if check_device:
|
||||
try:
|
||||
# Will raise a RuntimeError if no MLU is found
|
||||
_ = torch.mlu.device_count()
|
||||
return torch.mlu.is_available()
|
||||
except RuntimeError:
|
||||
return False
|
||||
return hasattr(torch, "mlu") and torch.mlu.is_available()
|
||||
return available
|
||||
|
||||
|
||||
@lru_cache()
|
||||
|
|
|
|||
Loading…
Reference in a new issue