[inductor] Better messaging when triton version is too old (#130403)

Summary:
If triton is available, but we can't import triton.compiler.compiler.triton_key, then we see some annoying behavior:
1) If we don't actually need to compile triton, the subprocess pool will still spew error messages about the import failure; it's unclear to users if this is an actual problem.
2) If we do need to compile triton, we a) see the error messages from above and b) get a vanilla import exception without the helpful "RuntimeError: Cannot find a working triton installation ..."

Test Plan: Ran with and without torch.compile for a) recent version of triton, b) triton 2.2, and c) no triton. In all cases, verified expected output (success or meaningful error message)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130403
Approved by: https://github.com/eellison
This commit is contained in:
Sam Larsen 2024-07-10 08:51:40 -07:00 committed by PyTorch MergeBot
parent ceedee23ec
commit 358da54be5
3 changed files with 8 additions and 5 deletions

View file

@ -38,6 +38,7 @@ from torch._inductor.runtime.compile_tasks import (
)
from torch.hub import _Faketqdm, tqdm
from torch.utils._triton import has_triton_package
if TYPE_CHECKING:
from torch._inductor.runtime.hints import HalideMeta
@ -63,8 +64,8 @@ def pre_fork_setup():
from triton.compiler.compiler import triton_key
triton_key()
except ModuleNotFoundError:
# Might not be installed.
except ImportError:
# Triton might not be installed or might be an old version.
pass
@ -267,6 +268,8 @@ class AsyncCompile:
if (
os.environ.get("TORCH_TNT_IN_USE", "0") == "1"
or os.environ.get("TORCH_WARM_POOL", "1") != "1"
# The subprocess pool is only used for the Triton backend
or not has_triton_package()
):
pass
else:

View file

@ -2637,7 +2637,7 @@ class Scheduler:
)
elif is_gpu(device.type):
raise RuntimeError(
"Cannot find a working triton installation. More information on installing Triton can be found at https://github.com/openai/triton" # noqa: B950
"Cannot find a working triton installation. Either the package is not installed or it is too old. More information on installing Triton can be found at https://github.com/openai/triton" # noqa: B950
)
return device_scheduling(self)

View file

@ -6,9 +6,9 @@ import hashlib
@functools.lru_cache(None)
def has_triton_package() -> bool:
try:
import triton
from triton.compiler.compiler import triton_key
return triton is not None
return triton_key is not None
except ImportError:
return False