From d5d69246885befb4c52ee9deef8946121834a00b Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Wed, 9 Nov 2022 08:47:06 -0800 Subject: [PATCH] rocblas alt impl during backward pass only (#13352) On AMD Instinct MI200 GPUs, the FP16 and BF16 V_DOT2 and MFMA matrix instructions flush input and output denormal values to zero. When training using FP16 precision, some models may fail to converge with FP16 denorms flushed to zero. The affected instructions are only used by rocBLAS (GEMM) and MIOpen (convolution) kernels; all other onnxruntime operations will not encounter this behavior. All other supported AMD GPUs will not encounter this behavior. rocBLAS and MIOpen provide alternate implementations for affected FP16 operations. Alternate implementations for BF16 operations are not provided; BF16 numbers have a larger dynamic range than FP16 numbers and are less likely to encounter denormal values. For the FP16 alternate implementations, FP16 input values are cast to an intermediate BF16 value and then cast back to FP16 output after the accumulate FP32 operations. In this way, the input and output types are unchanged. Denormal values more frequently occur in the backward pass of training during gradient calculation. Therefore, it is necessary to track when the backward pass of training is executing. For the ROCm EP only, the `__backwardpass` attribute is added to all Nodes after the YieldOp is detected. This takes place in a level1 graph optimization pass. The attribute is forwarded to any newly created FusedMatMul Nodes. In addition, the scope-based helper class `BackwardPassGuard` is provided to toggle state for rocblas. This behavior of using the alternate implementations during the backward pass is made automatic with this PR. This default behavior can be overridden using environment variables, ROCBLAS_INTERNAL_FP16_ALT_IMPL and MIOPEN_DEBUG_CONVOLUTION_ATTRIB_FP16_ALT_IMPL. The behavior of these environment variables is as follows: | | forward | backward | |--------------|-----------|-----------| | Env unset | original | alternate | | Env set to 1 | alternate | alternate | | Env set to 0 | original | original | See also: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices --- .../core/optimizer/graph_transformer_utils.cc | 5 +++ .../core/optimizer/matmul_scale_fusion.cc | 7 ++++ .../core/optimizer/matmul_transpose_fusion.cc | 7 ++++ .../core/optimizer/rocm_blas_alt_impl.cc | 36 +++++++++++++++++++ .../core/optimizer/rocm_blas_alt_impl.h | 19 ++++++++++ .../core/providers/rocm/backward_guard.cc | 21 +++++++++++ .../core/providers/rocm/backward_guard.h | 15 ++++++++ onnxruntime/core/providers/rocm/rocm_kernel.h | 11 +++++- .../providers/rocm/shared_inc/fpgeneric.h | 23 +++++++++--- 9 files changed, 138 insertions(+), 6 deletions(-) create mode 100644 onnxruntime/core/optimizer/rocm_blas_alt_impl.cc create mode 100644 onnxruntime/core/optimizer/rocm_blas_alt_impl.h create mode 100644 onnxruntime/core/providers/rocm/backward_guard.cc create mode 100644 onnxruntime/core/providers/rocm/backward_guard.h diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 3093d8d43b..58b64511d4 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -10,6 +10,7 @@ #include "core/optimizer/nhwc_transformer.h" #include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" +#include "core/optimizer/rocm_blas_alt_impl.h" #include "core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/optimizer/conv_add_act_fusion.h" @@ -208,6 +209,10 @@ InlinedVector> GenerateTransformers( // shouldn't affect the end result - just easier to debug any issue if it's last. auto cpu_allocator = cpu_execution_provider.GetAllocator(0, OrtMemTypeDefault); transformers.emplace_back(std::make_unique(std::move(cpu_allocator))); + + // add __backwardpass attribute to nodes after YieldOp, ROCm-only + const InlinedHashSet rocm_ep = {onnxruntime::kRocmExecutionProvider}; + transformers.emplace_back(std::make_unique(rocm_ep)); } break; case TransformerLevel::Level2: { diff --git a/onnxruntime/core/optimizer/matmul_scale_fusion.cc b/onnxruntime/core/optimizer/matmul_scale_fusion.cc index 2c43f5ab12..b944b5536d 100644 --- a/onnxruntime/core/optimizer/matmul_scale_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_scale_fusion.cc @@ -254,6 +254,13 @@ Status ProcessNode( kMSDomain); matmul_scale_node.SetExecutionProviderType(node.GetExecutionProviderType()); +#ifdef USE_ROCM + // forward the __backwardpass, if present + auto& attrs = node.GetAttributes(); + if (attrs.count("__backwardpass")) { + matmul_scale_node.AddAttribute("__backwardpass", static_cast(attrs.at("__backwardpass").i())); + } +#endif { InlinedVector> nodes_to_remove{node}; diff --git a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc index d47538640b..642805c93b 100644 --- a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc @@ -404,6 +404,13 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ matmul_node.AddAttribute("alpha", alpha); // Assign provider to this new node. Provider should be same as the provider for old node. matmul_node.SetExecutionProviderType(node.GetExecutionProviderType()); +#ifdef USE_ROCM + // forward the __backwardpass, if present + auto& attrs = node.GetAttributes(); + if (attrs.count("__backwardpass")) { + matmul_node.AddAttribute("__backwardpass", static_cast(attrs.at("__backwardpass").i())); + } +#endif graph_utils::FinalizeNodeFusion(graph, matmul_node, node); diff --git a/onnxruntime/core/optimizer/rocm_blas_alt_impl.cc b/onnxruntime/core/optimizer/rocm_blas_alt_impl.cc new file mode 100644 index 0000000000..decb25f565 --- /dev/null +++ b/onnxruntime/core/optimizer/rocm_blas_alt_impl.cc @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include + +#include "core/optimizer/initializer.h" +#include "core/optimizer/rocm_blas_alt_impl.h" +#include "core/graph/graph_utils.h" + +using namespace ONNX_NAMESPACE; +using namespace ::onnxruntime::common; +namespace onnxruntime { + +Status RocmBlasAltImpl::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + bool is_backward_pass = false; + + for (auto node_index : node_topology_list) { + auto& node = *graph.GetNode(node_index); + + if (node.OpType() == "YieldOp") { + is_backward_pass = true; + } + + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + if (is_backward_pass) { + node.AddAttribute(std::string("__backwardpass"), static_cast(1)); + modified = true; + } + } + + return Status::OK(); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/rocm_blas_alt_impl.h b/onnxruntime/core/optimizer/rocm_blas_alt_impl.h new file mode 100644 index 0000000000..11744d0dac --- /dev/null +++ b/onnxruntime/core/optimizer/rocm_blas_alt_impl.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" +#include "core/graph/graph_utils.h" + +namespace onnxruntime { + +class RocmBlasAltImpl : public GraphTransformer { + public: + RocmBlasAltImpl(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("RocmBlasAltImpl", compatible_execution_providers) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/backward_guard.cc b/onnxruntime/core/providers/rocm/backward_guard.cc new file mode 100644 index 0000000000..1695da092b --- /dev/null +++ b/onnxruntime/core/providers/rocm/backward_guard.cc @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/providers/rocm/backward_guard.h" + +namespace onnxruntime { + +thread_local bool BackwardPassGuard::is_backward_pass_; + +BackwardPassGuard::BackwardPassGuard() { + is_backward_pass_ = true; +} + +BackwardPassGuard::~BackwardPassGuard() { + is_backward_pass_ = false; +} + +bool BackwardPassGuard::is_backward_pass() { + return is_backward_pass_; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/backward_guard.h b/onnxruntime/core/providers/rocm/backward_guard.h new file mode 100644 index 0000000000..e36785af37 --- /dev/null +++ b/onnxruntime/core/providers/rocm/backward_guard.h @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +namespace onnxruntime { + +struct BackwardPassGuard { + BackwardPassGuard(); + ~BackwardPassGuard(); + static bool is_backward_pass(); +private: + static thread_local bool is_backward_pass_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_kernel.h b/onnxruntime/core/providers/rocm/rocm_kernel.h index 6737ae8582..a29537e9f1 100644 --- a/onnxruntime/core/providers/rocm/rocm_kernel.h +++ b/onnxruntime/core/providers/rocm/rocm_kernel.h @@ -3,6 +3,7 @@ #pragma once +#include "core/providers/rocm/backward_guard.h" #include "core/providers/rocm/rocm_common.h" #include "core/providers/rocm/rocm_execution_provider.h" #include "core/providers/rocm/rocm_fwd.h" @@ -22,7 +23,15 @@ class RocmKernel : public OpKernel { } Status Compute(OpKernelContext* p_op_kernel_context) const override { - auto s = ComputeInternal(p_op_kernel_context); + Status s; + auto is_backward_pass = Info().GetAttrOrDefault("__backwardpass", 0); + if (is_backward_pass) { + BackwardPassGuard guard; + s = ComputeInternal(p_op_kernel_context); + } + else { + s = ComputeInternal(p_op_kernel_context); + } // use this to precisely locate the node where ROCM failure comes from // if (hipSuccess != hipDeviceSynchronize()) // __debugbreak(); diff --git a/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h b/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h index 030275c52d..657a11ccb8 100644 --- a/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h +++ b/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h @@ -3,10 +3,23 @@ #pragma once +#include "core/providers/rocm/backward_guard.h" #include "core/providers/rocm/rocm_common.h" +#define ORT_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) +#if ORT_ROCBLAS_VERSION_DECIMAL >= 242 +#define FLAG rocblas_gemm_flags_fp16_alt_impl +#else +#define FLAG 0 +#endif + using namespace onnxruntime; +inline int get_flag() { + int result = BackwardPassGuard::is_backward_pass() ? FLAG : 0; + return result; +} + // Generalize library calls to be use in template functions // gemm @@ -67,7 +80,7 @@ inline rocblas_status rocblasGemmHelper(rocblas_handle handle, C, rocblas_datatype_f16_r, ldc, C, rocblas_datatype_f16_r, ldc, rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, 0); + rocblas_gemm_algo_standard, 0, get_flag()); } inline rocblas_status rocblasGemmHelper(rocblas_handle handle, @@ -90,7 +103,7 @@ inline rocblas_status rocblasGemmHelper(rocblas_handle handle, C, rocblas_datatype_f16_r, ldc, C, rocblas_datatype_f16_r, ldc, rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, 0); + rocblas_gemm_algo_standard, 0, get_flag()); } inline rocblas_status rocblasGemmHelper(rocblas_handle handle, @@ -225,7 +238,7 @@ inline rocblas_status rocblasGemmBatchedHelper(rocblas_handle handle, (void**)Carray, rocblas_datatype_f16_r, ldc, batchCount, rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, 0); + rocblas_gemm_algo_standard, 0, get_flag()); } inline rocblas_status rocblasGemmBatchedHelper(rocblas_handle handle, @@ -329,7 +342,7 @@ inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, C, rocblas_datatype_f16_r, ldc, strideC, batchCount, rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, 0); + rocblas_gemm_algo_standard, 0, get_flag()); } inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, @@ -357,7 +370,7 @@ inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, C, rocblas_datatype_f16_r, ldc, strideC, batchCount, rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, 0); + rocblas_gemm_algo_standard, 0, get_flag()); } inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle,