mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Transpose kernel fix for illegal memory access error (#5294)
* transpose fix * minor update per comments Co-authored-by: Ethan Tao <ettao@microsoft.com>
This commit is contained in:
parent
1a04b8f8b7
commit
b18a8bc74f
1 changed files with 7 additions and 2 deletions
|
|
@ -22,14 +22,19 @@
|
|||
// kernel(s) for half functions with no library support
|
||||
namespace {
|
||||
|
||||
// TODO - refactor the function with similar logic in Transpose3DKernel using 16x16 Tile
|
||||
__global__ void transposeNoOverlap(half* odata, const half* idata, const int m, const int n) {
|
||||
__shared__ half tile[TRANS_TILE_DIM][TRANS_TILE_DIM + 1];
|
||||
|
||||
int x = blockIdx.x * TRANS_TILE_DIM + threadIdx.x;
|
||||
int y = blockIdx.y * TRANS_TILE_DIM + threadIdx.y;
|
||||
|
||||
for (int j = 0; j < TRANS_TILE_DIM; j += BLOCK_ROWS)
|
||||
tile[threadIdx.y + j][threadIdx.x] = idata[(y + j) * m + x];
|
||||
if (x < m) {
|
||||
for (int j = 0; j < TRANS_TILE_DIM; j += BLOCK_ROWS) {
|
||||
if (j >= (n - y)) continue;
|
||||
tile[threadIdx.y + j][threadIdx.x] = idata[(y + j) * m + x];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue