mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Fixes #146404 Adds changes to the matmul and matmul_backward operation for nested jagged tensors, to support back propagation when the output is a regular strided tensor. This required adding support for the nested matmul operation to work when the nested tensor wasn't 'self', i.e `A@B` where `A` isn't nested but `B` is. The operation schemas had to be updated to reflect that either input can be a strided tensor instead (and the gradient), so an extra assertion is added in an edge case where neither input is nested. Unit tests are also added. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146405 Approved by: https://github.com/soulitzer, https://github.com/jbschlosser |
||
|---|---|---|
| .. | ||
| _internal | ||
| __init__.py | ||
| _comparison.py | ||
| _creation.py | ||
| _utils.py | ||