mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Performs shape inference at runtime using user-provided real tensors. - avoids the need for users to precompute shapes which is difficult and error prone - lets us remove args from the PipelineStage ctor (in a later PR) - deprecates existing inference helper in PipelineStage constructor for several reasons: its problematic to have to reason about the stage submod being on the right device for shape inference The current state as of this PR: - Users should not pass any input or output shapes into PipelineStage ctor, and shape inference will run automatically - To override shape inference, they can continue to pass input/output args as previously Currently, does not add a barrier after shape-inference, which essentially pipelines shape inference with the subsequent schedule action for that stage. If this complicates debugging, we could add in a barrier (it comes at a cost, but only during the first step). Testing: - Removed input args from all PP test cases, thus exposing them all to shape-inference. - Verified visually (nvidia-smi) that torchtitan PP 3D test runs shape inference fine without creating extra cuda contexts. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136912 Approved by: https://github.com/kwen2501, https://github.com/H-Huang |
||
|---|---|---|
| .. | ||
| fsdp | ||
| fully_shard | ||
| test_composability | ||
| test_checkpoint.py | ||
| test_compose.py | ||
| test_contract.py | ||
| test_replicate.py | ||
| test_replicate_with_compiler.py | ||