mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Print full stacktrace exception when exporter fails (#9169)
This commit is contained in:
parent
39dc6ea8a3
commit
fd91bf91c9
4 changed files with 96 additions and 84 deletions
|
|
@ -7,11 +7,18 @@ from . import _logger
|
|||
|
||||
import os
|
||||
import torch
|
||||
import traceback
|
||||
import warnings
|
||||
|
||||
from enum import IntFlag
|
||||
from typing import Optional
|
||||
from ._fallback_exceptions import (ORTModuleFallbackException,
|
||||
ORTModuleInitException,
|
||||
ORTModuleDeviceException,
|
||||
ORTModuleIOError,
|
||||
ORTModuleTorchModelException,
|
||||
ORTModuleONNXModelException,
|
||||
wrap_exception)
|
||||
from . import _utils
|
||||
|
||||
|
||||
class _FallbackPolicy(IntFlag):
|
||||
|
|
@ -43,67 +50,6 @@ class _FallbackPolicy(IntFlag):
|
|||
return _FallbackPolicy.FALLBACK_DISABLE in self
|
||||
|
||||
|
||||
class ORTModuleFallbackException(Exception):
|
||||
'''Base exception class for fallback
|
||||
|
||||
Although it must be specialized for specific scenarios,
|
||||
it can also be used for generic exception that require fallback
|
||||
'''
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ORTModuleInitException(ORTModuleFallbackException):
|
||||
'''Trigger fallback for ORTModule initialization related exceptions
|
||||
|
||||
This exception is triggered when an incompatible or missing requirements for ORTModule are detected,
|
||||
including PyTorch version, missing ORTModule's PyTorch C++ extension binaries, etc.
|
||||
'''
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ORTModuleDeviceException(ORTModuleFallbackException):
|
||||
'''Trigger fallback for device related exceptions
|
||||
|
||||
NOTE: This exception is raised during device validation within ORTModule frontend.
|
||||
Some device related exceptions can only be detected during PyTorch ONNX exporter execution.
|
||||
This exception does not capture these scenarios.
|
||||
'''
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ORTModuleIOError(ORTModuleFallbackException):
|
||||
'''Trigger fallback for I/O related exceptions
|
||||
|
||||
NOTE: This exception is raised during I/O validation within ORTModule Frontend.
|
||||
Some I/O related exceptions can only be detected during PyTorch ONNX exporter execution.
|
||||
This exception does not capture these scenarios.
|
||||
'''
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ORTModuleTorchModelException(ORTModuleFallbackException):
|
||||
'''Trigger fallback for PyTorch modules related exceptions
|
||||
|
||||
This exception is raised during model validation within ORTModule frontend and is based on
|
||||
checking type(model) over a hardcoded list of incompatible models.
|
||||
'''
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ORTModuleONNXModelException(ORTModuleFallbackException):
|
||||
'''Trigger fallback for ONNX model related exceptions
|
||||
|
||||
This exception is raised during model conversion to ONNX and post-processing validation within ORTModule frontend.
|
||||
'''
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class _FallbackManager(object):
|
||||
'''Manages fallbacks based on incoming exceptions and specified policies
|
||||
|
||||
|
|
@ -210,28 +156,9 @@ class _FallbackManager(object):
|
|||
warnings.warn(
|
||||
(f'Fallback to PyTorch due to exception {type(self._exception)} was triggered. '
|
||||
'Report this issue with a minimal repro at https://www.github.com/microsoft/onnxruntime. '
|
||||
f'See details below:\n\n{get_exception_as_string(self._exception)}'), UserWarning)
|
||||
f'See details below:\n\n{_utils.get_exception_as_string(self._exception)}'), UserWarning)
|
||||
|
||||
# Pending fallbacks are resetted to enforce retries
|
||||
if self.retry:
|
||||
self._exception = None
|
||||
return model(*inputs, **kwargs)
|
||||
|
||||
|
||||
def wrap_exception(new_exception: ORTModuleFallbackException, raised_exception: Exception) -> ORTModuleFallbackException:
|
||||
'''Wraps `raised_exception` exception as cause for the returned `new_exception` exception'''
|
||||
|
||||
exception = None
|
||||
try:
|
||||
raise new_exception(raised_exception) from raised_exception
|
||||
except Exception as e:
|
||||
exception = e
|
||||
return exception
|
||||
|
||||
def get_exception_as_string(exception):
|
||||
assert isinstance(exception, Exception), 'exception must be a `Exception`'
|
||||
|
||||
try:
|
||||
raise exception
|
||||
except:
|
||||
return traceback.format_exc()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# _fallback_exceptions.py
|
||||
|
||||
|
||||
class ORTModuleFallbackException(Exception):
|
||||
'''Base exception class for fallback
|
||||
|
||||
Although it must be specialized for specific scenarios,
|
||||
it can also be used for generic exception that require fallback
|
||||
'''
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ORTModuleInitException(ORTModuleFallbackException):
|
||||
'''Trigger fallback for ORTModule initialization related exceptions
|
||||
|
||||
This exception is triggered when an incompatible or missing requirements for ORTModule are detected,
|
||||
including PyTorch version, missing ORTModule's PyTorch C++ extension binaries, etc.
|
||||
'''
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ORTModuleDeviceException(ORTModuleFallbackException):
|
||||
'''Trigger fallback for device related exceptions
|
||||
|
||||
NOTE: This exception is raised during device validation within ORTModule frontend.
|
||||
Some device related exceptions can only be detected during PyTorch ONNX exporter execution.
|
||||
This exception does not capture these scenarios.
|
||||
'''
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ORTModuleIOError(ORTModuleFallbackException):
|
||||
'''Trigger fallback for I/O related exceptions
|
||||
|
||||
NOTE: This exception is raised during I/O validation within ORTModule Frontend.
|
||||
Some I/O related exceptions can only be detected during PyTorch ONNX exporter execution.
|
||||
This exception does not capture these scenarios.
|
||||
'''
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ORTModuleTorchModelException(ORTModuleFallbackException):
|
||||
'''Trigger fallback for PyTorch modules related exceptions
|
||||
|
||||
This exception is raised during model validation within ORTModule frontend and is based on
|
||||
checking type(model) over a hardcoded list of incompatible models.
|
||||
'''
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ORTModuleONNXModelException(ORTModuleFallbackException):
|
||||
'''Trigger fallback for ONNX model related exceptions
|
||||
|
||||
This exception is raised during model conversion to ONNX and post-processing validation within ORTModule frontend.
|
||||
'''
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def wrap_exception(new_exception: ORTModuleFallbackException, raised_exception: Exception) -> ORTModuleFallbackException:
|
||||
'''Wraps `raised_exception` exception as cause for the returned `new_exception` exception'''
|
||||
|
||||
exception = None
|
||||
try:
|
||||
raise new_exception(raised_exception) from raised_exception
|
||||
except Exception as e:
|
||||
exception = e
|
||||
return exception
|
||||
|
|
@ -368,7 +368,8 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
keep_initializers_as_inputs=True)
|
||||
except Exception as e:
|
||||
raise wrap_exception(ORTModuleONNXModelException,
|
||||
RuntimeError(f'There was an error while exporting the PyTorch model to ONNX: {e}'))
|
||||
RuntimeError(f'There was an error while exporting the PyTorch model to ONNX: '
|
||||
f'\n\n{_utils.get_exception_as_string(e)}'))
|
||||
exported_model = onnx.load_model_from_string(f.getvalue())
|
||||
|
||||
exported_model = _post_process_after_export(exported_model,
|
||||
|
|
|
|||
|
|
@ -5,13 +5,14 @@
|
|||
|
||||
from onnxruntime.capi.onnxruntime_inference_collection import OrtValue
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
from ._fallback import _FallbackManager, ORTModuleFallbackException, ORTModuleDeviceException, wrap_exception
|
||||
from ._fallback_exceptions import ORTModuleDeviceException, wrap_exception
|
||||
|
||||
import os
|
||||
import copy
|
||||
import inspect
|
||||
import torch
|
||||
from torch.utils.dlpack import from_dlpack, to_dlpack
|
||||
import traceback
|
||||
from typing import List
|
||||
import types
|
||||
import warnings
|
||||
|
|
@ -199,3 +200,11 @@ def parse_os_env_skip_check_flags(env_name, default_skip_check_str):
|
|||
"""Returns a list of SkipChecks as defined by os env variable env_name or default provided"""
|
||||
|
||||
return os.getenv(env_name, default_skip_check_str).split('|')
|
||||
|
||||
def get_exception_as_string(exception):
|
||||
assert isinstance(exception, Exception), 'exception must be a `Exception`'
|
||||
|
||||
try:
|
||||
raise exception
|
||||
except:
|
||||
return traceback.format_exc()
|
||||
|
|
|
|||
Loading…
Reference in a new issue