Fuse Cast + SoftmaxCrossEntropyLossInternal (#20334)

### Description
Fuse Cast + SoftmaxCrossEntropyLossInternal to
SoftmaxCrossEntropyLossInternal.
This commit is contained in:
guyang3532 2024-04-29 14:12:10 +08:00 committed by GitHub
parent 923b0ef323
commit 3e4db2c686
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 132 additions and 6 deletions

View file

@ -172,10 +172,7 @@ static bool RemoveNodeWithSingleNodeInSingleUsedOutput(Graph& graph, Node& node)
return true;
}
/** Move the input edges that src_node has to target_node.
After the move is complete src_node will have no input edges.
*/
static void MoveAllNodeInputEdges(Graph& graph, Node& src_node, Node& target_node) {
void MoveAllNodeInputEdges(Graph& graph, Node& src_node, Node& target_node) {
auto target_idx = target_node.Index();
auto input_edges = GraphEdge::GetNodeInputEdges(src_node);
@ -387,6 +384,18 @@ std::vector<GraphEdge> GraphEdge::GetNodeInputEdges(const Node& node) {
return input_edges;
}
/** Returns a vector of the input GraphEdges of a node for the provided input index. */
std::vector<GraphEdge> GraphEdge::GetNodeInputEdges(const Node& node, size_t index) {
std::vector<GraphEdge> input_edges;
for (auto it = node.InputEdgesBegin(), end = node.InputEdgesEnd(); it != end; ++it) {
if (static_cast<size_t>(it->GetDstArgIndex()) == index) {
input_edges.push_back(GraphEdge::CreateGraphEdge(node, *it, true));
}
}
return input_edges;
}
/** Returns a vector of the output GraphEdges of a node. */
std::vector<GraphEdge> GraphEdge::GetNodeOutputEdges(const Node& node) {
std::vector<GraphEdge> output_edges;

View file

@ -59,6 +59,11 @@ const std::string& GetNodeOutputName(const Node& node, int index);
*/
const Node::EdgeEnd* GetInputEdge(const Node& node, int arg_index);
/** Move the input edges that src_node has to target_node.
After the move is complete src_node will have no input edges.
*/
void MoveAllNodeInputEdges(Graph& graph, Node& src_node, Node& target_node);
/** Removes all output edges from the given Node of the Graph.
This should probably be elevated to the Graph API eventually. */
size_t RemoveNodeOutputEdges(Graph& graph, Node& node);
@ -89,6 +94,9 @@ struct GraphEdge {
/** Returns a vector of the input GraphEdges of a node. */
static std::vector<GraphEdge> GetNodeInputEdges(const Node& node);
/** Returns a vector of the input GraphEdges of a node for the provided input index. */
static std::vector<GraphEdge> GetNodeInputEdges(const Node& node, size_t index);
/** Returns a vector of the output GraphEdges of a node. */
static std::vector<GraphEdge> GetNodeOutputEdges(const Node& node);

View file

@ -87,6 +87,8 @@ Status GemmTransposeFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& m
new_gemm_node.AddAttribute("alpha", gemm_node.GetAttributes().at("alpha").f());
new_gemm_node.AddAttribute("beta", gemm_node.GetAttributes().at("beta").f());
new_gemm_node.SetExecutionProviderType(gemm_node.GetExecutionProviderType());
graph_utils::FinalizeNodeFusion(graph, nodes_to_remove, new_gemm_node);
modified = RewriteRuleEffect::kRemovedCurrentNode;

View file

@ -138,6 +138,7 @@ InlinedVector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
break;
case TransformerLevel::Level2:
rules.push_back(std::make_unique<GemmTransposeFusion>());
// No level2 rules available today
break;
@ -253,6 +254,11 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
} break;
case TransformerLevel::Level2: {
auto rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {});
if (rule_transformer != nullptr) {
transformers.emplace_back(std::move(rule_transformer));
}
// we run TransposeOptimizer again in Level2 for some CPU EP specific optimizations that can only be
// applied once nodes are assigned to the CPU EP (which happens between level 1 and level 2).
transformers.emplace_back(std::make_unique<TransposeOptimizer>(std::move(cpu_allocator), kCpuExecutionProvider));

View file

@ -171,7 +171,7 @@ static bool IsFP16Allow(const Node* node, size_t level, const FP16AllowOps& fp16
using OpsSetType = InlinedHashSet<std::string_view>;
static const OpsSetType level1_fp16_allow_set =
{"Expand", "Transpose", "Relu", "Reshape", "Split", "Tanh", "Squeeze", "Unsqueeze", "Gelu"};
{"Expand", "Transpose", "Relu", "Reshape", "Split", "Tanh", "Squeeze", "Unsqueeze", "Gelu", "Slice", "PadAndUnflatten"};
static const OpsSetType level2_fp16_allow_set = {
"Add", "BiasGelu", "Dropout", "FastGelu", "Gather", "LayerNormalization", "Where"};

View file

@ -281,7 +281,7 @@ constexpr std::array kOnnxDomainNonDeterministicOps{"RandomUniform", "RandomNorm
// (plus ShrunkenGather for training) are considered deterministic.
#ifdef ENABLE_TRAINING_OPS
constexpr std::array kMSDomainDeterministicOps{"ShrunkenGather", "QuantizeLinear", "DequantizeLinear",
"ConcatTraining"};
"ConcatTraining", "PadAndUnflatten"};
#else
constexpr std::array kMSDomainDeterministicOps{"QuantizeLinear", "DequantizeLinear"};
#endif

View file

@ -0,0 +1,59 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "orttraining/core/optimizer/cast_sce_loss_fusion.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/utils.h"
namespace onnxruntime {
Status CastSceLossFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
for (auto node_index : node_topology_list) {
auto* node_ptr = graph.GetNode(node_index);
if (!node_ptr) continue; // Node was removed.
auto& node = *node_ptr;
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
bool is_internal_sce = graph_utils::IsSupportedOptypeVersionAndDomain(node, "SoftmaxCrossEntropyLossInternal", {1},
kMSDomain);
if (!is_internal_sce || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
continue;
}
Node* input_node = graph.GetMutableProducerNode(node.MutableInputDefs()[0]->Name());
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(*input_node, "Cast", {9, 13, 19}))) {
continue;
}
if (input_node->GetOutputEdgesCount() != 1 || graph.IsOutput(input_node->OutputDefs()[0])) {
continue;
}
if (input_node->MutableInputDefs()[0]->TypeAsProto()->tensor_type().elem_type() == onnx::TensorProto_DataType_FLOAT16 &&
input_node->MutableOutputDefs()[0]->TypeAsProto()->tensor_type().elem_type() == onnx::TensorProto_DataType_FLOAT) {
std::vector<graph_utils::GraphEdge> input_edges = graph_utils::GraphEdge::GetNodeInputEdges(node, 0);
graph_utils::GraphEdge::RemoveGraphEdges(graph, input_edges);
node.MutableInputDefs()[0] = input_node->MutableInputDefs()[0];
graph_utils::MoveAllNodeInputEdges(graph, *input_node, node);
graph.RemoveNode(input_node->Index());
if (node.GetAttributes().count("output_type") == 0) {
node.AddAttribute("output_type", static_cast<int64_t>(onnx::TensorProto_DataType_FLOAT));
}
modified = true;
}
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/graph_transformer.h"
namespace onnxruntime {
/**
@Class CastSceLossFusion
Fuse Cast + SoftmaxCrossEntropyLossInternal to SoftmaxCrossEntropyLossInternal.
*/
class CastSceLossFusion : public GraphTransformer {
public:
explicit CastSceLossFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("CastSceLossFusion", compatible_execution_providers) {
}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -52,6 +52,7 @@
#include "orttraining/core/framework/distributed_run_context.h"
#include "orttraining/core/optimizer/batchnorm_replacement.h"
#include "orttraining/core/optimizer/bitmask_dropout_replacement.h"
#include "orttraining/core/optimizer/cast_sce_loss_fusion.h"
#include "orttraining/core/optimizer/concat_replacement.h"
#include "orttraining/core/optimizer/graph_transformer_registry.h"
#include "orttraining/core/optimizer/gru_replacement.h"
@ -188,6 +189,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
config.propagate_cast_ops_config.allow,
cuda_execution_provider));
}
transformers.emplace_back(std::make_unique<CastSceLossFusion>(compatible_eps));
if (config.enable_compute_optimizer) {
transformers.emplace_back(std::make_unique<UpStreamGatherGraphTransformer>(compatible_eps));

View file

@ -25,6 +25,7 @@
#include "test/util/include/asserts.h"
#include "orttraining/test/optimizer/horizontal_parallel_test_utils.h"
#include "orttraining/core/session/training_session.h"
#include "orttraining/core/optimizer/cast_sce_loss_fusion.h"
#include "orttraining/core/optimizer/loss_rewriter.h"
#include "orttraining/core/optimizer/bias_softmax_dropout_fusion.h"
#include "orttraining/core/optimizer/qdq_fusion.h"
@ -518,6 +519,22 @@ TEST_F(GraphTransformationTests, SceLossGradBiasFusion_Invalid) {
}
}
TEST_F(GraphTransformationTests, CastSceLossFusion) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "computation_reduction/reshape/mlm_bert_e2e.onnx";
std::shared_ptr<Model> model;
ASSERT_TRUE(Model::Load(model_uri, model, nullptr, *logger_).IsOK());
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Cast"], 10);
onnxruntime::GraphTransformerManager graph_transformation_mgr{1};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<CastSceLossFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Cast"], 9);
}
Node* GetNodeByName(Graph& graph, std::string node_name) {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();