pytorch/torch/nested/_internal
Joel Schlosser 128f3627b1 Implement backward for NJT matmul (#144587)
Part of my BE project addressing NJT bugs surfaced via OpInfo tests.

This PR implements missing backward support for NJT matmul. Notably, for dense tensors, matmul dispatches to bmm. However, due to historical reasons related to NST, NJT handles matmul directly, and thus can't rely on the CompositeImplicit impl of matmul to get the derivative formula.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144587
Approved by: https://github.com/soulitzer
ghstack dependencies: #144586
2025-01-21 18:27:50 +00:00
..
__init__.py
nested_int.py
nested_tensor.py PEP585 update - torch/_higher_order_ops torch/_subclasses torch/backends torch/compiler torch/cuda torch/masked torch/mtia torch/nested (#145202) 2025-01-20 22:37:26 +00:00
ops.py Implement backward for NJT matmul (#144587) 2025-01-21 18:27:50 +00:00
sdpa.py PEP585 update - torch/_higher_order_ops torch/_subclasses torch/backends torch/compiler torch/cuda torch/masked torch/mtia torch/nested (#145202) 2025-01-20 22:37:26 +00:00