Print full stacktrace exception when exporter fails (#9169)

This commit is contained in:
baijumeswani 2021-09-24 07:24:37 -07:00 committed by GitHub
parent 39dc6ea8a3
commit fd91bf91c9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 96 additions and 84 deletions

View file

@ -7,11 +7,18 @@ from . import _logger
import os
import torch
import traceback
import warnings
from enum import IntFlag
from typing import Optional
from ._fallback_exceptions import (ORTModuleFallbackException,
ORTModuleInitException,
ORTModuleDeviceException,
ORTModuleIOError,
ORTModuleTorchModelException,
ORTModuleONNXModelException,
wrap_exception)
from . import _utils
class _FallbackPolicy(IntFlag):
@ -43,67 +50,6 @@ class _FallbackPolicy(IntFlag):
return _FallbackPolicy.FALLBACK_DISABLE in self
class ORTModuleFallbackException(Exception):
'''Base exception class for fallback
Although it must be specialized for specific scenarios,
it can also be used for generic exception that require fallback
'''
pass
class ORTModuleInitException(ORTModuleFallbackException):
'''Trigger fallback for ORTModule initialization related exceptions
This exception is triggered when an incompatible or missing requirements for ORTModule are detected,
including PyTorch version, missing ORTModule's PyTorch C++ extension binaries, etc.
'''
pass
class ORTModuleDeviceException(ORTModuleFallbackException):
'''Trigger fallback for device related exceptions
NOTE: This exception is raised during device validation within ORTModule frontend.
Some device related exceptions can only be detected during PyTorch ONNX exporter execution.
This exception does not capture these scenarios.
'''
pass
class ORTModuleIOError(ORTModuleFallbackException):
'''Trigger fallback for I/O related exceptions
NOTE: This exception is raised during I/O validation within ORTModule Frontend.
Some I/O related exceptions can only be detected during PyTorch ONNX exporter execution.
This exception does not capture these scenarios.
'''
pass
class ORTModuleTorchModelException(ORTModuleFallbackException):
'''Trigger fallback for PyTorch modules related exceptions
This exception is raised during model validation within ORTModule frontend and is based on
checking type(model) over a hardcoded list of incompatible models.
'''
pass
class ORTModuleONNXModelException(ORTModuleFallbackException):
'''Trigger fallback for ONNX model related exceptions
This exception is raised during model conversion to ONNX and post-processing validation within ORTModule frontend.
'''
pass
class _FallbackManager(object):
'''Manages fallbacks based on incoming exceptions and specified policies
@ -210,28 +156,9 @@ class _FallbackManager(object):
warnings.warn(
(f'Fallback to PyTorch due to exception {type(self._exception)} was triggered. '
'Report this issue with a minimal repro at https://www.github.com/microsoft/onnxruntime. '
f'See details below:\n\n{get_exception_as_string(self._exception)}'), UserWarning)
f'See details below:\n\n{_utils.get_exception_as_string(self._exception)}'), UserWarning)
# Pending fallbacks are resetted to enforce retries
if self.retry:
self._exception = None
return model(*inputs, **kwargs)
def wrap_exception(new_exception: ORTModuleFallbackException, raised_exception: Exception) -> ORTModuleFallbackException:
'''Wraps `raised_exception` exception as cause for the returned `new_exception` exception'''
exception = None
try:
raise new_exception(raised_exception) from raised_exception
except Exception as e:
exception = e
return exception
def get_exception_as_string(exception):
assert isinstance(exception, Exception), 'exception must be a `Exception`'
try:
raise exception
except:
return traceback.format_exc()

View file

@ -0,0 +1,75 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# _fallback_exceptions.py
class ORTModuleFallbackException(Exception):
'''Base exception class for fallback
Although it must be specialized for specific scenarios,
it can also be used for generic exception that require fallback
'''
pass
class ORTModuleInitException(ORTModuleFallbackException):
'''Trigger fallback for ORTModule initialization related exceptions
This exception is triggered when an incompatible or missing requirements for ORTModule are detected,
including PyTorch version, missing ORTModule's PyTorch C++ extension binaries, etc.
'''
pass
class ORTModuleDeviceException(ORTModuleFallbackException):
'''Trigger fallback for device related exceptions
NOTE: This exception is raised during device validation within ORTModule frontend.
Some device related exceptions can only be detected during PyTorch ONNX exporter execution.
This exception does not capture these scenarios.
'''
pass
class ORTModuleIOError(ORTModuleFallbackException):
'''Trigger fallback for I/O related exceptions
NOTE: This exception is raised during I/O validation within ORTModule Frontend.
Some I/O related exceptions can only be detected during PyTorch ONNX exporter execution.
This exception does not capture these scenarios.
'''
pass
class ORTModuleTorchModelException(ORTModuleFallbackException):
'''Trigger fallback for PyTorch modules related exceptions
This exception is raised during model validation within ORTModule frontend and is based on
checking type(model) over a hardcoded list of incompatible models.
'''
pass
class ORTModuleONNXModelException(ORTModuleFallbackException):
'''Trigger fallback for ONNX model related exceptions
This exception is raised during model conversion to ONNX and post-processing validation within ORTModule frontend.
'''
pass
def wrap_exception(new_exception: ORTModuleFallbackException, raised_exception: Exception) -> ORTModuleFallbackException:
'''Wraps `raised_exception` exception as cause for the returned `new_exception` exception'''
exception = None
try:
raise new_exception(raised_exception) from raised_exception
except Exception as e:
exception = e
return exception

View file

@ -368,7 +368,8 @@ class GraphExecutionManager(GraphExecutionInterface):
keep_initializers_as_inputs=True)
except Exception as e:
raise wrap_exception(ORTModuleONNXModelException,
RuntimeError(f'There was an error while exporting the PyTorch model to ONNX: {e}'))
RuntimeError(f'There was an error while exporting the PyTorch model to ONNX: '
f'\n\n{_utils.get_exception_as_string(e)}'))
exported_model = onnx.load_model_from_string(f.getvalue())
exported_model = _post_process_after_export(exported_model,

View file

@ -5,13 +5,14 @@
from onnxruntime.capi.onnxruntime_inference_collection import OrtValue
from onnxruntime.capi import _pybind_state as C
from ._fallback import _FallbackManager, ORTModuleFallbackException, ORTModuleDeviceException, wrap_exception
from ._fallback_exceptions import ORTModuleDeviceException, wrap_exception
import os
import copy
import inspect
import torch
from torch.utils.dlpack import from_dlpack, to_dlpack
import traceback
from typing import List
import types
import warnings
@ -199,3 +200,11 @@ def parse_os_env_skip_check_flags(env_name, default_skip_check_str):
"""Returns a list of SkipChecks as defined by os env variable env_name or default provided"""
return os.getenv(env_name, default_skip_check_str).split('|')
def get_exception_as_string(exception):
assert isinstance(exception, Exception), 'exception must be a `Exception`'
try:
raise exception
except:
return traceback.format_exc()