mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[CUDA][B200] Update the number of threads in avg_pool2d backward for SM 10.0 (#145669)
Fixes register count issue when launching on SM 10.0, originally authored by @bilal2vec Pull Request resolved: https://github.com/pytorch/pytorch/pull/145669 Approved by: https://github.com/nWEIdia, https://github.com/ngimel
This commit is contained in:
parent
99ddbb4802
commit
07b214402a
1 changed files with 9 additions and 3 deletions
|
|
@ -399,15 +399,21 @@ TORCH_IMPL_FUNC(avg_pool2d_backward_out_cuda) (
|
|||
return;
|
||||
}
|
||||
|
||||
const uint32_t num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
|
||||
const uint32_t num_blocks = ceil_div<uint32_t>(count, num_threads);
|
||||
|
||||
bool use_divisor = divisor_override.has_value();
|
||||
const auto divisor_override_value = use_divisor ? divisor_override.value() : 0;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000
|
||||
constexpr int double_threads = 768;
|
||||
#else
|
||||
constexpr int double_threads = 1024;
|
||||
#endif
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
|
||||
"avg_pool2d_backward_out_cuda_frame",
|
||||
[&] {
|
||||
const uint32_t num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, std::is_same<scalar_t, double>::value ? double_threads : 1024);
|
||||
const uint32_t num_blocks = ceil_div<uint32_t>(count, num_threads);
|
||||
|
||||
using accscalar_t = acc_type<scalar_t, true>;
|
||||
|
||||
const scalar_t *gradOutput_data = gradOutput.const_data_ptr<scalar_t>();
|
||||
|
|
|
|||
Loading…
Reference in a new issue