pytorch/aten
Driss Guessous 11aab72dc9 [SDPA] Add an optional scale kwarg (#95259)
# Summary
This PR adds an optional kwarg to torch torch.nn.functional.scaled_dot_product_attention()
The new kwarg is a scaling factor that is applied after the q@k.T step of the computation. Made updates to the efficient kernel to support but flash and math were minimally updated to support as well.

Will reduce the complexity of: #94729 and has been asked for by a couple of users.

# Review Highlights
- As far as I know I did this the correct way and this both BC and FC compliant. However I always seem to break internal workloads so I would love if someone can advice I did this right?
- I named the optional arg 'scale'. This is probably dumb and I should name it 'scale_factor'. I will make this change but this is annoying and it will require someone thinking we should rename.
- 'scale' is interpreted as `Q@K.T * (scale)`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95259
Approved by: https://github.com/cpuhrsch
2023-03-08 18:07:40 +00:00
..
conda
src [SDPA] Add an optional scale kwarg (#95259) 2023-03-08 18:07:40 +00:00
tools
CMakeLists.txt Remove non-existing third_party/catch from CMake (#95420) 2023-02-24 08:00:07 +00:00