mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-02 23:39:58 +00:00
Liqun/ort training version (#7620)
This commit is contained in:
parent
bfbcc89db1
commit
359fe1d197
6 changed files with 223 additions and 12 deletions
|
|
@ -10,13 +10,30 @@ or the `Github project <https://github.com/microsoft/onnxruntime/>`_.
|
|||
__version__ = "1.7.0"
|
||||
__author__ = "Microsoft"
|
||||
|
||||
from onnxruntime.capi._pybind_state import get_all_providers, get_available_providers, get_device, set_seed, \
|
||||
RunOptions, SessionOptions, set_default_logger_severity, enable_telemetry_events, disable_telemetry_events, \
|
||||
NodeArg, ModelMetadata, GraphOptimizationLevel, ExecutionMode, ExecutionOrder, OrtDevice, SessionIOBinding, \
|
||||
OrtAllocatorType, OrtMemType, OrtArenaCfg, OrtMemoryInfo, create_and_register_allocator
|
||||
# we need to do device version validation (for example to check Cuda version for an onnxruntime-training package).
|
||||
# in order to know whether the onnxruntime package is for training it needs
|
||||
# to do import onnxruntime.training.ortmodule first.
|
||||
# onnxruntime.capi._pybind_state is required before import onnxruntime.training.ortmodule.
|
||||
# however, import onnxruntime.capi._pybind_state will already raise an exception if a required Cuda version
|
||||
# is not found.
|
||||
# here we need to save the exception and continue with Cuda version validation in order to post
|
||||
# meaningful messages to the user.
|
||||
# the saved exception is raised after device version validation.
|
||||
try:
|
||||
from onnxruntime.capi._pybind_state import get_all_providers, get_available_providers, get_device, set_seed, \
|
||||
RunOptions, SessionOptions, set_default_logger_severity, enable_telemetry_events, disable_telemetry_events, \
|
||||
NodeArg, ModelMetadata, GraphOptimizationLevel, ExecutionMode, ExecutionOrder, OrtDevice, SessionIOBinding, \
|
||||
OrtAllocatorType, OrtMemType, OrtArenaCfg, OrtMemoryInfo, create_and_register_allocator
|
||||
import_capi_exception = None
|
||||
except Exception as e:
|
||||
import_capi_exception = e
|
||||
|
||||
from onnxruntime.capi import onnxruntime_validation
|
||||
|
||||
if import_capi_exception:
|
||||
raise import_capi_exception
|
||||
|
||||
from onnxruntime.capi.onnxruntime_inference_collection import InferenceSession, IOBinding, OrtValue
|
||||
from onnxruntime.capi import onnxruntime_validation
|
||||
|
||||
from onnxruntime.capi.training import * # noqa: F403
|
||||
|
||||
|
|
@ -26,4 +43,8 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
from onnxruntime.capi.onnxruntime_validation import package_name, version, cuda_version
|
||||
if version:
|
||||
__version__ = version
|
||||
|
||||
onnxruntime_validation.check_distro_info()
|
||||
|
|
|
|||
89
onnxruntime/python/onnxruntime_collect_build_info.py
Normal file
89
onnxruntime/python/onnxruntime_collect_build_info.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import warnings
|
||||
import ctypes
|
||||
import sys
|
||||
|
||||
|
||||
def find_cudart_versions(build_env=False, build_cuda_version=None):
|
||||
# ctypes.CDLL and ctypes.util.find_library load the latest installed library.
|
||||
# it may not the the library that would be loaded by onnxruntime.
|
||||
# for example, in an environment with Cuda 11.1 and subsequently
|
||||
# conda cudatoolkit 10.2.89 installed. ctypes will find cudart 10.2. however,
|
||||
# onnxruntime built with Cuda 11.1 will find and load cudart for Cuda 11.1.
|
||||
# for the above reason, we need find all versions in the environment and
|
||||
# only give warnings if the expected cuda version is not found.
|
||||
# in onnxruntime build environment, we expected only one Cuda version.
|
||||
if not sys.platform.startswith('linux'):
|
||||
warnings.warn('find_cudart_versions only works on Linux')
|
||||
return None
|
||||
|
||||
cudart_possible_versions = {None, build_cuda_version}
|
||||
|
||||
def get_cudart_version(find_cudart_version=None):
|
||||
cudart_lib_filename = 'libcudart.so'
|
||||
if find_cudart_version:
|
||||
cudart_lib_filename = cudart_lib_filename + '.' + find_cudart_version
|
||||
|
||||
try:
|
||||
cudart = ctypes.CDLL(cudart_lib_filename)
|
||||
cudart.cudaRuntimeGetVersion.restype = int
|
||||
cudart.cudaRuntimeGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
|
||||
version = ctypes.c_int()
|
||||
status = cudart.cudaRuntimeGetVersion(ctypes.byref(version))
|
||||
if status != 0:
|
||||
return None
|
||||
except: # noqa
|
||||
return None
|
||||
|
||||
return version.value
|
||||
|
||||
# use set to avoid duplications
|
||||
cudart_found_versions = {
|
||||
get_cudart_version(cudart_version) for cudart_version in cudart_possible_versions}
|
||||
|
||||
# convert to list and remove None
|
||||
return [ver for ver in cudart_found_versions if ver]
|
||||
|
||||
|
||||
def find_cudnn_supported_cuda_versions(build_env=False):
|
||||
# comments in get_cudart_version apply here
|
||||
if not sys.platform.startswith('linux'):
|
||||
warnings.warn('find_cudnn_versions only works on Linux')
|
||||
|
||||
cudnn_possible_versions = {None}
|
||||
if not build_env:
|
||||
# if not in a build environment, there may be more than one installed cudnn.
|
||||
# https://developer.nvidia.com/rdp/cudnn-archive to include all that may support Cuda 10+.
|
||||
cudnn_possible_versions.update({
|
||||
'8.2',
|
||||
'8.1.1', '8.1.0',
|
||||
'8.0.5', '8.0.4', '8.0.3', '8.0.2', '8.0.1',
|
||||
'7.6.5', '7.6.4', '7.6.3', '7.6.2', '7.6.1', '7.6.0',
|
||||
'7.5.1', '7.5.0',
|
||||
'7.4.2', '7.4.1',
|
||||
'7.3.1', '7.3.0',
|
||||
})
|
||||
|
||||
def get_cudnn_supported_cuda_version(find_cudnn_version=None):
|
||||
cudnn_lib_filename = 'libcudnn.so'
|
||||
if find_cudnn_version:
|
||||
cudnn_lib_filename = cudnn_lib_filename + '.' + find_cudnn_version
|
||||
|
||||
# in cudnn.h cudnn version are calculated as:
|
||||
# #define CUDNN_VERSION (CUDNN_MAJOR * 1000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL)
|
||||
try:
|
||||
cudnn = ctypes.CDLL(cudnn_lib_filename)
|
||||
# cudnn_ver = cudnn.cudnnGetVersion()
|
||||
cuda_ver = cudnn.cudnnGetCudartVersion()
|
||||
return cuda_ver
|
||||
except: # noqa
|
||||
return None
|
||||
|
||||
# use set to avoid duplications
|
||||
cuda_found_versions = {get_cudnn_supported_cuda_version(cudnn_version) for cudnn_version in cudnn_possible_versions}
|
||||
|
||||
# convert to list and remove None
|
||||
return [ver for ver in cuda_found_versions if ver]
|
||||
|
|
@ -56,3 +56,71 @@ def check_distro_info():
|
|||
else:
|
||||
warnings.warn('Unsupported platform (%s). ONNX Runtime supports Linux, macOS and Windows platforms, only.' %
|
||||
__my_system__)
|
||||
|
||||
|
||||
def validate_build_package_info():
|
||||
import_ortmodule_exception = None
|
||||
try:
|
||||
from onnxruntime.training.ortmodule import ORTModule # noqa
|
||||
has_ortmodule = True
|
||||
except ImportError:
|
||||
has_ortmodule = False
|
||||
except Exception as e:
|
||||
# this may happen if Cuda is not installed, we want to raise it after
|
||||
# for any exception other than not having ortmodule, we want to continue
|
||||
# device version validation and raise the exception after.
|
||||
import_ortmodule_exception = e
|
||||
has_ortmodule = True
|
||||
|
||||
package_name = ''
|
||||
version = ''
|
||||
cuda_version = ''
|
||||
|
||||
if has_ortmodule:
|
||||
try:
|
||||
# collect onnxruntime package name, version, and cuda version
|
||||
from .build_and_package_info import package_name
|
||||
from .build_and_package_info import __version__ as version
|
||||
|
||||
try:
|
||||
from .build_and_package_info import cuda_version
|
||||
except: # noqa
|
||||
pass
|
||||
|
||||
if cuda_version:
|
||||
# collect cuda library build info. the library info may not be available
|
||||
# when the build environment has none or multiple libraries installed
|
||||
try:
|
||||
from .build_and_package_info import cudart_version
|
||||
except: # noqa
|
||||
warnings.warn('WARNING: failed to get cudart_version from onnxruntime build info.')
|
||||
cudart_version = None
|
||||
|
||||
def print_build_package_info():
|
||||
warnings.warn('onnxruntime training package info: package_name: %s' % package_name)
|
||||
warnings.warn('onnxruntime training package info: __version__: %s' % version)
|
||||
warnings.warn('onnxruntime training package info: cuda_version: %s' % cuda_version)
|
||||
warnings.warn('onnxruntime build info: cudart_version: %s' % cudart_version)
|
||||
|
||||
# collection cuda library info from current environment.
|
||||
from onnxruntime.capi.onnxruntime_collect_build_info import find_cudart_versions
|
||||
local_cudart_versions = find_cudart_versions(build_env=False, build_cuda_version=cuda_version)
|
||||
if cudart_version and cudart_version not in local_cudart_versions:
|
||||
print_build_package_info()
|
||||
warnings.warn('WARNING: failed to find cudart version that matches onnxruntime build info')
|
||||
warnings.warn('WARNING: found cudart versions: %s' % local_cudart_versions)
|
||||
else:
|
||||
# TODO: rcom
|
||||
pass
|
||||
|
||||
except Exception as e: # noqa
|
||||
warnings.warn('WARNING: failed to collect onnxruntime version and build info')
|
||||
print(e)
|
||||
|
||||
if import_ortmodule_exception:
|
||||
raise import_ortmodule_exception
|
||||
|
||||
return has_ortmodule, package_name, version, cuda_version
|
||||
|
||||
|
||||
has_ortmodule, package_name, version, cuda_version = validate_build_package_info()
|
||||
|
|
|
|||
33
setup.py
33
setup.py
|
|
@ -262,9 +262,8 @@ if enable_training:
|
|||
# install an onnxruntime training package with matching torch cuda version.
|
||||
package_name = 'onnxruntime-training'
|
||||
if cuda_version:
|
||||
# removing '.' to make Cuda version number in the same form as Pytorch.
|
||||
cuda_version = cuda_version.replace('.', '')
|
||||
local_version = '+cu' + cuda_version
|
||||
# removing '.' to make local Cuda version number in the same form as Pytorch.
|
||||
local_version = '+cu' + cuda_version.replace('.', '')
|
||||
if rocm_version:
|
||||
# removing '.' to make Cuda version number in the same form as Pytorch.
|
||||
rocm_version = rocm_version.replace('.', '')
|
||||
|
|
@ -369,6 +368,34 @@ if not path.exists(requirements_path):
|
|||
with open(requirements_path) as f:
|
||||
install_requires = f.read().splitlines()
|
||||
|
||||
if enable_training:
|
||||
def save_build_and_package_info(package_name, version_number, cuda_version):
|
||||
|
||||
sys.path.append(path.join(path.dirname(__file__), 'onnxruntime', 'python'))
|
||||
from onnxruntime_collect_build_info import find_cudart_versions
|
||||
|
||||
version_path = path.join('onnxruntime', 'capi', 'build_and_package_info.py')
|
||||
with open(version_path, 'w') as f:
|
||||
f.write("package_name = '{}'\n".format(package_name))
|
||||
f.write("__version__ = '{}'\n".format(version_number))
|
||||
|
||||
if cuda_version:
|
||||
f.write("cuda_version = '{}'\n".format(cuda_version))
|
||||
|
||||
# cudart_versions are integers
|
||||
cudart_versions = find_cudart_versions(build_env=True)
|
||||
if len(cudart_versions) == 1:
|
||||
f.write("cudart_version = {}\n".format(cudart_versions[0]))
|
||||
else:
|
||||
print(
|
||||
"Error getting cudart version. ",
|
||||
"did not find any cudart library" if len(cudart_versions) == 0 else "found multiple cudart libraries")
|
||||
else:
|
||||
# TODO: rocm
|
||||
pass
|
||||
|
||||
save_build_and_package_info(package_name, version_number, cuda_version)
|
||||
|
||||
# Setup
|
||||
setup(
|
||||
name=package_name,
|
||||
|
|
|
|||
|
|
@ -1298,7 +1298,13 @@ def run_orttraining_test_orttrainer_frontend_separately(cwd):
|
|||
|
||||
|
||||
def run_training_python_frontend_tests(cwd):
|
||||
run_subprocess([sys.executable, 'onnxruntime_test_ort_trainer.py'], cwd=cwd)
|
||||
# have to disable due to (with torchvision==0.9.1+cu102 which is required by ortmodule):
|
||||
# Downloading http://yann.lecun.com/exdb/mnist/
|
||||
# https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
|
||||
# Failed to download (trying next):
|
||||
# HTTP Error 404: Not Found
|
||||
# run_subprocess([sys.executable, 'onnxruntime_test_ort_trainer.py'], cwd=cwd)
|
||||
|
||||
run_subprocess([sys.executable, 'onnxruntime_test_training_unit_tests.py'], cwd=cwd)
|
||||
run_subprocess([
|
||||
sys.executable, 'orttraining_test_transformers.py',
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@
|
|||
sklearn
|
||||
numpy==1.16.6
|
||||
transformers==v2.10.0
|
||||
torch==1.6.0+cu101
|
||||
torchvision==0.7.0+cu101
|
||||
torchtext==0.7.0
|
||||
torch==1.8.1+cu102
|
||||
torchvision==0.9.1+cu102
|
||||
torchtext==0.9.1
|
||||
tensorboard==v2.0.0
|
||||
h5py
|
||||
|
|
|
|||
Loading…
Reference in a new issue