From c343f7cb43debe1d36aa745b79a18eed860dfa81 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Fri, 3 Sep 2021 11:25:17 +0800 Subject: [PATCH] Add Algorithm Search for ConvGrad (#8613) * algo search for conv grad * global cache, bigger workspace size * fix build error * refactor * refactor * resolve comments * fix rocm * change lock places * rename variable * remove setting for inference * resolve comments --- .../onnxruntime/core/framework/allocator.h | 17 +- .../core/providers/cuda/cuda_allocator.cc | 15 + .../core/providers/cuda/cuda_allocator.h | 10 +- .../providers/cuda/cuda_execution_provider.cc | 2 +- .../providers/cuda/cuda_execution_provider.h | 9 + .../cuda/cuda_execution_provider_info.cc | 16 +- .../cuda/cuda_execution_provider_info.h | 8 +- onnxruntime/core/providers/cuda/cuda_kernel.h | 9 + onnxruntime/core/providers/cuda/nn/conv.cc | 61 ++- onnxruntime/core/providers/cuda/nn/conv.h | 4 + .../core/providers/rocm/rocm_allocator.cc | 15 + .../core/providers/rocm/rocm_allocator.h | 10 +- .../providers/rocm/rocm_execution_provider.cc | 2 +- .../rocm/rocm_execution_provider_info.cc | 16 +- .../rocm/rocm_execution_provider_info.h | 5 +- .../test/python/onnxruntime_test_python.py | 2 + .../ortmodule/_graph_execution_manager.py | 15 +- .../torch_gpu_allocator.cc | 16 +- .../python/orttraining_test_ortmodule_api.py | 9 +- .../test/training_ops/cuda/conv_grad_test.cc | 322 +++++++++++ .../training_ops/cuda/nn/conv_grad.cc | 509 ++++++++++++------ .../training_ops/cuda/nn/conv_grad.h | 52 +- 22 files changed, 898 insertions(+), 226 deletions(-) create mode 100644 orttraining/orttraining/test/training_ops/cuda/conv_grad_test.cc diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 689ab2fb05..2f235f5f38 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -130,14 +130,16 @@ class IAllocator { Create a std::unique_ptr that is allocated and freed by the provided IAllocator. @param allocator The allocator. @param count_or_bytes The exact bytes to allocate if T is void, otherwise the number of elements to allocate. + @param use_reserve If true, call Reserve() instead of Alloc() to allocate memory. @returns std::unique_ptr with allocated memory and deleter. */ template - static IAllocatorUniquePtr MakeUniquePtr(std::shared_ptr allocator, size_t count_or_bytes) { + static IAllocatorUniquePtr MakeUniquePtr(std::shared_ptr allocator, size_t count_or_bytes, + bool use_reserve = false) { if (allocator == nullptr) return nullptr; // for now limit to fundamental types. we could support others, but to do so either we or the caller // needs to call the dtor for the objects, for buffers allocated on device we don't have destructor - //static_assert(std::is_fundamental::value, "Fundamental type required as no destructors are called."); + // static_assert(std::is_fundamental::value, "Fundamental type required as no destructors are called."); size_t alloc_size = count_or_bytes; @@ -145,14 +147,15 @@ class IAllocator { if (!std::is_void::value) { // sizeof(void) isn't valid, but the compiler isn't smart enough to ignore that this line isn't // reachable if T is void. use std::conditional to 'use' void* in the sizeof call - if (!CalcMemSizeForArray(count_or_bytes, - sizeof(typename std::conditional::value, void*, T>::type), - &alloc_size)) return nullptr; + if (!CalcMemSizeForArray( + count_or_bytes, sizeof(typename std::conditional::value, void*, T>::type), &alloc_size)) { + return nullptr; + } } return IAllocatorUniquePtr{ - static_cast(allocator->Alloc(alloc_size)), // allocate - [=](T* ptr) { // capture 'allocator' by value so it's always valid + static_cast(use_reserve ? allocator->Reserve(alloc_size) : allocator->Alloc(alloc_size)), // allocate + [=](T* ptr) { // capture 'allocator' by value so it's always valid allocator->Free(ptr); }}; } diff --git a/onnxruntime/core/providers/cuda/cuda_allocator.cc b/onnxruntime/core/providers/cuda/cuda_allocator.cc index 9dad792efc..1ba6fc3214 100644 --- a/onnxruntime/core/providers/cuda/cuda_allocator.cc +++ b/onnxruntime/core/providers/cuda/cuda_allocator.cc @@ -77,6 +77,21 @@ void* CUDAExternalAllocator::Alloc(size_t size) { void CUDAExternalAllocator::Free(void* p) { free_(p); + std::lock_guard lock(lock_); + auto it = reserved_.find(p); + if (it != reserved_.end()) { + reserved_.erase(it); + if (empty_cache_) empty_cache_(); + } +} + +void* CUDAExternalAllocator::Reserve(size_t size) { + void* p = Alloc(size); + if (!p) return nullptr; + std::lock_guard lock(lock_); + ORT_ENFORCE(reserved_.find(p) == reserved_.end()); + reserved_.insert(p); + return p; } FencePtr CUDAAllocator::CreateFence(const SessionState* session_state) { diff --git a/onnxruntime/core/providers/cuda/cuda_allocator.h b/onnxruntime/core/providers/cuda/cuda_allocator.h index 0dd6d1e300..2c56a9444b 100644 --- a/onnxruntime/core/providers/cuda/cuda_allocator.h +++ b/onnxruntime/core/providers/cuda/cuda_allocator.h @@ -3,7 +3,9 @@ #pragma once +#include #include "core/framework/allocator.h" +#include "core/platform/ort_mutex.h" namespace onnxruntime { @@ -26,20 +28,26 @@ class CUDAAllocator : public IAllocator { class CUDAExternalAllocator : public CUDAAllocator { typedef void* (*ExternalAlloc)(size_t size); typedef void (*ExternalFree)(void* p); + typedef void (*ExternalEmptyCache)(); public: - CUDAExternalAllocator(OrtDevice::DeviceId device_id, const char* name, void* alloc, void* free) + CUDAExternalAllocator(OrtDevice::DeviceId device_id, const char* name, void* alloc, void* free, void* empty_cache) : CUDAAllocator(device_id, name) { alloc_ = reinterpret_cast(alloc); free_ = reinterpret_cast(free); + empty_cache_ = reinterpret_cast(empty_cache); } void* Alloc(size_t size) override; void Free(void* p) override; + void* Reserve(size_t size) override; private: + mutable OrtMutex lock_; ExternalAlloc alloc_; ExternalFree free_; + ExternalEmptyCache empty_cache_; + std::unordered_set reserved_; }; //TODO: add a default constructor diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 726cf1d133..13fd229f25 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -98,7 +98,7 @@ AllocatorPtr CUDAExecutionProvider::CreateCudaAllocator(OrtDevice::DeviceId devi if (external_allocator_info.UseExternalAllocator()) { AllocatorCreationInfo default_memory_info( [external_allocator_info](OrtDevice::DeviceId id) { - return std::make_unique(id, CUDA, external_allocator_info.alloc, external_allocator_info.free); + return std::make_unique(id, CUDA, external_allocator_info.alloc, external_allocator_info.free, external_allocator_info.empty_cache); }, device_id, false); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 51f0781e18..200beed8fb 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -63,6 +63,14 @@ class CUDAExecutionProvider : public IExecutionProvider { return IAllocator::MakeUniquePtr(GetAllocator(info_.device_id, OrtMemTypeDefault), count_or_bytes); } + template + IAllocatorUniquePtr GetTransientScratchBuffer(size_t count_or_bytes) const { + if (count_or_bytes == 0) + return nullptr; + + return IAllocator::MakeUniquePtr(GetAllocator(info_.device_id, OrtMemTypeDefault), count_or_bytes, true); + } + std::shared_ptr GetKernelRegistry() const override; std::unique_ptr GetDataTransfer() const override; @@ -74,6 +82,7 @@ class CUDAExecutionProvider : public IExecutionProvider { const cudaDeviceProp& GetDeviceProp() const { return device_prop_; }; int GetCudnnConvAlgo() const { return info_.cudnn_conv_algo_search; } bool DoCopyOnDefaultStream() const { return info_.do_copy_in_default_stream; } + bool GetCudnnConvUseMaxWorkspace() const { return info_.cudnn_conv_use_max_workspace; } ProviderOptions GetProviderOptions() const override { return CUDAExecutionProviderInfo::ToProviderOptions(info_); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc index 62faa0151b..88cee40839 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc @@ -19,6 +19,8 @@ constexpr const char* kCudnnConvAlgoSearch = "cudnn_conv_algo_search"; constexpr const char* kDoCopyInDefaultStream = "do_copy_in_default_stream"; constexpr const char* kGpuExternalAlloc = "gpu_external_alloc"; constexpr const char* kGpuExternalFree = "gpu_external_free"; +constexpr const char* kGpuExternalEmptyCache = "gpu_external_empty_cache"; +constexpr const char* kCudnnConvUseMaxWorkspace = "cudnn_conv_use_max_workspace"; } // namespace provider_option_names } // namespace cuda @@ -39,6 +41,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P CUDAExecutionProviderInfo info{}; void* alloc = nullptr; void* free = nullptr; + void* empty_cache = nullptr; ORT_THROW_IF_ERROR( ProviderOptionsParser{} .AddValueParser( @@ -71,6 +74,14 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P free = reinterpret_cast(address); return Status::OK(); }) + .AddValueParser( + cuda::provider_option_names::kGpuExternalEmptyCache, + [&empty_cache](const std::string& value_str) -> Status { + size_t address; + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); + empty_cache = reinterpret_cast(address); + return Status::OK(); + }) .AddAssignmentToReference(cuda::provider_option_names::kMemLimit, info.gpu_mem_limit) .AddAssignmentToEnumReference( cuda::provider_option_names::kArenaExtendStrategy, @@ -79,9 +90,10 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P cuda::provider_option_names::kCudnnConvAlgoSearch, *ort_cudnn_conv_algo_search_mapping, info.cudnn_conv_algo_search) .AddAssignmentToReference(cuda::provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream) + .AddAssignmentToReference(cuda::provider_option_names::kCudnnConvUseMaxWorkspace, info.cudnn_conv_use_max_workspace) .Parse(options)); - CUDAExecutionProviderExternalAllocatorInfo alloc_info{alloc, free}; + CUDAExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache}; info.external_allocator_info = alloc_info; return info; } @@ -92,11 +104,13 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution {cuda::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)}, {cuda::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.alloc))}, {cuda::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.free))}, + {cuda::provider_option_names::kGpuExternalEmptyCache, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.empty_cache))}, {cuda::provider_option_names::kArenaExtendStrategy, EnumToName(*arena_extend_strategy_mapping, info.arena_extend_strategy)}, {cuda::provider_option_names::kCudnnConvAlgoSearch, EnumToName(*ort_cudnn_conv_algo_search_mapping, info.cudnn_conv_algo_search)}, {cuda::provider_option_names::kDoCopyInDefaultStream, MakeStringWithClassicLocale(info.do_copy_in_default_stream)}, + {cuda::provider_option_names::kCudnnConvUseMaxWorkspace, MakeStringWithClassicLocale(info.cudnn_conv_use_max_workspace)}, }; return options; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h index 199aaec114..22bbd62faf 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h @@ -15,15 +15,18 @@ namespace onnxruntime { struct CUDAExecutionProviderExternalAllocatorInfo { void* alloc{nullptr}; void* free{nullptr}; + void* empty_cache{nullptr}; CUDAExecutionProviderExternalAllocatorInfo() { alloc = nullptr; free = nullptr; + empty_cache = nullptr; } - CUDAExecutionProviderExternalAllocatorInfo(void* a, void* f) { + CUDAExecutionProviderExternalAllocatorInfo(void* a, void* f, void* e) { alloc = a; free = f; + empty_cache = e; } bool UseExternalAllocator() const { @@ -45,6 +48,9 @@ struct CUDAExecutionProviderInfo { // arena config. OrtArenaCfg* default_memory_arena_cfg{nullptr}; CUDAExecutionProviderExternalAllocatorInfo external_allocator_info{}; + // By default use fix workspace size (32M) for Conv algo search, the final algo might not be the best. + // If set to true, try to use as much as possible memory for algo search. + bool cudnn_conv_use_max_workspace{false}; static CUDAExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const CUDAExecutionProviderInfo& info); diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index f6fe4807b7..8268e40292 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -52,6 +52,15 @@ class CudaKernel : public OpKernel { return provider_->GetScratchBuffer(count_or_bytes); } + // Different from GetScratchBuffer which use IAllocator::Alloc() to allocate memory, + // this GetTransientScratchBuffer will call IAllocator::Reserve() to allocate memory. + // IAllocator::Reserve() optionally implement some allocation logic that by-passes any arena-based + // logic (or similar for different allocator) that may be housed in the Alloc() implementation. + template + inline IAllocatorUniquePtr GetTransientScratchBuffer(size_t count_or_bytes) const { + return provider_->GetTransientScratchBuffer(count_or_bytes); + } + inline void AddDeferredReleaseCPUPtr(void* p) const { provider_->AddDeferredReleaseCPUPtr(p); } diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 00f5e0bb54..a89c7458a4 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -33,6 +33,41 @@ REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(double) REGISTER_KERNEL_TYPED(MLFloat16) +template +const cudnnConvolutionFwdAlgo_t Conv::kAllAlgos[] = { + CUDNN_CONVOLUTION_FWD_ALGO_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_FFT, + CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, +}; + +cudnnStatus_t GetWorkspaceSize(const CudnnConvState& s, cudnnConvolutionFwdAlgo_t algo, + size_t* sz) { + return cudnnGetConvolutionForwardWorkspaceSize(s.handle, s.x_tensor, s.w_desc, s.conv_desc, s.y_tensor, algo, sz); +} + +size_t GetMaxWorkspaceSize(const CudnnConvState& s, + const cudnnConvolutionFwdAlgo_t* algo, int n_algo) { + // TODO: get maximum available size from memory areana + size_t free, total; + CUDA_CALL_THROW(cudaMemGetInfo(&free, &total)); + // Assuming 10% of fragmentation + free = static_cast(static_cast(free) * 0.9); + size_t max_ws_size = 0; + for (int i = 0; i < n_algo; i++) { + cudnnStatus_t err; + size_t sz; + err = GetWorkspaceSize(s, algo[i], &sz); + if (CUDNN_STATUS_SUCCESS != err || sz == 0 || sz < max_ws_size || sz > free) continue; + max_ws_size = sz; + } + return max_ws_size; +} + Status SliceOutUnwantedOutputSection(cudaStream_t stream, const void* input_data, const std::vector& input_dims, void* output_data, @@ -213,9 +248,14 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const ORT_ENFORCE(cudnn_conv_algo > -1 && cudnn_conv_algo < 3, "cudnn_conv_algo should be 0, 1 or 2, but got ", cudnn_conv_algo); switch (cudnn_conv_algo) { case 0: { - IAllocatorUniquePtr algo_search_workspace = GetScratchBuffer(AlgoSearchWorkspaceSize); + static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; + size_t max_ws_size = cuda_ep->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(s_, kAllAlgos, num_algos) + : AlgoSearchWorkspaceSize; + // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. + // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. + IAllocatorUniquePtr algo_search_workspace = GetTransientScratchBuffer(max_ws_size); CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionForwardAlgorithmEx( - CudnnHandle(), + s_.handle, s_.x_tensor, s_.x_data, s_.w_desc, @@ -227,12 +267,12 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const &algo_count, // returnedAlgoCount &perf, algo_search_workspace.get(), - AlgoSearchWorkspaceSize)); + max_ws_size)); break; } case 1: CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionForwardAlgorithm_v7( - CudnnHandle(), + s_.handle, s_.x_tensor, s_.w_desc, s_.conv_desc, @@ -244,14 +284,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const default: perf.algo = kDefaultConvAlgo; - CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionForwardWorkspaceSize( - CudnnHandle(), - s_.x_tensor, - s_.w_desc, - s_.conv_desc, - s_.y_tensor, - perf.algo, - &perf.memory)); + CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(s_, perf.algo, &perf.memory)); if (std::is_same::value) { perf.mathType = CUDNN_TENSOR_OP_MATH; } else { @@ -290,7 +323,7 @@ Status Conv::ComputeInternal(OpKernelContext* context) const { const auto alpha = Consts::One; const auto beta = Consts::Zero; IAllocatorUniquePtr workspace = GetWorkSpace(); - CUDNN_RETURN_IF_ERROR(cudnnConvolutionForward(CudnnHandle(), + CUDNN_RETURN_IF_ERROR(cudnnConvolutionForward(s_.handle, &alpha, s_.x_tensor, s_.x_data, @@ -304,7 +337,7 @@ Status Conv::ComputeInternal(OpKernelContext* context) const { s_.y_tensor, s_.y_data)); if (nullptr != s_.b_data) { - CUDNN_RETURN_IF_ERROR(cudnnAddTensor(CudnnHandle(), &alpha, s_.b_tensor, s_.b_data, + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(s_.handle, &alpha, s_.b_tensor, s_.b_data, &alpha, s_.y_tensor, s_.y_data)); } // To deal with asymmetric padding, we may have over-padded on one or both sides of the spatial dimensions diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index 5abaa1f595..1d037c20d8 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -111,6 +111,8 @@ constexpr size_t MAX_CACHED_ALGO_PERF_RESULTS = 10000; template struct CudnnConvState { + cudnnHandle_t handle; + // if x/w dims changed, update algo and cudnnTensors std::vector last_x_dims; std::vector last_w_dims; @@ -173,6 +175,7 @@ class Conv : public CudaKernel { Conv(const OpKernelInfo& info) : CudaKernel(info), conv_attrs_(info) { auto pads_size = conv_attrs_.pads.size(); ORT_ENFORCE(pads_size % 2 == 0); + s_.handle = CudnnHandle(); } Status ComputeInternal(OpKernelContext* context) const override; @@ -186,6 +189,7 @@ class Conv : public CudaKernel { ConvAttributes conv_attrs_; mutable CudnnConvState s_; constexpr static auto kDefaultConvAlgo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + static const cudnnConvolutionFwdAlgo_t kAllAlgos[]; }; Status SliceOutUnwantedOutputSection(cudaStream_t stream, diff --git a/onnxruntime/core/providers/rocm/rocm_allocator.cc b/onnxruntime/core/providers/rocm/rocm_allocator.cc index 7e8cfb51f5..76026470bf 100644 --- a/onnxruntime/core/providers/rocm/rocm_allocator.cc +++ b/onnxruntime/core/providers/rocm/rocm_allocator.cc @@ -66,6 +66,21 @@ void* ROCMExternalAllocator::Alloc(size_t size) { void ROCMExternalAllocator::Free(void* p) { free_(p); + std::lock_guard lock(lock_); + auto it = reserved_.find(p); + if (it != reserved_.end()) { + reserved_.erase(it); + if (empty_cache_) empty_cache_(); + } +} + +void* ROCMExternalAllocator::Reserve(size_t size) { + void* p = Alloc(size); + if (!p) return nullptr; + std::lock_guard lock(lock_); + ORT_ENFORCE(reserved_.find(p) == reserved_.end()); + reserved_.insert(p); + return p; } void* ROCMPinnedAllocator::Alloc(size_t size) { diff --git a/onnxruntime/core/providers/rocm/rocm_allocator.h b/onnxruntime/core/providers/rocm/rocm_allocator.h index 063a9bf5a0..89077bb2d3 100644 --- a/onnxruntime/core/providers/rocm/rocm_allocator.h +++ b/onnxruntime/core/providers/rocm/rocm_allocator.h @@ -3,7 +3,9 @@ #pragma once +#include #include "core/framework/allocator.h" +#include "core/platform/ort_mutex.h" namespace onnxruntime { @@ -25,20 +27,26 @@ class ROCMAllocator : public IAllocator { class ROCMExternalAllocator : public ROCMAllocator { typedef void* (*ExternalAlloc)(size_t size); typedef void (*ExternalFree)(void* p); + typedef void (*ExternalEmptyCache)(); public: - ROCMExternalAllocator(OrtDevice::DeviceId device_id, const char* name, const void* alloc, const void* free) + ROCMExternalAllocator(OrtDevice::DeviceId device_id, const char* name, const void* alloc, const void* free, void* empty_cache) : ROCMAllocator(device_id, name) { alloc_ = reinterpret_cast(const_cast(alloc)); free_ = reinterpret_cast(const_cast(free)); + empty_cache_ = reinterpret_cast(empty_cache); } void* Alloc(size_t size) override; void Free(void* p) override; + void* Reserve(size_t size) override; private: + mutable OrtMutex lock_; ExternalAlloc alloc_; ExternalFree free_; + ExternalEmptyCache empty_cache_; + std::unordered_set reserved_; }; //TODO: add a default constructor diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 1c2574cdc7..644d648127 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -63,7 +63,7 @@ AllocatorPtr ROCMExecutionProvider::CreateRocmAllocator(OrtDevice::DeviceId devi if (external_allocator_info.UseExternalAllocator()) { AllocatorCreationInfo default_memory_info( [external_allocator_info](OrtDevice::DeviceId id) { - return std::make_unique(id, CUDA, external_allocator_info.alloc, external_allocator_info.free); + return std::make_unique(id, CUDA, external_allocator_info.alloc, external_allocator_info.free, external_allocator_info.empty_cache); }, device_id, false); diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc index 44b2a94a51..705daf2356 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc @@ -16,6 +16,7 @@ constexpr const char* kArenaExtendStrategy = "arena_extend_strategy"; constexpr const char* kConvExhaustiveSearch = "conv_exhaustive_search"; constexpr const char* kGpuExternalAlloc = "gpu_external_alloc"; constexpr const char* kGpuExternalFree = "gpu_external_free"; +constexpr const char* kGpuExternalEmptyCache = "gpu_external_empty_cache"; } // namespace provider_option_names } // namespace rocm @@ -30,7 +31,7 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P ROCMExecutionProviderInfo info{}; void* alloc = nullptr; void* free = nullptr; - + void* empty_cache = nullptr; ORT_THROW_IF_ERROR( ProviderOptionsParser{} .AddValueParser( @@ -49,7 +50,15 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P free = reinterpret_cast(address); return Status::OK(); }) - .AddValueParser( + .AddValueParser( + rocm::provider_option_names::kGpuExternalEmptyCache, + [&empty_cache](const std::string& value_str) -> Status { + size_t address; + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); + empty_cache = reinterpret_cast(address); + return Status::OK(); + }) + .AddValueParser( rocm::provider_option_names::kDeviceId, [&info](const std::string& value_str) -> Status { ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id)); @@ -70,7 +79,7 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P arena_extend_strategy_mapping, info.arena_extend_strategy) .Parse(options)); - ROCMExecutionProviderExternalAllocatorInfo alloc_info{alloc, free}; + ROCMExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache}; info.external_allocator_info = alloc_info; return info; @@ -82,6 +91,7 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution {rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)}, {rocm::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.alloc))}, {rocm::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.free))}, + {rocm::provider_option_names::kGpuExternalEmptyCache, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.empty_cache))}, {rocm::provider_option_names::kConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)}, {rocm::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)}, diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h index 55cabd2531..be9d233eaf 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h @@ -14,15 +14,18 @@ namespace onnxruntime { struct ROCMExecutionProviderExternalAllocatorInfo { const void* alloc{nullptr}; const void* free{nullptr}; + void* empty_cache{nullptr}; ROCMExecutionProviderExternalAllocatorInfo() { alloc = nullptr; free = nullptr; + empty_cache = nullptr; } - ROCMExecutionProviderExternalAllocatorInfo(void* a, void* f) { + ROCMExecutionProviderExternalAllocatorInfo(void* a, void* f, void* e) { alloc = a; free = f; + empty_cache = e; } bool UseExternalAllocator() const { diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 52e34e08fd..a79b735c3f 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -184,10 +184,12 @@ class TestInferenceSession(unittest.TestCase): option['gpu_external_alloc'] = '0' option['gpu_external_free'] = '0' + option['gpu_external_empty_cache'] = '0' sess.set_providers(['CUDAExecutionProvider'], [option]) options = sess.get_provider_options() self.assertEqual(options['CUDAExecutionProvider']['gpu_external_alloc'], '0') self.assertEqual(options['CUDAExecutionProvider']['gpu_external_free'], '0') + self.assertEqual(options['CUDAExecutionProvider']['gpu_external_empty_cache'], '0') # # Note: Tests that throw an exception leave an empty session due to how set_providers currently works, # so run them last. Each set_providers call will attempt to re-create a session, so it's diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 5ad652a10b..7ddd033cb3 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -164,6 +164,7 @@ class GraphExecutionManager(GraphExecutionInterface): from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_gpu_allocator self._torch_alloc = torch_gpu_allocator.gpu_caching_allocator_raw_alloc_address() self._torch_free = torch_gpu_allocator.gpu_caching_allocator_raw_delete_address() + self._torch_empty_cache = torch_gpu_allocator.gpu_caching_allocator_empty_cache_address() def _validate_module_type(self, module): """Raises ORTModuleTorchModelException if the module is not a torch.nn.Module""" @@ -230,12 +231,16 @@ class GraphExecutionManager(GraphExecutionInterface): providers = (["ROCMExecutionProvider"] if self.is_rocm_pytorch else [ "CUDAExecutionProvider"]) providers.append("CPUExecutionProvider") + provider_option_map = {"device_id": str(self._device.index)} + if not self.is_rocm_pytorch: + # Set Conv algo search mode to HEURISTIC, which is same as PyTorch's default setting. + provider_option_map["cudnn_conv_algo_search"] = "HEURISTIC" + provider_option_map["cudnn_conv_use_max_workspace"] = "1" if self._use_external_gpu_allocator: - provider_options = [{"device_id": str(self._device.index), - "gpu_external_alloc": str(self._torch_alloc), - "gpu_external_free": str(self._torch_free)}, {}] - else: - provider_options = [{"device_id": str(self._device.index)}, {}] + provider_option_map["gpu_external_alloc"] = str(self._torch_alloc) + provider_option_map["gpu_external_free"] = str(self._torch_free) + provider_option_map["gpu_external_empty_cache"] = str(self._torch_empty_cache) + provider_options = [provider_option_map, {}] elif self._device.type == 'cpu': providers = ["CPUExecutionProvider"] provider_options = [{}] diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator/torch_gpu_allocator.cc b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator/torch_gpu_allocator.cc index f180744130..3799eb09b4 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator/torch_gpu_allocator.cc +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator/torch_gpu_allocator.cc @@ -5,14 +5,22 @@ #include size_t gpu_caching_allocator_raw_alloc_address() { - return reinterpret_cast(&c10::___gpu_identifier___::___gpu_allocator_header___::raw_alloc); + return reinterpret_cast(&c10::___gpu_identifier___::___gpu_allocator_header___::raw_alloc); } size_t gpu_caching_allocator_raw_delete_address() { - return reinterpret_cast(&c10::___gpu_identifier___::___gpu_allocator_header___::raw_delete); + return reinterpret_cast(&c10::___gpu_identifier___::___gpu_allocator_header___::raw_delete); +} + +size_t gpu_caching_allocator_empty_cache_address() { + return reinterpret_cast(&c10::___gpu_identifier___::___gpu_allocator_header___::emptyCache); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("gpu_caching_allocator_raw_alloc_address", &gpu_caching_allocator_raw_alloc_address, "Address of PyTorch GPU allocator"); - m.def("gpu_caching_allocator_raw_delete_address", &gpu_caching_allocator_raw_delete_address, "Address of PyTorch GPU deallocator"); + m.def("gpu_caching_allocator_raw_alloc_address", &gpu_caching_allocator_raw_alloc_address, + "Address of PyTorch GPU allocator"); + m.def("gpu_caching_allocator_raw_delete_address", &gpu_caching_allocator_raw_delete_address, + "Address of PyTorch GPU deallocator"); + m.def("gpu_caching_allocator_empty_cache_address", &gpu_caching_allocator_empty_cache_address, + "Address of PyTorch GPU empty cache"); } diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index d08c3924b9..7fc2583c4e 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -639,17 +639,20 @@ def test_gradient_correctness_conv1d(use_fp16, input_requires_grad): loss.backward() return prediction - for step in range(10): + for _ in range(10): x = torch.randn(N, seq_len, C_in, device=device, requires_grad=input_requires_grad) pt_prediction = run_step(pt_model, x) ort_prediction = run_step(ort_model, x) + # PyTorch's Conv/GonvGrad uses HEURISTIC mode to search algo while ORT uses EXHAUSTIVE mode by default. + # While different algo types generate slightly different results, especially for FP16, + # so relax the tolerance for comparison, especially for FP16 run and gradient comparison. if use_fp16: _test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-3, rtol=1e-3) - _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, rtol=1e-2, atol=2e-2) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, rtol=5e-1, atol=4e-1) else: _test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-5) - _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, rtol=5e-3, atol=4e-3) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, rtol=5e-2, atol=4e-2) @pytest.mark.parametrize("device", ['cuda', 'cpu']) @pytest.mark.parametrize("padding_idx", [None, 1]) diff --git a/orttraining/orttraining/test/training_ops/cuda/conv_grad_test.cc b/orttraining/orttraining/test/training_ops/cuda/conv_grad_test.cc new file mode 100644 index 0000000000..fed8d1a0ab --- /dev/null +++ b/orttraining/orttraining/test/training_ops/cuda/conv_grad_test.cc @@ -0,0 +1,322 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace cuda { +namespace test { + +using namespace std; +using namespace onnxruntime::test; + +#ifdef USE_CUDA +namespace { + +struct ConvGradOpAttributes { + vector dilations; + int64_t group; + vector kernel_shape; + vector pads; + vector strides; +}; + +void TestConvGradOp(const ConvGradOpAttributes& attributes, const vector>& inputs, + const vector>& input_shapes, const vector>& outputs, + const vector>& output_shapes, bool is_half = false) { + OpTester test("ConvGrad", 1, kMSDomain); + test.AddAttribute("group", attributes.group); + test.AddAttribute("kernel_shape", attributes.kernel_shape); + test.AddAttribute("pads", attributes.pads); + + if (!attributes.dilations.empty()) { + test.AddAttribute("dilations", attributes.dilations); + } + + if (!attributes.strides.empty()) { + test.AddAttribute("strides", attributes.strides); + } + + if (is_half) { + std::vector dY_half(inputs[0].size()); + ConvertFloatToMLFloat16(inputs[0].data(), dY_half.data(), static_cast(inputs[0].size())); + test.AddInput("dY", input_shapes[0], dY_half); + + std::vector X_half(inputs[1].size()); + ConvertFloatToMLFloat16(inputs[1].data(), X_half.data(), static_cast(inputs[1].size())); + test.AddInput("X", input_shapes[1], X_half); + + std::vector W_half(inputs[2].size()); + ConvertFloatToMLFloat16(inputs[2].data(), W_half.data(), static_cast(inputs[2].size())); + test.AddInput("W", input_shapes[2], W_half); + + std::vector dX_half(outputs[0].size()); + ConvertFloatToMLFloat16(outputs[0].data(), dX_half.data(), static_cast(outputs[0].size())); + test.AddOutput("dX", output_shapes[0], dX_half); + + std::vector dW_half(outputs[1].size()); + ConvertFloatToMLFloat16(outputs[1].data(), dW_half.data(), static_cast(outputs[1].size())); + test.AddOutput("dW", output_shapes[1], dW_half); + + if (outputs.size() >= 3) { + std::vector dB_half(outputs[2].size()); + ConvertFloatToMLFloat16(outputs[2].data(), dB_half.data(), static_cast(outputs[2].size())); + test.AddOutput("dB", output_shapes[2], dB_half); + } + } else { + test.AddInput("dY", input_shapes[0], inputs[0]); + test.AddInput("X", input_shapes[1], inputs[1]); + test.AddInput("W", input_shapes[2], inputs[2]); + + test.AddOutput("dX", output_shapes[0], outputs[0]); + test.AddOutput("dW", output_shapes[1], outputs[1]); + + if (outputs.size() >= 3) { + test.AddOutput("dB", output_shapes[2], outputs[2]); + } + } + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +} // namespace + +TEST(ConvGradTest, Conv1D_1) { + ConvGradOpAttributes attrs = { + vector{1}, // dilations + 1, // group + vector{1}, // kernel_shape + vector{0, 0}, // pads + vector{1}, // strides + }; + + vector dY(7, 1.0f); + vector dY_shape = {1, 1, 7}; + vector X = {2.0349f, -1.8088f, -0.1171f, 1.1849f, -0.6590f, -2.0404f, -1.2810f}; + vector X_shape = {1, 1, 7}; + vector W = {0.5081f}; + vector W_shape = {1, 1, 1}; + vector dX = {0.5081f, 0.5081f, 0.5081f, 0.5081f, 0.5081f, 0.5081f, 0.5081f}; + vector dX_shape = {1, 1, 7}; + vector dW = {-2.6865f}; + vector dW_shape = {1, 1, 1}; + + TestConvGradOp(attrs, {dY, X, W}, {dY_shape, X_shape, W_shape}, {dX, dW}, {dX_shape, dW_shape}); + TestConvGradOp(attrs, {dY, X, W}, {dY_shape, X_shape, W_shape}, {dX, dW}, {dX_shape, dW_shape}, true); +} + +TEST(ConvGradTest, Conv1D_2) { + ConvGradOpAttributes attrs = { + vector{2}, // dilations + 1, // group + vector{2}, // kernel_shape + vector{2, 2}, // pads + vector{2}, // strides + }; + + vector dY(30, 1.0f); + vector dY_shape = {3, 2, 5}; + vector X = {-0.9303f, 0.3717f, 0.4961f, 0.5068f, -0.7506f, -0.7609f, -1.8795f, 0.0536f, + 1.5201f, -0.9580f, -1.7678f, 0.4683f, -0.3142f, 0.2097f, -1.3819f, -0.1070f, + -1.7558f, -0.0278f, 1.5378f, 2.6415f, 1.0004f, 1.3604f, 1.2819f, -0.4629f}; + vector X_shape = {3, 1, 8}; + vector W = {1.6664f, -0.1582f, -0.8984f, 0.0849f}; + vector W_shape = {2, 1, 2}; + vector dX = {0.6948f, 0.0000f, 0.6948f, 0.0000f, 0.6948f, 0.0000f, 0.6948f, 0.0000f, + 0.6948f, 0.0000f, 0.6948f, 0.0000f, 0.6948f, 0.0000f, 0.6948f, 0.0000f, + 0.6948f, 0.0000f, 0.6948f, 0.0000f, 0.6948f, 0.0000f, 0.6948f, 0.0000f}; + vector dX_shape = {3, 1, 8}; + vector dW = {-2.9440f, -2.9440f, -2.9440f, -2.9440f}; + vector dW_shape = {2, 1, 2}; + + TestConvGradOp(attrs, {dY, X, W}, {dY_shape, X_shape, W_shape}, {dX, dW}, {dX_shape, dW_shape}); + TestConvGradOp(attrs, {dY, X, W}, {dY_shape, X_shape, W_shape}, {dX, dW}, {dX_shape, dW_shape}, true); +} + +TEST(ConvGradTest, Conv1D_Bias) { + ConvGradOpAttributes attrs = { + vector{2}, // dilations + 1, // group + vector{1}, // kernel_shape + vector{1, 1}, // pads + vector{3}, // strides + }; + + vector dY(8, 1.0f); + vector dY_shape = {2, 1, 4}; + vector X = {0.3305f, 2.6170f, -0.8102f, -1.1348f, -0.0850f, 0.2033f, 0.7295f, -0.2826f, -0.5977f, + 0.5505f, 0.3895f, -1.3394f, 0.6413f, -0.4744f, -0.9943f, 0.7560f, 0.1355f, -1.3931f, + 1.2644f, 0.0240f, 0.7571f, 0.6851f, -0.3362f, -1.1230f, 0.6475f, -0.4596f, 1.1648f, + 0.8991f, 0.0440f, 1.5056f, 0.9504f, -0.5266f, 0.0437f, -0.3006f, 0.8489f, 0.0960f}; + vector X_shape = {2, 2, 9}; + vector W = {0.0398f, 0.1392f}; + vector W_shape = {1, 2, 1}; + vector dX = {0.0000f, 0.0000f, 0.0398f, 0.0000f, 0.0000f, 0.0398f, 0.0000f, 0.0000f, 0.0398f, + 0.0000f, 0.0000f, 0.1392f, 0.0000f, 0.0000f, 0.1392f, 0.0000f, 0.0000f, 0.1392f, + 0.0000f, 0.0000f, 0.0398f, 0.0000f, 0.0000f, 0.0398f, 0.0000f, 0.0000f, 0.0398f, + 0.0000f, 0.0000f, 0.1392f, 0.0000f, 0.0000f, 0.1392f, 0.0000f, 0.0000f, 0.1392f}; + vector dX_shape = {2, 2, 9}; + vector dW = {-0.4057f, -2.0815f}; + vector dW_shape = {1, 2, 1}; + vector dB = {8.f}; + vector dB_shape = {1}; + + TestConvGradOp(attrs, {dY, X, W}, {dY_shape, X_shape, W_shape}, {dX, dW, dB}, {dX_shape, dW_shape, dB_shape}); + TestConvGradOp(attrs, {dY, X, W}, {dY_shape, X_shape, W_shape}, {dX, dW, dB}, {dX_shape, dW_shape, dB_shape}, true); +} + +TEST(ConvGradTest, Conv2D) { + ConvGradOpAttributes attrs = { + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + }; + + vector dY(48, 1.0f); + vector dY_shape = {1, 3, 4, 4}; + vector X = {0.8374f, -2.0758f, 1.8918f, -1.0625f, -1.2747f, -0.1561f, 0.4573f, 2.5314f, -0.0089f, -1.0412f, + 0.7690f, 0.2320f, 0.6535f, -0.4921f, -0.6051f, 0.5580f, 1.5682f, -1.0309f, -0.9379f, -0.1834f, + -1.2162f, 1.4167f, -0.2849f, -0.1625f, 0.3380f, -0.1393f, -1.1557f, 0.9718f, -0.4656f, -0.9046f, + 1.5710f, -1.3963f, 1.2470f, 0.7327f, 0.8045f, 0.8071f, 1.1703f, -1.3566f, -0.2030f, -0.1227f, + -0.5881f, 2.4159f, -0.2768f, 0.5567f, -0.2805f, 0.1618f, -0.7256f, -0.1053f}; + vector X_shape = {1, 3, 4, 4}; + vector W = {0.1094f, 1.1541f, 0.0486f, 0.5668f, 1.0372f, -0.3792f, 1.4979f, 0.1757f, 0.1733f}; + vector W_shape = {3, 3, 1, 1}; + vector dX = {2.1741f, 2.1741f, 2.1741f, 2.1741f, 2.1741f, 2.1741f, 2.1741f, 2.1741f, + 2.1741f, 2.1741f, 2.1741f, 2.1741f, 2.1741f, 2.1741f, 2.1741f, 2.1741f, + 2.3670f, 2.3670f, 2.3670f, 2.3670f, 2.3670f, 2.3670f, 2.3670f, 2.3670f, + 2.3670f, 2.3670f, 2.3670f, 2.3670f, 2.3670f, 2.3670f, 2.3670f, 2.3670f, + -0.1572f, -0.1572f, -0.1572f, -0.1572f, -0.1572f, -0.1572f, -0.1572f, -0.1572f, + -0.1572f, -0.1572f, -0.1572f, -0.1572f, -0.1572f, -0.1572f, -0.1572f, -0.1572f}; + vector dX_shape = {1, 3, 4, 4}; + vector dW = {1.2142f, -2.0115f, 4.2372f, 1.2142f, -2.0115f, 4.2372f, 1.2142f, -2.0115f, 4.2372f}; + vector dW_shape = {3, 3, 1, 1}; + + TestConvGradOp(attrs, {dY, X, W}, {dY_shape, X_shape, W_shape}, {dX, dW}, {dX_shape, dW_shape}); + TestConvGradOp(attrs, {dY, X, W}, {dY_shape, X_shape, W_shape}, {dX, dW}, {dX_shape, dW_shape}, true); +} + +TEST(ConvGradTest, Conv2D_Bias) { + ConvGradOpAttributes attrs = { + vector{1, 1}, // dilations + 1, // group + vector{2, 2}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + }; + + vector dY(8, 1.0f); + vector dY_shape = {1, 2, 2, 2}; + vector X = {-0.4406f, 0.3064f, 0.0794f, -0.2795f, 0.8228f, 0.4751f, 1.3114f, 0.9522f, 0.9082f}; + vector X_shape = {1, 1, 3, 3}; + vector W = {-0.0820f, -0.4214f, -2.2745f, -1.5834f, 0.9746f, 0.6936f, -0.5140f, 0.9900f}; + vector W_shape = {2, 1, 2, 2}; + vector dX = {0.8926f, 1.1648f, 0.2723f, -1.8959f, -2.2169f, -0.3211f, -2.7884f, -3.3818f, -0.5933f}; + vector dX_shape = {1, 1, 3, 3}; + vector dW = {0.4092f, 1.6838f, 2.8069f, 3.1583f, 0.4092f, 1.6838f, 2.8069f, 3.1583f}; + vector dW_shape = {2, 1, 2, 2}; + vector dB = {4.f, 4.f}; + vector dB_shape = {2}; + + TestConvGradOp(attrs, {dY, X, W}, {dY_shape, X_shape, W_shape}, {dX, dW, dB}, {dX_shape, dW_shape, dB_shape}); + TestConvGradOp(attrs, {dY, X, W}, {dY_shape, X_shape, W_shape}, {dX, dW, dB}, {dX_shape, dW_shape, dB_shape}, true); +} + +TEST(ConvGradTest, Conv3D) { + ConvGradOpAttributes attrs = { + vector{1, 1, 1}, // dilations + 1, // group + vector{2, 2, 1}, // kernel_shape + vector{2, 2, 2, 2, 2, 2}, // pads + vector{2, 2, 2}, // strides + }; + + vector dY(81, 1.0f); + vector dY_shape = {1, 3, 3, 3, 3}; + vector X = {0.5598f, -1.9201f, -0.7435f, 2.0217f, 1.2615f, -1.9540f, 0.3119f, 0.0106f, -0.1752f, + 0.1553f, 1.3088f, 0.8588f, -0.6396f, -1.1059f, -0.7768f, 0.3251f, -0.3116f, -0.1495f, + 0.8923f, -0.1832f, -0.3995f, -0.9641f, 1.9743f, -0.3098f, -0.3029f, -0.3453f, -0.7708f, + -0.3267f, -0.5051f, -0.9330f, -0.2421f, -0.0874f, 0.3225f, -0.8572f, -0.2019f, -0.7069f, + 0.4333f, 0.5562f, -1.5587f, -1.0665f, -0.6832f, -0.4320f, -0.0225f, 1.4662f, 0.4808f, + 0.0282f, 0.6967f, 0.5708f, -1.3258f, -0.6925f, 0.1217f, 1.3211f, 0.5877f, 0.7335f}; + vector X_shape = {1, 3, 3, 3, 2}; + vector W = {-0.1911f, 0.6604f, -1.0283f, -0.9381f, -0.3449f, 1.1152f, -1.0256f, -0.3494f, 0.4504f, + 0.2418f, 0.2258f, -1.5920f, 1.0468f, 0.2045f, 0.8264f, -0.5797f, -0.0254f, 0.6934f, + -1.7728f, 0.8619f, -0.2013f, -0.1045f, -0.4713f, 1.2544f, 1.7090f, -0.7133f, -0.6160f, + -1.2325f, -1.2152f, 0.0935f, -0.4929f, 1.3772f, 0.3125f, -0.7773f, 1.0350f, 3.2168f}; + vector W_shape = {3, 3, 2, 2, 1}; + vector dX = {2.5646f, 0.0000f, 0.1516f, 0.0000f, 2.5646f, 0.0000f, -0.8178f, 0.0000f, -2.7503f, + 0.0000f, -0.8178f, 0.0000f, 2.5646f, 0.0000f, 0.1516f, 0.0000f, 2.5646f, 0.0000f, + -1.5855f, 0.0000f, 1.9021f, 0.0000f, -1.5855f, 0.0000f, -3.2913f, 0.0000f, 1.8898f, + 0.0000f, -3.2913f, 0.0000f, -1.5855f, 0.0000f, 1.9021f, 0.0000f, -1.5855f, 0.0000f, + 0.5616f, 0.0000f, -0.6399f, 0.0000f, 0.5616f, 0.0000f, 0.7895f, 0.0000f, 2.8792f, + 0.0000f, 0.7895f, 0.0000f, 0.5616f, 0.0000f, -0.6399f, 0.0000f, 0.5616f, 0.0000f}; + vector dX_shape = {1, 3, 3, 3, 2}; + vector dW = {0.8701f, -1.5203f, 1.6207f, -0.1752f, 2.4225f, -0.0770f, -0.8080f, -0.7708f, -0.9880f, + -1.4370f, 0.6742f, 0.4808f, 0.8701f, -1.5203f, 1.6207f, -0.1752f, 2.4225f, -0.0770f, + -0.8080f, -0.7708f, -0.9880f, -1.4370f, 0.6742f, 0.4808f, 0.8701f, -1.5203f, 1.6207f, + -0.1752f, 2.4225f, -0.0770f, -0.8080f, -0.7708f, -0.9880f, -1.4370f, 0.6742f, 0.4808f}; + vector dW_shape = {3, 3, 2, 2, 1}; + + TestConvGradOp(attrs, {dY, X, W}, {dY_shape, X_shape, W_shape}, {dX, dW}, {dX_shape, dW_shape}); + TestConvGradOp(attrs, {dY, X, W}, {dY_shape, X_shape, W_shape}, {dX, dW}, {dX_shape, dW_shape}, true); +} + +TEST(ConvTest, Conv3D_Bias) { + ConvGradOpAttributes attrs = { + vector{2, 2, 2}, // dilations + 1, // group + vector{2, 2, 2}, // kernel_shape + vector{2, 2, 2, 2, 2, 2}, // pads + vector{2, 2, 2}, // strides + }; + + vector dY(108, 1.0f); + vector dY_shape = {2, 2, 3, 3, 3}; + vector X = {0.5189f, -0.8970f, 1.6637f, -0.0882f, 0.0562f, 1.5707f, 0.5424f, 0.4293f, -0.7949f, 2.1479f, + 0.8582f, -0.5160f, -0.2406f, 0.0915f, -0.6440f, -0.3800f, 1.7989f, -1.6479f, 0.9071f, -0.4087f, + 1.3729f, -1.5582f, 1.9312f, 1.0753f, 0.8313f, 0.5097f, -0.0664f, -1.2774f, 1.7208f, -1.7777f, + -2.0758f, -0.4782f, 2.0207f, 1.4053f, 0.3488f, -0.0871f, 1.4151f, 2.5243f, 0.2891f, 0.8317f, + 1.8934f, 0.3911f, 0.2915f, -0.8505f, 1.0430f, 0.5391f, 0.8347f, 0.0633f, -0.3250f, 1.3358f, + -0.3121f, 0.4587f, -0.4955f, 1.8411f, 0.9877f, 1.0809f, 0.0119f, -1.2706f, 1.8457f, -0.1520f, + -0.4535f, -0.5325f, -0.8921f, -0.3127f, 0.5746f, -1.2514f, 0.4638f, 0.8440f, -0.6113f, 0.6936f, + 0.0998f, 0.9767f, 0.2785f, -0.3068f, -0.4619f, 0.4801f, -2.1590f, -1.7342f, 0.7354f, 0.0234f, + 1.8095f, 0.1252f, -0.5841f, 0.0738f, 1.4252f, 1.4222f, -0.1192f, -2.9955f, 0.8287f, 0.6252f, + -1.5834f, -0.1388f, 0.5532f, 0.4044f, 1.0432f, -2.3991f, 0.4339f, 0.1083f, -0.7726f, 2.0629f, + 0.7136f, -0.0978f, -0.7905f, 0.9585f, -0.3205f, 1.3750f, 0.4137f, -0.4552f, 2.7165f, -1.6367f, + -0.6286f, -0.4656f, 0.6219f, -1.7275f, 1.7599f, -1.0443f, 1.3212f, 0.1621f, -0.5357f, 0.0957f, + 0.4524f, -0.3814f, -0.0744f, 0.8301f, -0.9539f, 0.0867f, -0.7864f, 1.4918f}; + vector X_shape = {2, 1, 4, 4, 4}; + vector W = {-0.6570f, 0.1637f, 1.7824f, 0.7986f, -0.2703f, -0.7447f, 1.2674f, 0.2019f, + 0.1045f, -0.7279f, 0.9658f, 0.4698f, -0.6699f, -1.5259f, 2.1664f, -0.3859f}; + vector W_shape = {2, 1, 2, 2, 2}; + vector dX = {2.9389f, 0.0000f, 2.9389f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 2.9389f, 0.0000f, 2.9389f, + 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, + 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 2.9389f, + 0.0000f, 2.9389f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 2.9389f, 0.0000f, 2.9389f, 0.0000f, + 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, + 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 2.9389f, 0.0000f, + 2.9389f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 2.9389f, 0.0000f, 2.9389f, 0.0000f, 0.0000f, + 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, + 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 2.9389f, 0.0000f, 2.9389f, + 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 2.9389f, 0.0000f, 2.9389f, 0.0000f, 0.0000f, 0.0000f, + 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, + 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f}; + vector dX_shape = {2, 1, 4, 4, 4}; + vector dW = {7.4097f, 7.4097f, 7.4097f, 7.4097f, 7.4097f, 7.4097f, 7.4097f, 7.4097f, + 7.4097f, 7.4097f, 7.4097f, 7.4097f, 7.4097f, 7.4097f, 7.4097f, 7.4097f}; + vector dW_shape = {2, 1, 2, 2, 2}; + vector dB = {54.f, 54.f}; + vector dB_shape = {2}; + + TestConvGradOp(attrs, {dY, X, W}, {dY_shape, X_shape, W_shape}, {dX, dW, dB}, {dX_shape, dW_shape, dB_shape}); + TestConvGradOp(attrs, {dY, X, W}, {dY_shape, X_shape, W_shape}, {dX, dW, dB}, {dX_shape, dW_shape, dB_shape}, true); +} +#endif // USE_CUDA + +} // namespace test +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc index bff551e548..788d33b5b9 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc @@ -4,114 +4,343 @@ #include "orttraining/training_ops/cuda/nn/conv_grad.h" #include "core/providers/common.h" -#include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/platform/ort_mutex.h" + +// The AlgoPerfCache and AlgoSearch here for Conv/ConvGrad is referenced on PyTorch's implementation +// from aten/src/ATen/native/cudnn/Conv_v7.cpp. namespace onnxruntime { namespace cuda { -#define REGISTER_GRADIENT_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - ConvGrad, \ - kMSDomain, \ - 1, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ConvGrad); +#define REGISTER_GRADIENT_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX(ConvGrad, kMSDomain, 1, T, kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + ConvGrad); REGISTER_GRADIENT_KERNEL_TYPED(float) REGISTER_GRADIENT_KERNEL_TYPED(double) REGISTER_GRADIENT_KERNEL_TYPED(MLFloat16) -cudnnStatus_t getWorkspaceSize( - const ConvolutionArgs& args, - cudnnConvolutionBwdDataAlgo_t algo, size_t* sz) { - return cudnnGetConvolutionBackwardDataWorkspaceSize( - args.handle, - args.w_desc, - args.o_desc, - args.c_desc, - args.i_desc, - algo, - sz); +using T_BwdDataPerf = cudnnConvolutionBwdDataAlgoPerf_t; +using T_BwdDataAlgo = cudnnConvolutionBwdDataAlgo_t; +using T_BwdFilterPerf = cudnnConvolutionBwdFilterAlgoPerf_t; +using T_BwdFilterAlgo = cudnnConvolutionBwdFilterAlgo_t; + +cudnnStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdDataAlgo algo, size_t* workspace_size) { + return cudnnGetConvolutionBackwardDataWorkspaceSize(args.handle, args.w_desc, args.y_tensor, args.conv_desc, + args.x_tensor, algo, workspace_size); } -cudnnStatus_t getWorkspaceSize( - const ConvolutionArgs& args, - cudnnConvolutionBwdFilterAlgo_t algo, size_t* sz) { - return cudnnGetConvolutionBackwardFilterWorkspaceSize( - args.handle, - args.i_desc, - args.o_desc, - args.c_desc, - args.w_desc, - algo, - sz); +cudnnStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdFilterAlgo algo, size_t* workspace_size) { + return cudnnGetConvolutionBackwardFilterWorkspaceSize(args.handle, args.x_tensor, args.y_tensor, args.conv_desc, + args.w_desc, algo, workspace_size); } -// TODO: we can cache the descriptors, and only update if the input shape changes +template +size_t GetMaxWorkspaceSize(const ConvArgs& args, const T_Algo* algo, int n_algo) { + // Calling cudaMemGetInfo is not ideal, but our cuda allocator doesn't have a way to get this info. + size_t free, total; + CUDA_CALL_THROW(cudaMemGetInfo(&free, &total)); + // Assuming 10% of fragmentation. + free = static_cast(static_cast(free) * 0.9); + size_t max_workspace_size = 0; + for (int i = 0; i < n_algo; i++) { + cudnnStatus_t status; + size_t workspace_size; + status = GetWorkspaceSize(args, algo[i], &workspace_size); + if (CUDNN_STATUS_SUCCESS != status || workspace_size == 0 || workspace_size < max_workspace_size || + workspace_size > free) + continue; + max_workspace_size = workspace_size; + } + + return max_workspace_size; +} + +template +std::vector GetValidAlgorithms(const T_Perf* perf_results, int n_algo) { + std::vector result; + result.reserve(n_algo); + for (int i = 0; i < n_algo; i++) { + T_Perf perf = perf_results[i]; + if (perf.status == CUDNN_STATUS_SUCCESS) { + result.emplace_back(perf); + } + } + ORT_ENFORCE(result.size() > 0, "No valid convolution algorithms available in CuDNN"); + // TODO: This is a cuDNN bug that gave wrong results in certain strided convolution gradient setups + // when cuDNN version < 7.5. Need to add handling for such special case. + return result; +} + +struct ConvParamsHash { + // ConvParams must be a POD because we read out its memory constant as char* when hashing. + static_assert(std::is_pod::value, "ConvParams is not POD"); + size_t operator()(const ConvParams& conv_params) const { + auto ptr = reinterpret_cast(&conv_params); + uint32_t value = 0x811C9DC5; + for (int i = 0; i < static_cast(sizeof(ConvParams)); ++i) { + value ^= ptr[i]; + value *= 0x01000193; + } + return static_cast(value); + } +}; + +struct ConvParamsEqual { + // ConvParams must be a POD because we read out its memory constant as char* when hashing. + static_assert(std::is_pod::value, "ConvParams is not POD"); + bool operator()(const ConvParams& a, const ConvParams& b) const { + auto ptr1 = reinterpret_cast(&a); + auto ptr2 = reinterpret_cast(&b); + return memcmp(ptr1, ptr2, sizeof(ConvParams)) == 0; + } +}; + +template +struct AlgoPerfCache { + mutable OrtMutex mutex; + std::unordered_map map; + + bool Find(const ConvParams& params, T_Perf* result) { + std::lock_guard guard(mutex); + auto it = map.find(params); + if (it == map.end()) { + return false; + } + *result = it->second; + return true; + } + + void Insert(const ConvParams& params, const T_Perf& algo_perf) { + std::lock_guard guard(mutex); + map[params] = algo_perf; + } +}; + +// TODO: Currently we use global AlgoPerfCache for ConvGrad only. Conv's perf cache is till per node. +// Need to apply such global cache for Conv, and move some shared code from here to conv.h/cc. +AlgoPerfCache bwd_data_algos; +AlgoPerfCache bwd_filter_algos; + +template +struct AlgoSearch {}; + +template <> +struct AlgoSearch { + static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + static AlgoPerfCache& Cache() { return bwd_data_algos; } + static Status FindAlgorithms(const ConvArgs& args, const CUDAExecutionProvider* provider, + std::vector& perf_results) { + static const T_BwdDataAlgo algos[] = { + CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED}; + static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; + ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward data algorithms."); + int perf_count; + std::unique_ptr candidates(new T_BwdDataPerf[num_algos]); + if (args.params.algo_mode == OrtCudnnConvAlgoSearch::HEURISTIC) { + CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionBackwardDataAlgorithm_v7(args.handle, args.w_desc, args.y_tensor, + args.conv_desc, args.x_tensor, num_algos, + &perf_count, candidates.get())); + } else if (args.params.algo_mode == OrtCudnnConvAlgoSearch::EXHAUSTIVE) { + size_t max_workspace_size = provider->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos) + : AlgoSearchWorkspaceSize; + // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. + // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. + IAllocatorUniquePtr workspace = provider->GetTransientScratchBuffer(max_workspace_size); + CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardDataAlgorithmEx( + args.handle, args.w_desc, args.w_data, args.y_tensor, args.dy_data, args.conv_desc, args.x_tensor, + args.dx_data, num_algos, &perf_count, candidates.get(), workspace.get(), max_workspace_size)); + } else { + ORT_ENFORCE(false, "Algo mode should be EXHAUSTIVE (0) or HEURISTIC (1), but got ", args.params.algo_mode); + } + perf_results = GetValidAlgorithms(candidates.get(), perf_count); + return Status::OK(); + } +}; + +template <> +struct AlgoSearch { + static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; + static AlgoPerfCache& Cache() { return bwd_filter_algos; } + static Status FindAlgorithms(const ConvArgs& args, const CUDAExecutionProvider* provider, + std::vector& perf_results) { + static const T_BwdFilterAlgo algos[] = { + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, + }; + + // NOTE: - 1 because ALGO_WINOGRAD is not implemented. + static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT - 1; + ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms."); + std::unique_ptr candidates(new T_BwdFilterPerf[num_algos]); + int perf_count; + if (args.params.algo_mode == OrtCudnnConvAlgoSearch::HEURISTIC) { + CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm_v7(args.handle, args.x_tensor, args.y_tensor, + args.conv_desc, args.w_desc, num_algos, + &perf_count, candidates.get())); + } else if (args.params.algo_mode == OrtCudnnConvAlgoSearch::EXHAUSTIVE) { + size_t max_workspace_size = provider->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos) + : AlgoSearchWorkspaceSize; + // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. + // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. + IAllocatorUniquePtr workspace = provider->GetTransientScratchBuffer(max_workspace_size); + CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardFilterAlgorithmEx( + args.handle, args.x_tensor, args.x_data, args.y_tensor, args.dy_data, args.conv_desc, args.w_desc, + args.dw_data, num_algos, &perf_count, candidates.get(), workspace.get(), max_workspace_size)); + } else { + ORT_ENFORCE(false, "Algo mode should be EXHAUSTIVE (0) or HEURISTIC (1), but got ", args.params.algo_mode); + } + perf_results = GetValidAlgorithms(candidates.get(), perf_count); + return Status::OK(); + } +}; + +template +class AlgoIterator { + public: + AlgoIterator(const ConvArgs& args) : args_(args) {} + + static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results) { + perf_results.resize(1); + perf_results[0].algo = AlgoSearch::DEFAULT_ALGO; + if (args.params.data_type == CUDNN_DATA_HALF) { + perf_results[0].mathType = CUDNN_TENSOR_OP_MATH; + } else { + perf_results[0].mathType = CUDNN_DEFAULT_MATH; + } + CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(args, perf_results[0].algo, &(perf_results[0].memory))); + return Status::OK(); + } + + Status TryAll(const CUDAExecutionProvider* provider, std::function f) { + auto& cache = AlgoSearch::Cache(); + T_Perf algo_perf; + if (cache.Find(args_.params, &algo_perf) && f(algo_perf) == Status::OK()) { + return Status::OK(); + } + + std::vector perf_results; + ORT_RETURN_IF_ERROR(args_.params.algo_mode == OrtCudnnConvAlgoSearch::DEFAULT + ? OnlyDefaultAlgorithm(args_, perf_results) + : AlgoSearch::FindAlgorithms(args_, provider, perf_results)); + for (auto& algo_perf : perf_results) { + if (f(algo_perf) == Status::OK()) { + cache.Insert(args_.params, algo_perf); + return Status::OK(); + } + } + ORT_ENFORCE(false, "Unable to find a valid cuDNN algorithm to run convolution."); + return Status::OK(); + } + + private: + const ConvArgs& args_; +}; + template -Status ConvGrad::PrepareArgs(const Tensor& input, const Tensor& output, const Tensor& weight, const Tensor* bias) const { - const TensorShape& i_shape = input.Shape(); - std::vector i_dims = i_shape.GetDims(); +Status ConvGrad::PrepareArgs(const Tensor& x, const Tensor& dY, const Tensor& w, Tensor* dB, Tensor* dX, + Tensor* dW) const { + const TensorShape& x_shape = x.Shape(); + std::vector x_dims = x_shape.GetDims(); + args_.x_data = reinterpret_cast(x.template Data()); - const TensorShape& o_shape = output.Shape(); - std::vector o_dims = o_shape.GetDims(); + const TensorShape& dy_shape = dY.Shape(); + std::vector dy_dims = dy_shape.GetDims(); + args_.dy_data = reinterpret_cast(dY.template Data()); - const TensorShape& w_shape = weight.Shape(); + const TensorShape& w_shape = w.Shape(); std::vector w_dims = w_shape.GetDims(); + args_.w_data = reinterpret_cast(w.template Data()); - // Update Attributes - ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(&input, &weight)); + args_.db_data = dB ? reinterpret_cast(dB->template MutableData()) : nullptr; + args_.dx_data = dX ? reinterpret_cast(dX->template MutableData()) : nullptr; + args_.dw_data = dW ? reinterpret_cast(dW->template MutableData()) : nullptr; - std::vector kernel_shape; - ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(w_shape, kernel_shape)); - auto rank = kernel_shape.size(); + bool x_dims_changed = (args_.last_x_dims != x_dims); + bool w_dims_changed = (args_.last_w_dims != w_dims); + if (x_dims_changed || w_dims_changed) { + if (x_dims_changed) args_.last_x_dims = x_dims; + if (w_dims_changed) args_.last_w_dims = w_dims; - std::vector pads(conv_attrs_.pads); - if (pads.empty()) { - pads.resize(rank * 2, 0); - } + // Update Attributes + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(&x, &w)); - std::vector dilations(conv_attrs_.dilations); - if (dilations.empty()) { - dilations.resize(rank, 1); - } + std::vector kernel_shape; + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(w_shape, kernel_shape)); + auto rank = kernel_shape.size(); - std::vector strides(conv_attrs_.strides); - if (strides.empty()) { - strides.resize(rank, 1); - } + std::vector pads(conv_attrs_.pads); + if (pads.empty()) { + pads.resize(rank * 2, 0); + } - // cudnn only takes 4D or 5D input, so pad dimensions if needed - if (rank < 2) { - i_dims.push_back(1); - o_dims.push_back(1); - w_dims.push_back(1); + std::vector dilations(conv_attrs_.dilations); + if (dilations.empty()) { + dilations.resize(rank, 1); + } - pads.insert(pads.begin() + rank, 0); - pads.insert(pads.end(), 0); - kernel_shape.push_back(1); - strides.push_back(1); - dilations.push_back(1); - } + std::vector strides(conv_attrs_.strides); + if (strides.empty()) { + strides.resize(rank, 1); + } - args_.handle = CudnnHandle(); - args_.data_type = CudnnTensor::GetDataType(); - ORT_RETURN_IF_ERROR(args_.i_desc.Set(i_dims, args_.data_type)); - ORT_RETURN_IF_ERROR(args_.o_desc.Set(o_dims, args_.data_type)); - ORT_RETURN_IF_ERROR(args_.w_desc.Set(w_dims, args_.data_type)); - ORT_RETURN_IF_ERROR(args_.c_desc.Set(kernel_shape.size(), pads, strides, dilations, - gsl::narrow_cast(conv_attrs_.group), - CUDNN_CROSS_CORRELATION, args_.data_type)); + // cuDNN only takes 4D or 5D x tensor, so pad dimensions if needed. + if (rank < 2) { + x_dims.push_back(1); + dy_dims.push_back(1); + w_dims.push_back(1); + pads.insert(pads.begin() + rank, 0); + pads.insert(pads.end(), 0); + kernel_shape.push_back(1); + strides.push_back(1); + dilations.push_back(1); + } - if (bias) { - const TensorShape& b_shape = bias->Shape(); - ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); - std::vector b_dims(2 + kernel_shape.size(), 1); - b_dims[1] = b_shape[0]; - ORT_RETURN_IF_ERROR(args_.b_desc.Set(b_dims, args_.data_type)); + const CUDAExecutionProvider* cuda_ep = + static_cast(this->Info().GetExecutionProvider()); + memset(&args_.params, 0, sizeof(ConvParams)); + args_.params.device_id = static_cast(cuda_ep->GetDeviceId()); + args_.params.data_type = CudnnTensor::GetDataType(); + args_.params.input_dim = static_cast(x_dims.size()); + for (size_t i = 0; i < x_dims.size(); i++) { + args_.params.input_size[i] = static_cast(x_dims[i]); + args_.params.weight_size[i] = static_cast(w_dims[i]); + } + for (size_t i = 0; i < rank; i++) { + args_.params.padding[i] = static_cast(pads[i]); + args_.params.padding[i + rank] = static_cast(pads[i + rank]); + args_.params.stride[i] = static_cast(strides[i]); + args_.params.dilation[i] = static_cast(dilations[i]); + } + args_.params.groups = conv_attrs_.group; + int algo_mode = cuda_ep->GetCudnnConvAlgo(); + ORT_ENFORCE(algo_mode > -1 && algo_mode < 3, + "Algo mode should be EXHAUSTIVE (0), HEURISTIC (1) or DEFAULT (2), but got ", algo_mode); + args_.params.algo_mode = algo_mode; + + args_.handle = CudnnHandle(); + ORT_RETURN_IF_ERROR(args_.w_desc.Set(w_dims, args_.params.data_type)); + ORT_RETURN_IF_ERROR(args_.x_tensor.Set(x_dims, args_.params.data_type)); + ORT_RETURN_IF_ERROR(args_.y_tensor.Set(dy_dims, args_.params.data_type)); + ORT_RETURN_IF_ERROR(args_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, + gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, + args_.params.data_type)); + + if (dB) { + const TensorShape& db_shape = dB->Shape(); + ORT_RETURN_IF_NOT(db_shape.NumDimensions() == 1, "bias should be 1D"); + std::vector db_dims(2 + kernel_shape.size(), 1); + db_dims[1] = db_shape[0]; + ORT_RETURN_IF_ERROR(args_.b_tensor.Set(db_dims, CudnnTensor::GetDataType())); + } } return Status::OK(); @@ -122,106 +351,56 @@ Status ConvGrad::ComputeInternal(OpKernelContext* context) const { const Tensor* dY = context->Input(0); const Tensor* X = context->Input(1); const Tensor* W = context->Input(2); - - const int64_t M = W->Shape()[0]; - Tensor* dX = context->Output(0, X->Shape()); Tensor* dW = context->Output(1, W->Shape()); - Tensor* dB = context->Output(2, {M}); - - ORT_RETURN_IF_ERROR(PrepareArgs(*X, *dY, *W, dB)); - - ORT_RETURN_IF_ERROR(ComputeWeightGradient(dW, dY, X)); - ORT_RETURN_IF_ERROR(ComputeInputGradient(dX, dY, W)); - ORT_RETURN_IF_ERROR(ComputeBiasGradient(dB, dY)); - + Tensor* dB = context->Output(2, {W->Shape()[0]}); + ORT_RETURN_IF_ERROR(PrepareArgs(*X, *dY, *W, dB, dX, dW)); + if (dX) ORT_RETURN_IF_ERROR(ComputeInputGradient()); + if (dW) ORT_RETURN_IF_ERROR(ComputeWeightGradient()); + if (dB) ORT_RETURN_IF_ERROR(ComputeBiasGradient()); return Status::OK(); } template -Status ConvGrad::ComputeWeightGradient(Tensor* dW, const Tensor* dY, const Tensor* X) const { - if (dW == nullptr) return Status::OK(); - - // TODO: implement the algoritm search - cudnnConvolutionBwdFilterAlgoPerf_t perf; - perf.algo = kDefaultConvBwdFilterAlgo; - if (args_.data_type == CUDNN_DATA_HALF) { - perf.mathType = CUDNN_TENSOR_OP_MATH; - } else { - perf.mathType = CUDNN_DEFAULT_MATH; - } - CUDNN_RETURN_IF_ERROR(getWorkspaceSize(args_, perf.algo, &perf.memory)); - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(args_.c_desc, perf.mathType)); - - void* dw_data = dW->template MutableData(); - const void* dy_data = dY->template Data(); - const void* x_data = X->template Data(); - IAllocatorUniquePtr workspace = GetScratchBuffer(perf.memory); - - const auto one = Consts::One; - const auto zero = Consts::Zero; - - CUDNN_RETURN_IF_ERROR( - cudnnConvolutionBackwardFilter( - args_.handle, - &one, args_.i_desc, x_data, - args_.o_desc, dy_data, - args_.c_desc, perf.algo, workspace.get(), perf.memory, - &zero, args_.w_desc, dw_data)); - +Status ConvGrad::ComputeInputGradient() const { + AlgoIterator(args_).TryAll( + static_cast(Info().GetExecutionProvider()), + [&](const T_BwdDataPerf& algo_perf) -> Status { + const auto one = Consts::One; + const auto zero = Consts::Zero; + IAllocatorUniquePtr workspace = GetScratchBuffer(algo_perf.memory); + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(args_.conv_desc, algo_perf.mathType)); + CUDNN_RETURN_IF_ERROR(cudnnConvolutionBackwardData( + args_.handle, &one, args_.w_desc, args_.w_data, args_.y_tensor, args_.dy_data, args_.conv_desc, + algo_perf.algo, workspace.get(), algo_perf.memory, &zero, args_.x_tensor, args_.dx_data)); + return Status::OK(); + }); return Status::OK(); } template -Status ConvGrad::ComputeInputGradient(Tensor* dX, const Tensor* dY, const Tensor* W) const { - if (dX == nullptr) return Status::OK(); - - // TODO: implement the algoritm search - cudnnConvolutionBwdDataAlgoPerf_t perf; - perf.algo = kDefaultConvBwdDataAlgo; - if (args_.data_type == CUDNN_DATA_HALF) { - perf.mathType = CUDNN_TENSOR_OP_MATH; - } else { - perf.mathType = CUDNN_DEFAULT_MATH; - } - CUDNN_RETURN_IF_ERROR(getWorkspaceSize(args_, perf.algo, &perf.memory)); - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(args_.c_desc, perf.mathType)); - - void* dx_data = dX->template MutableData(); - const void* dy_data = dY->template Data(); - const void* w_data = W->template Data(); - IAllocatorUniquePtr workspace = GetScratchBuffer(perf.memory); - - const auto one = Consts::One; - const auto zero = Consts::Zero; - - CUDNN_RETURN_IF_ERROR( - cudnnConvolutionBackwardData( - args_.handle, - &one, args_.w_desc, w_data, - args_.o_desc, dy_data, - args_.c_desc, perf.algo, workspace.get(), perf.memory, - &zero, args_.i_desc, dx_data)); - +Status ConvGrad::ComputeWeightGradient() const { + AlgoIterator(args_).TryAll( + static_cast(Info().GetExecutionProvider()), + [&](const T_BwdFilterPerf& algo_perf) -> Status { + const auto one = Consts::One; + const auto zero = Consts::Zero; + IAllocatorUniquePtr workspace = GetScratchBuffer(algo_perf.memory); + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(args_.conv_desc, algo_perf.mathType)); + CUDNN_RETURN_IF_ERROR(cudnnConvolutionBackwardFilter( + args_.handle, &one, args_.x_tensor, args_.x_data, args_.y_tensor, args_.dy_data, args_.conv_desc, + algo_perf.algo, workspace.get(), algo_perf.memory, &zero, args_.w_desc, args_.dw_data)); + return Status::OK(); + }); return Status::OK(); } template -Status ConvGrad::ComputeBiasGradient(Tensor* dB, const Tensor* dY) const { - if (dB == nullptr) return Status::OK(); - +Status ConvGrad::ComputeBiasGradient() const { const auto one = Consts::One; const auto zero = Consts::Zero; - - void* db_data = dB->template MutableData(); - const void* dy_data = dY->template Data(); - - CUDNN_RETURN_IF_ERROR( - cudnnConvolutionBackwardBias( - args_.handle, - &one, args_.o_desc, dy_data, - &zero, args_.b_desc, db_data)); - + CUDNN_RETURN_IF_ERROR(cudnnConvolutionBackwardBias(args_.handle, &one, args_.y_tensor, args_.dy_data, &zero, + args_.b_tensor, args_.db_data)); return Status::OK(); } diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.h b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.h index 79770cb382..c8eaf64e1c 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.h +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.h @@ -3,8 +3,6 @@ #pragma once -#include "core/common/common.h" -#include "core/providers/cuda/cuda_kernel.h" #include "core/providers/cuda/cudnn_common.h" #include "core/providers/cpu/nn/conv_attributes.h" #include "core/providers/cuda/nn/conv.h" @@ -12,15 +10,38 @@ namespace onnxruntime { namespace cuda { -struct ConvolutionArgs { - cudnnHandle_t handle; +// cuDNN only takes 4D or 5D x tensor. +constexpr int MAX_DIM = 3; + +struct ConvParams { + int8_t device_id; cudnnDataType_t data_type; + int input_size[2 + MAX_DIM]; + uint8_t input_dim; + int weight_size[2 + MAX_DIM]; + int padding[MAX_DIM * 2]; + int stride[MAX_DIM]; + int dilation[MAX_DIM]; + int64_t groups; + int algo_mode; +}; - CudnnTensor i_desc, o_desc, b_desc; +struct ConvArgs { + // Update needed if x or w's dims changed. + std::vector last_x_dims; + std::vector last_w_dims; + + cudnnHandle_t handle; + ConvParams params; + CudnnTensor x_tensor, y_tensor, b_tensor; CudnnFilterDescriptor w_desc; - CudnnConvolutionDescriptor c_desc; - - ConvolutionArgs() {} + CudnnConvolutionDescriptor conv_desc; + const void* x_data; + const void* w_data; + const void* dy_data; + void* dx_data; + void* dw_data; + void* db_data; }; template @@ -39,19 +60,14 @@ class ConvGrad final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; protected: - mutable ConvolutionArgs args_; - Status PrepareArgs(const Tensor& input, const Tensor& output, const Tensor& weight, const Tensor* bias) const; - + Status PrepareArgs(const Tensor& x, const Tensor& dY, const Tensor& w, Tensor* dB, Tensor* dX, Tensor* dW) const; + mutable ConvArgs args_; ConvAttributes conv_attrs_; - // https://docs.nvidia.com/deeplearning/cudnn/archives/cudnn_742/cudnn-developer-guide/index.html#tensor_ops - static constexpr auto kDefaultConvBwdDataAlgo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; - static constexpr auto kDefaultConvBwdFilterAlgo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; - private: - Status ComputeWeightGradient(Tensor* dW, const Tensor* dY, const Tensor* X) const; - Status ComputeInputGradient(Tensor* dX, const Tensor* dY, const Tensor* W) const; - Status ComputeBiasGradient(Tensor* dB, const Tensor* dY) const; + Status ComputeWeightGradient() const; + Status ComputeInputGradient() const; + Status ComputeBiasGradient() const; }; } // namespace cuda