mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
rocblas alt impl during backward pass only (#13352)
On AMD Instinct MI200 GPUs, the FP16 and BF16 V_DOT2 and MFMA matrix instructions flush input and output denormal values to zero. When training using FP16 precision, some models may fail to converge with FP16 denorms flushed to zero. The affected instructions are only used by rocBLAS (GEMM) and MIOpen (convolution) kernels; all other onnxruntime operations will not encounter this behavior. All other supported AMD GPUs will not encounter this behavior. rocBLAS and MIOpen provide alternate implementations for affected FP16 operations. Alternate implementations for BF16 operations are not provided; BF16 numbers have a larger dynamic range than FP16 numbers and are less likely to encounter denormal values. For the FP16 alternate implementations, FP16 input values are cast to an intermediate BF16 value and then cast back to FP16 output after the accumulate FP32 operations. In this way, the input and output types are unchanged. Denormal values more frequently occur in the backward pass of training during gradient calculation. Therefore, it is necessary to track when the backward pass of training is executing. For the ROCm EP only, the `__backwardpass` attribute is added to all Nodes after the YieldOp is detected. This takes place in a level1 graph optimization pass. The attribute is forwarded to any newly created FusedMatMul Nodes. In addition, the scope-based helper class `BackwardPassGuard` is provided to toggle state for rocblas. This behavior of using the alternate implementations during the backward pass is made automatic with this PR. This default behavior can be overridden using environment variables, ROCBLAS_INTERNAL_FP16_ALT_IMPL and MIOPEN_DEBUG_CONVOLUTION_ATTRIB_FP16_ALT_IMPL. The behavior of these environment variables is as follows: | | forward | backward | |--------------|-----------|-----------| | Env unset | original | alternate | | Env set to 1 | alternate | alternate | | Env set to 0 | original | original | See also: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
This commit is contained in:
parent
d10d66cc84
commit
d5d6924688
9 changed files with 138 additions and 6 deletions
|
|
@ -10,6 +10,7 @@
|
|||
#include "core/optimizer/nhwc_transformer.h"
|
||||
#include "core/optimizer/qdq_transformer/qdq_final_cleanup.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
|
||||
#include "core/optimizer/rocm_blas_alt_impl.h"
|
||||
#include "core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h"
|
||||
#include "core/session/onnxruntime_session_options_config_keys.h"
|
||||
#include "core/optimizer/conv_add_act_fusion.h"
|
||||
|
|
@ -208,6 +209,10 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
// shouldn't affect the end result - just easier to debug any issue if it's last.
|
||||
auto cpu_allocator = cpu_execution_provider.GetAllocator(0, OrtMemTypeDefault);
|
||||
transformers.emplace_back(std::make_unique<TransposeOptimizer>(std::move(cpu_allocator)));
|
||||
|
||||
// add __backwardpass attribute to nodes after YieldOp, ROCm-only
|
||||
const InlinedHashSet<std::string_view> rocm_ep = {onnxruntime::kRocmExecutionProvider};
|
||||
transformers.emplace_back(std::make_unique<RocmBlasAltImpl>(rocm_ep));
|
||||
} break;
|
||||
|
||||
case TransformerLevel::Level2: {
|
||||
|
|
|
|||
|
|
@ -254,6 +254,13 @@ Status ProcessNode(
|
|||
kMSDomain);
|
||||
|
||||
matmul_scale_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
#ifdef USE_ROCM
|
||||
// forward the __backwardpass, if present
|
||||
auto& attrs = node.GetAttributes();
|
||||
if (attrs.count("__backwardpass")) {
|
||||
matmul_scale_node.AddAttribute("__backwardpass", static_cast<int64_t>(attrs.at("__backwardpass").i()));
|
||||
}
|
||||
#endif
|
||||
|
||||
{
|
||||
InlinedVector<std::reference_wrapper<Node>> nodes_to_remove{node};
|
||||
|
|
|
|||
|
|
@ -404,6 +404,13 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_
|
|||
matmul_node.AddAttribute("alpha", alpha);
|
||||
// Assign provider to this new node. Provider should be same as the provider for old node.
|
||||
matmul_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
#ifdef USE_ROCM
|
||||
// forward the __backwardpass, if present
|
||||
auto& attrs = node.GetAttributes();
|
||||
if (attrs.count("__backwardpass")) {
|
||||
matmul_node.AddAttribute("__backwardpass", static_cast<int64_t>(attrs.at("__backwardpass").i()));
|
||||
}
|
||||
#endif
|
||||
|
||||
graph_utils::FinalizeNodeFusion(graph, matmul_node, node);
|
||||
|
||||
|
|
|
|||
36
onnxruntime/core/optimizer/rocm_blas_alt_impl.cc
Normal file
36
onnxruntime/core/optimizer/rocm_blas_alt_impl.cc
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#include <thread>
|
||||
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/optimizer/rocm_blas_alt_impl.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace ::onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
Status RocmBlasAltImpl::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
bool is_backward_pass = false;
|
||||
|
||||
for (auto node_index : node_topology_list) {
|
||||
auto& node = *graph.GetNode(node_index);
|
||||
|
||||
if (node.OpType() == "YieldOp") {
|
||||
is_backward_pass = true;
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
|
||||
if (is_backward_pass) {
|
||||
node.AddAttribute(std::string("__backwardpass"), static_cast<int64_t>(1));
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
19
onnxruntime/core/optimizer/rocm_blas_alt_impl.h
Normal file
19
onnxruntime/core/optimizer/rocm_blas_alt_impl.h
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
class RocmBlasAltImpl : public GraphTransformer {
|
||||
public:
|
||||
RocmBlasAltImpl(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
|
||||
: GraphTransformer("RocmBlasAltImpl", compatible_execution_providers) {}
|
||||
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
21
onnxruntime/core/providers/rocm/backward_guard.cc
Normal file
21
onnxruntime/core/providers/rocm/backward_guard.cc
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#include "core/providers/rocm/backward_guard.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
thread_local bool BackwardPassGuard::is_backward_pass_;
|
||||
|
||||
BackwardPassGuard::BackwardPassGuard() {
|
||||
is_backward_pass_ = true;
|
||||
}
|
||||
|
||||
BackwardPassGuard::~BackwardPassGuard() {
|
||||
is_backward_pass_ = false;
|
||||
}
|
||||
|
||||
bool BackwardPassGuard::is_backward_pass() {
|
||||
return is_backward_pass_;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
15
onnxruntime/core/providers/rocm/backward_guard.h
Normal file
15
onnxruntime/core/providers/rocm/backward_guard.h
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
struct BackwardPassGuard {
|
||||
BackwardPassGuard();
|
||||
~BackwardPassGuard();
|
||||
static bool is_backward_pass();
|
||||
private:
|
||||
static thread_local bool is_backward_pass_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/rocm/backward_guard.h"
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
#include "core/providers/rocm/rocm_execution_provider.h"
|
||||
#include "core/providers/rocm/rocm_fwd.h"
|
||||
|
|
@ -22,7 +23,15 @@ class RocmKernel : public OpKernel {
|
|||
}
|
||||
|
||||
Status Compute(OpKernelContext* p_op_kernel_context) const override {
|
||||
auto s = ComputeInternal(p_op_kernel_context);
|
||||
Status s;
|
||||
auto is_backward_pass = Info().GetAttrOrDefault<int64_t>("__backwardpass", 0);
|
||||
if (is_backward_pass) {
|
||||
BackwardPassGuard guard;
|
||||
s = ComputeInternal(p_op_kernel_context);
|
||||
}
|
||||
else {
|
||||
s = ComputeInternal(p_op_kernel_context);
|
||||
}
|
||||
// use this to precisely locate the node where ROCM failure comes from
|
||||
// if (hipSuccess != hipDeviceSynchronize())
|
||||
// __debugbreak();
|
||||
|
|
|
|||
|
|
@ -3,10 +3,23 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/rocm/backward_guard.h"
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
|
||||
#define ORT_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
|
||||
#if ORT_ROCBLAS_VERSION_DECIMAL >= 242
|
||||
#define FLAG rocblas_gemm_flags_fp16_alt_impl
|
||||
#else
|
||||
#define FLAG 0
|
||||
#endif
|
||||
|
||||
using namespace onnxruntime;
|
||||
|
||||
inline int get_flag() {
|
||||
int result = BackwardPassGuard::is_backward_pass() ? FLAG : 0;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Generalize library calls to be use in template functions
|
||||
|
||||
// gemm
|
||||
|
|
@ -67,7 +80,7 @@ inline rocblas_status rocblasGemmHelper(rocblas_handle handle,
|
|||
C, rocblas_datatype_f16_r, ldc,
|
||||
C, rocblas_datatype_f16_r, ldc,
|
||||
rocblas_datatype_f32_r,
|
||||
rocblas_gemm_algo_standard, 0, 0);
|
||||
rocblas_gemm_algo_standard, 0, get_flag());
|
||||
}
|
||||
|
||||
inline rocblas_status rocblasGemmHelper(rocblas_handle handle,
|
||||
|
|
@ -90,7 +103,7 @@ inline rocblas_status rocblasGemmHelper(rocblas_handle handle,
|
|||
C, rocblas_datatype_f16_r, ldc,
|
||||
C, rocblas_datatype_f16_r, ldc,
|
||||
rocblas_datatype_f32_r,
|
||||
rocblas_gemm_algo_standard, 0, 0);
|
||||
rocblas_gemm_algo_standard, 0, get_flag());
|
||||
}
|
||||
|
||||
inline rocblas_status rocblasGemmHelper(rocblas_handle handle,
|
||||
|
|
@ -225,7 +238,7 @@ inline rocblas_status rocblasGemmBatchedHelper(rocblas_handle handle,
|
|||
(void**)Carray, rocblas_datatype_f16_r, ldc,
|
||||
batchCount,
|
||||
rocblas_datatype_f32_r,
|
||||
rocblas_gemm_algo_standard, 0, 0);
|
||||
rocblas_gemm_algo_standard, 0, get_flag());
|
||||
}
|
||||
|
||||
inline rocblas_status rocblasGemmBatchedHelper(rocblas_handle handle,
|
||||
|
|
@ -329,7 +342,7 @@ inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle,
|
|||
C, rocblas_datatype_f16_r, ldc, strideC,
|
||||
batchCount,
|
||||
rocblas_datatype_f32_r,
|
||||
rocblas_gemm_algo_standard, 0, 0);
|
||||
rocblas_gemm_algo_standard, 0, get_flag());
|
||||
}
|
||||
|
||||
inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle,
|
||||
|
|
@ -357,7 +370,7 @@ inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle,
|
|||
C, rocblas_datatype_f16_r, ldc, strideC,
|
||||
batchCount,
|
||||
rocblas_datatype_f32_r,
|
||||
rocblas_gemm_algo_standard, 0, 0);
|
||||
rocblas_gemm_algo_standard, 0, get_flag());
|
||||
}
|
||||
|
||||
inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle,
|
||||
|
|
|
|||
Loading…
Reference in a new issue