diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 3093d8d43b..58b64511d4 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -10,6 +10,7 @@ #include "core/optimizer/nhwc_transformer.h" #include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" +#include "core/optimizer/rocm_blas_alt_impl.h" #include "core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/optimizer/conv_add_act_fusion.h" @@ -208,6 +209,10 @@ InlinedVector> GenerateTransformers( // shouldn't affect the end result - just easier to debug any issue if it's last. auto cpu_allocator = cpu_execution_provider.GetAllocator(0, OrtMemTypeDefault); transformers.emplace_back(std::make_unique(std::move(cpu_allocator))); + + // add __backwardpass attribute to nodes after YieldOp, ROCm-only + const InlinedHashSet rocm_ep = {onnxruntime::kRocmExecutionProvider}; + transformers.emplace_back(std::make_unique(rocm_ep)); } break; case TransformerLevel::Level2: { diff --git a/onnxruntime/core/optimizer/matmul_scale_fusion.cc b/onnxruntime/core/optimizer/matmul_scale_fusion.cc index 2c43f5ab12..b944b5536d 100644 --- a/onnxruntime/core/optimizer/matmul_scale_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_scale_fusion.cc @@ -254,6 +254,13 @@ Status ProcessNode( kMSDomain); matmul_scale_node.SetExecutionProviderType(node.GetExecutionProviderType()); +#ifdef USE_ROCM + // forward the __backwardpass, if present + auto& attrs = node.GetAttributes(); + if (attrs.count("__backwardpass")) { + matmul_scale_node.AddAttribute("__backwardpass", static_cast(attrs.at("__backwardpass").i())); + } +#endif { InlinedVector> nodes_to_remove{node}; diff --git a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc index d47538640b..642805c93b 100644 --- a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc @@ -404,6 +404,13 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ matmul_node.AddAttribute("alpha", alpha); // Assign provider to this new node. Provider should be same as the provider for old node. matmul_node.SetExecutionProviderType(node.GetExecutionProviderType()); +#ifdef USE_ROCM + // forward the __backwardpass, if present + auto& attrs = node.GetAttributes(); + if (attrs.count("__backwardpass")) { + matmul_node.AddAttribute("__backwardpass", static_cast(attrs.at("__backwardpass").i())); + } +#endif graph_utils::FinalizeNodeFusion(graph, matmul_node, node); diff --git a/onnxruntime/core/optimizer/rocm_blas_alt_impl.cc b/onnxruntime/core/optimizer/rocm_blas_alt_impl.cc new file mode 100644 index 0000000000..decb25f565 --- /dev/null +++ b/onnxruntime/core/optimizer/rocm_blas_alt_impl.cc @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include + +#include "core/optimizer/initializer.h" +#include "core/optimizer/rocm_blas_alt_impl.h" +#include "core/graph/graph_utils.h" + +using namespace ONNX_NAMESPACE; +using namespace ::onnxruntime::common; +namespace onnxruntime { + +Status RocmBlasAltImpl::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + bool is_backward_pass = false; + + for (auto node_index : node_topology_list) { + auto& node = *graph.GetNode(node_index); + + if (node.OpType() == "YieldOp") { + is_backward_pass = true; + } + + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + if (is_backward_pass) { + node.AddAttribute(std::string("__backwardpass"), static_cast(1)); + modified = true; + } + } + + return Status::OK(); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/rocm_blas_alt_impl.h b/onnxruntime/core/optimizer/rocm_blas_alt_impl.h new file mode 100644 index 0000000000..11744d0dac --- /dev/null +++ b/onnxruntime/core/optimizer/rocm_blas_alt_impl.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" +#include "core/graph/graph_utils.h" + +namespace onnxruntime { + +class RocmBlasAltImpl : public GraphTransformer { + public: + RocmBlasAltImpl(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("RocmBlasAltImpl", compatible_execution_providers) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/backward_guard.cc b/onnxruntime/core/providers/rocm/backward_guard.cc new file mode 100644 index 0000000000..1695da092b --- /dev/null +++ b/onnxruntime/core/providers/rocm/backward_guard.cc @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/providers/rocm/backward_guard.h" + +namespace onnxruntime { + +thread_local bool BackwardPassGuard::is_backward_pass_; + +BackwardPassGuard::BackwardPassGuard() { + is_backward_pass_ = true; +} + +BackwardPassGuard::~BackwardPassGuard() { + is_backward_pass_ = false; +} + +bool BackwardPassGuard::is_backward_pass() { + return is_backward_pass_; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/backward_guard.h b/onnxruntime/core/providers/rocm/backward_guard.h new file mode 100644 index 0000000000..e36785af37 --- /dev/null +++ b/onnxruntime/core/providers/rocm/backward_guard.h @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +namespace onnxruntime { + +struct BackwardPassGuard { + BackwardPassGuard(); + ~BackwardPassGuard(); + static bool is_backward_pass(); +private: + static thread_local bool is_backward_pass_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_kernel.h b/onnxruntime/core/providers/rocm/rocm_kernel.h index 6737ae8582..a29537e9f1 100644 --- a/onnxruntime/core/providers/rocm/rocm_kernel.h +++ b/onnxruntime/core/providers/rocm/rocm_kernel.h @@ -3,6 +3,7 @@ #pragma once +#include "core/providers/rocm/backward_guard.h" #include "core/providers/rocm/rocm_common.h" #include "core/providers/rocm/rocm_execution_provider.h" #include "core/providers/rocm/rocm_fwd.h" @@ -22,7 +23,15 @@ class RocmKernel : public OpKernel { } Status Compute(OpKernelContext* p_op_kernel_context) const override { - auto s = ComputeInternal(p_op_kernel_context); + Status s; + auto is_backward_pass = Info().GetAttrOrDefault("__backwardpass", 0); + if (is_backward_pass) { + BackwardPassGuard guard; + s = ComputeInternal(p_op_kernel_context); + } + else { + s = ComputeInternal(p_op_kernel_context); + } // use this to precisely locate the node where ROCM failure comes from // if (hipSuccess != hipDeviceSynchronize()) // __debugbreak(); diff --git a/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h b/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h index 030275c52d..657a11ccb8 100644 --- a/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h +++ b/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h @@ -3,10 +3,23 @@ #pragma once +#include "core/providers/rocm/backward_guard.h" #include "core/providers/rocm/rocm_common.h" +#define ORT_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) +#if ORT_ROCBLAS_VERSION_DECIMAL >= 242 +#define FLAG rocblas_gemm_flags_fp16_alt_impl +#else +#define FLAG 0 +#endif + using namespace onnxruntime; +inline int get_flag() { + int result = BackwardPassGuard::is_backward_pass() ? FLAG : 0; + return result; +} + // Generalize library calls to be use in template functions // gemm @@ -67,7 +80,7 @@ inline rocblas_status rocblasGemmHelper(rocblas_handle handle, C, rocblas_datatype_f16_r, ldc, C, rocblas_datatype_f16_r, ldc, rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, 0); + rocblas_gemm_algo_standard, 0, get_flag()); } inline rocblas_status rocblasGemmHelper(rocblas_handle handle, @@ -90,7 +103,7 @@ inline rocblas_status rocblasGemmHelper(rocblas_handle handle, C, rocblas_datatype_f16_r, ldc, C, rocblas_datatype_f16_r, ldc, rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, 0); + rocblas_gemm_algo_standard, 0, get_flag()); } inline rocblas_status rocblasGemmHelper(rocblas_handle handle, @@ -225,7 +238,7 @@ inline rocblas_status rocblasGemmBatchedHelper(rocblas_handle handle, (void**)Carray, rocblas_datatype_f16_r, ldc, batchCount, rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, 0); + rocblas_gemm_algo_standard, 0, get_flag()); } inline rocblas_status rocblasGemmBatchedHelper(rocblas_handle handle, @@ -329,7 +342,7 @@ inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, C, rocblas_datatype_f16_r, ldc, strideC, batchCount, rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, 0); + rocblas_gemm_algo_standard, 0, get_flag()); } inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, @@ -357,7 +370,7 @@ inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, C, rocblas_datatype_f16_r, ldc, strideC, batchCount, rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, 0); + rocblas_gemm_algo_standard, 0, get_flag()); } inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle,