mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Update ORTTrainer to permit Rocm and permit export of opset 13 (#7059)
* update orttrainer to permit rocm and allow export for opset 13 * wrap rocm check in try-except block
This commit is contained in:
parent
53392664d3
commit
c0994fdfbb
2 changed files with 14 additions and 5 deletions
|
|
@ -193,6 +193,12 @@ class ORTTrainer(object):
|
|||
break
|
||||
assert dtype is not None, f"ONNX model with unknown output type ({o_desc.name})"
|
||||
|
||||
try:
|
||||
from torch.utils.cpp_extension import ROCM_HOME
|
||||
self.is_rocm_pytorch = (True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False)
|
||||
except ImportError:
|
||||
self.is_rocm_pytorch = False
|
||||
|
||||
# TODO: Remove when experimental checkpoint functions are removed.
|
||||
self._state_dict = {}
|
||||
|
||||
|
|
@ -675,10 +681,13 @@ class ORTTrainer(object):
|
|||
|
||||
if 'cuda' in self.options.device.id.lower():
|
||||
cuda_ep_options = {"device_id": _utils.get_device_index(self.options.device.id)}
|
||||
if self.options.device.mem_limit > 0:
|
||||
cuda_ep_options["cuda_mem_limit"] = self.options.device.mem_limit
|
||||
|
||||
cuda_ep_name = "CUDAExecutionProvider"
|
||||
cuda_ep_name = ("ROCMExecutionProvider" if self.is_rocm_pytorch else "CUDAExecutionProvider")
|
||||
if self.options.device.mem_limit > 0:
|
||||
if not self.is_rocm_pytorch:
|
||||
cuda_ep_options["cuda_mem_limit"] = self.options.device.mem_limit
|
||||
else:
|
||||
warnings.warn("Ignoring 'mem_limit' for {}".format(cuda_ep_name))
|
||||
|
||||
if cuda_ep_name not in providers:
|
||||
raise RuntimeError(
|
||||
|
|
|
|||
|
|
@ -269,7 +269,7 @@ class ORTTrainerOptions(object):
|
|||
'onnx_opset_version': {
|
||||
'type': 'integer',
|
||||
'min' : 12,
|
||||
'max' : 12,
|
||||
'max' : 13,
|
||||
'default': 12
|
||||
},
|
||||
'enable_onnx_contrib_ops' : {
|
||||
|
|
@ -723,7 +723,7 @@ _ORTTRAINER_OPTIONS_SCHEMA = {
|
|||
'onnx_opset_version': {
|
||||
'type': 'integer',
|
||||
'min': 12,
|
||||
'max': 12,
|
||||
'max': 13,
|
||||
'default': 12
|
||||
},
|
||||
'enable_onnx_contrib_ops': {
|
||||
|
|
|
|||
Loading…
Reference in a new issue