MLU devices : Checks if mlu is available via an cndev-based check which won't trigger the drivers and leave mlu (#34326)

* add Cambricon MLUs support

* fix mlu device rng state

* up for quality check

* up mlu to support fp16

* fix mlu device dependency error

* fix mlu device dependency error

* enable mlu device for bf16

* fix mlu device memory tracker

* Cambricon support SDPA and flash_attn

* MLU devices : Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu
This commit is contained in:
huismiling 2024-11-19 23:37:39 +08:00 committed by GitHub
parent e3a5889ef0
commit 581524389a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -684,25 +684,27 @@ def is_torch_npu_available(check_device=False):
@lru_cache()
def is_torch_mlu_available(check_device=False):
"Checks if `torch_mlu` is installed and potentially if a MLU is in the environment"
"""
Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu
uninitialized.
"""
if not _torch_available or importlib.util.find_spec("torch_mlu") is None:
return False
import torch
import torch_mlu # noqa: F401
from ..dependency_versions_table import deps
pytorch_cndev_based_mlu_check_previous_value = os.environ.get("PYTORCH_CNDEV_BASED_MLU_CHECK")
try:
os.environ["PYTORCH_CNDEV_BASED_MLU_CHECK"] = str(1)
available = torch.mlu.is_available()
finally:
if pytorch_cndev_based_mlu_check_previous_value:
os.environ["PYTORCH_CNDEV_BASED_MLU_CHECK"] = pytorch_cndev_based_mlu_check_previous_value
else:
os.environ.pop("PYTORCH_CNDEV_BASED_MLU_CHECK", None)
deps["deepspeed"] = "deepspeed-mlu>=0.10.1"
if check_device:
try:
# Will raise a RuntimeError if no MLU is found
_ = torch.mlu.device_count()
return torch.mlu.is_available()
except RuntimeError:
return False
return hasattr(torch, "mlu") and torch.mlu.is_available()
return available
@lru_cache()