From 0cc29095733cecc55efe0b8e0d8ff7cd2a9e427a Mon Sep 17 00:00:00 2001 From: baijumeswani Date: Fri, 3 Sep 2021 08:25:44 -0700 Subject: [PATCH] Auto forward non method attribute lookups to the user's model and bind custom methods to ORTModule (#8798) --- .../ortmodule/_graph_execution_manager.py | 16 +- .../ortmodule/_torch_module_pytorch.py | 2 +- .../python/training/ortmodule/_utils.py | 66 +++++ .../python/training/ortmodule/ortmodule.py | 66 ++++- .../python/orttraining_test_ortmodule_api.py | 233 +++++++++++++++++- 5 files changed, 375 insertions(+), 8 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 7ddd033cb3..8fe54e1784 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -22,6 +22,7 @@ from onnxruntime.capi import _pybind_state as C from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference from abc import ABC, abstractmethod import copy +from functools import reduce import io import inspect import os @@ -94,7 +95,10 @@ class GraphExecutionManager(GraphExecutionInterface): self._execution_agent = None # indicators of some logic have been executed previously thus could be skipped for faster training - self._skip_check = _SkipCheck.SKIP_CHECK_DISABLED + self._skip_check = reduce(lambda x, y: x | y, + [_SkipCheck[name] for name in + _utils.parse_os_env_skip_check_flags('ORTMODULE_SKIPCHECK_POLICY', + _SkipCheck.SKIP_CHECK_DISABLED.name)]) self._first_skip_check_warning = True # Graph transformer config @@ -158,6 +162,10 @@ class GraphExecutionManager(GraphExecutionInterface): # Memory aware gradient builder. self._use_memory_efficient_gradient = False + # Flag to re-export the model due to attribute change on original module. + # Re-export will be avoided if _skip_check is enabled. + self._original_model_has_changed = False + def _get_torch_gpu_allocator_function_addresses(self): if self._use_external_gpu_allocator and torch.cuda.is_available(): # CPP extension to get torch GPU allocator's alloc and free function addresses @@ -278,7 +286,7 @@ class GraphExecutionManager(GraphExecutionInterface): schema = _io._extract_schema( {'args': copy.copy(inputs), 'kwargs': copy.copy(kwargs)}) - if self._onnx_models.exported_model and schema == self._input_info.schema: + if self._onnx_models.exported_model and schema == self._input_info.schema and not self._original_model_has_changed: # All required models have already been exported previously return False @@ -421,3 +429,7 @@ class GraphExecutionManager(GraphExecutionInterface): # between forward calls. self._graph_initializers = [param for name, param in self._flattened_module.named_parameters() if name in self._graph_initializer_names] + + def signal_model_changed(self): + """Signals the execution manager to re-export the model on the next forward call""" + self._original_model_has_changed = True diff --git a/orttraining/orttraining/python/training/ortmodule/_torch_module_pytorch.py b/orttraining/orttraining/python/training/ortmodule/_torch_module_pytorch.py index 0bc79efb55..c1f6cc7192 100644 --- a/orttraining/orttraining/python/training/ortmodule/_torch_module_pytorch.py +++ b/orttraining/orttraining/python/training/ortmodule/_torch_module_pytorch.py @@ -92,4 +92,4 @@ class TorchModulePytorch(TorchModuleInterface): @TorchModuleInterface.module.getter def module(self): - return self._original_module.module + return self._original_module diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index e0cbbaf28f..c1a146b2ff 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -7,8 +7,14 @@ from onnxruntime.capi.onnxruntime_inference_collection import OrtValue from onnxruntime.capi import _pybind_state as C from ._fallback import _FallbackManager, ORTModuleFallbackException, ORTModuleDeviceException, wrap_exception +import os +import copy +import inspect import torch from torch.utils.dlpack import from_dlpack, to_dlpack +from typing import List +import types +import warnings def _ortvalue_to_torch_tensor(ortvalue): @@ -112,3 +118,63 @@ def _create_iobinding(io_binding, inputs, model, device): for value_info in model.graph.output: io_binding.bind_output(value_info.name, device.type, device_id=get_device_index(device)) + +def check_for_name_collisions_and_bind_methods_to_ortmodule(ortmodule: torch.nn.Module, + user_module: torch.nn.Module): + """Warns if there are any common attributes between the user's model and ORTModule and binds user methods to ORTModule + + If there are methods defined on the user's model that ORTModule does not recognize (custom methods), + then this function binds these methods to ORTModule. + + Args: + ortmodule: the ORTModule instance + user_module: the user's torch.nn.Module + + Raises: + UserWarning: If there are any overlapping attributes between the ortmodule and user_module (except forward) + """ + + ortmodule_attributes = dict(inspect.getmembers(ortmodule)) + torch_module_attributes = dict(inspect.getmembers(torch.nn.Module())) + user_module_attributes = inspect.getmembers(user_module) + + # Check if any user defined attribute collides with ORTModule's attributes + for attribute_name, attribute in user_module_attributes: + if inspect.ismethod(attribute): + # Skip the dunder methods + if attribute_name.startswith('__'): + continue + + # if the attribute is not a torch attribute, or if the torch attribute + # corresponding to attribute_name is not a method or the user attribute + # does not equal the torch attribute, then this is a user defined method. + if attribute_name not in torch_module_attributes or \ + not inspect.ismethod(torch_module_attributes[attribute_name]) or \ + attribute.__func__ != torch_module_attributes[attribute_name].__func__: + + # forward is expected to be defined by the user. + if attribute_name == 'forward': + continue + + # This is a user defined/overriden method. Check for collisions. + if attribute_name in ortmodule_attributes: + # This is a user defined method, issue a warning. + warnings.warn(f"User Module's attribute name {attribute_name} collides with ORTModule's attribute name. " + "User Module's method may not be called upon invocation through ORTModule.") + else: + # This is a custom method, copy it and bind the copy to ORTModule. + # This is needed for cases where the user's custom method invokes + # the forward method. It should go through ORTModule's forward implementation + # and not go through the user defined forward method. + ortmodule.__dict__[attribute_name] = types.MethodType(copy.deepcopy(attribute.__func__), ortmodule) + else: + if attribute_name not in torch_module_attributes and attribute_name in ortmodule_attributes: + # This is a user defined attribute that collides with ORTModule + if attribute_name in ortmodule_attributes: + warnings.warn(f"User Module's attribute name {attribute_name} collides with ORTModule's attribute name. " + "User Module's attribute may not be returned when trying to retrieve the attribute through ORTModule.") + +def parse_os_env_skip_check_flags(env_name, default_skip_check_str): + """Returns a list of SkipChecks as defined by os env variable env_name or default provided""" + + return os.getenv(env_name, default_skip_check_str).split('|') diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index 964582c107..e2cd8fdec5 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -5,8 +5,10 @@ from ._torch_module_factory import TorchModuleFactory from ._torch_module_pytorch import TorchModulePytorch +from ._torch_module_ort import TorchModuleORT from ._custom_op_symbolic_registry import CustomOpSymbolicRegistry from ._custom_gradient_registry import CustomGradientRegistry +from . import _utils from .debug_options import DebugOptions from ._fallback import _FallbackManager, _FallbackPolicy, ORTModuleFallbackException, ORTModuleTorchModelException, wrap_exception from . import _FALLBACK_INIT_EXCEPTION, MINIMUM_RUNTIME_PYTORCH_VERSION_STR, ORTMODULE_FALLBACK_POLICY, ORTMODULE_FALLBACK_RETRY @@ -15,6 +17,7 @@ from onnxruntime.tools import pytorch_export_contrib_ops import functools import torch from typing import Iterator, Optional, Tuple, TypeVar, Set, Callable +import warnings # Needed to override PyTorch methods T = TypeVar('T', bound='Module') @@ -32,6 +35,18 @@ class ORTModule(torch.nn.Module): """ def __init__(self, module, debug_options=None): + + # NOTE: torch.nn.Modules that call setattr on their internal attributes regularly + # (for example PyTorch Lightning), will trigger regular re-exports. This is + # because ORTModule auto detects such setattrs on the original module and + # marks the model as stale. This is a known limitation. To disable repeated + # re-export checks when not required, please set the environment variable + # ORTMODULE_SKIPCHECK_POLICY to SKIP_CHECK_BUILD_GRADIENT|SKIP_CHECK_EXECUTION_AGENT + + # Set _is_initialized attribute first which starts off as False. + # This variable will be used for comparing strings in __setattr__ and __getattr__ + # NOTE: Do not rename/move. + self._is_initialized = False # Python default arguments are evaluated on function definition # and not on function invocation. So, if debug_options is not provided, # instantiate it inside the function. @@ -41,12 +56,15 @@ class ORTModule(torch.nn.Module): # Fallback settings self._fallback_manager = _FallbackManager(policy=ORTMODULE_FALLBACK_POLICY, retry=ORTMODULE_FALLBACK_RETRY) + try: # Read ORTModule module initialization status global _FALLBACK_INIT_EXCEPTION if _FALLBACK_INIT_EXCEPTION: raise _FALLBACK_INIT_EXCEPTION + super(ORTModule, self).__init__() + self._torch_module = TorchModuleFactory()(module, debug_options, self._fallback_manager) # Create forward dynamically, so each ORTModule instance will have its own copy. @@ -68,16 +86,19 @@ class ORTModule(torch.nn.Module): functools.update_wrapper( self.forward.__func__, self._torch_module.forward.__func__) - super(ORTModule, self).__init__() - # Support contrib OPs pytorch_export_contrib_ops.register() CustomOpSymbolicRegistry.register_all() CustomGradientRegistry.register_all() + # Warn user if there are name collisions between user model's and ORTModule attributes + # And if there are custom methods defined on the user's model, copy and bind them to + # ORTModule. + _utils.check_for_name_collisions_and_bind_methods_to_ortmodule(self, module) + except ORTModuleFallbackException as e: self._torch_module = TorchModulePytorch(module) - # TODO: Rework after "custom methods" task is designed + # TODO: Rework by implementing the "__getattribute__" method. # Assigning all default attributes from user's original torch.nn.Module into ORTModule self._backward_hooks = module._backward_hooks self._forward_hooks = module._forward_hooks @@ -92,7 +113,6 @@ class ORTModule(torch.nn.Module): self.forward = module.forward # Exceptions subject to fallback are handled here - # import pdb; pdb.set_trace() self._fallback_manager.handle_exception(exception=e, log_level=debug_options.logging.log_level) except Exception as e: @@ -102,6 +122,11 @@ class ORTModule(torch.nn.Module): log_level=debug_options.logging.log_level, override_policy=_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD) + # Finally, ORTModule initialization is complete. + # Assign self._is_initialized to True after all the ORTModule class attributes have been assigned + # else, they will be assigned to self._torch_module.original_module instead. + self._is_initialized = True + # IMPORTANT: DO NOT add code here # This declaration is for automatic document generation purposes only # The actual forward implementation is bound during ORTModule initialization @@ -269,3 +294,36 @@ class ORTModule(torch.nn.Module): """Override :meth:`~torch.nn.Module.named_modules`""" yield from self._torch_module.named_modules(*args, **kwargs) + + def __getattr__(self, name: str): + if '_is_initialized' in self.__dict__ and self.__dict__['_is_initialized'] == True: + # If ORTModule is intitialized and attribute is not found in ORTModule, + # it must be present in the user's torch.nn.Module. Forward the call to + # the user's model. + assert '_torch_module' in self.__dict__, "ORTModule does not have a reference to the user's model" + return getattr(self.module, name) + else: + return super(ORTModule, self).__getattr__(name) + + def __setattr__(self, name: str, value) -> None: + + if name in self.__dict__: + # If the name is an attribute of ORTModule, update only ORTModule + self.__dict__[name] = value + + elif '_is_initialized' in self.__dict__ and self.__dict__['_is_initialized'] == True: + + assert '_torch_module' in self.__dict__, "ORTModule does not have a reference to the user's model" + + # If the name is an attribute of user model, or is a new attribute, update there. + # Set the attribute on the user's original module + setattr(self.module, name, value) + # Signal to execution manager to re-export the model. + # Re-export will be avoided if _skip_check is enabled. + if isinstance(self._torch_module, TorchModuleORT): + for training_mode in [False, True]: + self._torch_module._execution_manager(training_mode).signal_model_changed() + + else: + # Setting any new attributes should be done on ORTModule only when 'torch_module' is not defined + self.__dict__[name] = value diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 7fc2583c4e..b44337479b 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -2,6 +2,8 @@ # Licensed under the MIT License. # orttraining_test_ortmodule_api.py +from functools import reduce +import itertools import math import random import copy @@ -18,7 +20,7 @@ from inspect import signature import tempfile import os -from onnxruntime.training.ortmodule import ORTModule, _utils, _io, DebugOptions, LogLevel, _fallback +from onnxruntime.training.ortmodule import ORTModule, _utils, _io, DebugOptions, LogLevel, _fallback, _graph_execution_manager import _test_helpers # Import autocasting libs @@ -3469,3 +3471,232 @@ def test_ortmodule_list_dict_input_with_kwargs_and_registered_buffer(): kwargs_input_copy = copy.deepcopy(kwargs_input) _test_helpers.assert_values_are_close(pt_model(x, **kwargs_input), ort_model(x_copy, **kwargs_input_copy)) + +def test_ortmodule_user_defined_method(): + class UserDefinedMethodsNet(torch.nn.Module): + def __init__(self): + super(UserDefinedMethodsNet, self).__init__() + self.dummy = torch.nn.Parameter(torch.FloatTensor([12])) + + def forward(self, a): + return self.dummy + a + + def custom_method_returns_input(self, user_input): + return user_input + + device = 'cuda' + N, D_in, H, D_out = 64, 784, 500, 10 + pt_model = UserDefinedMethodsNet().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + x = torch.randn(N, D_in, device=device) + y = copy.deepcopy(x) + + out = ort_model.custom_method_returns_input(x) + assert x is out + + pt_out = pt_model(x) + ort_out = ort_model(y) + _test_helpers.assert_values_are_close(pt_out, ort_out) + +def test_ortmodule_user_getattr_gets_successfully(): + class UserDefinedMethodsNet(torch.nn.Module): + def __init__(self): + super(UserDefinedMethodsNet, self).__init__() + self.dummy = torch.nn.Parameter(torch.FloatTensor([12])) + + def forward(self, a): + return self.dummy + a + + def custom_method_returns_input(self, user_input): + return user_input + + device = 'cuda' + N, D_in, H, D_out = 64, 784, 500, 10 + pt_model = UserDefinedMethodsNet().to(device) + ort_model = ORTModule(pt_model) + + assert ort_model.custom_method_returns_input != pt_model.custom_method_returns_input + assert ort_model.custom_method_returns_input.__func__ == pt_model.custom_method_returns_input.__func__ + assert ort_model.dummy is pt_model.dummy + +@pytest.mark.parametrize("attribute", ['True', 'lambda x : x']) +def test_ortmodule_setattr_new_attribute(attribute): + class UserNet(torch.nn.Module): + def __init__(self): + super(UserNet, self).__init__() + self.dummy = torch.nn.Parameter(torch.FloatTensor([0])) + + def forward(self, a): + return self.dummy + a + + device = 'cuda' + pt_model = UserNet().to(device) + ort_model = ORTModule(pt_model) + ort_model.a_new_attribute = attribute + + assert hasattr(pt_model, 'a_new_attribute') + assert pt_model.a_new_attribute == attribute + assert 'a_new_attribute' not in ort_model.__dict__ + +def test_ortmodule_setattr_on_ortmodule_copied_user_model_attribute(): + class UserNet(torch.nn.Module): + def __init__(self): + super(UserNet, self).__init__() + self.dummy = torch.nn.Parameter(torch.FloatTensor([0])) + + def forward(self, a): + return self.dummy + a + + def custom_method(self, a): + return a + + def my_new_custom_method(self, a, b, c): + return a + b + c + + device = 'cuda' + pt_model = UserNet().to(device) + ort_model = ORTModule(pt_model) + # custom_method is copied by ORTModule from the users model + # and bound to itself + ort_model.custom_method = my_new_custom_method + # dummy is defined on pt model + ort_model.dummy = torch.nn.Parameter(torch.FloatTensor([12])) + + assert hasattr(pt_model, 'dummy') + assert torch.eq(pt_model.dummy, torch.nn.Parameter(torch.FloatTensor([12]))) + assert 'dummy' not in ort_model.__dict__ + + assert hasattr(pt_model, 'custom_method') + assert pt_model.custom_method is not my_new_custom_method + assert ort_model.custom_method is my_new_custom_method + +def test_ortmodule_setattr_ortmodule_attribute(): + class UserNet(torch.nn.Module): + def __init__(self): + super(UserNet, self).__init__() + self.dummy = torch.nn.Parameter(torch.FloatTensor([0])) + + def forward(self, a): + return self.dummy + a + + device = 'cuda' + pt_model = UserNet().to(device) + ort_model = ORTModule(pt_model) + ort_model._torch_module = True + + assert not hasattr(pt_model, '_torch_module') + assert '_torch_module' in ort_model.__dict__ + assert ort_model._torch_module == True + +def test_ortmodule_setattr_signals_model_changed(): + class UserNet(torch.nn.Module): + def __init__(self, input_flag): + super(UserNet, self).__init__() + self.dummy = torch.nn.Parameter(torch.FloatTensor([10])) + self.input_flag = input_flag + + def forward(self, a): + if self.input_flag: + return self.dummy + a + else: + return a + + device = 'cuda' + N, D_in, H, D_out = 64, 784, 500, 10 + pt_model = UserNet(True).to(device) + ort_model = ORTModule(pt_model) + + _ = ort_model(torch.randn(N, D_in, device=device)) + exported_model1 = ort_model._torch_module._execution_manager(True)._onnx_models.exported_model + + for training_mode in [False, True]: + assert ort_model._torch_module._execution_manager(training_mode)._original_model_has_changed == False + ort_model.input_flag = False + + for training_mode in [False, True]: + assert ort_model._torch_module._execution_manager(training_mode)._original_model_has_changed == True + + _ = ort_model(torch.randn(N, D_in, device=device)) + exported_model2 = ort_model._torch_module._execution_manager(True)._onnx_models.exported_model + + assert exported_model1 != exported_model2 + +def test_ortmodule_attribute_name_collision_warning(): + class UserNet(torch.nn.Module): + def __init__(self): + super(UserNet, self).__init__() + self.dummy = torch.nn.Parameter(torch.FloatTensor([0])) + self._torch_module = True + + def forward(self, a): + return self.dummy + a + + def load_state_dict(self): + pass + + device = 'cuda' + pt_model = UserNet().to(device) + with pytest.warns(UserWarning) as warning_record: + ort_model = ORTModule(pt_model) + + assert len(warning_record) == 2 + assert "_torch_module collides with ORTModule's attribute name." in warning_record[0].message.args[0] + assert "load_state_dict collides with ORTModule's attribute name." in warning_record[1].message.args[0] + +def test_ortmodule_ortmodule_method_attribute_copy(): + class UserNetWithSelfCallingForward(torch.nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super(UserNetWithSelfCallingForward, self).__init__() + + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(hidden_size, num_classes) + + def forward(self, input1): + out = self.fc1(input1) + out = self.relu(out) + out = self.fc2(out) + return out + + def run_forward(self, *args, **kwargs): + return self(*args, **kwargs) + + device = 'cuda' + N, D_in, H, D_out = 64, 784, 500, 10 + pt_model = UserNetWithSelfCallingForward(D_in, H, D_out).to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + x_1 = torch.randn(N, D_in, device=device) + x_2 = copy.deepcopy(x_1) + x_3 = copy.deepcopy(x_1) + # Executed on ORTModule + out1 = ort_model(x_1) + # Executed on ORTModule even though run_forward is not defined on ORTModule + out2 = ort_model.run_forward(x_2) + # Executed on pytorch module since it is directly invoked from there + out3 = pt_model.run_forward(x_3) + + assert torch.equal(out1, out2) + _test_helpers.assert_values_are_close(out2, out3) + + assert type(out1.grad_fn).__name__ == '_ORTModuleFunctionBackward' + assert type(out2.grad_fn).__name__ == '_ORTModuleFunctionBackward' + assert type(out3.grad_fn).__name__ == 'AddmmBackward' + +@pytest.mark.parametrize("policy_str, policy",[ + ('SKIP_CHECK_DISABLED', _graph_execution_manager._SkipCheck.SKIP_CHECK_DISABLED), + ('SKIP_CHECK_DEVICE', _graph_execution_manager._SkipCheck.SKIP_CHECK_DEVICE), + ('SKIP_CHECK_BUILD_GRADIENT', _graph_execution_manager._SkipCheck.SKIP_CHECK_BUILD_GRADIENT), + ('SKIP_CHECK_EXECUTION_AGENT', _graph_execution_manager._SkipCheck.SKIP_CHECK_EXECUTION_AGENT), +]) +def test_ortmodule_skip_check_load_from_os_env(policy_str, policy): + device = 'cuda' + N, D_in, H, D_out = 64, 784, 500, 10 + os.environ['ORTMODULE_SKIPCHECK_POLICY'] = policy_str + model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) + ort_model = ORTModule(model) + + for training_mode in [False, True]: + assert ort_model._torch_module._execution_manager(training_mode)._skip_check == policy + + del os.environ['ORTMODULE_SKIPCHECK_POLICY']