mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[CUDA] Build nhwc ops by default (#22648)
### Description * Build cuda nhwc ops by default. * Deprecate `--enable_cuda_nhwc_ops` in build.py and add `--disable_cuda_nhwc_ops` option Note that it requires cuDNN 9.x. If you build with cuDNN 8, NHWC ops will be disabled automatically. ### Motivation and Context In general, NHWC is faster than NCHW for convolution in Nvidia GPUs with Tensor Cores, and this could improve performance for vision models. This is the first step to prefer NHWC for CUDA in 1.21 release. Next step is to do some tests on popular vision models. If it help in most models and devices, set `prefer_nhwc=1` as default cuda provider option.
This commit is contained in:
parent
ba22d7879a
commit
72186bbb71
8 changed files with 85 additions and 31 deletions
|
|
@ -86,7 +86,7 @@ option(onnxruntime_USE_CUDA "Build with CUDA support" OFF)
|
|||
# use. If you hit any problem with that, please do not report it to GTest. Turn OFF the following build option instead.
|
||||
cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS" OFF)
|
||||
|
||||
option(onnxruntime_USE_CUDA_NHWC_OPS "Build CUDA with NHWC op support" OFF)
|
||||
cmake_dependent_option(onnxruntime_USE_CUDA_NHWC_OPS "Build CUDA with NHWC op support" ON "onnxruntime_USE_CUDA" OFF)
|
||||
option(onnxruntime_CUDA_MINIMAL "Build CUDA without any operations apart from memcpy ops. Usefuel for a very minial TRT build" OFF)
|
||||
option(onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO "When building with CUDA support, generate device code line number information." OFF)
|
||||
option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF)
|
||||
|
|
|
|||
|
|
@ -56,7 +56,6 @@ RUN cd /code \
|
|||
--build_shared_lib --skip_tests \
|
||||
--config Release --build_wheel --update --build --parallel \
|
||||
--cmake_generator Ninja \
|
||||
--enable_cuda_nhwc_ops \
|
||||
--cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) "CMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}" onnxruntime_BUILD_UNIT_TESTS=OFF
|
||||
|
||||
# Start second stage to copy the build artifacts
|
||||
|
|
|
|||
|
|
@ -925,6 +925,35 @@ Do not modify directly.*
|
|||
|WhisperBeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *in* decoder_input_ids:**I**<br> *in* logits_processor:**I**<br> *in* cross_qk_layer_head:**I**<br> *in* extra_decoding_ids:**I**<br> *in* temperature:**T**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**<br> *out* cross_qk:**V**<br> *out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
| |
|
||||
| |
|
||||
|**Operator Domain:** *com.ms.internal.nhwc*||||
|
||||
|AveragePool|*in* X:**T**<br> *out* Y:**T**|11+|**T** = tensor(float), tensor(float16)|
|
||||
|||10|**T** = tensor(float), tensor(float16)|
|
||||
|||[7, 9]|**T** = tensor(float), tensor(float16)|
|
||||
|BatchNormalization|*in* X:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *in* input_mean:**U**<br> *in* input_var:**U**<br> *out* Y:**T**<br> *out* running_mean:**U**<br> *out* running_var:**U**<br><br>or<br><br>*in* X:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *in* mean:**T**<br> *in* var:**T**<br> *out* Y:**T**<br> *out* mean:**T**<br> *out* var:**T**<br> *out* saved_mean:**T**<br> *out* saved_var:**T**<br><br>or<br><br>*in* X:**T**<br> *in* scale:**T1**<br> *in* B:**T1**<br> *in* input_mean:**T2**<br> *in* input_var:**T2**<br> *out* Y:**T**<br> *out* running_mean:**T2**<br> *out* running_var:**T2**|15+|**T** = tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(double), tensor(float), tensor(float16)|
|
||||
|||14|**T** = tensor(double), tensor(float), tensor(float16)<br/> **U** = tensor(double), tensor(float), tensor(float16)|
|
||||
|||[9, 13]|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|Conv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|11+|**T** = tensor(float), tensor(float16)|
|
||||
|||[1, 10]|**T** = tensor(float), tensor(float16)|
|
||||
|ConvTranspose|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|11+|**T** = tensor(float), tensor(float16)|
|
||||
|||[1, 10]|**T** = tensor(float), tensor(float16)|
|
||||
|DepthToSpace|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|GlobalAveragePool|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|GlobalMaxPool|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|GridSample|*in* X:**T1**<br> *in* grid:**T2**<br> *out* Y:**T1**|16+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|
||||
|LRN|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|MaxPool|*in* X:**T**<br> *out* Y:**T**<br><br>or<br><br>*in* X:**T**<br> *out* Y:**T**<br> *out* Indices:**I**|12+|**I** = tensor(int64)<br/> **T** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)|
|
||||
|||11|**I** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
|
||||
|||10|**I** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
|
||||
|||[8, 9]|**I** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
|
||||
|||[1, 7]|**T** = tensor(float), tensor(float16)|
|
||||
|SpaceToDepth|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
| |
|
||||
| |
|
||||
|
||||
|
||||
<a name="dmlexecutionprovider"/>
|
||||
|
|
|
|||
|
|
@ -191,7 +191,6 @@ build_onnxruntime_gpu_for_profiling() {
|
|||
--build_wheel --skip_tests \
|
||||
--cmake_generator Ninja \
|
||||
--compile_no_warning_as_error \
|
||||
--enable_cuda_nhwc_ops \
|
||||
--cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=$CUDA_ARCH \
|
||||
--cmake_extra_defines onnxruntime_ENABLE_NVTX_PROFILE=ON \
|
||||
--enable_cuda_line_info
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/graph/constants.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
|
||||
using namespace std;
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
|
@ -28,7 +29,8 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes,
|
|||
optional<float> epsilon = optional<float>(),
|
||||
OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess,
|
||||
const std::string& err_str = "",
|
||||
int opset = 7) {
|
||||
int opset = 7,
|
||||
bool exclude_cuda_nhwc = false) {
|
||||
OpTester test("Conv", opset);
|
||||
test.AddAttribute("group", attributes.group);
|
||||
test.AddAttribute("kernel_shape", attributes.kernel_shape);
|
||||
|
|
@ -65,6 +67,12 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes,
|
|||
// Disable TensorRT because weight as input is not supported
|
||||
excluded_providers.insert(kTensorrtExecutionProvider);
|
||||
|
||||
if (exclude_cuda_nhwc) {
|
||||
#ifdef ENABLE_CUDA_NHWC_OPS
|
||||
excluded_providers.insert(kCudaNHWCExecutionProvider);
|
||||
#endif
|
||||
}
|
||||
|
||||
// QNN SDK 2.10.0 has a bug that breaks support for dynamic bias inputs.
|
||||
excluded_providers.insert(kQnnExecutionProvider);
|
||||
|
||||
|
|
@ -197,10 +205,15 @@ TEST(ConvTest, Conv1D_Bias) {
|
|||
// as TF32 has a 10 bit mantissa.
|
||||
float epsilon = 1.1e-5f;
|
||||
|
||||
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, false, epsilon);
|
||||
// This case is not supported by cuDNN frontend, and the fallback (legacy code) requires weight to 4D tensor for NHWC.
|
||||
constexpr bool exclude_cuda_nhwc = true;
|
||||
|
||||
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, false, epsilon,
|
||||
OpTester::ExpectResult::kExpectSuccess, "", 10, exclude_cuda_nhwc);
|
||||
|
||||
// CoreML EP requires weight to be an initializer
|
||||
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true, epsilon);
|
||||
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true, epsilon,
|
||||
OpTester::ExpectResult::kExpectSuccess, "", 10, exclude_cuda_nhwc);
|
||||
}
|
||||
|
||||
// Conv47
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import shlex
|
|||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
|
|
@ -253,7 +254,12 @@ def parse_arguments():
|
|||
"--cudnn_home is not specified.",
|
||||
)
|
||||
parser.add_argument("--enable_cuda_line_info", action="store_true", help="Enable CUDA line info.")
|
||||
parser.add_argument("--enable_cuda_nhwc_ops", action="store_true", help="Enable CUDA NHWC ops in build.")
|
||||
|
||||
parser.add_argument(
|
||||
"--enable_cuda_nhwc_ops", action="store_true", help="Deprecated; default to enable CUDA NHWC ops in build."
|
||||
)
|
||||
|
||||
parser.add_argument("--disable_cuda_nhwc_ops", action="store_true", help="Disable CUDA NHWC ops in build.")
|
||||
|
||||
# Python bindings
|
||||
parser.add_argument("--enable_pybind", action="store_true", help="Enable Python Bindings.")
|
||||
|
|
@ -793,6 +799,11 @@ def parse_arguments():
|
|||
if args.cmake_generator is None and is_windows():
|
||||
args.cmake_generator = "Ninja" if args.build_wasm else "Visual Studio 17 2022"
|
||||
|
||||
if args.enable_cuda_nhwc_ops:
|
||||
warnings.warn(
|
||||
"The argument '--enable_cuda_nhwc_ops' is deprecated and is default to True. ", DeprecationWarning
|
||||
)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
|
|
@ -1074,7 +1085,7 @@ def generate_build_tree(
|
|||
"-Donnxruntime_USE_MPI=" + ("ON" if args.use_mpi else "OFF"),
|
||||
"-Donnxruntime_ENABLE_MEMORY_PROFILE=" + ("ON" if args.enable_memory_profile else "OFF"),
|
||||
"-Donnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO=" + ("ON" if args.enable_cuda_line_info else "OFF"),
|
||||
"-Donnxruntime_USE_CUDA_NHWC_OPS=" + ("ON" if args.enable_cuda_nhwc_ops else "OFF"),
|
||||
"-Donnxruntime_USE_CUDA_NHWC_OPS=" + ("ON" if args.use_cuda and not args.disable_cuda_nhwc_ops else "OFF"),
|
||||
"-Donnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB=" + ("ON" if args.build_wasm_static_lib else "OFF"),
|
||||
"-Donnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING="
|
||||
+ ("OFF" if args.disable_wasm_exception_catching else "ON"),
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ stages:
|
|||
--parallel \
|
||||
--build_wheel \
|
||||
--enable_onnx_tests --use_cuda --cuda_version=11.8 --cuda_home=/usr/local/cuda-11.8 --cudnn_home=/usr/local/cuda-11.8 \
|
||||
--enable_cuda_profiling --enable_cuda_nhwc_ops \
|
||||
--enable_cuda_profiling \
|
||||
--enable_pybind --build_java \
|
||||
--use_cache \
|
||||
--cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=75;86' ; \
|
||||
|
|
|
|||
|
|
@ -3,28 +3,31 @@ set -ex
|
|||
#Every cuda container has this $CUDA_VERSION env var set.
|
||||
SHORT_CUDA_VERSION=$(echo $CUDA_VERSION | sed 's/\([[:digit:]]\+\.[[:digit:]]\+\)\.[[:digit:]]\+/\1/')
|
||||
|
||||
BUILD_ARGS=('--config' 'Release' '--update' '--build'
|
||||
'--skip_submodule_sync'
|
||||
'--build_shared_lib'
|
||||
'--parallel' '--use_binskim_compliant_compile_flags'
|
||||
'--build_wheel'
|
||||
'--enable_onnx_tests'
|
||||
'--use_cuda'
|
||||
"--cuda_version=$SHORT_CUDA_VERSION"
|
||||
"--cuda_home=/usr/local/cuda-$SHORT_CUDA_VERSION"
|
||||
"--cudnn_home=/usr/local/cuda-$SHORT_CUDA_VERSION"
|
||||
"--enable_cuda_profiling"
|
||||
"--enable_cuda_nhwc_ops"
|
||||
"--enable_pybind"
|
||||
"--build_java"
|
||||
"--cmake_extra_defines"
|
||||
"CMAKE_CUDA_ARCHITECTURES=75"
|
||||
"onnxruntime_BUILD_UNIT_TESTS=ON"
|
||||
"onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON")
|
||||
BUILD_ARGS=('--config'
|
||||
'Release'
|
||||
'--update'
|
||||
'--build'
|
||||
'--skip_submodule_sync'
|
||||
'--build_shared_lib'
|
||||
'--parallel'
|
||||
'--use_binskim_compliant_compile_flags'
|
||||
'--build_wheel'
|
||||
'--enable_onnx_tests'
|
||||
'--use_cuda'
|
||||
"--cuda_version=$SHORT_CUDA_VERSION"
|
||||
"--cuda_home=/usr/local/cuda-$SHORT_CUDA_VERSION"
|
||||
"--cudnn_home=/usr/local/cuda-$SHORT_CUDA_VERSION"
|
||||
"--enable_cuda_profiling"
|
||||
"--enable_pybind"
|
||||
"--build_java"
|
||||
"--cmake_extra_defines"
|
||||
"CMAKE_CUDA_ARCHITECTURES=75"
|
||||
"onnxruntime_BUILD_UNIT_TESTS=ON"
|
||||
"onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON")
|
||||
if [ -x "$(command -v ninja)" ]; then
|
||||
BUILD_ARGS+=('--cmake_generator' 'Ninja')
|
||||
fi
|
||||
|
||||
|
||||
if [ -d /build ]; then
|
||||
BUILD_ARGS+=('--build_dir' '/build')
|
||||
else
|
||||
|
|
@ -40,7 +43,7 @@ if [ -f /opt/python/cp312-cp312/bin/python3 ]; then
|
|||
else
|
||||
python3 tools/ci_build/build.py "${BUILD_ARGS[@]}"
|
||||
fi
|
||||
if [ -x "$(command -v ccache)" ]; then
|
||||
ccache -sv
|
||||
if [ -x "$(command -v ccache)" ]; then
|
||||
ccache -sv
|
||||
ccache -z
|
||||
fi
|
||||
|
|
|
|||
Loading…
Reference in a new issue