mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[pytorch] fix blasLt on windows (#125792)
Summary: It seems like required functions are not available due to `_MSC_VER` guard. Does anyone have more context why this functionality has been disabled for windows? I'm also unsure how this currently compiles in OSS land on windows, as there doesn't seem to be any preprocessor protection around `scaled_gemm` getting pulled in. Test Plan: Fix compilation errors like this ``` C:\open\fbsource\xplat\caffe2\aten\src\ATen\cuda\tunable\TunableGemm.h(74): error C2039: 'scaled_gemm': is not a member of 'at::cuda::blas' C:\open\fbsource\xplat\caffe2\aten\src\ATen\cuda\CUDABlas.h(19): note: see declaration of 'at::cuda::blas' C:\open\fbsource\xplat\caffe2\aten\src\ATen\cuda\tunable\TunableGemm.h(74): note: the template instantiation context (the oldest one first) is C:\open\fbsource\xplat\caffe2\aten\src\ATen\cuda\tunable\TunableGemm.h(71): note: while compiling class template 'at::cuda::tunable::DefaultScaledGemmOp' Action failed: fbsource//xplat/caffe2:ATen_cuda_lib_ovrsource (cxx_compile aten/src/ATen/native/cuda/Blas.cpp) ``` Differential Revision: D57087985 Pull Request resolved: https://github.com/pytorch/pytorch/pull/125792 Approved by: https://github.com/malfet, https://github.com/eqy
This commit is contained in:
parent
902a74c1d6
commit
fdff9920f6
5 changed files with 50 additions and 28 deletions
|
|
@ -236,7 +236,7 @@ namespace at::cuda::blas {
|
|||
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, num_batches); \
|
||||
} while (0)
|
||||
|
||||
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000
|
||||
// only for rocm 5.7 where we first supported hipblaslt, it was difficult
|
||||
|
|
@ -375,7 +375,7 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
|
|||
|
||||
template <typename Dtype>
|
||||
inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
||||
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
|
||||
cudaDataType_t abcType = CUDA_R_32F;
|
||||
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
||||
cudaDataType_t scaleType = CUDA_R_32F;
|
||||
|
|
@ -1235,7 +1235,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
|||
}
|
||||
}
|
||||
|
||||
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
|
||||
template <typename Dtype>
|
||||
void gemm_and_bias(
|
||||
|
|
@ -1745,7 +1745,7 @@ void int8_gemm(
|
|||
TORCH_CHECK(false, "int8_gemm is only supported for ROCm 6.0 and above");
|
||||
#endif // !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
|
||||
}
|
||||
#endif // (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
#endif // !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
|
||||
// ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not.
|
||||
#if defined(USE_ROCM) && ROCM_VERSION < 50600
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
|
|||
template <>
|
||||
void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
|
||||
|
||||
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
enum GEMMAndBiasActivationEpilogue {
|
||||
None,
|
||||
RELU,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
|
||||
// added bf16 support
|
||||
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
#include <cublasLt.h>
|
||||
#endif
|
||||
|
||||
|
|
@ -82,7 +82,7 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
|
|||
/* Handles */
|
||||
TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
|
||||
TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
|
||||
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -191,7 +191,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
|
|||
return handle;
|
||||
}
|
||||
|
||||
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
cublasLtHandle_t getCurrentCUDABlasLtHandle() {
|
||||
#ifdef USE_ROCM
|
||||
c10::DeviceIndex device = 0;
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ enum class Activation {
|
|||
GELU,
|
||||
};
|
||||
|
||||
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) {
|
||||
switch (a) {
|
||||
case Activation::None:
|
||||
|
|
@ -236,7 +236,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
|||
at::ScalarType scalar_type = self.scalar_type();
|
||||
c10::MaybeOwned<Tensor> self_;
|
||||
if (&result != &self) {
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && ROCM_VERSION >= 50700
|
||||
#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11040)) || (defined(USE_ROCM) && (ROCM_VERSION >= 50700))
|
||||
// Strangely, if mat2 has only 1 row or column, we get
|
||||
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
|
||||
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
|
||||
|
|
@ -334,8 +334,9 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
|||
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
|
||||
|
||||
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && (ROCM_VERSION >= 50700))
|
||||
if (useLtInterface) {
|
||||
#if defined(USE_ROCM)
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
|
|
@ -353,28 +354,49 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
|||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
#if defined(USE_ROCM)
|
||||
// This condition is needed for mm case on ROCm for hipblasLt path.
|
||||
// Passing the bias ptr as null to avoid accuracy issues for mm case.
|
||||
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
|
||||
#else
|
||||
self.const_data_ptr<scalar_t>(),
|
||||
#endif
|
||||
args.result->data_ptr<scalar_t>(),
|
||||
args.result_ld,
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11080) || defined(USE_ROCM)
|
||||
activation_to_gemm_and_blas_arg(activation)
|
||||
#else
|
||||
// GELU is not supported (and does not compile!) prior
|
||||
// to CUDA 11.4. Have observed accuracy issues with
|
||||
// GELU epilogue in 11.4; disabling the GELU epilogue
|
||||
// path for CUDA version < 11.8.
|
||||
activation != Activation::GELU
|
||||
? activation_to_gemm_and_blas_arg(activation)
|
||||
: cuda::blas::GEMMAndBiasActivationEpilogue::None
|
||||
#endif
|
||||
);
|
||||
});
|
||||
#else
|
||||
auto activation_epilogue = activation_to_gemm_and_blas_arg(activation);
|
||||
#if (defined(CUDA_VERSION) && (CUDA_VERSION < 11080))
|
||||
// GELU is not supported (and does not compile!) prior
|
||||
// to CUDA 11.4. Have observed accuracy issues with
|
||||
// GELU epilogue in 11.4; disabling the GELU epilogue
|
||||
// path for CUDA version < 11.8.
|
||||
if (activation == Activation::GELU)
|
||||
activation_epilogue = cuda::blas::GEMMAndBiasActivationEpilogue::None;
|
||||
#endif
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
scalar_type,
|
||||
"addmm_cuda_lt",
|
||||
[&] {
|
||||
at::cuda::blas::gemm_and_bias<scalar_t>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
self.const_data_ptr<scalar_t>(),
|
||||
args.result->data_ptr<scalar_t>(),
|
||||
args.result_ld,
|
||||
activation_epilogue
|
||||
);
|
||||
});
|
||||
#endif
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
|
|
@ -748,7 +770,7 @@ Tensor& _int_mm_out_cuda(const Tensor& self, const Tensor& mat2, Tensor& result)
|
|||
|
||||
TORCH_CHECK(result.is_contiguous(), "Expected result to be contiguous.");
|
||||
|
||||
#if (!defined(USE_ROCM) && !defined(_MSC_VER) && defined(CUDA_VERSION) && CUDA_VERSION >= 11070) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
|
||||
#if (!defined(USE_ROCM) && defined(CUDA_VERSION) && (CUDA_VERSION >= 11070)) || (defined(USE_ROCM) && (ROCM_VERSION >= 60000))
|
||||
cublasCommonArgs args(self, mat2, result);
|
||||
|
||||
at::cuda::blas::int8_gemm(
|
||||
|
|
@ -768,7 +790,7 @@ Tensor& _int_mm_out_cuda(const Tensor& self, const Tensor& mat2, Tensor& result)
|
|||
result.copy_(*args.result);
|
||||
}
|
||||
#else
|
||||
#if !defined(USE_ROCM) && !defined(_MSC_VER) && defined(CUDA_VERSION)
|
||||
#if !defined(USE_ROCM) && defined(CUDA_VERSION)
|
||||
TORCH_CHECK(false, "_int_mm_out_cuda not compiled for CUDA ", CUDA_VERSION);
|
||||
#else
|
||||
TORCH_CHECK(false, "_int_mm_out_cuda not compiled for this platform.");
|
||||
|
|
@ -888,7 +910,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
|||
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
|
||||
at::native::resize_output(amax, {});
|
||||
|
||||
#if !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && (ROCM_VERSION >= 60000))
|
||||
cublasCommonArgs args(mat1, mat2, out);
|
||||
const auto out_dtype_ = args.result->scalar_type();
|
||||
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
|
||||
|
|
|
|||
Loading…
Reference in a new issue