pytorch/torch/nested/_internal
Michael Diggin e01a5e9e1e Small improvements to NJT matrix multiplies (#146405)
Fixes #146404

Adds changes to the matmul and matmul_backward operation for nested jagged tensors, to support back propagation when the output is a regular strided tensor.
This required adding support for the nested matmul operation to work when the nested tensor wasn't 'self', i.e
`A@B` where `A` isn't nested but `B` is.

The operation schemas had to be updated to reflect that either input can be a strided tensor instead (and the gradient), so an extra assertion is added in an edge case where neither input is nested.

Unit tests are also added.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146405
Approved by: https://github.com/soulitzer, https://github.com/jbschlosser
2025-02-06 04:51:12 +00:00
..
__init__.py
nested_int.py Switch to using Python nested int (#141166) 2024-12-02 19:17:30 +00:00
nested_tensor.py PEP585 update - torch/_higher_order_ops torch/_subclasses torch/backends torch/compiler torch/cuda torch/masked torch/mtia torch/nested (#145202) 2025-01-20 22:37:26 +00:00
ops.py Small improvements to NJT matrix multiplies (#146405) 2025-02-06 04:51:12 +00:00
sdpa.py PEP585 update - torch/_higher_order_ops torch/_subclasses torch/backends torch/compiler torch/cuda torch/masked torch/mtia torch/nested (#145202) 2025-01-20 22:37:26 +00:00