mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
Fixes of Hierarchical ORTModule and ORTModule PythonOp (#13347)
The PR applies some fixes to Hierarchical ORTModule and ORTModule PythonOp. For Hierarchical ORTModule: - Don't wrap module if the caller is to call other function instead of forward() function - Support single module instance is call multiple times with different types of inputs - Check if module can be warped from top to bottom instead of from bottom to top For ORTModule PythonOp: - Add env variable control to allow using torch.utils.checkpoint.CheckpointFunction - Add env variable control to skip register some autograd functions so that there is no conflict for some models.
This commit is contained in:
parent
418304743d
commit
b6b3f41636
6 changed files with 337 additions and 103 deletions
|
|
@ -12,6 +12,7 @@ from packaging import version
|
|||
from torch.onnx import symbolic_helper
|
||||
|
||||
from onnxruntime.capi._pybind_state import register_torch_autograd_function
|
||||
from onnxruntime.training import ortmodule
|
||||
|
||||
from . import _logger
|
||||
from ._fallback import ORTModuleONNXModelException, wrap_exception
|
||||
|
|
@ -42,6 +43,13 @@ _CAST_PYTORCH_TO_ONNX = {
|
|||
}
|
||||
|
||||
|
||||
def _full_name(klass):
|
||||
module = klass.__module__
|
||||
if module == "builtins":
|
||||
return klass.__qualname__ # avoid outputs like 'builtins.str'
|
||||
return module + "." + klass.__qualname__
|
||||
|
||||
|
||||
def _pytorch_type_to_onnx(scalar_type: str) -> torch.onnx.TensorProtoDataType:
|
||||
try:
|
||||
return torch.onnx.JitScalarType.from_name(scalar_type).onnx_type()
|
||||
|
|
@ -66,7 +74,10 @@ def _export_pt_1_10(g, n, *args, **kwargs):
|
|||
"""
|
||||
try:
|
||||
name = kwargs["name"]
|
||||
if name in BANNED_AUTOGRAD_FUNCTION_NAMES:
|
||||
if name in BANNED_AUTOGRAD_FUNCTION_NAMES and (
|
||||
not ortmodule._defined_from_envvar("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", 0)
|
||||
or name != torch.utils.checkpoint.CheckpointFunction.__name__
|
||||
):
|
||||
raise Exception(
|
||||
f"The autograd.Function {name} should not be exported to ONNX. "
|
||||
"Please replace ORTModule with HierarchalORTModule to only"
|
||||
|
|
@ -238,11 +249,16 @@ def _post_process_after_export(exported_model, enable_custom_autograd_function,
|
|||
|
||||
def _post_process_enabling_autograd_fallback(exported_model):
|
||||
registered_name_mappings = {}
|
||||
skipped_autograd_function_list = ortmodule._defined_from_envvar("ORTMODULE_SKIPPED_AUTOGRAD_FUNCTIONS", "").split(
|
||||
","
|
||||
)
|
||||
for kclass in torch.autograd.Function.__subclasses__():
|
||||
full_qualified_name = _full_name(kclass)
|
||||
if full_qualified_name in skipped_autograd_function_list:
|
||||
continue
|
||||
# Collect mapping of class names to full qualified class names.
|
||||
if kclass.__name__ not in registered_name_mappings:
|
||||
registered_name_mappings[kclass.__name__] = []
|
||||
full_qualified_name = kclass.__module__ + "." + kclass.__qualname__
|
||||
registered_name_mappings[kclass.__name__].append(full_qualified_name)
|
||||
|
||||
# Register function with class names.
|
||||
|
|
|
|||
|
|
@ -267,7 +267,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
provider_option_map = {"device_id": str(self._device.index)}
|
||||
if not self.is_rocm_pytorch:
|
||||
# Set Conv algo search mode to HEURISTIC by default, which is same as PyTorch's default setting.
|
||||
conv_algo_search = ortmodule._defined_from_envvar("CONV_ALGO_SEARCH", "HEURISTIC", warn=True)
|
||||
conv_algo_search = ortmodule._defined_from_envvar("ORTMODULE_CONV_ALGO_SEARCH", "HEURISTIC", warn=True)
|
||||
if conv_algo_search not in ["HEURISTIC", "EXHAUSTIVE"]:
|
||||
warnings.warn("Invalid value of env CONV_ALGO_SEARCH. Must be HEURISTIC or EXHAUSTIVE.")
|
||||
conv_algo_search = "HEURISTIC"
|
||||
|
|
|
|||
|
|
@ -2,10 +2,13 @@
|
|||
# Licensed under the MIT License.
|
||||
# debug_options.py
|
||||
import tempfile
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from ... import ORTModule
|
||||
|
||||
from .... import ortmodule
|
||||
from ...debug_options import DebugOptions
|
||||
from ... import ORTModule
|
||||
from ...debug_options import DebugOptions, LogLevel
|
||||
|
||||
# nn.Module's in this set are considered exportable to ONNX.
|
||||
# For other nn.Module's, torch.onnx.export is called to check if
|
||||
|
|
@ -15,6 +18,40 @@ _force_exportable_set = set(
|
|||
)
|
||||
|
||||
|
||||
class _IteratedORTModule(torch.nn.Module):
|
||||
"""
|
||||
It's possible that a module instance is called multiple times in a single forward() call with different inputs.
|
||||
If the number of inputs or the data types are different, the exported graph for a given input set cannot be used
|
||||
for others. The _IteratedORTModule class is used to handle this case. It creates multiple ORTModule instances
|
||||
for a given nn.Module instance and uses one of them for each input set.
|
||||
|
||||
NOTE that we assume that for each step run, the running order of different input sets are same.
|
||||
If it's not this case (e.g., a module is used for checkpointing so that a same input set is used twice),
|
||||
this class cannot handle it. An ideal way is to maintain a map from different input sets (maybe compute a hash)
|
||||
to ORTModule instances.
|
||||
"""
|
||||
|
||||
def __init__(self, module, count, log_level, save_onnx, onnx_prefix):
|
||||
super(_IteratedORTModule, self).__init__()
|
||||
assert count > 1
|
||||
self._count = count
|
||||
self._it = count - 1
|
||||
self._ortmodules = []
|
||||
for idx in range(count):
|
||||
self._ortmodules.append(
|
||||
ORTModule(
|
||||
module,
|
||||
debug_options=DebugOptions(
|
||||
log_level=log_level, save_onnx=save_onnx, onnx_prefix=onnx_prefix + "_it" + str(idx)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
self._it = (self._it + 1) % self._count
|
||||
return self._ortmodules[self._it](*inputs, **kwargs)
|
||||
|
||||
|
||||
class HierarchicalORTModule(torch.nn.Module):
|
||||
"""
|
||||
Recursively wraps submodules of `module` as ORTModule whenever possible
|
||||
|
|
@ -55,7 +92,9 @@ class HierarchicalORTModule(torch.nn.Module):
|
|||
self._initialized = False
|
||||
super(HierarchicalORTModule, self).__init__()
|
||||
self._original_module = module
|
||||
self._debug_options = debug_options if debug_options else DebugOptions()
|
||||
self._log_level = debug_options.logging.log_level if debug_options else LogLevel.ERROR
|
||||
self._save_onnx = debug_options.save_onnx_models.save if debug_options else False
|
||||
self._name_prefix = debug_options.save_onnx_models.name_prefix if debug_options else ""
|
||||
|
||||
def _initialize(self, *args, **kwargs):
|
||||
handle_pool = []
|
||||
|
|
@ -69,17 +108,19 @@ class HierarchicalORTModule(torch.nn.Module):
|
|||
module_arg_pool[module] = [args]
|
||||
|
||||
# Recursively hook "record_args" to module and all its sub-modules.
|
||||
# The function "record_args" records the inputs for each nn.Module,
|
||||
# and later we will try exporting those nn.Module's with their recorded
|
||||
# inputs.
|
||||
# The function "record_args" records the inputs for each nn.Module and later we will try exporting
|
||||
# those nn.Module's with their recorded inputs.
|
||||
# NOTE that if a module is not called from forward(), it will fail to be captured by this hook.
|
||||
def recursive_hook(module):
|
||||
# We cannot skip module in whitelist because it's possible that a module is called multiple times
|
||||
# so that we still need to know the number of different input sets and use _IteratedORTModule to handle it.
|
||||
handle_pool.append(module.register_forward_pre_hook(record_args))
|
||||
for name, sub in module._modules.items():
|
||||
if isinstance(sub, torch.nn.ModuleList):
|
||||
for name1, sub1 in sub._modules.items():
|
||||
recursive_hook(sub1)
|
||||
for _, sub_module in module._modules.items():
|
||||
if isinstance(sub_module, torch.nn.ModuleList):
|
||||
for _, sub_module_item in sub_module._modules.items():
|
||||
recursive_hook(sub_module_item)
|
||||
else:
|
||||
recursive_hook(sub)
|
||||
recursive_hook(sub_module)
|
||||
|
||||
exportable_list = {}
|
||||
|
||||
|
|
@ -87,91 +128,75 @@ class HierarchicalORTModule(torch.nn.Module):
|
|||
# "module" can be wrapped as ORTModule. Otherwise, "module" is
|
||||
# not exportable to ONNX.
|
||||
def check_exportable(module):
|
||||
# forward functions of classes in _force_exportable_set may not be called
|
||||
# thus not in module_arg_pool
|
||||
def try_export(module, args):
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(prefix="sub-module") as temp, torch.no_grad():
|
||||
torch.onnx.export(
|
||||
module,
|
||||
args,
|
||||
temp,
|
||||
opset_version=ortmodule.ONNX_OPSET_VERSION,
|
||||
do_constant_folding=False,
|
||||
export_params=False,
|
||||
keep_initializers_as_inputs=True,
|
||||
training=torch.onnx.TrainingMode.TRAINING,
|
||||
)
|
||||
except Exception as e:
|
||||
if self._log_level <= LogLevel.WARNING:
|
||||
warnings.warn(
|
||||
f"Failed to export module with type {type(module).__name__}. Error message: {str(e)}",
|
||||
UserWarning,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
if type(module) in _force_exportable_set:
|
||||
exportable_list[module] = True
|
||||
return True
|
||||
sub_dict = module._modules
|
||||
if not sub_dict:
|
||||
# No sub-module exists, so this module is a leaf
|
||||
# module in overall model hierarchy.
|
||||
exportable = True
|
||||
# Check if this leaf module is exportable.
|
||||
return
|
||||
|
||||
# It's possible that the model runs a module by calling some other function instead of forward()
|
||||
# so that the module is not captured by the forward pre-hook. In this case, we will treat it as
|
||||
# not exportable for now.
|
||||
module_exportable = module in module_arg_pool
|
||||
if module_exportable:
|
||||
for args in module_arg_pool[module]:
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(prefix="sub-module") as temp:
|
||||
torch.onnx.export(
|
||||
module,
|
||||
args,
|
||||
temp,
|
||||
opset_version=ortmodule.ONNX_OPSET_VERSION,
|
||||
do_constant_folding=False,
|
||||
export_params=False,
|
||||
keep_initializers_as_inputs=True,
|
||||
training=torch.onnx.TrainingMode.TRAINING,
|
||||
)
|
||||
except Exception as e:
|
||||
exportable = False
|
||||
if not try_export(module, args):
|
||||
module_exportable = False
|
||||
break
|
||||
elif self._log_level <= LogLevel.WARNING:
|
||||
warnings.warn(
|
||||
f"Module with type {type(module).__name__} is not exportable because it's not in module_arg_pool.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
exportable_list[module] = exportable
|
||||
return exportable
|
||||
else:
|
||||
sub_exportable = True
|
||||
for name, sub_module in sub_dict.items():
|
||||
if isinstance(sub_module, torch.nn.ModuleList):
|
||||
for name1, sub_module1 in sub_module._modules.items():
|
||||
sub_exportable1 = check_exportable(sub_module1)
|
||||
sub_exportable = sub_exportable and sub_exportable1
|
||||
else:
|
||||
sub_exportable1 = check_exportable(sub_module)
|
||||
sub_exportable = sub_exportable and sub_exportable1
|
||||
exportable_list[module] = module_exportable
|
||||
if module_exportable:
|
||||
return
|
||||
|
||||
if sub_exportable is False:
|
||||
# At least one existing sub-module is not exportable,
|
||||
# so is the entire module.
|
||||
exportable_list[module] = sub_exportable
|
||||
return sub_exportable
|
||||
sub_module_dict = module._modules
|
||||
if not sub_module_dict:
|
||||
# No sub-module exists, so this module is a leaf
|
||||
return
|
||||
|
||||
for _, sub_module in sub_module_dict.items():
|
||||
if isinstance(sub_module, torch.nn.ModuleList):
|
||||
for _, sub_module_item in sub_module._modules.items():
|
||||
check_exportable(sub_module_item)
|
||||
else:
|
||||
# Now, we know all sub-modules are exportable, so
|
||||
# we are going to check if the composition of them
|
||||
# is still exportable at this module level.
|
||||
module_exportable = True
|
||||
for args in module_arg_pool[module]:
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(prefix="sub-module") as temp:
|
||||
torch.onnx.export(
|
||||
module,
|
||||
args,
|
||||
temp,
|
||||
opset_version=ortmodule.ONNX_OPSET_VERSION,
|
||||
do_constant_folding=False,
|
||||
export_params=False,
|
||||
keep_initializers_as_inputs=True,
|
||||
training=torch.onnx.TrainingMode.TRAINING,
|
||||
)
|
||||
except Exception as e:
|
||||
# If this module is not exportable for one arg
|
||||
# group, we say this module is not exportable.
|
||||
module_exportable = False
|
||||
# Already found a broken case.
|
||||
# No need to check next case.
|
||||
break
|
||||
|
||||
exportable_list[module] = module_exportable
|
||||
return exportable_list[module]
|
||||
check_exportable(sub_module)
|
||||
|
||||
# Add a hook to record forward's input for all modules.
|
||||
recursive_hook(self._original_module)
|
||||
|
||||
# Run forward with actual input to record all possible
|
||||
# inputs for all invoked modules.
|
||||
_ = self._original_module(*args, **kwargs)
|
||||
with torch.no_grad():
|
||||
_ = self._original_module(*args, **kwargs)
|
||||
|
||||
# We already have "supported_modules" so
|
||||
# we no longer need those hooks in forward pass.
|
||||
for h in handle_pool:
|
||||
h.remove()
|
||||
for handle in handle_pool:
|
||||
handle.remove()
|
||||
|
||||
# Try exporter on all module-input pairs. If a module can be exported with
|
||||
# all its recorded inputs, then it's exporable.
|
||||
|
|
@ -184,27 +209,70 @@ class HierarchicalORTModule(torch.nn.Module):
|
|||
# Top-down wrapper to replace nn.Module's with ORTModule.
|
||||
# Note that using bottom-up wrapper may lead to much
|
||||
# ORTModule instances and each ORTModule owns a much smaller graph.
|
||||
def recursive_wrap(module):
|
||||
sub_dict = module._modules
|
||||
for name, sub in sub_dict.items():
|
||||
if isinstance(sub, torch.nn.ModuleList):
|
||||
def recursive_wrap(module, save_onnx=False, onnx_prefix=""):
|
||||
sub_module_dict = module._modules
|
||||
for name, sub_module in sub_module_dict.items():
|
||||
new_prefix = onnx_prefix + "_" + name
|
||||
if isinstance(sub_module, torch.nn.ModuleList):
|
||||
# We encounter a list of sub-modules.
|
||||
# Let's wrap them one-by-one.
|
||||
for name1, sub1 in sub._modules.items():
|
||||
if is_supported(sub1):
|
||||
sub._modules[name1] = ORTModule(sub1, debug_options=self._debug_options)
|
||||
idx = 0
|
||||
for item_name, sub_module_item in sub_module._modules.items():
|
||||
# Avoid saving too many graphs.
|
||||
new_save_onnx = save_onnx and idx == 0
|
||||
sub_new_prefix = new_prefix + "_" + item_name
|
||||
if is_supported(sub_module_item):
|
||||
if sub_module_item in module_arg_pool and len(module_arg_pool[sub_module_item]) > 1:
|
||||
sub_module._modules[item_name] = _IteratedORTModule(
|
||||
sub_module_item,
|
||||
len(module_arg_pool[sub_module_item]),
|
||||
self._log_level,
|
||||
new_save_onnx,
|
||||
sub_new_prefix,
|
||||
)
|
||||
else:
|
||||
sub_module._modules[item_name] = ORTModule(
|
||||
sub_module_item,
|
||||
debug_options=DebugOptions(
|
||||
log_level=self._log_level, save_onnx=new_save_onnx, onnx_prefix=sub_new_prefix
|
||||
),
|
||||
)
|
||||
else:
|
||||
recursive_wrap(sub1)
|
||||
recursive_wrap(sub_module_item, new_save_onnx, sub_new_prefix)
|
||||
idx += 1
|
||||
else:
|
||||
if is_supported(sub):
|
||||
if is_supported(sub_module):
|
||||
# Just wrap it as ORTModule when possible.
|
||||
sub_dict[name] = ORTModule(sub, debug_options=self._debug_options)
|
||||
if sub_module in module_arg_pool and len(module_arg_pool[sub_module]) > 1:
|
||||
sub_module_dict[name] = _IteratedORTModule(
|
||||
sub_module, len(module_arg_pool[sub_module]), self._log_level, save_onnx, new_prefix
|
||||
)
|
||||
else:
|
||||
sub_module_dict[name] = ORTModule(
|
||||
sub_module,
|
||||
debug_options=DebugOptions(
|
||||
log_level=self._log_level, save_onnx=save_onnx, onnx_prefix=new_prefix
|
||||
),
|
||||
)
|
||||
else:
|
||||
# This sub-module is not exportable to ONNX
|
||||
# Let's check its sub-modules.
|
||||
recursive_wrap(sub)
|
||||
recursive_wrap(sub_module, save_onnx, new_prefix)
|
||||
|
||||
recursive_wrap(self._original_module)
|
||||
if is_supported(self._original_module):
|
||||
self._original_module = ORTModule(
|
||||
self._original_module,
|
||||
debug_options=DebugOptions(
|
||||
log_level=self._log_level, save_onnx=self._save_onnx, onnx_prefix=self._name_prefix
|
||||
),
|
||||
)
|
||||
else:
|
||||
recursive_wrap(self._original_module, self._save_onnx, self._name_prefix)
|
||||
if self._log_level <= LogLevel.WARNING:
|
||||
warnings.warn(
|
||||
f"Wrapped module: {str(self._original_module)}.",
|
||||
UserWarning,
|
||||
)
|
||||
self._initialized = True
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from collections.abc import Iterable
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from onnxruntime.training.ortmodule import ORTModule
|
||||
from onnxruntime.training.ortmodule.experimental.hierarchical_ortmodule import HierarchicalORTModule
|
||||
|
||||
|
|
@ -140,11 +142,52 @@ class MainWithMultiModuleOutputs(nn.Module):
|
|||
return y1, y2
|
||||
|
||||
|
||||
class G(nn.Module):
|
||||
def __init__(self):
|
||||
super(G, self).__init__()
|
||||
self.l1 = nn.Linear(2, 2)
|
||||
|
||||
def forward(self, x):
|
||||
if x.dtype == torch.float16:
|
||||
x = x.to(torch.float32)
|
||||
x = self.l1(x)
|
||||
return x if x.dtype == torch.float32 else x.to(torch.float16)
|
||||
|
||||
def forward_fp16(self, x):
|
||||
assert x.dtype == torch.float16
|
||||
return self.l1(x.to(torch.float32)).to(torch.float16)
|
||||
|
||||
|
||||
class MainWithModuleMultipleCalls(nn.Module):
|
||||
# Module with mixed precision.
|
||||
def __init__(self):
|
||||
super(MainWithModuleMultipleCalls, self).__init__()
|
||||
self.b = B()
|
||||
self.g = G()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.g(x)
|
||||
x = self.g(x.to(torch.float16)).to(torch.float32)
|
||||
return self.b(x)
|
||||
|
||||
|
||||
class MainWithNonForwardCall(nn.Module):
|
||||
# Module with mixed precision.
|
||||
def __init__(self):
|
||||
super(MainWithNonForwardCall, self).__init__()
|
||||
self.b = B()
|
||||
self.g = G()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.g.forward_fp16(x.to(torch.float16)).to(torch.float32)
|
||||
return self.b(x)
|
||||
|
||||
|
||||
def test_hierarchical_ortmodule():
|
||||
def count_ortmodule(module):
|
||||
n = 1 if isinstance(module, ORTModule) else 0
|
||||
def count_ortmodule(module, is_iterated=False):
|
||||
n = 1 if type(module).__name__ == ("_IteratedORTModule" if is_iterated else "ORTModule") else 0
|
||||
for sub in module._modules.values():
|
||||
n = n + count_ortmodule(sub)
|
||||
n = n + count_ortmodule(sub, is_iterated)
|
||||
return n
|
||||
|
||||
def call_backward(y):
|
||||
|
|
@ -162,7 +205,7 @@ def test_hierarchical_ortmodule():
|
|||
else:
|
||||
torch.allclose(y, y_ref)
|
||||
|
||||
def trial(module_to_wrap, args, expected_num_ortmodule):
|
||||
def trial(module_to_wrap, args, expected_num_ortmodule, expected_num_iterated_ortmodule=0):
|
||||
# Run baseline model.
|
||||
m = module_to_wrap
|
||||
|
||||
|
|
@ -185,6 +228,7 @@ def test_hierarchical_ortmodule():
|
|||
|
||||
# Some sub-modules become ORTModule.
|
||||
assert expected_num_ortmodule == count_ortmodule(m)
|
||||
assert expected_num_iterated_ortmodule == count_ortmodule(m, is_iterated=True)
|
||||
|
||||
call_allclose(y, y_ref)
|
||||
call_allclose(g, g_ref)
|
||||
|
|
@ -196,6 +240,8 @@ def test_hierarchical_ortmodule():
|
|||
trial(MainWithMultiModuleOutputs(), [torch.rand(2).requires_grad_()], 10)
|
||||
trial(MainWithNonTensorInput(), [torch.rand(2).requires_grad_(), "reverse"], 6)
|
||||
trial(MainWithNonTensorInput(), [torch.rand(2).requires_grad_(), "normal"], 6)
|
||||
trial(MainWithModuleMultipleCalls(), [torch.rand(2).requires_grad_()], 2, 1)
|
||||
trial(MainWithNonForwardCall(), [torch.rand(2).requires_grad_()], 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -777,7 +777,7 @@ def test_gradient_correctness_conv1d(use_fp16, input_requires_grad, conv_algo_se
|
|||
return
|
||||
|
||||
if conv_algo_search is not None:
|
||||
os.environ["CONV_ALGO_SEARCH"] = conv_algo_search
|
||||
os.environ["ORTMODULE_CONV_ALGO_SEARCH"] = conv_algo_search
|
||||
|
||||
device = "cuda"
|
||||
N, seq_len, C_in, C_out, kernel_size = 32, 128, 1536, 1536, 3
|
||||
|
|
@ -820,7 +820,7 @@ def test_gradient_correctness_conv1d(use_fp16, input_requires_grad, conv_algo_se
|
|||
assert actual_conv_algo_search == expected_conv_algo_search
|
||||
|
||||
if conv_algo_search is not None:
|
||||
del os.environ["CONV_ALGO_SEARCH"]
|
||||
del os.environ["ORTMODULE_CONV_ALGO_SEARCH"]
|
||||
|
||||
|
||||
def _run_gradient_correctness_transpose(perm, shape):
|
||||
|
|
|
|||
|
|
@ -1104,3 +1104,107 @@ def test_non_differentiable_autograd_function():
|
|||
assert torch.allclose(y_ref, y_train)
|
||||
|
||||
run()
|
||||
|
||||
|
||||
def test_checkpoint_function():
|
||||
class A(torch.nn.Module):
|
||||
# A supported module.
|
||||
def __init__(self):
|
||||
super(A, self).__init__()
|
||||
self.l1 = torch.nn.Linear(2, 2)
|
||||
|
||||
def forward(self, x):
|
||||
return self.l1(x)
|
||||
|
||||
class B(torch.nn.Module):
|
||||
# This module is not exportable to ONNX because it
|
||||
# uses gradient-checkpointing. However, its two sub-module's
|
||||
# are exportable, so ORTModule should be used to compute them.
|
||||
def __init__(self):
|
||||
super(B, self).__init__()
|
||||
self.l1 = torch.nn.Linear(2, 2)
|
||||
self.a = A()
|
||||
|
||||
def forward(self, x):
|
||||
def custom():
|
||||
def custom_forward(x_):
|
||||
return self.a(x_)
|
||||
|
||||
return custom_forward
|
||||
|
||||
z = self.l1(torch.utils.checkpoint.checkpoint(custom(), x))
|
||||
return z
|
||||
|
||||
def run():
|
||||
m = B().to("cuda")
|
||||
x = torch.rand((2, 2), dtype=torch.float).to("cuda")
|
||||
|
||||
# Baseline.
|
||||
y_ref = m(x)
|
||||
print("Ref:")
|
||||
print(y_ref)
|
||||
|
||||
os.environ["ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT"] = "1"
|
||||
|
||||
m = ORTModule(m)
|
||||
|
||||
# Inferene mode.
|
||||
y_infer = m(x)
|
||||
print(y_infer)
|
||||
assert torch.allclose(y_ref, y_infer)
|
||||
|
||||
# Training mode.
|
||||
m.train()
|
||||
y_train = m(x)
|
||||
print("Train:")
|
||||
assert torch.allclose(y_ref, y_train)
|
||||
|
||||
del os.environ["ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT"]
|
||||
|
||||
run()
|
||||
|
||||
|
||||
def test_skipped_autograd_function():
|
||||
class TestSkippedFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# bias is an optional argument
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_tensors
|
||||
return None
|
||||
|
||||
class TestSkippedModel(torch.nn.Module):
|
||||
def __init__(self, output_size):
|
||||
super(TestSkippedModel, self).__init__()
|
||||
self.custom_fn = TestSkippedFunction.apply
|
||||
self.bias = Parameter(torch.empty(output_size, device=torch.cuda.current_device(), dtype=torch.float))
|
||||
|
||||
with torch.no_grad():
|
||||
self.bias.uniform_()
|
||||
|
||||
def forward(self, model_input):
|
||||
# model_input did not require_grad
|
||||
out = self.custom_fn(model_input)
|
||||
return out + self.bias
|
||||
|
||||
output_size = 1024
|
||||
|
||||
os.environ[
|
||||
"ORTMODULE_SKIPPED_AUTOGRAD_FUNCTIONS"
|
||||
] = "orttraining_test_ortmodule_autograd.test_skipped_autograd_function.<locals>.TestSkippedFunction"
|
||||
|
||||
m = ORTModule(TestSkippedModel(output_size).to("cuda"))
|
||||
can_run = True
|
||||
try:
|
||||
m(torch.randn(output_size, dtype=torch.float, device="cuda"))
|
||||
except RuntimeError as e:
|
||||
assert "No forward registered for TestSkippedFunction" in str(e)
|
||||
can_run = False
|
||||
|
||||
assert not can_run
|
||||
|
||||
del os.environ["ORTMODULE_SKIPPED_AUTOGRAD_FUNCTIONS"]
|
||||
|
|
|
|||
Loading…
Reference in a new issue