From ed254c283f1eba7dab8ed1cafc1bbf063bb45463 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Tue, 17 Aug 2021 16:35:42 -0400 Subject: [PATCH] Add support for experimental json config for fallback (#8759) --- .../python/training/ortmodule/_fallback.py | 4 +- .../experimental/json_config/__init__.py | 3 + .../json_config/_load_config_from_json.py | 77 ++++++++++++++----- ...test_ortmodule_experimental_json_config.py | 6 ++ ..._ortmodule_experimental_json_config_1.json | 6 +- ..._ortmodule_experimental_json_config_2.json | 11 ++- 6 files changed, 85 insertions(+), 22 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_fallback.py b/orttraining/orttraining/python/training/ortmodule/_fallback.py index 6ead2080a5..9810e19d93 100644 --- a/orttraining/orttraining/python/training/ortmodule/_fallback.py +++ b/orttraining/orttraining/python/training/ortmodule/_fallback.py @@ -133,7 +133,7 @@ class _FallbackManager(object): # Read retry from environment variable for testing purposes retry = os.getenv('ORTMODULE_FALLBACK_RETRY', str(retry)).lower() in ['true', '1', 'yes'] - self.policy_exception_map = {_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD.value: {ORTModuleFallbackException, + self._policy_exception_map = {_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD.value: {ORTModuleFallbackException, ORTModuleDeviceException, ORTModuleIOError, ORTModuleTorchModelException, @@ -176,7 +176,7 @@ class _FallbackManager(object): def _set_exception(policy: _FallbackPolicy, exception: Exception, log_level: _logger.LogLevel): if policy is not _FallbackPolicy.FALLBACK_DISABLE and \ self.policy.is_set(policy) and \ - (policy.value in self.policy_exception_map and type(exception) in self.policy_exception_map[policy.value]): + (policy.value in self._policy_exception_map and type(exception) in self._policy_exception_map[policy.value]): if log_level <= _logger.LogLevel.WARNING: warnings.warn( diff --git a/orttraining/orttraining/python/training/ortmodule/experimental/json_config/__init__.py b/orttraining/orttraining/python/training/ortmodule/experimental/json_config/__init__.py index 0d62a217c7..2f7322394f 100644 --- a/orttraining/orttraining/python/training/ortmodule/experimental/json_config/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/experimental/json_config/__init__.py @@ -2,4 +2,7 @@ # Licensed under the MIT License. # __init__.py +# JSON global constants goes here +JSON_PATH_ENVIRONMENT_KEY = "ORTMODULE_JSON_CONFIG_PATH" + from ._load_config_from_json import load_from_json diff --git a/orttraining/orttraining/python/training/ortmodule/experimental/json_config/_load_config_from_json.py b/orttraining/orttraining/python/training/ortmodule/experimental/json_config/_load_config_from_json.py index ce4c1ab5d9..0c6b99bdf9 100644 --- a/orttraining/orttraining/python/training/ortmodule/experimental/json_config/_load_config_from_json.py +++ b/orttraining/orttraining/python/training/ortmodule/experimental/json_config/_load_config_from_json.py @@ -9,11 +9,14 @@ from types import SimpleNamespace from onnxruntime.capi import _pybind_state as C from functools import reduce +from . import JSON_PATH_ENVIRONMENT_KEY +from ..._fallback import _FallbackPolicy from ..._graph_execution_manager import _SkipCheck from ...debug_options import DebugOptions, LogLevel, _SaveOnnxOptions log = logging.getLogger(__name__) + def _load_data_from_json(path): """Loads data from the json file path provided.""" @@ -26,8 +29,9 @@ def _load_data_from_json(path): return data + def _load_propagate_cast_ops(ortmodule_config_accessor, data): - """Load PropagateCastOps from json file onto ORTModule""" + """Loads PropagateCastOps from json file onto ORTModule.""" assert hasattr(data, _load_propagate_cast_ops.loading_key) log.info(f"Found keyword {_load_propagate_cast_ops.loading_key} in json. Loading attributes from file.") @@ -51,8 +55,9 @@ def _load_propagate_cast_ops(ortmodule_config_accessor, data): for key, _ in data.PropagateCastOps.__dict__.items(): key_to_function_mapping[key]() + def _load_use_external_gpu_allocator(ortmodule_config_accessor, data): - """Load UseExternalGPUAllocator from json file onto ORTModule""" + """Loads UseExternalGPUAllocator from json file onto ORTModule.""" assert hasattr(data, _load_use_external_gpu_allocator.loading_key) log.info(f"Found keyword {_load_use_external_gpu_allocator.loading_key} in json. Loading attributes from file.") @@ -61,8 +66,9 @@ def _load_use_external_gpu_allocator(ortmodule_config_accessor, data): ortmodule_config_accessor._use_external_gpu_allocator = data.UseExternalGPUAllocator ortmodule_config_accessor._get_torch_gpu_allocator_function_addresses() + def _load_enable_custom_autograd_function(ortmodule_config_accessor, data): - """Load EnableCustomAutogradFunction from json file onto ORTModule""" + """Loads EnableCustomAutogradFunction from json file onto ORTModule.""" assert hasattr(data, _load_enable_custom_autograd_function.loading_key) log.info(f"Found keyword {_load_enable_custom_autograd_function.loading_key} in json. Loading attributes from file.") @@ -70,8 +76,9 @@ def _load_enable_custom_autograd_function(ortmodule_config_accessor, data): assert isinstance(data.EnableCustomAutogradFunction, bool), f"{_load_enable_custom_autograd_function.loading_key} must be a boolean" ortmodule_config_accessor._enable_custom_autograd_function = data.EnableCustomAutogradFunction + def _load_allow_layer_norm_mod_precision(ortmodule_config_accessor, data): - """Load AllowLayerNormModPrecision from json file onto ORTModule""" + """Loads AllowLayerNormModPrecision from json file onto ORTModule.""" assert hasattr(data, _load_allow_layer_norm_mod_precision.loading_key) log.info(f"Found keyword {_load_allow_layer_norm_mod_precision.loading_key} in json. Loading attributes from file.") @@ -79,8 +86,9 @@ def _load_allow_layer_norm_mod_precision(ortmodule_config_accessor, data): assert isinstance(data.AllowLayerNormModPrecision, bool), f"{_load_allow_layer_norm_mod_precision.loading_key} must be a boolean" ortmodule_config_accessor._allow_layer_norm_mod_precision = data.AllowLayerNormModPrecision + def _load_enable_grad_acc_optimization(ortmodule_config_accessor, data): - """Load EnableGradAccOptimization from json file onto ORTModule""" + """Loads EnableGradAccOptimization from json file onto ORTModule.""" assert hasattr(data, _load_enable_grad_acc_optimization.loading_key) log.info(f"Found keyword {_load_enable_grad_acc_optimization.loading_key} in json. Loading attributes from file.") @@ -88,8 +96,9 @@ def _load_enable_grad_acc_optimization(ortmodule_config_accessor, data): assert isinstance(data.EnableGradAccOptimization, bool), f"{_load_enable_grad_acc_optimization.loading_key} must be a boolean" ortmodule_config_accessor._enable_grad_acc_optimization = data.EnableGradAccOptimization + def _load_run_symbolic_shape_infer(ortmodule_config_accessor, data): - """Load RunSymbolicShapeInference from json file onto ORTModule""" + """Loads RunSymbolicShapeInference from json file onto ORTModule.""" assert hasattr(data, _load_run_symbolic_shape_infer.loading_key) log.info(f"Found keyword {_load_run_symbolic_shape_infer.loading_key} in json. Loading attributes from file.") @@ -97,8 +106,9 @@ def _load_run_symbolic_shape_infer(ortmodule_config_accessor, data): assert isinstance(data.RunSymbolicShapeInference, bool), f"{_load_run_symbolic_shape_infer.loading_key} must be a boolean" ortmodule_config_accessor._run_symbolic_shape_infer = data.RunSymbolicShapeInference + def _load_use_static_shape(ortmodule_config_accessor, data): - """Load UseStaticShape from json file onto ORTModule""" + """Loads UseStaticShape from json file onto ORTModule.""" assert hasattr(data, _load_use_static_shape.loading_key) log.info(f"Found keyword {_load_use_static_shape.loading_key} in json. Loading attributes from file.") @@ -106,33 +116,38 @@ def _load_use_static_shape(ortmodule_config_accessor, data): assert isinstance(data.UseStaticShape, bool), f"{_load_use_static_shape.loading_key} must be a boolean" ortmodule_config_accessor._use_static_shape = data.UseStaticShape + def _load_skip_check(ortmodule_config_accessor, data): - """Load SkipCheck from json file onto ORTModule""" + """Loads SkipCheck from json file onto ORTModule.""" assert hasattr(data, _load_skip_check.loading_key) log.info(f"Found keyword {_load_skip_check.loading_key} in json. Loading attributes from file.") - skip_check = reduce(lambda x, y: x|y, [_SkipCheck[name] for name in data.SkipCheck]) + skip_check = reduce(lambda x, y: x | y, [_SkipCheck[name] for name in data.SkipCheck]) if skip_check.value > 0: ortmodule_config_accessor._skip_check = skip_check + def _load_debug_options(ortmodule_config_accessor, data): - """Load DebugOptions from json file onto ORTModule""" + """Loads DebugOptions from json file onto ORTModule.""" assert hasattr(data, _load_debug_options.loading_key) log.info(f"Found keyword {_load_debug_options.loading_key} in json. Loading attributes from file.") log_level = LogLevel.WARNING + def _update_log_level(): nonlocal log_level log_level = LogLevel[data.DebugOptions.LogLevel] save_onnx = False + def _update_save_onnx(): nonlocal save_onnx save_onnx = data.DebugOptions.SaveONNX onnx_prefix = '' + def _update_onnx_prefix(): nonlocal onnx_prefix onnx_prefix = data.DebugOptions.ONNXPrefix @@ -153,8 +168,9 @@ def _load_debug_options(ortmodule_config_accessor, data): debug_options = DebugOptions(log_level=log_level, save_onnx=save_onnx, onnx_prefix=onnx_prefix) ortmodule_config_accessor._debug_options = debug_options + def _load_use_memory_efficient_gradient(ortmodule_config_accessor, data): - """Load UseMemoryEfficientGradient from json file onto ORTModule""" + """Loads UseMemoryEfficientGradient from json file onto ORTModule.""" assert hasattr(data, _load_use_memory_efficient_gradient.loading_key) log.info(f"Found keyword {_load_use_memory_efficient_gradient.loading_key} in json. Loading attributes from file.") @@ -162,6 +178,18 @@ def _load_use_memory_efficient_gradient(ortmodule_config_accessor, data): assert isinstance(data.UseMemoryEfficientGradient, bool), f"{_load_use_memory_efficient_gradient.loading_key} must be a boolean" ortmodule_config_accessor._use_memory_efficient_gradient = data.UseMemoryEfficientGradient + +def _load_fallback_policy(ortmodule_config_accessor, data): + """Loads SkipCheck from json file onto ORTModule.""" + + assert hasattr(data, _load_fallback_policy.loading_key) + log.info(f"Found keyword {_load_fallback_policy.loading_key} in json. Loading attributes from file.") + + fallback_policy = reduce(lambda x, y: x | y, [_FallbackPolicy[name] for name in data.FallbackPolicy]) + if fallback_policy.value > 0: + ortmodule_config_accessor._fallback_manager.policy = fallback_policy + + def _define_load_function_keys(): """Define static key variables for each loading function""" @@ -175,9 +203,11 @@ def _define_load_function_keys(): _load_skip_check.loading_key = "SkipCheck" _load_debug_options.loading_key = "DebugOptions" _load_use_memory_efficient_gradient.loading_key = "UseMemoryEfficientGradient" + _load_fallback_policy.loading_key = "FallbackPolicy" + def load_from_json(ortmodule, path=None): - """Load config from json file at given path. + """Loads config from json file at given path. Here is the schema that the json file must adhere to: { @@ -205,23 +235,33 @@ def load_from_json(ortmodule, path=None): "SaveONNX": true, "ONNXPrefix": "my_model", "SaveONNXPath": "/path/to/onnx/directory" - } + }, + "FallbackPolicy": # list of strings representing fallback policies (`_FallbackPolicy`s which can be aggregated using | + [ + "FALLBACK_DISABLE", + "FALLBACK_FORCE_TORCH_FORWARD", + "FALLBACK_UNSUPPORTED_DEVICE", + "FALLBACK_UNSUPPORTED_DATA", + "FALLBACK_UNSUPPORTED_TORCH_MODEL", + "FALLBACK_UNSUPPORTED_ONNX_MODEL", + "FALLBACK_BAD_INITIALIZATION", + ], } Args: ortmodule (:obj:`ORTModule`): ORTModule instance that needs to be configured path (:obj:`str`, optional): Path to json file. Alternatively, users can set the - environement variable ORTMODULE_JSON_CONFIG_PATH to the json config path. In case + environment variable ORTMODULE_JSON_CONFIG_PATH to the json config path. In case both path and environment variable are set, the environment variable gets precedence. """ - global JSON_PATH_ENVIRONMENT_KEY - JSON_PATH_ENVIRONMENT_KEY = "ORTMODULE_JSON_CONFIG_PATH" path = os.getenv(JSON_PATH_ENVIRONMENT_KEY, path) # figure out the json path if path is None: - raise ValueError(f"Path to json is not provided. Provide the path through function call or by setting the environment variable {json_path_environment_key}") + raise ValueError( + "Path to json is not provided." + f"Provide the path through function call or setting the environment variable {JSON_PATH_ENVIRONMENT_KEY}") # load the entire json file data = _load_data_from_json(path) @@ -240,7 +280,8 @@ def load_from_json(ortmodule, path=None): _load_use_static_shape.loading_key: _load_use_static_shape, _load_skip_check.loading_key: _load_skip_check, _load_debug_options.loading_key: _load_debug_options, - _load_use_memory_efficient_gradient.loading_key: _load_use_memory_efficient_gradient + _load_use_memory_efficient_gradient.loading_key: _load_use_memory_efficient_gradient, + _load_fallback_policy.loading_key: _load_fallback_policy } for training_mode in [True, False]: diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config.py index 9408c53674..f6a2a8986e 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config.py @@ -69,6 +69,9 @@ def test_load_config_from_json_1(): # test use memory aware gradient builder. assert ort_model_attributes._use_memory_efficient_gradient == False + # test fallback policy + assert ort_model_attributes._fallback_manager.policy.value == 1 + def test_load_config_from_json_2(): device = 'cuda' model = ORTModule(Net().to(device)) @@ -117,3 +120,6 @@ def test_load_config_from_json_2(): # test use memory aware gradient builder. assert ort_model_attributes._use_memory_efficient_gradient == True + + # test fallback policy + assert ort_model_attributes._fallback_manager.policy.value == 250 diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config_1.json b/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config_1.json index 4a0be5a2b9..47d1ee4425 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config_1.json +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config_1.json @@ -23,5 +23,9 @@ "SaveONNX": true, "ONNXPrefix": "my_model" }, - "UseMemoryEfficientGradient" : false + "UseMemoryEfficientGradient" : false, + "FallbackPolicy": + [ + "FALLBACK_DISABLE" + ] } diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config_2.json b/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config_2.json index 28f1a45c67..7130dbaf67 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config_2.json +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config_2.json @@ -22,5 +22,14 @@ "SaveONNX": true, "ONNXPrefix": "my_other_model" }, - "UseMemoryEfficientGradient" : true + "UseMemoryEfficientGradient" : true, + "FallbackPolicy": + [ + "FALLBACK_FORCE_TORCH_FORWARD", + "FALLBACK_UNSUPPORTED_DEVICE", + "FALLBACK_UNSUPPORTED_DATA", + "FALLBACK_UNSUPPORTED_TORCH_MODEL", + "FALLBACK_UNSUPPORTED_ONNX_MODEL", + "FALLBACK_BAD_INITIALIZATION" + ] }