Update reduction_all.cu

This commit is contained in:
Jesse Benson 2021-01-12 18:18:57 -08:00 committed by Jesse Benson
parent 5fc377f21e
commit c4b6559be9

View file

@ -12,12 +12,12 @@
namespace onnxruntime {
namespace rocm {
template<typename Tin, typename Tout>
template <typename Tin, typename Tout>
__global__ void ScalarSqrtKernel(Tin* input, Tout* output) {
*output = (Tout)_Sqrt(*input);
}
template<typename Tin, typename Tout>
template <typename Tin, typename Tout>
void ScalarSqrt(Tin* input, Tout* output) {
hipLaunchKernelGGL(ScalarSqrtKernel, dim3(1), dim3(1), 0, 0, input, output);
}
@ -61,7 +61,7 @@ __global__ void MultiTensorReduceKernel(ChunkGroup<1> chunk_group, TOut* output)
const int wid = threadIdx.x / GPU_WARP_SIZE;
// Shape is 2 x warp_count_in_block.
HIP_DYNAMIC_SHARED( unsigned char, shared_memory_)
extern __shared__ unsigned char shared_memory_[];
TBuf* shared_memory = reinterpret_cast<TBuf*>(shared_memory_);
if (lid == 0) {
@ -79,7 +79,7 @@ __global__ void MultiTensorReduceKernel(ChunkGroup<1> chunk_group, TOut* output)
}
if (threadIdx.x == 0) {
atomic_add(w_norm, TOutOp()(shared_memory[0]));
atomic_add(w_norm, TOutOp()(TOut(shared_memory[0])));
}
}
@ -100,7 +100,7 @@ void MultiTensorReduce(ChunkGroup<1> chunk_group, TOut* output) {
template <typename TIn, typename TOut>
void MultiTensorReduceL2<TIn, TOut>::operator()(ChunkGroup<1> chunk_group, TOut* output) {
using TBuf = AccumulationType_t<TIn>;
MultiTensorReduce<TIn, TOut, TBuf, Square<TBuf, TIn>, Cast<TOut, TBuf>>(chunk_group, output);
MultiTensorReduce<TIn, TOut, TBuf, Square2, Identity2>(chunk_group, output);
}
#define INSTANTIATE_MULTI_TENSOR_REDUCTION_L2_FUNCTOR(TIn, TOut) \