diff --git a/onnxruntime/core/providers/rocm/math/softmax_impl.cu b/onnxruntime/core/providers/rocm/math/softmax_impl.cu new file mode 100644 index 0000000000..94f8e4fc54 --- /dev/null +++ b/onnxruntime/core/providers/rocm/math/softmax_impl.cu @@ -0,0 +1,210 @@ +/** +* Copyright (c) 2016-present, Facebook, Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +/* Modifications Copyright (c) Microsoft. */ + +// The code below is mostly copied from Pytorch PersistentSoftmax.cuh +#include "hip/hip_runtime.h" +#include "core/providers/rocm/cu_inc/common.cuh" +#include "core/providers/rocm/math/softmax_impl.cuh" +#include "core/providers/rocm/math/softmax.h" + +#include + +namespace onnxruntime { +namespace rocm { + +// The softmax_warp_* methods perform softmax forward and backward propagation on samples spanning the fast dimension. +// Each sample contains element_count scalar elements. element_count can be any integer value <= 1024. +// The template arguments have the following meaning: +// One "WARP" works on one "BATCH". One "BATCH" contains "WARP_BATCH" samples. +// WARP_BATCH is equal to 1 when element_count is large, and > 1 when element_count is small. +// A "WARP" contains "GPU_WARP_SIZE" threads, these treads are guaranteed to belong to the same warp. +// This is important because it means only __shfl_ instructions are required for reductions. +// Note that this means WARP_SIZE must be a power of two and <= architecture warp size. +// ROCM warp size is 32 for all existing GPU architecures, but there is no guarantee this will not change for future arch. +// is_log_softmax is a flag indicating whether SoftMax or LogSoftMax should be computed. +// The template can be instantiated with any floating point type for the type arguments input_t, output_t and acc_t. +// This allows SoftMax to be fused with a cast immediately following the SoftMax. +// For instance: +// input_t=half, acc_t=float, output_t=half => read half tensor, float accumulators, write half tensor. +// input_t=half, acc_t=float, output_t=float => read half tensor, float accumulators, write float tensor. +// input_t_float, acc_t=float, output_t=half => read float tensor, float accumulators, write half tensor. + +template +__global__ void softmax_warp_forward(output_t* dst, const input_t* src, int batch_size, int stride, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < GPU_WARP_SIZE) ? next_power_of_two : GPU_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + // constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int WARP_BATCH = 1; + + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + src += first_batch * stride + local_idx; + dst += first_batch * stride + local_idx; + + // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop, + // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep + // the nested loops. + // This should have no impact on performance because the loops are unrolled anyway. + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + elements[i][it] = src[i * element_count + it * WARP_SIZE]; + } else { + elements[i][it] = -std::numeric_limits::infinity(); + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH]{0.0f}; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (is_log_softmax) { + sum[i] += expf((float)(elements[i][it] - max_value[i])); + } else { + elements[i][it] = expf((float)(elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + } + warp_reduce(sum); + +// store result +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + if (is_log_softmax) sum[i] = max_value[i] + logf((float)(sum[i])); +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * WARP_SIZE; + if (element_index < element_count) { + if (is_log_softmax) { + dst[i * element_count + it * WARP_SIZE] = elements[i][it] - sum[i]; + } else { + dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i]; + } + } else { + break; + } + } + } +} + +template +void dispatch_softmax_forward(output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count) { + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < GPU_WARP_SIZE) ? next_power_of_two : GPU_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + // int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + int batches_per_warp = 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 256; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = (batch_count + batches_per_block - 1) / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements); + break; + default: + break; + } + } +} + +#define SPECIALIZED_SOFTMAX_IMPL(input_t, output_t, acc_t) \ +template void dispatch_softmax_forward(output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count); \ +template void dispatch_softmax_forward(output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count); + +SPECIALIZED_SOFTMAX_IMPL(float, float, float) +SPECIALIZED_SOFTMAX_IMPL(half, half, float) +SPECIALIZED_SOFTMAX_IMPL(double, double, double) + +} +} \ No newline at end of file diff --git a/orttraining/orttraining/training_ops/rocm/math/softmax_grad_impl.cu b/orttraining/orttraining/training_ops/rocm/math/softmax_grad_impl.cu new file mode 100644 index 0000000000..c9c60c0706 --- /dev/null +++ b/orttraining/orttraining/training_ops/rocm/math/softmax_grad_impl.cu @@ -0,0 +1,194 @@ +/** +* Copyright (c) 2016-present, Facebook, Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +/* Modifications Copyright (c) Microsoft. */ + +// The code below is mostly copied from Pytorch PersistentSoftmax.cuh +#include "hip/hip_runtime.h" +#include "orttraining/training_ops/rocm/math/softmax_grad.h" +#include "core/providers/rocm/cu_inc/common.cuh" +#include "core/providers/rocm/math/softmax_impl.cuh" + +namespace onnxruntime { +namespace rocm { + +template +__global__ void softmax_warp_backward(output_t* gradInput, const input_t* grad, const input_t* output, int batch_size, int stride, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < GPU_WARP_SIZE) ? next_power_of_two : GPU_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + // constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int WARP_BATCH = 1; + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x % WARP_SIZE; + + // the first element to process by the current thread + int thread_offset = first_batch * stride + local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop, + // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep + // the nested loops. + // This should have no impact on performance because the loops are unrolled anyway. + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; + acc_t grad_output_reg[WARP_BATCH][WARP_ITERATIONS]; + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + grad_reg[i][it] = grad[i * element_count + it * WARP_SIZE]; + output_reg[i][it] = output[i * element_count + it * WARP_SIZE]; + grad_output_reg[i][it] = grad_reg[i][it] * output_reg[i][it]; + } else { + grad_reg[i][it] = acc_t(0); + output_reg[i][it] = acc_t(0); + grad_output_reg[i][it] = acc_t(0); + } + } + } + + acc_t sum[WARP_BATCH]; + if (!is_log_softmax) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_output_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_output_reg[i][it]; + } + } + warp_reduce(sum); + } + else { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + } + +// store result +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + if (is_log_softmax) { + gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - expf(output_reg[i][it]) * sum[i]); + } else { + gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - sum[i] ) * output_reg[i][it]; + } + } + } + } +} + +template +void dispatch_softmax_backward(output_t* grad_input, const input_t* grad, const input_t* output, int softmax_elements, int softmax_elements_stride, int batch_count) { + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < GPU_WARP_SIZE) ? next_power_of_two : GPU_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + // int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + int batches_per_warp = 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 256; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = (batch_count + batches_per_block - 1) / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); + break; + default: + break; + } + } +} + +#define SPECIALIZED_SOFTMAX_GRAD_IMPL(input_t, output_t, acc_t) \ +template void dispatch_softmax_backward(input_t * grad_input, const output_t* grad, const output_t* output, int softmax_elements, int softmax_elements_stride, int batch_count); \ +template void dispatch_softmax_backward(input_t * grad_input, const output_t* grad, const output_t* output, int softmax_elements, int softmax_elements_stride, int batch_count); + +SPECIALIZED_SOFTMAX_GRAD_IMPL(float, float, float) +SPECIALIZED_SOFTMAX_GRAD_IMPL(half, half, float) +SPECIALIZED_SOFTMAX_GRAD_IMPL(double, double, double) + +} +} \ No newline at end of file diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index 7c4fd7d7b1..d471019b1f 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -105,6 +105,7 @@ provider_excluded_files = [ 'math/matmul_integer.cu', 'math/matmul_integer.cuh', 'math/matmul_integer.h', + 'math/softmax_impl.cu', 'math/softmax.cc', 'math/topk.cc', 'math/topk.h', @@ -260,6 +261,7 @@ training_ops_excluded_files = [ 'math/scale.cc', 'math/scale.cu', 'math/scale.h', + 'math/softmax_grad_impl.cu', 'math/softmax_grad.cc', 'nn/batch_norm_grad.cc', 'nn/batch_norm_grad.h',