diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index 92ac42d815..f7b3e5e68b 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -193,6 +193,12 @@ class ORTTrainer(object): break assert dtype is not None, f"ONNX model with unknown output type ({o_desc.name})" + try: + from torch.utils.cpp_extension import ROCM_HOME + self.is_rocm_pytorch = (True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False) + except ImportError: + self.is_rocm_pytorch = False + # TODO: Remove when experimental checkpoint functions are removed. self._state_dict = {} @@ -675,10 +681,13 @@ class ORTTrainer(object): if 'cuda' in self.options.device.id.lower(): cuda_ep_options = {"device_id": _utils.get_device_index(self.options.device.id)} - if self.options.device.mem_limit > 0: - cuda_ep_options["cuda_mem_limit"] = self.options.device.mem_limit - cuda_ep_name = "CUDAExecutionProvider" + cuda_ep_name = ("ROCMExecutionProvider" if self.is_rocm_pytorch else "CUDAExecutionProvider") + if self.options.device.mem_limit > 0: + if not self.is_rocm_pytorch: + cuda_ep_options["cuda_mem_limit"] = self.options.device.mem_limit + else: + warnings.warn("Ignoring 'mem_limit' for {}".format(cuda_ep_name)) if cuda_ep_name not in providers: raise RuntimeError( diff --git a/orttraining/orttraining/python/training/orttrainer_options.py b/orttraining/orttraining/python/training/orttrainer_options.py index eb74879e0f..8208ea22be 100644 --- a/orttraining/orttraining/python/training/orttrainer_options.py +++ b/orttraining/orttraining/python/training/orttrainer_options.py @@ -269,7 +269,7 @@ class ORTTrainerOptions(object): 'onnx_opset_version': { 'type': 'integer', 'min' : 12, - 'max' : 12, + 'max' : 13, 'default': 12 }, 'enable_onnx_contrib_ops' : { @@ -723,7 +723,7 @@ _ORTTRAINER_OPTIONS_SCHEMA = { 'onnx_opset_version': { 'type': 'integer', 'min': 12, - 'max': 12, + 'max': 13, 'default': 12 }, 'enable_onnx_contrib_ops': {