Optimize Transpose CUDA Kernel (#10230)

* optimize transpose cuda

* fix comment typo
This commit is contained in:
Vincent Wang 2022-01-15 15:39:06 +08:00 committed by GitHub
parent a757bd7186
commit c12cafa524
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -101,7 +101,42 @@ Status Transpose::DoTranspose(const cudaDeviceProp& prop,
std::vector<int64_t> new_input_dims(input_dims.begin(), input_dims.end());
std::vector<int64_t> new_output_dims(output_dims.begin(), output_dims.end());
for (auto i = rank - 1; i > 0; i--) {
// Remove all dims with value 1.
std::vector<bool> dims_to_remove(new_rank, false);
int input_pos = 0;
int output_pos = 0;
int perm_pos = 0;
for (int i = 0; i < new_rank; ++i) {
if (new_input_dims[i] != 1) {
new_input_dims[input_pos++] = new_input_dims[i];
} else {
dims_to_remove[i] = true;
}
if (new_output_dims[i] != 1) {
new_output_dims[output_pos++] = new_output_dims[i];
}
}
for (int i = 0; i < new_rank; ++i) {
if (!dims_to_remove[new_permutations[i]]) {
new_permutations[perm_pos++] = new_permutations[i];
}
}
for (int i = new_rank - 1; i >= 0; --i) {
if (dims_to_remove[i]) {
for (int j = 0; j < perm_pos; ++j) {
if (new_permutations[j] > static_cast<size_t>(i)) {
new_permutations[j] -= 1;
}
}
}
}
ORT_ENFORCE(input_pos == output_pos && input_pos == perm_pos);
new_rank = input_pos;
new_input_dims.resize(new_rank);
new_output_dims.resize(new_rank);
new_permutations.resize(new_rank);
for (auto i = new_rank - 1; i > 0; i--) {
auto curr = new_permutations[i];
auto prev = new_permutations[i - 1];
if (prev + 1 == curr) {
@ -138,6 +173,13 @@ Status Transpose::DoTranspose(const cudaDeviceProp& prop,
new_input_dims.resize(new_rank);
new_output_dims.resize(new_rank);
if (new_rank <= 1) {
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output.MutableDataRaw(), input.DataRaw(),
input.Shape().Size() * input.DataType()->Size(), cudaMemcpyDeviceToDevice,
stream));
return Status::OK();
}
auto element_type = input.GetElementType();
size_t element_size = input.DataType()->Size();
if (element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT ||
@ -192,19 +234,10 @@ Status Transpose::DoTranspose(const cudaDeviceProp& prop,
stream, element_size, input_shape, tmp_input_strides, input.DataRaw(),
tmp_output_strides, output.MutableDataRaw(), gsl::narrow<int>(output.Shape().Size()),
grid_size, block_size);
} else if (CanDoTranspose4DParallelizeOneElementPerThread(
prop, element_size, new_rank, new_input_dims, new_permutations, grid_size, block_size)) {
// Trying to see if we can still do (best effort) more optimized transposing
// for the 4-D case before falling back to the generic case
TArray<int64_t> tmp_output_strides(new_rank);
for (auto i = 0; i < new_rank; i++) {
tmp_output_strides[static_cast<int32_t>(new_permutations[i])] = new_output_strides[i];
}
return Transpose4DParallelizeOneElementPerThread(
stream, element_size, input_shape, tmp_input_strides, input.DataRaw(),
tmp_output_strides, output.MutableDataRaw(), gsl::narrow<int>(output.Shape().Size()),
grid_size, block_size);
}
// We used to check if Transpose4DParallelizeOneElementPerThread can be used before falling back to generic case,
// But tests on lots of cases showing that Transpose4DParallelizeOneElementPerThread is not faster than generic case,
// and even much slower than generic case for some cases.
// General cases
TArray<int64_t> input_strides(new_rank);