mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-24 02:47:54 +00:00
Provide torch module interface for ORTModule (#8148)
* Interface for the module manager and implementation of the torch module manager
This commit is contained in:
parent
ce9d134952
commit
f616cd07b4
11 changed files with 474 additions and 172 deletions
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# _graph_execution_interface.py
|
||||
|
||||
from abc import ABC
|
||||
|
||||
|
||||
class GraphExecutionInterface(ABC):
|
||||
def __init__(self, module):
|
||||
self._original_module = module
|
||||
|
||||
self._validate_module_type(module)
|
||||
|
||||
def forward(self):
|
||||
"""Executes the forward method for ORTModule
|
||||
|
||||
This is an abstract method and must be overridden by a concrete implementation.
|
||||
This is the only method that the user should call on a concrete instance of the GraphExecutionInterface
|
||||
"""
|
||||
|
||||
raise NotImplementedError(f"forward is not implemented for {type(self)}")
|
||||
|
||||
def _validate_module_type(self, module):
|
||||
"""Validates the type of the input module
|
||||
|
||||
This is an abstract method and must be overridden by a concrete implementation.
|
||||
This is the only method that the user should call on a concrete instance of the GraphExecutionInterface
|
||||
"""
|
||||
|
||||
raise NotImplementedError(f"_validate_module_type is not implemented for {type(self)}")
|
||||
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
from . import _utils, _io, _logger, torch_cpp_extensions as _cpp_ext
|
||||
from ._custom_autograd_function_exporter import _post_process_after_export
|
||||
from ._graph_execution_interface import GraphExecutionInterface
|
||||
from onnxruntime.training.ortmodule import ONNX_OPSET_VERSION
|
||||
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
|
|
@ -30,7 +31,7 @@ class RunStateInfo(object):
|
|||
self.state = state
|
||||
self.output_info = output_info
|
||||
|
||||
class GraphExecutionManager(ABC):
|
||||
class GraphExecutionManager(GraphExecutionInterface):
|
||||
def __init__(self, module):
|
||||
"""Manages building and execution of onnx graphs
|
||||
|
||||
|
|
@ -41,8 +42,9 @@ class GraphExecutionManager(ABC):
|
|||
the onnx graph, and ExecutionAgent to run the onnx graph.
|
||||
"""
|
||||
|
||||
super(GraphExecutionManager, self).__init__(module._original_module)
|
||||
|
||||
# Original and flattened (tranformed) output module
|
||||
self._original_module = module._original_module
|
||||
self._flattened_module = module
|
||||
|
||||
# Exported model
|
||||
|
|
@ -126,6 +128,12 @@ class GraphExecutionManager(ABC):
|
|||
# WIP feature to enable caching in Gradient accumulation scenario.
|
||||
self._enable_grad_acc_optimization = False
|
||||
|
||||
def _validate_module_type(self, module):
|
||||
"""Raises a TypeError if the module is not a torch.nn.Module"""
|
||||
|
||||
if not isinstance(module, torch.nn.Module):
|
||||
raise TypeError(f"ORTModule only supports torch.nn.Module as input. {type(module)} is not supported.")
|
||||
|
||||
@staticmethod
|
||||
def execution_session_run_forward(execution_session, onnx_model, device, *inputs):
|
||||
"""Runs the forward pass on `execution_session` with given `onnx_model`, `device` and `inputs`
|
||||
|
|
|
|||
|
|
@ -0,0 +1,152 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# _torch_module.py
|
||||
|
||||
from . import _io
|
||||
from ._graph_execution_manager_factory import GraphExecutionManagerFactory
|
||||
from ._torch_module_interface import TorchModuleInterface
|
||||
|
||||
from collections import OrderedDict
|
||||
import functools
|
||||
import torch
|
||||
from typing import Iterator, Optional, Tuple, TypeVar, Callable
|
||||
|
||||
|
||||
T = TypeVar('T', bound='torch.nn.Module')
|
||||
|
||||
|
||||
class TorchModule(TorchModuleInterface):
|
||||
def __init__(self, module: torch.nn.Module):
|
||||
super(TorchModule, self).__init__(module)
|
||||
self._flattened_module = _io._FlattenedModule(module)
|
||||
|
||||
def _forward(self, *inputs, **kwargs):
|
||||
'''Forward pass starts here and continues at `_ORTModuleFunction.forward`
|
||||
|
||||
ONNX model is exported the first time this method is executed.
|
||||
Next, we build a full training graph with module_gradient_graph_builder.
|
||||
Finally, we instantiate the ONNX Runtime InferenceSession.
|
||||
'''
|
||||
|
||||
return self._execution_manager(self.is_training()).forward(*inputs, **kwargs)
|
||||
|
||||
# Bind the forward method.
|
||||
self.forward = _forward.__get__(self)
|
||||
# Copy the forward signature from the PyTorch module.
|
||||
functools.update_wrapper(
|
||||
self.forward.__func__, self._original_module.forward.__func__)
|
||||
|
||||
self._execution_manager = GraphExecutionManagerFactory(self._flattened_module)
|
||||
|
||||
def _apply(self, fn):
|
||||
"""Override original method to delegate execution to the flattened PyTorch user module"""
|
||||
|
||||
# Delegation must happen to _flattened_module since methods depend on
|
||||
# _apply to recursively apply the internal setting changes
|
||||
self._flattened_module._apply(fn)
|
||||
return self
|
||||
|
||||
def apply(self: T, fn: Callable[[T], None]) -> T:
|
||||
"""Override original method to delegate execution to the flattened PyTorch user module"""
|
||||
|
||||
# Delegation must happen to _flattened_module since methods depend on
|
||||
# apply to recursively apply the internal setting changes
|
||||
self._flattened_module.apply(fn)
|
||||
return self
|
||||
|
||||
def is_training(self):
|
||||
return self._flattened_module.training and torch.is_grad_enabled()
|
||||
|
||||
def train(self: T, mode: bool = True) -> T:
|
||||
"""Override original method to delegate execution to the flattened PyTorch user module"""
|
||||
|
||||
# Delegate the task to _module.flattened_module.train which will recursively
|
||||
# update the original_module
|
||||
self._flattened_module.train(mode)
|
||||
return self
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
# Override the state_dict() method so that the state dict key names
|
||||
# do not contain the flattened_module._original_module prefix
|
||||
return self._original_module.state_dict(
|
||||
destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]',
|
||||
strict: bool = True):
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
# Override the load_state_dict() method so that the loaded state dict
|
||||
# key names does not need to contain the _module.flattened_module._original_module prefix
|
||||
return self._original_module.load_state_dict(
|
||||
state_dict, strict=strict)
|
||||
|
||||
def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
self._original_module.register_buffer(name, tensor, persistent=persistent)
|
||||
|
||||
def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
self._original_module.register_parameter(name, param)
|
||||
|
||||
def get_parameter(self, target: str) -> torch.nn.Parameter:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
return self._original_module.get_parameter(target)
|
||||
|
||||
def get_buffer(self, target: str) -> torch.Tensor:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
return self._original_module.get_buffer(target)
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
yield from self._original_module.parameters(recurse=recurse)
|
||||
|
||||
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
yield from self._original_module.named_parameters(prefix=prefix, recurse=recurse)
|
||||
|
||||
def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
yield from self._original_module.buffers(recurse=recurse)
|
||||
|
||||
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
yield from self._original_module.named_buffers(prefix=prefix, recurse=recurse)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
# PyTorch load_state_dict implementation does not recursively call load_state_dict on its sub-modules.
|
||||
# Instead, it creates a recursive function and invokes _load_from_state_dict on all child modules.
|
||||
# For the scenario where an ORTModule is a sub-module of another module, loading of the state
|
||||
# dictionary requires the _load_from_state_dict to be overridden to prevent an error.
|
||||
|
||||
self._original_module._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
def named_children(self) -> Iterator[Tuple[str, T]]:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
yield from self._original_module.named_children()
|
||||
|
||||
def modules(self) -> Iterator[T]:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
yield from self._original_module.modules()
|
||||
|
||||
def named_modules(self, *args, **kwargs):
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
# PyTorch >1.8.1 has an extra arg remove_duplicate that is not present in 1.8.1
|
||||
# To support both, use args and kwargs (since user can call the method with only positional args or kwargs)
|
||||
yield from self._original_module.named_modules(*args, **kwargs)
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# _torch_module_factory.py
|
||||
|
||||
from ._torch_module import TorchModule
|
||||
|
||||
|
||||
class TorchModuleFactory:
|
||||
def __call__(self, module):
|
||||
"""Creates a TorchModule instance based on the input module."""
|
||||
|
||||
return TorchModule(module)
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# _torch_module_interface.py
|
||||
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
from typing import Iterator, Optional, Tuple, TypeVar, Callable
|
||||
|
||||
|
||||
T = TypeVar('T', bound='torch.nn.Module')
|
||||
|
||||
|
||||
class TorchModuleInterface:
|
||||
"""Abstract class that provides the function signatures for the torch.nn.Module
|
||||
|
||||
Concrete implementations should inherit from this class and provide necessary executions.
|
||||
"""
|
||||
|
||||
def __init__(self, module):
|
||||
self._original_module = module
|
||||
|
||||
@property
|
||||
def module(self):
|
||||
"""The original user provided module that this class manages.
|
||||
|
||||
This property provides access to methods and properties on the original module.
|
||||
"""
|
||||
|
||||
return self._original_module
|
||||
|
||||
###################################################
|
||||
# The methods below are part of torch.nn.Module API
|
||||
###################################################
|
||||
|
||||
def forward(self):
|
||||
"""Executes the forward method for ORTModule
|
||||
|
||||
This is an abstract method and must be overridden by a concrete implementation.
|
||||
"""
|
||||
|
||||
raise NotImplementedError(f"forward is not implemented for {type(self)}.")
|
||||
|
||||
def _apply(self, fn):
|
||||
|
||||
raise NotImplementedError(f"_apply is not implemented for {type(self)}.")
|
||||
|
||||
def apply(self: T, fn: Callable[[T], None]) -> T:
|
||||
|
||||
raise NotImplementedError(f"apply is not implemented for {type(self)}.")
|
||||
|
||||
def is_training(self):
|
||||
|
||||
raise NotImplementedError(f"is_training is not implemented for {type(self)}.")
|
||||
|
||||
def train(self: T, mode: bool = True) -> T:
|
||||
|
||||
raise NotImplementedError(f"train is not implemented for {type(self)}.")
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
|
||||
raise NotImplementedError(f"state_dict is not implemented for {type(self)}.")
|
||||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]',
|
||||
strict: bool = True):
|
||||
|
||||
raise NotImplementedError(f"load_state_dict is not implemented for {type(self)}.")
|
||||
|
||||
def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None:
|
||||
|
||||
raise NotImplementedError(f"register_buffer is not implemented for {type(self)}.")
|
||||
|
||||
def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
|
||||
|
||||
raise NotImplementedError(f"register_parameter is not implemented for {type(self)}.")
|
||||
|
||||
def get_parameter(self, target: str) -> torch.nn.Parameter:
|
||||
|
||||
raise NotImplementedError(f"get_parameter is not implemented for {type(self)}.")
|
||||
|
||||
def get_buffer(self, target: str) -> torch.Tensor:
|
||||
|
||||
raise NotImplementedError(f"get_buffer is not implemented for {type(self)}.")
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]:
|
||||
|
||||
raise NotImplementedError(f"parameters is not implemented for {type(self)}.")
|
||||
|
||||
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]:
|
||||
|
||||
raise NotImplementedError(f"named_parameters is not implemented for {type(self)}.")
|
||||
|
||||
def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]:
|
||||
|
||||
raise NotImplementedError(f"buffers is not implemented for {type(self)}.")
|
||||
|
||||
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||
|
||||
raise NotImplementedError(f"named_buffers is not implemented for {type(self)}.")
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
raise NotImplementedError(f"_load_from_state_dict is not implemented for {type(self)}.")
|
||||
|
||||
def named_children(self) -> Iterator[Tuple[str, T]]:
|
||||
|
||||
raise NotImplementedError(f"named_children is not implemented for {type(self)}.")
|
||||
|
||||
def modules(self) -> Iterator[T]:
|
||||
|
||||
raise NotImplementedError(f"modules is not implemented for {type(self)}.")
|
||||
|
||||
def named_modules(self, *args, **kwargs):
|
||||
|
||||
raise NotImplementedError(f"named_modules is not implemented for {type(self)}.")
|
||||
|
|
@ -97,10 +97,3 @@ def _create_iobinding(io_binding, inputs, model, device):
|
|||
|
||||
for value_info in model.graph.output:
|
||||
io_binding.bind_output(value_info.name, device.type, device_id=get_device_index(device))
|
||||
|
||||
class _PytorchModuleMetadata():
|
||||
"""Encapsulates modules and allows easy access as required"""
|
||||
|
||||
def __init__(self, original_module, flattened_module):
|
||||
self.original_module = original_module
|
||||
self.flattened_module = flattened_module
|
||||
|
|
|
|||
|
|
@ -3,9 +3,7 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from . import _io
|
||||
from ._graph_execution_manager_factory import GraphExecutionManagerFactory
|
||||
from ._utils import _PytorchModuleMetadata
|
||||
from ._torch_module_factory import TorchModuleFactory
|
||||
|
||||
from onnxruntime.training import register_custom_ops_pytorch_exporter
|
||||
|
||||
|
|
@ -24,8 +22,7 @@ class ORTModule(torch.nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, module):
|
||||
assert isinstance(
|
||||
module, torch.nn.Module), "'module' must be a torch.nn.Module"
|
||||
self._torch_module = TorchModuleFactory()(module)
|
||||
|
||||
# Create forward dynamically, so each ORTModule instance will have its own copy.
|
||||
# This is needed to be able to copy the forward signatures from the original PyTorch models
|
||||
|
|
@ -38,26 +35,19 @@ class ORTModule(torch.nn.Module):
|
|||
Finally, we instantiate the ONNX Runtime InferenceSession.
|
||||
'''
|
||||
|
||||
return self._execution_manager(self._is_training()).forward(*inputs, **kwargs)
|
||||
return self._torch_module.forward(*inputs, **kwargs)
|
||||
|
||||
# Bind the forward method.
|
||||
self.forward = _forward.__get__(self)
|
||||
# Copy the forward signature from the PyTorch module.
|
||||
# Copy the forward signature from the _torch_module's forward signature.
|
||||
functools.update_wrapper(
|
||||
self.forward.__func__, module.forward.__func__)
|
||||
self.forward.__func__, self._torch_module.forward.__func__)
|
||||
|
||||
super(ORTModule, self).__init__()
|
||||
|
||||
# Support contrib OPs
|
||||
register_custom_ops_pytorch_exporter.register_custom_op(is_ortmodule=True)
|
||||
|
||||
# User module is wrapped to use its initializers and save computed gradients
|
||||
# along with the module that flattens both input and output of the user module
|
||||
# inside _PytorchModuleMetadata
|
||||
self._module_metadata = _PytorchModuleMetadata(module, _io._FlattenedModule(module))
|
||||
|
||||
self._execution_manager = GraphExecutionManagerFactory(self._module_metadata.flattened_module)
|
||||
|
||||
# IMPORTANT: DO NOT add code here
|
||||
# This declaration is for automatic document generation purposes only
|
||||
# The actual forward implementation is bound during ORTModule initialization
|
||||
|
|
@ -65,83 +55,6 @@ class ORTModule(torch.nn.Module):
|
|||
'''Dummy documentation for forward method'''
|
||||
...
|
||||
|
||||
def _apply(self, fn):
|
||||
"""Override original method to delegate execution to the flattened PyTorch user module"""
|
||||
|
||||
# Delegation must happen to _flattened_module since methods depend on
|
||||
# _apply to recursively apply the internal setting changes
|
||||
self._module_metadata.flattened_module._apply(fn)
|
||||
return self
|
||||
|
||||
def apply(self: T, fn: Callable[['Module'], None]) -> T:
|
||||
"""Override original method to delegate execution to the flattened PyTorch user module"""
|
||||
|
||||
# Delegation must happen to _flattened_module since methods depend on
|
||||
# apply to recursively apply the internal setting changes
|
||||
self._module_metadata.flattened_module.apply(fn)
|
||||
return self
|
||||
|
||||
def _is_training(self):
|
||||
return self.training and torch.is_grad_enabled()
|
||||
|
||||
def train(self: T, mode: bool = True) -> T:
|
||||
"""Override original method to delegate execution to the flattened PyTorch user module"""
|
||||
|
||||
# Since _modules is empty, the task needs to be delegated to _module.flattened_module.train
|
||||
# which will recursively update the original_module
|
||||
self.training = mode
|
||||
self._module_metadata.flattened_module.train(mode)
|
||||
return self
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
# Override the state_dict() method so that the state dict key names
|
||||
# do not contain the flattened_module._original_module prefix
|
||||
return self._module_metadata.original_module.state_dict(
|
||||
destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
|
||||
strict: bool = True):
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
# Override the load_state_dict() method so that the loaded state dict
|
||||
# key names does not need to contain the _module.flattened_module._original_module prefix
|
||||
return self._module_metadata.original_module.load_state_dict(
|
||||
state_dict, strict=strict)
|
||||
|
||||
def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
self._module_metadata.original_module.register_buffer(name, tensor, persistent=persistent)
|
||||
|
||||
def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
self._module_metadata.original_module.register_parameter(name, param)
|
||||
|
||||
def get_parameter(self, target: str) -> torch.nn.Parameter:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
return self._module_metadata.original_module.get_parameter(target)
|
||||
|
||||
def get_buffer(self, target: str) -> torch.Tensor:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
return self._module_metadata.original_module.get_buffer(target)
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
yield from self._module_metadata.original_module.parameters(recurse=recurse)
|
||||
|
||||
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
yield from self._module_metadata.original_module.named_parameters(prefix=prefix, recurse=recurse)
|
||||
|
||||
def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
yield from self._module_metadata.original_module.buffers(recurse=recurse)
|
||||
|
||||
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
yield from self._module_metadata.original_module.named_buffers(prefix=prefix, recurse=recurse)
|
||||
|
||||
def _replicate_for_data_parallel(self):
|
||||
"""Raises a NotImplementedError exception since ORTModule is not compatible with torch.nn.DataParallel
|
||||
|
||||
|
|
@ -161,34 +74,6 @@ class ORTModule(torch.nn.Module):
|
|||
raise NotImplementedError("ORTModule is not compatible with torch.nn.DataParallel. "
|
||||
"Please use torch.nn.parallel.DistributedDataParallel instead.")
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
# PyTorch load_state_dict implementation does not recursively call load_state_dict on its sub-modules.
|
||||
# Instead, it creates a recursive function and invokes _load_from_state_dict on all child modules.
|
||||
# For the scenario where an ORTModule is a sub-module of another module, loading of the state
|
||||
# dictionary requires the _load_from_state_dict to be overridden to prevent an error.
|
||||
self._module_metadata.original_module._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
def named_children(self) -> Iterator[Tuple[str, 'Module']]:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
yield from self._module_metadata.original_module.named_children()
|
||||
|
||||
def modules(self) -> Iterator['Module']:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
yield from self._module_metadata.original_module.modules()
|
||||
|
||||
def named_modules(self, *args, **kwargs):
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
# PyTorch >1.8.1 has an extra arg remove_duplicate that is not present in 1.8.1
|
||||
# To support both, use args and kwargs (since user can call the method with only positional args or kwargs)
|
||||
yield from self._module_metadata.original_module.named_modules(*args, **kwargs)
|
||||
|
||||
def add_module(self, name: str, module: Optional['Module']) -> None:
|
||||
"""Raises a NotImplementedError exception since ORTModule does not support adding modules to it"""
|
||||
|
||||
|
|
@ -207,4 +92,111 @@ class ORTModule(torch.nn.Module):
|
|||
# This `module` property enables HuggingFace Trainer to retrieve the underlying PreTrainedModel inside ORTModule
|
||||
# to save and load a complete checkpoint
|
||||
|
||||
return self._module_metadata.original_module
|
||||
return self._torch_module.module
|
||||
|
||||
################################################################################
|
||||
# The methods below are part of torch.nn.Module API that are encapsulated through
|
||||
# TorchModuleInterface
|
||||
################################################################################
|
||||
|
||||
def _apply(self, fn):
|
||||
"""Override original method to delegate execution to the flattened PyTorch user module"""
|
||||
|
||||
self._torch_module._apply(fn)
|
||||
return self
|
||||
|
||||
def apply(self: T, fn: Callable[['Module'], None]) -> T:
|
||||
"""Override original method to delegate execution to the flattened PyTorch user module"""
|
||||
|
||||
self._torch_module.apply(fn)
|
||||
return self
|
||||
|
||||
def _is_training(self):
|
||||
return self._torch_module.is_training()
|
||||
|
||||
def train(self: T, mode: bool = True) -> T:
|
||||
"""Override original method to delegate execution to the flattened PyTorch user module"""
|
||||
|
||||
self.training = mode
|
||||
# In a torch.nn.Module, _modules stores all dependent modules (sub-modules) of the current module.
|
||||
# in a list so that torch.nn.Module can apply any changes to all sub-modules recursively.
|
||||
# Although the _flattened_module and _original_module are dependent modules for ORTModule,
|
||||
# they do not show up in _modules because they are abstracted away behind another class,
|
||||
# TorchModule. In order to apply changes to those sub-modules, delegate the task to _torch_module
|
||||
# which will recursively update the flattened_module and the original module.
|
||||
self._torch_module.train(mode)
|
||||
return self
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
return self._torch_module.state_dict(
|
||||
destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
|
||||
strict: bool = True):
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
return self._torch_module.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
self._torch_module.register_buffer(name, tensor, persistent=persistent)
|
||||
|
||||
def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
self._torch_module.register_parameter(name, param)
|
||||
|
||||
def get_parameter(self, target: str) -> torch.nn.Parameter:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
return self._torch_module.get_parameter(target)
|
||||
|
||||
def get_buffer(self, target: str) -> torch.Tensor:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
return self._torch_module.get_buffer(target)
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
yield from self._torch_module.parameters(recurse=recurse)
|
||||
|
||||
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
yield from self._torch_module.named_parameters(prefix=prefix, recurse=recurse)
|
||||
|
||||
def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
yield from self._torch_module.buffers(recurse=recurse)
|
||||
|
||||
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
yield from self._torch_module.named_buffers(prefix=prefix, recurse=recurse)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
self._torch_module._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
def named_children(self) -> Iterator[Tuple[str, 'Module']]:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
yield from self._torch_module.named_children()
|
||||
|
||||
def modules(self) -> Iterator['Module']:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
yield from self._torch_module.modules()
|
||||
|
||||
def named_modules(self, *args, **kwargs):
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
|
||||
yield from self._torch_module.named_modules(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ def assert_optim_state(expected_state, actual_state, rtol=1e-7, atol=0):
|
|||
|
||||
def is_dynamic_axes(model):
|
||||
# Check inputs
|
||||
for inp in model._execution_manager(model._is_training())._optimized_onnx_model.graph.input:
|
||||
for inp in model._torch_module._execution_manager(model._is_training())._optimized_onnx_model.graph.input:
|
||||
shape = inp.type.tensor_type.shape
|
||||
if shape:
|
||||
for dim in shape.dim:
|
||||
|
|
@ -132,7 +132,7 @@ def is_dynamic_axes(model):
|
|||
return False
|
||||
|
||||
# Check outputs
|
||||
for out in model._execution_manager(model._is_training())._optimized_onnx_model.graph.output:
|
||||
for out in model._torch_module._execution_manager(model._is_training())._optimized_onnx_model.graph.output:
|
||||
shape = out.type.tensor_type.shape
|
||||
if shape:
|
||||
for dim in shape.dim:
|
||||
|
|
|
|||
|
|
@ -472,7 +472,7 @@ def test_torch_nn_module_to_api(original_device, to_argument):
|
|||
model = model.to(to_argument)
|
||||
x = x.to(to_argument)
|
||||
model(x)
|
||||
assert _utils.get_device_str(model._execution_manager(model._is_training())._device) == \
|
||||
assert _utils.get_device_str(model._torch_module._execution_manager(model._is_training())._device) == \
|
||||
_utils.get_device_str(torch.device(to_argument))
|
||||
|
||||
def test_model_without_device():
|
||||
|
|
@ -524,7 +524,7 @@ def test_input_requires_grad_saved(device):
|
|||
model = ORTModule(model)
|
||||
x = torch.randn(N, D_in, device=device, requires_grad=True) + 1
|
||||
model(x)
|
||||
assert model._execution_manager(model._is_training())._input_info.require_grad_names == ['input1']
|
||||
assert model._torch_module._execution_manager(model._is_training())._input_info.require_grad_names == ['input1']
|
||||
|
||||
@pytest.mark.parametrize("device", ['cuda', 'cpu'])
|
||||
def test_input_requires_grad_backward_creates_input_grad(device):
|
||||
|
|
@ -1033,12 +1033,12 @@ def test_changes_input_requires_grad_reinitializes_module_gradient_graph_builder
|
|||
output_x.backward()
|
||||
assert x.grad is None
|
||||
module_gradient_graph_builder_training = \
|
||||
model._execution_manager(model._is_training())._graph_builder
|
||||
model._torch_module._execution_manager(model._is_training())._graph_builder
|
||||
output_y = torch.sum(model(y))
|
||||
output_y.backward()
|
||||
assert y.grad is not None
|
||||
assert module_gradient_graph_builder_training != \
|
||||
model._execution_manager(model._is_training())._graph_builder
|
||||
model._torch_module._execution_manager(model._is_training())._graph_builder
|
||||
|
||||
@pytest.mark.parametrize("device", ['cuda'])
|
||||
def test_input_requires_grad_backward_creates_input_grad_as_required0(device):
|
||||
|
|
@ -1695,7 +1695,7 @@ def test_forward_data_and_model_on_different_devices(data_device, model_device):
|
|||
x = torch.randn(N, D_in, device=data_device)
|
||||
with pytest.raises(RuntimeError) as runtime_error:
|
||||
ort_model(x)
|
||||
assert f"Input argument to forward found on device {torch.device(x.device)}, but expected it to be on module device {ort_model._execution_manager(ort_model._is_training())._device}." in str(runtime_error.value)
|
||||
assert f"Input argument to forward found on device {torch.device(x.device)}, but expected it to be on module device {ort_model._torch_module._execution_manager(ort_model._is_training())._device}." in str(runtime_error.value)
|
||||
|
||||
def test_forward_returns_none_type_as_output():
|
||||
class NeuralNetNoneTypeOutput(torch.nn.Module):
|
||||
|
|
@ -1799,26 +1799,26 @@ def test_model_initializer_requires_grad_changes_from_one_forward_to_next():
|
|||
model.fc1.requires_grad_(True)
|
||||
model = ORTModule(model)
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
assert model._module_metadata.original_module.fc1.weight.grad is None
|
||||
assert model._module_metadata.original_module.fc1.bias.grad is None
|
||||
assert model.module.fc1.weight.grad is None
|
||||
assert model.module.fc1.bias.grad is None
|
||||
|
||||
# Make sure no exception is raised
|
||||
output = model(x)
|
||||
loss = torch.sum(output)
|
||||
loss.backward()
|
||||
training_session1 = model._execution_manager(model._is_training())._execution_agent
|
||||
weight_grad_2 = model._module_metadata.original_module.fc1.weight.grad
|
||||
bias_grad_2 = model._module_metadata.original_module.fc1.bias.grad
|
||||
training_session1 = model._torch_module._execution_manager(model._is_training())._execution_agent
|
||||
weight_grad_2 = model.module.fc1.weight.grad
|
||||
bias_grad_2 = model.module.fc1.bias.grad
|
||||
assert weight_grad_2 is not None
|
||||
assert bias_grad_2 is not None
|
||||
|
||||
model._module_metadata.original_module.fc1.requires_grad_(False)
|
||||
model.module.fc1.requires_grad_(False)
|
||||
output = model(x)
|
||||
loss = torch.sum(output)
|
||||
loss.backward()
|
||||
training_session2 = model._execution_manager(model._is_training())._execution_agent
|
||||
weight_grad_3 = model._module_metadata.original_module.fc1.weight.grad
|
||||
bias_grad_3 = model._module_metadata.original_module.fc1.bias.grad
|
||||
training_session2 = model._torch_module._execution_manager(model._is_training())._execution_agent
|
||||
weight_grad_3 = model.module.fc1.weight.grad
|
||||
bias_grad_3 = model.module.fc1.bias.grad
|
||||
|
||||
assert training_session1 != training_session2
|
||||
assert torch.equal(weight_grad_2, weight_grad_3)
|
||||
|
|
@ -2142,21 +2142,21 @@ def test_forward_dynamic_args():
|
|||
for _ in range(10):
|
||||
output = model(*args_size1)
|
||||
assert output is not None
|
||||
hash_args_size1 = hash(repr(model._execution_manager(model._is_training())._input_info.schema))
|
||||
hash_args_size1 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema))
|
||||
assert hash_args_size1 is not None
|
||||
|
||||
# Decrease number of inputs and train some more
|
||||
for _ in range(10):
|
||||
output = model(*args_size2)
|
||||
assert output is not None
|
||||
hash_args_size2 = hash(repr(model._execution_manager(model._is_training())._input_info.schema))
|
||||
hash_args_size2 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema))
|
||||
assert hash_args_size2 != hash_args_size1
|
||||
|
||||
# Increase number of inputs and train some more
|
||||
for _ in range(10):
|
||||
output = model(*args_size3)
|
||||
assert output is not None
|
||||
hash_args_size3 = hash(repr(model._execution_manager(model._is_training())._input_info.schema))
|
||||
hash_args_size3 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema))
|
||||
assert hash_args_size3 != hash_args_size2
|
||||
|
||||
|
||||
|
|
@ -2178,35 +2178,35 @@ def test_forward_dynamic_kwargs():
|
|||
for _ in range(10):
|
||||
output = model(one)
|
||||
assert output is not None
|
||||
hash_x = hash(repr(model._execution_manager(model._is_training())._input_info.schema))
|
||||
hash_x = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema))
|
||||
assert hash_x is not None
|
||||
|
||||
# Train with x and y as inputs
|
||||
for _ in range(10):
|
||||
output = model(one,y=one)
|
||||
assert output is not None
|
||||
hash_x_y = hash(repr(model._execution_manager(model._is_training())._input_info.schema))
|
||||
hash_x_y = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema))
|
||||
assert hash_x_y != hash_x
|
||||
|
||||
# Train with x and z as inputs
|
||||
for _ in range(10):
|
||||
output = model(one,z=one)
|
||||
assert output is not None
|
||||
hash_x_z = hash(repr(model._execution_manager(model._is_training())._input_info.schema))
|
||||
hash_x_z = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema))
|
||||
assert hash_x_z != hash_x_y
|
||||
|
||||
# Train with x, y and z as inputs
|
||||
for _ in range(10):
|
||||
output = model(one,y=one, z=one)
|
||||
assert output is not None
|
||||
hash_x_y_z = hash(repr(model._execution_manager(model._is_training())._input_info.schema))
|
||||
hash_x_y_z = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema))
|
||||
assert hash_x_y_z != hash_x_z
|
||||
|
||||
# Return to original input with x as input
|
||||
for _ in range(10):
|
||||
output = model(one)
|
||||
assert output is not None
|
||||
hash_x2 = hash(repr(model._execution_manager(model._is_training())._input_info.schema))
|
||||
hash_x2 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema))
|
||||
assert hash_x2 != hash_x_y_z
|
||||
assert hash_x2 == hash_x
|
||||
|
||||
|
|
@ -2358,30 +2358,30 @@ def test_forward_call_default_input():
|
|||
# Model only uses a,d out of a,b,c,d
|
||||
out = model(one, two, three, four)
|
||||
assert out.item() == 5.0
|
||||
if model.training:
|
||||
if model._is_training():
|
||||
out.sum().backward()
|
||||
|
||||
out = model(one, two, c=three, d=four)
|
||||
assert out.item() == 5.0
|
||||
if model.training:
|
||||
if model._is_training():
|
||||
out.sum().backward()
|
||||
|
||||
# Model only uses a,d,args[-1] out of a,b,c,d,*args
|
||||
out = model(one, two, three, four, *args)
|
||||
assert out.item() == 7.0
|
||||
if model.training:
|
||||
if model._is_training():
|
||||
out.sum().backward()
|
||||
|
||||
# Model only uses a,d,args[-1],kw_0 out of a,b,c,d,*args,kw_0
|
||||
out = model(one, two, three, four, *args, kw_0=kw_0)
|
||||
assert out.item() == 13.0
|
||||
if model.training:
|
||||
if model._is_training():
|
||||
out.sum().backward()
|
||||
|
||||
# Model only uses a,d,args[-1],kwargs['kwargs_1'] out of a,b,c,d,*args,kw_0,**kwargs
|
||||
out = model(one, two, three, four, *args, **kwargs)
|
||||
assert out.item() == 15.0
|
||||
if model.training:
|
||||
if model._is_training():
|
||||
out.sum().backward()
|
||||
|
||||
|
||||
|
|
@ -2424,7 +2424,7 @@ def test_forward_call_kwargs_input_unexpected_order():
|
|||
y1, y2 = model(**{'input1': input1, 'input2': input2})
|
||||
assert y1 is not None
|
||||
assert y2 is not None
|
||||
if model.training:
|
||||
if model._is_training():
|
||||
loss = y1.sum() + y2.sum()
|
||||
loss.backward()
|
||||
|
||||
|
|
@ -2432,7 +2432,7 @@ def test_forward_call_kwargs_input_unexpected_order():
|
|||
y1, y2 = model(**{'input2': input2, 'input1': input1})
|
||||
assert y1 is not None
|
||||
assert y2 is not None
|
||||
if model.training:
|
||||
if model._is_training():
|
||||
loss = y1.sum() + y2.sum()
|
||||
loss.backward()
|
||||
|
||||
|
|
@ -2469,11 +2469,11 @@ def test_forward_call_lots_None():
|
|||
# ORTModule produces the same schema, thus not re-exporting
|
||||
# the model when `forward(a,b)` is used after `forward(**{'a': a, 'b': b})`
|
||||
# or vice-versa
|
||||
model._execution_manager(model._is_training())._onnx_model = None
|
||||
model._torch_module._execution_manager(model._is_training())._onnx_model = None
|
||||
out = model(a,b,c,d,e,f,y,z)
|
||||
assert out is not None
|
||||
assert out.item() == expected
|
||||
if model.training:
|
||||
if model._is_training():
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
|
||||
|
|
@ -2584,10 +2584,10 @@ def test_changing_bool_input_re_exports_model(bool_arguments):
|
|||
|
||||
input1 = torch.randn(N, D_in, device=device)
|
||||
ort_model(input1, bool_arguments[0])
|
||||
exported_model1 = ort_model._execution_manager(ort_model._is_training())._onnx_model
|
||||
exported_model1 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_model
|
||||
|
||||
ort_model(input1, bool_arguments[1])
|
||||
exported_model2 = ort_model._execution_manager(ort_model._is_training())._onnx_model
|
||||
exported_model2 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_model
|
||||
|
||||
assert exported_model1 != exported_model2
|
||||
|
||||
|
|
@ -2741,7 +2741,7 @@ def test_unused_parameters_does_not_unnecssarily_reinitilize(model):
|
|||
N, D_in, H1, H2, D_out = 64, 784, 500, 400, 10
|
||||
model = model.to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(model))
|
||||
training_manager = ort_model._execution_manager(ort_model._is_training())
|
||||
training_manager = ort_model._torch_module._execution_manager(ort_model._is_training())
|
||||
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
_ = ort_model(x)
|
||||
|
|
|
|||
|
|
@ -379,8 +379,8 @@ def main():
|
|||
model = ORTModule(model)
|
||||
|
||||
# Just for future debugging
|
||||
model._execution_manager(model._is_training())._save_onnx = False
|
||||
model._execution_manager(model._is_training())._save_onnx_prefix = 'BertForSequenceClassification'
|
||||
model._torch_module._execution_manager(model._is_training())._save_onnx = False
|
||||
model._torch_module._execution_manager(model._is_training())._save_onnx_prefix = 'BertForSequenceClassification'
|
||||
|
||||
# Tell pytorch to run this model on the GPU.
|
||||
if torch.cuda.is_available() and not args.no_cuda:
|
||||
|
|
|
|||
|
|
@ -379,10 +379,10 @@ def main():
|
|||
if not args.pytorch_only:
|
||||
model = ORTModule(model)
|
||||
|
||||
|
||||
model._execution_manager(is_training=True)._save_onnx = True
|
||||
model._execution_manager(is_training=True)._save_onnx_prefix = 'BertForSequenceClassification'
|
||||
model._execution_manager(is_training=True)._enable_grad_acc_optimization = True
|
||||
|
||||
model._torch_module._execution_manager(is_training=True)._save_onnx = True
|
||||
model._torch_module._execution_manager(is_training=True)._save_onnx_prefix = 'BertForSequenceClassification'
|
||||
model._torch_module._execution_manager(is_training=True)._enable_grad_acc_optimization = True
|
||||
|
||||
# Tell pytorch to run this model on the GPU.
|
||||
if torch.cuda.is_available() and not args.no_cuda:
|
||||
|
|
|
|||
Loading…
Reference in a new issue