pytorch/caffe2/utils/GpuScanUtils.cuh
Pruthvi Madugundu 085e2f7bdd [ROCm] Changes not to rely on CUDA_VERSION or HIP_VERSION (#65610)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65610

- Replace HIP_PLATFORM_HCC with USE_ROCM
- Dont rely on CUDA_VERSION or HIP_VERSION and use USE_ROCM and ROCM_VERSION.

- In the next PR
   - Will be removing the mapping from CUDA_VERSION to HIP_VERSION and CUDA to HIP in hipify.
   - HIP_PLATFORM_HCC is deprecated, so will add HIP_PLATFORM_AMD to support HIP host code compilation on gcc.

cc jeffdaily sunway513 jithunnair-amd ROCmSupport amathews-amd

Reviewed By: jbschlosser

Differential Revision: D30909053

Pulled By: ezyang

fbshipit-source-id: 224a966ebf1aaec79beccbbd686fdf3d49267e06
2021-09-29 09:55:43 -07:00

133 lines
3.5 KiB
Text

#ifndef CAFFE2_UTILS_GPU_SCAN_UTILS_H_
#define CAFFE2_UTILS_GPU_SCAN_UTILS_H_
#include "caffe2/utils/GpuDefs.cuh"
namespace caffe2 {
// from the cutorch library; can probably be replaced with their CUB
// equivalents
// Collection of in-kernel scan / prefix sum utilities
// Inclusive prefix sum using shared memory
template <typename T, bool KillWARDependency, class BinaryFunction>
__device__ void inclusivePrefixScan(T* smem, T in, T* out, BinaryFunction binop) {
// FIXME: this is a slow, simple implementation; need up/down sweep,
// prevent smem conflicts
smem[threadIdx.x] = in;
__syncthreads();
for (int offset = 1; offset < blockDim.x; offset *= 2) {
T val = 0;
if (threadIdx.x >= offset) {
val = binop(smem[threadIdx.x - offset], smem[threadIdx.x]);
}
__syncthreads();
if (threadIdx.x >= offset) {
smem[threadIdx.x] = val;
}
__syncthreads();
}
*out = smem[threadIdx.x];
// Prevent write-after-read dependencies on smem usage above if necessary
if (KillWARDependency) {
__syncthreads();
}
}
// Exclusive prefix sum using shared memory
template <typename T, bool KillWARDependency, class BinaryFunction>
__device__ void exclusivePrefixScan(T* smem, T in, T* out, T* carry, BinaryFunction binop) {
// FIXME: crappy implementation
// We kill write-after-read dependencies separately below, hence the `false`
inclusivePrefixScan<T, false, BinaryFunction>(smem, in, out, binop);
*out -= in;
*carry = smem[blockDim.x - 1];
// Prevent write-after-read dependencies on smem usage above if necessary
if (KillWARDependency) {
__syncthreads();
}
}
// Inclusive prefix sum for binary vars using intra-warp voting +
// shared memory
template <typename T, bool KillWARDependency, class BinaryFunction>
__device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) {
// Within-warp, we use warp voting.
#if defined(USE_ROCM)
unsigned long long int vote = __ballot(in);
T index = __popcll(getLaneMaskLe() & vote);
T carry = __popcll(vote);
#else
T vote = __ballot_sync(__activemask(), in);
T index = __popc(getLaneMaskLe() & vote);
T carry = __popc(vote);
#endif // USE_ROCM
int warp = threadIdx.x / kWarpSize;
// Per each warp, write out a value
if (getLaneId() == 0) {
smem[warp] = carry;
}
__syncthreads();
// Sum across warps in one thread. This appears to be faster than a
// warp shuffle scan for CC 3.0+
if (threadIdx.x == 0) {
int current = 0;
for (int i = 0; i < blockDim.x / kWarpSize; ++i) {
T v = smem[i];
smem[i] = binop(smem[i], current);
current = binop(current, v);
}
}
__syncthreads();
// load the carry from the preceding warp
if (warp >= 1) {
index = binop(index, smem[warp - 1]);
}
*out = index;
if (KillWARDependency) {
__syncthreads();
}
}
// Exclusive prefix sum for binary vars using intra-warp voting +
// shared memory
template <typename T, bool KillWARDependency, class BinaryFunction>
__device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, BinaryFunction binop) {
inclusiveBinaryPrefixScan<T, false, BinaryFunction>(smem, in, out, binop);
// Inclusive to exclusive
*out -= (T) in;
// The outgoing carry for all threads is the last warp's sum
#if defined(USE_ROCM)
*carry = smem[math::DivUp<int>(blockDim.x, kWarpSize) - 1];
#else
*carry = smem[(blockDim.x / kWarpSize) - 1];
#endif // USE_ROCM
if (KillWARDependency) {
__syncthreads();
}
}
} // namespace caffe2
#endif // CAFFE2_UTILS_GPU_SCAN_UTILS_H_