From 4b0808cfc082c7fefcb582b004b1d99e9fb895f7 Mon Sep 17 00:00:00 2001 From: nikitaved Date: Wed, 29 Jan 2025 17:13:18 +0100 Subject: [PATCH] tensordot: performance improvement with complete (to a scalar) contractions --- aten/src/ATen/native/Linear.cpp | 35 ++++++++++++++++--- test/distributed/tensor/test_dtensor_ops.py | 1 - torch/distributed/tensor/_ops/_matrix_ops.py | 5 +++ .../_internal/common_methods_invocations.py | 5 +-- 4 files changed, 38 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/native/Linear.cpp b/aten/src/ATen/native/Linear.cpp index 5adcdc4daa4..16ccb4a644e 100644 --- a/aten/src/ATen/native/Linear.cpp +++ b/aten/src/ATen/native/Linear.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -817,11 +818,35 @@ Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1, rsizes.emplace_back(t2.sym_size(i)); } } - // permute and reshape for matrix multiplication - t1 = t1.permute(p1).reshape_symint({size1, csize}); - t2 = t2.permute(p2).reshape_symint({csize, size2}); - // multiply and reshape to target size - return at::mm(t1, t2).reshape_symint(rsizes); + + // permute to align for contraction + t1 = t1.permute(p1); + t2 = t2.permute(p2); + + // Full contraction (size1 == 1 and size2 == 1) is much faster when done with dot ... + // TODO(@nikitaved): there are other cases where dot outperforms gemms, + // like, for example, when the non-contracted dims are relatively small. + // NOTE(@nikitaved): contract with gemm when on MPS, + // otherwise issues with the tests xpassing/xfailing + // when enabling the fast-path with dot. + // TODO: resolve that + if (size1 == 1 && size2 == 1 && (t1.device().type() != at::kMPS)) { + if (t1.is_contiguous() && t2.is_contiguous()) { + // If t1 and t2 are both contiguous, then flatten is a view, + // then dot is the method of choice + return at::dot(t1.flatten(), t2.flatten()); + } else { + // Otherwise mul + sum can be faster as it avoids at most 2x contiguous() calls + return (t1.squeeze() * t2.squeeze()).sum(); + } + } else { + // ... otherwise contract with a GEMM + // reshape for matrix multiplication + t1 = t1.reshape_symint({size1, csize}); + t2 = t2.reshape_symint({csize, size2}); + // multiply and reshape to target size + return at::mm(t1, t2).reshape_symint(rsizes); + } } Tensor &tensordot_out(const Tensor& input1, const Tensor& input2, IntArrayRef dims1, IntArrayRef dims2, Tensor& result) { diff --git a/test/distributed/tensor/test_dtensor_ops.py b/test/distributed/tensor/test_dtensor_ops.py index 647160eb718..fe14addce31 100644 --- a/test/distributed/tensor/test_dtensor_ops.py +++ b/test/distributed/tensor/test_dtensor_ops.py @@ -132,7 +132,6 @@ dtensor_fails = { xfail("cumulative_trapezoid"), xfail("diagonal_scatter"), xfail("dist"), - xfail("dot"), xfail("empty"), xfail("empty_strided"), xfail("empty_like"), diff --git a/torch/distributed/tensor/_ops/_matrix_ops.py b/torch/distributed/tensor/_ops/_matrix_ops.py index b2127b05a38..27c2a7eaa57 100644 --- a/torch/distributed/tensor/_ops/_matrix_ops.py +++ b/torch/distributed/tensor/_ops/_matrix_ops.py @@ -201,6 +201,11 @@ def _scaled_mm_like_strategy( return mm_strategy +@register_op_strategy(aten.dot.default) +def dot_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + return _mm_like_strategy("i,i->", mesh, op_schema) + + @register_op_strategy(aten.mm.default) def mm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: return _mm_like_strategy("mk,kn->mn", mesh, op_schema) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 087c874c60c..6a01bc5e5d6 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -7091,12 +7091,13 @@ def sample_inputs_tensordot(self, device, dtype, requires_grad, **kwargs): cases = ( ((2, 2, 2), (2, 2, 2), (2)), ((2, 2, 1), (2, 1, 2), ([0, 1], [2, 0])), + ((1, 1, 1), (2, 1, 2), ([0, 1], [2, 0])), ) for first_shape, second_shape, dims in cases: yield SampleInput(make_tensor(first_shape, dtype=dtype, device=device, - requires_grad=requires_grad), + requires_grad=requires_grad, low=-1, high=+2), make_tensor(second_shape, dtype=dtype, device=device, - requires_grad=requires_grad), + requires_grad=requires_grad, low=-1, high=+2), dims=dims) def sample_inputs_kron(op_info, device, dtype, requires_grad, **kwargs):