diff --git a/onnxruntime/core/providers/cuda/fpgeneric.cu b/onnxruntime/core/providers/cuda/fpgeneric.cu index 5c0df6332d..0d096148b5 100644 --- a/onnxruntime/core/providers/cuda/fpgeneric.cu +++ b/onnxruntime/core/providers/cuda/fpgeneric.cu @@ -48,67 +48,6 @@ __global__ void transposeNoOverlap(half* odata, const half* idata, const int m, odata[(y + j) * n + x] = tile[threadIdx.x][threadIdx.y + j]; } } -// set up curand state, need to move up layer to remove calling for each generate call -__global__ void setup_state(curandState* state, unsigned long long seed) { - curand_init(seed, 0, 0, state); -} - -__global__ void GenerateUniformHalf(curandState* state, half* result, int n) { - int id = blockIdx.x * blockDim.x + threadIdx.x; - if (id >= n) return; - - curandState localState = *state; - - float x; - skipahead(id, &localState); - x = curand_uniform(&localState); - - result[id] = x; - if (id == n - 1) *state = localState; -} - -__global__ void GenerateNormalHalf(curandState* state, half* result, int n, half mean, half stddev) { - int id = blockIdx.x * blockDim.x + threadIdx.x; - if (id >= n) return; - - curandState localState = *state; - - float x; - skipahead(id, &localState); - x = curand_normal(&localState); - - result[id] = (float)mean + (float)stddev * x; - if (id == n - 1) *state = localState; -} - -// kernels can convert matrix between half and float. speed currently not optimized, may need to add half2 -/* -__global__ void copyHalf2Float(float *odata, const half *idata, const int n) -{ - float tmp[COPY_TILE_DIM/COPY_BLOCK_DIM]; - - int x = blockIdx.x * COPY_TILE_DIM + threadIdx.x; - - for (int j = 0; j < COPY_TILE_DIM/COPY_BLOCK_DIM; j++) - tmp[j] = (float) idata[x + j*COPY_BLOCK_DIM]; - - for (int j = 0; j < COPY_TILE_DIM/COPY_BLOCK_DIM; j++) - if(x + j*COPY_BLOCK_DIM < n) odata[x + j*COPY_BLOCK_DIM] = tmp[j]; -} - -__global__ void copyFloat2Half(half *odata, const float *idata, const int n) -{ - float tmp[COPY_TILE_DIM/COPY_BLOCK_DIM]; - - int x = blockIdx.x * COPY_TILE_DIM + threadIdx.x; - - for (int j = 0; j < COPY_TILE_DIM/COPY_BLOCK_DIM; j++) - tmp[j] = idata[x + j*COPY_BLOCK_DIM]; - - for (int j = 0; j < COPY_TILE_DIM/COPY_BLOCK_DIM; j++) - if(x + j*COPY_BLOCK_DIM < n) odata[x + j*COPY_BLOCK_DIM] = tmp[j]; -} -*/ __global__ void CopyVectorHalf(const half* x, int incx, half* y, int incy, int n) { int id = blockIdx.x * blockDim.x + threadIdx.x; @@ -135,28 +74,4 @@ cublasStatus_t cublasCopyHelper(cublasHandle_t, int n, const half* x, int incx, dim3 dimBlock(COPY_BLOCK_DIM, 1, 1); CopyVectorHalf<<>>(x, incx, y, incy, n); return CUBLAS_STATUS_SUCCESS; -} - -curandStatus_t curandGenerateUniformHelper(curandGenerator_t, half* outputPtr, size_t num) { - curandState* devStates; - cudaMalloc((void**)&devStates, sizeof(curandState)); - setup_state<<<1, 1>>>(devStates, time(NULL)); // What does curandGenerateUniform actually doing? should also pass in state here - - dim3 dimGrid((unsigned int)(num + COPY_BLOCK_DIM - 1) / COPY_BLOCK_DIM, 1, 1); - dim3 dimBlock(COPY_BLOCK_DIM, 1, 1); - GenerateUniformHalf<<>>(devStates, outputPtr, (int)num); - - return (curandStatus_t)0; -} - -curandStatus_t curandGenerateNormalHelper(curandGenerator_t, half* outputPtr, size_t n, half mean, half stddev) { - curandState* devStates; - cudaMalloc((void**)&devStates, sizeof(curandState)); - setup_state<<<1, 1>>>(devStates, time(NULL)); // What does curandGenerateUniform actually doing? should also pass in state here - - dim3 dimGrid((unsigned int)(n + COPY_BLOCK_DIM - 1) / COPY_BLOCK_DIM, 1, 1); - dim3 dimBlock(COPY_BLOCK_DIM, 1, 1); - GenerateNormalHalf<<>>(devStates, outputPtr, (int)n, mean, stddev); - - return (curandStatus_t)0; -} +} \ No newline at end of file diff --git a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h index 91d480ceaf..4bef68b343 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h +++ b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h @@ -319,18 +319,6 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, #endif } -// axpy -inline cublasStatus_t cublasAxpyHelper(cublasHandle_t handle, int n, const float* alpha, const float* x, int incx, float* y, int incy) { - return cublasSaxpy(handle, n, alpha, x, incx, y, incy); -} -inline cublasStatus_t cublasAxpyHelper(cublasHandle_t handle, int n, const double* alpha, const double* x, int incx, double* y, int incy) { - return cublasDaxpy(handle, n, alpha, x, incx, y, incy); -} -inline cublasStatus_t cublasAxpyHelper(cublasHandle_t handle, int n, const half* alpha, const half* x, int incx, half* y, int incy) { - float tmp_alpha = onnxruntime::math::halfToFloat(*reinterpret_cast(alpha)); - return cublasAxpyEx(handle, n, (void*)&tmp_alpha, CUDA_R_32F, (void*)x, CUDA_R_16F, incx, (void*)y, CUDA_R_16F, incy, CUDA_R_32F); -} - // transpose using geam inline cublasStatus_t cublasTransposeHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, const float* alpha, const float* A, int lda, const float* beta, const float* B, int ldb, float* C, int ldc) { return cublasSgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); @@ -340,165 +328,6 @@ inline cublasStatus_t cublasTransposeHelper(cublasHandle_t handle, cublasOperati } cublasStatus_t cublasTransposeHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int); -// asum -inline cublasStatus_t cublasAsumHelper(cublasHandle_t handle, int n, const float* x, int incx, float* result) { - return cublasSasum(handle, n, x, incx, result); -} -inline cublasStatus_t cublasAsumHelper(cublasHandle_t handle, int n, const double* x, int incx, double* result) { - return cublasDasum(handle, n, x, incx, result); -} -inline cublasStatus_t cublasAsumHelper(cublasHandle_t, int n, const half* x, int incx, half* result) { - // pass in cudnn handle/descriptor to remove overhead? - cudnnHandle_t cudnnHandle; - cudnnTensorDescriptor_t srcTensorDesc, dstTensorDesc; - cudnnReduceTensorDescriptor_t reduceTensorDesc; - - cudnnCreate(&cudnnHandle); - cudnnCreateTensorDescriptor(&srcTensorDesc); - cudnnCreateTensorDescriptor(&dstTensorDesc); - cudnnCreateReduceTensorDescriptor(&reduceTensorDesc); - - cudnnSetTensor4dDescriptorEx(srcTensorDesc, CUDNN_DATA_HALF, 1, 1, 1, n, 1, 1, 1, incx); - cudnnSetTensor4dDescriptorEx(dstTensorDesc, CUDNN_DATA_HALF, 1, 1, 1, 1, 1, 1, 1, 1); - cudnnSetReduceTensorDescriptor(reduceTensorDesc, - CUDNN_REDUCE_TENSOR_NORM1, - CUDNN_DATA_FLOAT, - CUDNN_NOT_PROPAGATE_NAN, - CUDNN_REDUCE_TENSOR_NO_INDICES, - CUDNN_32BIT_INDICES); - - void* workspace = NULL; - size_t workspaceSizeInBytes = 0; - cudnnGetReductionWorkspaceSize(cudnnHandle, reduceTensorDesc, srcTensorDesc, dstTensorDesc, &workspaceSizeInBytes); - if (workspaceSizeInBytes > 0) cudaMalloc(&workspace, workspaceSizeInBytes); - - float alpha = 1.0f; - float beta = 0.0f; - - void* d_res; - cudaMalloc(&d_res, sizeof(half)); - - cudnnReduceTensor(cudnnHandle, - reduceTensorDesc, - NULL, - 0, - workspace, - workspaceSizeInBytes, - &alpha, - srcTensorDesc, - (void*)x, - &beta, - dstTensorDesc, - d_res); - - cudaMemcpy((void*)result, d_res, sizeof(half), cudaMemcpyDeviceToHost); - - cudnnDestroyReduceTensorDescriptor(reduceTensorDesc); - cudnnDestroyTensorDescriptor(srcTensorDesc); - cudnnDestroyTensorDescriptor(dstTensorDesc); - cudnnDestroy(cudnnHandle); - cudaFree(d_res); - cudaFree(workspace); - - return (cublasStatus_t)0; -} - -// amax -inline cublasStatus_t cublasAmaxHelper(cublasHandle_t handle, int n, const float* x, int incx, int* result) { - return cublasIsamax(handle, n, x, incx, result); -} -inline cublasStatus_t cublasAmaxHelper(cublasHandle_t handle, int n, const double* x, int incx, int* result) { - return cublasIdamax(handle, n, x, incx, result); -} -inline cublasStatus_t cublasAmaxHelper(cublasHandle_t, int n, const half* x, int incx, int* result) { - unsigned int h_result_uint = 0; - // pass in cudnn handle/descriptor to remove overhead? - cudnnHandle_t cudnnHandle; - cudnnTensorDescriptor_t srcTensorDesc, dstTensorDesc; - cudnnReduceTensorDescriptor_t reduceTensorDesc; - - cudnnCreate(&cudnnHandle); - cudnnCreateTensorDescriptor(&srcTensorDesc); - cudnnCreateTensorDescriptor(&dstTensorDesc); - cudnnCreateReduceTensorDescriptor(&reduceTensorDesc); - - cudnnSetTensor4dDescriptorEx(srcTensorDesc, CUDNN_DATA_HALF, 1, 1, 1, n, 1, 1, 1, incx); - cudnnSetTensor4dDescriptorEx(dstTensorDesc, CUDNN_DATA_HALF, 1, 1, 1, 1, 1, 1, 1, 1); - cudnnSetReduceTensorDescriptor(reduceTensorDesc, - CUDNN_REDUCE_TENSOR_AMAX, - CUDNN_DATA_FLOAT, - CUDNN_NOT_PROPAGATE_NAN, - CUDNN_REDUCE_TENSOR_FLATTENED_INDICES, - CUDNN_32BIT_INDICES); - - void* workspace = NULL; - size_t workspaceSizeInBytes = 0; - cudnnGetReductionWorkspaceSize(cudnnHandle, reduceTensorDesc, srcTensorDesc, dstTensorDesc, &workspaceSizeInBytes); - if (workspaceSizeInBytes > 0) cudaMalloc(&workspace, workspaceSizeInBytes); - - float alpha = 1.0f; - float beta = 0.0f; - void* d_max; - cudaMalloc(&d_max, sizeof(half)); - void* d_result_uint; - cudaMalloc(&d_result_uint, sizeof(unsigned int)); - - cudnnReduceTensor(cudnnHandle, - reduceTensorDesc, - d_result_uint, - sizeof(unsigned int), - workspace, - workspaceSizeInBytes, - &alpha, - srcTensorDesc, - (void*)x, - &beta, - dstTensorDesc, - d_max); - - cudaMemcpy(&h_result_uint, d_result_uint, sizeof(unsigned int), cudaMemcpyDeviceToHost); - - cudnnDestroyReduceTensorDescriptor(reduceTensorDesc); - cudnnDestroyTensorDescriptor(srcTensorDesc); - cudnnDestroyTensorDescriptor(dstTensorDesc); - cudnnDestroy(cudnnHandle); - cudaFree(workspace); - cudaFree(d_max); - cudaFree(d_result_uint); - - *result = (int)h_result_uint; - return (cublasStatus_t)0; -} - -// scal -inline cublasStatus_t cublasScalHelper(cublasHandle_t handle, int n, const float* alpha, float* x, int incx) { - return cublasSscal(handle, n, alpha, x, incx); -} -inline cublasStatus_t cublasScalHelper(cublasHandle_t handle, int n, const double* alpha, double* x, int incx) { - return cublasDscal(handle, n, alpha, x, incx); -} -inline cublasStatus_t cublasScalHelper(cublasHandle_t handle, int n, const half* alpha, half* x, int incx) { - float tmp_alpha = onnxruntime::math::halfToFloat(*reinterpret_cast(alpha)); - return cublasScalEx(handle, n, (void*)&tmp_alpha, CUDA_R_32F, (void*)x, CUDA_R_16F, incx, CUDA_R_32F); -} -inline cublasStatus_t cublasScalHelper(cublasHandle_t, int, const char*, char*, int) { - ORT_NOT_IMPLEMENTED("Unsupported template argument(char) in cublas_scal"); -} -inline cublasStatus_t cublasScalHelper(cublasHandle_t, int, const short*, short*, int) { - ORT_NOT_IMPLEMENTED("Unsupported template argument(short) in cublas_scal"); -} - -// dot -inline cublasStatus_t cublasDotHelper(cublasHandle_t handle, int n, const float* x, int incx, const float* y, int incy, float* result) { - return cublasSdot(handle, n, x, incx, y, incy, result); -} -inline cublasStatus_t cublasDotHelper(cublasHandle_t handle, int n, const double* x, int incx, const double* y, int incy, double* result) { - return cublasDdot(handle, n, x, incx, y, incy, result); -} -inline cublasStatus_t cublasDotHelper(cublasHandle_t handle, int n, const half* x, int incx, const half* y, int incy, half* result) { - return cublasDotEx(handle, n, (void*)x, CUDA_R_16F, incx, (void*)y, CUDA_R_16F, incy, (void*)result, CUDA_R_16F, CUDA_R_32F); -} - // copy inline cublasStatus_t cublasCopyHelper(cublasHandle_t handle, int n, const float* x, int incx, float* y, int incy) { return cublasScopy(handle, n, x, incx, y, incy); @@ -508,35 +337,6 @@ inline cublasStatus_t cublasCopyHelper(cublasHandle_t handle, int n, const doubl } cublasStatus_t cublasCopyHelper(cublasHandle_t handle, int n, const half* x, int incx, half* y, int incy); -// curand -inline curandStatus_t curandGenerateUniformHelper(curandGenerator_t generator, float* outputPtr, size_t num) { - return curandGenerateUniform(generator, outputPtr, num); -} -inline curandStatus_t curandGenerateUniformHelper(curandGenerator_t generator, double* outputPtr, size_t num) { - return curandGenerateUniformDouble(generator, outputPtr, num); -} -curandStatus_t curandGenerateUniformHelper(curandGenerator_t, half* outputPtr, size_t num); -inline curandStatus_t curandGenerateUniformHelper(curandGenerator_t, char*, size_t) { - ORT_NOT_IMPLEMENTED("Unsupported template argument(char) in GPUSparseMatrix"); -} -inline curandStatus_t curandGenerateUniformHelper(curandGenerator_t, short*, size_t) { - ORT_NOT_IMPLEMENTED("Unsupported template argument(short) in GPUSparseMatrix"); -} -inline curandStatus_t curandGenerateNormalHelper(curandGenerator_t generator, float* outputPtr, size_t n, float mean, float stddev) { - return curandGenerateNormal(generator, outputPtr, n, mean, stddev); -} -inline curandStatus_t curandGenerateNormalHelper(curandGenerator_t generator, double* outputPtr, size_t n, double mean, double stddev) { - return curandGenerateNormalDouble(generator, outputPtr, n, mean, stddev); -} -curandStatus_t curandGenerateNormalHelper(curandGenerator_t, half* outputPtr, size_t n, half mean, half stddev); - -inline curandStatus_t curandGenerateNormalHelper(curandGenerator_t, char*, size_t, char, char) { - ORT_NOT_IMPLEMENTED("Unsupported template argument(char) in GPUSparseMatrix"); -} - -inline curandStatus_t curandGenerateNormalHelper(curandGenerator_t, short*, size_t, short, short) { - ORT_NOT_IMPLEMENTED("Unsupported template argument(short) in GPUSparseMatrix"); -} diff --git a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.h b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.h index a974382268..4722ab6126 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.h +++ b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.h @@ -28,19 +28,6 @@ namespace onnxruntime { namespace cuda { -template -void HostApplyLayerNorm( - const cudaDeviceProp& prop, - T* output, - U* mean, - U* invvar, - const T* input, - int64_t n1, - int64_t n2, - double epsilon, - const T* gamma, - const T* beta); - template void HostLayerNormGradient( const cudaDeviceProp& prop,