[CUDA][B200] Update the number of threads in avg_pool2d backward for SM 10.0 (#145669)

Fixes register count issue when launching on SM 10.0, originally authored by @bilal2vec

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145669
Approved by: https://github.com/nWEIdia, https://github.com/ngimel
This commit is contained in:
eqy 2025-02-06 18:57:30 +00:00 committed by PyTorch MergeBot
parent 99ddbb4802
commit 07b214402a

View file

@ -399,15 +399,21 @@ TORCH_IMPL_FUNC(avg_pool2d_backward_out_cuda) (
return;
}
const uint32_t num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
const uint32_t num_blocks = ceil_div<uint32_t>(count, num_threads);
bool use_divisor = divisor_override.has_value();
const auto divisor_override_value = use_divisor ? divisor_override.value() : 0;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000
constexpr int double_threads = 768;
#else
constexpr int double_threads = 1024;
#endif
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
"avg_pool2d_backward_out_cuda_frame",
[&] {
const uint32_t num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, std::is_same<scalar_t, double>::value ? double_threads : 1024);
const uint32_t num_blocks = ceil_div<uint32_t>(count, num_threads);
using accscalar_t = acc_type<scalar_t, true>;
const scalar_t *gradOutput_data = gradOutput.const_data_ptr<scalar_t>();