Fix GeluRecompute for 2 inputs case. (#5573)

* Add test for FastGelu + GeluRecompute.

* Fix GeluRecompute for 2 inputs case.

* Fix test for BiasGelu + GeluRecompute.

* Copy all inputs to Gelu, not just 2.

* Move GeluRecompute test to training-specific file.
This commit is contained in:
Sergii Dymchenko 2020-10-29 00:07:13 -07:00 committed by GitHub
parent b85e7a19ea
commit 2e1fa3ccb7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 68 additions and 2 deletions

View file

@ -108,7 +108,7 @@ class Node {
/** Gets the domain of the OperatorSet that specifies the operator returned by #OpType. */
const std::string& Domain() const noexcept { return domain_; }
/** Gets the Node's exection priority.
/** Gets the Node's execution priority.
@remarks Lower value means higher priority */
int Priority() const noexcept { return priority_; };

View file

@ -0,0 +1,36 @@
import onnx
from onnx import helper
from onnx import TensorProto
graph = helper.make_graph(
[ # nodes
# Add node before Gelu
helper.make_node("Add", ["A", "B"], ["add0_out"], "add0"),
# Gelu subgraph
helper.make_node("Div", ["add0_out", "div_const"], ["div_out"], "div"),
helper.make_node("Mul", ["add0_out", "mul_const"], ["mul_out"], "mul0"),
helper.make_node("Erf", ["div_out"], ["erf_out"], "erf"),
helper.make_node("Add", ["erf_out", "add_const"], ["add1_out"], "add1"),
helper.make_node("Mul", ["mul_out", "add1_out"], ["C"], "mul1"),
# MatMul node after Gelu for recompute
helper.make_node("MatMul", ["X", "C"], ["D"], "matmul"),
],
"Gelu_Add_Fusion_Recompute", #name
[ # inputs
helper.make_tensor_value_info('A', TensorProto.FLOAT, ['unk_1', 'unk_2', 3072]),
helper.make_tensor_value_info('B', TensorProto.FLOAT, [3072]),
helper.make_tensor_value_info('X', TensorProto.FLOAT, ['unk_5', 'unk_6', 3072]),
],
[ # outputs
helper.make_tensor_value_info('D', TensorProto.FLOAT, ['unk_3', 'unk_4', 'unk_5']),
],
[ # initializers
helper.make_tensor('div_const', TensorProto.FLOAT, [], [1.4142135381698608]),
helper.make_tensor('mul_const', TensorProto.FLOAT, [], [0.5]),
helper.make_tensor('add_const', TensorProto.FLOAT, [], [1]),
])
model = helper.make_model(graph)
onnx.save(model, r'bias_gelu_fusion_recompute.onnx')

View file

@ -44,7 +44,7 @@ Status GeluRecompute::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*
Node& recompute_node = graph.AddNode(node.Name() + "_recompute",
node.OpType(),
"Recompute of " + node.Name(),
{node.MutableInputDefs()[0]},
node.MutableInputDefs(),
{&recomputed_output},
&node.GetAttributes(),
node.Domain());

View file

@ -10,12 +10,15 @@
#include "gtest/gtest.h"
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/optimizer/utils.h"
#include "core/optimizer/bias_gelu_fusion.h"
#include "core/optimizer/gelu_fusion.h"
#include "core/optimizer/dropout_elimination.h"
#include "orttraining/core/optimizer/bias_dropout_fusion.h"
#include "orttraining/core/optimizer/gist_encode_decode.h"
#include "orttraining/core/optimizer/nonzero_shape_setter.h"
#include "orttraining/core/optimizer/megatron_transformer.h"
#include "orttraining/core/optimizer/concat_replacement.h"
#include "orttraining/core/optimizer/localized_recompute.h"
#include "test/optimizer/graph_transform_test_fixture.h"
#include "test/util/include/default_providers.h"
#include "test/util/include/asserts.h"
@ -417,6 +420,33 @@ TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionRank1) {
}
}
TEST_F(GraphTransformationTests, BiasGeluRecomputeTest) {
auto model_uri = MODEL_FOLDER "fusion/bias_gelu_fusion_recompute.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<GeluFusion>(), TransformerLevel::Level2);
graph_transformation_mgr.Register(onnxruntime::make_unique<BiasGeluFusion>(), TransformerLevel::Level2);
graph_transformation_mgr.Register(onnxruntime::make_unique<GeluRecompute>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Div"] == 0);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Erf"] == 0);
ASSERT_TRUE(op_to_count["Mul"] == 0);
ASSERT_TRUE(op_to_count["com.microsoft.Gelu"] == 0);
ASSERT_TRUE(op_to_count["com.microsoft.BiasGelu"] == 2);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "com.microsoft.BiasGelu") {
ASSERT_TRUE(node.InputDefs().size() == 2);
}
}
}
// We only tested on CUDA run.
#if defined(USE_CUDA)
TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {