pytorch/torch/nested/_internal
Joel Schlosser b63b81410c Fix NJT frexp() to handle both outputs (#144585)
Part of my BE project addressing NJT bugs surfaced via OpInfo tests.

Before this PR, `frexp()` for NJT was handled via the unary pointwise fallback. The op returns a tuple, however, and the fallback doesn't handle that. This PR defines an explicit impl for `frexp()` that wraps both returned `(mantissa, exponent)` as NJTs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144585
Approved by: https://github.com/soulitzer
ghstack dependencies: #144582, #144583, #144584
2025-01-18 15:59:56 +00:00
..
__init__.py
nested_int.py Switch to using Python nested int (#141166) 2024-12-02 19:17:30 +00:00
nested_tensor.py Switch to using Python nested int (#141166) 2024-12-02 19:17:30 +00:00
ops.py Fix NJT frexp() to handle both outputs (#144585) 2025-01-18 15:59:56 +00:00
sdpa.py