When caching is enabled, an internal model fails with
```
assert_size_stride(bmm_9, (17, s0, 512), (54784, 512, 1))
AssertionError: expected size 17==17, stride 57344==54784 at dim=0
```
looking at this model, the exact problem is when the cache is hit on the forward graph, the generated code for backward fails since the strides of the outputs of forward, passed to backward as inputs, are not what we expected.
This PR changes the evaluation logic so that we defer evaluation of output stride exprs to load path as opposed to eagerly doing it on save path.
I have not been able to come up with a unit test repro for this problem.
Differential Revision: [D58796503](https://our.internmc.facebook.com/intern/diff/D58796503)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128997
Approved by: https://github.com/ezyang
Summary:
WARNING: This API is highly unstable and will be subject to change in the future.
Add a protoype to "decompose" an ExportedProgram into a joint graph form, so that we can compute the gradients on this graph.
Test Plan: buck test mode/opt caffe2/torch/fb/export:test_experimental
Differential Revision: D55657917
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128847
Approved by: https://github.com/tugsbayasgalan
I'd like to discuss the criteria that we regard an implementation as stable. If there is no existing standard, my initial proposal would be a 6 month period after the commit to regard it as stable. As a result, now Adam and AdamW on CUDA would be considered as stable, while the rest are of beta.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129006
Approved by: https://github.com/malfet
### bc-breaking for existing users of the private API:
- Existing policy functions must now change their return value to be [CheckpointPolicy](c0b40ab42e/torch/utils/checkpoint.py (L1204-L1230)) Enum instead of bool.
- To restore previous behavior, return `PREFER_RECOMPUTE` instead of `False` and `{PREFER,MUST}_SAVE` instead of `True` depending whether you prefer the compiler to override your policy.
- Policy function now accepts a `ctx` object instead of `mode` for its first argument.
- To restore previous behavior, `mode = "recompute" if ctx.is_recompute else "forward"`.
- Existing calls to `_pt2_selective_checkpoint_context_fn_gen` must be renamed to `create_selective_checkpoint_contexts `. The way you use the API remains the same. It would've been nice to do something different (not make the user have to use functools.partial?), but this was the easiest to compile (idk if this should actually be a constraint).
Related doc: https://docs.google.com/document/d/1BKyizkZPdri9mHqdDOLAUpkI7SbbKfLHRFVVpK9ZWqo/edit
Memory considerations:
- As with the existing SAC, cached values are cleared upon first use.
- We error if the user wishes to backward a second time on a region forwarded with SAC enabled.
In-place:
- We use version counting to enforce that if any cached tensor has been mutated. In-place operations not mutating cached tensors are allowed.
- `allow_cache_entry_mutation=True` can be passed to disable this check (useful in the case of auto AC where the user is cleverly also saves the output of the in-place)
Randomness, views
- Currently in this PR, we don't do anything special for randomness or views, the author of the policy function is expected to handle them properly. (Would it would be beneficial to error? - we either want to save all or recompute all random tensors)
Tensor object preservation
- ~We guarantee that if a tensor does not requires grad, and it is saved, then what you get out is the same tensor object.~ UPDATE: We guarantee that if a tensor is of non-differentiable dtype AND it is not a view, and it is saved, then what you get out is the same tensor object. This is a nice guarantee for nested tensors which care about the object identity of of the offsets tensor.
Policy function
- Enum values are `{MUST,PREFER}_{SAVE,RECOMPUTE}` (bikeshed welcome). Alternatively there was `{SAVE,RECOMPUTE}_{NON_,}OVERRIDABLE`. The former was preferred bc it seemed clearer that two `MUST` clashing should error, versus it is ambiguous whether two `NON_OVERRIDABLE` being stacked should silently ignore or error.
- The usage of Enum today. There actually is NO API to stack SAC policies today. The only thing the Enum should matter for in the near term is the compiler. The stacking SAC policy would be useful if someone wants to implement something like simple FSDP, but it is not perfect because with a policy of `PREFER_SAVE` you are actually saving more than autograd would save normally (would be fixed with AC v3).
- The number of times we call the policy_fn is something that should be documented as part of public API. We call the policy function for all ops except ~~detach~~ UPDATE : metadata ops listed in `torch.utils.checkpoint.SAC_IGNORED_OPS`) because these ops may be called a different number of times by AC itself between forward and recompute.
- The policy function can be a stateful object (we do NOT make separate copies of this object for forward/recompute, the user is expected to handle that via is_recompute see below).
Tensors guaranteed to be the same tensor as-is
- Policy function signature takes ctx object as its first argument. The ctx function is an object encapsulating info that may be useful to the user, it currently only holds "is_recompute". Adding this indirection gives us flexibility to add more attrs later if necessary.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125795
Approved by: https://github.com/Chillee, https://github.com/fmassa
Fixes#127908
## Description
Created docs to document the torch.cuda.cudart function to solve the issue #127908.
I tried to stick to the [guidelines to document a function](https://github.com/pytorch/pytorch/wiki/Docstring-Guidelines#documenting-a-function) but I was not sure if there is a consensus on how to handle the docs of a function that calls an internal function. So I went ahead and tried what the function will raise, etc. from the user endpoint and documented it (i.e. I am giving what actually _lazy_init() will raise).
Updated PR from #128298 since I made quite a big mistake in my branch. I apologize for the newbie mistake.
### Summary of Changes
- Added docs for torch.cuda.cudart
- Added the cudart function in the autosummary of docs/source/cuda.rst
## Checklist
- [X] The issue that is being fixed is referred in the description
- [X] Only one issue is addressed in this pull request
- [X] Labels from the issue that this PR is fixing are added to this pull request
- [X] No unnecesary issues are included into this pull request
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128741
Approved by: https://github.com/msaroufim
The following are all constrained under the ONNX exporter project scope.
- `personal_of_interest.rst`
- Moving folks no longer working on the project to emeritus.
- Adding @justinchuby, @titaiwangms, @shubhambhokare1 and @xadupre,
who have all made countless contributions to this project.
- `CODEOWNERS`
- Removing folks no longer working on the project.
- Updating new owners who will now be notified with PRs related to
the specific file paths.
- `merge_rules.yaml`
- Removing folks no longer working on the project.
🫡
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126364
Approved by: https://github.com/titaiwangms, https://github.com/justinchuby, https://github.com/albanD
Related doc: https://docs.google.com/document/d/1BKyizkZPdri9mHqdDOLAUpkI7SbbKfLHRFVVpK9ZWqo/edit
Memory considerations:
- As with the existing SAC, cached values are cleared upon first use.
- We error if the user wishes to backward a second time on a region forwarded with SAC enabled.
In-place:
- We use version counting to enforce that if any cached tensor has been mutated. In-place operations not mutating cached tensors are allowed.
- `allow_cache_entry_mutation=True` can be passed to disable this check (useful in the case of auto AC where the user is cleverly also saves the output of the in-place)
Randomness, views
- Currently in this PR, we don't do anything special for randomness or views, the author of the policy function is expected to handle them properly. (Would it would be beneficial to error? - we either want to save all or recompute all random tensors)
Tensor object preservation
- We guarantee that if a tensor does not requires grad, and it is saved, then what you get out is the same tensor object. If the tensor does require grad, we must detach to avoid creating a reference cycle. This is a nice guarantee for nested tensors which care about the object identity of of the offsets tensor.
Policy function
- Enum values are `{MUST,PREFER}_{SAVE,RECOMPUTE}` (bikeshed welcome). Alternatively there was `{SAVE,RECOMPUTE}_{NON_,}OVERRIDABLE`. The former was preferred bc it seemed clearer that two `MUST` clashing should error, versus it is ambiguous whether two `NON_OVERRIDABLE` being stacked should silently ignore or error.
- The usage of Enum today. There actually is NO API to stack SAC policies today. The only thing the Enum should matter for in the near term is the compiler. The stacking SAC policy would be useful if someone wants to implement something like simple FSDP, but it is not perfect because with a policy of `PREFER_SAVE` you are actually saving more than autograd would save normally (would be fixed with AC v3).
- The number of times we call the policy_fn is something documented part of public API. We call the policy function for all ops except detach because detach is itself called a different number of times by AC between forward and recompute.
- The policy function can be a stateful object (we do NOT make separate copies of this object for forward/recompute, the user is expected to handle that via is_recompute see below).
Tensors guaranteed to be the same tensor as-is
- Policy function signature takes ctx object as its first argument. The ctx function is an object encapsulating info that may be useful to the user, it currently only holds "is_recompute". Adding this indirection gives us flexibility to add more attrs later if necessary.
"bc-breaking" for existing users of the private API:
- Existing policy functions must now change their return value to use the Enum.
- Existing calls to `_pt2_selective_checkpoint_context_fn_gen` must be renamed to `gen_selective_checkpoint_context_fn`. The way you use the API remains the same. It would've been nice to do something different (not make the user have to use functools.partial?), but this was the easiest to compile (idk if this should actually be a constraint).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125795
Approved by: https://github.com/Chillee, https://github.com/fmassa
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.
What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...
Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
In case user modified stage module out of place, such as
mod = DDP(mod)
mod = torch.compile(mod)
They need a stage builder else than `pipe.build_stage()`.
This PR provides an API to do so:
```
def build_stage(
stage_module,
stage_index,
pipe.info(),
...
)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128273
Approved by: https://github.com/wconstab
----
- Bring PipelineStage/Schedule more front-and-center
- provide details on how to manually construct PipelineStage
- move tracer example and manual example below so the high-level flow
(e2e) is closer to the top
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128236
Approved by: https://github.com/H-Huang
ghstack dependencies: #128201, #128228
Changed the API of `pipeline()` to take microbatch instead of full batch as example args.
Main purpose is to:
- make this API more atomic;
- decouple tracing frontend from runtime info like `num_chunks`.
Side effects:
- Creates opportunity for varying `num_chunks` of schedules with the same `pipe` object.
- User has to create example microbatch input.
- Chunk spec stuff are now all moved to runtime side.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128163
Approved by: https://github.com/H-Huang
Renaming ManualPipelineStage to remove the "Manual" part. I needed to replace the existing `PipelineStage` which takes in the `pipe` argument, so I have renamed that to `TracerPipelineStage`. @kwen2501 will remove this entirely in favor of adding a util to `Pipe` to just create the stage directly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128157
Approved by: https://github.com/wconstab
Summary:
Part of the work helping export's automatic dynamic shapes / dynamic shapes refining based on suggested fixes.
Introduces a util function refine_dynamic_shapes_from_suggested_fixes() that takes the error message from a ConstraintViolationError message containing suggested dynamic shapes fixes, along with the original dynamic shapes spec, and returns the new spec. Written so that the suggested fixes from export can be directly parsed and used.
Example usage for the automatic dynamic shapes workflow:
```
# export, fail, parse & refine suggested fixes, re-export
try:
export(model, inps, dynamic_shapes=dynamic_shapes)
except torch._dynamo.exc.UserError as exc:
new_shapes = refine_dynamic_shapes_from_suggested_fixes(exc.msg, dynamic_shapes)
export(model, inps, dynamic_shapes=new_shapes)
```
For examples of behavior, see the added test and docstring. Will take suggestions for renaming the function to something else 😅
Test Plan: test_export tests
Differential Revision: D57409142
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127436
Approved by: https://github.com/avikchaudhuri
Fixes#126367.
## Description
Fixed a broken link in the pytorch/docs/source/torch.compiler_faq.rst doc and deleted a few words that were extra according to the issue tagged above.
## Checklist
- [X] The issue that is being fixed is referred in the description
- [X] Only one issue is addressed in this pull request
- [X] Labels from the issue that this PR is fixing are added to this pull request
- [X] No unnecesary issues are included into this pull request
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127938
Approved by: https://github.com/msaroufim
For a masked `tl.load` operation, the Triton language specifies that values masked out (i.e. where the mask evaluates to false) are undefined in the output of the load. Triton provides an optional `other` parameter which, when included, provides an explicit value to use for masked out values from the load. If the output from a masked load without the `other` parameter is used in a conditional, unexpected behavior can occur.
Despite the language specification, all Triton backends currently in use by PyTorch Inductor (NVIDIA, AMD, and Intel) 0-initialize masked loads if `other` is not present (we recently changed the Intel backend behavior to match NVIDIA and AMD because that's what our users expect, even if we are not following the Triton spec to the tee). This PR attempts to "future-proof" Inductor for new backends (or perhaps changes in the current backends? - we did not see any performance change from 0-initializing in the Intel XPU backend but one could imagine compiler optimizations to remove paths that depend on undefined) to add an explicit `other` in instances where later conditionals depend on the `tl.load` output. I also removed an exception to `other` behavior for boolean loads, which was put in place for a Triton bug that should be fixed. I added `other` to the getting started documentation as a clue that masked load behavior requires explicit initialization if, even though I don't expect `undef` values to cause the example code to fail if the underlying output is not 0-initialized. Finally, I added other to the `make_load` function in `select_algorithm.py`, though I wasn't able to determine if that function was actually being called.
Fixes#126535
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127311
Approved by: https://github.com/jansel