rocblas alt impl during backward pass only (#13352)

On AMD Instinct MI200 GPUs, the FP16 and BF16 V_DOT2 and MFMA matrix
instructions flush input and output denormal values to zero. When
training using FP16 precision, some models may fail to converge with
FP16 denorms flushed to zero. The affected instructions are only used by
rocBLAS (GEMM) and MIOpen (convolution) kernels; all other onnxruntime
operations will not encounter this behavior. All other supported AMD
GPUs will not encounter this behavior.

rocBLAS and MIOpen provide alternate implementations for affected FP16
operations. Alternate implementations for BF16 operations are not
provided; BF16 numbers have a larger dynamic range than FP16 numbers and
are less likely to encounter denormal values. For the FP16 alternate
implementations, FP16 input values are cast to an intermediate BF16
value and then cast back to FP16 output after the accumulate FP32
operations. In this way, the input and output types are unchanged.

Denormal values more frequently occur in the backward pass of training
during gradient calculation. Therefore, it is necessary to track when
the backward pass of training is executing. For the ROCm EP only, the
`__backwardpass` attribute is added to all Nodes after the YieldOp is
detected. This takes place in a level1 graph optimization pass. The
attribute is forwarded to any newly created FusedMatMul Nodes. In
addition, the scope-based helper class `BackwardPassGuard` is provided
to toggle state for rocblas. This behavior of using the alternate
implementations during the backward pass is made automatic with this PR.
This default behavior can be overridden using environment variables,
ROCBLAS_INTERNAL_FP16_ALT_IMPL and
MIOPEN_DEBUG_CONVOLUTION_ATTRIB_FP16_ALT_IMPL. The behavior of these
environment variables is as follows:

|              | forward   | backward  |
|--------------|-----------|-----------|
| Env unset    | original  | alternate |
| Env set to 1 | alternate | alternate |
| Env set to 0 | original  | original  |

See also:


https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
This commit is contained in:
Jeff Daily 2022-11-09 08:47:06 -08:00 committed by GitHub
parent d10d66cc84
commit d5d6924688
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 138 additions and 6 deletions

View file

@ -10,6 +10,7 @@
#include "core/optimizer/nhwc_transformer.h"
#include "core/optimizer/qdq_transformer/qdq_final_cleanup.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
#include "core/optimizer/rocm_blas_alt_impl.h"
#include "core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/optimizer/conv_add_act_fusion.h"
@ -208,6 +209,10 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
// shouldn't affect the end result - just easier to debug any issue if it's last.
auto cpu_allocator = cpu_execution_provider.GetAllocator(0, OrtMemTypeDefault);
transformers.emplace_back(std::make_unique<TransposeOptimizer>(std::move(cpu_allocator)));
// add __backwardpass attribute to nodes after YieldOp, ROCm-only
const InlinedHashSet<std::string_view> rocm_ep = {onnxruntime::kRocmExecutionProvider};
transformers.emplace_back(std::make_unique<RocmBlasAltImpl>(rocm_ep));
} break;
case TransformerLevel::Level2: {

View file

@ -254,6 +254,13 @@ Status ProcessNode(
kMSDomain);
matmul_scale_node.SetExecutionProviderType(node.GetExecutionProviderType());
#ifdef USE_ROCM
// forward the __backwardpass, if present
auto& attrs = node.GetAttributes();
if (attrs.count("__backwardpass")) {
matmul_scale_node.AddAttribute("__backwardpass", static_cast<int64_t>(attrs.at("__backwardpass").i()));
}
#endif
{
InlinedVector<std::reference_wrapper<Node>> nodes_to_remove{node};

View file

@ -404,6 +404,13 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_
matmul_node.AddAttribute("alpha", alpha);
// Assign provider to this new node. Provider should be same as the provider for old node.
matmul_node.SetExecutionProviderType(node.GetExecutionProviderType());
#ifdef USE_ROCM
// forward the __backwardpass, if present
auto& attrs = node.GetAttributes();
if (attrs.count("__backwardpass")) {
matmul_node.AddAttribute("__backwardpass", static_cast<int64_t>(attrs.at("__backwardpass").i()));
}
#endif
graph_utils::FinalizeNodeFusion(graph, matmul_node, node);

View file

@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <thread>
#include "core/optimizer/initializer.h"
#include "core/optimizer/rocm_blas_alt_impl.h"
#include "core/graph/graph_utils.h"
using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
namespace onnxruntime {
Status RocmBlasAltImpl::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
bool is_backward_pass = false;
for (auto node_index : node_topology_list) {
auto& node = *graph.GetNode(node_index);
if (node.OpType() == "YieldOp") {
is_backward_pass = true;
}
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
if (is_backward_pass) {
node.AddAttribute(std::string("__backwardpass"), static_cast<int64_t>(1));
modified = true;
}
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/graph_transformer.h"
#include "core/graph/graph_utils.h"
namespace onnxruntime {
class RocmBlasAltImpl : public GraphTransformer {
public:
RocmBlasAltImpl(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("RocmBlasAltImpl", compatible_execution_providers) {}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/backward_guard.h"
namespace onnxruntime {
thread_local bool BackwardPassGuard::is_backward_pass_;
BackwardPassGuard::BackwardPassGuard() {
is_backward_pass_ = true;
}
BackwardPassGuard::~BackwardPassGuard() {
is_backward_pass_ = false;
}
bool BackwardPassGuard::is_backward_pass() {
return is_backward_pass_;
}
} // namespace onnxruntime

View file

@ -0,0 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace onnxruntime {
struct BackwardPassGuard {
BackwardPassGuard();
~BackwardPassGuard();
static bool is_backward_pass();
private:
static thread_local bool is_backward_pass_;
};
} // namespace onnxruntime

View file

@ -3,6 +3,7 @@
#pragma once
#include "core/providers/rocm/backward_guard.h"
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/rocm_execution_provider.h"
#include "core/providers/rocm/rocm_fwd.h"
@ -22,7 +23,15 @@ class RocmKernel : public OpKernel {
}
Status Compute(OpKernelContext* p_op_kernel_context) const override {
auto s = ComputeInternal(p_op_kernel_context);
Status s;
auto is_backward_pass = Info().GetAttrOrDefault<int64_t>("__backwardpass", 0);
if (is_backward_pass) {
BackwardPassGuard guard;
s = ComputeInternal(p_op_kernel_context);
}
else {
s = ComputeInternal(p_op_kernel_context);
}
// use this to precisely locate the node where ROCM failure comes from
// if (hipSuccess != hipDeviceSynchronize())
// __debugbreak();

View file

@ -3,10 +3,23 @@
#pragma once
#include "core/providers/rocm/backward_guard.h"
#include "core/providers/rocm/rocm_common.h"
#define ORT_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#if ORT_ROCBLAS_VERSION_DECIMAL >= 242
#define FLAG rocblas_gemm_flags_fp16_alt_impl
#else
#define FLAG 0
#endif
using namespace onnxruntime;
inline int get_flag() {
int result = BackwardPassGuard::is_backward_pass() ? FLAG : 0;
return result;
}
// Generalize library calls to be use in template functions
// gemm
@ -67,7 +80,7 @@ inline rocblas_status rocblasGemmHelper(rocblas_handle handle,
C, rocblas_datatype_f16_r, ldc,
C, rocblas_datatype_f16_r, ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard, 0, 0);
rocblas_gemm_algo_standard, 0, get_flag());
}
inline rocblas_status rocblasGemmHelper(rocblas_handle handle,
@ -90,7 +103,7 @@ inline rocblas_status rocblasGemmHelper(rocblas_handle handle,
C, rocblas_datatype_f16_r, ldc,
C, rocblas_datatype_f16_r, ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard, 0, 0);
rocblas_gemm_algo_standard, 0, get_flag());
}
inline rocblas_status rocblasGemmHelper(rocblas_handle handle,
@ -225,7 +238,7 @@ inline rocblas_status rocblasGemmBatchedHelper(rocblas_handle handle,
(void**)Carray, rocblas_datatype_f16_r, ldc,
batchCount,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard, 0, 0);
rocblas_gemm_algo_standard, 0, get_flag());
}
inline rocblas_status rocblasGemmBatchedHelper(rocblas_handle handle,
@ -329,7 +342,7 @@ inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle,
C, rocblas_datatype_f16_r, ldc, strideC,
batchCount,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard, 0, 0);
rocblas_gemm_algo_standard, 0, get_flag());
}
inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle,
@ -357,7 +370,7 @@ inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle,
C, rocblas_datatype_f16_r, ldc, strideC,
batchCount,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard, 0, 0);
rocblas_gemm_algo_standard, 0, get_flag());
}
inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle,