mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
# Summary This PR adds an optional kwarg to torch torch.nn.functional.scaled_dot_product_attention() The new kwarg is a scaling factor that is applied after the q@k.T step of the computation. Made updates to the efficient kernel to support but flash and math were minimally updated to support as well. Will reduce the complexity of: #94729 and has been asked for by a couple of users. # Review Highlights - As far as I know I did this the correct way and this both BC and FC compliant. However I always seem to break internal workloads so I would love if someone can advice I did this right? - I named the optional arg 'scale'. This is probably dumb and I should name it 'scale_factor'. I will make this change but this is annoying and it will require someone thinking we should rename. - 'scale' is interpreted as `Q@K.T * (scale)` Pull Request resolved: https://github.com/pytorch/pytorch/pull/95259 Approved by: https://github.com/cpuhrsch |
||
|---|---|---|
| .. | ||
| conda | ||
| src | ||
| tools | ||
| CMakeLists.txt | ||