Implements nonzero_static on cuda (#141838)

using blockwide cub primitives.
This adds CUDA functionality for nonzero_static, which was missing in https://github.com/pytorch/pytorch/pull/97417.

For `size` approx equal to number of nonzeros, the perf is very close to the regular version, for larger sizes filling in padding  indices takes additional time.
Disabled for cuda <=11.4

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141838
Approved by: https://github.com/ezyang, https://github.com/malfet
This commit is contained in:
Natalia Gimelshein 2024-12-11 06:44:48 +00:00 committed by PyTorch MergeBot
parent 1d3b0108a6
commit bdbdbeeb3d
5 changed files with 448 additions and 166 deletions

View file

@ -349,7 +349,7 @@ __global__ void final_scan_kernel(const T* d_in, T* d_out, T* agg, int64_t nelem
// Per-thread tile data
T data[ITEMS_PER_THREAD];
int remaining = nelem - BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x;
int64_t remaining = nelem - BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x;
for (int i=0; i<iters_per_cta; i++){
// Load items into a blocked arrangement
if (remaining >= BLOCK_THREADS * ITEMS_PER_THREAD) {
@ -386,38 +386,57 @@ __global__ void final_scan_kernel(const T* d_in, T* d_out, T* agg, int64_t nelem
}
template <typename T, typename aggT, bool nonzero>
struct TransformFunctor {
__device__ aggT operator()(T value) const {
if constexpr (!nonzero) {
return value;
} else {
return (value != T(0)) ? 1 : 0;
}
}
};
template<int BLOCK_THREADS, int ITEMS_PER_THREAD, typename T>
__global__ void calc_block_sums(const T * d_in, T * agg, int64_t nelem, int iters_per_cta){
template<int BLOCK_THREADS, int ITEMS_PER_THREAD, bool nonzero, typename T, typename aggT>
__global__ void calc_block_sums(const T * d_in, aggT * agg, int64_t nelem, int iters_per_cta){
if (BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x >= nelem) return;
d_in += BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x;
d_in += BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * (int64_t)blockIdx.x;
using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad<T, BLOCK_THREADS, ITEMS_PER_THREAD, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_LOAD_STRIPED>;
using BlockReduceT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockReduce<T, BLOCK_THREADS>;
using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad<aggT, BLOCK_THREADS, ITEMS_PER_THREAD, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_LOAD_STRIPED>;
using BlockReduceT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockReduce<aggT, BLOCK_THREADS>;
// Shared memory
__shared__ union TempStorage
{
typename BlockLoadT::TempStorage load;
typename BlockReduceT::TempStorage reduce;
} temp_storage;
T data[ITEMS_PER_THREAD];
T agg_val = 0;
int64_t remaining = nelem - BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x;
aggT data[ITEMS_PER_THREAD];
aggT agg_val = 0;
int64_t remaining = nelem - BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * (int64_t)blockIdx.x;
TransformFunctor<T, aggT, nonzero> transform_functor;
auto iter_in = ROCM_HIPCUB(at_cuda_detail::cub)::TransformInputIterator<aggT, TransformFunctor<T, aggT, nonzero>, const T*>(d_in, transform_functor);
for (int i=0; i<iters_per_cta; i++){
if (remaining >= BLOCK_THREADS * ITEMS_PER_THREAD) {
BlockLoadT(temp_storage.load).Load(d_in, data);
BlockLoadT(temp_storage.load).Load(iter_in, data);
__syncthreads();
agg_val += BlockReduceT(temp_storage.reduce).Sum(data);
} else {
BlockLoadT(temp_storage.load).Load(d_in, data, remaining);
BlockLoadT(temp_storage.load).Load(iter_in, data, remaining, aggT(0));
__syncthreads();
agg_val += BlockReduceT(temp_storage.reduce).Sum(data);
}
d_in += BLOCK_THREADS * ITEMS_PER_THREAD;
iter_in += BLOCK_THREADS * ITEMS_PER_THREAD;
remaining -= BLOCK_THREADS * ITEMS_PER_THREAD;
if (remaining <= 0) return;
if (remaining <= 0) {
// for nonzeros we need to write out last blocks
// accumulated value to be able to compute
// total number of nonzeros
if (nonzero && threadIdx.x == 0) {
agg[blockIdx.x] = agg_val;
}
return;
}
__syncthreads();
}
@ -427,6 +446,13 @@ __global__ void calc_block_sums(const T * d_in, T * agg, int64_t nelem, int iter
}
template <typename T>
struct NonZeroOp {
__host__ __device__ __forceinline__ int operator()(const T& a) const {
return (a != T(0));
}
};
template<int size>
constexpr int block_threads(){
if constexpr (size >=16) {
@ -450,7 +476,7 @@ inline void inclusive_deterministic_scan(const scalar_t * input, scalar_t * out
grid_size = std::min(num_sms, grid_size);
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
auto agg = allocator.allocate(grid_size * sizeof(scalar_t));
calc_block_sums<BLOCK_THREADS, ITEMS_PER_THREAD>
calc_block_sums<BLOCK_THREADS, ITEMS_PER_THREAD, false>
<<<grid_size, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
input, (scalar_t*)agg.get(), num_items, iters_per_cta);
C10_CUDA_KERNEL_LAUNCH_CHECK();

View file

@ -37,9 +37,12 @@ __global__ void write_indices(
int64_t* inp,
TensorDims<index_t> dims,
int ndim,
index_t n) {
auto index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < n) {
index_t n,
int64_t * total = nullptr,
int64_t fill_value = -1) {
auto index = threadIdx.x + (int64_t)blockIdx.x * blockDim.x;
bool cond = (total == nullptr || index < *total);
if (index < n && cond) {
index_t div = 1;
int64_t idx_flat = inp[index];
#pragma unroll
@ -50,9 +53,117 @@ __global__ void write_indices(
inp[index + dim * n] = (idx_flat / div) % dim_size;
div *= dim_size;
}
} else if (index < n) {
// 0th dim has correct values already
for (int dim = ndim - 1; dim > 0; dim--) {
inp[index + dim * n] = fill_value;
}
}
}
__global__ void write_fill_value(int64_t * inp, int64_t * total, int64_t fill_value, int64_t n){
int64_t total_val = *total;
// not aiming for vectorized stores
for (int64_t idx = total_val + (int64_t)blockIdx.x * blockDim.x + threadIdx.x; idx < n; idx += blockDim.x * gridDim.x) {
inp[idx] = fill_value;
}
}
template <int BLOCK_THREADS>
__global__ void compute_agg(int32_t * agg, int64_t * agg_cum, uint32_t n_blocks) {
using BlockScanT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockScan<int64_t, BLOCK_THREADS, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_SCAN_WARP_SCANS>;
__shared__ typename BlockScanT::TempStorage temp_storage;
int agg_data;
int64_t agg_cum_data;
agg_data = threadIdx.x < n_blocks ? agg[threadIdx.x] : 0;
BlockScanT(temp_storage).InclusiveSum(agg_data, agg_cum_data);
if (threadIdx.x < n_blocks) {
agg_cum[threadIdx.x] = agg_cum_data;
}
}
template<int BLOCK_THREADS, int ITEMS_PER_THREAD, typename T>
__global__ void flag_kernel(const T* d_in, int64_t * d_out, const int64_t * agg, int64_t input_nelem, int64_t output_nelem, int iters_per_cta) {
int64_t start_idx = BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * (int64_t)blockIdx.x;
if (start_idx >= input_nelem) return;
d_in += start_idx;
using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad<int, BLOCK_THREADS, ITEMS_PER_THREAD, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_LOAD_WARP_TRANSPOSE>;
// Specialize BlockScan type for our thread block
using BlockScanT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockScan<int, BLOCK_THREADS, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_SCAN_WARP_SCANS>;
using TransformInputIteratorT = ROCM_HIPCUB(at_cuda_detail::cub)::TransformInputIterator<int, NonZeroOp<T>, const T*>;
using BlockExchangeT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockExchange<int, BLOCK_THREADS, ITEMS_PER_THREAD>;
// Shared memory
__shared__ union TempStorage
{
typename BlockLoadT::TempStorage load;
typename BlockScanT::TempStorage scan;
typename BlockExchangeT::TempStorage exchange;
} temp_storage;
int64_t aggregate = blockIdx.x == 0 ? 0 : agg[blockIdx.x - 1];
d_out += aggregate;
TransformInputIteratorT t_input_itr(d_in, NonZeroOp<T>());
// Per-thread tile data
int data[ITEMS_PER_THREAD];
int out_indices[ITEMS_PER_THREAD];
int64_t remaining = input_nelem - start_idx;
int64_t out_remaining = output_nelem - aggregate;
for (int i=0; i<iters_per_cta; i++){
// Load items into a blocked arrangement
if (remaining >= BLOCK_THREADS * ITEMS_PER_THREAD) {
BlockLoadT(temp_storage.load).Load(t_input_itr, data);
} else {
BlockLoadT(temp_storage.load).Load(t_input_itr, data, remaining, int(0));
}
// Barrier for smem reuse
__syncthreads();
// Compute inclusive prefix sum
int aggregate;
__shared__ int aggregate_sh;
BlockScanT(temp_storage.scan).ExclusiveSum(data, out_indices, aggregate);
if (threadIdx.x == 0){
aggregate_sh = aggregate;
}
// Barrier for smem reuse
__syncthreads();
// striped arrangement will provide a slightly better
// coalescing for writes (although it's still bad because it's indirect indexing)
BlockExchangeT(temp_storage.exchange).BlockedToStriped(data);
__syncthreads();
BlockExchangeT(temp_storage.exchange).BlockedToStriped(out_indices);
for (int ii=0; ii<ITEMS_PER_THREAD; ii++){
if (data[ii] != 0 && out_indices[ii] < out_remaining) {
int64_t inp_idx = start_idx + threadIdx.x + blockDim.x * ii;
d_out[out_indices[ii]] = inp_idx;
}
}
out_remaining -= aggregate_sh;
remaining -= BLOCK_THREADS * ITEMS_PER_THREAD;
if (remaining <= 0 || out_remaining <= 0) return;
d_out += aggregate_sh;
t_input_itr += BLOCK_THREADS * ITEMS_PER_THREAD;
start_idx += BLOCK_THREADS * ITEMS_PER_THREAD;
__syncthreads();
}
}
} // anonymous namespace
template <typename scalar_t>
@ -183,6 +294,83 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) {
}
}
template <typename scalar_t>
void nonzero_static_cuda_out_impl(
const Tensor& self,
int64_t size,
int64_t fill_value,
Tensor& out) {
# if (defined(CUDA_VERSION) && CUDA_VERSION > 11040) || defined(USE_ROCM)
Tensor self_contiguous_ = self.contiguous();
// see comment in nonzero_cuda_out_impl on reqs for out
bool out_correct_size =
out.dim() == 2 && out.sizes()[0] == size && out.sizes()[1] == self.dim();
bool need_to_copy = out_correct_size && !out.t().is_contiguous();
if (!out_correct_size) {
out.resize_({self.dim(), size}).t();
}
if (out.numel() == 0) return;
// we need to allocate temporary out to then copy to user provided out
at::Tensor out_temp;
if (need_to_copy) {
out_temp =
Tensor(at::detail::empty_cuda({self.dim(), size}, out.options())).t();
}
int64_t* out_data_ptr = need_to_copy ? out_temp.mutable_data_ptr<int64_t>()
: out.mutable_data_ptr<int64_t>();
const scalar_t * in_data_ptr = self_contiguous_.const_data_ptr<scalar_t>();
constexpr int BLOCK_THREADS = 512; //block_threads<sizeof(scalar_t)>();
constexpr int ITEMS_PER_THREAD = 16;
auto grid_size = (self.numel() + BLOCK_THREADS * ITEMS_PER_THREAD - 1) / (BLOCK_THREADS * ITEMS_PER_THREAD);
const int64_t num_sms = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
int64_t target_blocks = sizeof(scalar_t) == 1 ? 2 * num_sms : num_sms;
const int iters_per_cta = (grid_size + target_blocks - 1)/target_blocks;
grid_size = (self.numel() + iters_per_cta * BLOCK_THREADS * ITEMS_PER_THREAD - 1) / (iters_per_cta * BLOCK_THREADS * ITEMS_PER_THREAD);
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
auto agg = allocator.allocate(grid_size * sizeof(int));
at::cuda::cub::calc_block_sums<BLOCK_THREADS, ITEMS_PER_THREAD, true>
<<<grid_size, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
in_data_ptr, (int*)agg.get(), self.numel(), iters_per_cta);
C10_CUDA_KERNEL_LAUNCH_CHECK();
auto agg_cum = allocator.allocate(grid_size * sizeof(int64_t));
// computing partial sums in int64 in the flag kernel
// leads to 20-30% slowdown, so compute them in a separate 2 us kernel
compute_agg<BLOCK_THREADS><<<1, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
(int*)agg.get(), (int64_t*)agg_cum.get(), grid_size
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
flag_kernel<BLOCK_THREADS, ITEMS_PER_THREAD>
<<<grid_size, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
in_data_ptr, out_data_ptr, (int64_t*)agg_cum.get(), self.numel(), size, iters_per_cta);
C10_CUDA_KERNEL_LAUNCH_CHECK();
int64_t out_grid = std::min(num_sms, (size + BLOCK_THREADS - 1)/BLOCK_THREADS);
write_fill_value<<<out_grid, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(out_data_ptr, (int64_t *)agg_cum.get() + grid_size - 1, fill_value, size);
if (self.dim() > 1) {
TensorDims<int64_t> dims;
for (int i = 0; i < self.dim(); i++) {
dims.sizes[i] = self.sizes()[i];
}
const int nthreads = 256;
const int nblocks = (size + nthreads - 1) / nthreads;
write_indices<<<nblocks, nthreads, 0, at::cuda::getCurrentCUDAStream()>>>(
out_data_ptr,
dims,
self.dim(),
size,
(int64_t *)agg_cum.get() + grid_size - 1,
fill_value);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
if (need_to_copy) {
out.copy_(out_temp);
}
#else
TORCH_CHECK(false, "Nonzero_static is not supported for cuda <= 11.4");
#endif
}
Tensor& nonzero_out_cuda(const Tensor& self, Tensor& out) {
TORCH_CHECK(
out.dtype() == at::kLong,
@ -216,4 +404,56 @@ Tensor nonzero_cuda(const Tensor& self) {
Tensor out = at::detail::empty_cuda({0}, self.options().dtype(kLong));
return at::native::nonzero_out_cuda(self, out);
}
Tensor& nonzero_static_out_cuda(
const Tensor& self,
int64_t size,
int64_t fill_value,
Tensor& out) {
TORCH_CHECK(
out.dtype() == at::kLong,
"nonzero_static: Expected out tensor to have scalar type ",
at::kLong,
" but got ",
out.dtype());
TORCH_CHECK(
self.device() == out.device(),
"expected self and out to be on the same device, but got out on ",
out.device(),
" and self on ",
self.device());
TORCH_CHECK(
self.dim() <= MAX_DIMS,
"nonzero_static is not supported for tensor with more than ",
MAX_DIMS,
" dimensions");
TORCH_CHECK(
size >= 0, "nonzero_static: 'size' must be an non-negative integer"
)
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
at::ScalarType::ComplexHalf,
at::ScalarType::Bool,
at::ScalarType::BFloat16,
at::ScalarType::Half,
self.scalar_type(),
"nonzero_cuda",
[&] {
nonzero_static_cuda_out_impl<scalar_t>(self, size, fill_value, out);
});
return out;
}
Tensor nonzero_static_cuda(
const Tensor& self,
int64_t size,
int64_t fill_value) {
TORCH_CHECK(
size >= 0, "nonzero_static: 'size' must be an non-negative integer"
)
Tensor out = Tensor(at::detail::empty_cuda(
{self.dim(), size}, self.options().dtype(kLong)))
.t();
return at::native::nonzero_static_out_cuda(self, size, fill_value, out);
}
} // namespace at::native

View file

@ -9277,11 +9277,13 @@
- func: nonzero_static.out(Tensor self, *, int size, int fill_value=-1, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: nonzero_static_out_cpu
CUDA: nonzero_static_out_cuda
- func: nonzero_static(Tensor self, *, int size, int fill_value=-1) -> Tensor
variants: method, function
dispatch:
CPU: nonzero_static_cpu
CUDA: nonzero_static_cuda
- func: nonzero_numpy(Tensor self) -> Tensor[]
variants: method, function

View file

@ -5772,154 +5772,6 @@ utils_device.CURRENT_DEVICE == None""".split(
b = opt_fn(torch.tensor(True), torch.tensor([0.25, 0.25]))
self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), b))
def test_nonzero_static(self):
# invalid size
with self.assertRaisesRegex(
RuntimeError, "nonzero_static: 'size' must be an non-negative integer"
):
torch.nonzero_static(torch.tensor([8]), size=-2)
with self.assertRaisesRegex(
RuntimeError, "nonzero_static: 'size' must be an non-negative integer"
):
torch.nonzero_static(torch.tensor([8]), size=-2, out=torch.tensor(0))
# nonzero_static.out: out dtype mismatch
input_tensor = torch.tensor([8])
static_size = 1
out_tensor = torch.empty((static_size, input_tensor.dim()), dtype=torch.float)
with self.assertRaisesRegex(
RuntimeError, "nonzero_static: Expected out tensor to have scalar type Long"
):
torch.nonzero_static(input_tensor, size=static_size, out=out_tensor)
# nonzero_static.out: out resize (shrink)
input_tensor = torch.tensor([8])
static_size = 1
out_tensor = torch.empty((10, 10, 10, 10), dtype=torch.long)
self.assertTrue(
same(
torch.nonzero_static(input_tensor, size=static_size, out=out_tensor),
torch.tensor([0]),
)
)
self.assertTrue(
same(
out_tensor,
torch.tensor([0]),
)
)
# nonzero_static.out: out resize (enlarge)
input_tensor = torch.tensor([8])
static_size = 1
out_tensor = torch.empty((0), dtype=torch.long)
self.assertTrue(
same(
torch.nonzero_static(input_tensor, size=static_size, out=out_tensor),
torch.tensor([0]),
)
)
self.assertTrue(
same(
out_tensor,
torch.tensor([0]),
)
)
# 0 rank
input_tensor = torch.tensor(6)
static_size = 2
self.assertTrue(
same(
torch.nonzero_static(input_tensor, size=static_size),
torch.empty((static_size, input_tensor.dim()), dtype=torch.long),
)
)
# 0 size
input_tensor = torch.tensor([[[1]]])
static_size = 0
self.assertTrue(
same(
torch.nonzero_static(input_tensor, size=static_size),
torch.empty((static_size, input_tensor.dim()), dtype=torch.long),
)
)
# 1D input
input_tensor = torch.tensor([0, 8])
static_size = 1
self.assertTrue(
same(
torch.nonzero_static(input_tensor, size=static_size),
torch.tensor([1]),
)
)
input_tensor = torch.tensor([8, 0])
static_size = 2
self.assertTrue(
same(
torch.nonzero_static(input_tensor, size=static_size),
torch.tensor([[0], [-1]]), # padded with default fill_value "-1"
)
)
# 2D input
input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]])
static_size = 5
fill_value = -100
self.assertTrue(
torch._dynamo.utils.same(
torch.nonzero_static(
input_tensor, size=static_size, fill_value=fill_value
),
torch.tensor(
[
[0, 0],
[1, 0],
[1, 1],
[fill_value, fill_value],
[fill_value, fill_value],
]
),
)
)
input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]])
static_size = 2
fill_value = -100
self.assertTrue(
torch._dynamo.utils.same(
torch.nonzero_static(
input_tensor, size=static_size, fill_value=fill_value
),
torch.tensor([[0, 0], [1, 0]]),
)
)
# 3D input
input_tensor = torch.tensor([[[0, 0], [0, -3]], [[0, 0], [5, 0]]])
static_size = 4
fill_value = -999
self.assertTrue(
torch._dynamo.utils.same(
torch.nonzero_static(
input_tensor,
size=static_size,
fill_value=fill_value,
),
torch.tensor(
[
[0, 1, 1],
[1, 1, 0],
[fill_value, fill_value, fill_value],
[fill_value, fill_value, fill_value],
]
),
)
)
def test_cond_with_quantization(self):
from functorch.experimental.control_flow import cond

View file

@ -1577,6 +1577,168 @@ class TestUnaryUfuncs(TestCase):
y = torch.nonzero(x)
self.assertEqual(y.view(-1), indices)
def test_nonzero_static(self, device):
# invalid size
with self.assertRaisesRegex(
RuntimeError, "nonzero_static: 'size' must be an non-negative integer"
):
torch.nonzero_static(torch.tensor([8], device=device), size=-2)
with self.assertRaisesRegex(
RuntimeError, "nonzero_static: 'size' must be an non-negative integer"
):
torch.nonzero_static(
torch.tensor([8], device=device),
size=-2,
out=torch.tensor(0, device=device),
)
# nonzero_static.out: out dtype mismatch
input_tensor = torch.tensor([8], device=device)
static_size = 1
out_tensor = torch.empty(
(static_size, input_tensor.dim()), dtype=torch.float, device=device
)
with self.assertRaisesRegex(
RuntimeError, "nonzero_static: Expected out tensor to have scalar type Long"
):
torch.nonzero_static(input_tensor, size=static_size, out=out_tensor)
# nonzero_static.out: out resize (shrink)
input_tensor = torch.tensor([8], device=device)
static_size = 1
out_tensor = torch.empty((10, 10, 10, 10), dtype=torch.long, device=device)
ref = torch.tensor([[0]], device=device)
self.assertEqual(
torch.nonzero_static(input_tensor, size=static_size, out=out_tensor), ref
)
self.assertEqual(out_tensor, ref)
# nonzero_static.out: out resize (enlarge)
input_tensor = torch.tensor([8], device=device)
static_size = 1
out_tensor = torch.empty((0), dtype=torch.long, device=device)
self.assertEqual(
torch.nonzero_static(input_tensor, size=static_size, out=out_tensor), ref
)
self.assertEqual(out_tensor, ref)
# 0 rank
input_tensor = torch.tensor(6, device=device)
static_size = 2
self.assertEqual(
torch.nonzero_static(input_tensor, size=static_size),
torch.empty(
(static_size, input_tensor.dim()), device=device, dtype=torch.long
),
)
# 0 size
input_tensor = torch.tensor([[[1]]], device=device)
static_size = 0
self.assertEqual(
torch.nonzero_static(input_tensor, size=static_size),
torch.empty(
(static_size, input_tensor.dim()), device=device, dtype=torch.long
),
)
# 1D input
input_tensor = torch.tensor([0, 8], device=device)
static_size = 1
self.assertEqual(
torch.nonzero_static(input_tensor, size=static_size),
torch.tensor([[1]], device=device),
)
input_tensor = torch.tensor([8, 0], device=device)
static_size = 2
self.assertEqual(
torch.nonzero_static(input_tensor, size=static_size),
torch.tensor(
[[0], [-1]], device=device
), # padded with default fill_value "-1"
)
# 2D input
input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]], device=device)
static_size = 5
fill_value = -100
self.assertEqual(
torch.nonzero_static(input_tensor, size=static_size, fill_value=fill_value),
torch.tensor(
[
[0, 0],
[1, 0],
[1, 1],
[fill_value, fill_value],
[fill_value, fill_value],
],
device=device,
),
)
input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]], device=device)
static_size = 2
fill_value = -100
self.assertEqual(
torch.nonzero_static(input_tensor, size=static_size, fill_value=fill_value),
torch.tensor([[0, 0], [1, 0]], device=device),
)
# 3D input
input_tensor = torch.tensor(
[[[0, 0], [0, -3]], [[0, 0], [5, 0]]], device=device
)
static_size = 4
fill_value = -999
self.assertEqual(
torch.nonzero_static(
input_tensor,
size=static_size,
fill_value=fill_value,
),
torch.tensor(
[
[0, 1, 1],
[1, 1, 0],
[fill_value, fill_value, fill_value],
[fill_value, fill_value, fill_value],
],
device=device,
),
)
@onlyCUDA
def test_nonzero_static_large(self, device):
# large enough to have multiple iters per SM even on H100
# with 132 sms
size_inp = 1024 * 16 * 132 + 1024 * 16
x = torch.zeros(size_inp, device=device)
# unique indices
indices = torch.randperm(size_inp, device=device)[: size_inp // 2]
sorted, _ = torch.sort(indices)
x[sorted] = 1
res = torch.nonzero_static(x, size=size_inp // 2).view(-1)
self.assertEqual(res, sorted)
# no oob writes
out = torch.full((size_inp,), 10, device=device, dtype=torch.int64)
res = torch.nonzero_static(x, size=size_inp // 4, out=out[: size_inp // 2])
self.assertEqual(out[: size_inp // 4], sorted[: size_inp // 4])
self.assertEqual(
out[size_inp // 4 :],
torch.tensor(10, device="cuda").expand_as(out[size_inp // 4 :]),
)
# correct fill for 2d
x = x.view(2, size_inp // 2)
ref = x.nonzero()
res = x.nonzero_static(size=size_inp // 2 + 2)
self.assertEqual(res.shape, [size_inp // 2 + 2, 2])
self.assertEqual(ref, res[: size_inp // 2])
self.assertEqual(
res[size_inp // 2 :],
torch.tensor(-1, device="cuda").expand_as(res[size_inp // 2 :]),
)
# TODO: rationalize with exp OpInfo
@dtypes(*floating_and_complex_types_and(torch.bfloat16))