diff --git a/onnxruntime/core/providers/cuda/tensor/transpose.cc b/onnxruntime/core/providers/cuda/tensor/transpose.cc index 99922b88c1..d785c8422b 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose.cc +++ b/onnxruntime/core/providers/cuda/tensor/transpose.cc @@ -101,7 +101,42 @@ Status Transpose::DoTranspose(const cudaDeviceProp& prop, std::vector new_input_dims(input_dims.begin(), input_dims.end()); std::vector new_output_dims(output_dims.begin(), output_dims.end()); - for (auto i = rank - 1; i > 0; i--) { + // Remove all dims with value 1. + std::vector dims_to_remove(new_rank, false); + int input_pos = 0; + int output_pos = 0; + int perm_pos = 0; + for (int i = 0; i < new_rank; ++i) { + if (new_input_dims[i] != 1) { + new_input_dims[input_pos++] = new_input_dims[i]; + } else { + dims_to_remove[i] = true; + } + if (new_output_dims[i] != 1) { + new_output_dims[output_pos++] = new_output_dims[i]; + } + } + for (int i = 0; i < new_rank; ++i) { + if (!dims_to_remove[new_permutations[i]]) { + new_permutations[perm_pos++] = new_permutations[i]; + } + } + for (int i = new_rank - 1; i >= 0; --i) { + if (dims_to_remove[i]) { + for (int j = 0; j < perm_pos; ++j) { + if (new_permutations[j] > static_cast(i)) { + new_permutations[j] -= 1; + } + } + } + } + ORT_ENFORCE(input_pos == output_pos && input_pos == perm_pos); + new_rank = input_pos; + new_input_dims.resize(new_rank); + new_output_dims.resize(new_rank); + new_permutations.resize(new_rank); + + for (auto i = new_rank - 1; i > 0; i--) { auto curr = new_permutations[i]; auto prev = new_permutations[i - 1]; if (prev + 1 == curr) { @@ -138,6 +173,13 @@ Status Transpose::DoTranspose(const cudaDeviceProp& prop, new_input_dims.resize(new_rank); new_output_dims.resize(new_rank); + if (new_rank <= 1) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output.MutableDataRaw(), input.DataRaw(), + input.Shape().Size() * input.DataType()->Size(), cudaMemcpyDeviceToDevice, + stream)); + return Status::OK(); + } + auto element_type = input.GetElementType(); size_t element_size = input.DataType()->Size(); if (element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT || @@ -192,19 +234,10 @@ Status Transpose::DoTranspose(const cudaDeviceProp& prop, stream, element_size, input_shape, tmp_input_strides, input.DataRaw(), tmp_output_strides, output.MutableDataRaw(), gsl::narrow(output.Shape().Size()), grid_size, block_size); - } else if (CanDoTranspose4DParallelizeOneElementPerThread( - prop, element_size, new_rank, new_input_dims, new_permutations, grid_size, block_size)) { - // Trying to see if we can still do (best effort) more optimized transposing - // for the 4-D case before falling back to the generic case - TArray tmp_output_strides(new_rank); - for (auto i = 0; i < new_rank; i++) { - tmp_output_strides[static_cast(new_permutations[i])] = new_output_strides[i]; - } - return Transpose4DParallelizeOneElementPerThread( - stream, element_size, input_shape, tmp_input_strides, input.DataRaw(), - tmp_output_strides, output.MutableDataRaw(), gsl::narrow(output.Shape().Size()), - grid_size, block_size); } + // We used to check if Transpose4DParallelizeOneElementPerThread can be used before falling back to generic case, + // But tests on lots of cases showing that Transpose4DParallelizeOneElementPerThread is not faster than generic case, + // and even much slower than generic case for some cases. // General cases TArray input_strides(new_rank);