mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
Optimize _TileKernel for non-memcpy case (#9648)
* optimize _TileKernel for non-memcpy case * fallback shape_rank >MAX_DIMS
This commit is contained in:
parent
a355bcbd73
commit
ee167bd078
2 changed files with 96 additions and 5 deletions
|
|
@ -7,6 +7,35 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
constexpr int MAX_DIMS = 16;
|
||||
|
||||
template <typename T>
|
||||
__global__ void _UnRolledTileKernel(
|
||||
const size_t shape_rank,
|
||||
const TArray<fast_divmod> fdm_input_shape,
|
||||
const TArray<int64_t> input_strides,
|
||||
const T* input_data,
|
||||
const TArray<fast_divmod> fdm_output_strides,
|
||||
T* output_data,
|
||||
const CUDA_LONG N) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
|
||||
CUDA_LONG input_index = 0;
|
||||
CUDA_LONG output_index = id;
|
||||
|
||||
#pragma unroll
|
||||
for (int dim = 0; dim < MAX_DIMS; ++dim) {
|
||||
if (dim == shape_rank) {
|
||||
break;
|
||||
}
|
||||
int out_coord, r;
|
||||
fdm_output_strides[dim].divmod(output_index, out_coord, r);
|
||||
output_index = r;
|
||||
int in_coord = fdm_input_shape[dim].mod(out_coord);
|
||||
input_index += input_strides[dim] * in_coord;
|
||||
}
|
||||
output_data[id] = input_data[input_index];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void _TileKernel(
|
||||
const size_t shape_rank,
|
||||
|
|
@ -23,8 +52,7 @@ __global__ void _TileKernel(
|
|||
int out_coord, r;
|
||||
fdm_output_strides[dim].divmod(output_index, out_coord, r);
|
||||
output_index = r;
|
||||
int q, in_coord;
|
||||
fdm_input_shape[dim].divmod(out_coord, q, in_coord);
|
||||
int in_coord = fdm_input_shape[dim].mod(out_coord);
|
||||
input_index += input_strides[dim] * in_coord;
|
||||
}
|
||||
output_data[id] = input_data[input_index];
|
||||
|
|
@ -41,9 +69,15 @@ void TileImpl(
|
|||
T* output_data,
|
||||
const size_t N) {
|
||||
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
|
||||
_TileKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
shape_rank, fdm_input_shape, input_stride, input_data,
|
||||
fdm_output_strides, output_data, (CUDA_LONG)N);
|
||||
if (shape_rank > MAX_DIMS) {
|
||||
_TileKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
shape_rank, fdm_input_shape, input_stride, input_data,
|
||||
fdm_output_strides, output_data, (CUDA_LONG)N);
|
||||
} else {
|
||||
_UnRolledTileKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
shape_rank, fdm_input_shape, input_stride, input_data,
|
||||
fdm_output_strides, output_data, (CUDA_LONG)N);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -44,6 +44,63 @@ void RunTestWrapper() {
|
|||
// Tile3D
|
||||
RunTest<T>({111, 112, 113, 122, 123, 124}, {2, 1, 3}, {1, 2, 1}, {3}, {111, 112, 113, 111, 112, 113, 122, 123, 124, 122, 123, 124}, {2, 2, 3});
|
||||
|
||||
// Tile4D
|
||||
RunTest<T>(
|
||||
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, // input
|
||||
{1, 2, 3, 4}, // input dims
|
||||
{2, 1, 2, 1}, // repeat
|
||||
{4}, // repeat dims
|
||||
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, // output
|
||||
{2, 2, 6, 4} // output dims
|
||||
);
|
||||
|
||||
// Tile5D
|
||||
RunTest<T>(
|
||||
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
|
||||
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
|
||||
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
|
||||
54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71}, // input
|
||||
{2, 3, 2, 3, 2}, // input dims
|
||||
{2, 1, 2, 1, 2}, // repeat
|
||||
{5}, // repeat dims
|
||||
{0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9,
|
||||
8, 9, 10, 11, 10, 11, 0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 4, 5,
|
||||
6, 7, 6, 7, 8, 9, 8, 9, 10, 11, 10, 11, 12, 13, 12, 13, 14, 15,
|
||||
14, 15, 16, 17, 16, 17, 18, 19, 18, 19, 20, 21, 20, 21, 22, 23, 22, 23,
|
||||
12, 13, 12, 13, 14, 15, 14, 15, 16, 17, 16, 17, 18, 19, 18, 19, 20, 21,
|
||||
20, 21, 22, 23, 22, 23, 24, 25, 24, 25, 26, 27, 26, 27, 28, 29, 28, 29,
|
||||
30, 31, 30, 31, 32, 33, 32, 33, 34, 35, 34, 35, 24, 25, 24, 25, 26, 27,
|
||||
26, 27, 28, 29, 28, 29, 30, 31, 30, 31, 32, 33, 32, 33, 34, 35, 34, 35,
|
||||
36, 37, 36, 37, 38, 39, 38, 39, 40, 41, 40, 41, 42, 43, 42, 43, 44, 45,
|
||||
44, 45, 46, 47, 46, 47, 36, 37, 36, 37, 38, 39, 38, 39, 40, 41, 40, 41,
|
||||
42, 43, 42, 43, 44, 45, 44, 45, 46, 47, 46, 47, 48, 49, 48, 49, 50, 51,
|
||||
50, 51, 52, 53, 52, 53, 54, 55, 54, 55, 56, 57, 56, 57, 58, 59, 58, 59,
|
||||
48, 49, 48, 49, 50, 51, 50, 51, 52, 53, 52, 53, 54, 55, 54, 55, 56, 57,
|
||||
56, 57, 58, 59, 58, 59, 60, 61, 60, 61, 62, 63, 62, 63, 64, 65, 64, 65,
|
||||
66, 67, 66, 67, 68, 69, 68, 69, 70, 71, 70, 71, 60, 61, 60, 61, 62, 63,
|
||||
62, 63, 64, 65, 64, 65, 66, 67, 66, 67, 68, 69, 68, 69, 70, 71, 70, 71,
|
||||
0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9,
|
||||
8, 9, 10, 11, 10, 11, 0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 4, 5,
|
||||
6, 7, 6, 7, 8, 9, 8, 9, 10, 11, 10, 11, 12, 13, 12, 13, 14, 15,
|
||||
14, 15, 16, 17, 16, 17, 18, 19, 18, 19, 20, 21, 20, 21, 22, 23, 22, 23,
|
||||
12, 13, 12, 13, 14, 15, 14, 15, 16, 17, 16, 17, 18, 19, 18, 19, 20, 21,
|
||||
20, 21, 22, 23, 22, 23, 24, 25, 24, 25, 26, 27, 26, 27, 28, 29, 28, 29,
|
||||
30, 31, 30, 31, 32, 33, 32, 33, 34, 35, 34, 35, 24, 25, 24, 25, 26, 27,
|
||||
26, 27, 28, 29, 28, 29, 30, 31, 30, 31, 32, 33, 32, 33, 34, 35, 34, 35,
|
||||
36, 37, 36, 37, 38, 39, 38, 39, 40, 41, 40, 41, 42, 43, 42, 43, 44, 45,
|
||||
44, 45, 46, 47, 46, 47, 36, 37, 36, 37, 38, 39, 38, 39, 40, 41, 40, 41,
|
||||
42, 43, 42, 43, 44, 45, 44, 45, 46, 47, 46, 47, 48, 49, 48, 49, 50, 51,
|
||||
50, 51, 52, 53, 52, 53, 54, 55, 54, 55, 56, 57, 56, 57, 58, 59, 58, 59,
|
||||
48, 49, 48, 49, 50, 51, 50, 51, 52, 53, 52, 53, 54, 55, 54, 55, 56, 57,
|
||||
56, 57, 58, 59, 58, 59, 60, 61, 60, 61, 62, 63, 62, 63, 64, 65, 64, 65,
|
||||
66, 67, 66, 67, 68, 69, 68, 69, 70, 71, 70, 71, 60, 61, 60, 61, 62, 63,
|
||||
62, 63, 64, 65, 64, 65, 66, 67, 66, 67, 68, 69, 68, 69, 70, 71, 70, 71}, // output
|
||||
{4, 3, 4, 3, 4} // output dims
|
||||
);
|
||||
|
||||
// Tile1DWithOneRepeats
|
||||
RunTest<T>({111, 112, 113, 122, 123, 124}, {2, 1, 3}, {1, 1, 1}, {3}, {111, 112, 113, 122, 123, 124}, {2, 1, 3});
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue