Disable cudnn transpose for int types (#26934)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26934

Disable cudnn transpose for int types

Did experiment with int + 4d/5d

Test Plan: buck test mode/dev-nosan caffe2/caffe2/python/operator_test:utility_ops_test

Reviewed By: houseroad

Differential Revision: D17607176

fbshipit-source-id: 83b9f9cf654b33d68b657f1b5a17d9bbd06df529
This commit is contained in:
Xiaomeng Yang 2019-09-27 11:33:16 -07:00 committed by Facebook Github Bot
parent 8fa9900c28
commit f77b295edc

View file

@ -2,6 +2,7 @@
#include <algorithm>
#include <limits>
#include <type_traits>
#include <vector>
#include "caffe2/core/context_gpu.h"
@ -65,14 +66,15 @@ class CuDNNTransposeOp final : public Operator<CUDAContext> {
if (X.numel() == 0) {
return true;
}
if (ndim < 3 || ndim > CUDNN_DIM_MAX ||
X.numel() > std::numeric_limits<std::int32_t>::max()) {
if (!IsFloatType<T>() || !IsCuDNNValidTensor(X)) {
math::Transpose<std::int64_t, T, CUDAContext>(
ndim, X_dims.data(), axes_.data(), X_data, Y_data, &context_);
return true;
}
if (X_dims != cached_X_dims_) {
if (cudnnTypeWrapper<T>::type != cached_dtype_ ||
X_dims != cached_X_dims_) {
SetTensorDescriptor(cudnnTypeWrapper<T>::type, X_dims, Y_dims);
cached_dtype_ = cudnnTypeWrapper<T>::type;
cached_X_dims_ = X_dims;
}
CUDNN_ENFORCE(cudnnTransformTensor(
@ -87,6 +89,18 @@ class CuDNNTransposeOp final : public Operator<CUDAContext> {
}
private:
template <typename T>
constexpr bool IsFloatType() const {
return std::is_same<T, float>::value || std::is_same<T, double>::value ||
std::is_same<T, at::Half>::value;
}
bool IsCuDNNValidTensor(const Tensor& X) const {
const int ndim = X.dim();
return ndim >= 3 && ndim <= CUDNN_DIM_MAX &&
X.numel() < std::numeric_limits<int32_t>::max();
}
void SetTensorDescriptor(
const cudnnDataType_t data_type,
const std::vector<std::int64_t>& X_dims,
@ -115,6 +129,7 @@ class CuDNNTransposeOp final : public Operator<CUDAContext> {
cudnnTensorDescriptor_t X_desc_;
cudnnTensorDescriptor_t Y_desc_;
cudnnDataType_t cached_dtype_ = cudnnTypeWrapper<float>::type;
std::vector<std::int64_t> cached_X_dims_;
std::vector<std::int32_t> axes_;
};