mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
Fuse Cast + SoftmaxCrossEntropyLossInternal (#20334)
### Description Fuse Cast + SoftmaxCrossEntropyLossInternal to SoftmaxCrossEntropyLossInternal.
This commit is contained in:
parent
923b0ef323
commit
3e4db2c686
10 changed files with 132 additions and 6 deletions
|
|
@ -172,10 +172,7 @@ static bool RemoveNodeWithSingleNodeInSingleUsedOutput(Graph& graph, Node& node)
|
|||
return true;
|
||||
}
|
||||
|
||||
/** Move the input edges that src_node has to target_node.
|
||||
After the move is complete src_node will have no input edges.
|
||||
*/
|
||||
static void MoveAllNodeInputEdges(Graph& graph, Node& src_node, Node& target_node) {
|
||||
void MoveAllNodeInputEdges(Graph& graph, Node& src_node, Node& target_node) {
|
||||
auto target_idx = target_node.Index();
|
||||
auto input_edges = GraphEdge::GetNodeInputEdges(src_node);
|
||||
|
||||
|
|
@ -387,6 +384,18 @@ std::vector<GraphEdge> GraphEdge::GetNodeInputEdges(const Node& node) {
|
|||
return input_edges;
|
||||
}
|
||||
|
||||
/** Returns a vector of the input GraphEdges of a node for the provided input index. */
|
||||
std::vector<GraphEdge> GraphEdge::GetNodeInputEdges(const Node& node, size_t index) {
|
||||
std::vector<GraphEdge> input_edges;
|
||||
for (auto it = node.InputEdgesBegin(), end = node.InputEdgesEnd(); it != end; ++it) {
|
||||
if (static_cast<size_t>(it->GetDstArgIndex()) == index) {
|
||||
input_edges.push_back(GraphEdge::CreateGraphEdge(node, *it, true));
|
||||
}
|
||||
}
|
||||
|
||||
return input_edges;
|
||||
}
|
||||
|
||||
/** Returns a vector of the output GraphEdges of a node. */
|
||||
std::vector<GraphEdge> GraphEdge::GetNodeOutputEdges(const Node& node) {
|
||||
std::vector<GraphEdge> output_edges;
|
||||
|
|
|
|||
|
|
@ -59,6 +59,11 @@ const std::string& GetNodeOutputName(const Node& node, int index);
|
|||
*/
|
||||
const Node::EdgeEnd* GetInputEdge(const Node& node, int arg_index);
|
||||
|
||||
/** Move the input edges that src_node has to target_node.
|
||||
After the move is complete src_node will have no input edges.
|
||||
*/
|
||||
void MoveAllNodeInputEdges(Graph& graph, Node& src_node, Node& target_node);
|
||||
|
||||
/** Removes all output edges from the given Node of the Graph.
|
||||
This should probably be elevated to the Graph API eventually. */
|
||||
size_t RemoveNodeOutputEdges(Graph& graph, Node& node);
|
||||
|
|
@ -89,6 +94,9 @@ struct GraphEdge {
|
|||
/** Returns a vector of the input GraphEdges of a node. */
|
||||
static std::vector<GraphEdge> GetNodeInputEdges(const Node& node);
|
||||
|
||||
/** Returns a vector of the input GraphEdges of a node for the provided input index. */
|
||||
static std::vector<GraphEdge> GetNodeInputEdges(const Node& node, size_t index);
|
||||
|
||||
/** Returns a vector of the output GraphEdges of a node. */
|
||||
static std::vector<GraphEdge> GetNodeOutputEdges(const Node& node);
|
||||
|
||||
|
|
|
|||
|
|
@ -87,6 +87,8 @@ Status GemmTransposeFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& m
|
|||
new_gemm_node.AddAttribute("alpha", gemm_node.GetAttributes().at("alpha").f());
|
||||
new_gemm_node.AddAttribute("beta", gemm_node.GetAttributes().at("beta").f());
|
||||
|
||||
new_gemm_node.SetExecutionProviderType(gemm_node.GetExecutionProviderType());
|
||||
|
||||
graph_utils::FinalizeNodeFusion(graph, nodes_to_remove, new_gemm_node);
|
||||
|
||||
modified = RewriteRuleEffect::kRemovedCurrentNode;
|
||||
|
|
|
|||
|
|
@ -138,6 +138,7 @@ InlinedVector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
|
|||
break;
|
||||
|
||||
case TransformerLevel::Level2:
|
||||
rules.push_back(std::make_unique<GemmTransposeFusion>());
|
||||
// No level2 rules available today
|
||||
break;
|
||||
|
||||
|
|
@ -253,6 +254,11 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
} break;
|
||||
|
||||
case TransformerLevel::Level2: {
|
||||
auto rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {});
|
||||
if (rule_transformer != nullptr) {
|
||||
transformers.emplace_back(std::move(rule_transformer));
|
||||
}
|
||||
|
||||
// we run TransposeOptimizer again in Level2 for some CPU EP specific optimizations that can only be
|
||||
// applied once nodes are assigned to the CPU EP (which happens between level 1 and level 2).
|
||||
transformers.emplace_back(std::make_unique<TransposeOptimizer>(std::move(cpu_allocator), kCpuExecutionProvider));
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ static bool IsFP16Allow(const Node* node, size_t level, const FP16AllowOps& fp16
|
|||
|
||||
using OpsSetType = InlinedHashSet<std::string_view>;
|
||||
static const OpsSetType level1_fp16_allow_set =
|
||||
{"Expand", "Transpose", "Relu", "Reshape", "Split", "Tanh", "Squeeze", "Unsqueeze", "Gelu"};
|
||||
{"Expand", "Transpose", "Relu", "Reshape", "Split", "Tanh", "Squeeze", "Unsqueeze", "Gelu", "Slice", "PadAndUnflatten"};
|
||||
static const OpsSetType level2_fp16_allow_set = {
|
||||
"Add", "BiasGelu", "Dropout", "FastGelu", "Gather", "LayerNormalization", "Where"};
|
||||
|
||||
|
|
|
|||
|
|
@ -281,7 +281,7 @@ constexpr std::array kOnnxDomainNonDeterministicOps{"RandomUniform", "RandomNorm
|
|||
// (plus ShrunkenGather for training) are considered deterministic.
|
||||
#ifdef ENABLE_TRAINING_OPS
|
||||
constexpr std::array kMSDomainDeterministicOps{"ShrunkenGather", "QuantizeLinear", "DequantizeLinear",
|
||||
"ConcatTraining"};
|
||||
"ConcatTraining", "PadAndUnflatten"};
|
||||
#else
|
||||
constexpr std::array kMSDomainDeterministicOps{"QuantizeLinear", "DequantizeLinear"};
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -0,0 +1,59 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/core/optimizer/cast_sce_loss_fusion.h"
|
||||
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status CastSceLossFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
||||
const logging::Logger& logger) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
for (auto node_index : node_topology_list) {
|
||||
auto* node_ptr = graph.GetNode(node_index);
|
||||
if (!node_ptr) continue; // Node was removed.
|
||||
|
||||
auto& node = *node_ptr;
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
|
||||
bool is_internal_sce = graph_utils::IsSupportedOptypeVersionAndDomain(node, "SoftmaxCrossEntropyLossInternal", {1},
|
||||
kMSDomain);
|
||||
|
||||
if (!is_internal_sce || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Node* input_node = graph.GetMutableProducerNode(node.MutableInputDefs()[0]->Name());
|
||||
|
||||
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(*input_node, "Cast", {9, 13, 19}))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (input_node->GetOutputEdgesCount() != 1 || graph.IsOutput(input_node->OutputDefs()[0])) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (input_node->MutableInputDefs()[0]->TypeAsProto()->tensor_type().elem_type() == onnx::TensorProto_DataType_FLOAT16 &&
|
||||
input_node->MutableOutputDefs()[0]->TypeAsProto()->tensor_type().elem_type() == onnx::TensorProto_DataType_FLOAT) {
|
||||
std::vector<graph_utils::GraphEdge> input_edges = graph_utils::GraphEdge::GetNodeInputEdges(node, 0);
|
||||
graph_utils::GraphEdge::RemoveGraphEdges(graph, input_edges);
|
||||
node.MutableInputDefs()[0] = input_node->MutableInputDefs()[0];
|
||||
graph_utils::MoveAllNodeInputEdges(graph, *input_node, node);
|
||||
graph.RemoveNode(input_node->Index());
|
||||
|
||||
if (node.GetAttributes().count("output_type") == 0) {
|
||||
node.AddAttribute("output_type", static_cast<int64_t>(onnx::TensorProto_DataType_FLOAT));
|
||||
}
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
@Class CastSceLossFusion
|
||||
Fuse Cast + SoftmaxCrossEntropyLossInternal to SoftmaxCrossEntropyLossInternal.
|
||||
*/
|
||||
class CastSceLossFusion : public GraphTransformer {
|
||||
public:
|
||||
explicit CastSceLossFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
|
||||
: GraphTransformer("CastSceLossFusion", compatible_execution_providers) {
|
||||
}
|
||||
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -52,6 +52,7 @@
|
|||
#include "orttraining/core/framework/distributed_run_context.h"
|
||||
#include "orttraining/core/optimizer/batchnorm_replacement.h"
|
||||
#include "orttraining/core/optimizer/bitmask_dropout_replacement.h"
|
||||
#include "orttraining/core/optimizer/cast_sce_loss_fusion.h"
|
||||
#include "orttraining/core/optimizer/concat_replacement.h"
|
||||
#include "orttraining/core/optimizer/graph_transformer_registry.h"
|
||||
#include "orttraining/core/optimizer/gru_replacement.h"
|
||||
|
|
@ -188,6 +189,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
config.propagate_cast_ops_config.allow,
|
||||
cuda_execution_provider));
|
||||
}
|
||||
transformers.emplace_back(std::make_unique<CastSceLossFusion>(compatible_eps));
|
||||
|
||||
if (config.enable_compute_optimizer) {
|
||||
transformers.emplace_back(std::make_unique<UpStreamGatherGraphTransformer>(compatible_eps));
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@
|
|||
#include "test/util/include/asserts.h"
|
||||
#include "orttraining/test/optimizer/horizontal_parallel_test_utils.h"
|
||||
#include "orttraining/core/session/training_session.h"
|
||||
#include "orttraining/core/optimizer/cast_sce_loss_fusion.h"
|
||||
#include "orttraining/core/optimizer/loss_rewriter.h"
|
||||
#include "orttraining/core/optimizer/bias_softmax_dropout_fusion.h"
|
||||
#include "orttraining/core/optimizer/qdq_fusion.h"
|
||||
|
|
@ -518,6 +519,22 @@ TEST_F(GraphTransformationTests, SceLossGradBiasFusion_Invalid) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, CastSceLossFusion) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "computation_reduction/reshape/mlm_bert_e2e.onnx";
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, model, nullptr, *logger_).IsOK());
|
||||
Graph& graph = model->MainGraph();
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["Cast"], 10);
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{1};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<CastSceLossFusion>(), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["Cast"], 9);
|
||||
}
|
||||
|
||||
Node* GetNodeByName(Graph& graph, std::string node_name) {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
|
|
|||
Loading…
Reference in a new issue