Add support for experimental json config for fallback (#8759)

This commit is contained in:
Thiago Crepaldi 2021-08-17 16:35:42 -04:00 committed by GitHub
parent 6ecf626a9c
commit ed254c283f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 85 additions and 22 deletions

View file

@ -133,7 +133,7 @@ class _FallbackManager(object):
# Read retry from environment variable for testing purposes
retry = os.getenv('ORTMODULE_FALLBACK_RETRY', str(retry)).lower() in ['true', '1', 'yes']
self.policy_exception_map = {_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD.value: {ORTModuleFallbackException,
self._policy_exception_map = {_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD.value: {ORTModuleFallbackException,
ORTModuleDeviceException,
ORTModuleIOError,
ORTModuleTorchModelException,
@ -176,7 +176,7 @@ class _FallbackManager(object):
def _set_exception(policy: _FallbackPolicy, exception: Exception, log_level: _logger.LogLevel):
if policy is not _FallbackPolicy.FALLBACK_DISABLE and \
self.policy.is_set(policy) and \
(policy.value in self.policy_exception_map and type(exception) in self.policy_exception_map[policy.value]):
(policy.value in self._policy_exception_map and type(exception) in self._policy_exception_map[policy.value]):
if log_level <= _logger.LogLevel.WARNING:
warnings.warn(

View file

@ -2,4 +2,7 @@
# Licensed under the MIT License.
# __init__.py
# JSON global constants goes here
JSON_PATH_ENVIRONMENT_KEY = "ORTMODULE_JSON_CONFIG_PATH"
from ._load_config_from_json import load_from_json

View file

@ -9,11 +9,14 @@ from types import SimpleNamespace
from onnxruntime.capi import _pybind_state as C
from functools import reduce
from . import JSON_PATH_ENVIRONMENT_KEY
from ..._fallback import _FallbackPolicy
from ..._graph_execution_manager import _SkipCheck
from ...debug_options import DebugOptions, LogLevel, _SaveOnnxOptions
log = logging.getLogger(__name__)
def _load_data_from_json(path):
"""Loads data from the json file path provided."""
@ -26,8 +29,9 @@ def _load_data_from_json(path):
return data
def _load_propagate_cast_ops(ortmodule_config_accessor, data):
"""Load PropagateCastOps from json file onto ORTModule"""
"""Loads PropagateCastOps from json file onto ORTModule."""
assert hasattr(data, _load_propagate_cast_ops.loading_key)
log.info(f"Found keyword {_load_propagate_cast_ops.loading_key} in json. Loading attributes from file.")
@ -51,8 +55,9 @@ def _load_propagate_cast_ops(ortmodule_config_accessor, data):
for key, _ in data.PropagateCastOps.__dict__.items():
key_to_function_mapping[key]()
def _load_use_external_gpu_allocator(ortmodule_config_accessor, data):
"""Load UseExternalGPUAllocator from json file onto ORTModule"""
"""Loads UseExternalGPUAllocator from json file onto ORTModule."""
assert hasattr(data, _load_use_external_gpu_allocator.loading_key)
log.info(f"Found keyword {_load_use_external_gpu_allocator.loading_key} in json. Loading attributes from file.")
@ -61,8 +66,9 @@ def _load_use_external_gpu_allocator(ortmodule_config_accessor, data):
ortmodule_config_accessor._use_external_gpu_allocator = data.UseExternalGPUAllocator
ortmodule_config_accessor._get_torch_gpu_allocator_function_addresses()
def _load_enable_custom_autograd_function(ortmodule_config_accessor, data):
"""Load EnableCustomAutogradFunction from json file onto ORTModule"""
"""Loads EnableCustomAutogradFunction from json file onto ORTModule."""
assert hasattr(data, _load_enable_custom_autograd_function.loading_key)
log.info(f"Found keyword {_load_enable_custom_autograd_function.loading_key} in json. Loading attributes from file.")
@ -70,8 +76,9 @@ def _load_enable_custom_autograd_function(ortmodule_config_accessor, data):
assert isinstance(data.EnableCustomAutogradFunction, bool), f"{_load_enable_custom_autograd_function.loading_key} must be a boolean"
ortmodule_config_accessor._enable_custom_autograd_function = data.EnableCustomAutogradFunction
def _load_allow_layer_norm_mod_precision(ortmodule_config_accessor, data):
"""Load AllowLayerNormModPrecision from json file onto ORTModule"""
"""Loads AllowLayerNormModPrecision from json file onto ORTModule."""
assert hasattr(data, _load_allow_layer_norm_mod_precision.loading_key)
log.info(f"Found keyword {_load_allow_layer_norm_mod_precision.loading_key} in json. Loading attributes from file.")
@ -79,8 +86,9 @@ def _load_allow_layer_norm_mod_precision(ortmodule_config_accessor, data):
assert isinstance(data.AllowLayerNormModPrecision, bool), f"{_load_allow_layer_norm_mod_precision.loading_key} must be a boolean"
ortmodule_config_accessor._allow_layer_norm_mod_precision = data.AllowLayerNormModPrecision
def _load_enable_grad_acc_optimization(ortmodule_config_accessor, data):
"""Load EnableGradAccOptimization from json file onto ORTModule"""
"""Loads EnableGradAccOptimization from json file onto ORTModule."""
assert hasattr(data, _load_enable_grad_acc_optimization.loading_key)
log.info(f"Found keyword {_load_enable_grad_acc_optimization.loading_key} in json. Loading attributes from file.")
@ -88,8 +96,9 @@ def _load_enable_grad_acc_optimization(ortmodule_config_accessor, data):
assert isinstance(data.EnableGradAccOptimization, bool), f"{_load_enable_grad_acc_optimization.loading_key} must be a boolean"
ortmodule_config_accessor._enable_grad_acc_optimization = data.EnableGradAccOptimization
def _load_run_symbolic_shape_infer(ortmodule_config_accessor, data):
"""Load RunSymbolicShapeInference from json file onto ORTModule"""
"""Loads RunSymbolicShapeInference from json file onto ORTModule."""
assert hasattr(data, _load_run_symbolic_shape_infer.loading_key)
log.info(f"Found keyword {_load_run_symbolic_shape_infer.loading_key} in json. Loading attributes from file.")
@ -97,8 +106,9 @@ def _load_run_symbolic_shape_infer(ortmodule_config_accessor, data):
assert isinstance(data.RunSymbolicShapeInference, bool), f"{_load_run_symbolic_shape_infer.loading_key} must be a boolean"
ortmodule_config_accessor._run_symbolic_shape_infer = data.RunSymbolicShapeInference
def _load_use_static_shape(ortmodule_config_accessor, data):
"""Load UseStaticShape from json file onto ORTModule"""
"""Loads UseStaticShape from json file onto ORTModule."""
assert hasattr(data, _load_use_static_shape.loading_key)
log.info(f"Found keyword {_load_use_static_shape.loading_key} in json. Loading attributes from file.")
@ -106,33 +116,38 @@ def _load_use_static_shape(ortmodule_config_accessor, data):
assert isinstance(data.UseStaticShape, bool), f"{_load_use_static_shape.loading_key} must be a boolean"
ortmodule_config_accessor._use_static_shape = data.UseStaticShape
def _load_skip_check(ortmodule_config_accessor, data):
"""Load SkipCheck from json file onto ORTModule"""
"""Loads SkipCheck from json file onto ORTModule."""
assert hasattr(data, _load_skip_check.loading_key)
log.info(f"Found keyword {_load_skip_check.loading_key} in json. Loading attributes from file.")
skip_check = reduce(lambda x, y: x|y, [_SkipCheck[name] for name in data.SkipCheck])
skip_check = reduce(lambda x, y: x | y, [_SkipCheck[name] for name in data.SkipCheck])
if skip_check.value > 0:
ortmodule_config_accessor._skip_check = skip_check
def _load_debug_options(ortmodule_config_accessor, data):
"""Load DebugOptions from json file onto ORTModule"""
"""Loads DebugOptions from json file onto ORTModule."""
assert hasattr(data, _load_debug_options.loading_key)
log.info(f"Found keyword {_load_debug_options.loading_key} in json. Loading attributes from file.")
log_level = LogLevel.WARNING
def _update_log_level():
nonlocal log_level
log_level = LogLevel[data.DebugOptions.LogLevel]
save_onnx = False
def _update_save_onnx():
nonlocal save_onnx
save_onnx = data.DebugOptions.SaveONNX
onnx_prefix = ''
def _update_onnx_prefix():
nonlocal onnx_prefix
onnx_prefix = data.DebugOptions.ONNXPrefix
@ -153,8 +168,9 @@ def _load_debug_options(ortmodule_config_accessor, data):
debug_options = DebugOptions(log_level=log_level, save_onnx=save_onnx, onnx_prefix=onnx_prefix)
ortmodule_config_accessor._debug_options = debug_options
def _load_use_memory_efficient_gradient(ortmodule_config_accessor, data):
"""Load UseMemoryEfficientGradient from json file onto ORTModule"""
"""Loads UseMemoryEfficientGradient from json file onto ORTModule."""
assert hasattr(data, _load_use_memory_efficient_gradient.loading_key)
log.info(f"Found keyword {_load_use_memory_efficient_gradient.loading_key} in json. Loading attributes from file.")
@ -162,6 +178,18 @@ def _load_use_memory_efficient_gradient(ortmodule_config_accessor, data):
assert isinstance(data.UseMemoryEfficientGradient, bool), f"{_load_use_memory_efficient_gradient.loading_key} must be a boolean"
ortmodule_config_accessor._use_memory_efficient_gradient = data.UseMemoryEfficientGradient
def _load_fallback_policy(ortmodule_config_accessor, data):
"""Loads SkipCheck from json file onto ORTModule."""
assert hasattr(data, _load_fallback_policy.loading_key)
log.info(f"Found keyword {_load_fallback_policy.loading_key} in json. Loading attributes from file.")
fallback_policy = reduce(lambda x, y: x | y, [_FallbackPolicy[name] for name in data.FallbackPolicy])
if fallback_policy.value > 0:
ortmodule_config_accessor._fallback_manager.policy = fallback_policy
def _define_load_function_keys():
"""Define static key variables for each loading function"""
@ -175,9 +203,11 @@ def _define_load_function_keys():
_load_skip_check.loading_key = "SkipCheck"
_load_debug_options.loading_key = "DebugOptions"
_load_use_memory_efficient_gradient.loading_key = "UseMemoryEfficientGradient"
_load_fallback_policy.loading_key = "FallbackPolicy"
def load_from_json(ortmodule, path=None):
"""Load config from json file at given path.
"""Loads config from json file at given path.
Here is the schema that the json file must adhere to:
{
@ -205,23 +235,33 @@ def load_from_json(ortmodule, path=None):
"SaveONNX": true,
"ONNXPrefix": "my_model",
"SaveONNXPath": "/path/to/onnx/directory"
}
},
"FallbackPolicy": # list of strings representing fallback policies (`_FallbackPolicy`s which can be aggregated using |
[
"FALLBACK_DISABLE",
"FALLBACK_FORCE_TORCH_FORWARD",
"FALLBACK_UNSUPPORTED_DEVICE",
"FALLBACK_UNSUPPORTED_DATA",
"FALLBACK_UNSUPPORTED_TORCH_MODEL",
"FALLBACK_UNSUPPORTED_ONNX_MODEL",
"FALLBACK_BAD_INITIALIZATION",
],
}
Args:
ortmodule (:obj:`ORTModule`): ORTModule instance that needs to be configured
path (:obj:`str`, optional): Path to json file. Alternatively, users can set the
environement variable ORTMODULE_JSON_CONFIG_PATH to the json config path. In case
environment variable ORTMODULE_JSON_CONFIG_PATH to the json config path. In case
both path and environment variable are set, the environment variable gets precedence.
"""
global JSON_PATH_ENVIRONMENT_KEY
JSON_PATH_ENVIRONMENT_KEY = "ORTMODULE_JSON_CONFIG_PATH"
path = os.getenv(JSON_PATH_ENVIRONMENT_KEY, path)
# figure out the json path
if path is None:
raise ValueError(f"Path to json is not provided. Provide the path through function call or by setting the environment variable {json_path_environment_key}")
raise ValueError(
"Path to json is not provided."
f"Provide the path through function call or setting the environment variable {JSON_PATH_ENVIRONMENT_KEY}")
# load the entire json file
data = _load_data_from_json(path)
@ -240,7 +280,8 @@ def load_from_json(ortmodule, path=None):
_load_use_static_shape.loading_key: _load_use_static_shape,
_load_skip_check.loading_key: _load_skip_check,
_load_debug_options.loading_key: _load_debug_options,
_load_use_memory_efficient_gradient.loading_key: _load_use_memory_efficient_gradient
_load_use_memory_efficient_gradient.loading_key: _load_use_memory_efficient_gradient,
_load_fallback_policy.loading_key: _load_fallback_policy
}
for training_mode in [True, False]:

View file

@ -69,6 +69,9 @@ def test_load_config_from_json_1():
# test use memory aware gradient builder.
assert ort_model_attributes._use_memory_efficient_gradient == False
# test fallback policy
assert ort_model_attributes._fallback_manager.policy.value == 1
def test_load_config_from_json_2():
device = 'cuda'
model = ORTModule(Net().to(device))
@ -117,3 +120,6 @@ def test_load_config_from_json_2():
# test use memory aware gradient builder.
assert ort_model_attributes._use_memory_efficient_gradient == True
# test fallback policy
assert ort_model_attributes._fallback_manager.policy.value == 250

View file

@ -23,5 +23,9 @@
"SaveONNX": true,
"ONNXPrefix": "my_model"
},
"UseMemoryEfficientGradient" : false
"UseMemoryEfficientGradient" : false,
"FallbackPolicy":
[
"FALLBACK_DISABLE"
]
}

View file

@ -22,5 +22,14 @@
"SaveONNX": true,
"ONNXPrefix": "my_other_model"
},
"UseMemoryEfficientGradient" : true
"UseMemoryEfficientGradient" : true,
"FallbackPolicy":
[
"FALLBACK_FORCE_TORCH_FORWARD",
"FALLBACK_UNSUPPORTED_DEVICE",
"FALLBACK_UNSUPPORTED_DATA",
"FALLBACK_UNSUPPORTED_TORCH_MODEL",
"FALLBACK_UNSUPPORTED_ONNX_MODEL",
"FALLBACK_BAD_INITIALIZATION"
]
}