mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Part of my BE project addressing NJT bugs surfaced via OpInfo tests. This PR implements missing backward support for NJT matmul. Notably, for dense tensors, matmul dispatches to bmm. However, due to historical reasons related to NST, NJT handles matmul directly, and thus can't rely on the CompositeImplicit impl of matmul to get the derivative formula. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144587 Approved by: https://github.com/soulitzer ghstack dependencies: #144586 |
||
|---|---|---|
| .. | ||
| _internal | ||
| __init__.py | ||