mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[inductor] Better messaging when triton version is too old (#130403)
Summary: If triton is available, but we can't import triton.compiler.compiler.triton_key, then we see some annoying behavior: 1) If we don't actually need to compile triton, the subprocess pool will still spew error messages about the import failure; it's unclear to users if this is an actual problem. 2) If we do need to compile triton, we a) see the error messages from above and b) get a vanilla import exception without the helpful "RuntimeError: Cannot find a working triton installation ..." Test Plan: Ran with and without torch.compile for a) recent version of triton, b) triton 2.2, and c) no triton. In all cases, verified expected output (success or meaningful error message) Pull Request resolved: https://github.com/pytorch/pytorch/pull/130403 Approved by: https://github.com/eellison
This commit is contained in:
parent
ceedee23ec
commit
358da54be5
3 changed files with 8 additions and 5 deletions
|
|
@ -38,6 +38,7 @@ from torch._inductor.runtime.compile_tasks import (
|
|||
)
|
||||
|
||||
from torch.hub import _Faketqdm, tqdm
|
||||
from torch.utils._triton import has_triton_package
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.runtime.hints import HalideMeta
|
||||
|
|
@ -63,8 +64,8 @@ def pre_fork_setup():
|
|||
from triton.compiler.compiler import triton_key
|
||||
|
||||
triton_key()
|
||||
except ModuleNotFoundError:
|
||||
# Might not be installed.
|
||||
except ImportError:
|
||||
# Triton might not be installed or might be an old version.
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -267,6 +268,8 @@ class AsyncCompile:
|
|||
if (
|
||||
os.environ.get("TORCH_TNT_IN_USE", "0") == "1"
|
||||
or os.environ.get("TORCH_WARM_POOL", "1") != "1"
|
||||
# The subprocess pool is only used for the Triton backend
|
||||
or not has_triton_package()
|
||||
):
|
||||
pass
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -2637,7 +2637,7 @@ class Scheduler:
|
|||
)
|
||||
elif is_gpu(device.type):
|
||||
raise RuntimeError(
|
||||
"Cannot find a working triton installation. More information on installing Triton can be found at https://github.com/openai/triton" # noqa: B950
|
||||
"Cannot find a working triton installation. Either the package is not installed or it is too old. More information on installing Triton can be found at https://github.com/openai/triton" # noqa: B950
|
||||
)
|
||||
|
||||
return device_scheduling(self)
|
||||
|
|
|
|||
|
|
@ -6,9 +6,9 @@ import hashlib
|
|||
@functools.lru_cache(None)
|
||||
def has_triton_package() -> bool:
|
||||
try:
|
||||
import triton
|
||||
from triton.compiler.compiler import triton_key
|
||||
|
||||
return triton is not None
|
||||
return triton_key is not None
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue