diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.cc index ab9f9e1c35..2fc6685712 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.cc @@ -297,7 +297,6 @@ std::unique_ptr Transpose(const Tensor& input, const std::vector::FinalizeOutput(const Tensor& candidate_outp } } +static bool IsTransposeReshapeForEinsum(const std::vector& perm, + const std::vector& input_dims, + std::vector& new_shape) { + // As long as the dims with values > 1 stay in the same order, it's a reshape. + // Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1). + size_t last_permuted_axis = 0; + for (size_t i = 0; i < perm.size(); ++i) { + if (input_dims[perm[i]] == 1) + continue; + if (perm[i] < last_permuted_axis) + return false; + last_permuted_axis = perm[i]; + } + new_shape = input_dims; + for (size_t i = 0; i < perm.size(); ++i) { + new_shape[i] = input_dims[perm[i]]; + } + return true; +} + template std::unique_ptr EinsumTypedComputeProcessor::PairwiseOperandProcess(const Tensor& left, const TensorShape& left_shape_override, @@ -165,6 +185,7 @@ std::unique_ptr EinsumTypedComputeProcessor::PairwiseOperandProcess(c } // Permutate the left operand so that the axes order go like this: [lro, lo, reduce_dims, ro] + std::vector reshaped_dims; std::vector left_permutation; left_permutation.reserve(lro.size() + lo.size() + reduce_dims.size() + ro.size()); left_permutation.insert(left_permutation.end(), lro.begin(), lro.end()); @@ -173,10 +194,21 @@ std::unique_ptr EinsumTypedComputeProcessor::PairwiseOperandProcess(c left_permutation.insert(left_permutation.end(), ro.begin(), ro.end()); if (EinsumOp::IsTransposeRequired(current_left ? current_left->Shape().GetDims().size() : left_dims.size(), left_permutation)) { - current_left = EinsumOp::Transpose(current_left ? *current_left : left, - current_left ? current_left->Shape().GetDims() : left_dims, - left_permutation, allocator_, einsum_ep_assets_, - device_transpose_func_); + if (current_left && IsTransposeReshapeForEinsum(left_permutation, + current_left->Shape().GetDims(), + reshaped_dims)) { + // This can be done because curent_* tensors (if they exist) and output tensors are + // intermediate tensors and cannot be input tensors to the Einsum node itself + // (which are immutable). + // Covered by ExplicitEinsumAsTensorContractionReshapeLeft. + current_left->Reshape(reshaped_dims); + } else { + // Covered by ExplicitEinsumAsTensorContraction, DiagonalWithMatmul, ... + current_left = EinsumOp::Transpose(current_left ? *current_left : left, + current_left ? current_left->Shape().GetDims() : left_dims, + left_permutation, allocator_, einsum_ep_assets_, + device_transpose_func_); + } } // Permutate the right operand so that the axes order go like this: [lro, reduce_dims, ro, lo] @@ -188,10 +220,19 @@ std::unique_ptr EinsumTypedComputeProcessor::PairwiseOperandProcess(c right_permutation.insert(right_permutation.end(), lo.begin(), lo.end()); if (EinsumOp::IsTransposeRequired(current_right ? current_right->Shape().GetDims().size() : right_dims.size(), right_permutation)) { - current_right = EinsumOp::Transpose(current_right ? *current_right : right, - current_right ? current_right->Shape().GetDims() : right_dims, - right_permutation, allocator_, einsum_ep_assets_, - device_transpose_func_); + if (current_right && IsTransposeReshapeForEinsum(right_permutation, + current_right->Shape().GetDims(), + reshaped_dims)) { + // See note following the previous call of function IsTransposeReshapeForEinsum. + // Covered by ExplicitEinsumAsBatchedMatmulWithBroadcasting_1, ExplicitEinsumAsMatmul_2, ... + current_right->Reshape(reshaped_dims); + } else { + // Covered by DiagonalWithMatmul, ExplicitEinsumAsBatchedMatmul, ... + current_right = EinsumOp::Transpose(current_right ? *current_right : right, + current_right ? current_right->Shape().GetDims() : right_dims, + right_permutation, allocator_, einsum_ep_assets_, + device_transpose_func_); + } } // Calculate output size @@ -258,8 +299,16 @@ std::unique_ptr EinsumTypedComputeProcessor::PairwiseOperandProcess(c if (!is_final_pair) { // This is not the final pair - so bring the axes order to what the inputs conformed to if (EinsumOp::IsTransposeRequired(output_dims.size(), output_permutation)) { - output = EinsumOp::Transpose(*output, output_dims, output_permutation, allocator_, - einsum_ep_assets_, device_transpose_func_); + if (IsTransposeReshapeForEinsum(output_permutation, + output_dims, + reshaped_dims)) { + // See note following the previous call of function IsTransposeReshapeForEinsum. + // Covered by ExplicitEinsumAsTensorContractionReshapeFinal. + output->Reshape(reshaped_dims); + } else { + output = EinsumOp::Transpose(*output, output_dims, output_permutation, allocator_, + einsum_ep_assets_, device_transpose_func_); + } } } else { // This is the final pair - Transpose directly to the output ordering required and copy the contents to the op's output FinalizeOutput(*output, current_subscript_order); diff --git a/onnxruntime/test/providers/cpu/math/einsum_test.cc b/onnxruntime/test/providers/cpu/math/einsum_test.cc index c1f4fe79ab..79446b5273 100644 --- a/onnxruntime/test/providers/cpu/math/einsum_test.cc +++ b/onnxruntime/test/providers/cpu/math/einsum_test.cc @@ -510,6 +510,25 @@ TEST(Einsum, ExplicitEinsumAsTensorContraction) { test.Run(); } +TEST(Einsum, ExplicitEinsumAsTensorContractionReshapeFinal) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "sbcd,es,eh->bce"); + test.AddInput("x", {2, 2, 2, 2}, {1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f}); + test.AddInput("y", {2, 2}, {1.f, 2.f, -6.f, 2.f}); + test.AddInput("z", {2, 2}, {3.f, 4.f, 5.f, 6.f}); + test.AddOutput("o", {2, 2, 2}, {63.f, -132.f, 63.f, -132.f, 63.f, -132.f, 63.f, -132.f}); + test.Run(); +} + +TEST(Einsum, ExplicitEinsumAsTensorContractionReshapeLeft) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "bsnh,btnh->bnts"); + test.AddInput("x", {2, 1, 2, 2}, {1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f}); + test.AddInput("y", {2, 2, 2, 1}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + test.AddOutput("o", {2, 2, 2, 1}, {3.f, 9.f, 6.f, 12.f, 15.f, 21.f, 18.f, 24.f}); + test.Run(); +} + // Implicit TEST(Einsum, ImplicitEinsumAsTensorContraction) { OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); @@ -520,7 +539,6 @@ TEST(Einsum, ImplicitEinsumAsTensorContraction) { test.Run(); } - // Test each theme for half support TEST(Einsum, ExplicitEinsumAsIdentity_1D_input_Half) { if (!HasCudaEnvironment(600)) {