Optimize _TileKernel for non-memcpy case (#9648)

* optimize _TileKernel for non-memcpy case

* fallback shape_rank >MAX_DIMS
This commit is contained in:
pengwa 2021-11-05 09:22:09 +08:00 committed by GitHub
parent a355bcbd73
commit ee167bd078
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 96 additions and 5 deletions

View file

@ -7,6 +7,35 @@
namespace onnxruntime {
namespace cuda {
constexpr int MAX_DIMS = 16;
template <typename T>
__global__ void _UnRolledTileKernel(
const size_t shape_rank,
const TArray<fast_divmod> fdm_input_shape,
const TArray<int64_t> input_strides,
const T* input_data,
const TArray<fast_divmod> fdm_output_strides,
T* output_data,
const CUDA_LONG N) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
CUDA_LONG input_index = 0;
CUDA_LONG output_index = id;
#pragma unroll
for (int dim = 0; dim < MAX_DIMS; ++dim) {
if (dim == shape_rank) {
break;
}
int out_coord, r;
fdm_output_strides[dim].divmod(output_index, out_coord, r);
output_index = r;
int in_coord = fdm_input_shape[dim].mod(out_coord);
input_index += input_strides[dim] * in_coord;
}
output_data[id] = input_data[input_index];
}
template <typename T>
__global__ void _TileKernel(
const size_t shape_rank,
@ -23,8 +52,7 @@ __global__ void _TileKernel(
int out_coord, r;
fdm_output_strides[dim].divmod(output_index, out_coord, r);
output_index = r;
int q, in_coord;
fdm_input_shape[dim].divmod(out_coord, q, in_coord);
int in_coord = fdm_input_shape[dim].mod(out_coord);
input_index += input_strides[dim] * in_coord;
}
output_data[id] = input_data[input_index];
@ -41,9 +69,15 @@ void TileImpl(
T* output_data,
const size_t N) {
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
_TileKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
shape_rank, fdm_input_shape, input_stride, input_data,
fdm_output_strides, output_data, (CUDA_LONG)N);
if (shape_rank > MAX_DIMS) {
_TileKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
shape_rank, fdm_input_shape, input_stride, input_data,
fdm_output_strides, output_data, (CUDA_LONG)N);
} else {
_UnRolledTileKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
shape_rank, fdm_input_shape, input_stride, input_data,
fdm_output_strides, output_data, (CUDA_LONG)N);
}
}
template <typename T>

View file

@ -44,6 +44,63 @@ void RunTestWrapper() {
// Tile3D
RunTest<T>({111, 112, 113, 122, 123, 124}, {2, 1, 3}, {1, 2, 1}, {3}, {111, 112, 113, 111, 112, 113, 122, 123, 124, 122, 123, 124}, {2, 2, 3});
// Tile4D
RunTest<T>(
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, // input
{1, 2, 3, 4}, // input dims
{2, 1, 2, 1}, // repeat
{4}, // repeat dims
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, // output
{2, 2, 6, 4} // output dims
);
// Tile5D
RunTest<T>(
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71}, // input
{2, 3, 2, 3, 2}, // input dims
{2, 1, 2, 1, 2}, // repeat
{5}, // repeat dims
{0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9,
8, 9, 10, 11, 10, 11, 0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 4, 5,
6, 7, 6, 7, 8, 9, 8, 9, 10, 11, 10, 11, 12, 13, 12, 13, 14, 15,
14, 15, 16, 17, 16, 17, 18, 19, 18, 19, 20, 21, 20, 21, 22, 23, 22, 23,
12, 13, 12, 13, 14, 15, 14, 15, 16, 17, 16, 17, 18, 19, 18, 19, 20, 21,
20, 21, 22, 23, 22, 23, 24, 25, 24, 25, 26, 27, 26, 27, 28, 29, 28, 29,
30, 31, 30, 31, 32, 33, 32, 33, 34, 35, 34, 35, 24, 25, 24, 25, 26, 27,
26, 27, 28, 29, 28, 29, 30, 31, 30, 31, 32, 33, 32, 33, 34, 35, 34, 35,
36, 37, 36, 37, 38, 39, 38, 39, 40, 41, 40, 41, 42, 43, 42, 43, 44, 45,
44, 45, 46, 47, 46, 47, 36, 37, 36, 37, 38, 39, 38, 39, 40, 41, 40, 41,
42, 43, 42, 43, 44, 45, 44, 45, 46, 47, 46, 47, 48, 49, 48, 49, 50, 51,
50, 51, 52, 53, 52, 53, 54, 55, 54, 55, 56, 57, 56, 57, 58, 59, 58, 59,
48, 49, 48, 49, 50, 51, 50, 51, 52, 53, 52, 53, 54, 55, 54, 55, 56, 57,
56, 57, 58, 59, 58, 59, 60, 61, 60, 61, 62, 63, 62, 63, 64, 65, 64, 65,
66, 67, 66, 67, 68, 69, 68, 69, 70, 71, 70, 71, 60, 61, 60, 61, 62, 63,
62, 63, 64, 65, 64, 65, 66, 67, 66, 67, 68, 69, 68, 69, 70, 71, 70, 71,
0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9,
8, 9, 10, 11, 10, 11, 0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 4, 5,
6, 7, 6, 7, 8, 9, 8, 9, 10, 11, 10, 11, 12, 13, 12, 13, 14, 15,
14, 15, 16, 17, 16, 17, 18, 19, 18, 19, 20, 21, 20, 21, 22, 23, 22, 23,
12, 13, 12, 13, 14, 15, 14, 15, 16, 17, 16, 17, 18, 19, 18, 19, 20, 21,
20, 21, 22, 23, 22, 23, 24, 25, 24, 25, 26, 27, 26, 27, 28, 29, 28, 29,
30, 31, 30, 31, 32, 33, 32, 33, 34, 35, 34, 35, 24, 25, 24, 25, 26, 27,
26, 27, 28, 29, 28, 29, 30, 31, 30, 31, 32, 33, 32, 33, 34, 35, 34, 35,
36, 37, 36, 37, 38, 39, 38, 39, 40, 41, 40, 41, 42, 43, 42, 43, 44, 45,
44, 45, 46, 47, 46, 47, 36, 37, 36, 37, 38, 39, 38, 39, 40, 41, 40, 41,
42, 43, 42, 43, 44, 45, 44, 45, 46, 47, 46, 47, 48, 49, 48, 49, 50, 51,
50, 51, 52, 53, 52, 53, 54, 55, 54, 55, 56, 57, 56, 57, 58, 59, 58, 59,
48, 49, 48, 49, 50, 51, 50, 51, 52, 53, 52, 53, 54, 55, 54, 55, 56, 57,
56, 57, 58, 59, 58, 59, 60, 61, 60, 61, 62, 63, 62, 63, 64, 65, 64, 65,
66, 67, 66, 67, 68, 69, 68, 69, 70, 71, 70, 71, 60, 61, 60, 61, 62, 63,
62, 63, 64, 65, 64, 65, 66, 67, 66, 67, 68, 69, 68, 69, 70, 71, 70, 71}, // output
{4, 3, 4, 3, 4} // output dims
);
// Tile1DWithOneRepeats
RunTest<T>({111, 112, 113, 122, 123, 124}, {2, 1, 3}, {1, 1, 1}, {3}, {111, 112, 113, 122, 123, 124}, {2, 1, 3});