Update ORTModule API docstrings (#8309)

This commit is contained in:
Nat Kershaw (MSFT) 2021-08-16 16:53:01 -07:00 committed by GitHub
parent 8713d76dd1
commit aa12d68c37
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -66,8 +66,24 @@ class ORTModule(torch.nn.Module):
# This declaration is for automatic document generation purposes only
# The actual forward implementation is bound during ORTModule initialization
def forward(self, *inputs, **kwargs):
'''Dummy documentation for forward method'''
...
'''Delegate the :meth:`~torch.nn.Module.forward` pass of PyTorch training to
ONNX Runtime.
The first call to forward performs setup and checking steps. During this call,
ORTModule determines whether the module can be trained with ONNX Runtime. For
this reason, the first forward call execution takes longer than subsequent calls.
Execution is interupted if ONNX Runtime cannot process the model for training.
Args:
*inputs and **kwargs represent the positional, variable positional, keyword
and variable keyword arguments defined in the user's PyTorch module's forward
method. Values can be torch tensors and primitive types.
Returns:
The output as expected from the forward method defined by the user's
PyTorch module. Output values supported include tensors, nested sequences
of tensors and nested dictionaries of tensor values.
'''
def _replicate_for_data_parallel(self):
"""Raises a NotImplementedError exception since ORTModule is not compatible with torch.nn.DataParallel
@ -120,7 +136,7 @@ class ORTModule(torch.nn.Module):
return self
def apply(self: T, fn: Callable[['Module'], None]) -> T:
"""Override original method to delegate execution to the flattened PyTorch user module"""
"""Override :meth:`~torch.nn.Module.apply` to delegate execution to ONNX Runtime"""
self._torch_module.apply(fn)
return self
@ -129,7 +145,7 @@ class ORTModule(torch.nn.Module):
return self._torch_module.is_training()
def train(self: T, mode: bool = True) -> T:
"""Override original method to delegate execution to the flattened PyTorch user module"""
"""Override :meth:`~torch.nn.Module.train` to delegate execution to ONNX Runtime"""
self.training = mode
# In a torch.nn.Module, _modules stores all dependent modules (sub-modules) of the current module.
@ -142,54 +158,54 @@ class ORTModule(torch.nn.Module):
return self
def state_dict(self, destination=None, prefix='', keep_vars=False):
"""Override original method to delegate execution to the original PyTorch user module"""
"""Override :meth:`~torch.nn.Module.state_dict` to delegate execution to ONNX Runtime"""
return self._torch_module.state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars)
def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
strict: bool = True):
"""Override original method to delegate execution to the original PyTorch user module"""
"""Override :meth:`~torch.nn.Module.load_state_dict` to delegate execution to ONNX Runtime"""
return self._torch_module.load_state_dict(state_dict, strict=strict)
def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None:
"""Override original method to delegate execution to the original PyTorch user module"""
"""Override :meth:`~torch.nn.Module.register_buffer`"""
self._torch_module.register_buffer(name, tensor, persistent=persistent)
def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
"""Override original method to delegate execution to the original PyTorch user module"""
"""Override :meth:`~torch.nn.Module.register_parameter`"""
self._torch_module.register_parameter(name, param)
def get_parameter(self, target: str) -> torch.nn.Parameter:
"""Override original method to delegate execution to the original PyTorch user module"""
"""Override :meth:`~torch.nn.Module.get_parameter`"""
return self._torch_module.get_parameter(target)
def get_buffer(self, target: str) -> torch.Tensor:
"""Override original method to delegate execution to the original PyTorch user module"""
"""Override :meth:`~torch.nn.Module.get_buffer`"""
return self._torch_module.get_buffer(target)
def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]:
"""Override original method to delegate execution to the original PyTorch user module"""
"""Override :meth:`~torch.nn.Module.parameters`"""
yield from self._torch_module.parameters(recurse=recurse)
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]:
"""Override original method to delegate execution to the original PyTorch user module"""
"""Override :meth:`~torch.nn.Module.named_parameters`"""
yield from self._torch_module.named_parameters(prefix=prefix, recurse=recurse)
def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]:
"""Override original method to delegate execution to the original PyTorch user module"""
"""Override :meth:`~torch.nn.Module.buffers`"""
yield from self._torch_module.buffers(recurse=recurse)
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]:
"""Override original method to delegate execution to the original PyTorch user module"""
"""Override :meth:`~torch.nn.Module.named_buffers`"""
yield from self._torch_module.named_buffers(prefix=prefix, recurse=recurse)
@ -201,16 +217,16 @@ class ORTModule(torch.nn.Module):
missing_keys, unexpected_keys, error_msgs)
def named_children(self) -> Iterator[Tuple[str, 'Module']]:
"""Override original method to delegate execution to the original PyTorch user module"""
"""Override :meth:`~torch.nn.Module.named_children`"""
yield from self._torch_module.named_children()
def modules(self) -> Iterator['Module']:
"""Override original method to delegate execution to the original PyTorch user module"""
"""Override :meth:`~torch.nn.Module.modules`"""
yield from self._torch_module.modules()
def named_modules(self, *args, **kwargs):
"""Override original method to delegate execution to the original PyTorch user module"""
"""Override :meth:`~torch.nn.Module.named_modules`"""
yield from self._torch_module.named_modules(*args, **kwargs)