mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Update ORTModule API docstrings (#8309)
This commit is contained in:
parent
8713d76dd1
commit
aa12d68c37
1 changed files with 33 additions and 17 deletions
|
|
@ -66,8 +66,24 @@ class ORTModule(torch.nn.Module):
|
|||
# This declaration is for automatic document generation purposes only
|
||||
# The actual forward implementation is bound during ORTModule initialization
|
||||
def forward(self, *inputs, **kwargs):
|
||||
'''Dummy documentation for forward method'''
|
||||
...
|
||||
'''Delegate the :meth:`~torch.nn.Module.forward` pass of PyTorch training to
|
||||
ONNX Runtime.
|
||||
|
||||
The first call to forward performs setup and checking steps. During this call,
|
||||
ORTModule determines whether the module can be trained with ONNX Runtime. For
|
||||
this reason, the first forward call execution takes longer than subsequent calls.
|
||||
Execution is interupted if ONNX Runtime cannot process the model for training.
|
||||
|
||||
Args:
|
||||
*inputs and **kwargs represent the positional, variable positional, keyword
|
||||
and variable keyword arguments defined in the user's PyTorch module's forward
|
||||
method. Values can be torch tensors and primitive types.
|
||||
|
||||
Returns:
|
||||
The output as expected from the forward method defined by the user's
|
||||
PyTorch module. Output values supported include tensors, nested sequences
|
||||
of tensors and nested dictionaries of tensor values.
|
||||
'''
|
||||
|
||||
def _replicate_for_data_parallel(self):
|
||||
"""Raises a NotImplementedError exception since ORTModule is not compatible with torch.nn.DataParallel
|
||||
|
|
@ -120,7 +136,7 @@ class ORTModule(torch.nn.Module):
|
|||
return self
|
||||
|
||||
def apply(self: T, fn: Callable[['Module'], None]) -> T:
|
||||
"""Override original method to delegate execution to the flattened PyTorch user module"""
|
||||
"""Override :meth:`~torch.nn.Module.apply` to delegate execution to ONNX Runtime"""
|
||||
|
||||
self._torch_module.apply(fn)
|
||||
return self
|
||||
|
|
@ -129,7 +145,7 @@ class ORTModule(torch.nn.Module):
|
|||
return self._torch_module.is_training()
|
||||
|
||||
def train(self: T, mode: bool = True) -> T:
|
||||
"""Override original method to delegate execution to the flattened PyTorch user module"""
|
||||
"""Override :meth:`~torch.nn.Module.train` to delegate execution to ONNX Runtime"""
|
||||
|
||||
self.training = mode
|
||||
# In a torch.nn.Module, _modules stores all dependent modules (sub-modules) of the current module.
|
||||
|
|
@ -142,54 +158,54 @@ class ORTModule(torch.nn.Module):
|
|||
return self
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
"""Override :meth:`~torch.nn.Module.state_dict` to delegate execution to ONNX Runtime"""
|
||||
|
||||
return self._torch_module.state_dict(
|
||||
destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
|
||||
strict: bool = True):
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
"""Override :meth:`~torch.nn.Module.load_state_dict` to delegate execution to ONNX Runtime"""
|
||||
|
||||
return self._torch_module.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
"""Override :meth:`~torch.nn.Module.register_buffer`"""
|
||||
|
||||
self._torch_module.register_buffer(name, tensor, persistent=persistent)
|
||||
|
||||
def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
"""Override :meth:`~torch.nn.Module.register_parameter`"""
|
||||
|
||||
self._torch_module.register_parameter(name, param)
|
||||
|
||||
def get_parameter(self, target: str) -> torch.nn.Parameter:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
"""Override :meth:`~torch.nn.Module.get_parameter`"""
|
||||
|
||||
return self._torch_module.get_parameter(target)
|
||||
|
||||
def get_buffer(self, target: str) -> torch.Tensor:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
"""Override :meth:`~torch.nn.Module.get_buffer`"""
|
||||
|
||||
return self._torch_module.get_buffer(target)
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
"""Override :meth:`~torch.nn.Module.parameters`"""
|
||||
|
||||
yield from self._torch_module.parameters(recurse=recurse)
|
||||
|
||||
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
"""Override :meth:`~torch.nn.Module.named_parameters`"""
|
||||
|
||||
yield from self._torch_module.named_parameters(prefix=prefix, recurse=recurse)
|
||||
|
||||
def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
"""Override :meth:`~torch.nn.Module.buffers`"""
|
||||
|
||||
yield from self._torch_module.buffers(recurse=recurse)
|
||||
|
||||
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
"""Override :meth:`~torch.nn.Module.named_buffers`"""
|
||||
|
||||
yield from self._torch_module.named_buffers(prefix=prefix, recurse=recurse)
|
||||
|
||||
|
|
@ -201,16 +217,16 @@ class ORTModule(torch.nn.Module):
|
|||
missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
def named_children(self) -> Iterator[Tuple[str, 'Module']]:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
"""Override :meth:`~torch.nn.Module.named_children`"""
|
||||
|
||||
yield from self._torch_module.named_children()
|
||||
|
||||
def modules(self) -> Iterator['Module']:
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
"""Override :meth:`~torch.nn.Module.modules`"""
|
||||
|
||||
yield from self._torch_module.modules()
|
||||
|
||||
def named_modules(self, *args, **kwargs):
|
||||
"""Override original method to delegate execution to the original PyTorch user module"""
|
||||
"""Override :meth:`~torch.nn.Module.named_modules`"""
|
||||
|
||||
yield from self._torch_module.named_modules(*args, **kwargs)
|
||||
|
|
|
|||
Loading…
Reference in a new issue