Liqun/ort training version (#7620)

This commit is contained in:
liqunfu 2021-05-14 09:54:19 -07:00 committed by GitHub
parent bfbcc89db1
commit 359fe1d197
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 223 additions and 12 deletions

View file

@ -10,13 +10,30 @@ or the `Github project <https://github.com/microsoft/onnxruntime/>`_.
__version__ = "1.7.0"
__author__ = "Microsoft"
from onnxruntime.capi._pybind_state import get_all_providers, get_available_providers, get_device, set_seed, \
RunOptions, SessionOptions, set_default_logger_severity, enable_telemetry_events, disable_telemetry_events, \
NodeArg, ModelMetadata, GraphOptimizationLevel, ExecutionMode, ExecutionOrder, OrtDevice, SessionIOBinding, \
OrtAllocatorType, OrtMemType, OrtArenaCfg, OrtMemoryInfo, create_and_register_allocator
# we need to do device version validation (for example to check Cuda version for an onnxruntime-training package).
# in order to know whether the onnxruntime package is for training it needs
# to do import onnxruntime.training.ortmodule first.
# onnxruntime.capi._pybind_state is required before import onnxruntime.training.ortmodule.
# however, import onnxruntime.capi._pybind_state will already raise an exception if a required Cuda version
# is not found.
# here we need to save the exception and continue with Cuda version validation in order to post
# meaningful messages to the user.
# the saved exception is raised after device version validation.
try:
from onnxruntime.capi._pybind_state import get_all_providers, get_available_providers, get_device, set_seed, \
RunOptions, SessionOptions, set_default_logger_severity, enable_telemetry_events, disable_telemetry_events, \
NodeArg, ModelMetadata, GraphOptimizationLevel, ExecutionMode, ExecutionOrder, OrtDevice, SessionIOBinding, \
OrtAllocatorType, OrtMemType, OrtArenaCfg, OrtMemoryInfo, create_and_register_allocator
import_capi_exception = None
except Exception as e:
import_capi_exception = e
from onnxruntime.capi import onnxruntime_validation
if import_capi_exception:
raise import_capi_exception
from onnxruntime.capi.onnxruntime_inference_collection import InferenceSession, IOBinding, OrtValue
from onnxruntime.capi import onnxruntime_validation
from onnxruntime.capi.training import * # noqa: F403
@ -26,4 +43,8 @@ try:
except ImportError:
pass
from onnxruntime.capi.onnxruntime_validation import package_name, version, cuda_version
if version:
__version__ = version
onnxruntime_validation.check_distro_info()

View file

@ -0,0 +1,89 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import warnings
import ctypes
import sys
def find_cudart_versions(build_env=False, build_cuda_version=None):
# ctypes.CDLL and ctypes.util.find_library load the latest installed library.
# it may not the the library that would be loaded by onnxruntime.
# for example, in an environment with Cuda 11.1 and subsequently
# conda cudatoolkit 10.2.89 installed. ctypes will find cudart 10.2. however,
# onnxruntime built with Cuda 11.1 will find and load cudart for Cuda 11.1.
# for the above reason, we need find all versions in the environment and
# only give warnings if the expected cuda version is not found.
# in onnxruntime build environment, we expected only one Cuda version.
if not sys.platform.startswith('linux'):
warnings.warn('find_cudart_versions only works on Linux')
return None
cudart_possible_versions = {None, build_cuda_version}
def get_cudart_version(find_cudart_version=None):
cudart_lib_filename = 'libcudart.so'
if find_cudart_version:
cudart_lib_filename = cudart_lib_filename + '.' + find_cudart_version
try:
cudart = ctypes.CDLL(cudart_lib_filename)
cudart.cudaRuntimeGetVersion.restype = int
cudart.cudaRuntimeGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
version = ctypes.c_int()
status = cudart.cudaRuntimeGetVersion(ctypes.byref(version))
if status != 0:
return None
except: # noqa
return None
return version.value
# use set to avoid duplications
cudart_found_versions = {
get_cudart_version(cudart_version) for cudart_version in cudart_possible_versions}
# convert to list and remove None
return [ver for ver in cudart_found_versions if ver]
def find_cudnn_supported_cuda_versions(build_env=False):
# comments in get_cudart_version apply here
if not sys.platform.startswith('linux'):
warnings.warn('find_cudnn_versions only works on Linux')
cudnn_possible_versions = {None}
if not build_env:
# if not in a build environment, there may be more than one installed cudnn.
# https://developer.nvidia.com/rdp/cudnn-archive to include all that may support Cuda 10+.
cudnn_possible_versions.update({
'8.2',
'8.1.1', '8.1.0',
'8.0.5', '8.0.4', '8.0.3', '8.0.2', '8.0.1',
'7.6.5', '7.6.4', '7.6.3', '7.6.2', '7.6.1', '7.6.0',
'7.5.1', '7.5.0',
'7.4.2', '7.4.1',
'7.3.1', '7.3.0',
})
def get_cudnn_supported_cuda_version(find_cudnn_version=None):
cudnn_lib_filename = 'libcudnn.so'
if find_cudnn_version:
cudnn_lib_filename = cudnn_lib_filename + '.' + find_cudnn_version
# in cudnn.h cudnn version are calculated as:
# #define CUDNN_VERSION (CUDNN_MAJOR * 1000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL)
try:
cudnn = ctypes.CDLL(cudnn_lib_filename)
# cudnn_ver = cudnn.cudnnGetVersion()
cuda_ver = cudnn.cudnnGetCudartVersion()
return cuda_ver
except: # noqa
return None
# use set to avoid duplications
cuda_found_versions = {get_cudnn_supported_cuda_version(cudnn_version) for cudnn_version in cudnn_possible_versions}
# convert to list and remove None
return [ver for ver in cuda_found_versions if ver]

View file

@ -56,3 +56,71 @@ def check_distro_info():
else:
warnings.warn('Unsupported platform (%s). ONNX Runtime supports Linux, macOS and Windows platforms, only.' %
__my_system__)
def validate_build_package_info():
import_ortmodule_exception = None
try:
from onnxruntime.training.ortmodule import ORTModule # noqa
has_ortmodule = True
except ImportError:
has_ortmodule = False
except Exception as e:
# this may happen if Cuda is not installed, we want to raise it after
# for any exception other than not having ortmodule, we want to continue
# device version validation and raise the exception after.
import_ortmodule_exception = e
has_ortmodule = True
package_name = ''
version = ''
cuda_version = ''
if has_ortmodule:
try:
# collect onnxruntime package name, version, and cuda version
from .build_and_package_info import package_name
from .build_and_package_info import __version__ as version
try:
from .build_and_package_info import cuda_version
except: # noqa
pass
if cuda_version:
# collect cuda library build info. the library info may not be available
# when the build environment has none or multiple libraries installed
try:
from .build_and_package_info import cudart_version
except: # noqa
warnings.warn('WARNING: failed to get cudart_version from onnxruntime build info.')
cudart_version = None
def print_build_package_info():
warnings.warn('onnxruntime training package info: package_name: %s' % package_name)
warnings.warn('onnxruntime training package info: __version__: %s' % version)
warnings.warn('onnxruntime training package info: cuda_version: %s' % cuda_version)
warnings.warn('onnxruntime build info: cudart_version: %s' % cudart_version)
# collection cuda library info from current environment.
from onnxruntime.capi.onnxruntime_collect_build_info import find_cudart_versions
local_cudart_versions = find_cudart_versions(build_env=False, build_cuda_version=cuda_version)
if cudart_version and cudart_version not in local_cudart_versions:
print_build_package_info()
warnings.warn('WARNING: failed to find cudart version that matches onnxruntime build info')
warnings.warn('WARNING: found cudart versions: %s' % local_cudart_versions)
else:
# TODO: rcom
pass
except Exception as e: # noqa
warnings.warn('WARNING: failed to collect onnxruntime version and build info')
print(e)
if import_ortmodule_exception:
raise import_ortmodule_exception
return has_ortmodule, package_name, version, cuda_version
has_ortmodule, package_name, version, cuda_version = validate_build_package_info()

View file

@ -262,9 +262,8 @@ if enable_training:
# install an onnxruntime training package with matching torch cuda version.
package_name = 'onnxruntime-training'
if cuda_version:
# removing '.' to make Cuda version number in the same form as Pytorch.
cuda_version = cuda_version.replace('.', '')
local_version = '+cu' + cuda_version
# removing '.' to make local Cuda version number in the same form as Pytorch.
local_version = '+cu' + cuda_version.replace('.', '')
if rocm_version:
# removing '.' to make Cuda version number in the same form as Pytorch.
rocm_version = rocm_version.replace('.', '')
@ -369,6 +368,34 @@ if not path.exists(requirements_path):
with open(requirements_path) as f:
install_requires = f.read().splitlines()
if enable_training:
def save_build_and_package_info(package_name, version_number, cuda_version):
sys.path.append(path.join(path.dirname(__file__), 'onnxruntime', 'python'))
from onnxruntime_collect_build_info import find_cudart_versions
version_path = path.join('onnxruntime', 'capi', 'build_and_package_info.py')
with open(version_path, 'w') as f:
f.write("package_name = '{}'\n".format(package_name))
f.write("__version__ = '{}'\n".format(version_number))
if cuda_version:
f.write("cuda_version = '{}'\n".format(cuda_version))
# cudart_versions are integers
cudart_versions = find_cudart_versions(build_env=True)
if len(cudart_versions) == 1:
f.write("cudart_version = {}\n".format(cudart_versions[0]))
else:
print(
"Error getting cudart version. ",
"did not find any cudart library" if len(cudart_versions) == 0 else "found multiple cudart libraries")
else:
# TODO: rocm
pass
save_build_and_package_info(package_name, version_number, cuda_version)
# Setup
setup(
name=package_name,

View file

@ -1298,7 +1298,13 @@ def run_orttraining_test_orttrainer_frontend_separately(cwd):
def run_training_python_frontend_tests(cwd):
run_subprocess([sys.executable, 'onnxruntime_test_ort_trainer.py'], cwd=cwd)
# have to disable due to (with torchvision==0.9.1+cu102 which is required by ortmodule):
# Downloading http://yann.lecun.com/exdb/mnist/
# https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
# Failed to download (trying next):
# HTTP Error 404: Not Found
# run_subprocess([sys.executable, 'onnxruntime_test_ort_trainer.py'], cwd=cwd)
run_subprocess([sys.executable, 'onnxruntime_test_training_unit_tests.py'], cwd=cwd)
run_subprocess([
sys.executable, 'orttraining_test_transformers.py',

View file

@ -4,8 +4,8 @@
sklearn
numpy==1.16.6
transformers==v2.10.0
torch==1.6.0+cu101
torchvision==0.7.0+cu101
torchtext==0.7.0
torch==1.8.1+cu102
torchvision==0.9.1+cu102
torchtext==0.9.1
tensorboard==v2.0.0
h5py