diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_interface.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_interface.py new file mode 100644 index 0000000000..3fbba5000d --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_interface.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# _graph_execution_interface.py + +from abc import ABC + + +class GraphExecutionInterface(ABC): + def __init__(self, module): + self._original_module = module + + self._validate_module_type(module) + + def forward(self): + """Executes the forward method for ORTModule + + This is an abstract method and must be overridden by a concrete implementation. + This is the only method that the user should call on a concrete instance of the GraphExecutionInterface + """ + + raise NotImplementedError(f"forward is not implemented for {type(self)}") + + def _validate_module_type(self, module): + """Validates the type of the input module + + This is an abstract method and must be overridden by a concrete implementation. + This is the only method that the user should call on a concrete instance of the GraphExecutionInterface + """ + + raise NotImplementedError(f"_validate_module_type is not implemented for {type(self)}") diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 7fc4bf0496..cec03077b2 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -5,6 +5,7 @@ from . import _utils, _io, _logger, torch_cpp_extensions as _cpp_ext from ._custom_autograd_function_exporter import _post_process_after_export +from ._graph_execution_interface import GraphExecutionInterface from onnxruntime.training.ortmodule import ONNX_OPSET_VERSION from onnxruntime.capi import _pybind_state as C @@ -30,7 +31,7 @@ class RunStateInfo(object): self.state = state self.output_info = output_info -class GraphExecutionManager(ABC): +class GraphExecutionManager(GraphExecutionInterface): def __init__(self, module): """Manages building and execution of onnx graphs @@ -41,8 +42,9 @@ class GraphExecutionManager(ABC): the onnx graph, and ExecutionAgent to run the onnx graph. """ + super(GraphExecutionManager, self).__init__(module._original_module) + # Original and flattened (tranformed) output module - self._original_module = module._original_module self._flattened_module = module # Exported model @@ -126,6 +128,12 @@ class GraphExecutionManager(ABC): # WIP feature to enable caching in Gradient accumulation scenario. self._enable_grad_acc_optimization = False + def _validate_module_type(self, module): + """Raises a TypeError if the module is not a torch.nn.Module""" + + if not isinstance(module, torch.nn.Module): + raise TypeError(f"ORTModule only supports torch.nn.Module as input. {type(module)} is not supported.") + @staticmethod def execution_session_run_forward(execution_session, onnx_model, device, *inputs): """Runs the forward pass on `execution_session` with given `onnx_model`, `device` and `inputs` diff --git a/orttraining/orttraining/python/training/ortmodule/_torch_module.py b/orttraining/orttraining/python/training/ortmodule/_torch_module.py new file mode 100644 index 0000000000..e4f7ba5fcc --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/_torch_module.py @@ -0,0 +1,152 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# _torch_module.py + +from . import _io +from ._graph_execution_manager_factory import GraphExecutionManagerFactory +from ._torch_module_interface import TorchModuleInterface + +from collections import OrderedDict +import functools +import torch +from typing import Iterator, Optional, Tuple, TypeVar, Callable + + +T = TypeVar('T', bound='torch.nn.Module') + + +class TorchModule(TorchModuleInterface): + def __init__(self, module: torch.nn.Module): + super(TorchModule, self).__init__(module) + self._flattened_module = _io._FlattenedModule(module) + + def _forward(self, *inputs, **kwargs): + '''Forward pass starts here and continues at `_ORTModuleFunction.forward` + + ONNX model is exported the first time this method is executed. + Next, we build a full training graph with module_gradient_graph_builder. + Finally, we instantiate the ONNX Runtime InferenceSession. + ''' + + return self._execution_manager(self.is_training()).forward(*inputs, **kwargs) + + # Bind the forward method. + self.forward = _forward.__get__(self) + # Copy the forward signature from the PyTorch module. + functools.update_wrapper( + self.forward.__func__, self._original_module.forward.__func__) + + self._execution_manager = GraphExecutionManagerFactory(self._flattened_module) + + def _apply(self, fn): + """Override original method to delegate execution to the flattened PyTorch user module""" + + # Delegation must happen to _flattened_module since methods depend on + # _apply to recursively apply the internal setting changes + self._flattened_module._apply(fn) + return self + + def apply(self: T, fn: Callable[[T], None]) -> T: + """Override original method to delegate execution to the flattened PyTorch user module""" + + # Delegation must happen to _flattened_module since methods depend on + # apply to recursively apply the internal setting changes + self._flattened_module.apply(fn) + return self + + def is_training(self): + return self._flattened_module.training and torch.is_grad_enabled() + + def train(self: T, mode: bool = True) -> T: + """Override original method to delegate execution to the flattened PyTorch user module""" + + # Delegate the task to _module.flattened_module.train which will recursively + # update the original_module + self._flattened_module.train(mode) + return self + + def state_dict(self, destination=None, prefix='', keep_vars=False): + """Override original method to delegate execution to the original PyTorch user module""" + + # Override the state_dict() method so that the state dict key names + # do not contain the flattened_module._original_module prefix + return self._original_module.state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars) + + def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', + strict: bool = True): + """Override original method to delegate execution to the original PyTorch user module""" + + # Override the load_state_dict() method so that the loaded state dict + # key names does not need to contain the _module.flattened_module._original_module prefix + return self._original_module.load_state_dict( + state_dict, strict=strict) + + def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None: + """Override original method to delegate execution to the original PyTorch user module""" + + self._original_module.register_buffer(name, tensor, persistent=persistent) + + def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None: + """Override original method to delegate execution to the original PyTorch user module""" + + self._original_module.register_parameter(name, param) + + def get_parameter(self, target: str) -> torch.nn.Parameter: + """Override original method to delegate execution to the original PyTorch user module""" + + return self._original_module.get_parameter(target) + + def get_buffer(self, target: str) -> torch.Tensor: + """Override original method to delegate execution to the original PyTorch user module""" + + return self._original_module.get_buffer(target) + + def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]: + """Override original method to delegate execution to the original PyTorch user module""" + + yield from self._original_module.parameters(recurse=recurse) + + def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]: + """Override original method to delegate execution to the original PyTorch user module""" + + yield from self._original_module.named_parameters(prefix=prefix, recurse=recurse) + + def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]: + """Override original method to delegate execution to the original PyTorch user module""" + + yield from self._original_module.buffers(recurse=recurse) + + def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]: + """Override original method to delegate execution to the original PyTorch user module""" + + yield from self._original_module.named_buffers(prefix=prefix, recurse=recurse) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + """Override original method to delegate execution to the original PyTorch user module""" + + # PyTorch load_state_dict implementation does not recursively call load_state_dict on its sub-modules. + # Instead, it creates a recursive function and invokes _load_from_state_dict on all child modules. + # For the scenario where an ORTModule is a sub-module of another module, loading of the state + # dictionary requires the _load_from_state_dict to be overridden to prevent an error. + + self._original_module._load_from_state_dict(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def named_children(self) -> Iterator[Tuple[str, T]]: + """Override original method to delegate execution to the original PyTorch user module""" + + yield from self._original_module.named_children() + + def modules(self) -> Iterator[T]: + """Override original method to delegate execution to the original PyTorch user module""" + + yield from self._original_module.modules() + + def named_modules(self, *args, **kwargs): + """Override original method to delegate execution to the original PyTorch user module""" + + # PyTorch >1.8.1 has an extra arg remove_duplicate that is not present in 1.8.1 + # To support both, use args and kwargs (since user can call the method with only positional args or kwargs) + yield from self._original_module.named_modules(*args, **kwargs) diff --git a/orttraining/orttraining/python/training/ortmodule/_torch_module_factory.py b/orttraining/orttraining/python/training/ortmodule/_torch_module_factory.py new file mode 100644 index 0000000000..3aaf40dca6 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/_torch_module_factory.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# _torch_module_factory.py + +from ._torch_module import TorchModule + + +class TorchModuleFactory: + def __call__(self, module): + """Creates a TorchModule instance based on the input module.""" + + return TorchModule(module) diff --git a/orttraining/orttraining/python/training/ortmodule/_torch_module_interface.py b/orttraining/orttraining/python/training/ortmodule/_torch_module_interface.py new file mode 100644 index 0000000000..8294f949f1 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/_torch_module_interface.py @@ -0,0 +1,115 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# _torch_module_interface.py + +from collections import OrderedDict +import torch +from typing import Iterator, Optional, Tuple, TypeVar, Callable + + +T = TypeVar('T', bound='torch.nn.Module') + + +class TorchModuleInterface: + """Abstract class that provides the function signatures for the torch.nn.Module + + Concrete implementations should inherit from this class and provide necessary executions. + """ + + def __init__(self, module): + self._original_module = module + + @property + def module(self): + """The original user provided module that this class manages. + + This property provides access to methods and properties on the original module. + """ + + return self._original_module + + ################################################### + # The methods below are part of torch.nn.Module API + ################################################### + + def forward(self): + """Executes the forward method for ORTModule + + This is an abstract method and must be overridden by a concrete implementation. + """ + + raise NotImplementedError(f"forward is not implemented for {type(self)}.") + + def _apply(self, fn): + + raise NotImplementedError(f"_apply is not implemented for {type(self)}.") + + def apply(self: T, fn: Callable[[T], None]) -> T: + + raise NotImplementedError(f"apply is not implemented for {type(self)}.") + + def is_training(self): + + raise NotImplementedError(f"is_training is not implemented for {type(self)}.") + + def train(self: T, mode: bool = True) -> T: + + raise NotImplementedError(f"train is not implemented for {type(self)}.") + + def state_dict(self, destination=None, prefix='', keep_vars=False): + + raise NotImplementedError(f"state_dict is not implemented for {type(self)}.") + + def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', + strict: bool = True): + + raise NotImplementedError(f"load_state_dict is not implemented for {type(self)}.") + + def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None: + + raise NotImplementedError(f"register_buffer is not implemented for {type(self)}.") + + def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None: + + raise NotImplementedError(f"register_parameter is not implemented for {type(self)}.") + + def get_parameter(self, target: str) -> torch.nn.Parameter: + + raise NotImplementedError(f"get_parameter is not implemented for {type(self)}.") + + def get_buffer(self, target: str) -> torch.Tensor: + + raise NotImplementedError(f"get_buffer is not implemented for {type(self)}.") + + def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]: + + raise NotImplementedError(f"parameters is not implemented for {type(self)}.") + + def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]: + + raise NotImplementedError(f"named_parameters is not implemented for {type(self)}.") + + def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]: + + raise NotImplementedError(f"buffers is not implemented for {type(self)}.") + + def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]: + + raise NotImplementedError(f"named_buffers is not implemented for {type(self)}.") + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + + raise NotImplementedError(f"_load_from_state_dict is not implemented for {type(self)}.") + + def named_children(self) -> Iterator[Tuple[str, T]]: + + raise NotImplementedError(f"named_children is not implemented for {type(self)}.") + + def modules(self) -> Iterator[T]: + + raise NotImplementedError(f"modules is not implemented for {type(self)}.") + + def named_modules(self, *args, **kwargs): + + raise NotImplementedError(f"named_modules is not implemented for {type(self)}.") diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index 751c5f1a46..98d553bcac 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -97,10 +97,3 @@ def _create_iobinding(io_binding, inputs, model, device): for value_info in model.graph.output: io_binding.bind_output(value_info.name, device.type, device_id=get_device_index(device)) - -class _PytorchModuleMetadata(): - """Encapsulates modules and allows easy access as required""" - - def __init__(self, original_module, flattened_module): - self.original_module = original_module - self.flattened_module = flattened_module diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index de8a9acc0f..daed3a9f3c 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -3,9 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from . import _io -from ._graph_execution_manager_factory import GraphExecutionManagerFactory -from ._utils import _PytorchModuleMetadata +from ._torch_module_factory import TorchModuleFactory from onnxruntime.training import register_custom_ops_pytorch_exporter @@ -24,8 +22,7 @@ class ORTModule(torch.nn.Module): """ def __init__(self, module): - assert isinstance( - module, torch.nn.Module), "'module' must be a torch.nn.Module" + self._torch_module = TorchModuleFactory()(module) # Create forward dynamically, so each ORTModule instance will have its own copy. # This is needed to be able to copy the forward signatures from the original PyTorch models @@ -38,26 +35,19 @@ class ORTModule(torch.nn.Module): Finally, we instantiate the ONNX Runtime InferenceSession. ''' - return self._execution_manager(self._is_training()).forward(*inputs, **kwargs) + return self._torch_module.forward(*inputs, **kwargs) # Bind the forward method. self.forward = _forward.__get__(self) - # Copy the forward signature from the PyTorch module. + # Copy the forward signature from the _torch_module's forward signature. functools.update_wrapper( - self.forward.__func__, module.forward.__func__) + self.forward.__func__, self._torch_module.forward.__func__) super(ORTModule, self).__init__() # Support contrib OPs register_custom_ops_pytorch_exporter.register_custom_op(is_ortmodule=True) - # User module is wrapped to use its initializers and save computed gradients - # along with the module that flattens both input and output of the user module - # inside _PytorchModuleMetadata - self._module_metadata = _PytorchModuleMetadata(module, _io._FlattenedModule(module)) - - self._execution_manager = GraphExecutionManagerFactory(self._module_metadata.flattened_module) - # IMPORTANT: DO NOT add code here # This declaration is for automatic document generation purposes only # The actual forward implementation is bound during ORTModule initialization @@ -65,83 +55,6 @@ class ORTModule(torch.nn.Module): '''Dummy documentation for forward method''' ... - def _apply(self, fn): - """Override original method to delegate execution to the flattened PyTorch user module""" - - # Delegation must happen to _flattened_module since methods depend on - # _apply to recursively apply the internal setting changes - self._module_metadata.flattened_module._apply(fn) - return self - - def apply(self: T, fn: Callable[['Module'], None]) -> T: - """Override original method to delegate execution to the flattened PyTorch user module""" - - # Delegation must happen to _flattened_module since methods depend on - # apply to recursively apply the internal setting changes - self._module_metadata.flattened_module.apply(fn) - return self - - def _is_training(self): - return self.training and torch.is_grad_enabled() - - def train(self: T, mode: bool = True) -> T: - """Override original method to delegate execution to the flattened PyTorch user module""" - - # Since _modules is empty, the task needs to be delegated to _module.flattened_module.train - # which will recursively update the original_module - self.training = mode - self._module_metadata.flattened_module.train(mode) - return self - - def state_dict(self, destination=None, prefix='', keep_vars=False): - """Override original method to delegate execution to the original PyTorch user module""" - - # Override the state_dict() method so that the state dict key names - # do not contain the flattened_module._original_module prefix - return self._module_metadata.original_module.state_dict( - destination=destination, prefix=prefix, keep_vars=keep_vars) - - def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', - strict: bool = True): - """Override original method to delegate execution to the original PyTorch user module""" - - # Override the load_state_dict() method so that the loaded state dict - # key names does not need to contain the _module.flattened_module._original_module prefix - return self._module_metadata.original_module.load_state_dict( - state_dict, strict=strict) - - def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None: - """Override original method to delegate execution to the original PyTorch user module""" - self._module_metadata.original_module.register_buffer(name, tensor, persistent=persistent) - - def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None: - """Override original method to delegate execution to the original PyTorch user module""" - self._module_metadata.original_module.register_parameter(name, param) - - def get_parameter(self, target: str) -> torch.nn.Parameter: - """Override original method to delegate execution to the original PyTorch user module""" - return self._module_metadata.original_module.get_parameter(target) - - def get_buffer(self, target: str) -> torch.Tensor: - """Override original method to delegate execution to the original PyTorch user module""" - return self._module_metadata.original_module.get_buffer(target) - - def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]: - """Override original method to delegate execution to the original PyTorch user module""" - yield from self._module_metadata.original_module.parameters(recurse=recurse) - - def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]: - """Override original method to delegate execution to the original PyTorch user module""" - yield from self._module_metadata.original_module.named_parameters(prefix=prefix, recurse=recurse) - - def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]: - """Override original method to delegate execution to the original PyTorch user module""" - yield from self._module_metadata.original_module.buffers(recurse=recurse) - - def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]: - """Override original method to delegate execution to the original PyTorch user module""" - yield from self._module_metadata.original_module.named_buffers(prefix=prefix, recurse=recurse) - def _replicate_for_data_parallel(self): """Raises a NotImplementedError exception since ORTModule is not compatible with torch.nn.DataParallel @@ -161,34 +74,6 @@ class ORTModule(torch.nn.Module): raise NotImplementedError("ORTModule is not compatible with torch.nn.DataParallel. " "Please use torch.nn.parallel.DistributedDataParallel instead.") - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): - """Override original method to delegate execution to the original PyTorch user module""" - - # PyTorch load_state_dict implementation does not recursively call load_state_dict on its sub-modules. - # Instead, it creates a recursive function and invokes _load_from_state_dict on all child modules. - # For the scenario where an ORTModule is a sub-module of another module, loading of the state - # dictionary requires the _load_from_state_dict to be overridden to prevent an error. - self._module_metadata.original_module._load_from_state_dict(state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs) - - def named_children(self) -> Iterator[Tuple[str, 'Module']]: - """Override original method to delegate execution to the original PyTorch user module""" - - yield from self._module_metadata.original_module.named_children() - - def modules(self) -> Iterator['Module']: - """Override original method to delegate execution to the original PyTorch user module""" - - yield from self._module_metadata.original_module.modules() - - def named_modules(self, *args, **kwargs): - """Override original method to delegate execution to the original PyTorch user module""" - - # PyTorch >1.8.1 has an extra arg remove_duplicate that is not present in 1.8.1 - # To support both, use args and kwargs (since user can call the method with only positional args or kwargs) - yield from self._module_metadata.original_module.named_modules(*args, **kwargs) - def add_module(self, name: str, module: Optional['Module']) -> None: """Raises a NotImplementedError exception since ORTModule does not support adding modules to it""" @@ -207,4 +92,111 @@ class ORTModule(torch.nn.Module): # This `module` property enables HuggingFace Trainer to retrieve the underlying PreTrainedModel inside ORTModule # to save and load a complete checkpoint - return self._module_metadata.original_module + return self._torch_module.module + + ################################################################################ + # The methods below are part of torch.nn.Module API that are encapsulated through + # TorchModuleInterface + ################################################################################ + + def _apply(self, fn): + """Override original method to delegate execution to the flattened PyTorch user module""" + + self._torch_module._apply(fn) + return self + + def apply(self: T, fn: Callable[['Module'], None]) -> T: + """Override original method to delegate execution to the flattened PyTorch user module""" + + self._torch_module.apply(fn) + return self + + def _is_training(self): + return self._torch_module.is_training() + + def train(self: T, mode: bool = True) -> T: + """Override original method to delegate execution to the flattened PyTorch user module""" + + self.training = mode + # In a torch.nn.Module, _modules stores all dependent modules (sub-modules) of the current module. + # in a list so that torch.nn.Module can apply any changes to all sub-modules recursively. + # Although the _flattened_module and _original_module are dependent modules for ORTModule, + # they do not show up in _modules because they are abstracted away behind another class, + # TorchModule. In order to apply changes to those sub-modules, delegate the task to _torch_module + # which will recursively update the flattened_module and the original module. + self._torch_module.train(mode) + return self + + def state_dict(self, destination=None, prefix='', keep_vars=False): + """Override original method to delegate execution to the original PyTorch user module""" + + return self._torch_module.state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars) + + def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', + strict: bool = True): + """Override original method to delegate execution to the original PyTorch user module""" + + return self._torch_module.load_state_dict(state_dict, strict=strict) + + def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None: + """Override original method to delegate execution to the original PyTorch user module""" + + self._torch_module.register_buffer(name, tensor, persistent=persistent) + + def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None: + """Override original method to delegate execution to the original PyTorch user module""" + + self._torch_module.register_parameter(name, param) + + def get_parameter(self, target: str) -> torch.nn.Parameter: + """Override original method to delegate execution to the original PyTorch user module""" + + return self._torch_module.get_parameter(target) + + def get_buffer(self, target: str) -> torch.Tensor: + """Override original method to delegate execution to the original PyTorch user module""" + + return self._torch_module.get_buffer(target) + + def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]: + """Override original method to delegate execution to the original PyTorch user module""" + + yield from self._torch_module.parameters(recurse=recurse) + + def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]: + """Override original method to delegate execution to the original PyTorch user module""" + + yield from self._torch_module.named_parameters(prefix=prefix, recurse=recurse) + + def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]: + """Override original method to delegate execution to the original PyTorch user module""" + + yield from self._torch_module.buffers(recurse=recurse) + + def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]: + """Override original method to delegate execution to the original PyTorch user module""" + + yield from self._torch_module.named_buffers(prefix=prefix, recurse=recurse) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + """Override original method to delegate execution to the original PyTorch user module""" + + self._torch_module._load_from_state_dict(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def named_children(self) -> Iterator[Tuple[str, 'Module']]: + """Override original method to delegate execution to the original PyTorch user module""" + + yield from self._torch_module.named_children() + + def modules(self) -> Iterator['Module']: + """Override original method to delegate execution to the original PyTorch user module""" + + yield from self._torch_module.modules() + + def named_modules(self, *args, **kwargs): + """Override original method to delegate execution to the original PyTorch user module""" + + yield from self._torch_module.named_modules(*args, **kwargs) diff --git a/orttraining/orttraining/test/python/_test_helpers.py b/orttraining/orttraining/test/python/_test_helpers.py index d39da1ae64..c1d6192c95 100644 --- a/orttraining/orttraining/test/python/_test_helpers.py +++ b/orttraining/orttraining/test/python/_test_helpers.py @@ -124,7 +124,7 @@ def assert_optim_state(expected_state, actual_state, rtol=1e-7, atol=0): def is_dynamic_axes(model): # Check inputs - for inp in model._execution_manager(model._is_training())._optimized_onnx_model.graph.input: + for inp in model._torch_module._execution_manager(model._is_training())._optimized_onnx_model.graph.input: shape = inp.type.tensor_type.shape if shape: for dim in shape.dim: @@ -132,7 +132,7 @@ def is_dynamic_axes(model): return False # Check outputs - for out in model._execution_manager(model._is_training())._optimized_onnx_model.graph.output: + for out in model._torch_module._execution_manager(model._is_training())._optimized_onnx_model.graph.output: shape = out.type.tensor_type.shape if shape: for dim in shape.dim: diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 98388700ef..ac73fd06fc 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -472,7 +472,7 @@ def test_torch_nn_module_to_api(original_device, to_argument): model = model.to(to_argument) x = x.to(to_argument) model(x) - assert _utils.get_device_str(model._execution_manager(model._is_training())._device) == \ + assert _utils.get_device_str(model._torch_module._execution_manager(model._is_training())._device) == \ _utils.get_device_str(torch.device(to_argument)) def test_model_without_device(): @@ -524,7 +524,7 @@ def test_input_requires_grad_saved(device): model = ORTModule(model) x = torch.randn(N, D_in, device=device, requires_grad=True) + 1 model(x) - assert model._execution_manager(model._is_training())._input_info.require_grad_names == ['input1'] + assert model._torch_module._execution_manager(model._is_training())._input_info.require_grad_names == ['input1'] @pytest.mark.parametrize("device", ['cuda', 'cpu']) def test_input_requires_grad_backward_creates_input_grad(device): @@ -1033,12 +1033,12 @@ def test_changes_input_requires_grad_reinitializes_module_gradient_graph_builder output_x.backward() assert x.grad is None module_gradient_graph_builder_training = \ - model._execution_manager(model._is_training())._graph_builder + model._torch_module._execution_manager(model._is_training())._graph_builder output_y = torch.sum(model(y)) output_y.backward() assert y.grad is not None assert module_gradient_graph_builder_training != \ - model._execution_manager(model._is_training())._graph_builder + model._torch_module._execution_manager(model._is_training())._graph_builder @pytest.mark.parametrize("device", ['cuda']) def test_input_requires_grad_backward_creates_input_grad_as_required0(device): @@ -1695,7 +1695,7 @@ def test_forward_data_and_model_on_different_devices(data_device, model_device): x = torch.randn(N, D_in, device=data_device) with pytest.raises(RuntimeError) as runtime_error: ort_model(x) - assert f"Input argument to forward found on device {torch.device(x.device)}, but expected it to be on module device {ort_model._execution_manager(ort_model._is_training())._device}." in str(runtime_error.value) + assert f"Input argument to forward found on device {torch.device(x.device)}, but expected it to be on module device {ort_model._torch_module._execution_manager(ort_model._is_training())._device}." in str(runtime_error.value) def test_forward_returns_none_type_as_output(): class NeuralNetNoneTypeOutput(torch.nn.Module): @@ -1799,26 +1799,26 @@ def test_model_initializer_requires_grad_changes_from_one_forward_to_next(): model.fc1.requires_grad_(True) model = ORTModule(model) x = torch.randn(N, D_in, device=device) - assert model._module_metadata.original_module.fc1.weight.grad is None - assert model._module_metadata.original_module.fc1.bias.grad is None + assert model.module.fc1.weight.grad is None + assert model.module.fc1.bias.grad is None # Make sure no exception is raised output = model(x) loss = torch.sum(output) loss.backward() - training_session1 = model._execution_manager(model._is_training())._execution_agent - weight_grad_2 = model._module_metadata.original_module.fc1.weight.grad - bias_grad_2 = model._module_metadata.original_module.fc1.bias.grad + training_session1 = model._torch_module._execution_manager(model._is_training())._execution_agent + weight_grad_2 = model.module.fc1.weight.grad + bias_grad_2 = model.module.fc1.bias.grad assert weight_grad_2 is not None assert bias_grad_2 is not None - model._module_metadata.original_module.fc1.requires_grad_(False) + model.module.fc1.requires_grad_(False) output = model(x) loss = torch.sum(output) loss.backward() - training_session2 = model._execution_manager(model._is_training())._execution_agent - weight_grad_3 = model._module_metadata.original_module.fc1.weight.grad - bias_grad_3 = model._module_metadata.original_module.fc1.bias.grad + training_session2 = model._torch_module._execution_manager(model._is_training())._execution_agent + weight_grad_3 = model.module.fc1.weight.grad + bias_grad_3 = model.module.fc1.bias.grad assert training_session1 != training_session2 assert torch.equal(weight_grad_2, weight_grad_3) @@ -2142,21 +2142,21 @@ def test_forward_dynamic_args(): for _ in range(10): output = model(*args_size1) assert output is not None - hash_args_size1 = hash(repr(model._execution_manager(model._is_training())._input_info.schema)) + hash_args_size1 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) assert hash_args_size1 is not None # Decrease number of inputs and train some more for _ in range(10): output = model(*args_size2) assert output is not None - hash_args_size2 = hash(repr(model._execution_manager(model._is_training())._input_info.schema)) + hash_args_size2 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) assert hash_args_size2 != hash_args_size1 # Increase number of inputs and train some more for _ in range(10): output = model(*args_size3) assert output is not None - hash_args_size3 = hash(repr(model._execution_manager(model._is_training())._input_info.schema)) + hash_args_size3 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) assert hash_args_size3 != hash_args_size2 @@ -2178,35 +2178,35 @@ def test_forward_dynamic_kwargs(): for _ in range(10): output = model(one) assert output is not None - hash_x = hash(repr(model._execution_manager(model._is_training())._input_info.schema)) + hash_x = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) assert hash_x is not None # Train with x and y as inputs for _ in range(10): output = model(one,y=one) assert output is not None - hash_x_y = hash(repr(model._execution_manager(model._is_training())._input_info.schema)) + hash_x_y = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) assert hash_x_y != hash_x # Train with x and z as inputs for _ in range(10): output = model(one,z=one) assert output is not None - hash_x_z = hash(repr(model._execution_manager(model._is_training())._input_info.schema)) + hash_x_z = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) assert hash_x_z != hash_x_y # Train with x, y and z as inputs for _ in range(10): output = model(one,y=one, z=one) assert output is not None - hash_x_y_z = hash(repr(model._execution_manager(model._is_training())._input_info.schema)) + hash_x_y_z = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) assert hash_x_y_z != hash_x_z # Return to original input with x as input for _ in range(10): output = model(one) assert output is not None - hash_x2 = hash(repr(model._execution_manager(model._is_training())._input_info.schema)) + hash_x2 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) assert hash_x2 != hash_x_y_z assert hash_x2 == hash_x @@ -2358,30 +2358,30 @@ def test_forward_call_default_input(): # Model only uses a,d out of a,b,c,d out = model(one, two, three, four) assert out.item() == 5.0 - if model.training: + if model._is_training(): out.sum().backward() out = model(one, two, c=three, d=four) assert out.item() == 5.0 - if model.training: + if model._is_training(): out.sum().backward() # Model only uses a,d,args[-1] out of a,b,c,d,*args out = model(one, two, three, four, *args) assert out.item() == 7.0 - if model.training: + if model._is_training(): out.sum().backward() # Model only uses a,d,args[-1],kw_0 out of a,b,c,d,*args,kw_0 out = model(one, two, three, four, *args, kw_0=kw_0) assert out.item() == 13.0 - if model.training: + if model._is_training(): out.sum().backward() # Model only uses a,d,args[-1],kwargs['kwargs_1'] out of a,b,c,d,*args,kw_0,**kwargs out = model(one, two, three, four, *args, **kwargs) assert out.item() == 15.0 - if model.training: + if model._is_training(): out.sum().backward() @@ -2424,7 +2424,7 @@ def test_forward_call_kwargs_input_unexpected_order(): y1, y2 = model(**{'input1': input1, 'input2': input2}) assert y1 is not None assert y2 is not None - if model.training: + if model._is_training(): loss = y1.sum() + y2.sum() loss.backward() @@ -2432,7 +2432,7 @@ def test_forward_call_kwargs_input_unexpected_order(): y1, y2 = model(**{'input2': input2, 'input1': input1}) assert y1 is not None assert y2 is not None - if model.training: + if model._is_training(): loss = y1.sum() + y2.sum() loss.backward() @@ -2469,11 +2469,11 @@ def test_forward_call_lots_None(): # ORTModule produces the same schema, thus not re-exporting # the model when `forward(a,b)` is used after `forward(**{'a': a, 'b': b})` # or vice-versa - model._execution_manager(model._is_training())._onnx_model = None + model._torch_module._execution_manager(model._is_training())._onnx_model = None out = model(a,b,c,d,e,f,y,z) assert out is not None assert out.item() == expected - if model.training: + if model._is_training(): loss = out.sum() loss.backward() @@ -2584,10 +2584,10 @@ def test_changing_bool_input_re_exports_model(bool_arguments): input1 = torch.randn(N, D_in, device=device) ort_model(input1, bool_arguments[0]) - exported_model1 = ort_model._execution_manager(ort_model._is_training())._onnx_model + exported_model1 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_model ort_model(input1, bool_arguments[1]) - exported_model2 = ort_model._execution_manager(ort_model._is_training())._onnx_model + exported_model2 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_model assert exported_model1 != exported_model2 @@ -2741,7 +2741,7 @@ def test_unused_parameters_does_not_unnecssarily_reinitilize(model): N, D_in, H1, H2, D_out = 64, 784, 500, 400, 10 model = model.to(device) ort_model = ORTModule(copy.deepcopy(model)) - training_manager = ort_model._execution_manager(ort_model._is_training()) + training_manager = ort_model._torch_module._execution_manager(ort_model._is_training()) x = torch.randn(N, D_in, device=device) _ = ort_model(x) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py index f86ec1ee79..383e26e607 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py @@ -379,8 +379,8 @@ def main(): model = ORTModule(model) # Just for future debugging - model._execution_manager(model._is_training())._save_onnx = False - model._execution_manager(model._is_training())._save_onnx_prefix = 'BertForSequenceClassification' + model._torch_module._execution_manager(model._is_training())._save_onnx = False + model._torch_module._execution_manager(model._is_training())._save_onnx_prefix = 'BertForSequenceClassification' # Tell pytorch to run this model on the GPU. if torch.cuda.is_available() and not args.no_cuda: diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py index 00c07d547f..7b6e0afd4c 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py @@ -379,10 +379,10 @@ def main(): if not args.pytorch_only: model = ORTModule(model) - - model._execution_manager(is_training=True)._save_onnx = True - model._execution_manager(is_training=True)._save_onnx_prefix = 'BertForSequenceClassification' - model._execution_manager(is_training=True)._enable_grad_acc_optimization = True + + model._torch_module._execution_manager(is_training=True)._save_onnx = True + model._torch_module._execution_manager(is_training=True)._save_onnx_prefix = 'BertForSequenceClassification' + model._torch_module._execution_manager(is_training=True)._enable_grad_acc_optimization = True # Tell pytorch to run this model on the GPU. if torch.cuda.is_available() and not args.no_cuda: