Update ORTTrainer to permit Rocm and permit export of opset 13 (#7059)

* update orttrainer to permit rocm and allow export for opset 13

* wrap rocm check in try-except block
This commit is contained in:
Suffian Khan 2021-03-23 11:09:48 -07:00 committed by GitHub
parent 53392664d3
commit c0994fdfbb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 5 deletions

View file

@ -193,6 +193,12 @@ class ORTTrainer(object):
break
assert dtype is not None, f"ONNX model with unknown output type ({o_desc.name})"
try:
from torch.utils.cpp_extension import ROCM_HOME
self.is_rocm_pytorch = (True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False)
except ImportError:
self.is_rocm_pytorch = False
# TODO: Remove when experimental checkpoint functions are removed.
self._state_dict = {}
@ -675,10 +681,13 @@ class ORTTrainer(object):
if 'cuda' in self.options.device.id.lower():
cuda_ep_options = {"device_id": _utils.get_device_index(self.options.device.id)}
if self.options.device.mem_limit > 0:
cuda_ep_options["cuda_mem_limit"] = self.options.device.mem_limit
cuda_ep_name = "CUDAExecutionProvider"
cuda_ep_name = ("ROCMExecutionProvider" if self.is_rocm_pytorch else "CUDAExecutionProvider")
if self.options.device.mem_limit > 0:
if not self.is_rocm_pytorch:
cuda_ep_options["cuda_mem_limit"] = self.options.device.mem_limit
else:
warnings.warn("Ignoring 'mem_limit' for {}".format(cuda_ep_name))
if cuda_ep_name not in providers:
raise RuntimeError(

View file

@ -269,7 +269,7 @@ class ORTTrainerOptions(object):
'onnx_opset_version': {
'type': 'integer',
'min' : 12,
'max' : 12,
'max' : 13,
'default': 12
},
'enable_onnx_contrib_ops' : {
@ -723,7 +723,7 @@ _ORTTRAINER_OPTIONS_SCHEMA = {
'onnx_opset_version': {
'type': 'integer',
'min': 12,
'max': 12,
'max': 13,
'default': 12
},
'enable_onnx_contrib_ops': {