diff --git a/onnxruntime/core/providers/rocm/miopen_common.h b/onnxruntime/core/providers/rocm/miopen_common.h index 30af360881..8d73ce3f98 100644 --- a/onnxruntime/core/providers/rocm/miopen_common.h +++ b/onnxruntime/core/providers/rocm/miopen_common.h @@ -15,6 +15,8 @@ namespace onnxruntime { namespace rocm { #define MIOPEN_CONVOLUTION_FWD_ALGO_COUNT 6 +#define MIOPEN_CONVOLUTION_BWD_FILTER_ALGO_COUNT 4 +#define MIOPEN_CONVOLUTION_BWD_DATA_ALGO_COUNT 6 class MiopenTensor final { public: diff --git a/orttraining/orttraining/test/training_ops/cuda/conv_grad_test.cc b/orttraining/orttraining/test/training_ops/cuda/conv_grad_test.cc index fed8d1a0ab..6ed7cb5a81 100644 --- a/orttraining/orttraining/test/training_ops/cuda/conv_grad_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/conv_grad_test.cc @@ -5,13 +5,13 @@ #include "test/providers/provider_test_utils.h" namespace onnxruntime { -namespace cuda { +namespace contrib { namespace test { using namespace std; using namespace onnxruntime::test; -#ifdef USE_CUDA +#if USE_CUDA || USE_ROCM namespace { struct ConvGradOpAttributes { @@ -315,8 +315,8 @@ TEST(ConvTest, Conv3D_Bias) { TestConvGradOp(attrs, {dY, X, W}, {dY_shape, X_shape, W_shape}, {dX, dW, dB}, {dX_shape, dW_shape, dB_shape}); TestConvGradOp(attrs, {dY, X, W}, {dY_shape, X_shape, W_shape}, {dX, dW, dB}, {dX_shape, dW_shape, dB_shape}, true); } -#endif // USE_CUDA +#endif // USE_CUDA || USE_ROCM } // namespace test -} // namespace cuda +} // namespace contrib } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/nn/conv_grad.cc b/orttraining/orttraining/training_ops/rocm/nn/conv_grad.cc new file mode 100644 index 0000000000..2e002c6afa --- /dev/null +++ b/orttraining/orttraining/training_ops/rocm/nn/conv_grad.cc @@ -0,0 +1,384 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//TODO Add exhaustive and default cases for algo. + +#include "orttraining/training_ops/rocm/nn/conv_grad.h" + +#include "core/providers/common.h" +#include "core/providers/rocm/shared_inc/fpgeneric.h" +#include "core/platform/ort_mutex.h" + +namespace onnxruntime { +namespace rocm { + +#define REGISTER_GRADIENT_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX(ConvGrad, kMSDomain, 1, T, kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + ConvGrad); + +REGISTER_GRADIENT_KERNEL_TYPED(float) +// MIOpen double support not currently implemented. +//REGISTER_GRADIENT_KERNEL_TYPED(double) +REGISTER_GRADIENT_KERNEL_TYPED(MLFloat16) + +using T_BwdDataPerf = miopenConvAlgoPerf_t; +using T_BwdDataAlgo = miopenConvBwdDataAlgorithm_t; +using T_BwdFilterPerf = miopenConvAlgoPerf_t; +using T_BwdFilterAlgo = miopenConvBwdWeightsAlgorithm_t; + +miopenStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdDataAlgo algo, size_t* workspace_size) { + return miopenConvolutionBackwardDataGetWorkSpaceSize(args.handle, args.y_tensor, args.x_tensor, args.conv_desc, + args.w_desc, workspace_size); +} + +miopenStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdFilterAlgo algo, size_t* workspace_size) { + return miopenConvolutionBackwardWeightsGetWorkSpaceSize(args.handle, args.y_tensor, args.x_tensor, args.conv_desc, + args.w_desc, workspace_size); +} + +template +size_t GetMaxWorkspaceSize(const ConvArgs& args, const T_Algo* algo, int n_algo) { + // Calling hipMemGetInfo is not ideal, but our rocm allocator doesn't have a way to get this info. + size_t free, total; + HIP_CALL_THROW(hipMemGetInfo(&free, &total)); + // Assuming 10% of fragmentation. + free = static_cast(static_cast(free) * 0.9); + size_t max_workspace_size = 0; + for (int i = 0; i < n_algo; i++) { + miopenStatus_t status; + size_t workspace_size; + status = GetWorkspaceSize(args, algo[i], &workspace_size); + if (miopenStatusSuccess != status || workspace_size == 0 || workspace_size < max_workspace_size || + workspace_size > free) + continue; + max_workspace_size = workspace_size; + } + + return max_workspace_size; +} + +template +std::vector GetValidAlgorithms(const T_Perf* perf_results, int n_algo) { + std::vector result; + result.reserve(n_algo); + for (int i = 0; i < n_algo; i++) { + T_Perf perf = perf_results[i]; + result.emplace_back(perf); + } + ORT_ENFORCE(result.size() > 0, "No valid convolution algorithms available in MIOpen"); + return result; +} + +struct ConvParamsHash { + // ConvParams must be a POD because we read out its memory constant as char* when hashing. + static_assert(std::is_pod::value, "ConvParams is not POD"); + size_t operator()(const ConvParams& conv_params) const { + auto ptr = reinterpret_cast(&conv_params); + uint32_t value = 0x811C9DC5; + for (int i = 0; i < static_cast(sizeof(ConvParams)); ++i) { + value ^= ptr[i]; + value *= 0x01000193; + } + return static_cast(value); + } +}; + +struct ConvParamsEqual { + // ConvParams must be a POD because we read out its memory constant as char* when hashing. + static_assert(std::is_pod::value, "ConvParams is not POD"); + bool operator()(const ConvParams& a, const ConvParams& b) const { + auto ptr1 = reinterpret_cast(&a); + auto ptr2 = reinterpret_cast(&b); + return memcmp(ptr1, ptr2, sizeof(ConvParams)) == 0; + } +}; + +template +struct AlgoPerfCache { + mutable OrtMutex mutex; + std::unordered_map map; + + bool Find(const ConvParams& params, T_Perf* result) { + std::lock_guard guard(mutex); + auto it = map.find(params); + if (it == map.end()) { + return false; + } + *result = it->second; + return true; + } + + void Insert(const ConvParams& params, const T_Perf& algo_perf) { + std::lock_guard guard(mutex); + map[params] = algo_perf; + } +}; + +// TODO: Currently we use global AlgoPerfCache for ConvGrad only. Conv's perf cache is still per node. +// Need to apply such global cache for Conv, and move some shared code from here to conv.h/cc. +AlgoPerfCache bwd_data_algos; +AlgoPerfCache bwd_filter_algos; + +template +struct AlgoSearch {}; + +template <> +struct AlgoSearch { + static constexpr auto DEFAULT_ALGO = miopenConvolutionBwdDataAlgoGEMM; + static AlgoPerfCache& Cache() { return bwd_data_algos; } + static Status FindAlgorithms(const ConvArgs& args, const ROCMExecutionProvider* provider, + std::vector& perf_results) { + static const T_BwdDataAlgo algos[] = { + miopenConvolutionBwdDataAlgoGEMM, + miopenConvolutionBwdDataAlgoDirect, + miopenConvolutionBwdDataAlgoFFT, + miopenConvolutionBwdDataAlgoWinograd, + miopenTransposeBwdDataAlgoGEMM, + miopenConvolutionBwdDataAlgoImplicitGEMM + }; + static constexpr int num_algos = MIOPEN_CONVOLUTION_BWD_DATA_ALGO_COUNT; + ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing MIOpen convolution backward data algorithms."); + int perf_count; + std::unique_ptr candidates(new T_BwdDataPerf[num_algos]); + size_t max_workspace_size = provider->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos) + : AlgoSearchWorkspaceSize; + // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. + // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. + IAllocatorUniquePtr workspace = provider->GetTransientScratchBuffer(max_workspace_size); + MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionBackwardDataAlgorithm( + args.handle, args.y_tensor, args.dy_data, args.w_desc, args.w_data, args.conv_desc, args.x_tensor, + args.dx_data, 1, &perf_count, candidates.get(), workspace.get(), max_workspace_size, false)); + perf_results = GetValidAlgorithms(candidates.get(), perf_count); + return Status::OK(); + } +}; + +template <> +struct AlgoSearch { + static constexpr auto DEFAULT_ALGO = miopenConvolutionBwdWeightsAlgoGEMM; + static AlgoPerfCache& Cache() { return bwd_filter_algos; } + static Status FindAlgorithms(const ConvArgs& args, const ROCMExecutionProvider* provider, + std::vector& perf_results) { + static const T_BwdFilterAlgo algos[] = { + miopenConvolutionBwdWeightsAlgoGEMM, + miopenConvolutionBwdWeightsAlgoDirect, + miopenConvolutionBwdWeightsAlgoWinograd, + miopenConvolutionBwdWeightsAlgoImplicitGEMM + }; + + static constexpr int num_algos = MIOPEN_CONVOLUTION_BWD_FILTER_ALGO_COUNT; + ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing MIOpen convolution backward filter algorithms."); + std::unique_ptr candidates(new T_BwdFilterPerf[num_algos]); + int perf_count; + size_t max_workspace_size = provider->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos) + : AlgoSearchWorkspaceSize; + // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. + // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. + IAllocatorUniquePtr workspace = provider->GetTransientScratchBuffer(max_workspace_size); + MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionBackwardWeightsAlgorithm( + args.handle, args.y_tensor, args.dy_data, args.x_tensor, args.x_data, args.conv_desc, args.w_desc, + args.dw_data, 1, &perf_count, candidates.get(), workspace.get(), max_workspace_size, false)); + perf_results = GetValidAlgorithms(candidates.get(), perf_count); + return Status::OK(); + } +}; + +template +class AlgoIterator { + public: + AlgoIterator(const ConvArgs& args) : args_(args) {} + + Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results); + + Status TryAll(const ROCMExecutionProvider* provider, std::function f) { + auto& cache = AlgoSearch::Cache(); + miopenConvAlgoPerf_t algo_perf; + if (cache.Find(args_.params, &algo_perf) && f(algo_perf) == Status::OK()) { + return Status::OK(); + } + + std::vector perf_results; + ORT_RETURN_IF_ERROR(AlgoSearch::FindAlgorithms(args_, provider, perf_results)); + for (auto& algo_perf : perf_results) { + if (f(algo_perf) == Status::OK()) { + cache.Insert(args_.params, algo_perf); + return Status::OK(); + } + } + ORT_ENFORCE(false, "Unable to find a valid MIOpen algorithm to run convolution."); + return Status::OK(); + } + + private: + const ConvArgs& args_; +}; + +template<> Status AlgoIterator::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results){ + perf_results.resize(1); + perf_results[0].bwd_data_algo = AlgoSearch::DEFAULT_ALGO; + MIOPEN_RETURN_IF_ERROR(GetWorkspaceSize(args, perf_results[0].bwd_data_algo, &(perf_results[0].memory))); + return Status::OK(); +} + +template<> Status AlgoIterator::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results){ + perf_results.resize(1); + perf_results[0].bwd_weights_algo = AlgoSearch::DEFAULT_ALGO; + MIOPEN_RETURN_IF_ERROR(GetWorkspaceSize(args, perf_results[0].bwd_weights_algo, &(perf_results[0].memory))); + return Status::OK(); +} + + +template +Status ConvGrad::PrepareArgs(const Tensor& x, const Tensor& dY, const Tensor& w, Tensor* dB, Tensor* dX, + Tensor* dW) const { + const TensorShape& x_shape = x.Shape(); + std::vector x_dims = x_shape.GetDims(); + args_.x_data = reinterpret_cast(x.template Data()); + + const TensorShape& dy_shape = dY.Shape(); + std::vector dy_dims = dy_shape.GetDims(); + args_.dy_data = reinterpret_cast(dY.template Data()); + + const TensorShape& w_shape = w.Shape(); + std::vector w_dims = w_shape.GetDims(); + args_.w_data = reinterpret_cast(w.template Data()); + + args_.db_data = dB ? reinterpret_cast(dB->template MutableData()) : nullptr; + args_.dx_data = dX ? reinterpret_cast(dX->template MutableData()) : nullptr; + args_.dw_data = dW ? reinterpret_cast(dW->template MutableData()) : nullptr; + + bool x_dims_changed = (args_.last_x_dims != x_dims); + bool w_dims_changed = (args_.last_w_dims != w_dims); + if (x_dims_changed || w_dims_changed) { + if (x_dims_changed) args_.last_x_dims = x_dims; + if (w_dims_changed) args_.last_w_dims = w_dims; + + // Update Attributes + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(&x, &w)); + + std::vector kernel_shape; + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(w_shape, kernel_shape)); + auto rank = kernel_shape.size(); + + std::vector pads(conv_attrs_.pads); + if (pads.empty()) { + pads.resize(rank * 2, 0); + } + + std::vector dilations(conv_attrs_.dilations); + if (dilations.empty()) { + dilations.resize(rank, 1); + } + + std::vector strides(conv_attrs_.strides); + if (strides.empty()) { + strides.resize(rank, 1); + } + + // MIOpen only takes 4D or 5D x tensor, so pad dimensions if needed. + if (rank < 2) { + x_dims.push_back(1); + dy_dims.push_back(1); + w_dims.push_back(1); + pads.insert(pads.begin() + rank, 0); + pads.insert(pads.end(), 0); + kernel_shape.push_back(1); + strides.push_back(1); + dilations.push_back(1); + } + + const ROCMExecutionProvider* rocm_ep = + static_cast(this->Info().GetExecutionProvider()); + memset(&args_.params, 0, sizeof(ConvParams)); + args_.params.device_id = static_cast(rocm_ep->GetDeviceId()); + args_.params.data_type = MiopenTensor::GetDataType(); + args_.params.input_dim = static_cast(x_dims.size()); + for (size_t i = 0; i < x_dims.size(); i++) { + args_.params.input_size[i] = static_cast(x_dims[i]); + args_.params.weight_size[i] = static_cast(w_dims[i]); + } + for (size_t i = 0; i < rank; i++) { + args_.params.padding[i] = static_cast(pads[i]); + args_.params.padding[i + rank] = static_cast(pads[i + rank]); + args_.params.stride[i] = static_cast(strides[i]); + args_.params.dilation[i] = static_cast(dilations[i]); + } + args_.params.groups = conv_attrs_.group; + args_.handle = MiopenHandle(); + ORT_RETURN_IF_ERROR(args_.w_desc.Set(w_dims, args_.params.data_type)); + ORT_RETURN_IF_ERROR(args_.x_tensor.Set(x_dims, args_.params.data_type)); + ORT_RETURN_IF_ERROR(args_.y_tensor.Set(dy_dims, args_.params.data_type)); + ORT_RETURN_IF_ERROR(args_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, + gsl::narrow_cast(conv_attrs_.group), miopenConvolution, + args_.params.data_type)); + + if (dB) { + const TensorShape& db_shape = dB->Shape(); + ORT_RETURN_IF_NOT(db_shape.NumDimensions() == 1, "bias should be 1D"); + std::vector db_dims(2 + kernel_shape.size(), 1); + db_dims[1] = db_shape[0]; + ORT_RETURN_IF_ERROR(args_.b_tensor.Set(db_dims, MiopenTensor::GetDataType())); + } + } + + return Status::OK(); +} + +template +Status ConvGrad::ComputeInternal(OpKernelContext* context) const { + const Tensor* dY = context->Input(0); + const Tensor* X = context->Input(1); + const Tensor* W = context->Input(2); + Tensor* dX = context->Output(0, X->Shape()); + Tensor* dW = context->Output(1, W->Shape()); + Tensor* dB = context->Output(2, {W->Shape()[0]}); + ORT_RETURN_IF_ERROR(PrepareArgs(*X, *dY, *W, dB, dX, dW)); + if (dX) ORT_RETURN_IF_ERROR(ComputeInputGradient()); + if (dW) ORT_RETURN_IF_ERROR(ComputeWeightGradient()); + if (dB) ORT_RETURN_IF_ERROR(ComputeBiasGradient()); + return Status::OK(); +} + +template +Status ConvGrad::ComputeInputGradient() const { + return AlgoIterator(args_).TryAll( + static_cast(Info().GetExecutionProvider()), + [&](const T_BwdDataPerf& algo_perf) -> Status { + const auto one = Consts::One; + const auto zero = Consts::Zero; + IAllocatorUniquePtr workspace = GetScratchBuffer(algo_perf.memory); + MIOPEN_RETURN_IF_ERROR(miopenConvolutionBackwardData( + args_.handle, &one, args_.y_tensor, args_.dy_data, args_.w_desc, args_.w_data, args_.conv_desc, + algo_perf.bwd_data_algo, &zero, args_.x_tensor, args_.dx_data, workspace.get(), algo_perf.memory)); + return Status::OK(); + }); +} + +template +Status ConvGrad::ComputeWeightGradient() const { + return AlgoIterator(args_).TryAll( + static_cast(Info().GetExecutionProvider()), + [&](const T_BwdFilterPerf& algo_perf) -> Status { + const auto one = Consts::One; + const auto zero = Consts::Zero; + IAllocatorUniquePtr workspace = GetScratchBuffer(algo_perf.memory); + MIOPEN_RETURN_IF_ERROR(miopenConvolutionBackwardWeights( + args_.handle, &one, args_.y_tensor, args_.dy_data, args_.x_tensor, args_.x_data, args_.conv_desc, + algo_perf.bwd_weights_algo, &zero, args_.w_desc, args_.dw_data, workspace.get(), algo_perf.memory)); + return Status::OK(); + }); +} + +template +Status ConvGrad::ComputeBiasGradient() const { + const auto one = Consts::One; + const auto zero = Consts::Zero; + MIOPEN_RETURN_IF_ERROR(miopenConvolutionBackwardBias( + args_.handle, &one, args_.y_tensor, args_.dy_data, &zero, + args_.b_tensor, args_.db_data)); + return Status::OK(); +} + +} // namespace rocm +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/nn/conv_grad.h b/orttraining/orttraining/training_ops/rocm/nn/conv_grad.h new file mode 100644 index 0000000000..97428eab42 --- /dev/null +++ b/orttraining/orttraining/training_ops/rocm/nn/conv_grad.h @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/rocm/miopen_common.h" +#include "core/providers/cpu/nn/conv_attributes.h" +#include "core/providers/rocm/nn/conv.h" + +namespace onnxruntime { +namespace rocm { + +constexpr int MAX_DIM = 3; + +struct ConvParams { + int8_t device_id; + miopenDataType_t data_type; + int input_size[2 + MAX_DIM]; + uint8_t input_dim; + int weight_size[2 + MAX_DIM]; + int padding[MAX_DIM * 2]; + int stride[MAX_DIM]; + int dilation[MAX_DIM]; + int64_t groups; +}; + +struct ConvArgs { + // Update needed if x or w's dims changed. + std::vector last_x_dims; + std::vector last_w_dims; + + miopenHandle_t handle; + ConvParams params; + MiopenTensor x_tensor, y_tensor, b_tensor; + MiopenTensorDescriptor w_desc; + MiopenConvolutionDescriptor conv_desc; + const void* x_data; + const void* w_data; + const void* dy_data; + void* dx_data; + void* dw_data; + void* db_data; +}; + +template +class ConvGrad final : public RocmKernel { + public: + using HipT = typename ToHipType::MappedType; + + ConvGrad(const OpKernelInfo& info) : RocmKernel(info), conv_attrs_(info) { + auto pads_size = conv_attrs_.pads.size(); + ORT_ENFORCE(pads_size % 2 == 0); + } + + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + Status PrepareArgs(const Tensor& x, const Tensor& dY, const Tensor& w, Tensor* dB, Tensor* dX, Tensor* dW) const; + mutable ConvArgs args_; + ConvAttributes conv_attrs_; + + private: + Status ComputeWeightGradient() const; + Status ComputeInputGradient() const; + Status ComputeBiasGradient() const; +}; + +} // namespace rocm +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc index 0fcb36f161..322a35c566 100644 --- a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc +++ b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc @@ -75,7 +75,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_float, BatchNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_float, BatchNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ConvGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ConvGrad); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ConvGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ConvGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GatherGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, DropoutGrad); @@ -262,9 +262,9 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo,