diff --git a/setup.py b/setup.py index 1d6246f63f..e1b265b860 100644 --- a/setup.py +++ b/setup.py @@ -327,7 +327,7 @@ if enable_training: + 'cu' + cuda_version.replace('.', '') else: local_version = '+cu' + cuda_version.replace('.', '') - if rocm_version: + elif rocm_version: # removing '.' to make Cuda version number in the same form as Pytorch. rocm_version = rocm_version.replace('.', '') if torch_version: