mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Resolve issue with wrapped ORTModule load_state_dict (#7847)
* Encapsulate children modules inside a ModuleAccessor object to prevent erroneuos iteration over children while loading the state dictionary * Add named_models, models, apply methods, change ModuleAccessor to ModuleMetadata and modify unit tests * Change ModuleMetadata module getter logic, raise NotImplementedError for add_modules * Add comment explaining why overriding _load_from_state_dict method is needed
This commit is contained in:
parent
8140e3fde5
commit
ddf4aaaae1
3 changed files with 127 additions and 36 deletions
|
|
@ -97,3 +97,10 @@ def _create_iobinding(io_binding, inputs, model, device):
|
|||
|
||||
for value_info in model.graph.output:
|
||||
io_binding.bind_output(value_info.name, device.type, device_id=get_device_index(device))
|
||||
|
||||
class _PytorchModuleMetadata():
|
||||
"""Encapsulates modules and allows easy access as required"""
|
||||
|
||||
def __init__(self, original_module, flattened_module):
|
||||
self.original_module = original_module
|
||||
self.flattened_module = flattened_module
|
||||
|
|
|
|||
|
|
@ -5,12 +5,13 @@
|
|||
|
||||
from . import _io
|
||||
from ._graph_execution_manager_factory import GraphExecutionManagerFactory
|
||||
from ._utils import _PytorchModuleMetadata
|
||||
|
||||
from onnxruntime.training import register_custom_ops_pytorch_exporter
|
||||
|
||||
import functools
|
||||
import torch
|
||||
from typing import Iterator, Optional, Tuple, TypeVar
|
||||
from typing import Iterator, Optional, Tuple, TypeVar, Set, Callable
|
||||
|
||||
# Needed to override PyTorch methods
|
||||
T = TypeVar('T', bound='Module')
|
||||
|
|
@ -51,12 +52,11 @@ class ORTModule(torch.nn.Module):
|
|||
register_custom_ops_pytorch_exporter.register_custom_op(is_ortmodule=True)
|
||||
|
||||
# User module is wrapped to use its initializers and save computed gradients
|
||||
self._original_module = module
|
||||
# along with the module that flattens both input and output of the user module
|
||||
# inside _PytorchModuleMetadata
|
||||
self._module_metadata = _PytorchModuleMetadata(module, _io._FlattenedModule(module))
|
||||
|
||||
# Get the module that flattens both input and output
|
||||
self._flattened_module = _io._FlattenedModule(self._original_module)
|
||||
|
||||
self._execution_manager = GraphExecutionManagerFactory(self._flattened_module)
|
||||
self._execution_manager = GraphExecutionManagerFactory(self._module_metadata.flattened_module)
|
||||
|
||||
# IMPORTANT: DO NOT add code here
|
||||
# This declaration is for automatic document generation purposes only
|
||||
|
|
@ -65,57 +65,82 @@ class ORTModule(torch.nn.Module):
|
|||
'''Dummy documentation for forward method'''
|
||||
...
|
||||
|
||||
def _apply(self, fn):
|
||||
"""Override original method to delegate execution to the flattened PyTorch user module"""
|
||||
|
||||
# Delegation must happen to _flattened_module since methods depend on
|
||||
# _apply to recursively apply the internal setting changes
|
||||
self._module_metadata.flattened_module._apply(fn)
|
||||
return self
|
||||
|
||||
def apply(self: T, fn: Callable[['Module'], None]) -> T:
|
||||
"""Override original method to delegate execution to the flattened PyTorch user module"""
|
||||
|
||||
# Delegation must happen to _flattened_module since methods depend on
|
||||
# apply to recursively apply the internal setting changes
|
||||
self._module_metadata.flattened_module.apply(fn)
|
||||
return self
|
||||
|
||||
def _is_training(self):
|
||||
return self._flattened_module.training and torch.is_grad_enabled()
|
||||
return self.training and torch.is_grad_enabled()
|
||||
|
||||
def train(self: T, mode: bool = True) -> T:
|
||||
"""Override original method to delegate execution to the flattened PyTorch user module"""
|
||||
|
||||
# Since _modules is empty, the task needs to be delegated to _module.flattened_module.train
|
||||
# which will recursively update the original_module
|
||||
self.training = mode
|
||||
self._module_metadata.flattened_module.train(mode)
|
||||
return self
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
"""Override original method to delegate execution to the base module"""
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
# Override the state_dict() method so that the state dict key names
|
||||
# do not contain the _flattened_module._original_module prefix
|
||||
return self._original_module.state_dict(
|
||||
# do not contain the flattened_module._original_module prefix
|
||||
return self._module_metadata.original_module.state_dict(
|
||||
destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
|
||||
strict: bool = True):
|
||||
"""Override original method to delegate execution to the base module"""
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
# Override the load_state_dict() method so that the loaded state dict
|
||||
# key names does not need to contain the _flattened_module._original_module prefix
|
||||
return self._original_module.load_state_dict(
|
||||
# key names does not need to contain the _module.flattened_module._original_module prefix
|
||||
return self._module_metadata.original_module.load_state_dict(
|
||||
state_dict, strict=strict)
|
||||
|
||||
def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None:
|
||||
"""Override original method to delegate execution to the base module"""
|
||||
self._original_module.register_buffer(name, tensor, persistent=persistent)
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
self._module_metadata.original_module.register_buffer(name, tensor, persistent=persistent)
|
||||
|
||||
def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
|
||||
"""Override original method to delegate execution to the base module"""
|
||||
self._original_module.register_parameter(name, param)
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
self._module_metadata.original_module.register_parameter(name, param)
|
||||
|
||||
def get_parameter(self, target: str) -> torch.nn.Parameter:
|
||||
"""Override original method to delegate execution to the base module"""
|
||||
return self._original_module.get_parameter(target)
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
return self._module_metadata.original_module.get_parameter(target)
|
||||
|
||||
def get_buffer(self, target: str) -> torch.Tensor:
|
||||
"""Override original method to delegate execution to the base module"""
|
||||
return self._original_module.get_buffer(target)
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
return self._module_metadata.original_module.get_buffer(target)
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]:
|
||||
"""Override original method to delegate execution to the base module"""
|
||||
yield from self._original_module.parameters(recurse=recurse)
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
yield from self._module_metadata.original_module.parameters(recurse=recurse)
|
||||
|
||||
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]:
|
||||
"""Override original method to delegate execution to the base module"""
|
||||
yield from self._original_module.named_parameters(prefix=prefix, recurse=recurse)
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
yield from self._module_metadata.original_module.named_parameters(prefix=prefix, recurse=recurse)
|
||||
|
||||
def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]:
|
||||
"""Override original method to delegate execution to the base module"""
|
||||
yield from self._original_module.buffers(recurse=recurse)
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
yield from self._module_metadata.original_module.buffers(recurse=recurse)
|
||||
|
||||
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||
"""Override original method to delegate execution to the base module"""
|
||||
yield from self._original_module.named_buffers(prefix=prefix, recurse=recurse)
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
yield from self._module_metadata.original_module.named_buffers(prefix=prefix, recurse=recurse)
|
||||
|
||||
def _replicate_for_data_parallel(self):
|
||||
"""Raises a NotImplementedError exception since ORTModule is not compatible with torch.nn.DataParallel
|
||||
|
|
@ -135,3 +160,34 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
raise NotImplementedError("ORTModule is not compatible with torch.nn.DataParallel. "
|
||||
"Please use torch.nn.parallel.DistributedDataParallel instead.")
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
# PyTorch load_state_dict implementation does not recursively call load_state_dict on its sub-modules.
|
||||
# Instead, it creates a recursive function and invokes _load_from_state_dict on all child modules.
|
||||
# For the scenario where an ORTModule is a sub-module of another module, loading of the state
|
||||
# dictionary requires the _load_from_state_dict to be overridden to prevent an error.
|
||||
self._module_metadata.original_module._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
def named_children(self) -> Iterator[Tuple[str, 'Module']]:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
yield from self._module_metadata.original_module.named_children()
|
||||
|
||||
def modules(self) -> Iterator['Module']:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
yield from self._module_metadata.original_module.modules()
|
||||
|
||||
def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = ''):
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
yield from self._module_metadata.original_module.named_modules(memo, prefix)
|
||||
|
||||
def add_module(self, name: str, module: Optional['Module']) -> None:
|
||||
"""Raises a NotImplementedError exception since ORTModule does not support adding modules to it"""
|
||||
|
||||
raise NotImplementedError("ORTModule does not support adding modules to it.")
|
||||
|
|
|
|||
|
|
@ -1666,26 +1666,26 @@ def test_model_initializer_requires_grad_changes_from_one_forward_to_next():
|
|||
model.fc1.requires_grad_(True)
|
||||
model = ORTModule(model)
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
assert model._original_module.fc1.weight.grad is None
|
||||
assert model._original_module.fc1.bias.grad is None
|
||||
assert model._module_metadata.original_module.fc1.weight.grad is None
|
||||
assert model._module_metadata.original_module.fc1.bias.grad is None
|
||||
|
||||
# Make sure no exception is raised
|
||||
output = model(x)
|
||||
loss = torch.sum(output)
|
||||
loss.backward()
|
||||
training_session1 = model._execution_manager(model._is_training())._execution_agent
|
||||
weight_grad_2 = model._original_module.fc1.weight.grad
|
||||
bias_grad_2 = model._original_module.fc1.bias.grad
|
||||
weight_grad_2 = model._module_metadata.original_module.fc1.weight.grad
|
||||
bias_grad_2 = model._module_metadata.original_module.fc1.bias.grad
|
||||
assert weight_grad_2 is not None
|
||||
assert bias_grad_2 is not None
|
||||
|
||||
model._original_module.fc1.requires_grad_(False)
|
||||
model._module_metadata.original_module.fc1.requires_grad_(False)
|
||||
output = model(x)
|
||||
loss = torch.sum(output)
|
||||
loss.backward()
|
||||
training_session2 = model._execution_manager(model._is_training())._execution_agent
|
||||
weight_grad_3 = model._original_module.fc1.weight.grad
|
||||
bias_grad_3 = model._original_module.fc1.bias.grad
|
||||
weight_grad_3 = model._module_metadata.original_module.fc1.weight.grad
|
||||
bias_grad_3 = model._module_metadata.original_module.fc1.bias.grad
|
||||
|
||||
assert training_session1 != training_session2
|
||||
assert torch.equal(weight_grad_2, weight_grad_3)
|
||||
|
|
@ -2619,3 +2619,31 @@ def test_unused_parameters_does_not_unnecssarily_reinitilize(model):
|
|||
{})
|
||||
|
||||
assert not training_manager._reinitialize_graph_builder(input_info)
|
||||
|
||||
def test_load_state_dict_for_wrapped_ortmodule():
|
||||
class WrapperModule(torch.nn.Module):
|
||||
def __init__(self, ortmodule):
|
||||
super(WrapperModule, self).__init__()
|
||||
self._ortmodule = ortmodule
|
||||
|
||||
def forward(self, x):
|
||||
return self._ortmodule(x)
|
||||
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
model = ORTModule(copy.deepcopy(model))
|
||||
wrapper_module = WrapperModule(model)
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
_ = wrapper_module(x)
|
||||
|
||||
state_dict1 = wrapper_module.state_dict()
|
||||
list(next(iter(state_dict1.items())))[1] += 10
|
||||
wrapper_module.load_state_dict(state_dict1)
|
||||
state_dict2 = wrapper_module.state_dict()
|
||||
|
||||
assert state_dict1
|
||||
assert len(state_dict1.keys()) == len(state_dict2.keys())
|
||||
for param_name, param_value in state_dict1.items():
|
||||
assert param_name in state_dict2
|
||||
assert torch.equal(param_value, state_dict2[param_name])
|
||||
|
|
|
|||
Loading…
Reference in a new issue