Provide torch module interface for ORTModule (#8148)

* Interface for the module manager and implementation of the torch module manager
This commit is contained in:
baijumeswani 2021-07-01 09:15:16 -07:00 committed by GitHub
parent ce9d134952
commit f616cd07b4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 474 additions and 172 deletions

View file

@ -0,0 +1,30 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# _graph_execution_interface.py
from abc import ABC
class GraphExecutionInterface(ABC):
def __init__(self, module):
self._original_module = module
self._validate_module_type(module)
def forward(self):
"""Executes the forward method for ORTModule
This is an abstract method and must be overridden by a concrete implementation.
This is the only method that the user should call on a concrete instance of the GraphExecutionInterface
"""
raise NotImplementedError(f"forward is not implemented for {type(self)}")
def _validate_module_type(self, module):
"""Validates the type of the input module
This is an abstract method and must be overridden by a concrete implementation.
This is the only method that the user should call on a concrete instance of the GraphExecutionInterface
"""
raise NotImplementedError(f"_validate_module_type is not implemented for {type(self)}")

View file

@ -5,6 +5,7 @@
from . import _utils, _io, _logger, torch_cpp_extensions as _cpp_ext
from ._custom_autograd_function_exporter import _post_process_after_export
from ._graph_execution_interface import GraphExecutionInterface
from onnxruntime.training.ortmodule import ONNX_OPSET_VERSION
from onnxruntime.capi import _pybind_state as C
@ -30,7 +31,7 @@ class RunStateInfo(object):
self.state = state
self.output_info = output_info
class GraphExecutionManager(ABC):
class GraphExecutionManager(GraphExecutionInterface):
def __init__(self, module):
"""Manages building and execution of onnx graphs
@ -41,8 +42,9 @@ class GraphExecutionManager(ABC):
the onnx graph, and ExecutionAgent to run the onnx graph.
"""
super(GraphExecutionManager, self).__init__(module._original_module)
# Original and flattened (tranformed) output module
self._original_module = module._original_module
self._flattened_module = module
# Exported model
@ -126,6 +128,12 @@ class GraphExecutionManager(ABC):
# WIP feature to enable caching in Gradient accumulation scenario.
self._enable_grad_acc_optimization = False
def _validate_module_type(self, module):
"""Raises a TypeError if the module is not a torch.nn.Module"""
if not isinstance(module, torch.nn.Module):
raise TypeError(f"ORTModule only supports torch.nn.Module as input. {type(module)} is not supported.")
@staticmethod
def execution_session_run_forward(execution_session, onnx_model, device, *inputs):
"""Runs the forward pass on `execution_session` with given `onnx_model`, `device` and `inputs`

View file

@ -0,0 +1,152 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# _torch_module.py
from . import _io
from ._graph_execution_manager_factory import GraphExecutionManagerFactory
from ._torch_module_interface import TorchModuleInterface
from collections import OrderedDict
import functools
import torch
from typing import Iterator, Optional, Tuple, TypeVar, Callable
T = TypeVar('T', bound='torch.nn.Module')
class TorchModule(TorchModuleInterface):
def __init__(self, module: torch.nn.Module):
super(TorchModule, self).__init__(module)
self._flattened_module = _io._FlattenedModule(module)
def _forward(self, *inputs, **kwargs):
'''Forward pass starts here and continues at `_ORTModuleFunction.forward`
ONNX model is exported the first time this method is executed.
Next, we build a full training graph with module_gradient_graph_builder.
Finally, we instantiate the ONNX Runtime InferenceSession.
'''
return self._execution_manager(self.is_training()).forward(*inputs, **kwargs)
# Bind the forward method.
self.forward = _forward.__get__(self)
# Copy the forward signature from the PyTorch module.
functools.update_wrapper(
self.forward.__func__, self._original_module.forward.__func__)
self._execution_manager = GraphExecutionManagerFactory(self._flattened_module)
def _apply(self, fn):
"""Override original method to delegate execution to the flattened PyTorch user module"""
# Delegation must happen to _flattened_module since methods depend on
# _apply to recursively apply the internal setting changes
self._flattened_module._apply(fn)
return self
def apply(self: T, fn: Callable[[T], None]) -> T:
"""Override original method to delegate execution to the flattened PyTorch user module"""
# Delegation must happen to _flattened_module since methods depend on
# apply to recursively apply the internal setting changes
self._flattened_module.apply(fn)
return self
def is_training(self):
return self._flattened_module.training and torch.is_grad_enabled()
def train(self: T, mode: bool = True) -> T:
"""Override original method to delegate execution to the flattened PyTorch user module"""
# Delegate the task to _module.flattened_module.train which will recursively
# update the original_module
self._flattened_module.train(mode)
return self
def state_dict(self, destination=None, prefix='', keep_vars=False):
"""Override original method to delegate execution to the original PyTorch user module"""
# Override the state_dict() method so that the state dict key names
# do not contain the flattened_module._original_module prefix
return self._original_module.state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars)
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]',
strict: bool = True):
"""Override original method to delegate execution to the original PyTorch user module"""
# Override the load_state_dict() method so that the loaded state dict
# key names does not need to contain the _module.flattened_module._original_module prefix
return self._original_module.load_state_dict(
state_dict, strict=strict)
def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None:
"""Override original method to delegate execution to the original PyTorch user module"""
self._original_module.register_buffer(name, tensor, persistent=persistent)
def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
"""Override original method to delegate execution to the original PyTorch user module"""
self._original_module.register_parameter(name, param)
def get_parameter(self, target: str) -> torch.nn.Parameter:
"""Override original method to delegate execution to the original PyTorch user module"""
return self._original_module.get_parameter(target)
def get_buffer(self, target: str) -> torch.Tensor:
"""Override original method to delegate execution to the original PyTorch user module"""
return self._original_module.get_buffer(target)
def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]:
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._original_module.parameters(recurse=recurse)
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]:
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._original_module.named_parameters(prefix=prefix, recurse=recurse)
def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]:
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._original_module.buffers(recurse=recurse)
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]:
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._original_module.named_buffers(prefix=prefix, recurse=recurse)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""Override original method to delegate execution to the original PyTorch user module"""
# PyTorch load_state_dict implementation does not recursively call load_state_dict on its sub-modules.
# Instead, it creates a recursive function and invokes _load_from_state_dict on all child modules.
# For the scenario where an ORTModule is a sub-module of another module, loading of the state
# dictionary requires the _load_from_state_dict to be overridden to prevent an error.
self._original_module._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
def named_children(self) -> Iterator[Tuple[str, T]]:
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._original_module.named_children()
def modules(self) -> Iterator[T]:
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._original_module.modules()
def named_modules(self, *args, **kwargs):
"""Override original method to delegate execution to the original PyTorch user module"""
# PyTorch >1.8.1 has an extra arg remove_duplicate that is not present in 1.8.1
# To support both, use args and kwargs (since user can call the method with only positional args or kwargs)
yield from self._original_module.named_modules(*args, **kwargs)

View file

@ -0,0 +1,12 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# _torch_module_factory.py
from ._torch_module import TorchModule
class TorchModuleFactory:
def __call__(self, module):
"""Creates a TorchModule instance based on the input module."""
return TorchModule(module)

View file

@ -0,0 +1,115 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# _torch_module_interface.py
from collections import OrderedDict
import torch
from typing import Iterator, Optional, Tuple, TypeVar, Callable
T = TypeVar('T', bound='torch.nn.Module')
class TorchModuleInterface:
"""Abstract class that provides the function signatures for the torch.nn.Module
Concrete implementations should inherit from this class and provide necessary executions.
"""
def __init__(self, module):
self._original_module = module
@property
def module(self):
"""The original user provided module that this class manages.
This property provides access to methods and properties on the original module.
"""
return self._original_module
###################################################
# The methods below are part of torch.nn.Module API
###################################################
def forward(self):
"""Executes the forward method for ORTModule
This is an abstract method and must be overridden by a concrete implementation.
"""
raise NotImplementedError(f"forward is not implemented for {type(self)}.")
def _apply(self, fn):
raise NotImplementedError(f"_apply is not implemented for {type(self)}.")
def apply(self: T, fn: Callable[[T], None]) -> T:
raise NotImplementedError(f"apply is not implemented for {type(self)}.")
def is_training(self):
raise NotImplementedError(f"is_training is not implemented for {type(self)}.")
def train(self: T, mode: bool = True) -> T:
raise NotImplementedError(f"train is not implemented for {type(self)}.")
def state_dict(self, destination=None, prefix='', keep_vars=False):
raise NotImplementedError(f"state_dict is not implemented for {type(self)}.")
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]',
strict: bool = True):
raise NotImplementedError(f"load_state_dict is not implemented for {type(self)}.")
def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None:
raise NotImplementedError(f"register_buffer is not implemented for {type(self)}.")
def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
raise NotImplementedError(f"register_parameter is not implemented for {type(self)}.")
def get_parameter(self, target: str) -> torch.nn.Parameter:
raise NotImplementedError(f"get_parameter is not implemented for {type(self)}.")
def get_buffer(self, target: str) -> torch.Tensor:
raise NotImplementedError(f"get_buffer is not implemented for {type(self)}.")
def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]:
raise NotImplementedError(f"parameters is not implemented for {type(self)}.")
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]:
raise NotImplementedError(f"named_parameters is not implemented for {type(self)}.")
def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]:
raise NotImplementedError(f"buffers is not implemented for {type(self)}.")
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]:
raise NotImplementedError(f"named_buffers is not implemented for {type(self)}.")
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
raise NotImplementedError(f"_load_from_state_dict is not implemented for {type(self)}.")
def named_children(self) -> Iterator[Tuple[str, T]]:
raise NotImplementedError(f"named_children is not implemented for {type(self)}.")
def modules(self) -> Iterator[T]:
raise NotImplementedError(f"modules is not implemented for {type(self)}.")
def named_modules(self, *args, **kwargs):
raise NotImplementedError(f"named_modules is not implemented for {type(self)}.")

View file

@ -97,10 +97,3 @@ def _create_iobinding(io_binding, inputs, model, device):
for value_info in model.graph.output:
io_binding.bind_output(value_info.name, device.type, device_id=get_device_index(device))
class _PytorchModuleMetadata():
"""Encapsulates modules and allows easy access as required"""
def __init__(self, original_module, flattened_module):
self.original_module = original_module
self.flattened_module = flattened_module

View file

@ -3,9 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from . import _io
from ._graph_execution_manager_factory import GraphExecutionManagerFactory
from ._utils import _PytorchModuleMetadata
from ._torch_module_factory import TorchModuleFactory
from onnxruntime.training import register_custom_ops_pytorch_exporter
@ -24,8 +22,7 @@ class ORTModule(torch.nn.Module):
"""
def __init__(self, module):
assert isinstance(
module, torch.nn.Module), "'module' must be a torch.nn.Module"
self._torch_module = TorchModuleFactory()(module)
# Create forward dynamically, so each ORTModule instance will have its own copy.
# This is needed to be able to copy the forward signatures from the original PyTorch models
@ -38,26 +35,19 @@ class ORTModule(torch.nn.Module):
Finally, we instantiate the ONNX Runtime InferenceSession.
'''
return self._execution_manager(self._is_training()).forward(*inputs, **kwargs)
return self._torch_module.forward(*inputs, **kwargs)
# Bind the forward method.
self.forward = _forward.__get__(self)
# Copy the forward signature from the PyTorch module.
# Copy the forward signature from the _torch_module's forward signature.
functools.update_wrapper(
self.forward.__func__, module.forward.__func__)
self.forward.__func__, self._torch_module.forward.__func__)
super(ORTModule, self).__init__()
# Support contrib OPs
register_custom_ops_pytorch_exporter.register_custom_op(is_ortmodule=True)
# User module is wrapped to use its initializers and save computed gradients
# along with the module that flattens both input and output of the user module
# inside _PytorchModuleMetadata
self._module_metadata = _PytorchModuleMetadata(module, _io._FlattenedModule(module))
self._execution_manager = GraphExecutionManagerFactory(self._module_metadata.flattened_module)
# IMPORTANT: DO NOT add code here
# This declaration is for automatic document generation purposes only
# The actual forward implementation is bound during ORTModule initialization
@ -65,83 +55,6 @@ class ORTModule(torch.nn.Module):
'''Dummy documentation for forward method'''
...
def _apply(self, fn):
"""Override original method to delegate execution to the flattened PyTorch user module"""
# Delegation must happen to _flattened_module since methods depend on
# _apply to recursively apply the internal setting changes
self._module_metadata.flattened_module._apply(fn)
return self
def apply(self: T, fn: Callable[['Module'], None]) -> T:
"""Override original method to delegate execution to the flattened PyTorch user module"""
# Delegation must happen to _flattened_module since methods depend on
# apply to recursively apply the internal setting changes
self._module_metadata.flattened_module.apply(fn)
return self
def _is_training(self):
return self.training and torch.is_grad_enabled()
def train(self: T, mode: bool = True) -> T:
"""Override original method to delegate execution to the flattened PyTorch user module"""
# Since _modules is empty, the task needs to be delegated to _module.flattened_module.train
# which will recursively update the original_module
self.training = mode
self._module_metadata.flattened_module.train(mode)
return self
def state_dict(self, destination=None, prefix='', keep_vars=False):
"""Override original method to delegate execution to the original PyTorch user module"""
# Override the state_dict() method so that the state dict key names
# do not contain the flattened_module._original_module prefix
return self._module_metadata.original_module.state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars)
def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
strict: bool = True):
"""Override original method to delegate execution to the original PyTorch user module"""
# Override the load_state_dict() method so that the loaded state dict
# key names does not need to contain the _module.flattened_module._original_module prefix
return self._module_metadata.original_module.load_state_dict(
state_dict, strict=strict)
def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None:
"""Override original method to delegate execution to the original PyTorch user module"""
self._module_metadata.original_module.register_buffer(name, tensor, persistent=persistent)
def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
"""Override original method to delegate execution to the original PyTorch user module"""
self._module_metadata.original_module.register_parameter(name, param)
def get_parameter(self, target: str) -> torch.nn.Parameter:
"""Override original method to delegate execution to the original PyTorch user module"""
return self._module_metadata.original_module.get_parameter(target)
def get_buffer(self, target: str) -> torch.Tensor:
"""Override original method to delegate execution to the original PyTorch user module"""
return self._module_metadata.original_module.get_buffer(target)
def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]:
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._module_metadata.original_module.parameters(recurse=recurse)
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]:
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._module_metadata.original_module.named_parameters(prefix=prefix, recurse=recurse)
def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]:
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._module_metadata.original_module.buffers(recurse=recurse)
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]:
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._module_metadata.original_module.named_buffers(prefix=prefix, recurse=recurse)
def _replicate_for_data_parallel(self):
"""Raises a NotImplementedError exception since ORTModule is not compatible with torch.nn.DataParallel
@ -161,34 +74,6 @@ class ORTModule(torch.nn.Module):
raise NotImplementedError("ORTModule is not compatible with torch.nn.DataParallel. "
"Please use torch.nn.parallel.DistributedDataParallel instead.")
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""Override original method to delegate execution to the original PyTorch user module"""
# PyTorch load_state_dict implementation does not recursively call load_state_dict on its sub-modules.
# Instead, it creates a recursive function and invokes _load_from_state_dict on all child modules.
# For the scenario where an ORTModule is a sub-module of another module, loading of the state
# dictionary requires the _load_from_state_dict to be overridden to prevent an error.
self._module_metadata.original_module._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
def named_children(self) -> Iterator[Tuple[str, 'Module']]:
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._module_metadata.original_module.named_children()
def modules(self) -> Iterator['Module']:
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._module_metadata.original_module.modules()
def named_modules(self, *args, **kwargs):
"""Override original method to delegate execution to the original PyTorch user module"""
# PyTorch >1.8.1 has an extra arg remove_duplicate that is not present in 1.8.1
# To support both, use args and kwargs (since user can call the method with only positional args or kwargs)
yield from self._module_metadata.original_module.named_modules(*args, **kwargs)
def add_module(self, name: str, module: Optional['Module']) -> None:
"""Raises a NotImplementedError exception since ORTModule does not support adding modules to it"""
@ -207,4 +92,111 @@ class ORTModule(torch.nn.Module):
# This `module` property enables HuggingFace Trainer to retrieve the underlying PreTrainedModel inside ORTModule
# to save and load a complete checkpoint
return self._module_metadata.original_module
return self._torch_module.module
################################################################################
# The methods below are part of torch.nn.Module API that are encapsulated through
# TorchModuleInterface
################################################################################
def _apply(self, fn):
"""Override original method to delegate execution to the flattened PyTorch user module"""
self._torch_module._apply(fn)
return self
def apply(self: T, fn: Callable[['Module'], None]) -> T:
"""Override original method to delegate execution to the flattened PyTorch user module"""
self._torch_module.apply(fn)
return self
def _is_training(self):
return self._torch_module.is_training()
def train(self: T, mode: bool = True) -> T:
"""Override original method to delegate execution to the flattened PyTorch user module"""
self.training = mode
# In a torch.nn.Module, _modules stores all dependent modules (sub-modules) of the current module.
# in a list so that torch.nn.Module can apply any changes to all sub-modules recursively.
# Although the _flattened_module and _original_module are dependent modules for ORTModule,
# they do not show up in _modules because they are abstracted away behind another class,
# TorchModule. In order to apply changes to those sub-modules, delegate the task to _torch_module
# which will recursively update the flattened_module and the original module.
self._torch_module.train(mode)
return self
def state_dict(self, destination=None, prefix='', keep_vars=False):
"""Override original method to delegate execution to the original PyTorch user module"""
return self._torch_module.state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars)
def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
strict: bool = True):
"""Override original method to delegate execution to the original PyTorch user module"""
return self._torch_module.load_state_dict(state_dict, strict=strict)
def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None:
"""Override original method to delegate execution to the original PyTorch user module"""
self._torch_module.register_buffer(name, tensor, persistent=persistent)
def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
"""Override original method to delegate execution to the original PyTorch user module"""
self._torch_module.register_parameter(name, param)
def get_parameter(self, target: str) -> torch.nn.Parameter:
"""Override original method to delegate execution to the original PyTorch user module"""
return self._torch_module.get_parameter(target)
def get_buffer(self, target: str) -> torch.Tensor:
"""Override original method to delegate execution to the original PyTorch user module"""
return self._torch_module.get_buffer(target)
def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]:
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._torch_module.parameters(recurse=recurse)
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]:
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._torch_module.named_parameters(prefix=prefix, recurse=recurse)
def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]:
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._torch_module.buffers(recurse=recurse)
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]:
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._torch_module.named_buffers(prefix=prefix, recurse=recurse)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""Override original method to delegate execution to the original PyTorch user module"""
self._torch_module._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
def named_children(self) -> Iterator[Tuple[str, 'Module']]:
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._torch_module.named_children()
def modules(self) -> Iterator['Module']:
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._torch_module.modules()
def named_modules(self, *args, **kwargs):
"""Override original method to delegate execution to the original PyTorch user module"""
yield from self._torch_module.named_modules(*args, **kwargs)

View file

@ -124,7 +124,7 @@ def assert_optim_state(expected_state, actual_state, rtol=1e-7, atol=0):
def is_dynamic_axes(model):
# Check inputs
for inp in model._execution_manager(model._is_training())._optimized_onnx_model.graph.input:
for inp in model._torch_module._execution_manager(model._is_training())._optimized_onnx_model.graph.input:
shape = inp.type.tensor_type.shape
if shape:
for dim in shape.dim:
@ -132,7 +132,7 @@ def is_dynamic_axes(model):
return False
# Check outputs
for out in model._execution_manager(model._is_training())._optimized_onnx_model.graph.output:
for out in model._torch_module._execution_manager(model._is_training())._optimized_onnx_model.graph.output:
shape = out.type.tensor_type.shape
if shape:
for dim in shape.dim:

View file

@ -472,7 +472,7 @@ def test_torch_nn_module_to_api(original_device, to_argument):
model = model.to(to_argument)
x = x.to(to_argument)
model(x)
assert _utils.get_device_str(model._execution_manager(model._is_training())._device) == \
assert _utils.get_device_str(model._torch_module._execution_manager(model._is_training())._device) == \
_utils.get_device_str(torch.device(to_argument))
def test_model_without_device():
@ -524,7 +524,7 @@ def test_input_requires_grad_saved(device):
model = ORTModule(model)
x = torch.randn(N, D_in, device=device, requires_grad=True) + 1
model(x)
assert model._execution_manager(model._is_training())._input_info.require_grad_names == ['input1']
assert model._torch_module._execution_manager(model._is_training())._input_info.require_grad_names == ['input1']
@pytest.mark.parametrize("device", ['cuda', 'cpu'])
def test_input_requires_grad_backward_creates_input_grad(device):
@ -1033,12 +1033,12 @@ def test_changes_input_requires_grad_reinitializes_module_gradient_graph_builder
output_x.backward()
assert x.grad is None
module_gradient_graph_builder_training = \
model._execution_manager(model._is_training())._graph_builder
model._torch_module._execution_manager(model._is_training())._graph_builder
output_y = torch.sum(model(y))
output_y.backward()
assert y.grad is not None
assert module_gradient_graph_builder_training != \
model._execution_manager(model._is_training())._graph_builder
model._torch_module._execution_manager(model._is_training())._graph_builder
@pytest.mark.parametrize("device", ['cuda'])
def test_input_requires_grad_backward_creates_input_grad_as_required0(device):
@ -1695,7 +1695,7 @@ def test_forward_data_and_model_on_different_devices(data_device, model_device):
x = torch.randn(N, D_in, device=data_device)
with pytest.raises(RuntimeError) as runtime_error:
ort_model(x)
assert f"Input argument to forward found on device {torch.device(x.device)}, but expected it to be on module device {ort_model._execution_manager(ort_model._is_training())._device}." in str(runtime_error.value)
assert f"Input argument to forward found on device {torch.device(x.device)}, but expected it to be on module device {ort_model._torch_module._execution_manager(ort_model._is_training())._device}." in str(runtime_error.value)
def test_forward_returns_none_type_as_output():
class NeuralNetNoneTypeOutput(torch.nn.Module):
@ -1799,26 +1799,26 @@ def test_model_initializer_requires_grad_changes_from_one_forward_to_next():
model.fc1.requires_grad_(True)
model = ORTModule(model)
x = torch.randn(N, D_in, device=device)
assert model._module_metadata.original_module.fc1.weight.grad is None
assert model._module_metadata.original_module.fc1.bias.grad is None
assert model.module.fc1.weight.grad is None
assert model.module.fc1.bias.grad is None
# Make sure no exception is raised
output = model(x)
loss = torch.sum(output)
loss.backward()
training_session1 = model._execution_manager(model._is_training())._execution_agent
weight_grad_2 = model._module_metadata.original_module.fc1.weight.grad
bias_grad_2 = model._module_metadata.original_module.fc1.bias.grad
training_session1 = model._torch_module._execution_manager(model._is_training())._execution_agent
weight_grad_2 = model.module.fc1.weight.grad
bias_grad_2 = model.module.fc1.bias.grad
assert weight_grad_2 is not None
assert bias_grad_2 is not None
model._module_metadata.original_module.fc1.requires_grad_(False)
model.module.fc1.requires_grad_(False)
output = model(x)
loss = torch.sum(output)
loss.backward()
training_session2 = model._execution_manager(model._is_training())._execution_agent
weight_grad_3 = model._module_metadata.original_module.fc1.weight.grad
bias_grad_3 = model._module_metadata.original_module.fc1.bias.grad
training_session2 = model._torch_module._execution_manager(model._is_training())._execution_agent
weight_grad_3 = model.module.fc1.weight.grad
bias_grad_3 = model.module.fc1.bias.grad
assert training_session1 != training_session2
assert torch.equal(weight_grad_2, weight_grad_3)
@ -2142,21 +2142,21 @@ def test_forward_dynamic_args():
for _ in range(10):
output = model(*args_size1)
assert output is not None
hash_args_size1 = hash(repr(model._execution_manager(model._is_training())._input_info.schema))
hash_args_size1 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema))
assert hash_args_size1 is not None
# Decrease number of inputs and train some more
for _ in range(10):
output = model(*args_size2)
assert output is not None
hash_args_size2 = hash(repr(model._execution_manager(model._is_training())._input_info.schema))
hash_args_size2 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema))
assert hash_args_size2 != hash_args_size1
# Increase number of inputs and train some more
for _ in range(10):
output = model(*args_size3)
assert output is not None
hash_args_size3 = hash(repr(model._execution_manager(model._is_training())._input_info.schema))
hash_args_size3 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema))
assert hash_args_size3 != hash_args_size2
@ -2178,35 +2178,35 @@ def test_forward_dynamic_kwargs():
for _ in range(10):
output = model(one)
assert output is not None
hash_x = hash(repr(model._execution_manager(model._is_training())._input_info.schema))
hash_x = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema))
assert hash_x is not None
# Train with x and y as inputs
for _ in range(10):
output = model(one,y=one)
assert output is not None
hash_x_y = hash(repr(model._execution_manager(model._is_training())._input_info.schema))
hash_x_y = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema))
assert hash_x_y != hash_x
# Train with x and z as inputs
for _ in range(10):
output = model(one,z=one)
assert output is not None
hash_x_z = hash(repr(model._execution_manager(model._is_training())._input_info.schema))
hash_x_z = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema))
assert hash_x_z != hash_x_y
# Train with x, y and z as inputs
for _ in range(10):
output = model(one,y=one, z=one)
assert output is not None
hash_x_y_z = hash(repr(model._execution_manager(model._is_training())._input_info.schema))
hash_x_y_z = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema))
assert hash_x_y_z != hash_x_z
# Return to original input with x as input
for _ in range(10):
output = model(one)
assert output is not None
hash_x2 = hash(repr(model._execution_manager(model._is_training())._input_info.schema))
hash_x2 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema))
assert hash_x2 != hash_x_y_z
assert hash_x2 == hash_x
@ -2358,30 +2358,30 @@ def test_forward_call_default_input():
# Model only uses a,d out of a,b,c,d
out = model(one, two, three, four)
assert out.item() == 5.0
if model.training:
if model._is_training():
out.sum().backward()
out = model(one, two, c=three, d=four)
assert out.item() == 5.0
if model.training:
if model._is_training():
out.sum().backward()
# Model only uses a,d,args[-1] out of a,b,c,d,*args
out = model(one, two, three, four, *args)
assert out.item() == 7.0
if model.training:
if model._is_training():
out.sum().backward()
# Model only uses a,d,args[-1],kw_0 out of a,b,c,d,*args,kw_0
out = model(one, two, three, four, *args, kw_0=kw_0)
assert out.item() == 13.0
if model.training:
if model._is_training():
out.sum().backward()
# Model only uses a,d,args[-1],kwargs['kwargs_1'] out of a,b,c,d,*args,kw_0,**kwargs
out = model(one, two, three, four, *args, **kwargs)
assert out.item() == 15.0
if model.training:
if model._is_training():
out.sum().backward()
@ -2424,7 +2424,7 @@ def test_forward_call_kwargs_input_unexpected_order():
y1, y2 = model(**{'input1': input1, 'input2': input2})
assert y1 is not None
assert y2 is not None
if model.training:
if model._is_training():
loss = y1.sum() + y2.sum()
loss.backward()
@ -2432,7 +2432,7 @@ def test_forward_call_kwargs_input_unexpected_order():
y1, y2 = model(**{'input2': input2, 'input1': input1})
assert y1 is not None
assert y2 is not None
if model.training:
if model._is_training():
loss = y1.sum() + y2.sum()
loss.backward()
@ -2469,11 +2469,11 @@ def test_forward_call_lots_None():
# ORTModule produces the same schema, thus not re-exporting
# the model when `forward(a,b)` is used after `forward(**{'a': a, 'b': b})`
# or vice-versa
model._execution_manager(model._is_training())._onnx_model = None
model._torch_module._execution_manager(model._is_training())._onnx_model = None
out = model(a,b,c,d,e,f,y,z)
assert out is not None
assert out.item() == expected
if model.training:
if model._is_training():
loss = out.sum()
loss.backward()
@ -2584,10 +2584,10 @@ def test_changing_bool_input_re_exports_model(bool_arguments):
input1 = torch.randn(N, D_in, device=device)
ort_model(input1, bool_arguments[0])
exported_model1 = ort_model._execution_manager(ort_model._is_training())._onnx_model
exported_model1 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_model
ort_model(input1, bool_arguments[1])
exported_model2 = ort_model._execution_manager(ort_model._is_training())._onnx_model
exported_model2 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_model
assert exported_model1 != exported_model2
@ -2741,7 +2741,7 @@ def test_unused_parameters_does_not_unnecssarily_reinitilize(model):
N, D_in, H1, H2, D_out = 64, 784, 500, 400, 10
model = model.to(device)
ort_model = ORTModule(copy.deepcopy(model))
training_manager = ort_model._execution_manager(ort_model._is_training())
training_manager = ort_model._torch_module._execution_manager(ort_model._is_training())
x = torch.randn(N, D_in, device=device)
_ = ort_model(x)

View file

@ -379,8 +379,8 @@ def main():
model = ORTModule(model)
# Just for future debugging
model._execution_manager(model._is_training())._save_onnx = False
model._execution_manager(model._is_training())._save_onnx_prefix = 'BertForSequenceClassification'
model._torch_module._execution_manager(model._is_training())._save_onnx = False
model._torch_module._execution_manager(model._is_training())._save_onnx_prefix = 'BertForSequenceClassification'
# Tell pytorch to run this model on the GPU.
if torch.cuda.is_available() and not args.no_cuda:

View file

@ -379,10 +379,10 @@ def main():
if not args.pytorch_only:
model = ORTModule(model)
model._execution_manager(is_training=True)._save_onnx = True
model._execution_manager(is_training=True)._save_onnx_prefix = 'BertForSequenceClassification'
model._execution_manager(is_training=True)._enable_grad_acc_optimization = True
model._torch_module._execution_manager(is_training=True)._save_onnx = True
model._torch_module._execution_manager(is_training=True)._save_onnx_prefix = 'BertForSequenceClassification'
model._torch_module._execution_manager(is_training=True)._enable_grad_acc_optimization = True
# Tell pytorch to run this model on the GPU.
if torch.cuda.is_available() and not args.no_cuda: