mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
Fix GeluRecompute for 2 inputs case. (#5573)
* Add test for FastGelu + GeluRecompute. * Fix GeluRecompute for 2 inputs case. * Fix test for BiasGelu + GeluRecompute. * Copy all inputs to Gelu, not just 2. * Move GeluRecompute test to training-specific file.
This commit is contained in:
parent
b85e7a19ea
commit
2e1fa3ccb7
5 changed files with 68 additions and 2 deletions
|
|
@ -108,7 +108,7 @@ class Node {
|
|||
/** Gets the domain of the OperatorSet that specifies the operator returned by #OpType. */
|
||||
const std::string& Domain() const noexcept { return domain_; }
|
||||
|
||||
/** Gets the Node's exection priority.
|
||||
/** Gets the Node's execution priority.
|
||||
@remarks Lower value means higher priority */
|
||||
int Priority() const noexcept { return priority_; };
|
||||
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/bias_gelu_fusion_recompute.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/bias_gelu_fusion_recompute.onnx
vendored
Normal file
Binary file not shown.
36
onnxruntime/test/testdata/transform/fusion/bias_gelu_matmul_gen.py
vendored
Normal file
36
onnxruntime/test/testdata/transform/fusion/bias_gelu_matmul_gen.py
vendored
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import TensorProto
|
||||
|
||||
graph = helper.make_graph(
|
||||
[ # nodes
|
||||
# Add node before Gelu
|
||||
helper.make_node("Add", ["A", "B"], ["add0_out"], "add0"),
|
||||
|
||||
# Gelu subgraph
|
||||
helper.make_node("Div", ["add0_out", "div_const"], ["div_out"], "div"),
|
||||
helper.make_node("Mul", ["add0_out", "mul_const"], ["mul_out"], "mul0"),
|
||||
helper.make_node("Erf", ["div_out"], ["erf_out"], "erf"),
|
||||
helper.make_node("Add", ["erf_out", "add_const"], ["add1_out"], "add1"),
|
||||
helper.make_node("Mul", ["mul_out", "add1_out"], ["C"], "mul1"),
|
||||
|
||||
# MatMul node after Gelu for recompute
|
||||
helper.make_node("MatMul", ["X", "C"], ["D"], "matmul"),
|
||||
],
|
||||
"Gelu_Add_Fusion_Recompute", #name
|
||||
[ # inputs
|
||||
helper.make_tensor_value_info('A', TensorProto.FLOAT, ['unk_1', 'unk_2', 3072]),
|
||||
helper.make_tensor_value_info('B', TensorProto.FLOAT, [3072]),
|
||||
helper.make_tensor_value_info('X', TensorProto.FLOAT, ['unk_5', 'unk_6', 3072]),
|
||||
],
|
||||
[ # outputs
|
||||
helper.make_tensor_value_info('D', TensorProto.FLOAT, ['unk_3', 'unk_4', 'unk_5']),
|
||||
],
|
||||
[ # initializers
|
||||
helper.make_tensor('div_const', TensorProto.FLOAT, [], [1.4142135381698608]),
|
||||
helper.make_tensor('mul_const', TensorProto.FLOAT, [], [0.5]),
|
||||
helper.make_tensor('add_const', TensorProto.FLOAT, [], [1]),
|
||||
])
|
||||
|
||||
model = helper.make_model(graph)
|
||||
onnx.save(model, r'bias_gelu_fusion_recompute.onnx')
|
||||
|
|
@ -44,7 +44,7 @@ Status GeluRecompute::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*
|
|||
Node& recompute_node = graph.AddNode(node.Name() + "_recompute",
|
||||
node.OpType(),
|
||||
"Recompute of " + node.Name(),
|
||||
{node.MutableInputDefs()[0]},
|
||||
node.MutableInputDefs(),
|
||||
{&recomputed_output},
|
||||
&node.GetAttributes(),
|
||||
node.Domain());
|
||||
|
|
|
|||
|
|
@ -10,12 +10,15 @@
|
|||
#include "gtest/gtest.h"
|
||||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "core/optimizer/bias_gelu_fusion.h"
|
||||
#include "core/optimizer/gelu_fusion.h"
|
||||
#include "core/optimizer/dropout_elimination.h"
|
||||
#include "orttraining/core/optimizer/bias_dropout_fusion.h"
|
||||
#include "orttraining/core/optimizer/gist_encode_decode.h"
|
||||
#include "orttraining/core/optimizer/nonzero_shape_setter.h"
|
||||
#include "orttraining/core/optimizer/megatron_transformer.h"
|
||||
#include "orttraining/core/optimizer/concat_replacement.h"
|
||||
#include "orttraining/core/optimizer/localized_recompute.h"
|
||||
#include "test/optimizer/graph_transform_test_fixture.h"
|
||||
#include "test/util/include/default_providers.h"
|
||||
#include "test/util/include/asserts.h"
|
||||
|
|
@ -417,6 +420,33 @@ TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionRank1) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, BiasGeluRecomputeTest) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/bias_gelu_fusion_recompute.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<GeluFusion>(), TransformerLevel::Level2);
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<BiasGeluFusion>(), TransformerLevel::Level2);
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<GeluRecompute>(), TransformerLevel::Level2);
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Div"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Erf"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Mul"] == 0);
|
||||
ASSERT_TRUE(op_to_count["com.microsoft.Gelu"] == 0);
|
||||
ASSERT_TRUE(op_to_count["com.microsoft.BiasGelu"] == 2);
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "com.microsoft.BiasGelu") {
|
||||
ASSERT_TRUE(node.InputDefs().size() == 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We only tested on CUDA run.
|
||||
#if defined(USE_CUDA)
|
||||
TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue