From aa12d68c376ac63ee14b3dce62dbfd6c955c8f01 Mon Sep 17 00:00:00 2001 From: "Nat Kershaw (MSFT)" Date: Mon, 16 Aug 2021 16:53:01 -0700 Subject: [PATCH] Update ORTModule API docstrings (#8309) --- .../python/training/ortmodule/ortmodule.py | 50 ++++++++++++------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index 998cdc4010..036bd514f0 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -66,8 +66,24 @@ class ORTModule(torch.nn.Module): # This declaration is for automatic document generation purposes only # The actual forward implementation is bound during ORTModule initialization def forward(self, *inputs, **kwargs): - '''Dummy documentation for forward method''' - ... + '''Delegate the :meth:`~torch.nn.Module.forward` pass of PyTorch training to + ONNX Runtime. + + The first call to forward performs setup and checking steps. During this call, + ORTModule determines whether the module can be trained with ONNX Runtime. For + this reason, the first forward call execution takes longer than subsequent calls. + Execution is interupted if ONNX Runtime cannot process the model for training. + + Args: + *inputs and **kwargs represent the positional, variable positional, keyword + and variable keyword arguments defined in the user's PyTorch module's forward + method. Values can be torch tensors and primitive types. + + Returns: + The output as expected from the forward method defined by the user's + PyTorch module. Output values supported include tensors, nested sequences + of tensors and nested dictionaries of tensor values. + ''' def _replicate_for_data_parallel(self): """Raises a NotImplementedError exception since ORTModule is not compatible with torch.nn.DataParallel @@ -120,7 +136,7 @@ class ORTModule(torch.nn.Module): return self def apply(self: T, fn: Callable[['Module'], None]) -> T: - """Override original method to delegate execution to the flattened PyTorch user module""" + """Override :meth:`~torch.nn.Module.apply` to delegate execution to ONNX Runtime""" self._torch_module.apply(fn) return self @@ -129,7 +145,7 @@ class ORTModule(torch.nn.Module): return self._torch_module.is_training() def train(self: T, mode: bool = True) -> T: - """Override original method to delegate execution to the flattened PyTorch user module""" + """Override :meth:`~torch.nn.Module.train` to delegate execution to ONNX Runtime""" self.training = mode # In a torch.nn.Module, _modules stores all dependent modules (sub-modules) of the current module. @@ -142,54 +158,54 @@ class ORTModule(torch.nn.Module): return self def state_dict(self, destination=None, prefix='', keep_vars=False): - """Override original method to delegate execution to the original PyTorch user module""" + """Override :meth:`~torch.nn.Module.state_dict` to delegate execution to ONNX Runtime""" return self._torch_module.state_dict( destination=destination, prefix=prefix, keep_vars=keep_vars) def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', strict: bool = True): - """Override original method to delegate execution to the original PyTorch user module""" + """Override :meth:`~torch.nn.Module.load_state_dict` to delegate execution to ONNX Runtime""" return self._torch_module.load_state_dict(state_dict, strict=strict) def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None: - """Override original method to delegate execution to the original PyTorch user module""" + """Override :meth:`~torch.nn.Module.register_buffer`""" self._torch_module.register_buffer(name, tensor, persistent=persistent) def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None: - """Override original method to delegate execution to the original PyTorch user module""" + """Override :meth:`~torch.nn.Module.register_parameter`""" self._torch_module.register_parameter(name, param) def get_parameter(self, target: str) -> torch.nn.Parameter: - """Override original method to delegate execution to the original PyTorch user module""" + """Override :meth:`~torch.nn.Module.get_parameter`""" return self._torch_module.get_parameter(target) def get_buffer(self, target: str) -> torch.Tensor: - """Override original method to delegate execution to the original PyTorch user module""" + """Override :meth:`~torch.nn.Module.get_buffer`""" return self._torch_module.get_buffer(target) def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]: - """Override original method to delegate execution to the original PyTorch user module""" + """Override :meth:`~torch.nn.Module.parameters`""" yield from self._torch_module.parameters(recurse=recurse) def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]: - """Override original method to delegate execution to the original PyTorch user module""" + """Override :meth:`~torch.nn.Module.named_parameters`""" yield from self._torch_module.named_parameters(prefix=prefix, recurse=recurse) def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]: - """Override original method to delegate execution to the original PyTorch user module""" + """Override :meth:`~torch.nn.Module.buffers`""" yield from self._torch_module.buffers(recurse=recurse) def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]: - """Override original method to delegate execution to the original PyTorch user module""" + """Override :meth:`~torch.nn.Module.named_buffers`""" yield from self._torch_module.named_buffers(prefix=prefix, recurse=recurse) @@ -201,16 +217,16 @@ class ORTModule(torch.nn.Module): missing_keys, unexpected_keys, error_msgs) def named_children(self) -> Iterator[Tuple[str, 'Module']]: - """Override original method to delegate execution to the original PyTorch user module""" + """Override :meth:`~torch.nn.Module.named_children`""" yield from self._torch_module.named_children() def modules(self) -> Iterator['Module']: - """Override original method to delegate execution to the original PyTorch user module""" + """Override :meth:`~torch.nn.Module.modules`""" yield from self._torch_module.modules() def named_modules(self, *args, **kwargs): - """Override original method to delegate execution to the original PyTorch user module""" + """Override :meth:`~torch.nn.Module.named_modules`""" yield from self._torch_module.named_modules(*args, **kwargs)