mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-02 23:39:58 +00:00
Update reduction_all.cu
This commit is contained in:
parent
5fc377f21e
commit
c4b6559be9
1 changed files with 5 additions and 5 deletions
|
|
@ -12,12 +12,12 @@
|
|||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
|
||||
template<typename Tin, typename Tout>
|
||||
template <typename Tin, typename Tout>
|
||||
__global__ void ScalarSqrtKernel(Tin* input, Tout* output) {
|
||||
*output = (Tout)_Sqrt(*input);
|
||||
}
|
||||
|
||||
template<typename Tin, typename Tout>
|
||||
template <typename Tin, typename Tout>
|
||||
void ScalarSqrt(Tin* input, Tout* output) {
|
||||
hipLaunchKernelGGL(ScalarSqrtKernel, dim3(1), dim3(1), 0, 0, input, output);
|
||||
}
|
||||
|
|
@ -61,7 +61,7 @@ __global__ void MultiTensorReduceKernel(ChunkGroup<1> chunk_group, TOut* output)
|
|||
const int wid = threadIdx.x / GPU_WARP_SIZE;
|
||||
|
||||
// Shape is 2 x warp_count_in_block.
|
||||
HIP_DYNAMIC_SHARED( unsigned char, shared_memory_)
|
||||
extern __shared__ unsigned char shared_memory_[];
|
||||
TBuf* shared_memory = reinterpret_cast<TBuf*>(shared_memory_);
|
||||
|
||||
if (lid == 0) {
|
||||
|
|
@ -79,7 +79,7 @@ __global__ void MultiTensorReduceKernel(ChunkGroup<1> chunk_group, TOut* output)
|
|||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
atomic_add(w_norm, TOutOp()(shared_memory[0]));
|
||||
atomic_add(w_norm, TOutOp()(TOut(shared_memory[0])));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -100,7 +100,7 @@ void MultiTensorReduce(ChunkGroup<1> chunk_group, TOut* output) {
|
|||
template <typename TIn, typename TOut>
|
||||
void MultiTensorReduceL2<TIn, TOut>::operator()(ChunkGroup<1> chunk_group, TOut* output) {
|
||||
using TBuf = AccumulationType_t<TIn>;
|
||||
MultiTensorReduce<TIn, TOut, TBuf, Square<TBuf, TIn>, Cast<TOut, TBuf>>(chunk_group, output);
|
||||
MultiTensorReduce<TIn, TOut, TBuf, Square2, Identity2>(chunk_group, output);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MULTI_TENSOR_REDUCTION_L2_FUNCTOR(TIn, TOut) \
|
||||
|
|
|
|||
Loading…
Reference in a new issue