mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Implements nonzero_static on cuda (#141838)
using blockwide cub primitives. This adds CUDA functionality for nonzero_static, which was missing in https://github.com/pytorch/pytorch/pull/97417. For `size` approx equal to number of nonzeros, the perf is very close to the regular version, for larger sizes filling in padding indices takes additional time. Disabled for cuda <=11.4 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141838 Approved by: https://github.com/ezyang, https://github.com/malfet
This commit is contained in:
parent
1d3b0108a6
commit
bdbdbeeb3d
5 changed files with 448 additions and 166 deletions
|
|
@ -349,7 +349,7 @@ __global__ void final_scan_kernel(const T* d_in, T* d_out, T* agg, int64_t nelem
|
|||
// Per-thread tile data
|
||||
T data[ITEMS_PER_THREAD];
|
||||
|
||||
int remaining = nelem - BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x;
|
||||
int64_t remaining = nelem - BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x;
|
||||
for (int i=0; i<iters_per_cta; i++){
|
||||
// Load items into a blocked arrangement
|
||||
if (remaining >= BLOCK_THREADS * ITEMS_PER_THREAD) {
|
||||
|
|
@ -386,38 +386,57 @@ __global__ void final_scan_kernel(const T* d_in, T* d_out, T* agg, int64_t nelem
|
|||
|
||||
}
|
||||
|
||||
template <typename T, typename aggT, bool nonzero>
|
||||
struct TransformFunctor {
|
||||
__device__ aggT operator()(T value) const {
|
||||
if constexpr (!nonzero) {
|
||||
return value;
|
||||
} else {
|
||||
return (value != T(0)) ? 1 : 0;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template<int BLOCK_THREADS, int ITEMS_PER_THREAD, typename T>
|
||||
__global__ void calc_block_sums(const T * d_in, T * agg, int64_t nelem, int iters_per_cta){
|
||||
template<int BLOCK_THREADS, int ITEMS_PER_THREAD, bool nonzero, typename T, typename aggT>
|
||||
__global__ void calc_block_sums(const T * d_in, aggT * agg, int64_t nelem, int iters_per_cta){
|
||||
if (BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x >= nelem) return;
|
||||
d_in += BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x;
|
||||
d_in += BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * (int64_t)blockIdx.x;
|
||||
|
||||
using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad<T, BLOCK_THREADS, ITEMS_PER_THREAD, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_LOAD_STRIPED>;
|
||||
using BlockReduceT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockReduce<T, BLOCK_THREADS>;
|
||||
using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad<aggT, BLOCK_THREADS, ITEMS_PER_THREAD, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_LOAD_STRIPED>;
|
||||
using BlockReduceT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockReduce<aggT, BLOCK_THREADS>;
|
||||
// Shared memory
|
||||
__shared__ union TempStorage
|
||||
{
|
||||
typename BlockLoadT::TempStorage load;
|
||||
typename BlockReduceT::TempStorage reduce;
|
||||
} temp_storage;
|
||||
T data[ITEMS_PER_THREAD];
|
||||
T agg_val = 0;
|
||||
int64_t remaining = nelem - BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x;
|
||||
aggT data[ITEMS_PER_THREAD];
|
||||
aggT agg_val = 0;
|
||||
int64_t remaining = nelem - BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * (int64_t)blockIdx.x;
|
||||
TransformFunctor<T, aggT, nonzero> transform_functor;
|
||||
auto iter_in = ROCM_HIPCUB(at_cuda_detail::cub)::TransformInputIterator<aggT, TransformFunctor<T, aggT, nonzero>, const T*>(d_in, transform_functor);
|
||||
for (int i=0; i<iters_per_cta; i++){
|
||||
if (remaining >= BLOCK_THREADS * ITEMS_PER_THREAD) {
|
||||
BlockLoadT(temp_storage.load).Load(d_in, data);
|
||||
BlockLoadT(temp_storage.load).Load(iter_in, data);
|
||||
__syncthreads();
|
||||
agg_val += BlockReduceT(temp_storage.reduce).Sum(data);
|
||||
|
||||
} else {
|
||||
BlockLoadT(temp_storage.load).Load(d_in, data, remaining);
|
||||
BlockLoadT(temp_storage.load).Load(iter_in, data, remaining, aggT(0));
|
||||
__syncthreads();
|
||||
agg_val += BlockReduceT(temp_storage.reduce).Sum(data);
|
||||
}
|
||||
d_in += BLOCK_THREADS * ITEMS_PER_THREAD;
|
||||
iter_in += BLOCK_THREADS * ITEMS_PER_THREAD;
|
||||
remaining -= BLOCK_THREADS * ITEMS_PER_THREAD;
|
||||
if (remaining <= 0) return;
|
||||
if (remaining <= 0) {
|
||||
// for nonzeros we need to write out last blocks
|
||||
// accumulated value to be able to compute
|
||||
// total number of nonzeros
|
||||
if (nonzero && threadIdx.x == 0) {
|
||||
agg[blockIdx.x] = agg_val;
|
||||
}
|
||||
return;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
|
|
@ -427,6 +446,13 @@ __global__ void calc_block_sums(const T * d_in, T * agg, int64_t nelem, int iter
|
|||
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct NonZeroOp {
|
||||
__host__ __device__ __forceinline__ int operator()(const T& a) const {
|
||||
return (a != T(0));
|
||||
}
|
||||
};
|
||||
|
||||
template<int size>
|
||||
constexpr int block_threads(){
|
||||
if constexpr (size >=16) {
|
||||
|
|
@ -450,7 +476,7 @@ inline void inclusive_deterministic_scan(const scalar_t * input, scalar_t * out
|
|||
grid_size = std::min(num_sms, grid_size);
|
||||
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
|
||||
auto agg = allocator.allocate(grid_size * sizeof(scalar_t));
|
||||
calc_block_sums<BLOCK_THREADS, ITEMS_PER_THREAD>
|
||||
calc_block_sums<BLOCK_THREADS, ITEMS_PER_THREAD, false>
|
||||
<<<grid_size, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
input, (scalar_t*)agg.get(), num_items, iters_per_cta);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
|
|
|
|||
|
|
@ -37,9 +37,12 @@ __global__ void write_indices(
|
|||
int64_t* inp,
|
||||
TensorDims<index_t> dims,
|
||||
int ndim,
|
||||
index_t n) {
|
||||
auto index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (index < n) {
|
||||
index_t n,
|
||||
int64_t * total = nullptr,
|
||||
int64_t fill_value = -1) {
|
||||
auto index = threadIdx.x + (int64_t)blockIdx.x * blockDim.x;
|
||||
bool cond = (total == nullptr || index < *total);
|
||||
if (index < n && cond) {
|
||||
index_t div = 1;
|
||||
int64_t idx_flat = inp[index];
|
||||
#pragma unroll
|
||||
|
|
@ -50,9 +53,117 @@ __global__ void write_indices(
|
|||
inp[index + dim * n] = (idx_flat / div) % dim_size;
|
||||
div *= dim_size;
|
||||
}
|
||||
} else if (index < n) {
|
||||
// 0th dim has correct values already
|
||||
for (int dim = ndim - 1; dim > 0; dim--) {
|
||||
inp[index + dim * n] = fill_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void write_fill_value(int64_t * inp, int64_t * total, int64_t fill_value, int64_t n){
|
||||
int64_t total_val = *total;
|
||||
// not aiming for vectorized stores
|
||||
|
||||
for (int64_t idx = total_val + (int64_t)blockIdx.x * blockDim.x + threadIdx.x; idx < n; idx += blockDim.x * gridDim.x) {
|
||||
inp[idx] = fill_value;
|
||||
}
|
||||
}
|
||||
|
||||
template <int BLOCK_THREADS>
|
||||
__global__ void compute_agg(int32_t * agg, int64_t * agg_cum, uint32_t n_blocks) {
|
||||
|
||||
using BlockScanT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockScan<int64_t, BLOCK_THREADS, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_SCAN_WARP_SCANS>;
|
||||
__shared__ typename BlockScanT::TempStorage temp_storage;
|
||||
int agg_data;
|
||||
int64_t agg_cum_data;
|
||||
agg_data = threadIdx.x < n_blocks ? agg[threadIdx.x] : 0;
|
||||
BlockScanT(temp_storage).InclusiveSum(agg_data, agg_cum_data);
|
||||
if (threadIdx.x < n_blocks) {
|
||||
agg_cum[threadIdx.x] = agg_cum_data;
|
||||
}
|
||||
}
|
||||
|
||||
template<int BLOCK_THREADS, int ITEMS_PER_THREAD, typename T>
|
||||
__global__ void flag_kernel(const T* d_in, int64_t * d_out, const int64_t * agg, int64_t input_nelem, int64_t output_nelem, int iters_per_cta) {
|
||||
int64_t start_idx = BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * (int64_t)blockIdx.x;
|
||||
if (start_idx >= input_nelem) return;
|
||||
d_in += start_idx;
|
||||
|
||||
using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad<int, BLOCK_THREADS, ITEMS_PER_THREAD, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
|
||||
// Specialize BlockScan type for our thread block
|
||||
using BlockScanT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockScan<int, BLOCK_THREADS, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_SCAN_WARP_SCANS>;
|
||||
using TransformInputIteratorT = ROCM_HIPCUB(at_cuda_detail::cub)::TransformInputIterator<int, NonZeroOp<T>, const T*>;
|
||||
using BlockExchangeT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockExchange<int, BLOCK_THREADS, ITEMS_PER_THREAD>;
|
||||
|
||||
// Shared memory
|
||||
__shared__ union TempStorage
|
||||
{
|
||||
typename BlockLoadT::TempStorage load;
|
||||
typename BlockScanT::TempStorage scan;
|
||||
typename BlockExchangeT::TempStorage exchange;
|
||||
} temp_storage;
|
||||
|
||||
int64_t aggregate = blockIdx.x == 0 ? 0 : agg[blockIdx.x - 1];
|
||||
d_out += aggregate;
|
||||
|
||||
TransformInputIteratorT t_input_itr(d_in, NonZeroOp<T>());
|
||||
|
||||
// Per-thread tile data
|
||||
int data[ITEMS_PER_THREAD];
|
||||
int out_indices[ITEMS_PER_THREAD];
|
||||
|
||||
int64_t remaining = input_nelem - start_idx;
|
||||
int64_t out_remaining = output_nelem - aggregate;
|
||||
for (int i=0; i<iters_per_cta; i++){
|
||||
|
||||
// Load items into a blocked arrangement
|
||||
if (remaining >= BLOCK_THREADS * ITEMS_PER_THREAD) {
|
||||
BlockLoadT(temp_storage.load).Load(t_input_itr, data);
|
||||
} else {
|
||||
BlockLoadT(temp_storage.load).Load(t_input_itr, data, remaining, int(0));
|
||||
}
|
||||
|
||||
// Barrier for smem reuse
|
||||
__syncthreads();
|
||||
|
||||
// Compute inclusive prefix sum
|
||||
int aggregate;
|
||||
__shared__ int aggregate_sh;
|
||||
BlockScanT(temp_storage.scan).ExclusiveSum(data, out_indices, aggregate);
|
||||
|
||||
if (threadIdx.x == 0){
|
||||
aggregate_sh = aggregate;
|
||||
}
|
||||
|
||||
// Barrier for smem reuse
|
||||
__syncthreads();
|
||||
// striped arrangement will provide a slightly better
|
||||
// coalescing for writes (although it's still bad because it's indirect indexing)
|
||||
BlockExchangeT(temp_storage.exchange).BlockedToStriped(data);
|
||||
__syncthreads();
|
||||
BlockExchangeT(temp_storage.exchange).BlockedToStriped(out_indices);
|
||||
for (int ii=0; ii<ITEMS_PER_THREAD; ii++){
|
||||
if (data[ii] != 0 && out_indices[ii] < out_remaining) {
|
||||
int64_t inp_idx = start_idx + threadIdx.x + blockDim.x * ii;
|
||||
d_out[out_indices[ii]] = inp_idx;
|
||||
}
|
||||
}
|
||||
|
||||
out_remaining -= aggregate_sh;
|
||||
remaining -= BLOCK_THREADS * ITEMS_PER_THREAD;
|
||||
if (remaining <= 0 || out_remaining <= 0) return;
|
||||
d_out += aggregate_sh;
|
||||
t_input_itr += BLOCK_THREADS * ITEMS_PER_THREAD;
|
||||
start_idx += BLOCK_THREADS * ITEMS_PER_THREAD;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
|
|
@ -183,6 +294,83 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void nonzero_static_cuda_out_impl(
|
||||
const Tensor& self,
|
||||
int64_t size,
|
||||
int64_t fill_value,
|
||||
Tensor& out) {
|
||||
# if (defined(CUDA_VERSION) && CUDA_VERSION > 11040) || defined(USE_ROCM)
|
||||
|
||||
Tensor self_contiguous_ = self.contiguous();
|
||||
// see comment in nonzero_cuda_out_impl on reqs for out
|
||||
bool out_correct_size =
|
||||
out.dim() == 2 && out.sizes()[0] == size && out.sizes()[1] == self.dim();
|
||||
bool need_to_copy = out_correct_size && !out.t().is_contiguous();
|
||||
if (!out_correct_size) {
|
||||
out.resize_({self.dim(), size}).t();
|
||||
}
|
||||
if (out.numel() == 0) return;
|
||||
// we need to allocate temporary out to then copy to user provided out
|
||||
at::Tensor out_temp;
|
||||
if (need_to_copy) {
|
||||
out_temp =
|
||||
Tensor(at::detail::empty_cuda({self.dim(), size}, out.options())).t();
|
||||
}
|
||||
int64_t* out_data_ptr = need_to_copy ? out_temp.mutable_data_ptr<int64_t>()
|
||||
: out.mutable_data_ptr<int64_t>();
|
||||
|
||||
const scalar_t * in_data_ptr = self_contiguous_.const_data_ptr<scalar_t>();
|
||||
constexpr int BLOCK_THREADS = 512; //block_threads<sizeof(scalar_t)>();
|
||||
constexpr int ITEMS_PER_THREAD = 16;
|
||||
auto grid_size = (self.numel() + BLOCK_THREADS * ITEMS_PER_THREAD - 1) / (BLOCK_THREADS * ITEMS_PER_THREAD);
|
||||
const int64_t num_sms = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
||||
int64_t target_blocks = sizeof(scalar_t) == 1 ? 2 * num_sms : num_sms;
|
||||
const int iters_per_cta = (grid_size + target_blocks - 1)/target_blocks;
|
||||
grid_size = (self.numel() + iters_per_cta * BLOCK_THREADS * ITEMS_PER_THREAD - 1) / (iters_per_cta * BLOCK_THREADS * ITEMS_PER_THREAD);
|
||||
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
|
||||
auto agg = allocator.allocate(grid_size * sizeof(int));
|
||||
at::cuda::cub::calc_block_sums<BLOCK_THREADS, ITEMS_PER_THREAD, true>
|
||||
<<<grid_size, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
in_data_ptr, (int*)agg.get(), self.numel(), iters_per_cta);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
auto agg_cum = allocator.allocate(grid_size * sizeof(int64_t));
|
||||
// computing partial sums in int64 in the flag kernel
|
||||
// leads to 20-30% slowdown, so compute them in a separate 2 us kernel
|
||||
compute_agg<BLOCK_THREADS><<<1, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
(int*)agg.get(), (int64_t*)agg_cum.get(), grid_size
|
||||
);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
flag_kernel<BLOCK_THREADS, ITEMS_PER_THREAD>
|
||||
<<<grid_size, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
in_data_ptr, out_data_ptr, (int64_t*)agg_cum.get(), self.numel(), size, iters_per_cta);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
int64_t out_grid = std::min(num_sms, (size + BLOCK_THREADS - 1)/BLOCK_THREADS);
|
||||
write_fill_value<<<out_grid, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(out_data_ptr, (int64_t *)agg_cum.get() + grid_size - 1, fill_value, size);
|
||||
if (self.dim() > 1) {
|
||||
TensorDims<int64_t> dims;
|
||||
for (int i = 0; i < self.dim(); i++) {
|
||||
dims.sizes[i] = self.sizes()[i];
|
||||
}
|
||||
const int nthreads = 256;
|
||||
const int nblocks = (size + nthreads - 1) / nthreads;
|
||||
write_indices<<<nblocks, nthreads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
out_data_ptr,
|
||||
dims,
|
||||
self.dim(),
|
||||
size,
|
||||
(int64_t *)agg_cum.get() + grid_size - 1,
|
||||
fill_value);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
if (need_to_copy) {
|
||||
out.copy_(out_temp);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "Nonzero_static is not supported for cuda <= 11.4");
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor& nonzero_out_cuda(const Tensor& self, Tensor& out) {
|
||||
TORCH_CHECK(
|
||||
out.dtype() == at::kLong,
|
||||
|
|
@ -216,4 +404,56 @@ Tensor nonzero_cuda(const Tensor& self) {
|
|||
Tensor out = at::detail::empty_cuda({0}, self.options().dtype(kLong));
|
||||
return at::native::nonzero_out_cuda(self, out);
|
||||
}
|
||||
|
||||
Tensor& nonzero_static_out_cuda(
|
||||
const Tensor& self,
|
||||
int64_t size,
|
||||
int64_t fill_value,
|
||||
Tensor& out) {
|
||||
TORCH_CHECK(
|
||||
out.dtype() == at::kLong,
|
||||
"nonzero_static: Expected out tensor to have scalar type ",
|
||||
at::kLong,
|
||||
" but got ",
|
||||
out.dtype());
|
||||
TORCH_CHECK(
|
||||
self.device() == out.device(),
|
||||
"expected self and out to be on the same device, but got out on ",
|
||||
out.device(),
|
||||
" and self on ",
|
||||
self.device());
|
||||
TORCH_CHECK(
|
||||
self.dim() <= MAX_DIMS,
|
||||
"nonzero_static is not supported for tensor with more than ",
|
||||
MAX_DIMS,
|
||||
" dimensions");
|
||||
TORCH_CHECK(
|
||||
size >= 0, "nonzero_static: 'size' must be an non-negative integer"
|
||||
)
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
|
||||
at::ScalarType::ComplexHalf,
|
||||
at::ScalarType::Bool,
|
||||
at::ScalarType::BFloat16,
|
||||
at::ScalarType::Half,
|
||||
self.scalar_type(),
|
||||
"nonzero_cuda",
|
||||
[&] {
|
||||
nonzero_static_cuda_out_impl<scalar_t>(self, size, fill_value, out);
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor nonzero_static_cuda(
|
||||
const Tensor& self,
|
||||
int64_t size,
|
||||
int64_t fill_value) {
|
||||
TORCH_CHECK(
|
||||
size >= 0, "nonzero_static: 'size' must be an non-negative integer"
|
||||
)
|
||||
Tensor out = Tensor(at::detail::empty_cuda(
|
||||
{self.dim(), size}, self.options().dtype(kLong)))
|
||||
.t();
|
||||
return at::native::nonzero_static_out_cuda(self, size, fill_value, out);
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -9277,11 +9277,13 @@
|
|||
- func: nonzero_static.out(Tensor self, *, int size, int fill_value=-1, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU: nonzero_static_out_cpu
|
||||
CUDA: nonzero_static_out_cuda
|
||||
|
||||
- func: nonzero_static(Tensor self, *, int size, int fill_value=-1) -> Tensor
|
||||
variants: method, function
|
||||
dispatch:
|
||||
CPU: nonzero_static_cpu
|
||||
CUDA: nonzero_static_cuda
|
||||
|
||||
- func: nonzero_numpy(Tensor self) -> Tensor[]
|
||||
variants: method, function
|
||||
|
|
|
|||
|
|
@ -5772,154 +5772,6 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
b = opt_fn(torch.tensor(True), torch.tensor([0.25, 0.25]))
|
||||
self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), b))
|
||||
|
||||
def test_nonzero_static(self):
|
||||
# invalid size
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "nonzero_static: 'size' must be an non-negative integer"
|
||||
):
|
||||
torch.nonzero_static(torch.tensor([8]), size=-2)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "nonzero_static: 'size' must be an non-negative integer"
|
||||
):
|
||||
torch.nonzero_static(torch.tensor([8]), size=-2, out=torch.tensor(0))
|
||||
|
||||
# nonzero_static.out: out dtype mismatch
|
||||
input_tensor = torch.tensor([8])
|
||||
static_size = 1
|
||||
out_tensor = torch.empty((static_size, input_tensor.dim()), dtype=torch.float)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "nonzero_static: Expected out tensor to have scalar type Long"
|
||||
):
|
||||
torch.nonzero_static(input_tensor, size=static_size, out=out_tensor)
|
||||
|
||||
# nonzero_static.out: out resize (shrink)
|
||||
input_tensor = torch.tensor([8])
|
||||
static_size = 1
|
||||
out_tensor = torch.empty((10, 10, 10, 10), dtype=torch.long)
|
||||
self.assertTrue(
|
||||
same(
|
||||
torch.nonzero_static(input_tensor, size=static_size, out=out_tensor),
|
||||
torch.tensor([0]),
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
same(
|
||||
out_tensor,
|
||||
torch.tensor([0]),
|
||||
)
|
||||
)
|
||||
|
||||
# nonzero_static.out: out resize (enlarge)
|
||||
input_tensor = torch.tensor([8])
|
||||
static_size = 1
|
||||
out_tensor = torch.empty((0), dtype=torch.long)
|
||||
self.assertTrue(
|
||||
same(
|
||||
torch.nonzero_static(input_tensor, size=static_size, out=out_tensor),
|
||||
torch.tensor([0]),
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
same(
|
||||
out_tensor,
|
||||
torch.tensor([0]),
|
||||
)
|
||||
)
|
||||
|
||||
# 0 rank
|
||||
input_tensor = torch.tensor(6)
|
||||
static_size = 2
|
||||
self.assertTrue(
|
||||
same(
|
||||
torch.nonzero_static(input_tensor, size=static_size),
|
||||
torch.empty((static_size, input_tensor.dim()), dtype=torch.long),
|
||||
)
|
||||
)
|
||||
|
||||
# 0 size
|
||||
input_tensor = torch.tensor([[[1]]])
|
||||
static_size = 0
|
||||
self.assertTrue(
|
||||
same(
|
||||
torch.nonzero_static(input_tensor, size=static_size),
|
||||
torch.empty((static_size, input_tensor.dim()), dtype=torch.long),
|
||||
)
|
||||
)
|
||||
|
||||
# 1D input
|
||||
input_tensor = torch.tensor([0, 8])
|
||||
static_size = 1
|
||||
self.assertTrue(
|
||||
same(
|
||||
torch.nonzero_static(input_tensor, size=static_size),
|
||||
torch.tensor([1]),
|
||||
)
|
||||
)
|
||||
|
||||
input_tensor = torch.tensor([8, 0])
|
||||
static_size = 2
|
||||
self.assertTrue(
|
||||
same(
|
||||
torch.nonzero_static(input_tensor, size=static_size),
|
||||
torch.tensor([[0], [-1]]), # padded with default fill_value "-1"
|
||||
)
|
||||
)
|
||||
|
||||
# 2D input
|
||||
input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]])
|
||||
static_size = 5
|
||||
fill_value = -100
|
||||
self.assertTrue(
|
||||
torch._dynamo.utils.same(
|
||||
torch.nonzero_static(
|
||||
input_tensor, size=static_size, fill_value=fill_value
|
||||
),
|
||||
torch.tensor(
|
||||
[
|
||||
[0, 0],
|
||||
[1, 0],
|
||||
[1, 1],
|
||||
[fill_value, fill_value],
|
||||
[fill_value, fill_value],
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]])
|
||||
static_size = 2
|
||||
fill_value = -100
|
||||
self.assertTrue(
|
||||
torch._dynamo.utils.same(
|
||||
torch.nonzero_static(
|
||||
input_tensor, size=static_size, fill_value=fill_value
|
||||
),
|
||||
torch.tensor([[0, 0], [1, 0]]),
|
||||
)
|
||||
)
|
||||
|
||||
# 3D input
|
||||
input_tensor = torch.tensor([[[0, 0], [0, -3]], [[0, 0], [5, 0]]])
|
||||
static_size = 4
|
||||
fill_value = -999
|
||||
self.assertTrue(
|
||||
torch._dynamo.utils.same(
|
||||
torch.nonzero_static(
|
||||
input_tensor,
|
||||
size=static_size,
|
||||
fill_value=fill_value,
|
||||
),
|
||||
torch.tensor(
|
||||
[
|
||||
[0, 1, 1],
|
||||
[1, 1, 0],
|
||||
[fill_value, fill_value, fill_value],
|
||||
[fill_value, fill_value, fill_value],
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def test_cond_with_quantization(self):
|
||||
from functorch.experimental.control_flow import cond
|
||||
|
||||
|
|
|
|||
|
|
@ -1577,6 +1577,168 @@ class TestUnaryUfuncs(TestCase):
|
|||
y = torch.nonzero(x)
|
||||
self.assertEqual(y.view(-1), indices)
|
||||
|
||||
def test_nonzero_static(self, device):
|
||||
# invalid size
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "nonzero_static: 'size' must be an non-negative integer"
|
||||
):
|
||||
torch.nonzero_static(torch.tensor([8], device=device), size=-2)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "nonzero_static: 'size' must be an non-negative integer"
|
||||
):
|
||||
torch.nonzero_static(
|
||||
torch.tensor([8], device=device),
|
||||
size=-2,
|
||||
out=torch.tensor(0, device=device),
|
||||
)
|
||||
|
||||
# nonzero_static.out: out dtype mismatch
|
||||
input_tensor = torch.tensor([8], device=device)
|
||||
static_size = 1
|
||||
out_tensor = torch.empty(
|
||||
(static_size, input_tensor.dim()), dtype=torch.float, device=device
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "nonzero_static: Expected out tensor to have scalar type Long"
|
||||
):
|
||||
torch.nonzero_static(input_tensor, size=static_size, out=out_tensor)
|
||||
|
||||
# nonzero_static.out: out resize (shrink)
|
||||
input_tensor = torch.tensor([8], device=device)
|
||||
static_size = 1
|
||||
out_tensor = torch.empty((10, 10, 10, 10), dtype=torch.long, device=device)
|
||||
ref = torch.tensor([[0]], device=device)
|
||||
self.assertEqual(
|
||||
torch.nonzero_static(input_tensor, size=static_size, out=out_tensor), ref
|
||||
)
|
||||
self.assertEqual(out_tensor, ref)
|
||||
|
||||
# nonzero_static.out: out resize (enlarge)
|
||||
input_tensor = torch.tensor([8], device=device)
|
||||
static_size = 1
|
||||
out_tensor = torch.empty((0), dtype=torch.long, device=device)
|
||||
self.assertEqual(
|
||||
torch.nonzero_static(input_tensor, size=static_size, out=out_tensor), ref
|
||||
)
|
||||
self.assertEqual(out_tensor, ref)
|
||||
|
||||
# 0 rank
|
||||
input_tensor = torch.tensor(6, device=device)
|
||||
static_size = 2
|
||||
self.assertEqual(
|
||||
torch.nonzero_static(input_tensor, size=static_size),
|
||||
torch.empty(
|
||||
(static_size, input_tensor.dim()), device=device, dtype=torch.long
|
||||
),
|
||||
)
|
||||
|
||||
# 0 size
|
||||
input_tensor = torch.tensor([[[1]]], device=device)
|
||||
static_size = 0
|
||||
self.assertEqual(
|
||||
torch.nonzero_static(input_tensor, size=static_size),
|
||||
torch.empty(
|
||||
(static_size, input_tensor.dim()), device=device, dtype=torch.long
|
||||
),
|
||||
)
|
||||
|
||||
# 1D input
|
||||
input_tensor = torch.tensor([0, 8], device=device)
|
||||
static_size = 1
|
||||
self.assertEqual(
|
||||
torch.nonzero_static(input_tensor, size=static_size),
|
||||
torch.tensor([[1]], device=device),
|
||||
)
|
||||
|
||||
input_tensor = torch.tensor([8, 0], device=device)
|
||||
static_size = 2
|
||||
self.assertEqual(
|
||||
torch.nonzero_static(input_tensor, size=static_size),
|
||||
torch.tensor(
|
||||
[[0], [-1]], device=device
|
||||
), # padded with default fill_value "-1"
|
||||
)
|
||||
|
||||
# 2D input
|
||||
input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]], device=device)
|
||||
static_size = 5
|
||||
fill_value = -100
|
||||
self.assertEqual(
|
||||
torch.nonzero_static(input_tensor, size=static_size, fill_value=fill_value),
|
||||
torch.tensor(
|
||||
[
|
||||
[0, 0],
|
||||
[1, 0],
|
||||
[1, 1],
|
||||
[fill_value, fill_value],
|
||||
[fill_value, fill_value],
|
||||
],
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]], device=device)
|
||||
static_size = 2
|
||||
fill_value = -100
|
||||
self.assertEqual(
|
||||
torch.nonzero_static(input_tensor, size=static_size, fill_value=fill_value),
|
||||
torch.tensor([[0, 0], [1, 0]], device=device),
|
||||
)
|
||||
|
||||
# 3D input
|
||||
input_tensor = torch.tensor(
|
||||
[[[0, 0], [0, -3]], [[0, 0], [5, 0]]], device=device
|
||||
)
|
||||
static_size = 4
|
||||
fill_value = -999
|
||||
self.assertEqual(
|
||||
torch.nonzero_static(
|
||||
input_tensor,
|
||||
size=static_size,
|
||||
fill_value=fill_value,
|
||||
),
|
||||
torch.tensor(
|
||||
[
|
||||
[0, 1, 1],
|
||||
[1, 1, 0],
|
||||
[fill_value, fill_value, fill_value],
|
||||
[fill_value, fill_value, fill_value],
|
||||
],
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
|
||||
@onlyCUDA
|
||||
def test_nonzero_static_large(self, device):
|
||||
# large enough to have multiple iters per SM even on H100
|
||||
# with 132 sms
|
||||
size_inp = 1024 * 16 * 132 + 1024 * 16
|
||||
x = torch.zeros(size_inp, device=device)
|
||||
# unique indices
|
||||
indices = torch.randperm(size_inp, device=device)[: size_inp // 2]
|
||||
sorted, _ = torch.sort(indices)
|
||||
x[sorted] = 1
|
||||
res = torch.nonzero_static(x, size=size_inp // 2).view(-1)
|
||||
self.assertEqual(res, sorted)
|
||||
# no oob writes
|
||||
out = torch.full((size_inp,), 10, device=device, dtype=torch.int64)
|
||||
res = torch.nonzero_static(x, size=size_inp // 4, out=out[: size_inp // 2])
|
||||
self.assertEqual(out[: size_inp // 4], sorted[: size_inp // 4])
|
||||
self.assertEqual(
|
||||
out[size_inp // 4 :],
|
||||
torch.tensor(10, device="cuda").expand_as(out[size_inp // 4 :]),
|
||||
)
|
||||
# correct fill for 2d
|
||||
x = x.view(2, size_inp // 2)
|
||||
ref = x.nonzero()
|
||||
res = x.nonzero_static(size=size_inp // 2 + 2)
|
||||
self.assertEqual(res.shape, [size_inp // 2 + 2, 2])
|
||||
self.assertEqual(ref, res[: size_inp // 2])
|
||||
self.assertEqual(
|
||||
res[size_inp // 2 :],
|
||||
torch.tensor(-1, device="cuda").expand_as(res[size_inp // 2 :]),
|
||||
)
|
||||
|
||||
# TODO: rationalize with exp OpInfo
|
||||
|
||||
@dtypes(*floating_and_complex_types_and(torch.bfloat16))
|
||||
|
|
|
|||
Loading…
Reference in a new issue