From b0c3d39e0d475e2dad6ad010863f563f8bd60d84 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Wed, 11 Dec 2024 13:15:44 -0800 Subject: [PATCH] [pipelining] Update tutorials and documentation (#143045) Pull Request resolved: https://github.com/pytorch/pytorch/pull/143045 Approved by: https://github.com/wconstab, https://github.com/kwen2501 --- docs/source/distributed.pipelining.rst | 9 +-------- torch/distributed/pipelining/_IR.py | 7 +++++++ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/source/distributed.pipelining.rst b/docs/source/distributed.pipelining.rst index 40ba3eb37f8..77aa8da7784 100644 --- a/docs/source/distributed.pipelining.rst +++ b/docs/source/distributed.pipelining.rst @@ -199,15 +199,8 @@ the model. For example: stage_index, num_stages, device, - input_args=example_input_microbatch, ) - -The ``PipelineStage`` requires an example argument ``input_args`` representing -the runtime input to the stage, which would be one microbatch worth of input -data. This argument is passed through the forward method of the stage module to -determine the input and output shapes required for communication. - When composing with other Data or Model parallelism techniques, ``output_args`` may also be required, if the output shape/dtype of the model chunk will be affected. @@ -421,7 +414,7 @@ are subclasses of ``PipelineScheduleMulti``. Logging ******* -You can turn on additional logging using the `TORCH_LOGS` environment variable from [`torch._logging`](https://pytorch.org/docs/main/logging.html#module-torch._logging): +You can turn on additional logging using the `TORCH_LOGS` environment variable from `torch._logging `_: * `TORCH_LOGS=+pp` will display `logging.DEBUG` messages and all levels above it. * `TORCH_LOGS=pp` will display `logging.INFO` messages and above. diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 54cc11a6ae3..33703e859ce 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -1143,6 +1143,13 @@ class Pipe(torch.nn.Module): class SplitPoint(Enum): + """ + Enum representing the points at which a split can occur in the execution of a submodule. + Attributes: + BEGINNING: Represents adding a split point *before* the execution of a certain submodule in the `forward` function. + END: Represents adding a split point *after* the execution of a certain submodule in the `forward` function. + """ + BEGINNING = 1 END = 2