mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Part of my BE project addressing NJT bugs surfaced via OpInfo tests. Before this PR, `frexp()` for NJT was handled via the unary pointwise fallback. The op returns a tuple, however, and the fallback doesn't handle that. This PR defines an explicit impl for `frexp()` that wraps both returned `(mantissa, exponent)` as NJTs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144585 Approved by: https://github.com/soulitzer ghstack dependencies: #144582, #144583, #144584 |
||
|---|---|---|
| .. | ||
| __init__.py | ||
| nested_int.py | ||
| nested_tensor.py | ||
| ops.py | ||
| sdpa.py | ||