mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
I'm sick of reductions not working properly - spotty dim coverage, missing backwards, etc. This PR fixes quite a bit. It applies to the following ops: * `sum` / `mean` / `prod` * `all` / `any` * `amin` / `amax` * `min` / `max` * `argmin` / `argmax` The general reduction logic has been factored out into a helper `_apply_reduction(func, func_name, identity_element, *args, **kwargs)`. The idea is that by providing a valid identity element, we can utilize conversions to padded dense when needed for reducing over the ragged dim. Extensive test coverage includes: * reductions across ragged dim * reductions across non-batch, non-ragged dims * reductions across both batch and ragged dims * multiple dim reductions (for ops that support this) * full reduction -> scalar Bonus: the PR includes backwards fixes for `sum` and `mean`, which have never worked. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139317 Approved by: https://github.com/cpuhrsch |
||
|---|---|---|
| .. | ||
| _internal | ||
| __init__.py | ||