diff --git a/orttraining/orttraining/python/training/ortmodule/_fallback.py b/orttraining/orttraining/python/training/ortmodule/_fallback.py index 53b0e25009..4c316e68b4 100644 --- a/orttraining/orttraining/python/training/ortmodule/_fallback.py +++ b/orttraining/orttraining/python/training/ortmodule/_fallback.py @@ -7,11 +7,18 @@ from . import _logger import os import torch -import traceback import warnings from enum import IntFlag from typing import Optional +from ._fallback_exceptions import (ORTModuleFallbackException, + ORTModuleInitException, + ORTModuleDeviceException, + ORTModuleIOError, + ORTModuleTorchModelException, + ORTModuleONNXModelException, + wrap_exception) +from . import _utils class _FallbackPolicy(IntFlag): @@ -43,67 +50,6 @@ class _FallbackPolicy(IntFlag): return _FallbackPolicy.FALLBACK_DISABLE in self -class ORTModuleFallbackException(Exception): - '''Base exception class for fallback - - Although it must be specialized for specific scenarios, - it can also be used for generic exception that require fallback - ''' - - pass - - -class ORTModuleInitException(ORTModuleFallbackException): - '''Trigger fallback for ORTModule initialization related exceptions - - This exception is triggered when an incompatible or missing requirements for ORTModule are detected, - including PyTorch version, missing ORTModule's PyTorch C++ extension binaries, etc. - ''' - - pass - - -class ORTModuleDeviceException(ORTModuleFallbackException): - '''Trigger fallback for device related exceptions - - NOTE: This exception is raised during device validation within ORTModule frontend. - Some device related exceptions can only be detected during PyTorch ONNX exporter execution. - This exception does not capture these scenarios. - ''' - - pass - - -class ORTModuleIOError(ORTModuleFallbackException): - '''Trigger fallback for I/O related exceptions - - NOTE: This exception is raised during I/O validation within ORTModule Frontend. - Some I/O related exceptions can only be detected during PyTorch ONNX exporter execution. - This exception does not capture these scenarios. - ''' - - pass - - -class ORTModuleTorchModelException(ORTModuleFallbackException): - '''Trigger fallback for PyTorch modules related exceptions - - This exception is raised during model validation within ORTModule frontend and is based on - checking type(model) over a hardcoded list of incompatible models. - ''' - - pass - - -class ORTModuleONNXModelException(ORTModuleFallbackException): - '''Trigger fallback for ONNX model related exceptions - - This exception is raised during model conversion to ONNX and post-processing validation within ORTModule frontend. - ''' - - pass - - class _FallbackManager(object): '''Manages fallbacks based on incoming exceptions and specified policies @@ -210,28 +156,9 @@ class _FallbackManager(object): warnings.warn( (f'Fallback to PyTorch due to exception {type(self._exception)} was triggered. ' 'Report this issue with a minimal repro at https://www.github.com/microsoft/onnxruntime. ' - f'See details below:\n\n{get_exception_as_string(self._exception)}'), UserWarning) + f'See details below:\n\n{_utils.get_exception_as_string(self._exception)}'), UserWarning) # Pending fallbacks are resetted to enforce retries if self.retry: self._exception = None return model(*inputs, **kwargs) - - -def wrap_exception(new_exception: ORTModuleFallbackException, raised_exception: Exception) -> ORTModuleFallbackException: - '''Wraps `raised_exception` exception as cause for the returned `new_exception` exception''' - - exception = None - try: - raise new_exception(raised_exception) from raised_exception - except Exception as e: - exception = e - return exception - -def get_exception_as_string(exception): - assert isinstance(exception, Exception), 'exception must be a `Exception`' - - try: - raise exception - except: - return traceback.format_exc() diff --git a/orttraining/orttraining/python/training/ortmodule/_fallback_exceptions.py b/orttraining/orttraining/python/training/ortmodule/_fallback_exceptions.py new file mode 100644 index 0000000000..2b9914bdf7 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/_fallback_exceptions.py @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# _fallback_exceptions.py + + +class ORTModuleFallbackException(Exception): + '''Base exception class for fallback + + Although it must be specialized for specific scenarios, + it can also be used for generic exception that require fallback + ''' + + pass + + +class ORTModuleInitException(ORTModuleFallbackException): + '''Trigger fallback for ORTModule initialization related exceptions + + This exception is triggered when an incompatible or missing requirements for ORTModule are detected, + including PyTorch version, missing ORTModule's PyTorch C++ extension binaries, etc. + ''' + + pass + + +class ORTModuleDeviceException(ORTModuleFallbackException): + '''Trigger fallback for device related exceptions + + NOTE: This exception is raised during device validation within ORTModule frontend. + Some device related exceptions can only be detected during PyTorch ONNX exporter execution. + This exception does not capture these scenarios. + ''' + + pass + + +class ORTModuleIOError(ORTModuleFallbackException): + '''Trigger fallback for I/O related exceptions + + NOTE: This exception is raised during I/O validation within ORTModule Frontend. + Some I/O related exceptions can only be detected during PyTorch ONNX exporter execution. + This exception does not capture these scenarios. + ''' + + pass + + +class ORTModuleTorchModelException(ORTModuleFallbackException): + '''Trigger fallback for PyTorch modules related exceptions + + This exception is raised during model validation within ORTModule frontend and is based on + checking type(model) over a hardcoded list of incompatible models. + ''' + + pass + + +class ORTModuleONNXModelException(ORTModuleFallbackException): + '''Trigger fallback for ONNX model related exceptions + + This exception is raised during model conversion to ONNX and post-processing validation within ORTModule frontend. + ''' + + pass + + +def wrap_exception(new_exception: ORTModuleFallbackException, raised_exception: Exception) -> ORTModuleFallbackException: + '''Wraps `raised_exception` exception as cause for the returned `new_exception` exception''' + + exception = None + try: + raise new_exception(raised_exception) from raised_exception + except Exception as e: + exception = e + return exception \ No newline at end of file diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index beb063bef6..8de45c29e5 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -368,7 +368,8 @@ class GraphExecutionManager(GraphExecutionInterface): keep_initializers_as_inputs=True) except Exception as e: raise wrap_exception(ORTModuleONNXModelException, - RuntimeError(f'There was an error while exporting the PyTorch model to ONNX: {e}')) + RuntimeError(f'There was an error while exporting the PyTorch model to ONNX: ' + f'\n\n{_utils.get_exception_as_string(e)}')) exported_model = onnx.load_model_from_string(f.getvalue()) exported_model = _post_process_after_export(exported_model, diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index 50161e8e77..1cf035c37b 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -5,13 +5,14 @@ from onnxruntime.capi.onnxruntime_inference_collection import OrtValue from onnxruntime.capi import _pybind_state as C -from ._fallback import _FallbackManager, ORTModuleFallbackException, ORTModuleDeviceException, wrap_exception +from ._fallback_exceptions import ORTModuleDeviceException, wrap_exception import os import copy import inspect import torch from torch.utils.dlpack import from_dlpack, to_dlpack +import traceback from typing import List import types import warnings @@ -199,3 +200,11 @@ def parse_os_env_skip_check_flags(env_name, default_skip_check_str): """Returns a list of SkipChecks as defined by os env variable env_name or default provided""" return os.getenv(env_name, default_skip_check_str).split('|') + +def get_exception_as_string(exception): + assert isinstance(exception, Exception), 'exception must be a `Exception`' + + try: + raise exception + except: + return traceback.format_exc()