Auto forward non method attribute lookups to the user's model and bind custom methods to ORTModule (#8798)

This commit is contained in:
baijumeswani 2021-09-03 08:25:44 -07:00 committed by GitHub
parent c343f7cb43
commit 0cc2909573
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 375 additions and 8 deletions

View file

@ -22,6 +22,7 @@ from onnxruntime.capi import _pybind_state as C
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
from abc import ABC, abstractmethod
import copy
from functools import reduce
import io
import inspect
import os
@ -94,7 +95,10 @@ class GraphExecutionManager(GraphExecutionInterface):
self._execution_agent = None
# indicators of some logic have been executed previously thus could be skipped for faster training
self._skip_check = _SkipCheck.SKIP_CHECK_DISABLED
self._skip_check = reduce(lambda x, y: x | y,
[_SkipCheck[name] for name in
_utils.parse_os_env_skip_check_flags('ORTMODULE_SKIPCHECK_POLICY',
_SkipCheck.SKIP_CHECK_DISABLED.name)])
self._first_skip_check_warning = True
# Graph transformer config
@ -158,6 +162,10 @@ class GraphExecutionManager(GraphExecutionInterface):
# Memory aware gradient builder.
self._use_memory_efficient_gradient = False
# Flag to re-export the model due to attribute change on original module.
# Re-export will be avoided if _skip_check is enabled.
self._original_model_has_changed = False
def _get_torch_gpu_allocator_function_addresses(self):
if self._use_external_gpu_allocator and torch.cuda.is_available():
# CPP extension to get torch GPU allocator's alloc and free function addresses
@ -278,7 +286,7 @@ class GraphExecutionManager(GraphExecutionInterface):
schema = _io._extract_schema(
{'args': copy.copy(inputs), 'kwargs': copy.copy(kwargs)})
if self._onnx_models.exported_model and schema == self._input_info.schema:
if self._onnx_models.exported_model and schema == self._input_info.schema and not self._original_model_has_changed:
# All required models have already been exported previously
return False
@ -421,3 +429,7 @@ class GraphExecutionManager(GraphExecutionInterface):
# between forward calls.
self._graph_initializers = [param for name, param in self._flattened_module.named_parameters()
if name in self._graph_initializer_names]
def signal_model_changed(self):
"""Signals the execution manager to re-export the model on the next forward call"""
self._original_model_has_changed = True

View file

@ -92,4 +92,4 @@ class TorchModulePytorch(TorchModuleInterface):
@TorchModuleInterface.module.getter
def module(self):
return self._original_module.module
return self._original_module

View file

@ -7,8 +7,14 @@ from onnxruntime.capi.onnxruntime_inference_collection import OrtValue
from onnxruntime.capi import _pybind_state as C
from ._fallback import _FallbackManager, ORTModuleFallbackException, ORTModuleDeviceException, wrap_exception
import os
import copy
import inspect
import torch
from torch.utils.dlpack import from_dlpack, to_dlpack
from typing import List
import types
import warnings
def _ortvalue_to_torch_tensor(ortvalue):
@ -112,3 +118,63 @@ def _create_iobinding(io_binding, inputs, model, device):
for value_info in model.graph.output:
io_binding.bind_output(value_info.name, device.type, device_id=get_device_index(device))
def check_for_name_collisions_and_bind_methods_to_ortmodule(ortmodule: torch.nn.Module,
user_module: torch.nn.Module):
"""Warns if there are any common attributes between the user's model and ORTModule and binds user methods to ORTModule
If there are methods defined on the user's model that ORTModule does not recognize (custom methods),
then this function binds these methods to ORTModule.
Args:
ortmodule: the ORTModule instance
user_module: the user's torch.nn.Module
Raises:
UserWarning: If there are any overlapping attributes between the ortmodule and user_module (except forward)
"""
ortmodule_attributes = dict(inspect.getmembers(ortmodule))
torch_module_attributes = dict(inspect.getmembers(torch.nn.Module()))
user_module_attributes = inspect.getmembers(user_module)
# Check if any user defined attribute collides with ORTModule's attributes
for attribute_name, attribute in user_module_attributes:
if inspect.ismethod(attribute):
# Skip the dunder methods
if attribute_name.startswith('__'):
continue
# if the attribute is not a torch attribute, or if the torch attribute
# corresponding to attribute_name is not a method or the user attribute
# does not equal the torch attribute, then this is a user defined method.
if attribute_name not in torch_module_attributes or \
not inspect.ismethod(torch_module_attributes[attribute_name]) or \
attribute.__func__ != torch_module_attributes[attribute_name].__func__:
# forward is expected to be defined by the user.
if attribute_name == 'forward':
continue
# This is a user defined/overriden method. Check for collisions.
if attribute_name in ortmodule_attributes:
# This is a user defined method, issue a warning.
warnings.warn(f"User Module's attribute name {attribute_name} collides with ORTModule's attribute name. "
"User Module's method may not be called upon invocation through ORTModule.")
else:
# This is a custom method, copy it and bind the copy to ORTModule.
# This is needed for cases where the user's custom method invokes
# the forward method. It should go through ORTModule's forward implementation
# and not go through the user defined forward method.
ortmodule.__dict__[attribute_name] = types.MethodType(copy.deepcopy(attribute.__func__), ortmodule)
else:
if attribute_name not in torch_module_attributes and attribute_name in ortmodule_attributes:
# This is a user defined attribute that collides with ORTModule
if attribute_name in ortmodule_attributes:
warnings.warn(f"User Module's attribute name {attribute_name} collides with ORTModule's attribute name. "
"User Module's attribute may not be returned when trying to retrieve the attribute through ORTModule.")
def parse_os_env_skip_check_flags(env_name, default_skip_check_str):
"""Returns a list of SkipChecks as defined by os env variable env_name or default provided"""
return os.getenv(env_name, default_skip_check_str).split('|')

View file

@ -5,8 +5,10 @@
from ._torch_module_factory import TorchModuleFactory
from ._torch_module_pytorch import TorchModulePytorch
from ._torch_module_ort import TorchModuleORT
from ._custom_op_symbolic_registry import CustomOpSymbolicRegistry
from ._custom_gradient_registry import CustomGradientRegistry
from . import _utils
from .debug_options import DebugOptions
from ._fallback import _FallbackManager, _FallbackPolicy, ORTModuleFallbackException, ORTModuleTorchModelException, wrap_exception
from . import _FALLBACK_INIT_EXCEPTION, MINIMUM_RUNTIME_PYTORCH_VERSION_STR, ORTMODULE_FALLBACK_POLICY, ORTMODULE_FALLBACK_RETRY
@ -15,6 +17,7 @@ from onnxruntime.tools import pytorch_export_contrib_ops
import functools
import torch
from typing import Iterator, Optional, Tuple, TypeVar, Set, Callable
import warnings
# Needed to override PyTorch methods
T = TypeVar('T', bound='Module')
@ -32,6 +35,18 @@ class ORTModule(torch.nn.Module):
"""
def __init__(self, module, debug_options=None):
# NOTE: torch.nn.Modules that call setattr on their internal attributes regularly
# (for example PyTorch Lightning), will trigger regular re-exports. This is
# because ORTModule auto detects such setattrs on the original module and
# marks the model as stale. This is a known limitation. To disable repeated
# re-export checks when not required, please set the environment variable
# ORTMODULE_SKIPCHECK_POLICY to SKIP_CHECK_BUILD_GRADIENT|SKIP_CHECK_EXECUTION_AGENT
# Set _is_initialized attribute first which starts off as False.
# This variable will be used for comparing strings in __setattr__ and __getattr__
# NOTE: Do not rename/move.
self._is_initialized = False
# Python default arguments are evaluated on function definition
# and not on function invocation. So, if debug_options is not provided,
# instantiate it inside the function.
@ -41,12 +56,15 @@ class ORTModule(torch.nn.Module):
# Fallback settings
self._fallback_manager = _FallbackManager(policy=ORTMODULE_FALLBACK_POLICY,
retry=ORTMODULE_FALLBACK_RETRY)
try:
# Read ORTModule module initialization status
global _FALLBACK_INIT_EXCEPTION
if _FALLBACK_INIT_EXCEPTION:
raise _FALLBACK_INIT_EXCEPTION
super(ORTModule, self).__init__()
self._torch_module = TorchModuleFactory()(module, debug_options, self._fallback_manager)
# Create forward dynamically, so each ORTModule instance will have its own copy.
@ -68,16 +86,19 @@ class ORTModule(torch.nn.Module):
functools.update_wrapper(
self.forward.__func__, self._torch_module.forward.__func__)
super(ORTModule, self).__init__()
# Support contrib OPs
pytorch_export_contrib_ops.register()
CustomOpSymbolicRegistry.register_all()
CustomGradientRegistry.register_all()
# Warn user if there are name collisions between user model's and ORTModule attributes
# And if there are custom methods defined on the user's model, copy and bind them to
# ORTModule.
_utils.check_for_name_collisions_and_bind_methods_to_ortmodule(self, module)
except ORTModuleFallbackException as e:
self._torch_module = TorchModulePytorch(module)
# TODO: Rework after "custom methods" task is designed
# TODO: Rework by implementing the "__getattribute__" method.
# Assigning all default attributes from user's original torch.nn.Module into ORTModule
self._backward_hooks = module._backward_hooks
self._forward_hooks = module._forward_hooks
@ -92,7 +113,6 @@ class ORTModule(torch.nn.Module):
self.forward = module.forward
# Exceptions subject to fallback are handled here
# import pdb; pdb.set_trace()
self._fallback_manager.handle_exception(exception=e,
log_level=debug_options.logging.log_level)
except Exception as e:
@ -102,6 +122,11 @@ class ORTModule(torch.nn.Module):
log_level=debug_options.logging.log_level,
override_policy=_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD)
# Finally, ORTModule initialization is complete.
# Assign self._is_initialized to True after all the ORTModule class attributes have been assigned
# else, they will be assigned to self._torch_module.original_module instead.
self._is_initialized = True
# IMPORTANT: DO NOT add code here
# This declaration is for automatic document generation purposes only
# The actual forward implementation is bound during ORTModule initialization
@ -269,3 +294,36 @@ class ORTModule(torch.nn.Module):
"""Override :meth:`~torch.nn.Module.named_modules`"""
yield from self._torch_module.named_modules(*args, **kwargs)
def __getattr__(self, name: str):
if '_is_initialized' in self.__dict__ and self.__dict__['_is_initialized'] == True:
# If ORTModule is intitialized and attribute is not found in ORTModule,
# it must be present in the user's torch.nn.Module. Forward the call to
# the user's model.
assert '_torch_module' in self.__dict__, "ORTModule does not have a reference to the user's model"
return getattr(self.module, name)
else:
return super(ORTModule, self).__getattr__(name)
def __setattr__(self, name: str, value) -> None:
if name in self.__dict__:
# If the name is an attribute of ORTModule, update only ORTModule
self.__dict__[name] = value
elif '_is_initialized' in self.__dict__ and self.__dict__['_is_initialized'] == True:
assert '_torch_module' in self.__dict__, "ORTModule does not have a reference to the user's model"
# If the name is an attribute of user model, or is a new attribute, update there.
# Set the attribute on the user's original module
setattr(self.module, name, value)
# Signal to execution manager to re-export the model.
# Re-export will be avoided if _skip_check is enabled.
if isinstance(self._torch_module, TorchModuleORT):
for training_mode in [False, True]:
self._torch_module._execution_manager(training_mode).signal_model_changed()
else:
# Setting any new attributes should be done on ORTModule only when 'torch_module' is not defined
self.__dict__[name] = value

View file

@ -2,6 +2,8 @@
# Licensed under the MIT License.
# orttraining_test_ortmodule_api.py
from functools import reduce
import itertools
import math
import random
import copy
@ -18,7 +20,7 @@ from inspect import signature
import tempfile
import os
from onnxruntime.training.ortmodule import ORTModule, _utils, _io, DebugOptions, LogLevel, _fallback
from onnxruntime.training.ortmodule import ORTModule, _utils, _io, DebugOptions, LogLevel, _fallback, _graph_execution_manager
import _test_helpers
# Import autocasting libs
@ -3469,3 +3471,232 @@ def test_ortmodule_list_dict_input_with_kwargs_and_registered_buffer():
kwargs_input_copy = copy.deepcopy(kwargs_input)
_test_helpers.assert_values_are_close(pt_model(x, **kwargs_input), ort_model(x_copy, **kwargs_input_copy))
def test_ortmodule_user_defined_method():
class UserDefinedMethodsNet(torch.nn.Module):
def __init__(self):
super(UserDefinedMethodsNet, self).__init__()
self.dummy = torch.nn.Parameter(torch.FloatTensor([12]))
def forward(self, a):
return self.dummy + a
def custom_method_returns_input(self, user_input):
return user_input
device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
pt_model = UserDefinedMethodsNet().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
x = torch.randn(N, D_in, device=device)
y = copy.deepcopy(x)
out = ort_model.custom_method_returns_input(x)
assert x is out
pt_out = pt_model(x)
ort_out = ort_model(y)
_test_helpers.assert_values_are_close(pt_out, ort_out)
def test_ortmodule_user_getattr_gets_successfully():
class UserDefinedMethodsNet(torch.nn.Module):
def __init__(self):
super(UserDefinedMethodsNet, self).__init__()
self.dummy = torch.nn.Parameter(torch.FloatTensor([12]))
def forward(self, a):
return self.dummy + a
def custom_method_returns_input(self, user_input):
return user_input
device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
pt_model = UserDefinedMethodsNet().to(device)
ort_model = ORTModule(pt_model)
assert ort_model.custom_method_returns_input != pt_model.custom_method_returns_input
assert ort_model.custom_method_returns_input.__func__ == pt_model.custom_method_returns_input.__func__
assert ort_model.dummy is pt_model.dummy
@pytest.mark.parametrize("attribute", ['True', 'lambda x : x'])
def test_ortmodule_setattr_new_attribute(attribute):
class UserNet(torch.nn.Module):
def __init__(self):
super(UserNet, self).__init__()
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))
def forward(self, a):
return self.dummy + a
device = 'cuda'
pt_model = UserNet().to(device)
ort_model = ORTModule(pt_model)
ort_model.a_new_attribute = attribute
assert hasattr(pt_model, 'a_new_attribute')
assert pt_model.a_new_attribute == attribute
assert 'a_new_attribute' not in ort_model.__dict__
def test_ortmodule_setattr_on_ortmodule_copied_user_model_attribute():
class UserNet(torch.nn.Module):
def __init__(self):
super(UserNet, self).__init__()
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))
def forward(self, a):
return self.dummy + a
def custom_method(self, a):
return a
def my_new_custom_method(self, a, b, c):
return a + b + c
device = 'cuda'
pt_model = UserNet().to(device)
ort_model = ORTModule(pt_model)
# custom_method is copied by ORTModule from the users model
# and bound to itself
ort_model.custom_method = my_new_custom_method
# dummy is defined on pt model
ort_model.dummy = torch.nn.Parameter(torch.FloatTensor([12]))
assert hasattr(pt_model, 'dummy')
assert torch.eq(pt_model.dummy, torch.nn.Parameter(torch.FloatTensor([12])))
assert 'dummy' not in ort_model.__dict__
assert hasattr(pt_model, 'custom_method')
assert pt_model.custom_method is not my_new_custom_method
assert ort_model.custom_method is my_new_custom_method
def test_ortmodule_setattr_ortmodule_attribute():
class UserNet(torch.nn.Module):
def __init__(self):
super(UserNet, self).__init__()
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))
def forward(self, a):
return self.dummy + a
device = 'cuda'
pt_model = UserNet().to(device)
ort_model = ORTModule(pt_model)
ort_model._torch_module = True
assert not hasattr(pt_model, '_torch_module')
assert '_torch_module' in ort_model.__dict__
assert ort_model._torch_module == True
def test_ortmodule_setattr_signals_model_changed():
class UserNet(torch.nn.Module):
def __init__(self, input_flag):
super(UserNet, self).__init__()
self.dummy = torch.nn.Parameter(torch.FloatTensor([10]))
self.input_flag = input_flag
def forward(self, a):
if self.input_flag:
return self.dummy + a
else:
return a
device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
pt_model = UserNet(True).to(device)
ort_model = ORTModule(pt_model)
_ = ort_model(torch.randn(N, D_in, device=device))
exported_model1 = ort_model._torch_module._execution_manager(True)._onnx_models.exported_model
for training_mode in [False, True]:
assert ort_model._torch_module._execution_manager(training_mode)._original_model_has_changed == False
ort_model.input_flag = False
for training_mode in [False, True]:
assert ort_model._torch_module._execution_manager(training_mode)._original_model_has_changed == True
_ = ort_model(torch.randn(N, D_in, device=device))
exported_model2 = ort_model._torch_module._execution_manager(True)._onnx_models.exported_model
assert exported_model1 != exported_model2
def test_ortmodule_attribute_name_collision_warning():
class UserNet(torch.nn.Module):
def __init__(self):
super(UserNet, self).__init__()
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))
self._torch_module = True
def forward(self, a):
return self.dummy + a
def load_state_dict(self):
pass
device = 'cuda'
pt_model = UserNet().to(device)
with pytest.warns(UserWarning) as warning_record:
ort_model = ORTModule(pt_model)
assert len(warning_record) == 2
assert "_torch_module collides with ORTModule's attribute name." in warning_record[0].message.args[0]
assert "load_state_dict collides with ORTModule's attribute name." in warning_record[1].message.args[0]
def test_ortmodule_ortmodule_method_attribute_copy():
class UserNetWithSelfCallingForward(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(UserNetWithSelfCallingForward, self).__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
def forward(self, input1):
out = self.fc1(input1)
out = self.relu(out)
out = self.fc2(out)
return out
def run_forward(self, *args, **kwargs):
return self(*args, **kwargs)
device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
pt_model = UserNetWithSelfCallingForward(D_in, H, D_out).to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
x_1 = torch.randn(N, D_in, device=device)
x_2 = copy.deepcopy(x_1)
x_3 = copy.deepcopy(x_1)
# Executed on ORTModule
out1 = ort_model(x_1)
# Executed on ORTModule even though run_forward is not defined on ORTModule
out2 = ort_model.run_forward(x_2)
# Executed on pytorch module since it is directly invoked from there
out3 = pt_model.run_forward(x_3)
assert torch.equal(out1, out2)
_test_helpers.assert_values_are_close(out2, out3)
assert type(out1.grad_fn).__name__ == '_ORTModuleFunctionBackward'
assert type(out2.grad_fn).__name__ == '_ORTModuleFunctionBackward'
assert type(out3.grad_fn).__name__ == 'AddmmBackward'
@pytest.mark.parametrize("policy_str, policy",[
('SKIP_CHECK_DISABLED', _graph_execution_manager._SkipCheck.SKIP_CHECK_DISABLED),
('SKIP_CHECK_DEVICE', _graph_execution_manager._SkipCheck.SKIP_CHECK_DEVICE),
('SKIP_CHECK_BUILD_GRADIENT', _graph_execution_manager._SkipCheck.SKIP_CHECK_BUILD_GRADIENT),
('SKIP_CHECK_EXECUTION_AGENT', _graph_execution_manager._SkipCheck.SKIP_CHECK_EXECUTION_AGENT),
])
def test_ortmodule_skip_check_load_from_os_env(policy_str, policy):
device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
os.environ['ORTMODULE_SKIPCHECK_POLICY'] = policy_str
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
ort_model = ORTModule(model)
for training_mode in [False, True]:
assert ort_model._torch_module._execution_manager(training_mode)._skip_check == policy
del os.environ['ORTMODULE_SKIPCHECK_POLICY']