mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Disable cudnn transpose for int types (#26934)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26934 Disable cudnn transpose for int types Did experiment with int + 4d/5d Test Plan: buck test mode/dev-nosan caffe2/caffe2/python/operator_test:utility_ops_test Reviewed By: houseroad Differential Revision: D17607176 fbshipit-source-id: 83b9f9cf654b33d68b657f1b5a17d9bbd06df529
This commit is contained in:
parent
8fa9900c28
commit
f77b295edc
1 changed files with 18 additions and 3 deletions
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
|
|
@ -65,14 +66,15 @@ class CuDNNTransposeOp final : public Operator<CUDAContext> {
|
|||
if (X.numel() == 0) {
|
||||
return true;
|
||||
}
|
||||
if (ndim < 3 || ndim > CUDNN_DIM_MAX ||
|
||||
X.numel() > std::numeric_limits<std::int32_t>::max()) {
|
||||
if (!IsFloatType<T>() || !IsCuDNNValidTensor(X)) {
|
||||
math::Transpose<std::int64_t, T, CUDAContext>(
|
||||
ndim, X_dims.data(), axes_.data(), X_data, Y_data, &context_);
|
||||
return true;
|
||||
}
|
||||
if (X_dims != cached_X_dims_) {
|
||||
if (cudnnTypeWrapper<T>::type != cached_dtype_ ||
|
||||
X_dims != cached_X_dims_) {
|
||||
SetTensorDescriptor(cudnnTypeWrapper<T>::type, X_dims, Y_dims);
|
||||
cached_dtype_ = cudnnTypeWrapper<T>::type;
|
||||
cached_X_dims_ = X_dims;
|
||||
}
|
||||
CUDNN_ENFORCE(cudnnTransformTensor(
|
||||
|
|
@ -87,6 +89,18 @@ class CuDNNTransposeOp final : public Operator<CUDAContext> {
|
|||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
constexpr bool IsFloatType() const {
|
||||
return std::is_same<T, float>::value || std::is_same<T, double>::value ||
|
||||
std::is_same<T, at::Half>::value;
|
||||
}
|
||||
|
||||
bool IsCuDNNValidTensor(const Tensor& X) const {
|
||||
const int ndim = X.dim();
|
||||
return ndim >= 3 && ndim <= CUDNN_DIM_MAX &&
|
||||
X.numel() < std::numeric_limits<int32_t>::max();
|
||||
}
|
||||
|
||||
void SetTensorDescriptor(
|
||||
const cudnnDataType_t data_type,
|
||||
const std::vector<std::int64_t>& X_dims,
|
||||
|
|
@ -115,6 +129,7 @@ class CuDNNTransposeOp final : public Operator<CUDAContext> {
|
|||
cudnnTensorDescriptor_t X_desc_;
|
||||
cudnnTensorDescriptor_t Y_desc_;
|
||||
|
||||
cudnnDataType_t cached_dtype_ = cudnnTypeWrapper<float>::type;
|
||||
std::vector<std::int64_t> cached_X_dims_;
|
||||
std::vector<std::int32_t> axes_;
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue