Commit graph

70 commits

Author SHA1 Message Date
Guilherme Leobas
6a9a02acbe Set enable_faithful_generator_behavior flag to True (#142513)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142513
Approved by: https://github.com/zou3519
ghstack dependencies: #141055, #144421, #144422, #144423, #144424, #144420, #145223
2025-02-08 22:42:12 +00:00
Guilherme Leobas
8603a1c870 Suport generators (#141055)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141055
Approved by: https://github.com/zou3519
2025-02-08 22:42:12 +00:00
Ryan Guo
5a4d959cdb [dynamo] Properly model torch profiler context objects (#145537)
Prior to this patch, Dynamo conveniently modelled torch profiler context
objects (e.g., `torch.profiler.profile`) as `NullContextVariable`
because `torch.compile` ignore the effect of these profiler contexts.

However, the semantics of these profiler contexts diverges from
`contextlib.nullcontext` in the `__enter__` function, where the former
returns `self` and the latter returns `None`. This causes subtle error
as observed in #125021.

This patch adds back a `ProfilerContextVariable`, which addresses the
aforementioned semantic discrepency.

Fixes #125021.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145537
Approved by: https://github.com/zou3519, https://github.com/williamwen42
2025-01-28 00:03:36 +00:00
Ryan Guo
66631bc84b [dynamo] Fix read/write conflicts in a cuda test (#145658)
Prior to this patch, the `test_cuda_event_created_outside_of_graph`
is flaky in CI, and that's because we have read and write to the same
`foo` tensor buffer from 2 different streams. This patch eliminates that
by adding a synchronization to wait till read finishes before starting
the write.

Fixes #133837, #133828.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145658
Approved by: https://github.com/yifuwang
2025-01-27 19:55:57 +00:00
Tom Ritchford
f1cbf4b1b5 Enable ruff's unused variable checking everywhere in pytorch (#136965)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136965
Approved by: https://github.com/cyyever, https://github.com/albanD
2024-12-22 02:33:11 +00:00
Guilherme Leobas
673cc88fd6 Add support for contextmanager in Dynamo (#136033)
Fixes #130559

* Intro

This PR adds support for `@contextmanager` in Dynamo. We chose to limit the
scope of this work to only `@contextmanager` and plan to handle generators fully
in #141055 (still in draft).

* Motivation

Dynamo lacks support for generator functions. When it encounters one, it traces
it as if it were a regular function. This is problematic because it can lead to
incorrect behavior. To illustrate, consider the test case below:

```python
import torch
import contextlib

@contextlib.contextmanager
def set_default_dtype(dtype):
    old_dtype = torch.get_default_dtype()
    try:
        torch.set_default_dtype(dtype)
        yield
    finally:
        torch.set_default_dtype(old_dtype)

@torch.compile(backend="eager", fullgraph=True)
def fn():
    with set_default_dtype(torch.float64):
        x = torch.tensor([3.0, 3.0 + 5.0j])
    return x
```

Before this work, Dynamo would not stop at the `yield`, and the graph produced
would contain both calls to `set_default_dtype` executed one after the other.
This is incorrect because the context manager should execute code before and
after the `yield`.

* List of changes

`YIELD_VALUE` now raises an exception (`YieldValueOp`) to signal that control
flow must be suspended and returned to the caller. Additionally, `RETURN_VALUE`
behaves differently in a generator function. Unlike regular functions, where
`RETURN_VALUE` indicates the final result, in generators it signifies that the
generator is exhausted and implicitly raises `StopIteration`.

A new `VariableTracker` named `FunctionDecoratedByContextlibContextManagerVariable`
was introduced to handle `@contextmanager`. This variable tracker acts not just
as a wrapper for the original function but also maintains an internal `tx`
(InstructionTranslator) object to suspend and return control flow to the parent
tracer when a `yield` is encountered.

* Corner cases

Returning a context manager from a compiled function is not supported. This
would require PyTorch to synchronize the generator state between Dynamo and the
interpreter. Any attempt to return it will result in an `IncorrectUsage`
exception.

Graph breaks require special handling as well. In the event of a graph break,
the frame associated with the context manager is skipped, and the context
manager runs in eager mode.

* This PR is breaking my code

There is a configuration flag (`enable_trace_contextlib`) that can be set to
`False` to disable tracing context managers. If this still causes crashes,
please revert this PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136033
Approved by: https://github.com/zou3519
2024-12-20 12:02:20 +00:00
Tom Ritchford
d25e6e623f Fix unused Python variables in test/[a-d]* (#134665)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134665
Approved by: https://github.com/albanD
2024-12-13 22:13:12 +00:00
Yuanhao Ji
16bc82a015 [Dynamo] Replace torch._dynamo.optimize() with torch.compile() [6/N] (#140688)
related commits:

- #139706
- #140238
- #140247
- #140253
- #140663
- #140688

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140688
Approved by: https://github.com/williamwen42
2024-11-18 04:09:09 +00:00
FFFrog
e14b58ffbd Using device-agnostic autocast api (#136613)
- using torch.autocast(device_str="cuda") instead of torch.cuda.amp.autocast()
- using torch.autocast(device_str="cpu") instead of torch.cpu.amp.autocast()

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136613
Approved by: https://github.com/shink, https://github.com/cyyever, https://github.com/kwen2501
2024-09-27 07:16:24 +00:00
Justin Chu
58274e4655 Remove onnx imports in dynamo (#136334)
Remove imports of the ``torch.onnx.operators`` module in dynamo. Since ONNX depends on dynamo, this import line causes a circular dependency. Judging from the source they are not actually needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136334
Approved by: https://github.com/xadupre, https://github.com/jansel, https://github.com/titaiwangms
2024-09-24 14:54:23 +00:00
William Wen
63d6cd351a [dynamo] support torch.nn.attention.sdpa_kernel context manager (#135404)
Fixes https://github.com/pytorch/pytorch/issues/134608

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135404
Approved by: https://github.com/jansel, https://github.com/drisspg
2024-09-12 22:04:48 +00:00
Will Feng
f57b00704e [Traceable FSDP2][Dynamo] Support reconstructing CUDA event object within Dynamo graph (#133635)
`torch.cuda.Event` objects are different from `torch.cuda.Stream` in that events are not pooled, meaning we can't look up a previously created CUDA event object by ID. This prevents CUDA event object created outside of the Dynamo graph from being used within the graph (since Dynamo needs a way to emit a `call_function` line in the graph that does the retrieval of the event object for downstream op use). This PR adds a simple object pool within Dynamo utility, to support looking up CUDA event object by ID from within the Dynamo graph.

After this PR, if a user creates a CUDA event object outside of the graph and use that event within the graph, the behavior will exactly match eager.

Test commands:
- `pytest -rA test/dynamo/test_ctx_manager.py::CtxManagerTests::test_cuda_event_created_outside_of_graph`
- `pytest -rA test/dynamo/test_ctx_manager.py::CtxManagerTests::test_cuda_event_across_graph_break`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133635
Approved by: https://github.com/yifuwang
ghstack dependencies: #133532, #133531, #133636
2024-08-16 20:40:46 +00:00
Yanbo Liang
770086fe39 [Dynamo] Support torch.cuda.device ctx manager (#133385)
Fixes #128059

I'm not sure if this is the right way, since Inductor doesn't always respect the device id set by users, so probably we should just wrap it as null context manager and print a warning. cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @jansel @anijain2305 @mlazos @williamwen42

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133385
Approved by: https://github.com/jansel
2024-08-16 17:05:55 +00:00
Will Feng
1206958d89 [Dynamo] add EventVariable reconstruct (#133236)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133236
Approved by: https://github.com/yifuwang
2024-08-14 02:56:11 +00:00
YangQun1
589aef4bb0 Fix py codegen to delete values that don't have any users (#131028)
Fixes #131025

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131028
Approved by: https://github.com/ezyang
2024-08-01 03:18:37 +00:00
Guilherme Leobas
1e9cdf7d91 Relax constraints for creating a GenericContextWrappingVariable (#129091)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129091
Approved by: https://github.com/yanboliang, https://github.com/zou3519
2024-07-29 15:40:59 +00:00
Xuehai Pan
918ece4f4d [BE][Easy][11/19] enforce style for empty lines in import segments in test/dy*/ (#129762)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129762
Approved by: https://github.com/anijain2305
2024-07-27 17:43:53 +00:00
PyTorch MergeBot
c3679bed35 Revert "Fix py codegen to delete values that don't have any users (#131028)"
This reverts commit 91aba7baac.

Reverted https://github.com/pytorch/pytorch/pull/131028 on behalf of https://github.com/clee2000 due to broke inductor/test_triton_kernels inductor/test_triton_kernels.py::KernelTests::test_triton_kernel_functionalize [GH job link](https://github.com/pytorch/pytorch/actions/runs/10094659640/job/27915271250) [HUD commit link](91aba7baac) ([comment](https://github.com/pytorch/pytorch/pull/131028#issuecomment-2251058374))
2024-07-25 17:42:18 +00:00
YangQun1
91aba7baac Fix py codegen to delete values that don't have any users (#131028)
Fixes #131025

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131028
Approved by: https://github.com/ezyang
2024-07-25 13:04:23 +00:00
PyTorch MergeBot
8ffd109a00 Revert "Fix py codegen to delete values that don't have any users (#131028)"
This reverts commit 466c167b71.

Reverted https://github.com/pytorch/pytorch/pull/131028 on behalf of https://github.com/atalman due to breaks CI ([comment](https://github.com/pytorch/pytorch/pull/131028#issuecomment-2247771530))
2024-07-24 12:21:43 +00:00
YangQun1
466c167b71 Fix py codegen to delete values that don't have any users (#131028)
Fixes #131025

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131028
Approved by: https://github.com/ezyang
2024-07-24 01:03:56 +00:00
PyTorch MergeBot
0b134c15cd Revert "Relax constraints for creating a GenericContextWrappingVariable (#129091)"
This reverts commit 882fd91869.

Reverted https://github.com/pytorch/pytorch/pull/129091 on behalf of https://github.com/clee2000 due to test_jit started failing on main after this stack https://github.com/pytorch/pytorch/actions/runs/9980754603/job/27583474357 a8bd2933d9 ([comment](https://github.com/pytorch/pytorch/pull/129091#issuecomment-2234269541))
2024-07-17 20:59:40 +00:00
Guilherme Leobas
882fd91869 Relax constraints for creating a GenericContextWrappingVariable (#129091)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129091
Approved by: https://github.com/yanboliang, https://github.com/zou3519
2024-07-17 20:07:06 +00:00
Xuehai Pan
a7c596870d [BE][Eazy] remove torch.torch.xxx usages (#127800)
NB: `torch` is exposed in `torch/__init__.py`. So there can be `torch.torch.torch.xxx`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127800
Approved by: https://github.com/peterbell10, https://github.com/kit1980, https://github.com/malfet
2024-06-05 21:53:49 +00:00
William Wen
255a3afbf1 [dynamo] don't LOAD_FAST local context variables in modified bytecode (#125719)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125719
Approved by: https://github.com/jansel
2024-05-08 20:39:06 +00:00
Edward Z. Yang
ecd62746e3 Also pull size/stride info from example_value (#125505)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125505
Approved by: https://github.com/jansel
2024-05-05 22:27:46 +00:00
William Wen
1a20b4ef3f [dynamo] handle inactive nullcontexts across graph breaks (#125518)
whoops

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125518
Approved by: https://github.com/yanboliang
2024-05-04 12:52:20 +00:00
William Wen
f2ab96a57e [dynamo] fix crash when context manager is passed to a function (#125321)
Fix https://github.com/pytorch/pytorch/issues/125274. Main change was to reconstruct `ContextWrappingVariables` as objects in general, but we can replace them with the class on the caller side when generating the resume function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125321
Approved by: https://github.com/jansel
2024-05-03 23:01:30 +00:00
William Wen
0506e95433 [dynamo] support inactive context managers across graph breaks (#125203)
Fix https://github.com/pytorch/pytorch/issues/124900.

When we reconstruct `ContextWrappingVariables`s, we only reconstruct the context class, not the object. Normally, contexts are active (via `with ctx:`) and we initialize the context object in the resume function. But for the case of inactive contexts (contexts declared ahead of time before the `with` block), we do not reconstruct them properly in the optimized bytecode or resume function. So this PR adds initialization for inactive contexts in the resume function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125203
Approved by: https://github.com/jansel
2024-05-01 01:49:09 +00:00
Will Feng
7b02910163 [Compile FSDP2][2/n] Support streams created outside of compile region (#123487)
FSDP2 creates CUDA streams outside of compile region in its 1st iteration eager run, and then torch.compile will attempt to record method calls on these streams (e.g. `stream.record_event()`) in >1st iteration compiled run.

Before this PR, stream proxy is None which causes "None doesn't have attribute record_event" error when we try to call `record_event()` on it. After this PR, stream proxy has the correct value which makes calling methods on it possible.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123487
Approved by: https://github.com/jansel
2024-04-06 08:42:42 +00:00
ydwu4
5410385c42 [dynamo] support comparing stream with constant (#119199)
Before the pr, we have a graph break for:
```python
def f():
    if torch.cuda.current_stream() is not None:
        return torch.randn(2, 2)
torch.compile(f, backend="eager", fullgraph=True)()
```
This pr supports comparson ops of StreamVariable and ConstantVariable by returning a constant.

It's safe to return a constant in this case becuase the StreamVariable is guarded by ID_MATCH when created.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119199
Approved by: https://github.com/yifuwang, https://github.com/anijain2305, https://github.com/jansel
2024-02-06 19:26:03 +00:00
rzou
728789d850 Deflake stream tests, part 2 (#118391)
I missed these the first time around, some more streams need to be
synchronized.

Fixes https://github.com/pytorch/pytorch/issues/112694

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118391
Approved by: https://github.com/ydwu4, https://github.com/yanboliang
2024-01-26 19:10:53 +00:00
rzou
61865205b6 Deflake Dynamo stream tests (#118205)
streams need to be synchronized, otherwise, there is undefined behavior.
This PR adds the necessary synchronization. This exposed some bugs
(https://github.com/pytorch/pytorch/issues/118204), so I just marked the
tests as expectedFailure.

Test Plan:
- tested locally

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118205
Approved by: https://github.com/yanboliang
2024-01-24 23:31:47 +00:00
voznesenskym
fed45aee54 Replace invoking self.value if there is a user defined init, avoiding arbitrary code execution (#117818)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117818
Approved by: https://github.com/ezyang
2024-01-23 03:14:58 +00:00
voznesenskym
f008efa8e7 Reconstruct streams via global registration, temporary impl to unblock FSDP (#117386)
This is a placeholder implementation for reconstructing streams via global storage to unblock FSDP, pending proper stream support design

This PR does a few things:

1) fixes registration for devices with indices. We were only supporting "cuda", we now support "cuda:k" interfaces where k is # of gpu

2) Changes the stream objects in dynamo to take devices as device types, instead of strings, and updates the string based device APIs to gracefully take device types.

3) Introduces a reconstruct-by-global (using existing cleanup hook structures) to streams as a placeholder impl for now

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117386
Approved by: https://github.com/jansel
2024-01-13 07:03:33 +00:00
Xinya Zhang
e3ca7346ce Re-add initial Flash Attention support on ROCM (#115981)
Note about the Updates:

This PR:
1. skips more flash attention related UTs on MI200
2. Fix additional ATen compiling errors after hipification
3. Fix the author "root" of a specific commit
4. Includes the patch from Nikita in favor of block level static initialization.

CAVEAT: This revised PR has a commit that modifies the CI to force its running on MI200 nodes. That specific commit must be reverted before merge.

Original PR (https://github.com/pytorch/pytorch/pull/114309) Note:

This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- Only supports power of two sequence lengths.
- No support for varlen APIs.
- Only support head dimension 16,32,64,128.
- Performance is still being optimized.

Fixes #112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115981
Approved by: https://github.com/malfet
2024-01-04 22:21:31 +00:00
Yanbo Liang
7f40640342 [Dynamo] Support torch.amp.autocast as decorator (#114845)
Fixes #114818

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114845
Approved by: https://github.com/jansel
2023-11-30 23:54:57 +00:00
Jon Chuang
5ccd22502f [contextlib] Wrapping a function with set_grad_enabled will consume its global mutation (#113359)
Fixes https://github.com/pytorch/pytorch/issues/113298

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113359
Approved by: https://github.com/soulitzer, https://github.com/jansel
2023-11-09 19:16:20 +00:00
Jon Chuang
0093e23e52 [dynamo] GradModeVariable should only be eagerly initialized when doing the equivalent of set_grad_enabled (#113293)
Grad mode variable was previously initialized eagerly when called - which is wrong when not explicitly using it in `set_grad_enabled`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113293
Approved by: https://github.com/jansel
2023-11-09 06:00:14 +00:00
Jon Chuang
5af97fedd2 [dynamo] Fix context wrapping grad mode variable (#111534)
Fixes https://github.com/pytorch/pytorch/issues/111528

Makes use of `ContextWrappingVariable` so that the function will enter the grad mode whenever it is called, and exit once it is done calling.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111534
Approved by: https://github.com/jansel
2023-10-22 20:55:48 +00:00
Chen, Zejun
8e60d646b9 [dynamo][stream]support device-agnostic stream in dynamo and capture stream/event method in fx graph (#108312)
This PR implements 2 things:
1. support the device agnostic stream and runtime APIs captured by the dynamo.
2. support the stream methods(include the event) captured by the dynamo.

Here are details for 1st.
Previously the stream captured in dynamo was tightly bind to CUDA. Here we implement a global singleton container named `StreamMethodContainer` for different backends to register their associated stream methods to dynamo. When import the backend’s product, the stream operations can be registered directly by calling

```
device_stream_method = {'current_stream': method_1,
                         'create_stream_context': method_2,
                         'set_stream': method_3,
                         'set_stream_by_id': method_4}
torch._dynamo.stream.register_stream_method(device_name, device_stream_method)
```

Stream methods need to be passed in this API according to the precise semantics represented by the dict key in `device_stream_method`. After register, these methods can be used by dynamo to capture the stream operations in users’ script, for example, get the current stream or set the specific stream. Additionally, the wrapped stream variable and the stream context variable are changed to be the device-agnostic, the proxy functions of these variables are assigned by the associated methods in the container. All of this are illustrated in the below. Below is a illustration.

![image](https://github.com/pytorch/pytorch/assets/74231238/37ac7350-c539-4167-9886-c3744ecab65d)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108312
Approved by: https://github.com/jansel, https://github.com/jgong5
2023-10-22 13:22:58 +00:00
drisspg
ad90ab31f2 Flash Attention v2 (#105602)
# Summary
## PR Dependencies
I don't use ghstack :( this is a PR where it would have been helpful. That beings said I am going to peel off some PRs to make reviewing this easier:
- [x] Separate build flags for Flash and MemEff: #107985

### Description
This pull request updates the version of _scaled_dot_product_flash_attention from version 1 to version 2. The changes are based on the flash attention code originally authored by @tridao

### Changes Made
The majority of the changes in this pull request involve:

- Copying over the flash_attention sources.
- Updating header files.
- Removing padding and slicing code from within the flash_attention kernel and relocating it to the composite implicit region of the SDPA. This was need to make the kernel functional and appease autograd.
- Introducing a simple kernel generator to generate different instantiations of the forward and backward flash templates.
- Adding conditional compilation (ifdef) to prevent building when nvcc is invoked with gencode < sm80.
- Introducing a separate dependent option for mem_eff_attention, as flash_attention v2 lacks support for Windows and cannot be built for sm50 generation codes.
- Modifying build.sh to reduce parallelization on sm86 runners and to lower the maximum parallelization on the manywheel builds. This adjustment was made to address out-of-memory issues during the compilation of FlashAttentionV2 sources.
- Adding/Updating tests.

### Notes for Reviewers
This is not a fun review, and I apologize in advance.
Most of the files-changed are in the flash_attn/ folder. The only files of interest here IMO:
- aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
- aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py ( this has been incorporated upstream to flash-attention github)

There are a number of files all related to avoiding OOMs in CI/CD. These are typically shell scripts.

### Follow up items
- Include the updates from e07aa036db and 9e5e8bc91e | https://github.com/pytorch/pytorch/issues/108108

### Work Items
- [x] I don't think Windows will be supported for 3.1.0 - Need to update cmakee
- [x] Let multi_query/attention pass through and test | UPDATE: I have the fast path implemented here: https://github.com/pytorch/pytorch/pull/106730 but since this will require changes to semantics of math to call repeat_interleave, I think this should be done as a followup.
- [x] Had to drop cutlass back to 3.0.0 to get it to compile. Need to figure out how to upgrade to 3.1.0 and later. Spoke with Tri and he is going to be taking a look. Note: compiling with clang currently errors for the cute headers.
- [x] Update test exercise above codepath
- [x] Still need to disable on seq_len % 128 != 0 for backward( Tri beat me to it a4f148b6ab)
- [x] Add determinism warning to BWD, Tri got to this one as well: 1c41d2b
- [x] Update dispatcher to universally prefer FlashV2
- [x] Update tests to exercise new head_dims
- [x] Move the head_dim padding from kernel to top level composite implicit function in order to make it purely functional
- [x] Create template generator script
- [x] Initial cmake support for building kernels/ folder
- [x] Replay CudaGraph changes

### Results
#### Forward only
The TFlops are reported here are on a100 that is underclocked.
![flashv2_tflops_vs_seq_len](https://github.com/pytorch/pytorch/assets/32754868/152de46d-8fa6-42f0-9a9c-ef1eb7ae29e7)

#### Forward+Backward
Ran a sweep and for large compute bound sizes we do see a ~2x performance increase for forw+back.
<img width="1684" alt="Screenshot 2023-07-20 at 3 47 47 PM" src="https://github.com/pytorch/pytorch/assets/32754868/fdd26e07-0077-4878-a417-f3a418b6fb3b">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105602
Approved by: https://github.com/huydhn, https://github.com/cpuhrsch
2023-09-13 13:59:05 +00:00
Michael Voznesensky
de0b18fad9 Use user directed names for variables where possible (#109092)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109092
Approved by: https://github.com/ezyang
ghstack dependencies: #108846
2023-09-13 07:44:04 +00:00
Huy Do
a9c663c269 Revert "Flash Attention v2 (#105602)" (#108827)
This reverts commit add45aea1c.

There are some conflicts on some benchmark csv file https://github.com/pytorch/pytorch/pull/105602#issuecomment-1710988951 so I need to revert this manually.

The diff has been reverted internally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108827
Approved by: https://github.com/kit1980
2023-09-08 07:43:04 +00:00
PyTorch MergeBot
e45b290127 Revert "Revert "Flash Attention v2 (#105602)" (#108827)"
This reverts commit 24e9bbe22a.

Reverted https://github.com/pytorch/pytorch/pull/108827 on behalf of https://github.com/huydhn due to I need to land this revert properly as there are new failures showing up on trunk ([comment](https://github.com/pytorch/pytorch/pull/108827#issuecomment-1711020924))
2023-09-08 03:25:45 +00:00
Huy Do
24e9bbe22a Revert "Flash Attention v2 (#105602)" (#108827)
This reverts commit add45aea1c.

There are some conflicts on some benchmark csv file https://github.com/pytorch/pytorch/pull/105602#issuecomment-1710988951 so I need to revert this manually.

The diff has been reverted internally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108827
Approved by: https://github.com/kit1980
2023-09-08 02:54:20 +00:00
drisspg
add45aea1c Flash Attention v2 (#105602)
# Summary
## PR Dependencies
I don't use ghstack :( this is a PR where it would have been helpful. That beings said I am going to peel off some PRs to make reviewing this easier:
- [x] Separate build flags for Flash and MemEff: #107985

### Description
This pull request updates the version of _scaled_dot_product_flash_attention from version 1 to version 2. The changes are based on the flash attention code originally authored by @tridao

### Changes Made
The majority of the changes in this pull request involve:

- Copying over the flash_attention sources.
- Updating header files.
- Removing padding and slicing code from within the flash_attention kernel and relocating it to the composite implicit region of the SDPA. This was need to make the kernel functional and appease autograd.
- Introducing a simple kernel generator to generate different instantiations of the forward and backward flash templates.
- Adding conditional compilation (ifdef) to prevent building when nvcc is invoked with gencode < sm80.
- Introducing a separate dependent option for mem_eff_attention, as flash_attention v2 lacks support for Windows and cannot be built for sm50 generation codes.
- Modifying build.sh to reduce parallelization on sm86 runners and to lower the maximum parallelization on the manywheel builds. This adjustment was made to address out-of-memory issues during the compilation of FlashAttentionV2 sources.
- Adding/Updating tests.

### Notes for Reviewers
This is not a fun review, and I apologize in advance.
Most of the files-changed are in the flash_attn/ folder. The only files of interest here IMO:
- aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
- aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py ( this has been incorporated upstream to flash-attention github)

There are a number of files all related to avoiding OOMs in CI/CD. These are typically shell scripts.

### Follow up items
- Include the updates from e07aa036db and 9e5e8bc91e | https://github.com/pytorch/pytorch/issues/108108

### Work Items
- [x] I don't think Windows will be supported for 3.1.0 - Need to update cmakee
- [x] Let multi_query/attention pass through and test | UPDATE: I have the fast path implemented here: https://github.com/pytorch/pytorch/pull/106730 but since this will require changes to semantics of math to call repeat_interleave, I think this should be done as a followup.
- [x] Had to drop cutlass back to 3.0.0 to get it to compile. Need to figure out how to upgrade to 3.1.0 and later. Spoke with Tri and he is going to be taking a look. Note: compiling with clang currently errors for the cute headers.
- [x] Update test exercise above codepath
- [x] Still need to disable on seq_len % 128 != 0 for backward( Tri beat me to it a4f148b6ab)
- [x] Add determinism warning to BWD, Tri got to this one as well: 1c41d2b
- [x] Update dispatcher to universally prefer FlashV2
- [x] Update tests to exercise new head_dims
- [x] Move the head_dim padding from kernel to top level composite implicit function in order to make it purely functional
- [x] Create template generator script
- [x] Initial cmake support for building kernels/ folder
- [x] Replay CudaGraph changes

### Results
#### Forward only
The TFlops are reported here are on a100 that is underclocked.
![flashv2_tflops_vs_seq_len](https://github.com/pytorch/pytorch/assets/32754868/152de46d-8fa6-42f0-9a9c-ef1eb7ae29e7)

#### Forward+Backward
Ran a sweep and for large compute bound sizes we do see a ~2x performance increase for forw+back.
<img width="1684" alt="Screenshot 2023-07-20 at 3 47 47 PM" src="https://github.com/pytorch/pytorch/assets/32754868/fdd26e07-0077-4878-a417-f3a418b6fb3b">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105602
Approved by: https://github.com/huydhn, https://github.com/cpuhrsch
2023-09-01 22:14:44 +00:00
kshitij12345
11602ac564 [dynamo] fix disable_saved_tensors_hooks - graph break (#106875)
```python
def wrapper_fn(x):
    with torch.autograd.graph.disable_saved_tensors_hooks("ERROR"):
        y = x + 1
        print("HI")
        return y + 2

x = torch.randn(())

a = wrapper_fn(x)
opt = torch.compile(wrapper_fn, backend='eager', fullgraph=False)
e = opt(x)
```

Without the fix fails with,
```
Traceback (most recent call last):
  File "/home/kshiteej/Pytorch/pytorch_functorch/test/test_trace_grad.py", line 182, in <module>
    e = opt(x)
  File "/home/kshiteej/Pytorch/pytorch_functorch/torch/_dynamo/eval_frame.py", line 333, in _fn
    return fn(*args, **kwargs)
  File "/home/kshiteej/Pytorch/pytorch_functorch/test/test_trace_grad.py", line 165, in wrapper_fn
    def wrapper_fn(x):
AttributeError: module 'torch.autograd.graph' has no attribute 'disable_saved_tensors_hook'
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106875
Approved by: https://github.com/zou3519
2023-08-19 11:41:40 +00:00
Edward Z. Yang
7b9d250f06 Change _dynamo.export to be export(f)(*args, **kwargs) (#106109)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106109
Approved by: https://github.com/voznesenskym
2023-07-27 21:41:13 +00:00
kshitij12345
920b446da9 dynamo: support disable_saved_tensors_hooks (#104869)
Functorch transforms use this context manager which will lead to graph-breaks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104869
Approved by: https://github.com/zou3519
2023-07-26 07:27:37 +00:00