tensordot: performance improvement with complete (to a scalar) contractions

This commit is contained in:
nikitaved 2025-01-29 17:13:18 +01:00
parent 6667e5d786
commit 4b0808cfc0
4 changed files with 38 additions and 8 deletions

View file

@ -19,6 +19,7 @@
#include <ATen/ops/addmm.h>
#include <ATen/ops/bilinear_native.h>
#include <ATen/ops/bmm.h>
#include <ATen/ops/dot.h>
#include <ATen/ops/einsum_native.h>
#include <ATen/ops/linear_native.h>
#include <ATen/ops/matmul.h>
@ -817,11 +818,35 @@ Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1,
rsizes.emplace_back(t2.sym_size(i));
}
}
// permute and reshape for matrix multiplication
t1 = t1.permute(p1).reshape_symint({size1, csize});
t2 = t2.permute(p2).reshape_symint({csize, size2});
// multiply and reshape to target size
return at::mm(t1, t2).reshape_symint(rsizes);
// permute to align for contraction
t1 = t1.permute(p1);
t2 = t2.permute(p2);
// Full contraction (size1 == 1 and size2 == 1) is much faster when done with dot ...
// TODO(@nikitaved): there are other cases where dot outperforms gemms,
// like, for example, when the non-contracted dims are relatively small.
// NOTE(@nikitaved): contract with gemm when on MPS,
// otherwise issues with the tests xpassing/xfailing
// when enabling the fast-path with dot.
// TODO: resolve that
if (size1 == 1 && size2 == 1 && (t1.device().type() != at::kMPS)) {
if (t1.is_contiguous() && t2.is_contiguous()) {
// If t1 and t2 are both contiguous, then flatten is a view,
// then dot is the method of choice
return at::dot(t1.flatten(), t2.flatten());
} else {
// Otherwise mul + sum can be faster as it avoids at most 2x contiguous() calls
return (t1.squeeze() * t2.squeeze()).sum();
}
} else {
// ... otherwise contract with a GEMM
// reshape for matrix multiplication
t1 = t1.reshape_symint({size1, csize});
t2 = t2.reshape_symint({csize, size2});
// multiply and reshape to target size
return at::mm(t1, t2).reshape_symint(rsizes);
}
}
Tensor &tensordot_out(const Tensor& input1, const Tensor& input2, IntArrayRef dims1, IntArrayRef dims2, Tensor& result) {

View file

@ -132,7 +132,6 @@ dtensor_fails = {
xfail("cumulative_trapezoid"),
xfail("diagonal_scatter"),
xfail("dist"),
xfail("dot"),
xfail("empty"),
xfail("empty_strided"),
xfail("empty_like"),

View file

@ -201,6 +201,11 @@ def _scaled_mm_like_strategy(
return mm_strategy
@register_op_strategy(aten.dot.default)
def dot_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
return _mm_like_strategy("i,i->", mesh, op_schema)
@register_op_strategy(aten.mm.default)
def mm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
return _mm_like_strategy("mk,kn->mn", mesh, op_schema)

View file

@ -7091,12 +7091,13 @@ def sample_inputs_tensordot(self, device, dtype, requires_grad, **kwargs):
cases = (
((2, 2, 2), (2, 2, 2), (2)),
((2, 2, 1), (2, 1, 2), ([0, 1], [2, 0])),
((1, 1, 1), (2, 1, 2), ([0, 1], [2, 0])),
)
for first_shape, second_shape, dims in cases:
yield SampleInput(make_tensor(first_shape, dtype=dtype, device=device,
requires_grad=requires_grad),
requires_grad=requires_grad, low=-1, high=+2),
make_tensor(second_shape, dtype=dtype, device=device,
requires_grad=requires_grad),
requires_grad=requires_grad, low=-1, high=+2),
dims=dims)
def sample_inputs_kron(op_info, device, dtype, requires_grad, **kwargs):