diff --git a/setup.py b/setup.py index a3bf000d85..18f5fac3e0 100644 --- a/setup.py +++ b/setup.py @@ -428,7 +428,7 @@ with open(requirements_path) as f: if enable_training: - def save_build_and_package_info(package_name, version_number, cuda_version): + def save_build_and_package_info(package_name, version_number, cuda_version, rocm_version): sys.path.append(path.join(path.dirname(__file__), 'onnxruntime', 'python')) from onnxruntime_collect_build_info import find_cudart_versions @@ -450,11 +450,10 @@ if enable_training: "did not find any cudart library" if not cudart_versions or len(cudart_versions) == 0 else "found multiple cudart libraries") - else: - # TODO: rocm - pass + elif rocm_version: + f.write("rocm_version = '{}'\n".format(rocm_version)) - save_build_and_package_info(package_name, version_number, cuda_version) + save_build_and_package_info(package_name, version_number, cuda_version, rocm_version) # Setup setup(