Removes unnecessary transpose in operator Einsum (#7141)

* remove one unnecessary transpose
* add more unit test
This commit is contained in:
Xavier Dupré 2021-03-31 09:59:08 +02:00 committed by GitHub
parent d500c5952b
commit b370ddbf5e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 78 additions and 12 deletions

View file

@ -297,7 +297,6 @@ std::unique_ptr<Tensor> Transpose(const Tensor& input, const std::vector<int64_t
if (!status.IsOK()) {
ORT_THROW(ONNXRUNTIME, FAIL, "Einsum op: Transpose failed: ", status.ErrorMessage());
}
return output;
}

View file

@ -70,6 +70,26 @@ void EinsumTypedComputeProcessor<T>::FinalizeOutput(const Tensor& candidate_outp
}
}
static bool IsTransposeReshapeForEinsum(const std::vector<size_t>& perm,
const std::vector<int64_t>& input_dims,
std::vector<int64_t>& new_shape) {
// As long as the dims with values > 1 stay in the same order, it's a reshape.
// Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1).
size_t last_permuted_axis = 0;
for (size_t i = 0; i < perm.size(); ++i) {
if (input_dims[perm[i]] == 1)
continue;
if (perm[i] < last_permuted_axis)
return false;
last_permuted_axis = perm[i];
}
new_shape = input_dims;
for (size_t i = 0; i < perm.size(); ++i) {
new_shape[i] = input_dims[perm[i]];
}
return true;
}
template <typename T>
std::unique_ptr<Tensor> EinsumTypedComputeProcessor<T>::PairwiseOperandProcess(const Tensor& left,
const TensorShape& left_shape_override,
@ -165,6 +185,7 @@ std::unique_ptr<Tensor> EinsumTypedComputeProcessor<T>::PairwiseOperandProcess(c
}
// Permutate the left operand so that the axes order go like this: [lro, lo, reduce_dims, ro]
std::vector<int64_t> reshaped_dims;
std::vector<size_t> left_permutation;
left_permutation.reserve(lro.size() + lo.size() + reduce_dims.size() + ro.size());
left_permutation.insert(left_permutation.end(), lro.begin(), lro.end());
@ -173,10 +194,21 @@ std::unique_ptr<Tensor> EinsumTypedComputeProcessor<T>::PairwiseOperandProcess(c
left_permutation.insert(left_permutation.end(), ro.begin(), ro.end());
if (EinsumOp::IsTransposeRequired(current_left ? current_left->Shape().GetDims().size() : left_dims.size(),
left_permutation)) {
current_left = EinsumOp::Transpose(current_left ? *current_left : left,
current_left ? current_left->Shape().GetDims() : left_dims,
left_permutation, allocator_, einsum_ep_assets_,
device_transpose_func_);
if (current_left && IsTransposeReshapeForEinsum(left_permutation,
current_left->Shape().GetDims(),
reshaped_dims)) {
// This can be done because curent_* tensors (if they exist) and output tensors are
// intermediate tensors and cannot be input tensors to the Einsum node itself
// (which are immutable).
// Covered by ExplicitEinsumAsTensorContractionReshapeLeft.
current_left->Reshape(reshaped_dims);
} else {
// Covered by ExplicitEinsumAsTensorContraction, DiagonalWithMatmul, ...
current_left = EinsumOp::Transpose(current_left ? *current_left : left,
current_left ? current_left->Shape().GetDims() : left_dims,
left_permutation, allocator_, einsum_ep_assets_,
device_transpose_func_);
}
}
// Permutate the right operand so that the axes order go like this: [lro, reduce_dims, ro, lo]
@ -188,10 +220,19 @@ std::unique_ptr<Tensor> EinsumTypedComputeProcessor<T>::PairwiseOperandProcess(c
right_permutation.insert(right_permutation.end(), lo.begin(), lo.end());
if (EinsumOp::IsTransposeRequired(current_right ? current_right->Shape().GetDims().size() : right_dims.size(),
right_permutation)) {
current_right = EinsumOp::Transpose(current_right ? *current_right : right,
current_right ? current_right->Shape().GetDims() : right_dims,
right_permutation, allocator_, einsum_ep_assets_,
device_transpose_func_);
if (current_right && IsTransposeReshapeForEinsum(right_permutation,
current_right->Shape().GetDims(),
reshaped_dims)) {
// See note following the previous call of function IsTransposeReshapeForEinsum.
// Covered by ExplicitEinsumAsBatchedMatmulWithBroadcasting_1, ExplicitEinsumAsMatmul_2, ...
current_right->Reshape(reshaped_dims);
} else {
// Covered by DiagonalWithMatmul, ExplicitEinsumAsBatchedMatmul, ...
current_right = EinsumOp::Transpose(current_right ? *current_right : right,
current_right ? current_right->Shape().GetDims() : right_dims,
right_permutation, allocator_, einsum_ep_assets_,
device_transpose_func_);
}
}
// Calculate output size
@ -258,8 +299,16 @@ std::unique_ptr<Tensor> EinsumTypedComputeProcessor<T>::PairwiseOperandProcess(c
if (!is_final_pair) { // This is not the final pair - so bring the axes order to what the inputs conformed to
if (EinsumOp::IsTransposeRequired(output_dims.size(), output_permutation)) {
output = EinsumOp::Transpose(*output, output_dims, output_permutation, allocator_,
einsum_ep_assets_, device_transpose_func_);
if (IsTransposeReshapeForEinsum(output_permutation,
output_dims,
reshaped_dims)) {
// See note following the previous call of function IsTransposeReshapeForEinsum.
// Covered by ExplicitEinsumAsTensorContractionReshapeFinal.
output->Reshape(reshaped_dims);
} else {
output = EinsumOp::Transpose(*output, output_dims, output_permutation, allocator_,
einsum_ep_assets_, device_transpose_func_);
}
}
} else { // This is the final pair - Transpose directly to the output ordering required and copy the contents to the op's output
FinalizeOutput(*output, current_subscript_order);

View file

@ -510,6 +510,25 @@ TEST(Einsum, ExplicitEinsumAsTensorContraction) {
test.Run();
}
TEST(Einsum, ExplicitEinsumAsTensorContractionReshapeFinal) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "sbcd,es,eh->bce");
test.AddInput<float>("x", {2, 2, 2, 2}, {1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f});
test.AddInput<float>("y", {2, 2}, {1.f, 2.f, -6.f, 2.f});
test.AddInput<float>("z", {2, 2}, {3.f, 4.f, 5.f, 6.f});
test.AddOutput<float>("o", {2, 2, 2}, {63.f, -132.f, 63.f, -132.f, 63.f, -132.f, 63.f, -132.f});
test.Run();
}
TEST(Einsum, ExplicitEinsumAsTensorContractionReshapeLeft) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "bsnh,btnh->bnts");
test.AddInput<float>("x", {2, 1, 2, 2}, {1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f});
test.AddInput<float>("y", {2, 2, 2, 1}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
test.AddOutput<float>("o", {2, 2, 2, 1}, {3.f, 9.f, 6.f, 12.f, 15.f, 21.f, 18.f, 24.f});
test.Run();
}
// Implicit
TEST(Einsum, ImplicitEinsumAsTensorContraction) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
@ -520,7 +539,6 @@ TEST(Einsum, ImplicitEinsumAsTensorContraction) {
test.Run();
}
// Test each theme for half support
TEST(Einsum, ExplicitEinsumAsIdentity_1D_input_Half) {
if (!HasCudaEnvironment(600)) {