mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-19 21:32:23 +00:00
[CUDA] Add use_tf32 cuda provider option (for FP32 Conv) (#19426)
Follow up of https://github.com/microsoft/onnxruntime/pull/19357 to apply the use_tf32 option on fp32 cuDNN convolution. When use_tf32 = 0, we will disable TF32 in cuDNN convolution for FP32 inputs. https://docs.nvidia.com/deeplearning/cudnn/api/cudnn-graph-library.html#cudnnmathtype-t **CUDNN_FMA_MATH** - Restricted to only kernels that use FMA instructions. - On pre-NVIDIA A100 GPU devices, CUDNN_DEFAULT_MATH and CUDNN_FMA_MATH have the same behavior: Tensor Core kernels will not be selected. - With NVIDIA Ampere architecture and CUDA toolkit 11, CUDNN_DEFAULT_MATH permits TF32 Tensor Core operation and CUDNN_FMA_MATH does not. - The TF32 behavior for CUDNN_DEFAULT_MATH and the other Tensor Core math types can be explicitly disabled by the environment variable NVIDIA_TF32_OVERRIDE=0.
This commit is contained in:
parent
e5ce81ae84
commit
3afb38cfb7
7 changed files with 35 additions and 12 deletions
|
|
@ -326,7 +326,8 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
|
|||
|
||||
ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
|
||||
gsl::narrow_cast<int>(conv_attrs_.group),
|
||||
CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType<CudaT>()));
|
||||
CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType<CudaT>(),
|
||||
UseTF32()));
|
||||
|
||||
if (context->InputCount() >= 3) {
|
||||
const Tensor* B = context->Input<Tensor>(2);
|
||||
|
|
@ -351,8 +352,13 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
|
|||
|
||||
if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) {
|
||||
// set math type to tensor core before algorithm search
|
||||
if constexpr (std::is_same<T, MLFloat16>::value)
|
||||
if constexpr (std::is_same<T, MLFloat16>::value) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH));
|
||||
} else if constexpr (std::is_same<T, float>::value) {
|
||||
if (!UseTF32()) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH));
|
||||
}
|
||||
}
|
||||
|
||||
cudnnConvolutionFwdAlgoPerf_t perf;
|
||||
int algo_count = 1;
|
||||
|
|
@ -399,6 +405,8 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
|
|||
CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(GetCudnnHandle(context), s_, perf.algo, &perf.memory));
|
||||
if (std::is_same<T, MLFloat16>::value) {
|
||||
perf.mathType = CUDNN_TENSOR_OP_MATH;
|
||||
} else if (std::is_same<T, float>::value && !UseTF32()) {
|
||||
perf.mathType = CUDNN_FMA_MATH;
|
||||
} else {
|
||||
perf.mathType = CUDNN_DEFAULT_MATH;
|
||||
}
|
||||
|
|
@ -480,7 +488,8 @@ Status CudnnConvolutionDescriptor::Set(
|
|||
const gsl::span<const int64_t>& dilations,
|
||||
int groups,
|
||||
cudnnConvolutionMode_t mode,
|
||||
cudnnDataType_t data_type) {
|
||||
cudnnDataType_t data_type,
|
||||
bool use_tf32) {
|
||||
if (!desc_)
|
||||
CUDNN_RETURN_IF_ERROR(cudnnCreateConvolutionDescriptor(&desc_));
|
||||
|
||||
|
|
@ -513,6 +522,8 @@ Status CudnnConvolutionDescriptor::Set(
|
|||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_DEFAULT_MATH));
|
||||
if (data_type == CUDNN_DATA_HALF) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_TENSOR_OP_MATH));
|
||||
} else if (data_type == CUDNN_DATA_FLOAT && !use_tf32) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_FMA_MATH));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -29,7 +29,8 @@ class CudnnConvolutionDescriptor final {
|
|||
const gsl::span<const int64_t>& dilations,
|
||||
int groups,
|
||||
cudnnConvolutionMode_t mode,
|
||||
cudnnDataType_t data_type);
|
||||
cudnnDataType_t data_type,
|
||||
bool use_tf32);
|
||||
|
||||
operator cudnnConvolutionDescriptor_t() const { return desc_; }
|
||||
|
||||
|
|
|
|||
|
|
@ -167,7 +167,8 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
|
|||
cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION;
|
||||
ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations,
|
||||
gsl::narrow_cast<int>(conv_transpose_attrs_.group), mode,
|
||||
CudnnTensor::GetDataType<CudaT>()));
|
||||
CudnnTensor::GetDataType<CudaT>(),
|
||||
UseTF32()));
|
||||
|
||||
if (has_bias) {
|
||||
const auto& b_shape = p.B->Shape();
|
||||
|
|
@ -187,8 +188,13 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
|
|||
GetScratchBuffer<void>(AlgoSearchWorkspaceSize, context->GetComputeStream());
|
||||
|
||||
// set math type to tensor core before algorithm search
|
||||
if constexpr (std::is_same<T, MLFloat16>::value)
|
||||
if constexpr (std::is_same<T, MLFloat16>::value) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH));
|
||||
} else if constexpr (std::is_same<T, float>::value) {
|
||||
if (!UseTF32()) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH));
|
||||
}
|
||||
}
|
||||
|
||||
cudnnConvolutionBwdDataAlgoPerf_t perf;
|
||||
int algo_count = 1;
|
||||
|
|
|
|||
|
|
@ -114,7 +114,8 @@ Status ConvGrad<T>::PrepareArgs(const Tensor& x, const Tensor& dY, const Tensor&
|
|||
ORT_RETURN_IF_ERROR(args_.y_tensor.Set(dy_dims, args_.params.data_type));
|
||||
ORT_RETURN_IF_ERROR(args_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
|
||||
gsl::narrow_cast<int>(conv_attrs_.group), CUDNN_CROSS_CORRELATION,
|
||||
args_.params.data_type));
|
||||
args_.params.data_type,
|
||||
UseTF32()));
|
||||
|
||||
if (dB) {
|
||||
const TensorShape& db_shape = dB->Shape();
|
||||
|
|
|
|||
|
|
@ -233,11 +233,13 @@ bool ConvParamsEqual::operator()(const ConvParams& a, const ConvParams& b) const
|
|||
}
|
||||
|
||||
template <typename T_Perf>
|
||||
Status AlgoIterator<T_Perf>::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results) {
|
||||
Status AlgoIterator<T_Perf>::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results, bool use_tf32) {
|
||||
perf_results.resize(1);
|
||||
perf_results[0].algo = AlgoSearch<T_Perf>::DEFAULT_ALGO;
|
||||
if (args.params.data_type == CUDNN_DATA_HALF) {
|
||||
perf_results[0].mathType = CUDNN_TENSOR_OP_MATH;
|
||||
} else if (args.params.data_type == CUDNN_DATA_FLOAT && !use_tf32) {
|
||||
perf_results[0].mathType = CUDNN_FMA_MATH;
|
||||
} else {
|
||||
perf_results[0].mathType = CUDNN_DEFAULT_MATH;
|
||||
}
|
||||
|
|
@ -256,7 +258,7 @@ Status AlgoIterator<T_Perf>::TryAll(const CUDAExecutionProvider* provider, const
|
|||
|
||||
std::vector<T_Perf> perf_results;
|
||||
ORT_RETURN_IF_ERROR(args_.params.algo_mode == OrtCudnnConvAlgoSearchDefault
|
||||
? OnlyDefaultAlgorithm(args_, perf_results)
|
||||
? OnlyDefaultAlgorithm(args_, perf_results, provider->UseTF32())
|
||||
: AlgoSearch<T_Perf>::FindAlgorithms(args_, provider, allocator, perf_results));
|
||||
for (auto& algo_perf : perf_results) {
|
||||
if (f(algo_perf) == Status::OK()) {
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ class AlgoIterator {
|
|||
Status TryAll(const CUDAExecutionProvider* provider, const AllocatorPtr& allocator,
|
||||
std::function<Status(const T_Perf& perf)> f);
|
||||
|
||||
static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results);
|
||||
static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results, bool use_tf32);
|
||||
|
||||
private:
|
||||
const ConvArgs& args_;
|
||||
|
|
|
|||
|
|
@ -182,7 +182,8 @@ Status ConvTransposeGrad<T>::PrepareConvForwardArgs(const Tensor& X, const Tenso
|
|||
ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type));
|
||||
ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
|
||||
gsl::narrow_cast<int>(conv_attrs_.group), CUDNN_CROSS_CORRELATION,
|
||||
args.params.data_type));
|
||||
args.params.data_type,
|
||||
UseTF32()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
@ -287,7 +288,8 @@ Status ConvTransposeGrad<T>::PrepareConvBackwardFilterArgs(const Tensor& X, cons
|
|||
ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type));
|
||||
ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
|
||||
gsl::narrow_cast<int>(conv_attrs_.group), CUDNN_CROSS_CORRELATION,
|
||||
args.params.data_type));
|
||||
args.params.data_type,
|
||||
UseTF32()));
|
||||
|
||||
if (dB) {
|
||||
const auto& b_shape = dB->Shape();
|
||||
|
|
|
|||
Loading…
Reference in a new issue