mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
improve perf for softmax (#6128)
* improve perf for both gathergrad and softmax * revert the change in gathergrad and will be done in another PR. * address comments from code review.
This commit is contained in:
parent
ea9cfa554a
commit
53307a5f2e
3 changed files with 406 additions and 0 deletions
210
onnxruntime/core/providers/rocm/math/softmax_impl.cu
Normal file
210
onnxruntime/core/providers/rocm/math/softmax_impl.cu
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/* Modifications Copyright (c) Microsoft. */
|
||||
|
||||
// The code below is mostly copied from Pytorch PersistentSoftmax.cuh
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "core/providers/rocm/cu_inc/common.cuh"
|
||||
#include "core/providers/rocm/math/softmax_impl.cuh"
|
||||
#include "core/providers/rocm/math/softmax.h"
|
||||
|
||||
#include <limits>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
|
||||
// The softmax_warp_* methods perform softmax forward and backward propagation on samples spanning the fast dimension.
|
||||
// Each sample contains element_count scalar elements. element_count can be any integer value <= 1024.
|
||||
// The template arguments have the following meaning:
|
||||
// One "WARP" works on one "BATCH". One "BATCH" contains "WARP_BATCH" samples.
|
||||
// WARP_BATCH is equal to 1 when element_count is large, and > 1 when element_count is small.
|
||||
// A "WARP" contains "GPU_WARP_SIZE" threads, these treads are guaranteed to belong to the same warp.
|
||||
// This is important because it means only __shfl_ instructions are required for reductions.
|
||||
// Note that this means WARP_SIZE must be a power of two and <= architecture warp size.
|
||||
// ROCM warp size is 32 for all existing GPU architecures, but there is no guarantee this will not change for future arch.
|
||||
// is_log_softmax is a flag indicating whether SoftMax or LogSoftMax should be computed.
|
||||
// The template can be instantiated with any floating point type for the type arguments input_t, output_t and acc_t.
|
||||
// This allows SoftMax to be fused with a cast immediately following the SoftMax.
|
||||
// For instance:
|
||||
// input_t=half, acc_t=float, output_t=half => read half tensor, float accumulators, write half tensor.
|
||||
// input_t=half, acc_t=float, output_t=float => read half tensor, float accumulators, write float tensor.
|
||||
// input_t_float, acc_t=float, output_t=half => read float tensor, float accumulators, write half tensor.
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>
|
||||
__global__ void softmax_warp_forward(output_t* dst, const input_t* src, int batch_size, int stride, int element_count) {
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_forward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE = (next_power_of_two < GPU_WARP_SIZE) ? next_power_of_two : GPU_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
// constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int WARP_BATCH = 1;
|
||||
|
||||
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
|
||||
|
||||
// batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH)
|
||||
local_batches = WARP_BATCH;
|
||||
|
||||
// there might be multiple batches per warp. compute the index within the batch
|
||||
int local_idx = threadIdx.x;
|
||||
|
||||
src += first_batch * stride + local_idx;
|
||||
dst += first_batch * stride + local_idx;
|
||||
|
||||
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,
|
||||
// but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
|
||||
// the nested loops.
|
||||
// This should have no impact on performance because the loops are unrolled anyway.
|
||||
|
||||
// load data from global memory
|
||||
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : element_count;
|
||||
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
||||
int element_index = local_idx + it * WARP_SIZE;
|
||||
if (element_index < batch_element_count) {
|
||||
elements[i][it] = src[i * element_count + it * WARP_SIZE];
|
||||
} else {
|
||||
elements[i][it] = -std::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// compute max_value
|
||||
acc_t max_value[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
max_value[i] = elements[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
|
||||
|
||||
acc_t sum[WARP_BATCH]{0.0f};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
||||
if (is_log_softmax) {
|
||||
sum[i] += expf((float)(elements[i][it] - max_value[i]));
|
||||
} else {
|
||||
elements[i][it] = expf((float)(elements[i][it] - max_value[i]));
|
||||
sum[i] += elements[i][it];
|
||||
}
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches)
|
||||
break;
|
||||
if (is_log_softmax) sum[i] = max_value[i] + logf((float)(sum[i]));
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
||||
int element_index = local_idx + it * WARP_SIZE;
|
||||
if (element_index < element_count) {
|
||||
if (is_log_softmax) {
|
||||
dst[i * element_count + it * WARP_SIZE] = elements[i][it] - sum[i];
|
||||
} else {
|
||||
dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
|
||||
void dispatch_softmax_forward(output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count) {
|
||||
if (softmax_elements == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(softmax_elements);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
|
||||
int warp_size = (next_power_of_two < GPU_WARP_SIZE) ? next_power_of_two : GPU_WARP_SIZE;
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
|
||||
// int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
int batches_per_warp = 1;
|
||||
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 256;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward<input_t, output_t, acc_t, 0, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 1: // 2
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward<input_t, output_t, acc_t, 1, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 2: // 4
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward<input_t, output_t, acc_t, 2, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 3: // 8
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward<input_t, output_t, acc_t, 3, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 4: // 16
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward<input_t, output_t, acc_t, 4, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 5: // 32
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward<input_t, output_t, acc_t, 5, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 6: // 64
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward<input_t, output_t, acc_t, 6, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 7: // 128
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward<input_t, output_t, acc_t, 7, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 8: // 256
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward<input_t, output_t, acc_t, 8, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 9: // 512
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward<input_t, output_t, acc_t, 9, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 10: // 1024
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward<input_t, output_t, acc_t, 10, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define SPECIALIZED_SOFTMAX_IMPL(input_t, output_t, acc_t) \
|
||||
template void dispatch_softmax_forward<input_t, output_t, acc_t, false>(output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count); \
|
||||
template void dispatch_softmax_forward<input_t, output_t, acc_t, true>(output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count);
|
||||
|
||||
SPECIALIZED_SOFTMAX_IMPL(float, float, float)
|
||||
SPECIALIZED_SOFTMAX_IMPL(half, half, float)
|
||||
SPECIALIZED_SOFTMAX_IMPL(double, double, double)
|
||||
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,194 @@
|
|||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/* Modifications Copyright (c) Microsoft. */
|
||||
|
||||
// The code below is mostly copied from Pytorch PersistentSoftmax.cuh
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "orttraining/training_ops/rocm/math/softmax_grad.h"
|
||||
#include "core/providers/rocm/cu_inc/common.cuh"
|
||||
#include "core/providers/rocm/math/softmax_impl.cuh"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>
|
||||
__global__ void softmax_warp_backward(output_t* gradInput, const input_t* grad, const input_t* output, int batch_size, int stride, int element_count) {
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE = (next_power_of_two < GPU_WARP_SIZE) ? next_power_of_two : GPU_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
// constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int WARP_BATCH = 1;
|
||||
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
|
||||
|
||||
// batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH)
|
||||
local_batches = WARP_BATCH;
|
||||
|
||||
// there might be multiple batches per warp. compute the index within the batch
|
||||
int local_idx = threadIdx.x % WARP_SIZE;
|
||||
|
||||
// the first element to process by the current thread
|
||||
int thread_offset = first_batch * stride + local_idx;
|
||||
grad += thread_offset;
|
||||
output += thread_offset;
|
||||
gradInput += thread_offset;
|
||||
|
||||
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,
|
||||
// but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
|
||||
// the nested loops.
|
||||
// This should have no impact on performance because the loops are unrolled anyway.
|
||||
|
||||
// load data from global memory
|
||||
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];
|
||||
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];
|
||||
acc_t grad_output_reg[WARP_BATCH][WARP_ITERATIONS];
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : element_count;
|
||||
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
||||
int element_index = local_idx + it * WARP_SIZE;
|
||||
if (element_index < batch_element_count) {
|
||||
grad_reg[i][it] = grad[i * element_count + it * WARP_SIZE];
|
||||
output_reg[i][it] = output[i * element_count + it * WARP_SIZE];
|
||||
grad_output_reg[i][it] = grad_reg[i][it] * output_reg[i][it];
|
||||
} else {
|
||||
grad_reg[i][it] = acc_t(0);
|
||||
output_reg[i][it] = acc_t(0);
|
||||
grad_output_reg[i][it] = acc_t(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
acc_t sum[WARP_BATCH];
|
||||
if (!is_log_softmax) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
sum[i] = grad_output_reg[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
sum[i] += grad_output_reg[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
}
|
||||
else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
sum[i] = grad_reg[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
sum[i] += grad_reg[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
}
|
||||
|
||||
// store result
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches)
|
||||
break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
||||
int element_index = local_idx + it * WARP_SIZE;
|
||||
if (element_index < element_count) {
|
||||
// compute gradients
|
||||
if (is_log_softmax) {
|
||||
gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - expf(output_reg[i][it]) * sum[i]);
|
||||
} else {
|
||||
gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - sum[i] ) * output_reg[i][it];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
|
||||
void dispatch_softmax_backward(output_t* grad_input, const input_t* grad, const input_t* output, int softmax_elements, int softmax_elements_stride, int batch_count) {
|
||||
if (softmax_elements == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(softmax_elements);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
|
||||
int warp_size = (next_power_of_two < GPU_WARP_SIZE) ? next_power_of_two : GPU_WARP_SIZE;
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
|
||||
// int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
int batches_per_warp = 1;
|
||||
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 256;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward<input_t, output_t, acc_t, 0, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 1: // 2
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward<input_t, output_t, acc_t, 1, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 2: // 4
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward<input_t, output_t, acc_t, 2, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 3: // 8
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward<input_t, output_t, acc_t, 3, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 4: // 16
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward<input_t, output_t, acc_t, 4, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 5: // 32
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward<input_t, output_t, acc_t, 5, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 6: // 64
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward<input_t, output_t, acc_t, 6, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 7: // 128
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward<input_t, output_t, acc_t, 7, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 8: // 256
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward<input_t, output_t, acc_t, 8, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 9: // 512
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward<input_t, output_t, acc_t, 9, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 10: // 1024
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward<input_t, output_t, acc_t, 10, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define SPECIALIZED_SOFTMAX_GRAD_IMPL(input_t, output_t, acc_t) \
|
||||
template void dispatch_softmax_backward<input_t, output_t, acc_t, false>(input_t * grad_input, const output_t* grad, const output_t* output, int softmax_elements, int softmax_elements_stride, int batch_count); \
|
||||
template void dispatch_softmax_backward<input_t, output_t, acc_t, true>(input_t * grad_input, const output_t* grad, const output_t* output, int softmax_elements, int softmax_elements_stride, int batch_count);
|
||||
|
||||
SPECIALIZED_SOFTMAX_GRAD_IMPL(float, float, float)
|
||||
SPECIALIZED_SOFTMAX_GRAD_IMPL(half, half, float)
|
||||
SPECIALIZED_SOFTMAX_GRAD_IMPL(double, double, double)
|
||||
|
||||
}
|
||||
}
|
||||
|
|
@ -105,6 +105,7 @@ provider_excluded_files = [
|
|||
'math/matmul_integer.cu',
|
||||
'math/matmul_integer.cuh',
|
||||
'math/matmul_integer.h',
|
||||
'math/softmax_impl.cu',
|
||||
'math/softmax.cc',
|
||||
'math/topk.cc',
|
||||
'math/topk.h',
|
||||
|
|
@ -260,6 +261,7 @@ training_ops_excluded_files = [
|
|||
'math/scale.cc',
|
||||
'math/scale.cu',
|
||||
'math/scale.h',
|
||||
'math/softmax_grad_impl.cu',
|
||||
'math/softmax_grad.cc',
|
||||
'nn/batch_norm_grad.cc',
|
||||
'nn/batch_norm_grad.h',
|
||||
|
|
|
|||
Loading…
Reference in a new issue