[pytorch] fix blasLt on windows (#125792)

Summary:
It seems like required functions are not available due to `_MSC_VER` guard. Does anyone have more context why this functionality has been disabled for windows?

I'm also unsure how this currently compiles in OSS land on windows, as there doesn't seem to be any preprocessor protection around `scaled_gemm` getting pulled in.

Test Plan:
Fix compilation errors like this
```
C:\open\fbsource\xplat\caffe2\aten\src\ATen\cuda\tunable\TunableGemm.h(74): error C2039: 'scaled_gemm': is not a member of 'at::cuda::blas'
C:\open\fbsource\xplat\caffe2\aten\src\ATen\cuda\CUDABlas.h(19): note: see declaration of 'at::cuda::blas'
C:\open\fbsource\xplat\caffe2\aten\src\ATen\cuda\tunable\TunableGemm.h(74): note: the template instantiation context (the oldest one first) is
C:\open\fbsource\xplat\caffe2\aten\src\ATen\cuda\tunable\TunableGemm.h(71): note: while compiling class template 'at::cuda::tunable::DefaultScaledGemmOp'
Action failed: fbsource//xplat/caffe2:ATen_cuda_lib_ovrsource (cxx_compile aten/src/ATen/native/cuda/Blas.cpp)
```

Differential Revision: D57087985

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125792
Approved by: https://github.com/malfet, https://github.com/eqy
This commit is contained in:
Michael Ranieri 2024-05-09 01:54:25 +00:00 committed by PyTorch MergeBot
parent 902a74c1d6
commit fdff9920f6
5 changed files with 50 additions and 28 deletions

View file

@ -236,7 +236,7 @@ namespace at::cuda::blas {
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, num_batches); \
} while (0)
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000
// only for rocm 5.7 where we first supported hipblaslt, it was difficult
@ -375,7 +375,7 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
template <typename Dtype>
inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
cudaDataType_t abcType = CUDA_R_32F;
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
cudaDataType_t scaleType = CUDA_R_32F;
@ -1235,7 +1235,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
}
}
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
template <typename Dtype>
void gemm_and_bias(
@ -1745,7 +1745,7 @@ void int8_gemm(
TORCH_CHECK(false, "int8_gemm is only supported for ROCm 6.0 and above");
#endif // !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
}
#endif // (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#endif // !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
// ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not.
#if defined(USE_ROCM) && ROCM_VERSION < 50600

View file

@ -82,7 +82,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
template <>
void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
enum GEMMAndBiasActivationEpilogue {
None,
RELU,

View file

@ -9,7 +9,7 @@
// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
// added bf16 support
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#include <cublasLt.h>
#endif
@ -82,7 +82,7 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
/* Handles */
TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
#endif

View file

@ -191,7 +191,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
return handle;
}
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
cublasLtHandle_t getCurrentCUDABlasLtHandle() {
#ifdef USE_ROCM
c10::DeviceIndex device = 0;

View file

@ -157,7 +157,7 @@ enum class Activation {
GELU,
};
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) {
switch (a) {
case Activation::None:
@ -236,7 +236,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
at::ScalarType scalar_type = self.scalar_type();
c10::MaybeOwned<Tensor> self_;
if (&result != &self) {
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && ROCM_VERSION >= 50700
#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11040)) || (defined(USE_ROCM) && (ROCM_VERSION >= 50700))
// Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
@ -334,8 +334,9 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && (ROCM_VERSION >= 50700))
if (useLtInterface) {
#if defined(USE_ROCM)
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
@ -353,28 +354,49 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
#if defined(USE_ROCM)
// This condition is needed for mm case on ROCm for hipblasLt path.
// Passing the bias ptr as null to avoid accuracy issues for mm case.
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
#else
self.const_data_ptr<scalar_t>(),
#endif
args.result->data_ptr<scalar_t>(),
args.result_ld,
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11080) || defined(USE_ROCM)
activation_to_gemm_and_blas_arg(activation)
#else
// GELU is not supported (and does not compile!) prior
// to CUDA 11.4. Have observed accuracy issues with
// GELU epilogue in 11.4; disabling the GELU epilogue
// path for CUDA version < 11.8.
activation != Activation::GELU
? activation_to_gemm_and_blas_arg(activation)
: cuda::blas::GEMMAndBiasActivationEpilogue::None
#endif
);
});
#else
auto activation_epilogue = activation_to_gemm_and_blas_arg(activation);
#if (defined(CUDA_VERSION) && (CUDA_VERSION < 11080))
// GELU is not supported (and does not compile!) prior
// to CUDA 11.4. Have observed accuracy issues with
// GELU epilogue in 11.4; disabling the GELU epilogue
// path for CUDA version < 11.8.
if (activation == Activation::GELU)
activation_epilogue = cuda::blas::GEMMAndBiasActivationEpilogue::None;
#endif
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
scalar_type,
"addmm_cuda_lt",
[&] {
at::cuda::blas::gemm_and_bias<scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self.const_data_ptr<scalar_t>(),
args.result->data_ptr<scalar_t>(),
args.result_ld,
activation_epilogue
);
});
#endif
} else
#endif
{
@ -748,7 +770,7 @@ Tensor& _int_mm_out_cuda(const Tensor& self, const Tensor& mat2, Tensor& result)
TORCH_CHECK(result.is_contiguous(), "Expected result to be contiguous.");
#if (!defined(USE_ROCM) && !defined(_MSC_VER) && defined(CUDA_VERSION) && CUDA_VERSION >= 11070) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
#if (!defined(USE_ROCM) && defined(CUDA_VERSION) && (CUDA_VERSION >= 11070)) || (defined(USE_ROCM) && (ROCM_VERSION >= 60000))
cublasCommonArgs args(self, mat2, result);
at::cuda::blas::int8_gemm(
@ -768,7 +790,7 @@ Tensor& _int_mm_out_cuda(const Tensor& self, const Tensor& mat2, Tensor& result)
result.copy_(*args.result);
}
#else
#if !defined(USE_ROCM) && !defined(_MSC_VER) && defined(CUDA_VERSION)
#if !defined(USE_ROCM) && defined(CUDA_VERSION)
TORCH_CHECK(false, "_int_mm_out_cuda not compiled for CUDA ", CUDA_VERSION);
#else
TORCH_CHECK(false, "_int_mm_out_cuda not compiled for this platform.");
@ -888,7 +910,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
at::native::resize_output(amax, {});
#if !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && (ROCM_VERSION >= 60000))
cublasCommonArgs args(mat1, mat2, out);
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");