diff --git a/.gitmodules b/.gitmodules index f5277a4561..1db5171adb 100644 --- a/.gitmodules +++ b/.gitmodules @@ -22,3 +22,6 @@ [submodule "cmake/external/nsync"] path = cmake/external/nsync url = https://github.com/google/nsync +[submodule "cmake/external/onnx-tensorrt"] + path = cmake/external/onnx-tensorrt + url = https://github.com/stevenlix/onnx-tensorrt.git diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index aeb0de07ac..0be7f80919 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -64,6 +64,7 @@ option(onnxruntime_BUILD_SHARED_LIB "Build a shared library" OFF) option(onnxruntime_ENABLE_MICROSOFT_INTERNAL "Use this option to enable/disable microsoft internal only code" OFF) option(onnxruntime_USE_NUPHAR "Build with Nupha" OFF) option(onnxruntime_USE_BRAINSLICE "Build with BrainSlice" OFF) +option(onnxruntime_USE_TRT "Build with TensorRT support" OFF) set(protobuf_BUILD_TESTS OFF CACHE BOOL "Build protobuf tests" FORCE) #nsync tests failed on Mac Build diff --git a/cmake/external/onnx-tensorrt b/cmake/external/onnx-tensorrt new file mode 160000 index 0000000000..493487a720 --- /dev/null +++ b/cmake/external/onnx-tensorrt @@ -0,0 +1 @@ +Subproject commit 493487a7203d8a059a9b2288807cb700857fe5ca diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index b20848cd41..25b6f36c0f 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -23,5 +23,6 @@ constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider"; constexpr const char* kMklDnnExecutionProvider = "MKLDNNExecutionProvider"; constexpr const char* kNupharExecutionProvider = "NupharExecutionProvider"; constexpr const char* kBrainSliceExecutionProvider = "BrainSliceExecutionProvider"; +constexpr const char* kTRTExecutionProvider = "TRTExecutionProvider"; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 9136f87593..8be7ec9ec4 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -131,6 +131,12 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, ORT_ENFORCE(p_input_provider); } + //no copy for TRT + if (required_provider_type == onnxruntime::kTRTExecutionProvider) { + new_mlvalue = orig_mlvalue; + return Status::OK(); + } + auto input_provider_type = p_input_provider->Type(); if (input_provider_type == required_provider_type && input_tensor_loc.mem_type == OrtMemTypeDefault) { new_mlvalue = orig_mlvalue; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 2e2cdd7204..742da9fb3c 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -299,7 +299,8 @@ class InferenceSession::Impl { for (auto& provider : providers) { if (provider->Type() != onnxruntime::kCpuExecutionProvider && provider->Type() != onnxruntime::kMklDnnExecutionProvider && - provider->Type() != onnxruntime::kNupharExecutionProvider) { + provider->Type() != onnxruntime::kNupharExecutionProvider && + provider->Type() != onnxruntime::kTRTExecutionProvider) { TransformerMemcpyImpl copy_impl(graph, provider->Type()); copy_impl.ModifyGraph(kernel_registry_manager); } diff --git a/onnxruntime/test/framework/test_utils.cc b/onnxruntime/test/framework/test_utils.cc index 19df409604..f666507679 100644 --- a/onnxruntime/test/framework/test_utils.cc +++ b/onnxruntime/test/framework/test_utils.cc @@ -17,5 +17,12 @@ IExecutionProvider* TestCudaExecutionProvider() { return &cuda_provider; } #endif + +#ifdef USE_TRT +IExecutionProvider* TestTRTExecutionProvider() { + static TRTExecutionProvider trt_provider; + return &trt_provider; +} +#endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/framework/test_utils.h b/onnxruntime/test/framework/test_utils.h index 24f8851102..255b7809c6 100644 --- a/onnxruntime/test/framework/test_utils.h +++ b/onnxruntime/test/framework/test_utils.h @@ -9,6 +9,9 @@ #ifdef USE_CUDA #include "core/providers/cuda/cuda_execution_provider.h" #endif +#ifdef USE_TRT +#include "core/providers/trt/trt_execution_provider.h" +#endif namespace onnxruntime { namespace test { @@ -18,6 +21,10 @@ IExecutionProvider* TestCPUExecutionProvider(); IExecutionProvider* TestCudaExecutionProvider(); #endif +#ifdef USE_TRT +IExecutionProvider* TestTRTExecutionProvider(); +#endif + template void CreateMLValue(AllocatorPtr alloc, const std::vector& dims, diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index bbb8ded46e..4b3a8043dc 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -82,6 +82,7 @@ int real_main(int argc, char* argv[]) { bool enable_cuda = false; bool enable_mkl = false; bool enable_nuphar = false; + bool enable_trt = false; OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING; { int ch; @@ -131,6 +132,8 @@ int real_main(int argc, char* argv[]) { enable_mkl = true; } else if (!MyStrCmp(optarg, ORT_TSTR("nuphar"))) { enable_nuphar = true; + } else if (!MyStrCmp(optarg, ORT_TSTR("trt"))) { + enable_trt = true; } else { usage(); return -1; @@ -229,6 +232,17 @@ int real_main(int argc, char* argv[]) { #else fprintf(stderr, "MKL-DNN is not supported in this build"); return -1; +#endif + } + if (enable_trt) { +#ifdef USE_TRT + OrtProviderFactoryInterface** f; + ORT_THROW_ON_ERROR(OrtCreateTRTExecutionProviderFactory(0, &f)); + sf.AppendExecutionProvider(f); + OrtReleaseObject(f); +#else + fprintf(stderr, "TensorRT is not supported in this build"); + return -1; #endif } TestEnv args(tests, stat, sf); diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index b61cb6e12d..bae6ccc5bc 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -62,6 +62,8 @@ namespace perftest { test_config.machine_config.provider_type_name = onnxruntime::kMklDnnExecutionProvider; } else if (!strcmp(optarg, "brainslice")) { test_config.machine_config.provider_type_name = onnxruntime::kBrainSliceExecutionProvider; + } else if (!strcmp(optarg, "trt")) { + test_config.machine_config.provider_type_name = onnxruntime::kTRTExecutionProvider; } else { return false; } diff --git a/onnxruntime/test/perftest/testenv.cc b/onnxruntime/test/perftest/testenv.cc index f661a1d7f8..fbfbb04e12 100644 --- a/onnxruntime/test/perftest/testenv.cc +++ b/onnxruntime/test/perftest/testenv.cc @@ -74,6 +74,15 @@ Status SessionFactory::create(std::shared_ptr<::onnxruntime::InferenceSession>& FACTORY_PTR_HOLDER; #else ORT_THROW("This executable was not built with BrainSlice"); +#endif + } else if (provider == onnxruntime::kTRTExecutionProvider) { +#if USE_TRT + OrtProviderFactoryInterface** f; + ORT_THROW_ON_ERROR(OrtCreateTRTExecutionProviderFactory(0, &f)); + RegisterExecutionProvider(sess.get(), f); + FACTORY_PTR_HOLDER; +#else + ORT_THROW("TensorRT is not supported in this build"); #endif } //TODO: add more diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 6d386c648b..9d08cd8ed4 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -71,5 +71,18 @@ std::unique_ptr DefaultBrainSliceExecutionProvider() { #endif } +std::unique_ptr DefaultTRTExecutionProvider() { +#ifdef USE_TRT + OrtProviderFactoryInterface** f; + ORT_THROW_ON_ERROR(OrtCreateTRTExecutionProviderFactory(0, &f)); + FACTORY_PTR_HOLDER; + OrtProvider* out; + ORT_THROW_ON_ERROR((*f)->CreateProvider(f, &out)); + return std::unique_ptr((IExecutionProvider*)out); +#else + return nullptr; +#endif +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index d9bc16c369..5c4f23a298 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -12,6 +12,7 @@ std::unique_ptr DefaultCudaExecutionProvider(); std::unique_ptr DefaultMkldnnExecutionProvider(bool enable_arena = true); std::unique_ptr DefaultNupharExecutionProvider(); std::unique_ptr DefaultBrainSliceExecutionProvider(); +std::unique_ptr DefaultTRTExecutionProvider(); } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/util/include/providers.h b/onnxruntime/test/util/include/providers.h index c421fa8f19..048d8332f9 100644 --- a/onnxruntime/test/util/include/providers.h +++ b/onnxruntime/test/util/include/providers.h @@ -15,4 +15,7 @@ #endif #if USE_BRAINSLICE #include "core/providers/brainslice/brainslice_provider_factory.h" -#endif \ No newline at end of file +#endif +#if USE_TRT +#include "core/providers/trt/trt_provider_factory.h" +#endif diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 7cf81d3312..2e6ff174c5 100755 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -119,6 +119,8 @@ Use the individual flags to only run the specified stages. parser.add_argument("--brain_slice_package_name", help="Name of brain slice packages") parser.add_argument("--brain_slice_client_package_name", help="Name of brainslice client package") parser.add_argument("--use_nuphar", action='store_true', help="Build with nuphar") + parser.add_argument("--use_trt", action='store_true', help="Build with trt") + parser.add_argument("--trt_path", action='store_true', help="Path to trt dir") return parser.parse_args() def resolve_executable_path(command_or_path): @@ -297,6 +299,7 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home "-Donnxruntime_USE_BRAINSLICE=" + ("ON" if args.use_brainslice else "OFF"), "-Donnxruntime_USE_NUPHAR=" + ("ON" if args.use_nuphar else "OFF"), "-Donnxruntime_USE_EIGEN_THREADPOOL=" + ("ON" if args.use_eigenthreadpool else "OFF"), + "-Donnxruntime_USE_TRT=" + ("ON" if args.use_trt else "OFF"), ] if args.use_brainslice: bs_pkg_name = args.brain_slice_package_name.split('.', 1) @@ -306,6 +309,9 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home "-Donnxruntime_BS_CLIENT_PACKAGE=%s/%s" % (args.brain_slice_package_path, args.brain_slice_client_package_name), "-Donnxruntime_BRAINSLICE_dynamic_lib_PATH=%s/%s" % (args.brain_slice_package_path, bs_shared_lib_name)] + if args.use_trt: + cmake_args += ["-DTENSORRT_ROOT=%s" % args.trt_path] + if args.use_llvm: cmake_args += ["-DLLVM_DIR=%s" % args.llvm_path]