Resolve issue with wrapped ORTModule load_state_dict (#7847)

* Encapsulate children modules inside a ModuleAccessor object to prevent erroneuos iteration over children while loading the state dictionary

* Add named_models, models, apply methods, change ModuleAccessor to ModuleMetadata and modify unit tests

* Change ModuleMetadata module getter logic, raise NotImplementedError for add_modules

* Add comment explaining why overriding _load_from_state_dict method is needed
This commit is contained in:
baijumeswani 2021-05-27 16:11:37 -07:00 committed by GitHub
parent 8140e3fde5
commit ddf4aaaae1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 127 additions and 36 deletions

View file

@ -97,3 +97,10 @@ def _create_iobinding(io_binding, inputs, model, device):
for value_info in model.graph.output:
io_binding.bind_output(value_info.name, device.type, device_id=get_device_index(device))
class _PytorchModuleMetadata():
"""Encapsulates modules and allows easy access as required"""
def __init__(self, original_module, flattened_module):
self.original_module = original_module
self.flattened_module = flattened_module

View file

@ -5,12 +5,13 @@
from . import _io
from ._graph_execution_manager_factory import GraphExecutionManagerFactory
from ._utils import _PytorchModuleMetadata
from onnxruntime.training import register_custom_ops_pytorch_exporter
import functools
import torch
from typing import Iterator, Optional, Tuple, TypeVar
from typing import Iterator, Optional, Tuple, TypeVar, Set, Callable
# Needed to override PyTorch methods
T = TypeVar('T', bound='Module')
@ -51,12 +52,11 @@ class ORTModule(torch.nn.Module):
register_custom_ops_pytorch_exporter.register_custom_op(is_ortmodule=True)
# User module is wrapped to use its initializers and save computed gradients
self._original_module = module
# along with the module that flattens both input and output of the user module
# inside _PytorchModuleMetadata
self._module_metadata = _PytorchModuleMetadata(module, _io._FlattenedModule(module))
# Get the module that flattens both input and output
self._flattened_module = _io._FlattenedModule(self._original_module)
self._execution_manager = GraphExecutionManagerFactory(self._flattened_module)
self._execution_manager = GraphExecutionManagerFactory(self._module_metadata.flattened_module)
# IMPORTANT: DO NOT add code here
# This declaration is for automatic document generation purposes only
@ -65,57 +65,82 @@ class ORTModule(torch.nn.Module):
'''Dummy documentation for forward method'''
...
def _apply(self, fn):
"""Override original method to delegate execution to the flattened PyTorch user module"""
# Delegation must happen to _flattened_module since methods depend on
# _apply to recursively apply the internal setting changes
self._module_metadata.flattened_module._apply(fn)
return self
def apply(self: T, fn: Callable[['Module'], None]) -> T:
"""Override original method to delegate execution to the flattened PyTorch user module"""
# Delegation must happen to _flattened_module since methods depend on
# apply to recursively apply the internal setting changes
self._module_metadata.flattened_module.apply(fn)
return self
def _is_training(self):
return self._flattened_module.training and torch.is_grad_enabled()
return self.training and torch.is_grad_enabled()
def train(self: T, mode: bool = True) -> T:
"""Override original method to delegate execution to the flattened PyTorch user module"""
# Since _modules is empty, the task needs to be delegated to _module.flattened_module.train
# which will recursively update the original_module
self.training = mode
self._module_metadata.flattened_module.train(mode)
return self
def state_dict(self, destination=None, prefix='', keep_vars=False):
"""Override original method to delegate execution to the base module"""
"""Override original method to delegate execution to the original PyTorch user module"""
# Override the state_dict() method so that the state dict key names
# do not contain the _flattened_module._original_module prefix
return self._original_module.state_dict(
# do not contain the flattened_module._original_module prefix
return self._module_metadata.original_module.state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars)
def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
strict: bool = True):
"""Override original method to delegate execution to the base module"""
"""Override original method to delegate execution to the original PyTorch user module"""
# Override the load_state_dict() method so that the loaded state dict
# key names does not need to contain the _flattened_module._original_module prefix
return self._original_module.load_state_dict(
# key names does not need to contain the _module.flattened_module._original_module prefix
return self._module_metadata.original_module.load_state_dict(
state_dict, strict=strict)
def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None:
"""Override original method to delegate execution to the base module"""
self._original_module.register_buffer(name, tensor, persistent=persistent)
"""Override original method to delegate execution to the original PyTorch user module"""
self._module_metadata.original_module.register_buffer(name, tensor, persistent=persistent)
def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
"""Override original method to delegate execution to the base module"""
self._original_module.register_parameter(name, param)
"""Override original method to delegate execution to the original PyTorch user module"""
self._module_metadata.original_module.register_parameter(name, param)
def get_parameter(self, target: str) -> torch.nn.Parameter:
"""Override original method to delegate execution to the base module"""
return self._original_module.get_parameter(target)
"""Override original method to delegate execution to the original PyTorch user module"""
return self._module_metadata.original_module.get_parameter(target)
def get_buffer(self, target: str) -> torch.Tensor:
"""Override original method to delegate execution to the base module"""
return self._original_module.get_buffer(target)
"""Override original method to delegate execution to the original PyTorch user module"""
return self._module_metadata.original_module.get_buffer(target)
def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]:
"""Override original method to delegate execution to the base module"""
yield from self._original_module.parameters(recurse=recurse)
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._module_metadata.original_module.parameters(recurse=recurse)
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]:
"""Override original method to delegate execution to the base module"""
yield from self._original_module.named_parameters(prefix=prefix, recurse=recurse)
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._module_metadata.original_module.named_parameters(prefix=prefix, recurse=recurse)
def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]:
"""Override original method to delegate execution to the base module"""
yield from self._original_module.buffers(recurse=recurse)
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._module_metadata.original_module.buffers(recurse=recurse)
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]:
"""Override original method to delegate execution to the base module"""
yield from self._original_module.named_buffers(prefix=prefix, recurse=recurse)
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._module_metadata.original_module.named_buffers(prefix=prefix, recurse=recurse)
def _replicate_for_data_parallel(self):
"""Raises a NotImplementedError exception since ORTModule is not compatible with torch.nn.DataParallel
@ -135,3 +160,34 @@ class ORTModule(torch.nn.Module):
raise NotImplementedError("ORTModule is not compatible with torch.nn.DataParallel. "
"Please use torch.nn.parallel.DistributedDataParallel instead.")
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""Override original method to delegate execution to the original PyTorch user module"""
# PyTorch load_state_dict implementation does not recursively call load_state_dict on its sub-modules.
# Instead, it creates a recursive function and invokes _load_from_state_dict on all child modules.
# For the scenario where an ORTModule is a sub-module of another module, loading of the state
# dictionary requires the _load_from_state_dict to be overridden to prevent an error.
self._module_metadata.original_module._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
def named_children(self) -> Iterator[Tuple[str, 'Module']]:
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._module_metadata.original_module.named_children()
def modules(self) -> Iterator['Module']:
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._module_metadata.original_module.modules()
def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = ''):
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._module_metadata.original_module.named_modules(memo, prefix)
def add_module(self, name: str, module: Optional['Module']) -> None:
"""Raises a NotImplementedError exception since ORTModule does not support adding modules to it"""
raise NotImplementedError("ORTModule does not support adding modules to it.")

View file

@ -1666,26 +1666,26 @@ def test_model_initializer_requires_grad_changes_from_one_forward_to_next():
model.fc1.requires_grad_(True)
model = ORTModule(model)
x = torch.randn(N, D_in, device=device)
assert model._original_module.fc1.weight.grad is None
assert model._original_module.fc1.bias.grad is None
assert model._module_metadata.original_module.fc1.weight.grad is None
assert model._module_metadata.original_module.fc1.bias.grad is None
# Make sure no exception is raised
output = model(x)
loss = torch.sum(output)
loss.backward()
training_session1 = model._execution_manager(model._is_training())._execution_agent
weight_grad_2 = model._original_module.fc1.weight.grad
bias_grad_2 = model._original_module.fc1.bias.grad
weight_grad_2 = model._module_metadata.original_module.fc1.weight.grad
bias_grad_2 = model._module_metadata.original_module.fc1.bias.grad
assert weight_grad_2 is not None
assert bias_grad_2 is not None
model._original_module.fc1.requires_grad_(False)
model._module_metadata.original_module.fc1.requires_grad_(False)
output = model(x)
loss = torch.sum(output)
loss.backward()
training_session2 = model._execution_manager(model._is_training())._execution_agent
weight_grad_3 = model._original_module.fc1.weight.grad
bias_grad_3 = model._original_module.fc1.bias.grad
weight_grad_3 = model._module_metadata.original_module.fc1.weight.grad
bias_grad_3 = model._module_metadata.original_module.fc1.bias.grad
assert training_session1 != training_session2
assert torch.equal(weight_grad_2, weight_grad_3)
@ -2619,3 +2619,31 @@ def test_unused_parameters_does_not_unnecssarily_reinitilize(model):
{})
assert not training_manager._reinitialize_graph_builder(input_info)
def test_load_state_dict_for_wrapped_ortmodule():
class WrapperModule(torch.nn.Module):
def __init__(self, ortmodule):
super(WrapperModule, self).__init__()
self._ortmodule = ortmodule
def forward(self, x):
return self._ortmodule(x)
device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
model = ORTModule(copy.deepcopy(model))
wrapper_module = WrapperModule(model)
x = torch.randn(N, D_in, device=device)
_ = wrapper_module(x)
state_dict1 = wrapper_module.state_dict()
list(next(iter(state_dict1.items())))[1] += 10
wrapper_module.load_state_dict(state_dict1)
state_dict2 = wrapper_module.state_dict()
assert state_dict1
assert len(state_dict1.keys()) == len(state_dict2.keys())
for param_name, param_value in state_dict1.items():
assert param_name in state_dict2
assert torch.equal(param_value, state_dict2[param_name])