Fixes of Hierarchical ORTModule and ORTModule PythonOp (#13347)

The PR applies some fixes to Hierarchical ORTModule and ORTModule
PythonOp.

For Hierarchical ORTModule:
- Don't wrap module if the caller is to call other function instead of
forward() function
- Support single module instance is call multiple times with different
types of inputs
- Check if module can be warped from top to bottom instead of from
bottom to top

For ORTModule PythonOp:
- Add env variable control to allow using
torch.utils.checkpoint.CheckpointFunction
- Add env variable control to skip register some autograd functions so
that there is no conflict for some models.
This commit is contained in:
Vincent Wang 2022-10-20 08:16:03 +08:00 committed by GitHub
parent 418304743d
commit b6b3f41636
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 337 additions and 103 deletions

View file

@ -12,6 +12,7 @@ from packaging import version
from torch.onnx import symbolic_helper
from onnxruntime.capi._pybind_state import register_torch_autograd_function
from onnxruntime.training import ortmodule
from . import _logger
from ._fallback import ORTModuleONNXModelException, wrap_exception
@ -42,6 +43,13 @@ _CAST_PYTORCH_TO_ONNX = {
}
def _full_name(klass):
module = klass.__module__
if module == "builtins":
return klass.__qualname__ # avoid outputs like 'builtins.str'
return module + "." + klass.__qualname__
def _pytorch_type_to_onnx(scalar_type: str) -> torch.onnx.TensorProtoDataType:
try:
return torch.onnx.JitScalarType.from_name(scalar_type).onnx_type()
@ -66,7 +74,10 @@ def _export_pt_1_10(g, n, *args, **kwargs):
"""
try:
name = kwargs["name"]
if name in BANNED_AUTOGRAD_FUNCTION_NAMES:
if name in BANNED_AUTOGRAD_FUNCTION_NAMES and (
not ortmodule._defined_from_envvar("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", 0)
or name != torch.utils.checkpoint.CheckpointFunction.__name__
):
raise Exception(
f"The autograd.Function {name} should not be exported to ONNX. "
"Please replace ORTModule with HierarchalORTModule to only"
@ -238,11 +249,16 @@ def _post_process_after_export(exported_model, enable_custom_autograd_function,
def _post_process_enabling_autograd_fallback(exported_model):
registered_name_mappings = {}
skipped_autograd_function_list = ortmodule._defined_from_envvar("ORTMODULE_SKIPPED_AUTOGRAD_FUNCTIONS", "").split(
","
)
for kclass in torch.autograd.Function.__subclasses__():
full_qualified_name = _full_name(kclass)
if full_qualified_name in skipped_autograd_function_list:
continue
# Collect mapping of class names to full qualified class names.
if kclass.__name__ not in registered_name_mappings:
registered_name_mappings[kclass.__name__] = []
full_qualified_name = kclass.__module__ + "." + kclass.__qualname__
registered_name_mappings[kclass.__name__].append(full_qualified_name)
# Register function with class names.

View file

@ -267,7 +267,7 @@ class GraphExecutionManager(GraphExecutionInterface):
provider_option_map = {"device_id": str(self._device.index)}
if not self.is_rocm_pytorch:
# Set Conv algo search mode to HEURISTIC by default, which is same as PyTorch's default setting.
conv_algo_search = ortmodule._defined_from_envvar("CONV_ALGO_SEARCH", "HEURISTIC", warn=True)
conv_algo_search = ortmodule._defined_from_envvar("ORTMODULE_CONV_ALGO_SEARCH", "HEURISTIC", warn=True)
if conv_algo_search not in ["HEURISTIC", "EXHAUSTIVE"]:
warnings.warn("Invalid value of env CONV_ALGO_SEARCH. Must be HEURISTIC or EXHAUSTIVE.")
conv_algo_search = "HEURISTIC"

View file

@ -2,10 +2,13 @@
# Licensed under the MIT License.
# debug_options.py
import tempfile
import warnings
import torch
from ... import ORTModule
from .... import ortmodule
from ...debug_options import DebugOptions
from ... import ORTModule
from ...debug_options import DebugOptions, LogLevel
# nn.Module's in this set are considered exportable to ONNX.
# For other nn.Module's, torch.onnx.export is called to check if
@ -15,6 +18,40 @@ _force_exportable_set = set(
)
class _IteratedORTModule(torch.nn.Module):
"""
It's possible that a module instance is called multiple times in a single forward() call with different inputs.
If the number of inputs or the data types are different, the exported graph for a given input set cannot be used
for others. The _IteratedORTModule class is used to handle this case. It creates multiple ORTModule instances
for a given nn.Module instance and uses one of them for each input set.
NOTE that we assume that for each step run, the running order of different input sets are same.
If it's not this case (e.g., a module is used for checkpointing so that a same input set is used twice),
this class cannot handle it. An ideal way is to maintain a map from different input sets (maybe compute a hash)
to ORTModule instances.
"""
def __init__(self, module, count, log_level, save_onnx, onnx_prefix):
super(_IteratedORTModule, self).__init__()
assert count > 1
self._count = count
self._it = count - 1
self._ortmodules = []
for idx in range(count):
self._ortmodules.append(
ORTModule(
module,
debug_options=DebugOptions(
log_level=log_level, save_onnx=save_onnx, onnx_prefix=onnx_prefix + "_it" + str(idx)
),
)
)
def forward(self, *inputs, **kwargs):
self._it = (self._it + 1) % self._count
return self._ortmodules[self._it](*inputs, **kwargs)
class HierarchicalORTModule(torch.nn.Module):
"""
Recursively wraps submodules of `module` as ORTModule whenever possible
@ -55,7 +92,9 @@ class HierarchicalORTModule(torch.nn.Module):
self._initialized = False
super(HierarchicalORTModule, self).__init__()
self._original_module = module
self._debug_options = debug_options if debug_options else DebugOptions()
self._log_level = debug_options.logging.log_level if debug_options else LogLevel.ERROR
self._save_onnx = debug_options.save_onnx_models.save if debug_options else False
self._name_prefix = debug_options.save_onnx_models.name_prefix if debug_options else ""
def _initialize(self, *args, **kwargs):
handle_pool = []
@ -69,17 +108,19 @@ class HierarchicalORTModule(torch.nn.Module):
module_arg_pool[module] = [args]
# Recursively hook "record_args" to module and all its sub-modules.
# The function "record_args" records the inputs for each nn.Module,
# and later we will try exporting those nn.Module's with their recorded
# inputs.
# The function "record_args" records the inputs for each nn.Module and later we will try exporting
# those nn.Module's with their recorded inputs.
# NOTE that if a module is not called from forward(), it will fail to be captured by this hook.
def recursive_hook(module):
# We cannot skip module in whitelist because it's possible that a module is called multiple times
# so that we still need to know the number of different input sets and use _IteratedORTModule to handle it.
handle_pool.append(module.register_forward_pre_hook(record_args))
for name, sub in module._modules.items():
if isinstance(sub, torch.nn.ModuleList):
for name1, sub1 in sub._modules.items():
recursive_hook(sub1)
for _, sub_module in module._modules.items():
if isinstance(sub_module, torch.nn.ModuleList):
for _, sub_module_item in sub_module._modules.items():
recursive_hook(sub_module_item)
else:
recursive_hook(sub)
recursive_hook(sub_module)
exportable_list = {}
@ -87,91 +128,75 @@ class HierarchicalORTModule(torch.nn.Module):
# "module" can be wrapped as ORTModule. Otherwise, "module" is
# not exportable to ONNX.
def check_exportable(module):
# forward functions of classes in _force_exportable_set may not be called
# thus not in module_arg_pool
def try_export(module, args):
try:
with tempfile.NamedTemporaryFile(prefix="sub-module") as temp, torch.no_grad():
torch.onnx.export(
module,
args,
temp,
opset_version=ortmodule.ONNX_OPSET_VERSION,
do_constant_folding=False,
export_params=False,
keep_initializers_as_inputs=True,
training=torch.onnx.TrainingMode.TRAINING,
)
except Exception as e:
if self._log_level <= LogLevel.WARNING:
warnings.warn(
f"Failed to export module with type {type(module).__name__}. Error message: {str(e)}",
UserWarning,
)
return False
return True
if type(module) in _force_exportable_set:
exportable_list[module] = True
return True
sub_dict = module._modules
if not sub_dict:
# No sub-module exists, so this module is a leaf
# module in overall model hierarchy.
exportable = True
# Check if this leaf module is exportable.
return
# It's possible that the model runs a module by calling some other function instead of forward()
# so that the module is not captured by the forward pre-hook. In this case, we will treat it as
# not exportable for now.
module_exportable = module in module_arg_pool
if module_exportable:
for args in module_arg_pool[module]:
try:
with tempfile.NamedTemporaryFile(prefix="sub-module") as temp:
torch.onnx.export(
module,
args,
temp,
opset_version=ortmodule.ONNX_OPSET_VERSION,
do_constant_folding=False,
export_params=False,
keep_initializers_as_inputs=True,
training=torch.onnx.TrainingMode.TRAINING,
)
except Exception as e:
exportable = False
if not try_export(module, args):
module_exportable = False
break
elif self._log_level <= LogLevel.WARNING:
warnings.warn(
f"Module with type {type(module).__name__} is not exportable because it's not in module_arg_pool.",
UserWarning,
)
exportable_list[module] = exportable
return exportable
else:
sub_exportable = True
for name, sub_module in sub_dict.items():
if isinstance(sub_module, torch.nn.ModuleList):
for name1, sub_module1 in sub_module._modules.items():
sub_exportable1 = check_exportable(sub_module1)
sub_exportable = sub_exportable and sub_exportable1
else:
sub_exportable1 = check_exportable(sub_module)
sub_exportable = sub_exportable and sub_exportable1
exportable_list[module] = module_exportable
if module_exportable:
return
if sub_exportable is False:
# At least one existing sub-module is not exportable,
# so is the entire module.
exportable_list[module] = sub_exportable
return sub_exportable
sub_module_dict = module._modules
if not sub_module_dict:
# No sub-module exists, so this module is a leaf
return
for _, sub_module in sub_module_dict.items():
if isinstance(sub_module, torch.nn.ModuleList):
for _, sub_module_item in sub_module._modules.items():
check_exportable(sub_module_item)
else:
# Now, we know all sub-modules are exportable, so
# we are going to check if the composition of them
# is still exportable at this module level.
module_exportable = True
for args in module_arg_pool[module]:
try:
with tempfile.NamedTemporaryFile(prefix="sub-module") as temp:
torch.onnx.export(
module,
args,
temp,
opset_version=ortmodule.ONNX_OPSET_VERSION,
do_constant_folding=False,
export_params=False,
keep_initializers_as_inputs=True,
training=torch.onnx.TrainingMode.TRAINING,
)
except Exception as e:
# If this module is not exportable for one arg
# group, we say this module is not exportable.
module_exportable = False
# Already found a broken case.
# No need to check next case.
break
exportable_list[module] = module_exportable
return exportable_list[module]
check_exportable(sub_module)
# Add a hook to record forward's input for all modules.
recursive_hook(self._original_module)
# Run forward with actual input to record all possible
# inputs for all invoked modules.
_ = self._original_module(*args, **kwargs)
with torch.no_grad():
_ = self._original_module(*args, **kwargs)
# We already have "supported_modules" so
# we no longer need those hooks in forward pass.
for h in handle_pool:
h.remove()
for handle in handle_pool:
handle.remove()
# Try exporter on all module-input pairs. If a module can be exported with
# all its recorded inputs, then it's exporable.
@ -184,27 +209,70 @@ class HierarchicalORTModule(torch.nn.Module):
# Top-down wrapper to replace nn.Module's with ORTModule.
# Note that using bottom-up wrapper may lead to much
# ORTModule instances and each ORTModule owns a much smaller graph.
def recursive_wrap(module):
sub_dict = module._modules
for name, sub in sub_dict.items():
if isinstance(sub, torch.nn.ModuleList):
def recursive_wrap(module, save_onnx=False, onnx_prefix=""):
sub_module_dict = module._modules
for name, sub_module in sub_module_dict.items():
new_prefix = onnx_prefix + "_" + name
if isinstance(sub_module, torch.nn.ModuleList):
# We encounter a list of sub-modules.
# Let's wrap them one-by-one.
for name1, sub1 in sub._modules.items():
if is_supported(sub1):
sub._modules[name1] = ORTModule(sub1, debug_options=self._debug_options)
idx = 0
for item_name, sub_module_item in sub_module._modules.items():
# Avoid saving too many graphs.
new_save_onnx = save_onnx and idx == 0
sub_new_prefix = new_prefix + "_" + item_name
if is_supported(sub_module_item):
if sub_module_item in module_arg_pool and len(module_arg_pool[sub_module_item]) > 1:
sub_module._modules[item_name] = _IteratedORTModule(
sub_module_item,
len(module_arg_pool[sub_module_item]),
self._log_level,
new_save_onnx,
sub_new_prefix,
)
else:
sub_module._modules[item_name] = ORTModule(
sub_module_item,
debug_options=DebugOptions(
log_level=self._log_level, save_onnx=new_save_onnx, onnx_prefix=sub_new_prefix
),
)
else:
recursive_wrap(sub1)
recursive_wrap(sub_module_item, new_save_onnx, sub_new_prefix)
idx += 1
else:
if is_supported(sub):
if is_supported(sub_module):
# Just wrap it as ORTModule when possible.
sub_dict[name] = ORTModule(sub, debug_options=self._debug_options)
if sub_module in module_arg_pool and len(module_arg_pool[sub_module]) > 1:
sub_module_dict[name] = _IteratedORTModule(
sub_module, len(module_arg_pool[sub_module]), self._log_level, save_onnx, new_prefix
)
else:
sub_module_dict[name] = ORTModule(
sub_module,
debug_options=DebugOptions(
log_level=self._log_level, save_onnx=save_onnx, onnx_prefix=new_prefix
),
)
else:
# This sub-module is not exportable to ONNX
# Let's check its sub-modules.
recursive_wrap(sub)
recursive_wrap(sub_module, save_onnx, new_prefix)
recursive_wrap(self._original_module)
if is_supported(self._original_module):
self._original_module = ORTModule(
self._original_module,
debug_options=DebugOptions(
log_level=self._log_level, save_onnx=self._save_onnx, onnx_prefix=self._name_prefix
),
)
else:
recursive_wrap(self._original_module, self._save_onnx, self._name_prefix)
if self._log_level <= LogLevel.WARNING:
warnings.warn(
f"Wrapped module: {str(self._original_module)}.",
UserWarning,
)
self._initialized = True
def forward(self, *inputs, **kwargs):

View file

@ -1,11 +1,13 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from collections.abc import Iterable
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections.abc import Iterable
from torch.utils.checkpoint import checkpoint
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.ortmodule.experimental.hierarchical_ortmodule import HierarchicalORTModule
@ -140,11 +142,52 @@ class MainWithMultiModuleOutputs(nn.Module):
return y1, y2
class G(nn.Module):
def __init__(self):
super(G, self).__init__()
self.l1 = nn.Linear(2, 2)
def forward(self, x):
if x.dtype == torch.float16:
x = x.to(torch.float32)
x = self.l1(x)
return x if x.dtype == torch.float32 else x.to(torch.float16)
def forward_fp16(self, x):
assert x.dtype == torch.float16
return self.l1(x.to(torch.float32)).to(torch.float16)
class MainWithModuleMultipleCalls(nn.Module):
# Module with mixed precision.
def __init__(self):
super(MainWithModuleMultipleCalls, self).__init__()
self.b = B()
self.g = G()
def forward(self, x):
x = self.g(x)
x = self.g(x.to(torch.float16)).to(torch.float32)
return self.b(x)
class MainWithNonForwardCall(nn.Module):
# Module with mixed precision.
def __init__(self):
super(MainWithNonForwardCall, self).__init__()
self.b = B()
self.g = G()
def forward(self, x):
x = self.g.forward_fp16(x.to(torch.float16)).to(torch.float32)
return self.b(x)
def test_hierarchical_ortmodule():
def count_ortmodule(module):
n = 1 if isinstance(module, ORTModule) else 0
def count_ortmodule(module, is_iterated=False):
n = 1 if type(module).__name__ == ("_IteratedORTModule" if is_iterated else "ORTModule") else 0
for sub in module._modules.values():
n = n + count_ortmodule(sub)
n = n + count_ortmodule(sub, is_iterated)
return n
def call_backward(y):
@ -162,7 +205,7 @@ def test_hierarchical_ortmodule():
else:
torch.allclose(y, y_ref)
def trial(module_to_wrap, args, expected_num_ortmodule):
def trial(module_to_wrap, args, expected_num_ortmodule, expected_num_iterated_ortmodule=0):
# Run baseline model.
m = module_to_wrap
@ -185,6 +228,7 @@ def test_hierarchical_ortmodule():
# Some sub-modules become ORTModule.
assert expected_num_ortmodule == count_ortmodule(m)
assert expected_num_iterated_ortmodule == count_ortmodule(m, is_iterated=True)
call_allclose(y, y_ref)
call_allclose(g, g_ref)
@ -196,6 +240,8 @@ def test_hierarchical_ortmodule():
trial(MainWithMultiModuleOutputs(), [torch.rand(2).requires_grad_()], 10)
trial(MainWithNonTensorInput(), [torch.rand(2).requires_grad_(), "reverse"], 6)
trial(MainWithNonTensorInput(), [torch.rand(2).requires_grad_(), "normal"], 6)
trial(MainWithModuleMultipleCalls(), [torch.rand(2).requires_grad_()], 2, 1)
trial(MainWithNonForwardCall(), [torch.rand(2).requires_grad_()], 3)
if __name__ == "__main__":

View file

@ -777,7 +777,7 @@ def test_gradient_correctness_conv1d(use_fp16, input_requires_grad, conv_algo_se
return
if conv_algo_search is not None:
os.environ["CONV_ALGO_SEARCH"] = conv_algo_search
os.environ["ORTMODULE_CONV_ALGO_SEARCH"] = conv_algo_search
device = "cuda"
N, seq_len, C_in, C_out, kernel_size = 32, 128, 1536, 1536, 3
@ -820,7 +820,7 @@ def test_gradient_correctness_conv1d(use_fp16, input_requires_grad, conv_algo_se
assert actual_conv_algo_search == expected_conv_algo_search
if conv_algo_search is not None:
del os.environ["CONV_ALGO_SEARCH"]
del os.environ["ORTMODULE_CONV_ALGO_SEARCH"]
def _run_gradient_correctness_transpose(perm, shape):

View file

@ -1104,3 +1104,107 @@ def test_non_differentiable_autograd_function():
assert torch.allclose(y_ref, y_train)
run()
def test_checkpoint_function():
class A(torch.nn.Module):
# A supported module.
def __init__(self):
super(A, self).__init__()
self.l1 = torch.nn.Linear(2, 2)
def forward(self, x):
return self.l1(x)
class B(torch.nn.Module):
# This module is not exportable to ONNX because it
# uses gradient-checkpointing. However, its two sub-module's
# are exportable, so ORTModule should be used to compute them.
def __init__(self):
super(B, self).__init__()
self.l1 = torch.nn.Linear(2, 2)
self.a = A()
def forward(self, x):
def custom():
def custom_forward(x_):
return self.a(x_)
return custom_forward
z = self.l1(torch.utils.checkpoint.checkpoint(custom(), x))
return z
def run():
m = B().to("cuda")
x = torch.rand((2, 2), dtype=torch.float).to("cuda")
# Baseline.
y_ref = m(x)
print("Ref:")
print(y_ref)
os.environ["ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT"] = "1"
m = ORTModule(m)
# Inferene mode.
y_infer = m(x)
print(y_infer)
assert torch.allclose(y_ref, y_infer)
# Training mode.
m.train()
y_train = m(x)
print("Train:")
assert torch.allclose(y_ref, y_train)
del os.environ["ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT"]
run()
def test_skipped_autograd_function():
class TestSkippedFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, x):
ctx.save_for_backward(x)
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors
return None
class TestSkippedModel(torch.nn.Module):
def __init__(self, output_size):
super(TestSkippedModel, self).__init__()
self.custom_fn = TestSkippedFunction.apply
self.bias = Parameter(torch.empty(output_size, device=torch.cuda.current_device(), dtype=torch.float))
with torch.no_grad():
self.bias.uniform_()
def forward(self, model_input):
# model_input did not require_grad
out = self.custom_fn(model_input)
return out + self.bias
output_size = 1024
os.environ[
"ORTMODULE_SKIPPED_AUTOGRAD_FUNCTIONS"
] = "orttraining_test_ortmodule_autograd.test_skipped_autograd_function.<locals>.TestSkippedFunction"
m = ORTModule(TestSkippedModel(output_size).to("cuda"))
can_run = True
try:
m(torch.randn(output_size, dtype=torch.float, device="cuda"))
except RuntimeError as e:
assert "No forward registered for TestSkippedFunction" in str(e)
can_run = False
assert not can_run
del os.environ["ORTMODULE_SKIPPED_AUTOGRAD_FUNCTIONS"]