diff --git a/onnxruntime/core/providers/cuda/tensor/tile_impl.cu b/onnxruntime/core/providers/cuda/tensor/tile_impl.cu index d5b6cc931e..adbf095c55 100644 --- a/onnxruntime/core/providers/cuda/tensor/tile_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/tile_impl.cu @@ -7,6 +7,35 @@ namespace onnxruntime { namespace cuda { +constexpr int MAX_DIMS = 16; + +template +__global__ void _UnRolledTileKernel( + const size_t shape_rank, + const TArray fdm_input_shape, + const TArray input_strides, + const T* input_data, + const TArray fdm_output_strides, + T* output_data, + const CUDA_LONG N) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + CUDA_LONG input_index = 0; + CUDA_LONG output_index = id; + + #pragma unroll + for (int dim = 0; dim < MAX_DIMS; ++dim) { + if (dim == shape_rank) { + break; + } + int out_coord, r; + fdm_output_strides[dim].divmod(output_index, out_coord, r); + output_index = r; + int in_coord = fdm_input_shape[dim].mod(out_coord); + input_index += input_strides[dim] * in_coord; + } + output_data[id] = input_data[input_index]; +} + template __global__ void _TileKernel( const size_t shape_rank, @@ -23,8 +52,7 @@ __global__ void _TileKernel( int out_coord, r; fdm_output_strides[dim].divmod(output_index, out_coord, r); output_index = r; - int q, in_coord; - fdm_input_shape[dim].divmod(out_coord, q, in_coord); + int in_coord = fdm_input_shape[dim].mod(out_coord); input_index += input_strides[dim] * in_coord; } output_data[id] = input_data[input_index]; @@ -41,9 +69,15 @@ void TileImpl( T* output_data, const size_t N) { int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); - _TileKernel<<>>( - shape_rank, fdm_input_shape, input_stride, input_data, - fdm_output_strides, output_data, (CUDA_LONG)N); + if (shape_rank > MAX_DIMS) { + _TileKernel<<>>( + shape_rank, fdm_input_shape, input_stride, input_data, + fdm_output_strides, output_data, (CUDA_LONG)N); + } else { + _UnRolledTileKernel<<>>( + shape_rank, fdm_input_shape, input_stride, input_data, + fdm_output_strides, output_data, (CUDA_LONG)N); + } } template diff --git a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc index e7636fbba1..f992992c9d 100644 --- a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc @@ -44,6 +44,63 @@ void RunTestWrapper() { // Tile3D RunTest({111, 112, 113, 122, 123, 124}, {2, 1, 3}, {1, 2, 1}, {3}, {111, 112, 113, 111, 112, 113, 122, 123, 124, 122, 123, 124}, {2, 2, 3}); + // Tile4D + RunTest( + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, // input + {1, 2, 3, 4}, // input dims + {2, 1, 2, 1}, // repeat + {4}, // repeat dims + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, // output + {2, 2, 6, 4} // output dims + ); + + // Tile5D + RunTest( + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, + 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71}, // input + {2, 3, 2, 3, 2}, // input dims + {2, 1, 2, 1, 2}, // repeat + {5}, // repeat dims + {0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9, + 8, 9, 10, 11, 10, 11, 0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 4, 5, + 6, 7, 6, 7, 8, 9, 8, 9, 10, 11, 10, 11, 12, 13, 12, 13, 14, 15, + 14, 15, 16, 17, 16, 17, 18, 19, 18, 19, 20, 21, 20, 21, 22, 23, 22, 23, + 12, 13, 12, 13, 14, 15, 14, 15, 16, 17, 16, 17, 18, 19, 18, 19, 20, 21, + 20, 21, 22, 23, 22, 23, 24, 25, 24, 25, 26, 27, 26, 27, 28, 29, 28, 29, + 30, 31, 30, 31, 32, 33, 32, 33, 34, 35, 34, 35, 24, 25, 24, 25, 26, 27, + 26, 27, 28, 29, 28, 29, 30, 31, 30, 31, 32, 33, 32, 33, 34, 35, 34, 35, + 36, 37, 36, 37, 38, 39, 38, 39, 40, 41, 40, 41, 42, 43, 42, 43, 44, 45, + 44, 45, 46, 47, 46, 47, 36, 37, 36, 37, 38, 39, 38, 39, 40, 41, 40, 41, + 42, 43, 42, 43, 44, 45, 44, 45, 46, 47, 46, 47, 48, 49, 48, 49, 50, 51, + 50, 51, 52, 53, 52, 53, 54, 55, 54, 55, 56, 57, 56, 57, 58, 59, 58, 59, + 48, 49, 48, 49, 50, 51, 50, 51, 52, 53, 52, 53, 54, 55, 54, 55, 56, 57, + 56, 57, 58, 59, 58, 59, 60, 61, 60, 61, 62, 63, 62, 63, 64, 65, 64, 65, + 66, 67, 66, 67, 68, 69, 68, 69, 70, 71, 70, 71, 60, 61, 60, 61, 62, 63, + 62, 63, 64, 65, 64, 65, 66, 67, 66, 67, 68, 69, 68, 69, 70, 71, 70, 71, + 0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9, + 8, 9, 10, 11, 10, 11, 0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 4, 5, + 6, 7, 6, 7, 8, 9, 8, 9, 10, 11, 10, 11, 12, 13, 12, 13, 14, 15, + 14, 15, 16, 17, 16, 17, 18, 19, 18, 19, 20, 21, 20, 21, 22, 23, 22, 23, + 12, 13, 12, 13, 14, 15, 14, 15, 16, 17, 16, 17, 18, 19, 18, 19, 20, 21, + 20, 21, 22, 23, 22, 23, 24, 25, 24, 25, 26, 27, 26, 27, 28, 29, 28, 29, + 30, 31, 30, 31, 32, 33, 32, 33, 34, 35, 34, 35, 24, 25, 24, 25, 26, 27, + 26, 27, 28, 29, 28, 29, 30, 31, 30, 31, 32, 33, 32, 33, 34, 35, 34, 35, + 36, 37, 36, 37, 38, 39, 38, 39, 40, 41, 40, 41, 42, 43, 42, 43, 44, 45, + 44, 45, 46, 47, 46, 47, 36, 37, 36, 37, 38, 39, 38, 39, 40, 41, 40, 41, + 42, 43, 42, 43, 44, 45, 44, 45, 46, 47, 46, 47, 48, 49, 48, 49, 50, 51, + 50, 51, 52, 53, 52, 53, 54, 55, 54, 55, 56, 57, 56, 57, 58, 59, 58, 59, + 48, 49, 48, 49, 50, 51, 50, 51, 52, 53, 52, 53, 54, 55, 54, 55, 56, 57, + 56, 57, 58, 59, 58, 59, 60, 61, 60, 61, 62, 63, 62, 63, 64, 65, 64, 65, + 66, 67, 66, 67, 68, 69, 68, 69, 70, 71, 70, 71, 60, 61, 60, 61, 62, 63, + 62, 63, 64, 65, 64, 65, 66, 67, 66, 67, 68, 69, 68, 69, 70, 71, 70, 71}, // output + {4, 3, 4, 3, 4} // output dims + ); + // Tile1DWithOneRepeats RunTest({111, 112, 113, 122, 123, 124}, {2, 1, 3}, {1, 1, 1}, {3}, {111, 112, 113, 122, 123, 124}, {2, 1, 3});