From c0994fdfbb38345b2037abc2b4c7d5490243a2f0 Mon Sep 17 00:00:00 2001 From: Suffian Khan Date: Tue, 23 Mar 2021 11:09:48 -0700 Subject: [PATCH] Update ORTTrainer to permit Rocm and permit export of opset 13 (#7059) * update orttrainer to permit rocm and allow export for opset 13 * wrap rocm check in try-except block --- .../orttraining/python/training/orttrainer.py | 15 ++++++++++++--- .../python/training/orttrainer_options.py | 4 ++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index 92ac42d815..f7b3e5e68b 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -193,6 +193,12 @@ class ORTTrainer(object): break assert dtype is not None, f"ONNX model with unknown output type ({o_desc.name})" + try: + from torch.utils.cpp_extension import ROCM_HOME + self.is_rocm_pytorch = (True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False) + except ImportError: + self.is_rocm_pytorch = False + # TODO: Remove when experimental checkpoint functions are removed. self._state_dict = {} @@ -675,10 +681,13 @@ class ORTTrainer(object): if 'cuda' in self.options.device.id.lower(): cuda_ep_options = {"device_id": _utils.get_device_index(self.options.device.id)} - if self.options.device.mem_limit > 0: - cuda_ep_options["cuda_mem_limit"] = self.options.device.mem_limit - cuda_ep_name = "CUDAExecutionProvider" + cuda_ep_name = ("ROCMExecutionProvider" if self.is_rocm_pytorch else "CUDAExecutionProvider") + if self.options.device.mem_limit > 0: + if not self.is_rocm_pytorch: + cuda_ep_options["cuda_mem_limit"] = self.options.device.mem_limit + else: + warnings.warn("Ignoring 'mem_limit' for {}".format(cuda_ep_name)) if cuda_ep_name not in providers: raise RuntimeError( diff --git a/orttraining/orttraining/python/training/orttrainer_options.py b/orttraining/orttraining/python/training/orttrainer_options.py index eb74879e0f..8208ea22be 100644 --- a/orttraining/orttraining/python/training/orttrainer_options.py +++ b/orttraining/orttraining/python/training/orttrainer_options.py @@ -269,7 +269,7 @@ class ORTTrainerOptions(object): 'onnx_opset_version': { 'type': 'integer', 'min' : 12, - 'max' : 12, + 'max' : 13, 'default': 12 }, 'enable_onnx_contrib_ops' : { @@ -723,7 +723,7 @@ _ORTTRAINER_OPTIONS_SCHEMA = { 'onnx_opset_version': { 'type': 'integer', 'min': 12, - 'max': 12, + 'max': 13, 'default': 12 }, 'enable_onnx_contrib_ops': {