mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
Removes unnecessary transpose in operator Einsum (#7141)
* remove one unnecessary transpose * add more unit test
This commit is contained in:
parent
d500c5952b
commit
b370ddbf5e
3 changed files with 78 additions and 12 deletions
|
|
@ -297,7 +297,6 @@ std::unique_ptr<Tensor> Transpose(const Tensor& input, const std::vector<int64_t
|
|||
if (!status.IsOK()) {
|
||||
ORT_THROW(ONNXRUNTIME, FAIL, "Einsum op: Transpose failed: ", status.ErrorMessage());
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -70,6 +70,26 @@ void EinsumTypedComputeProcessor<T>::FinalizeOutput(const Tensor& candidate_outp
|
|||
}
|
||||
}
|
||||
|
||||
static bool IsTransposeReshapeForEinsum(const std::vector<size_t>& perm,
|
||||
const std::vector<int64_t>& input_dims,
|
||||
std::vector<int64_t>& new_shape) {
|
||||
// As long as the dims with values > 1 stay in the same order, it's a reshape.
|
||||
// Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1).
|
||||
size_t last_permuted_axis = 0;
|
||||
for (size_t i = 0; i < perm.size(); ++i) {
|
||||
if (input_dims[perm[i]] == 1)
|
||||
continue;
|
||||
if (perm[i] < last_permuted_axis)
|
||||
return false;
|
||||
last_permuted_axis = perm[i];
|
||||
}
|
||||
new_shape = input_dims;
|
||||
for (size_t i = 0; i < perm.size(); ++i) {
|
||||
new_shape[i] = input_dims[perm[i]];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::unique_ptr<Tensor> EinsumTypedComputeProcessor<T>::PairwiseOperandProcess(const Tensor& left,
|
||||
const TensorShape& left_shape_override,
|
||||
|
|
@ -165,6 +185,7 @@ std::unique_ptr<Tensor> EinsumTypedComputeProcessor<T>::PairwiseOperandProcess(c
|
|||
}
|
||||
|
||||
// Permutate the left operand so that the axes order go like this: [lro, lo, reduce_dims, ro]
|
||||
std::vector<int64_t> reshaped_dims;
|
||||
std::vector<size_t> left_permutation;
|
||||
left_permutation.reserve(lro.size() + lo.size() + reduce_dims.size() + ro.size());
|
||||
left_permutation.insert(left_permutation.end(), lro.begin(), lro.end());
|
||||
|
|
@ -173,10 +194,21 @@ std::unique_ptr<Tensor> EinsumTypedComputeProcessor<T>::PairwiseOperandProcess(c
|
|||
left_permutation.insert(left_permutation.end(), ro.begin(), ro.end());
|
||||
if (EinsumOp::IsTransposeRequired(current_left ? current_left->Shape().GetDims().size() : left_dims.size(),
|
||||
left_permutation)) {
|
||||
current_left = EinsumOp::Transpose(current_left ? *current_left : left,
|
||||
current_left ? current_left->Shape().GetDims() : left_dims,
|
||||
left_permutation, allocator_, einsum_ep_assets_,
|
||||
device_transpose_func_);
|
||||
if (current_left && IsTransposeReshapeForEinsum(left_permutation,
|
||||
current_left->Shape().GetDims(),
|
||||
reshaped_dims)) {
|
||||
// This can be done because curent_* tensors (if they exist) and output tensors are
|
||||
// intermediate tensors and cannot be input tensors to the Einsum node itself
|
||||
// (which are immutable).
|
||||
// Covered by ExplicitEinsumAsTensorContractionReshapeLeft.
|
||||
current_left->Reshape(reshaped_dims);
|
||||
} else {
|
||||
// Covered by ExplicitEinsumAsTensorContraction, DiagonalWithMatmul, ...
|
||||
current_left = EinsumOp::Transpose(current_left ? *current_left : left,
|
||||
current_left ? current_left->Shape().GetDims() : left_dims,
|
||||
left_permutation, allocator_, einsum_ep_assets_,
|
||||
device_transpose_func_);
|
||||
}
|
||||
}
|
||||
|
||||
// Permutate the right operand so that the axes order go like this: [lro, reduce_dims, ro, lo]
|
||||
|
|
@ -188,10 +220,19 @@ std::unique_ptr<Tensor> EinsumTypedComputeProcessor<T>::PairwiseOperandProcess(c
|
|||
right_permutation.insert(right_permutation.end(), lo.begin(), lo.end());
|
||||
if (EinsumOp::IsTransposeRequired(current_right ? current_right->Shape().GetDims().size() : right_dims.size(),
|
||||
right_permutation)) {
|
||||
current_right = EinsumOp::Transpose(current_right ? *current_right : right,
|
||||
current_right ? current_right->Shape().GetDims() : right_dims,
|
||||
right_permutation, allocator_, einsum_ep_assets_,
|
||||
device_transpose_func_);
|
||||
if (current_right && IsTransposeReshapeForEinsum(right_permutation,
|
||||
current_right->Shape().GetDims(),
|
||||
reshaped_dims)) {
|
||||
// See note following the previous call of function IsTransposeReshapeForEinsum.
|
||||
// Covered by ExplicitEinsumAsBatchedMatmulWithBroadcasting_1, ExplicitEinsumAsMatmul_2, ...
|
||||
current_right->Reshape(reshaped_dims);
|
||||
} else {
|
||||
// Covered by DiagonalWithMatmul, ExplicitEinsumAsBatchedMatmul, ...
|
||||
current_right = EinsumOp::Transpose(current_right ? *current_right : right,
|
||||
current_right ? current_right->Shape().GetDims() : right_dims,
|
||||
right_permutation, allocator_, einsum_ep_assets_,
|
||||
device_transpose_func_);
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate output size
|
||||
|
|
@ -258,8 +299,16 @@ std::unique_ptr<Tensor> EinsumTypedComputeProcessor<T>::PairwiseOperandProcess(c
|
|||
|
||||
if (!is_final_pair) { // This is not the final pair - so bring the axes order to what the inputs conformed to
|
||||
if (EinsumOp::IsTransposeRequired(output_dims.size(), output_permutation)) {
|
||||
output = EinsumOp::Transpose(*output, output_dims, output_permutation, allocator_,
|
||||
einsum_ep_assets_, device_transpose_func_);
|
||||
if (IsTransposeReshapeForEinsum(output_permutation,
|
||||
output_dims,
|
||||
reshaped_dims)) {
|
||||
// See note following the previous call of function IsTransposeReshapeForEinsum.
|
||||
// Covered by ExplicitEinsumAsTensorContractionReshapeFinal.
|
||||
output->Reshape(reshaped_dims);
|
||||
} else {
|
||||
output = EinsumOp::Transpose(*output, output_dims, output_permutation, allocator_,
|
||||
einsum_ep_assets_, device_transpose_func_);
|
||||
}
|
||||
}
|
||||
} else { // This is the final pair - Transpose directly to the output ordering required and copy the contents to the op's output
|
||||
FinalizeOutput(*output, current_subscript_order);
|
||||
|
|
|
|||
|
|
@ -510,6 +510,25 @@ TEST(Einsum, ExplicitEinsumAsTensorContraction) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(Einsum, ExplicitEinsumAsTensorContractionReshapeFinal) {
|
||||
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
|
||||
test.AddAttribute<std::string>("equation", "sbcd,es,eh->bce");
|
||||
test.AddInput<float>("x", {2, 2, 2, 2}, {1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f});
|
||||
test.AddInput<float>("y", {2, 2}, {1.f, 2.f, -6.f, 2.f});
|
||||
test.AddInput<float>("z", {2, 2}, {3.f, 4.f, 5.f, 6.f});
|
||||
test.AddOutput<float>("o", {2, 2, 2}, {63.f, -132.f, 63.f, -132.f, 63.f, -132.f, 63.f, -132.f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(Einsum, ExplicitEinsumAsTensorContractionReshapeLeft) {
|
||||
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
|
||||
test.AddAttribute<std::string>("equation", "bsnh,btnh->bnts");
|
||||
test.AddInput<float>("x", {2, 1, 2, 2}, {1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f});
|
||||
test.AddInput<float>("y", {2, 2, 2, 1}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
|
||||
test.AddOutput<float>("o", {2, 2, 2, 1}, {3.f, 9.f, 6.f, 12.f, 15.f, 21.f, 18.f, 24.f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
// Implicit
|
||||
TEST(Einsum, ImplicitEinsumAsTensorContraction) {
|
||||
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
|
||||
|
|
@ -520,7 +539,6 @@ TEST(Einsum, ImplicitEinsumAsTensorContraction) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
|
||||
// Test each theme for half support
|
||||
TEST(Einsum, ExplicitEinsumAsIdentity_1D_input_Half) {
|
||||
if (!HasCudaEnvironment(600)) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue