From bdbdbeeb3d88a65eb59c91257e61d1accaab646e Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Wed, 11 Dec 2024 06:44:48 +0000 Subject: [PATCH] Implements nonzero_static on cuda (#141838) using blockwide cub primitives. This adds CUDA functionality for nonzero_static, which was missing in https://github.com/pytorch/pytorch/pull/97417. For `size` approx equal to number of nonzeros, the perf is very close to the regular version, for larger sizes filling in padding indices takes additional time. Disabled for cuda <=11.4 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141838 Approved by: https://github.com/ezyang, https://github.com/malfet --- aten/src/ATen/cuda/cub.cuh | 56 +++-- aten/src/ATen/native/cuda/Nonzero.cu | 246 ++++++++++++++++++++- aten/src/ATen/native/native_functions.yaml | 2 + test/dynamo/test_misc.py | 148 ------------- test/test_unary_ufuncs.py | 162 ++++++++++++++ 5 files changed, 448 insertions(+), 166 deletions(-) diff --git a/aten/src/ATen/cuda/cub.cuh b/aten/src/ATen/cuda/cub.cuh index b214b182214..d75523f1ef9 100644 --- a/aten/src/ATen/cuda/cub.cuh +++ b/aten/src/ATen/cuda/cub.cuh @@ -349,7 +349,7 @@ __global__ void final_scan_kernel(const T* d_in, T* d_out, T* agg, int64_t nelem // Per-thread tile data T data[ITEMS_PER_THREAD]; - int remaining = nelem - BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x; + int64_t remaining = nelem - BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x; for (int i=0; i= BLOCK_THREADS * ITEMS_PER_THREAD) { @@ -386,38 +386,57 @@ __global__ void final_scan_kernel(const T* d_in, T* d_out, T* agg, int64_t nelem } +template +struct TransformFunctor { + __device__ aggT operator()(T value) const { + if constexpr (!nonzero) { + return value; + } else { + return (value != T(0)) ? 1 : 0; + } + } +}; - -template -__global__ void calc_block_sums(const T * d_in, T * agg, int64_t nelem, int iters_per_cta){ +template +__global__ void calc_block_sums(const T * d_in, aggT * agg, int64_t nelem, int iters_per_cta){ if (BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x >= nelem) return; - d_in += BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x; + d_in += BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * (int64_t)blockIdx.x; - using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad; - using BlockReduceT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockReduce; + using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad; + using BlockReduceT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockReduce; // Shared memory __shared__ union TempStorage { typename BlockLoadT::TempStorage load; typename BlockReduceT::TempStorage reduce; } temp_storage; - T data[ITEMS_PER_THREAD]; - T agg_val = 0; - int64_t remaining = nelem - BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x; + aggT data[ITEMS_PER_THREAD]; + aggT agg_val = 0; + int64_t remaining = nelem - BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * (int64_t)blockIdx.x; + TransformFunctor transform_functor; + auto iter_in = ROCM_HIPCUB(at_cuda_detail::cub)::TransformInputIterator, const T*>(d_in, transform_functor); for (int i=0; i= BLOCK_THREADS * ITEMS_PER_THREAD) { - BlockLoadT(temp_storage.load).Load(d_in, data); + BlockLoadT(temp_storage.load).Load(iter_in, data); __syncthreads(); agg_val += BlockReduceT(temp_storage.reduce).Sum(data); } else { - BlockLoadT(temp_storage.load).Load(d_in, data, remaining); + BlockLoadT(temp_storage.load).Load(iter_in, data, remaining, aggT(0)); __syncthreads(); agg_val += BlockReduceT(temp_storage.reduce).Sum(data); } - d_in += BLOCK_THREADS * ITEMS_PER_THREAD; + iter_in += BLOCK_THREADS * ITEMS_PER_THREAD; remaining -= BLOCK_THREADS * ITEMS_PER_THREAD; - if (remaining <= 0) return; + if (remaining <= 0) { + // for nonzeros we need to write out last blocks + // accumulated value to be able to compute + // total number of nonzeros + if (nonzero && threadIdx.x == 0) { + agg[blockIdx.x] = agg_val; + } + return; + } __syncthreads(); } @@ -427,6 +446,13 @@ __global__ void calc_block_sums(const T * d_in, T * agg, int64_t nelem, int iter } +template +struct NonZeroOp { + __host__ __device__ __forceinline__ int operator()(const T& a) const { + return (a != T(0)); + } +}; + template constexpr int block_threads(){ if constexpr (size >=16) { @@ -450,7 +476,7 @@ inline void inclusive_deterministic_scan(const scalar_t * input, scalar_t * out grid_size = std::min(num_sms, grid_size); auto& allocator = *c10::cuda::CUDACachingAllocator::get(); auto agg = allocator.allocate(grid_size * sizeof(scalar_t)); - calc_block_sums + calc_block_sums <<>>( input, (scalar_t*)agg.get(), num_items, iters_per_cta); C10_CUDA_KERNEL_LAUNCH_CHECK(); diff --git a/aten/src/ATen/native/cuda/Nonzero.cu b/aten/src/ATen/native/cuda/Nonzero.cu index e7c6c9d19d8..294a8c4a991 100644 --- a/aten/src/ATen/native/cuda/Nonzero.cu +++ b/aten/src/ATen/native/cuda/Nonzero.cu @@ -37,9 +37,12 @@ __global__ void write_indices( int64_t* inp, TensorDims dims, int ndim, - index_t n) { - auto index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < n) { + index_t n, + int64_t * total = nullptr, + int64_t fill_value = -1) { + auto index = threadIdx.x + (int64_t)blockIdx.x * blockDim.x; + bool cond = (total == nullptr || index < *total); + if (index < n && cond) { index_t div = 1; int64_t idx_flat = inp[index]; #pragma unroll @@ -50,9 +53,117 @@ __global__ void write_indices( inp[index + dim * n] = (idx_flat / div) % dim_size; div *= dim_size; } + } else if (index < n) { + // 0th dim has correct values already + for (int dim = ndim - 1; dim > 0; dim--) { + inp[index + dim * n] = fill_value; + } } } +__global__ void write_fill_value(int64_t * inp, int64_t * total, int64_t fill_value, int64_t n){ + int64_t total_val = *total; + // not aiming for vectorized stores + + for (int64_t idx = total_val + (int64_t)blockIdx.x * blockDim.x + threadIdx.x; idx < n; idx += blockDim.x * gridDim.x) { + inp[idx] = fill_value; + } +} + +template +__global__ void compute_agg(int32_t * agg, int64_t * agg_cum, uint32_t n_blocks) { + + using BlockScanT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockScan; + __shared__ typename BlockScanT::TempStorage temp_storage; + int agg_data; + int64_t agg_cum_data; + agg_data = threadIdx.x < n_blocks ? agg[threadIdx.x] : 0; + BlockScanT(temp_storage).InclusiveSum(agg_data, agg_cum_data); + if (threadIdx.x < n_blocks) { + agg_cum[threadIdx.x] = agg_cum_data; + } +} + +template +__global__ void flag_kernel(const T* d_in, int64_t * d_out, const int64_t * agg, int64_t input_nelem, int64_t output_nelem, int iters_per_cta) { + int64_t start_idx = BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * (int64_t)blockIdx.x; + if (start_idx >= input_nelem) return; + d_in += start_idx; + + using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad; + + // Specialize BlockScan type for our thread block + using BlockScanT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockScan; + using TransformInputIteratorT = ROCM_HIPCUB(at_cuda_detail::cub)::TransformInputIterator, const T*>; + using BlockExchangeT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockExchange; + + // Shared memory + __shared__ union TempStorage + { + typename BlockLoadT::TempStorage load; + typename BlockScanT::TempStorage scan; + typename BlockExchangeT::TempStorage exchange; + } temp_storage; + + int64_t aggregate = blockIdx.x == 0 ? 0 : agg[blockIdx.x - 1]; + d_out += aggregate; + + TransformInputIteratorT t_input_itr(d_in, NonZeroOp()); + + // Per-thread tile data + int data[ITEMS_PER_THREAD]; + int out_indices[ITEMS_PER_THREAD]; + + int64_t remaining = input_nelem - start_idx; + int64_t out_remaining = output_nelem - aggregate; + for (int i=0; i= BLOCK_THREADS * ITEMS_PER_THREAD) { + BlockLoadT(temp_storage.load).Load(t_input_itr, data); + } else { + BlockLoadT(temp_storage.load).Load(t_input_itr, data, remaining, int(0)); + } + + // Barrier for smem reuse + __syncthreads(); + + // Compute inclusive prefix sum + int aggregate; + __shared__ int aggregate_sh; + BlockScanT(temp_storage.scan).ExclusiveSum(data, out_indices, aggregate); + + if (threadIdx.x == 0){ + aggregate_sh = aggregate; + } + + // Barrier for smem reuse + __syncthreads(); + // striped arrangement will provide a slightly better + // coalescing for writes (although it's still bad because it's indirect indexing) + BlockExchangeT(temp_storage.exchange).BlockedToStriped(data); + __syncthreads(); + BlockExchangeT(temp_storage.exchange).BlockedToStriped(out_indices); + for (int ii=0; ii @@ -183,6 +294,83 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) { } } +template +void nonzero_static_cuda_out_impl( + const Tensor& self, + int64_t size, + int64_t fill_value, + Tensor& out) { +# if (defined(CUDA_VERSION) && CUDA_VERSION > 11040) || defined(USE_ROCM) + + Tensor self_contiguous_ = self.contiguous(); + // see comment in nonzero_cuda_out_impl on reqs for out + bool out_correct_size = + out.dim() == 2 && out.sizes()[0] == size && out.sizes()[1] == self.dim(); + bool need_to_copy = out_correct_size && !out.t().is_contiguous(); + if (!out_correct_size) { + out.resize_({self.dim(), size}).t(); + } + if (out.numel() == 0) return; + // we need to allocate temporary out to then copy to user provided out + at::Tensor out_temp; + if (need_to_copy) { + out_temp = + Tensor(at::detail::empty_cuda({self.dim(), size}, out.options())).t(); + } + int64_t* out_data_ptr = need_to_copy ? out_temp.mutable_data_ptr() + : out.mutable_data_ptr(); + + const scalar_t * in_data_ptr = self_contiguous_.const_data_ptr(); + constexpr int BLOCK_THREADS = 512; //block_threads(); + constexpr int ITEMS_PER_THREAD = 16; + auto grid_size = (self.numel() + BLOCK_THREADS * ITEMS_PER_THREAD - 1) / (BLOCK_THREADS * ITEMS_PER_THREAD); + const int64_t num_sms = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + int64_t target_blocks = sizeof(scalar_t) == 1 ? 2 * num_sms : num_sms; + const int iters_per_cta = (grid_size + target_blocks - 1)/target_blocks; + grid_size = (self.numel() + iters_per_cta * BLOCK_THREADS * ITEMS_PER_THREAD - 1) / (iters_per_cta * BLOCK_THREADS * ITEMS_PER_THREAD); + auto& allocator = *c10::cuda::CUDACachingAllocator::get(); + auto agg = allocator.allocate(grid_size * sizeof(int)); + at::cuda::cub::calc_block_sums + <<>>( + in_data_ptr, (int*)agg.get(), self.numel(), iters_per_cta); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + auto agg_cum = allocator.allocate(grid_size * sizeof(int64_t)); + // computing partial sums in int64 in the flag kernel + // leads to 20-30% slowdown, so compute them in a separate 2 us kernel + compute_agg<<<1, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>( + (int*)agg.get(), (int64_t*)agg_cum.get(), grid_size + ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + flag_kernel + <<>>( + in_data_ptr, out_data_ptr, (int64_t*)agg_cum.get(), self.numel(), size, iters_per_cta); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + int64_t out_grid = std::min(num_sms, (size + BLOCK_THREADS - 1)/BLOCK_THREADS); + write_fill_value<<>>(out_data_ptr, (int64_t *)agg_cum.get() + grid_size - 1, fill_value, size); + if (self.dim() > 1) { + TensorDims dims; + for (int i = 0; i < self.dim(); i++) { + dims.sizes[i] = self.sizes()[i]; + } + const int nthreads = 256; + const int nblocks = (size + nthreads - 1) / nthreads; + write_indices<<>>( + out_data_ptr, + dims, + self.dim(), + size, + (int64_t *)agg_cum.get() + grid_size - 1, + fill_value); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + if (need_to_copy) { + out.copy_(out_temp); + } +#else + TORCH_CHECK(false, "Nonzero_static is not supported for cuda <= 11.4"); +#endif +} + Tensor& nonzero_out_cuda(const Tensor& self, Tensor& out) { TORCH_CHECK( out.dtype() == at::kLong, @@ -216,4 +404,56 @@ Tensor nonzero_cuda(const Tensor& self) { Tensor out = at::detail::empty_cuda({0}, self.options().dtype(kLong)); return at::native::nonzero_out_cuda(self, out); } + +Tensor& nonzero_static_out_cuda( + const Tensor& self, + int64_t size, + int64_t fill_value, + Tensor& out) { + TORCH_CHECK( + out.dtype() == at::kLong, + "nonzero_static: Expected out tensor to have scalar type ", + at::kLong, + " but got ", + out.dtype()); + TORCH_CHECK( + self.device() == out.device(), + "expected self and out to be on the same device, but got out on ", + out.device(), + " and self on ", + self.device()); + TORCH_CHECK( + self.dim() <= MAX_DIMS, + "nonzero_static is not supported for tensor with more than ", + MAX_DIMS, + " dimensions"); + TORCH_CHECK( + size >= 0, "nonzero_static: 'size' must be an non-negative integer" + ) + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( + at::ScalarType::ComplexHalf, + at::ScalarType::Bool, + at::ScalarType::BFloat16, + at::ScalarType::Half, + self.scalar_type(), + "nonzero_cuda", + [&] { + nonzero_static_cuda_out_impl(self, size, fill_value, out); + }); + return out; +} + +Tensor nonzero_static_cuda( + const Tensor& self, + int64_t size, + int64_t fill_value) { + TORCH_CHECK( + size >= 0, "nonzero_static: 'size' must be an non-negative integer" + ) + Tensor out = Tensor(at::detail::empty_cuda( + {self.dim(), size}, self.options().dtype(kLong))) + .t(); + return at::native::nonzero_static_out_cuda(self, size, fill_value, out); +} + } // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 3f7df7676f7..ede2e97aead 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -9277,11 +9277,13 @@ - func: nonzero_static.out(Tensor self, *, int size, int fill_value=-1, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: nonzero_static_out_cpu + CUDA: nonzero_static_out_cuda - func: nonzero_static(Tensor self, *, int size, int fill_value=-1) -> Tensor variants: method, function dispatch: CPU: nonzero_static_cpu + CUDA: nonzero_static_cuda - func: nonzero_numpy(Tensor self) -> Tensor[] variants: method, function diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index f1292198e36..72572e8e84b 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -5772,154 +5772,6 @@ utils_device.CURRENT_DEVICE == None""".split( b = opt_fn(torch.tensor(True), torch.tensor([0.25, 0.25])) self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), b)) - def test_nonzero_static(self): - # invalid size - with self.assertRaisesRegex( - RuntimeError, "nonzero_static: 'size' must be an non-negative integer" - ): - torch.nonzero_static(torch.tensor([8]), size=-2) - - with self.assertRaisesRegex( - RuntimeError, "nonzero_static: 'size' must be an non-negative integer" - ): - torch.nonzero_static(torch.tensor([8]), size=-2, out=torch.tensor(0)) - - # nonzero_static.out: out dtype mismatch - input_tensor = torch.tensor([8]) - static_size = 1 - out_tensor = torch.empty((static_size, input_tensor.dim()), dtype=torch.float) - with self.assertRaisesRegex( - RuntimeError, "nonzero_static: Expected out tensor to have scalar type Long" - ): - torch.nonzero_static(input_tensor, size=static_size, out=out_tensor) - - # nonzero_static.out: out resize (shrink) - input_tensor = torch.tensor([8]) - static_size = 1 - out_tensor = torch.empty((10, 10, 10, 10), dtype=torch.long) - self.assertTrue( - same( - torch.nonzero_static(input_tensor, size=static_size, out=out_tensor), - torch.tensor([0]), - ) - ) - self.assertTrue( - same( - out_tensor, - torch.tensor([0]), - ) - ) - - # nonzero_static.out: out resize (enlarge) - input_tensor = torch.tensor([8]) - static_size = 1 - out_tensor = torch.empty((0), dtype=torch.long) - self.assertTrue( - same( - torch.nonzero_static(input_tensor, size=static_size, out=out_tensor), - torch.tensor([0]), - ) - ) - self.assertTrue( - same( - out_tensor, - torch.tensor([0]), - ) - ) - - # 0 rank - input_tensor = torch.tensor(6) - static_size = 2 - self.assertTrue( - same( - torch.nonzero_static(input_tensor, size=static_size), - torch.empty((static_size, input_tensor.dim()), dtype=torch.long), - ) - ) - - # 0 size - input_tensor = torch.tensor([[[1]]]) - static_size = 0 - self.assertTrue( - same( - torch.nonzero_static(input_tensor, size=static_size), - torch.empty((static_size, input_tensor.dim()), dtype=torch.long), - ) - ) - - # 1D input - input_tensor = torch.tensor([0, 8]) - static_size = 1 - self.assertTrue( - same( - torch.nonzero_static(input_tensor, size=static_size), - torch.tensor([1]), - ) - ) - - input_tensor = torch.tensor([8, 0]) - static_size = 2 - self.assertTrue( - same( - torch.nonzero_static(input_tensor, size=static_size), - torch.tensor([[0], [-1]]), # padded with default fill_value "-1" - ) - ) - - # 2D input - input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]]) - static_size = 5 - fill_value = -100 - self.assertTrue( - torch._dynamo.utils.same( - torch.nonzero_static( - input_tensor, size=static_size, fill_value=fill_value - ), - torch.tensor( - [ - [0, 0], - [1, 0], - [1, 1], - [fill_value, fill_value], - [fill_value, fill_value], - ] - ), - ) - ) - input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]]) - static_size = 2 - fill_value = -100 - self.assertTrue( - torch._dynamo.utils.same( - torch.nonzero_static( - input_tensor, size=static_size, fill_value=fill_value - ), - torch.tensor([[0, 0], [1, 0]]), - ) - ) - - # 3D input - input_tensor = torch.tensor([[[0, 0], [0, -3]], [[0, 0], [5, 0]]]) - static_size = 4 - fill_value = -999 - self.assertTrue( - torch._dynamo.utils.same( - torch.nonzero_static( - input_tensor, - size=static_size, - fill_value=fill_value, - ), - torch.tensor( - [ - [0, 1, 1], - [1, 1, 0], - [fill_value, fill_value, fill_value], - [fill_value, fill_value, fill_value], - ] - ), - ) - ) - def test_cond_with_quantization(self): from functorch.experimental.control_flow import cond diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 67c56145283..7ea1155165f 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -1577,6 +1577,168 @@ class TestUnaryUfuncs(TestCase): y = torch.nonzero(x) self.assertEqual(y.view(-1), indices) + def test_nonzero_static(self, device): + # invalid size + with self.assertRaisesRegex( + RuntimeError, "nonzero_static: 'size' must be an non-negative integer" + ): + torch.nonzero_static(torch.tensor([8], device=device), size=-2) + + with self.assertRaisesRegex( + RuntimeError, "nonzero_static: 'size' must be an non-negative integer" + ): + torch.nonzero_static( + torch.tensor([8], device=device), + size=-2, + out=torch.tensor(0, device=device), + ) + + # nonzero_static.out: out dtype mismatch + input_tensor = torch.tensor([8], device=device) + static_size = 1 + out_tensor = torch.empty( + (static_size, input_tensor.dim()), dtype=torch.float, device=device + ) + with self.assertRaisesRegex( + RuntimeError, "nonzero_static: Expected out tensor to have scalar type Long" + ): + torch.nonzero_static(input_tensor, size=static_size, out=out_tensor) + + # nonzero_static.out: out resize (shrink) + input_tensor = torch.tensor([8], device=device) + static_size = 1 + out_tensor = torch.empty((10, 10, 10, 10), dtype=torch.long, device=device) + ref = torch.tensor([[0]], device=device) + self.assertEqual( + torch.nonzero_static(input_tensor, size=static_size, out=out_tensor), ref + ) + self.assertEqual(out_tensor, ref) + + # nonzero_static.out: out resize (enlarge) + input_tensor = torch.tensor([8], device=device) + static_size = 1 + out_tensor = torch.empty((0), dtype=torch.long, device=device) + self.assertEqual( + torch.nonzero_static(input_tensor, size=static_size, out=out_tensor), ref + ) + self.assertEqual(out_tensor, ref) + + # 0 rank + input_tensor = torch.tensor(6, device=device) + static_size = 2 + self.assertEqual( + torch.nonzero_static(input_tensor, size=static_size), + torch.empty( + (static_size, input_tensor.dim()), device=device, dtype=torch.long + ), + ) + + # 0 size + input_tensor = torch.tensor([[[1]]], device=device) + static_size = 0 + self.assertEqual( + torch.nonzero_static(input_tensor, size=static_size), + torch.empty( + (static_size, input_tensor.dim()), device=device, dtype=torch.long + ), + ) + + # 1D input + input_tensor = torch.tensor([0, 8], device=device) + static_size = 1 + self.assertEqual( + torch.nonzero_static(input_tensor, size=static_size), + torch.tensor([[1]], device=device), + ) + + input_tensor = torch.tensor([8, 0], device=device) + static_size = 2 + self.assertEqual( + torch.nonzero_static(input_tensor, size=static_size), + torch.tensor( + [[0], [-1]], device=device + ), # padded with default fill_value "-1" + ) + + # 2D input + input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]], device=device) + static_size = 5 + fill_value = -100 + self.assertEqual( + torch.nonzero_static(input_tensor, size=static_size, fill_value=fill_value), + torch.tensor( + [ + [0, 0], + [1, 0], + [1, 1], + [fill_value, fill_value], + [fill_value, fill_value], + ], + device=device, + ), + ) + input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]], device=device) + static_size = 2 + fill_value = -100 + self.assertEqual( + torch.nonzero_static(input_tensor, size=static_size, fill_value=fill_value), + torch.tensor([[0, 0], [1, 0]], device=device), + ) + + # 3D input + input_tensor = torch.tensor( + [[[0, 0], [0, -3]], [[0, 0], [5, 0]]], device=device + ) + static_size = 4 + fill_value = -999 + self.assertEqual( + torch.nonzero_static( + input_tensor, + size=static_size, + fill_value=fill_value, + ), + torch.tensor( + [ + [0, 1, 1], + [1, 1, 0], + [fill_value, fill_value, fill_value], + [fill_value, fill_value, fill_value], + ], + device=device, + ), + ) + + @onlyCUDA + def test_nonzero_static_large(self, device): + # large enough to have multiple iters per SM even on H100 + # with 132 sms + size_inp = 1024 * 16 * 132 + 1024 * 16 + x = torch.zeros(size_inp, device=device) + # unique indices + indices = torch.randperm(size_inp, device=device)[: size_inp // 2] + sorted, _ = torch.sort(indices) + x[sorted] = 1 + res = torch.nonzero_static(x, size=size_inp // 2).view(-1) + self.assertEqual(res, sorted) + # no oob writes + out = torch.full((size_inp,), 10, device=device, dtype=torch.int64) + res = torch.nonzero_static(x, size=size_inp // 4, out=out[: size_inp // 2]) + self.assertEqual(out[: size_inp // 4], sorted[: size_inp // 4]) + self.assertEqual( + out[size_inp // 4 :], + torch.tensor(10, device="cuda").expand_as(out[size_inp // 4 :]), + ) + # correct fill for 2d + x = x.view(2, size_inp // 2) + ref = x.nonzero() + res = x.nonzero_static(size=size_inp // 2 + 2) + self.assertEqual(res.shape, [size_inp // 2 + 2, 2]) + self.assertEqual(ref, res[: size_inp // 2]) + self.assertEqual( + res[size_inp // 2 :], + torch.tensor(-1, device="cuda").expand_as(res[size_inp // 2 :]), + ) + # TODO: rationalize with exp OpInfo @dtypes(*floating_and_complex_types_and(torch.bfloat16))