mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-17 01:44:45 +00:00
Add support for experimental json config for fallback (#8759)
This commit is contained in:
parent
6ecf626a9c
commit
ed254c283f
6 changed files with 85 additions and 22 deletions
|
|
@ -133,7 +133,7 @@ class _FallbackManager(object):
|
|||
# Read retry from environment variable for testing purposes
|
||||
retry = os.getenv('ORTMODULE_FALLBACK_RETRY', str(retry)).lower() in ['true', '1', 'yes']
|
||||
|
||||
self.policy_exception_map = {_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD.value: {ORTModuleFallbackException,
|
||||
self._policy_exception_map = {_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD.value: {ORTModuleFallbackException,
|
||||
ORTModuleDeviceException,
|
||||
ORTModuleIOError,
|
||||
ORTModuleTorchModelException,
|
||||
|
|
@ -176,7 +176,7 @@ class _FallbackManager(object):
|
|||
def _set_exception(policy: _FallbackPolicy, exception: Exception, log_level: _logger.LogLevel):
|
||||
if policy is not _FallbackPolicy.FALLBACK_DISABLE and \
|
||||
self.policy.is_set(policy) and \
|
||||
(policy.value in self.policy_exception_map and type(exception) in self.policy_exception_map[policy.value]):
|
||||
(policy.value in self._policy_exception_map and type(exception) in self._policy_exception_map[policy.value]):
|
||||
|
||||
if log_level <= _logger.LogLevel.WARNING:
|
||||
warnings.warn(
|
||||
|
|
|
|||
|
|
@ -2,4 +2,7 @@
|
|||
# Licensed under the MIT License.
|
||||
# __init__.py
|
||||
|
||||
# JSON global constants goes here
|
||||
JSON_PATH_ENVIRONMENT_KEY = "ORTMODULE_JSON_CONFIG_PATH"
|
||||
|
||||
from ._load_config_from_json import load_from_json
|
||||
|
|
|
|||
|
|
@ -9,11 +9,14 @@ from types import SimpleNamespace
|
|||
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
from functools import reduce
|
||||
from . import JSON_PATH_ENVIRONMENT_KEY
|
||||
from ..._fallback import _FallbackPolicy
|
||||
from ..._graph_execution_manager import _SkipCheck
|
||||
from ...debug_options import DebugOptions, LogLevel, _SaveOnnxOptions
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load_data_from_json(path):
|
||||
"""Loads data from the json file path provided."""
|
||||
|
||||
|
|
@ -26,8 +29,9 @@ def _load_data_from_json(path):
|
|||
|
||||
return data
|
||||
|
||||
|
||||
def _load_propagate_cast_ops(ortmodule_config_accessor, data):
|
||||
"""Load PropagateCastOps from json file onto ORTModule"""
|
||||
"""Loads PropagateCastOps from json file onto ORTModule."""
|
||||
|
||||
assert hasattr(data, _load_propagate_cast_ops.loading_key)
|
||||
log.info(f"Found keyword {_load_propagate_cast_ops.loading_key} in json. Loading attributes from file.")
|
||||
|
|
@ -51,8 +55,9 @@ def _load_propagate_cast_ops(ortmodule_config_accessor, data):
|
|||
for key, _ in data.PropagateCastOps.__dict__.items():
|
||||
key_to_function_mapping[key]()
|
||||
|
||||
|
||||
def _load_use_external_gpu_allocator(ortmodule_config_accessor, data):
|
||||
"""Load UseExternalGPUAllocator from json file onto ORTModule"""
|
||||
"""Loads UseExternalGPUAllocator from json file onto ORTModule."""
|
||||
|
||||
assert hasattr(data, _load_use_external_gpu_allocator.loading_key)
|
||||
log.info(f"Found keyword {_load_use_external_gpu_allocator.loading_key} in json. Loading attributes from file.")
|
||||
|
|
@ -61,8 +66,9 @@ def _load_use_external_gpu_allocator(ortmodule_config_accessor, data):
|
|||
ortmodule_config_accessor._use_external_gpu_allocator = data.UseExternalGPUAllocator
|
||||
ortmodule_config_accessor._get_torch_gpu_allocator_function_addresses()
|
||||
|
||||
|
||||
def _load_enable_custom_autograd_function(ortmodule_config_accessor, data):
|
||||
"""Load EnableCustomAutogradFunction from json file onto ORTModule"""
|
||||
"""Loads EnableCustomAutogradFunction from json file onto ORTModule."""
|
||||
|
||||
assert hasattr(data, _load_enable_custom_autograd_function.loading_key)
|
||||
log.info(f"Found keyword {_load_enable_custom_autograd_function.loading_key} in json. Loading attributes from file.")
|
||||
|
|
@ -70,8 +76,9 @@ def _load_enable_custom_autograd_function(ortmodule_config_accessor, data):
|
|||
assert isinstance(data.EnableCustomAutogradFunction, bool), f"{_load_enable_custom_autograd_function.loading_key} must be a boolean"
|
||||
ortmodule_config_accessor._enable_custom_autograd_function = data.EnableCustomAutogradFunction
|
||||
|
||||
|
||||
def _load_allow_layer_norm_mod_precision(ortmodule_config_accessor, data):
|
||||
"""Load AllowLayerNormModPrecision from json file onto ORTModule"""
|
||||
"""Loads AllowLayerNormModPrecision from json file onto ORTModule."""
|
||||
|
||||
assert hasattr(data, _load_allow_layer_norm_mod_precision.loading_key)
|
||||
log.info(f"Found keyword {_load_allow_layer_norm_mod_precision.loading_key} in json. Loading attributes from file.")
|
||||
|
|
@ -79,8 +86,9 @@ def _load_allow_layer_norm_mod_precision(ortmodule_config_accessor, data):
|
|||
assert isinstance(data.AllowLayerNormModPrecision, bool), f"{_load_allow_layer_norm_mod_precision.loading_key} must be a boolean"
|
||||
ortmodule_config_accessor._allow_layer_norm_mod_precision = data.AllowLayerNormModPrecision
|
||||
|
||||
|
||||
def _load_enable_grad_acc_optimization(ortmodule_config_accessor, data):
|
||||
"""Load EnableGradAccOptimization from json file onto ORTModule"""
|
||||
"""Loads EnableGradAccOptimization from json file onto ORTModule."""
|
||||
|
||||
assert hasattr(data, _load_enable_grad_acc_optimization.loading_key)
|
||||
log.info(f"Found keyword {_load_enable_grad_acc_optimization.loading_key} in json. Loading attributes from file.")
|
||||
|
|
@ -88,8 +96,9 @@ def _load_enable_grad_acc_optimization(ortmodule_config_accessor, data):
|
|||
assert isinstance(data.EnableGradAccOptimization, bool), f"{_load_enable_grad_acc_optimization.loading_key} must be a boolean"
|
||||
ortmodule_config_accessor._enable_grad_acc_optimization = data.EnableGradAccOptimization
|
||||
|
||||
|
||||
def _load_run_symbolic_shape_infer(ortmodule_config_accessor, data):
|
||||
"""Load RunSymbolicShapeInference from json file onto ORTModule"""
|
||||
"""Loads RunSymbolicShapeInference from json file onto ORTModule."""
|
||||
|
||||
assert hasattr(data, _load_run_symbolic_shape_infer.loading_key)
|
||||
log.info(f"Found keyword {_load_run_symbolic_shape_infer.loading_key} in json. Loading attributes from file.")
|
||||
|
|
@ -97,8 +106,9 @@ def _load_run_symbolic_shape_infer(ortmodule_config_accessor, data):
|
|||
assert isinstance(data.RunSymbolicShapeInference, bool), f"{_load_run_symbolic_shape_infer.loading_key} must be a boolean"
|
||||
ortmodule_config_accessor._run_symbolic_shape_infer = data.RunSymbolicShapeInference
|
||||
|
||||
|
||||
def _load_use_static_shape(ortmodule_config_accessor, data):
|
||||
"""Load UseStaticShape from json file onto ORTModule"""
|
||||
"""Loads UseStaticShape from json file onto ORTModule."""
|
||||
|
||||
assert hasattr(data, _load_use_static_shape.loading_key)
|
||||
log.info(f"Found keyword {_load_use_static_shape.loading_key} in json. Loading attributes from file.")
|
||||
|
|
@ -106,33 +116,38 @@ def _load_use_static_shape(ortmodule_config_accessor, data):
|
|||
assert isinstance(data.UseStaticShape, bool), f"{_load_use_static_shape.loading_key} must be a boolean"
|
||||
ortmodule_config_accessor._use_static_shape = data.UseStaticShape
|
||||
|
||||
|
||||
def _load_skip_check(ortmodule_config_accessor, data):
|
||||
"""Load SkipCheck from json file onto ORTModule"""
|
||||
"""Loads SkipCheck from json file onto ORTModule."""
|
||||
|
||||
assert hasattr(data, _load_skip_check.loading_key)
|
||||
log.info(f"Found keyword {_load_skip_check.loading_key} in json. Loading attributes from file.")
|
||||
|
||||
skip_check = reduce(lambda x, y: x|y, [_SkipCheck[name] for name in data.SkipCheck])
|
||||
skip_check = reduce(lambda x, y: x | y, [_SkipCheck[name] for name in data.SkipCheck])
|
||||
if skip_check.value > 0:
|
||||
ortmodule_config_accessor._skip_check = skip_check
|
||||
|
||||
|
||||
def _load_debug_options(ortmodule_config_accessor, data):
|
||||
"""Load DebugOptions from json file onto ORTModule"""
|
||||
"""Loads DebugOptions from json file onto ORTModule."""
|
||||
|
||||
assert hasattr(data, _load_debug_options.loading_key)
|
||||
log.info(f"Found keyword {_load_debug_options.loading_key} in json. Loading attributes from file.")
|
||||
|
||||
log_level = LogLevel.WARNING
|
||||
|
||||
def _update_log_level():
|
||||
nonlocal log_level
|
||||
log_level = LogLevel[data.DebugOptions.LogLevel]
|
||||
|
||||
save_onnx = False
|
||||
|
||||
def _update_save_onnx():
|
||||
nonlocal save_onnx
|
||||
save_onnx = data.DebugOptions.SaveONNX
|
||||
|
||||
onnx_prefix = ''
|
||||
|
||||
def _update_onnx_prefix():
|
||||
nonlocal onnx_prefix
|
||||
onnx_prefix = data.DebugOptions.ONNXPrefix
|
||||
|
|
@ -153,8 +168,9 @@ def _load_debug_options(ortmodule_config_accessor, data):
|
|||
debug_options = DebugOptions(log_level=log_level, save_onnx=save_onnx, onnx_prefix=onnx_prefix)
|
||||
ortmodule_config_accessor._debug_options = debug_options
|
||||
|
||||
|
||||
def _load_use_memory_efficient_gradient(ortmodule_config_accessor, data):
|
||||
"""Load UseMemoryEfficientGradient from json file onto ORTModule"""
|
||||
"""Loads UseMemoryEfficientGradient from json file onto ORTModule."""
|
||||
|
||||
assert hasattr(data, _load_use_memory_efficient_gradient.loading_key)
|
||||
log.info(f"Found keyword {_load_use_memory_efficient_gradient.loading_key} in json. Loading attributes from file.")
|
||||
|
|
@ -162,6 +178,18 @@ def _load_use_memory_efficient_gradient(ortmodule_config_accessor, data):
|
|||
assert isinstance(data.UseMemoryEfficientGradient, bool), f"{_load_use_memory_efficient_gradient.loading_key} must be a boolean"
|
||||
ortmodule_config_accessor._use_memory_efficient_gradient = data.UseMemoryEfficientGradient
|
||||
|
||||
|
||||
def _load_fallback_policy(ortmodule_config_accessor, data):
|
||||
"""Loads SkipCheck from json file onto ORTModule."""
|
||||
|
||||
assert hasattr(data, _load_fallback_policy.loading_key)
|
||||
log.info(f"Found keyword {_load_fallback_policy.loading_key} in json. Loading attributes from file.")
|
||||
|
||||
fallback_policy = reduce(lambda x, y: x | y, [_FallbackPolicy[name] for name in data.FallbackPolicy])
|
||||
if fallback_policy.value > 0:
|
||||
ortmodule_config_accessor._fallback_manager.policy = fallback_policy
|
||||
|
||||
|
||||
def _define_load_function_keys():
|
||||
"""Define static key variables for each loading function"""
|
||||
|
||||
|
|
@ -175,9 +203,11 @@ def _define_load_function_keys():
|
|||
_load_skip_check.loading_key = "SkipCheck"
|
||||
_load_debug_options.loading_key = "DebugOptions"
|
||||
_load_use_memory_efficient_gradient.loading_key = "UseMemoryEfficientGradient"
|
||||
_load_fallback_policy.loading_key = "FallbackPolicy"
|
||||
|
||||
|
||||
def load_from_json(ortmodule, path=None):
|
||||
"""Load config from json file at given path.
|
||||
"""Loads config from json file at given path.
|
||||
|
||||
Here is the schema that the json file must adhere to:
|
||||
{
|
||||
|
|
@ -205,23 +235,33 @@ def load_from_json(ortmodule, path=None):
|
|||
"SaveONNX": true,
|
||||
"ONNXPrefix": "my_model",
|
||||
"SaveONNXPath": "/path/to/onnx/directory"
|
||||
}
|
||||
},
|
||||
"FallbackPolicy": # list of strings representing fallback policies (`_FallbackPolicy`s which can be aggregated using |
|
||||
[
|
||||
"FALLBACK_DISABLE",
|
||||
"FALLBACK_FORCE_TORCH_FORWARD",
|
||||
"FALLBACK_UNSUPPORTED_DEVICE",
|
||||
"FALLBACK_UNSUPPORTED_DATA",
|
||||
"FALLBACK_UNSUPPORTED_TORCH_MODEL",
|
||||
"FALLBACK_UNSUPPORTED_ONNX_MODEL",
|
||||
"FALLBACK_BAD_INITIALIZATION",
|
||||
],
|
||||
}
|
||||
|
||||
Args:
|
||||
ortmodule (:obj:`ORTModule`): ORTModule instance that needs to be configured
|
||||
path (:obj:`str`, optional): Path to json file. Alternatively, users can set the
|
||||
environement variable ORTMODULE_JSON_CONFIG_PATH to the json config path. In case
|
||||
environment variable ORTMODULE_JSON_CONFIG_PATH to the json config path. In case
|
||||
both path and environment variable are set, the environment variable gets precedence.
|
||||
"""
|
||||
|
||||
global JSON_PATH_ENVIRONMENT_KEY
|
||||
JSON_PATH_ENVIRONMENT_KEY = "ORTMODULE_JSON_CONFIG_PATH"
|
||||
path = os.getenv(JSON_PATH_ENVIRONMENT_KEY, path)
|
||||
|
||||
# figure out the json path
|
||||
if path is None:
|
||||
raise ValueError(f"Path to json is not provided. Provide the path through function call or by setting the environment variable {json_path_environment_key}")
|
||||
raise ValueError(
|
||||
"Path to json is not provided."
|
||||
f"Provide the path through function call or setting the environment variable {JSON_PATH_ENVIRONMENT_KEY}")
|
||||
|
||||
# load the entire json file
|
||||
data = _load_data_from_json(path)
|
||||
|
|
@ -240,7 +280,8 @@ def load_from_json(ortmodule, path=None):
|
|||
_load_use_static_shape.loading_key: _load_use_static_shape,
|
||||
_load_skip_check.loading_key: _load_skip_check,
|
||||
_load_debug_options.loading_key: _load_debug_options,
|
||||
_load_use_memory_efficient_gradient.loading_key: _load_use_memory_efficient_gradient
|
||||
_load_use_memory_efficient_gradient.loading_key: _load_use_memory_efficient_gradient,
|
||||
_load_fallback_policy.loading_key: _load_fallback_policy
|
||||
}
|
||||
|
||||
for training_mode in [True, False]:
|
||||
|
|
|
|||
|
|
@ -69,6 +69,9 @@ def test_load_config_from_json_1():
|
|||
# test use memory aware gradient builder.
|
||||
assert ort_model_attributes._use_memory_efficient_gradient == False
|
||||
|
||||
# test fallback policy
|
||||
assert ort_model_attributes._fallback_manager.policy.value == 1
|
||||
|
||||
def test_load_config_from_json_2():
|
||||
device = 'cuda'
|
||||
model = ORTModule(Net().to(device))
|
||||
|
|
@ -117,3 +120,6 @@ def test_load_config_from_json_2():
|
|||
|
||||
# test use memory aware gradient builder.
|
||||
assert ort_model_attributes._use_memory_efficient_gradient == True
|
||||
|
||||
# test fallback policy
|
||||
assert ort_model_attributes._fallback_manager.policy.value == 250
|
||||
|
|
|
|||
|
|
@ -23,5 +23,9 @@
|
|||
"SaveONNX": true,
|
||||
"ONNXPrefix": "my_model"
|
||||
},
|
||||
"UseMemoryEfficientGradient" : false
|
||||
"UseMemoryEfficientGradient" : false,
|
||||
"FallbackPolicy":
|
||||
[
|
||||
"FALLBACK_DISABLE"
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,5 +22,14 @@
|
|||
"SaveONNX": true,
|
||||
"ONNXPrefix": "my_other_model"
|
||||
},
|
||||
"UseMemoryEfficientGradient" : true
|
||||
"UseMemoryEfficientGradient" : true,
|
||||
"FallbackPolicy":
|
||||
[
|
||||
"FALLBACK_FORCE_TORCH_FORWARD",
|
||||
"FALLBACK_UNSUPPORTED_DEVICE",
|
||||
"FALLBACK_UNSUPPORTED_DATA",
|
||||
"FALLBACK_UNSUPPORTED_TORCH_MODEL",
|
||||
"FALLBACK_UNSUPPORTED_ONNX_MODEL",
|
||||
"FALLBACK_BAD_INITIALIZATION"
|
||||
]
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue