diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 0b44b5d350..30f020938e 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -12,6 +12,7 @@ from packaging import version from torch.onnx import symbolic_helper from onnxruntime.capi._pybind_state import register_torch_autograd_function +from onnxruntime.training import ortmodule from . import _logger from ._fallback import ORTModuleONNXModelException, wrap_exception @@ -42,6 +43,13 @@ _CAST_PYTORCH_TO_ONNX = { } +def _full_name(klass): + module = klass.__module__ + if module == "builtins": + return klass.__qualname__ # avoid outputs like 'builtins.str' + return module + "." + klass.__qualname__ + + def _pytorch_type_to_onnx(scalar_type: str) -> torch.onnx.TensorProtoDataType: try: return torch.onnx.JitScalarType.from_name(scalar_type).onnx_type() @@ -66,7 +74,10 @@ def _export_pt_1_10(g, n, *args, **kwargs): """ try: name = kwargs["name"] - if name in BANNED_AUTOGRAD_FUNCTION_NAMES: + if name in BANNED_AUTOGRAD_FUNCTION_NAMES and ( + not ortmodule._defined_from_envvar("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", 0) + or name != torch.utils.checkpoint.CheckpointFunction.__name__ + ): raise Exception( f"The autograd.Function {name} should not be exported to ONNX. " "Please replace ORTModule with HierarchalORTModule to only" @@ -238,11 +249,16 @@ def _post_process_after_export(exported_model, enable_custom_autograd_function, def _post_process_enabling_autograd_fallback(exported_model): registered_name_mappings = {} + skipped_autograd_function_list = ortmodule._defined_from_envvar("ORTMODULE_SKIPPED_AUTOGRAD_FUNCTIONS", "").split( + "," + ) for kclass in torch.autograd.Function.__subclasses__(): + full_qualified_name = _full_name(kclass) + if full_qualified_name in skipped_autograd_function_list: + continue # Collect mapping of class names to full qualified class names. if kclass.__name__ not in registered_name_mappings: registered_name_mappings[kclass.__name__] = [] - full_qualified_name = kclass.__module__ + "." + kclass.__qualname__ registered_name_mappings[kclass.__name__].append(full_qualified_name) # Register function with class names. diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index fd29774240..a18f2347b0 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -267,7 +267,7 @@ class GraphExecutionManager(GraphExecutionInterface): provider_option_map = {"device_id": str(self._device.index)} if not self.is_rocm_pytorch: # Set Conv algo search mode to HEURISTIC by default, which is same as PyTorch's default setting. - conv_algo_search = ortmodule._defined_from_envvar("CONV_ALGO_SEARCH", "HEURISTIC", warn=True) + conv_algo_search = ortmodule._defined_from_envvar("ORTMODULE_CONV_ALGO_SEARCH", "HEURISTIC", warn=True) if conv_algo_search not in ["HEURISTIC", "EXHAUSTIVE"]: warnings.warn("Invalid value of env CONV_ALGO_SEARCH. Must be HEURISTIC or EXHAUSTIVE.") conv_algo_search = "HEURISTIC" diff --git a/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py b/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py index 650a51ecb1..06f024111d 100644 --- a/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py @@ -2,10 +2,13 @@ # Licensed under the MIT License. # debug_options.py import tempfile +import warnings + import torch -from ... import ORTModule + from .... import ortmodule -from ...debug_options import DebugOptions +from ... import ORTModule +from ...debug_options import DebugOptions, LogLevel # nn.Module's in this set are considered exportable to ONNX. # For other nn.Module's, torch.onnx.export is called to check if @@ -15,6 +18,40 @@ _force_exportable_set = set( ) +class _IteratedORTModule(torch.nn.Module): + """ + It's possible that a module instance is called multiple times in a single forward() call with different inputs. + If the number of inputs or the data types are different, the exported graph for a given input set cannot be used + for others. The _IteratedORTModule class is used to handle this case. It creates multiple ORTModule instances + for a given nn.Module instance and uses one of them for each input set. + + NOTE that we assume that for each step run, the running order of different input sets are same. + If it's not this case (e.g., a module is used for checkpointing so that a same input set is used twice), + this class cannot handle it. An ideal way is to maintain a map from different input sets (maybe compute a hash) + to ORTModule instances. + """ + + def __init__(self, module, count, log_level, save_onnx, onnx_prefix): + super(_IteratedORTModule, self).__init__() + assert count > 1 + self._count = count + self._it = count - 1 + self._ortmodules = [] + for idx in range(count): + self._ortmodules.append( + ORTModule( + module, + debug_options=DebugOptions( + log_level=log_level, save_onnx=save_onnx, onnx_prefix=onnx_prefix + "_it" + str(idx) + ), + ) + ) + + def forward(self, *inputs, **kwargs): + self._it = (self._it + 1) % self._count + return self._ortmodules[self._it](*inputs, **kwargs) + + class HierarchicalORTModule(torch.nn.Module): """ Recursively wraps submodules of `module` as ORTModule whenever possible @@ -55,7 +92,9 @@ class HierarchicalORTModule(torch.nn.Module): self._initialized = False super(HierarchicalORTModule, self).__init__() self._original_module = module - self._debug_options = debug_options if debug_options else DebugOptions() + self._log_level = debug_options.logging.log_level if debug_options else LogLevel.ERROR + self._save_onnx = debug_options.save_onnx_models.save if debug_options else False + self._name_prefix = debug_options.save_onnx_models.name_prefix if debug_options else "" def _initialize(self, *args, **kwargs): handle_pool = [] @@ -69,17 +108,19 @@ class HierarchicalORTModule(torch.nn.Module): module_arg_pool[module] = [args] # Recursively hook "record_args" to module and all its sub-modules. - # The function "record_args" records the inputs for each nn.Module, - # and later we will try exporting those nn.Module's with their recorded - # inputs. + # The function "record_args" records the inputs for each nn.Module and later we will try exporting + # those nn.Module's with their recorded inputs. + # NOTE that if a module is not called from forward(), it will fail to be captured by this hook. def recursive_hook(module): + # We cannot skip module in whitelist because it's possible that a module is called multiple times + # so that we still need to know the number of different input sets and use _IteratedORTModule to handle it. handle_pool.append(module.register_forward_pre_hook(record_args)) - for name, sub in module._modules.items(): - if isinstance(sub, torch.nn.ModuleList): - for name1, sub1 in sub._modules.items(): - recursive_hook(sub1) + for _, sub_module in module._modules.items(): + if isinstance(sub_module, torch.nn.ModuleList): + for _, sub_module_item in sub_module._modules.items(): + recursive_hook(sub_module_item) else: - recursive_hook(sub) + recursive_hook(sub_module) exportable_list = {} @@ -87,91 +128,75 @@ class HierarchicalORTModule(torch.nn.Module): # "module" can be wrapped as ORTModule. Otherwise, "module" is # not exportable to ONNX. def check_exportable(module): - # forward functions of classes in _force_exportable_set may not be called - # thus not in module_arg_pool + def try_export(module, args): + try: + with tempfile.NamedTemporaryFile(prefix="sub-module") as temp, torch.no_grad(): + torch.onnx.export( + module, + args, + temp, + opset_version=ortmodule.ONNX_OPSET_VERSION, + do_constant_folding=False, + export_params=False, + keep_initializers_as_inputs=True, + training=torch.onnx.TrainingMode.TRAINING, + ) + except Exception as e: + if self._log_level <= LogLevel.WARNING: + warnings.warn( + f"Failed to export module with type {type(module).__name__}. Error message: {str(e)}", + UserWarning, + ) + return False + return True + if type(module) in _force_exportable_set: exportable_list[module] = True - return True - sub_dict = module._modules - if not sub_dict: - # No sub-module exists, so this module is a leaf - # module in overall model hierarchy. - exportable = True - # Check if this leaf module is exportable. + return + + # It's possible that the model runs a module by calling some other function instead of forward() + # so that the module is not captured by the forward pre-hook. In this case, we will treat it as + # not exportable for now. + module_exportable = module in module_arg_pool + if module_exportable: for args in module_arg_pool[module]: - try: - with tempfile.NamedTemporaryFile(prefix="sub-module") as temp: - torch.onnx.export( - module, - args, - temp, - opset_version=ortmodule.ONNX_OPSET_VERSION, - do_constant_folding=False, - export_params=False, - keep_initializers_as_inputs=True, - training=torch.onnx.TrainingMode.TRAINING, - ) - except Exception as e: - exportable = False + if not try_export(module, args): + module_exportable = False + break + elif self._log_level <= LogLevel.WARNING: + warnings.warn( + f"Module with type {type(module).__name__} is not exportable because it's not in module_arg_pool.", + UserWarning, + ) - exportable_list[module] = exportable - return exportable - else: - sub_exportable = True - for name, sub_module in sub_dict.items(): - if isinstance(sub_module, torch.nn.ModuleList): - for name1, sub_module1 in sub_module._modules.items(): - sub_exportable1 = check_exportable(sub_module1) - sub_exportable = sub_exportable and sub_exportable1 - else: - sub_exportable1 = check_exportable(sub_module) - sub_exportable = sub_exportable and sub_exportable1 + exportable_list[module] = module_exportable + if module_exportable: + return - if sub_exportable is False: - # At least one existing sub-module is not exportable, - # so is the entire module. - exportable_list[module] = sub_exportable - return sub_exportable + sub_module_dict = module._modules + if not sub_module_dict: + # No sub-module exists, so this module is a leaf + return + + for _, sub_module in sub_module_dict.items(): + if isinstance(sub_module, torch.nn.ModuleList): + for _, sub_module_item in sub_module._modules.items(): + check_exportable(sub_module_item) else: - # Now, we know all sub-modules are exportable, so - # we are going to check if the composition of them - # is still exportable at this module level. - module_exportable = True - for args in module_arg_pool[module]: - try: - with tempfile.NamedTemporaryFile(prefix="sub-module") as temp: - torch.onnx.export( - module, - args, - temp, - opset_version=ortmodule.ONNX_OPSET_VERSION, - do_constant_folding=False, - export_params=False, - keep_initializers_as_inputs=True, - training=torch.onnx.TrainingMode.TRAINING, - ) - except Exception as e: - # If this module is not exportable for one arg - # group, we say this module is not exportable. - module_exportable = False - # Already found a broken case. - # No need to check next case. - break - - exportable_list[module] = module_exportable - return exportable_list[module] + check_exportable(sub_module) # Add a hook to record forward's input for all modules. recursive_hook(self._original_module) # Run forward with actual input to record all possible # inputs for all invoked modules. - _ = self._original_module(*args, **kwargs) + with torch.no_grad(): + _ = self._original_module(*args, **kwargs) # We already have "supported_modules" so # we no longer need those hooks in forward pass. - for h in handle_pool: - h.remove() + for handle in handle_pool: + handle.remove() # Try exporter on all module-input pairs. If a module can be exported with # all its recorded inputs, then it's exporable. @@ -184,27 +209,70 @@ class HierarchicalORTModule(torch.nn.Module): # Top-down wrapper to replace nn.Module's with ORTModule. # Note that using bottom-up wrapper may lead to much # ORTModule instances and each ORTModule owns a much smaller graph. - def recursive_wrap(module): - sub_dict = module._modules - for name, sub in sub_dict.items(): - if isinstance(sub, torch.nn.ModuleList): + def recursive_wrap(module, save_onnx=False, onnx_prefix=""): + sub_module_dict = module._modules + for name, sub_module in sub_module_dict.items(): + new_prefix = onnx_prefix + "_" + name + if isinstance(sub_module, torch.nn.ModuleList): # We encounter a list of sub-modules. # Let's wrap them one-by-one. - for name1, sub1 in sub._modules.items(): - if is_supported(sub1): - sub._modules[name1] = ORTModule(sub1, debug_options=self._debug_options) + idx = 0 + for item_name, sub_module_item in sub_module._modules.items(): + # Avoid saving too many graphs. + new_save_onnx = save_onnx and idx == 0 + sub_new_prefix = new_prefix + "_" + item_name + if is_supported(sub_module_item): + if sub_module_item in module_arg_pool and len(module_arg_pool[sub_module_item]) > 1: + sub_module._modules[item_name] = _IteratedORTModule( + sub_module_item, + len(module_arg_pool[sub_module_item]), + self._log_level, + new_save_onnx, + sub_new_prefix, + ) + else: + sub_module._modules[item_name] = ORTModule( + sub_module_item, + debug_options=DebugOptions( + log_level=self._log_level, save_onnx=new_save_onnx, onnx_prefix=sub_new_prefix + ), + ) else: - recursive_wrap(sub1) + recursive_wrap(sub_module_item, new_save_onnx, sub_new_prefix) + idx += 1 else: - if is_supported(sub): + if is_supported(sub_module): # Just wrap it as ORTModule when possible. - sub_dict[name] = ORTModule(sub, debug_options=self._debug_options) + if sub_module in module_arg_pool and len(module_arg_pool[sub_module]) > 1: + sub_module_dict[name] = _IteratedORTModule( + sub_module, len(module_arg_pool[sub_module]), self._log_level, save_onnx, new_prefix + ) + else: + sub_module_dict[name] = ORTModule( + sub_module, + debug_options=DebugOptions( + log_level=self._log_level, save_onnx=save_onnx, onnx_prefix=new_prefix + ), + ) else: # This sub-module is not exportable to ONNX # Let's check its sub-modules. - recursive_wrap(sub) + recursive_wrap(sub_module, save_onnx, new_prefix) - recursive_wrap(self._original_module) + if is_supported(self._original_module): + self._original_module = ORTModule( + self._original_module, + debug_options=DebugOptions( + log_level=self._log_level, save_onnx=self._save_onnx, onnx_prefix=self._name_prefix + ), + ) + else: + recursive_wrap(self._original_module, self._save_onnx, self._name_prefix) + if self._log_level <= LogLevel.WARNING: + warnings.warn( + f"Wrapped module: {str(self._original_module)}.", + UserWarning, + ) self._initialized = True def forward(self, *inputs, **kwargs): diff --git a/orttraining/orttraining/test/python/orttraining_test_hierarchical_ortmodule.py b/orttraining/orttraining/test/python/orttraining_test_hierarchical_ortmodule.py index 7eba32402c..42daff79bd 100644 --- a/orttraining/orttraining/test/python/orttraining_test_hierarchical_ortmodule.py +++ b/orttraining/orttraining/test/python/orttraining_test_hierarchical_ortmodule.py @@ -1,11 +1,13 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from collections.abc import Iterable + import torch import torch.nn as nn import torch.nn.functional as F -from collections.abc import Iterable from torch.utils.checkpoint import checkpoint + from onnxruntime.training.ortmodule import ORTModule from onnxruntime.training.ortmodule.experimental.hierarchical_ortmodule import HierarchicalORTModule @@ -140,11 +142,52 @@ class MainWithMultiModuleOutputs(nn.Module): return y1, y2 +class G(nn.Module): + def __init__(self): + super(G, self).__init__() + self.l1 = nn.Linear(2, 2) + + def forward(self, x): + if x.dtype == torch.float16: + x = x.to(torch.float32) + x = self.l1(x) + return x if x.dtype == torch.float32 else x.to(torch.float16) + + def forward_fp16(self, x): + assert x.dtype == torch.float16 + return self.l1(x.to(torch.float32)).to(torch.float16) + + +class MainWithModuleMultipleCalls(nn.Module): + # Module with mixed precision. + def __init__(self): + super(MainWithModuleMultipleCalls, self).__init__() + self.b = B() + self.g = G() + + def forward(self, x): + x = self.g(x) + x = self.g(x.to(torch.float16)).to(torch.float32) + return self.b(x) + + +class MainWithNonForwardCall(nn.Module): + # Module with mixed precision. + def __init__(self): + super(MainWithNonForwardCall, self).__init__() + self.b = B() + self.g = G() + + def forward(self, x): + x = self.g.forward_fp16(x.to(torch.float16)).to(torch.float32) + return self.b(x) + + def test_hierarchical_ortmodule(): - def count_ortmodule(module): - n = 1 if isinstance(module, ORTModule) else 0 + def count_ortmodule(module, is_iterated=False): + n = 1 if type(module).__name__ == ("_IteratedORTModule" if is_iterated else "ORTModule") else 0 for sub in module._modules.values(): - n = n + count_ortmodule(sub) + n = n + count_ortmodule(sub, is_iterated) return n def call_backward(y): @@ -162,7 +205,7 @@ def test_hierarchical_ortmodule(): else: torch.allclose(y, y_ref) - def trial(module_to_wrap, args, expected_num_ortmodule): + def trial(module_to_wrap, args, expected_num_ortmodule, expected_num_iterated_ortmodule=0): # Run baseline model. m = module_to_wrap @@ -185,6 +228,7 @@ def test_hierarchical_ortmodule(): # Some sub-modules become ORTModule. assert expected_num_ortmodule == count_ortmodule(m) + assert expected_num_iterated_ortmodule == count_ortmodule(m, is_iterated=True) call_allclose(y, y_ref) call_allclose(g, g_ref) @@ -196,6 +240,8 @@ def test_hierarchical_ortmodule(): trial(MainWithMultiModuleOutputs(), [torch.rand(2).requires_grad_()], 10) trial(MainWithNonTensorInput(), [torch.rand(2).requires_grad_(), "reverse"], 6) trial(MainWithNonTensorInput(), [torch.rand(2).requires_grad_(), "normal"], 6) + trial(MainWithModuleMultipleCalls(), [torch.rand(2).requires_grad_()], 2, 1) + trial(MainWithNonForwardCall(), [torch.rand(2).requires_grad_()], 3) if __name__ == "__main__": diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 83bfdead8b..8a1765ca61 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -777,7 +777,7 @@ def test_gradient_correctness_conv1d(use_fp16, input_requires_grad, conv_algo_se return if conv_algo_search is not None: - os.environ["CONV_ALGO_SEARCH"] = conv_algo_search + os.environ["ORTMODULE_CONV_ALGO_SEARCH"] = conv_algo_search device = "cuda" N, seq_len, C_in, C_out, kernel_size = 32, 128, 1536, 1536, 3 @@ -820,7 +820,7 @@ def test_gradient_correctness_conv1d(use_fp16, input_requires_grad, conv_algo_se assert actual_conv_algo_search == expected_conv_algo_search if conv_algo_search is not None: - del os.environ["CONV_ALGO_SEARCH"] + del os.environ["ORTMODULE_CONV_ALGO_SEARCH"] def _run_gradient_correctness_transpose(perm, shape): diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py index 99d7f99f93..6932e9f707 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py @@ -1104,3 +1104,107 @@ def test_non_differentiable_autograd_function(): assert torch.allclose(y_ref, y_train) run() + + +def test_checkpoint_function(): + class A(torch.nn.Module): + # A supported module. + def __init__(self): + super(A, self).__init__() + self.l1 = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.l1(x) + + class B(torch.nn.Module): + # This module is not exportable to ONNX because it + # uses gradient-checkpointing. However, its two sub-module's + # are exportable, so ORTModule should be used to compute them. + def __init__(self): + super(B, self).__init__() + self.l1 = torch.nn.Linear(2, 2) + self.a = A() + + def forward(self, x): + def custom(): + def custom_forward(x_): + return self.a(x_) + + return custom_forward + + z = self.l1(torch.utils.checkpoint.checkpoint(custom(), x)) + return z + + def run(): + m = B().to("cuda") + x = torch.rand((2, 2), dtype=torch.float).to("cuda") + + # Baseline. + y_ref = m(x) + print("Ref:") + print(y_ref) + + os.environ["ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT"] = "1" + + m = ORTModule(m) + + # Inferene mode. + y_infer = m(x) + print(y_infer) + assert torch.allclose(y_ref, y_infer) + + # Training mode. + m.train() + y_train = m(x) + print("Train:") + assert torch.allclose(y_ref, y_train) + + del os.environ["ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT"] + + run() + + +def test_skipped_autograd_function(): + class TestSkippedFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, x): + ctx.save_for_backward(x) + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors + return None + + class TestSkippedModel(torch.nn.Module): + def __init__(self, output_size): + super(TestSkippedModel, self).__init__() + self.custom_fn = TestSkippedFunction.apply + self.bias = Parameter(torch.empty(output_size, device=torch.cuda.current_device(), dtype=torch.float)) + + with torch.no_grad(): + self.bias.uniform_() + + def forward(self, model_input): + # model_input did not require_grad + out = self.custom_fn(model_input) + return out + self.bias + + output_size = 1024 + + os.environ[ + "ORTMODULE_SKIPPED_AUTOGRAD_FUNCTIONS" + ] = "orttraining_test_ortmodule_autograd.test_skipped_autograd_function..TestSkippedFunction" + + m = ORTModule(TestSkippedModel(output_size).to("cuda")) + can_run = True + try: + m(torch.randn(output_size, dtype=torch.float, device="cuda")) + except RuntimeError as e: + assert "No forward registered for TestSkippedFunction" in str(e) + can_run = False + + assert not can_run + + del os.environ["ORTMODULE_SKIPPED_AUTOGRAD_FUNCTIONS"]