GPT2 Gelu Fusion & Test (#3009)

* GPT2 Gelu Fusion & Test

* change header path

* Refine code & add missing test onnx file

* Fix builds & refine float/double/fp16 compare.

* Fix builds

* Add Bias Check and UTs

* Fix build and uts

* Fuse with second formula & test

* minor change

* disable FastGelu to see whether the builds can pass

* Verify where is wrong

* disable for debugging

* Revert "disable for debugging"

This reverts commit 535c0817fb36fb95a75773a7f00c8b969dd5362c.

* Revert "Verify where is wrong"

This reverts commit ffc43ec1d136636ba2cee30df49f563a75e84676.

* disable the transformer for inference currently

* Enable FastGeluFusion and fix segement fault when run bertsquad10.onnx test

* Add more Unit tests convering Gelu subgraph use graph input/output

(cherry picked from commit 0739ab985240c6d9acdb8f0afd40c5fb316166af)

* Mode Bias Fusion in BiasGelu.cc

Co-authored-by: Changming Sun <chasun@microsoft.com>
This commit is contained in:
pengwa 2020-02-21 18:25:43 +08:00 committed by GitHub
parent 932ecaea34
commit 92b8a7a2be
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 837 additions and 9 deletions

View file

@ -64,20 +64,28 @@ Status BiasGelu::ApplyImpl(Graph& graph, bool& modified, int graph_level, const
}
const Node& next_node = (*next_node_itr);
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {1}, kMSDomain) ||
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {1}, kMSDomain) ||
graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "FastGelu", {1}, kMSDomain)) ||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
continue;
}
bool is_fast_gelu = next_node.OpType().compare("FastGelu") == 0;
if (is_fast_gelu && next_node.InputDefs().size() > 1) {
continue;
}
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
continue;
}
Node& add_node = node;
Node& gelu_node = const_cast<Node&>(next_node);
std::string op_type = "BiasGelu";
if (is_fast_gelu) op_type = "FastGelu";
Node& gelu_add_fusion_node = graph.AddNode(graph.GenerateNodeName("BiasGelu"),
"BiasGelu",
Node& gelu_add_fusion_node = graph.AddNode(graph.GenerateNodeName(op_type),
op_type,
"fused Add and Gelu",
gelu_input,
{},

View file

@ -0,0 +1,244 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/optimizer/initializer.h"
#include "core/optimizer/fast_gelu_fusion.h"
#include "core/optimizer/utils.h"
#include "core/graph/graph_utils.h"
#include "float.h"
#include <deque>
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {
// FastGelu supports limited data types.
static std::vector<std::string> supported_data_types{"tensor(float16)", "tensor(float)"};
static bool CheckNode(const Node& node, const std::string& op_name, int32_t opset_version, ProviderType provider,
bool require_single_output=false){
return graph_utils::IsSupportedOptypeVersionAndDomain(node, op_name, {opset_version}) &&
node.GetExecutionProviderType() == provider &&
optimizer_utils::IsSupportedDataType(node, supported_data_types) &&
(!require_single_output || node.GetOutputEdgesCount() == 1);
}
MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node,
std::vector<std::reference_wrapper<Node>>& nodes_to_fuse) const {
MatchResult matchResult{false, nullptr, nullptr};
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul1_node, "Mul", {7}) ||
!graph_utils::IsSupportedProvider(mul1_node, GetCompatibleExecutionProviders()) ||
mul1_node.GetOutputEdgesCount() != 1 ||
!optimizer_utils::IsSupportedDataType(mul1_node, supported_data_types)) {
return matchResult;
}
int32_t input_index = -1;
const float mul_val = 0.044715f;
for (auto i = 0; i < 2; i++) {
if (optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul1_node.InputDefs()[i]), mul_val, true)){
input_index = i;
break;
}
}
if (input_index == -1) return matchResult;
NodeArg* gelu_without_bias_input_arg = mul1_node.MutableInputDefs()[(input_index + 1) % 2];
nodes_to_fuse.push_back(mul1_node);
Node& mul2_node = *graph.GetNode(mul1_node.OutputNodesBegin()->Index());
input_index = optimizer_utils::IndexOfNodeInput(mul2_node, *mul1_node.MutableOutputDefs()[0]);
if (!CheckNode(mul2_node, "Mul", 7, mul1_node.GetExecutionProviderType(), true) ||
mul2_node.MutableInputDefs()[(input_index + 1) % 2]->Name() != gelu_without_bias_input_arg->Name()) {
return matchResult;;
}
nodes_to_fuse.push_back(mul2_node);
Node& add1_node = *graph.GetNode(mul2_node.OutputNodesBegin()->Index());
input_index = optimizer_utils::IndexOfNodeInput(add1_node, *mul2_node.MutableOutputDefs()[0]);
if (!CheckNode(add1_node, "Add", 7, mul1_node.GetExecutionProviderType(), true) ||
!optimizer_utils::IsInitializerWithExpectedValue(graph, *(add1_node.InputDefs()[(input_index + 1) % 2]), 1.0f, true)) {
return matchResult;
}
nodes_to_fuse.push_back(add1_node);
Node& mul3_node = *graph.GetNode(add1_node.OutputNodesBegin()->Index());
if (!CheckNode(mul3_node, "Mul", 7, mul1_node.GetExecutionProviderType(), true)) {
return matchResult;
}
nodes_to_fuse.push_back(mul3_node);
input_index = optimizer_utils::IndexOfNodeInput(mul3_node, *add1_node.MutableOutputDefs()[0]);
const Node* p_mul3_input_node = graph_utils::GetInputNode(mul3_node, (input_index + 1) % 2);
if (p_mul3_input_node == nullptr) return matchResult;
Node& mul4_node = const_cast<Node&>(*p_mul3_input_node);
if (!CheckNode(mul4_node, "Mul", 7, mul1_node.GetExecutionProviderType(), true)) {
return matchResult;
}
input_index = -1;
const float mul4_val = 0.7978845834732056f;
for (auto i = 0; i < 2; i++) {
if (optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul4_node.InputDefs()[i]), mul4_val, true)){
input_index = i;
break;
}
}
if (input_index == -1 || mul4_node.InputDefs()[(input_index + 1) % 2]->Name() != gelu_without_bias_input_arg->Name())
return matchResult;
nodes_to_fuse.push_back(mul4_node);
matchResult.matched = true;
matchResult.gelu_without_bias_input_arg = gelu_without_bias_input_arg;
matchResult.tanh_input_node = &mul3_node;
return matchResult;
}
MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node,
std::vector<std::reference_wrapper<Node>>& nodes_to_fuse) const {
MatchResult matchResult{false, nullptr, nullptr};
if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow1_node, "Pow", {7}) ||
!graph_utils::IsSupportedProvider(pow1_node, GetCompatibleExecutionProviders()) ||
pow1_node.GetOutputEdgesCount() != 1 ||
!optimizer_utils::IsSupportedDataType(pow1_node, supported_data_types)) {
return matchResult;
}
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(pow1_node.InputDefs()[1]), 3.0f, true)){
return matchResult;
}
NodeArg* pow_input_arg = pow1_node.MutableInputDefs()[0];
nodes_to_fuse.push_back(pow1_node);
Node& mul1_node = *graph.GetNode(pow1_node.OutputNodesBegin()->Index());
auto input_index = optimizer_utils::IndexOfNodeInput(mul1_node, *pow1_node.MutableOutputDefs()[0]);
if (!CheckNode(mul1_node, "Mul", 7, pow1_node.GetExecutionProviderType(), true) ||
!optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul1_node.InputDefs()[(input_index + 1) % 2]),
0.044714998453855515f, true)) {
return matchResult;
}
nodes_to_fuse.push_back(mul1_node);
Node& add1_node = *graph.GetNode(mul1_node.OutputNodesBegin()->Index());
input_index = optimizer_utils::IndexOfNodeInput(add1_node, *mul1_node.MutableOutputDefs()[0]);
if (!CheckNode(add1_node, "Add", 7, pow1_node.GetExecutionProviderType(), true) ||
add1_node.MutableInputDefs()[(input_index + 1) % 2]->Name() != pow_input_arg->Name()) {
return matchResult;
}
nodes_to_fuse.push_back(add1_node);
Node& mul2_node = *graph.GetNode(add1_node.OutputNodesBegin()->Index());
input_index = optimizer_utils::IndexOfNodeInput(mul2_node, *add1_node.MutableOutputDefs()[0]);
if (!CheckNode(mul2_node, "Mul", 7, pow1_node.GetExecutionProviderType(), true) ||
!optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul2_node.InputDefs()[(input_index + 1) % 2]),
0.7978845834732056f, true)) {
return matchResult;
}
nodes_to_fuse.push_back(mul2_node);
matchResult.matched = true;
matchResult.gelu_without_bias_input_arg = pow_input_arg;
matchResult.tanh_input_node = &mul2_node;
return matchResult;
}
Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
for (auto node_index : node_topology_list) {
auto* p_node = graph.GetNode(node_index);
if (p_node == nullptr)
continue;
Node& node = *p_node;
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
std::vector<std::reference_wrapper<Node>> nodes_to_fuse;
MatchResult matchRet = CheckFirstFormula(graph, node, nodes_to_fuse);
if (!matchRet.matched) {
nodes_to_fuse.clear();
matchRet = CheckSecondFormula(graph, node, nodes_to_fuse);
if(!matchRet.matched) continue;
};
Node& tanh_node = *graph.GetNode(matchRet.tanh_input_node->OutputNodesBegin()->Index());
if (!CheckNode(tanh_node, "Tanh", 6, node.GetExecutionProviderType(), true)) {
continue;
}
Node& add2_node = *graph.GetNode(tanh_node.OutputNodesBegin()->Index());
if (!CheckNode(add2_node, "Add", 7, node.GetExecutionProviderType(), true)) {
continue;
}
auto input_index = optimizer_utils::IndexOfNodeInput(add2_node, *tanh_node.MutableOutputDefs()[0]);
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(add2_node.InputDefs()[(input_index + 1) % 2]), 1.0f, true)) {
continue;
}
Node& mul5_node = *graph.GetNode(add2_node.OutputNodesBegin()->Index());
// This is the output of the Gelu subgraph, we don't need check it has single edge.
if (!CheckNode(mul5_node, "Mul", 7, node.GetExecutionProviderType(), false)) {
continue;
}
// ingnore the transformer if Gelu's output is the graph's output.
if (!graph.GetNodeOutputsInGraphOutputs(mul5_node).empty()) {
continue;
}
input_index = optimizer_utils::IndexOfNodeInput(mul5_node, *add2_node.MutableOutputDefs()[0]);
const Node* p_mul5_input_node = graph_utils::GetInputNode(mul5_node, (input_index + 1) % 2);
if (p_mul5_input_node == nullptr) continue;
Node& mul6_node = const_cast<Node&>(*p_mul5_input_node);
if (!CheckNode(mul6_node, "Mul", 7, node.GetExecutionProviderType(), false)) {
continue;
}
input_index = -1;
for (auto i = 0; i < 2; i++) {
if (optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul6_node.InputDefs()[i]), 0.5f, true)){
input_index = i;
break;
}
}
if (input_index == -1 || mul6_node.InputDefs()[(input_index + 1) % 2]->Name() != matchRet.gelu_without_bias_input_arg->Name())
continue;
std::vector<NodeArg*> gelu_input_defs{matchRet.gelu_without_bias_input_arg};
nodes_to_fuse.insert(nodes_to_fuse.end(), {tanh_node, add2_node, mul6_node, mul5_node});
auto type_info = *node.MutableOutputDefs()[0]->TypeAsProto();
auto& shape_output = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("fast_gelu_output"), &type_info);
Node& fast_gelu_node = graph.AddNode(graph.GenerateNodeName("GPT2Gelu"),
"FastGelu",
"fused GPT2Gelu subgraphs ",
gelu_input_defs,
{&shape_output}, {}, kMSDomain);
// assign provider to this new node, provider should be same as the provider for old node.
fast_gelu_node.SetExecutionProviderType(node.GetExecutionProviderType());
// move input edges to node (first in list) across to the fast_gelu_node.
// move output definitions and output edges from mul5_node (last in list) to fast_gelu_node.
// remove all nodes.
graph_utils::FinalizeNodeFusion(graph, nodes_to_fuse, fast_gelu_node);
modified = true;
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/graph_transformer.h"
namespace onnxruntime {
struct MatchResult {
public:
bool matched;
NodeArg* gelu_without_bias_input_arg; // The Gelu input arg if not considering bias node.
Node* tanh_input_node;
};
/**
@Class FastGeluFusion
Rewrite graph fusing Gelu activation subgraph to a single Gelu node.
The formula corresponding to Gelu activation subgraph:
x * 0.5 * (1.0 + tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) or
x * 0.5 * (1.0 + tanh((sqrt(2 / pi) * (x + 0.044715 * pow(x, 3))))), where x is the input.
*/
class FastGeluFusion : public GraphTransformer {
public:
FastGeluFusion(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
: GraphTransformer("FastGeluFusion", compatible_execution_providers) {}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
MatchResult CheckFirstFormula(Graph& graph, Node& node, std::vector<std::reference_wrapper<Node>>& nodes_to_fuse) const;
MatchResult CheckSecondFormula(Graph& graph, Node& nodes, std::vector<std::reference_wrapper<Node>>& nodes_to_fuse) const;
};
} // namespace onnxruntime

View file

@ -21,6 +21,7 @@
#include "core/optimizer/bias_gelu_fusion.h"
#include "core/optimizer/gelu_fusion.h"
#include "core/optimizer/gelu_approximation.h"
#include "core/optimizer/fast_gelu_fusion.h"
#include "core/optimizer/layer_norm_fusion.h"
#include "core/optimizer/skip_layer_norm_fusion.h"
#include "core/optimizer/embed_layer_norm_fusion.h"
@ -135,6 +136,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL
std::unordered_set<std::string> cuda_execution_providers = {onnxruntime::kCudaExecutionProvider};
transformers.emplace_back(onnxruntime::make_unique<GeluApproximation>(cuda_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<FastGeluFusion>(cuda_execution_providers));
#endif
} break;

View file

@ -36,6 +36,8 @@ bool IsInitializerWithExpectedValue(const Graph& graph, const NodeArg& input_arg
return false;
}
const float atol = 1e-8f;
const float rtol = 1e-5f;
const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr;
if (is_constant) {
tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name());
@ -51,20 +53,28 @@ bool IsInitializerWithExpectedValue(const Graph& graph, const NodeArg& input_arg
const auto data_type = tensor_proto->data_type();
if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
const float* val = init_const->data<float>();
float diff = std::abs(val[0] - static_cast<float>(expected_value));
if (diff > FLT_EPSILON) {
if (std::isnan(val[0]) || std::isinf(val[0])) return false;
float diff = std::abs(val[0] - expected_value);
if (diff > (atol + rtol * std::abs(expected_value))) {
return false;
}
} else if (data_type == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE) {
const double* val = init_const->data<double>();
double diff = std::abs(val[0] - static_cast<double>(expected_value));
if (diff > DBL_EPSILON) {
if (std::isnan(val[0]) || std::isinf(val[0])) return false;
const double expected_val = static_cast<double>(expected_value);
double diff = std::abs(val[0] - expected_val);
if (diff > (atol + rtol * std::abs(expected_value))) {
return false;
}
} else if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
const MLFloat16* val = init_const->data<MLFloat16>();
float diff = std::abs(math::halfToFloat(val[0].val) - math::halfToFloat(math::floatToHalf(expected_value)));
if (diff > FLT_EPSILON) {
const float flt_val = math::halfToFloat(val[0].val);
if (std::isnan(flt_val) || std::isinf(flt_val)) return false;
const float expected_val = math::halfToFloat(math::floatToHalf(expected_value));
float diff = std::abs(flt_val - expected_val);
if (diff > (atol + rtol * std::abs(expected_value))) {
return false;
}
} else {
@ -176,5 +186,27 @@ bool IsShapeKnownOnAllDims(const NodeArg& node_arg, int expected_dim_size) {
return true;
}
int32_t IndexOfNodeInput(const Node& node, const NodeArg& node_arg) {
int32_t index = 0;
for (auto& input_arg : node.InputDefs()) {
if (input_arg->Name().compare(node_arg.Name()) == 0) {
return index;
}
index++;
}
return -1;
}
bool IsSupportedDataType(const Node& node, const std::vector<std::string>& supported_data_types) {
for (const auto& input_arg : node.InputDefs()) {
if (std::find(supported_data_types.begin(), supported_data_types.end(),
*(input_arg->Type())) == supported_data_types.end()) {
return false;
}
}
return true;
}
} // namespace optimizer_utils
} // namespace onnxruntime

View file

@ -52,5 +52,15 @@ bool ValidateShape(const NodeArg& node_arg, const std::initializer_list<int64_t>
*/
bool IsShapeKnownOnAllDims(const NodeArg& node_arg, int expected_dim_size);
/** Get the index of node_arg among the node's all inputs.
@remarks -1 when node_arg is not in node's inputs..
*/
int32_t IndexOfNodeInput(const Node& node, const NodeArg& node_arg);
/** Check whether node's input data types are in supported data type list.
@param supported_data_types specify the supported data types.
*/
bool IsSupportedDataType(const Node& node, const std::vector<std::string>& supported_data_types);
} // namespace optimizer_utils
} // namespace onnxruntime

View file

@ -33,6 +33,7 @@
#include "core/optimizer/unsqueeze_elimination.h"
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/attention_fusion.h"
#include "core/optimizer/fast_gelu_fusion.h"
#include "core/optimizer/utils.h"
#include "core/platform/env.h"
#include "core/util/math.h"
@ -1201,6 +1202,163 @@ TEST(GraphTransformationTests, GeluApproximation_Gelu_Add_MatMul) {
EXPECT_EQ(op_to_count["FastGelu"], 1);
}
TEST(GraphTransformationTests, FastGeluFusionTest) {
auto model_uri = MODEL_FOLDER "fusion/fast_gelu.onnx";
std::shared_ptr<Model> p_model;
auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(load_ret.IsOK());
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<FastGeluFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Identity"] == 2);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Tanh"] == 0);
ASSERT_TRUE(op_to_count["Mul"] == 0);
ASSERT_TRUE(op_to_count["FastGelu"] == 1);
}
TEST(GraphTransformationTests, FastGeluUseGraphInputFusionTest) {
auto model_uri = MODEL_FOLDER "fusion/fast_gelu_use_graph_input.onnx";
std::shared_ptr<Model> p_model;
auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(load_ret.IsOK());
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<FastGeluFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Tanh"] == 0);
ASSERT_TRUE(op_to_count["Mul"] == 0);
ASSERT_TRUE(op_to_count["FastGelu"] == 1);
}
TEST(GraphTransformationTests, FastGeluWithBiasFusionTest) {
auto model_uri = MODEL_FOLDER "fusion/fast_gelu_with_bias.onnx";
std::shared_ptr<Model> p_model;
auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(load_ret.IsOK());
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<FastGeluFusion>(), TransformerLevel::Level2);
graph_transformation_mgr.Register(onnxruntime::make_unique<BiasGelu>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Tanh"] == 0);
ASSERT_TRUE(op_to_count["Mul"] == 0);
ASSERT_TRUE(op_to_count["FastGelu"] == 1);
}
TEST(GraphTransformationTests, FastGeluWithBiasUseGraphInputFusionTest) {
auto model_uri = MODEL_FOLDER "fusion/fast_gelu_with_bias_use_graph_input.onnx";
std::shared_ptr<Model> p_model;
auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(load_ret.IsOK());
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<FastGeluFusion>(), TransformerLevel::Level2);
graph_transformation_mgr.Register(onnxruntime::make_unique<BiasGelu>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Tanh"] == 0);
ASSERT_TRUE(op_to_count["Mul"] == 0);
ASSERT_TRUE(op_to_count["FastGelu"] == 1);
}
TEST(GraphTransformationTests, FastGeluFusionTest2) {
auto model_uri = MODEL_FOLDER "fusion/fast_gelu2.onnx";
std::shared_ptr<Model> p_model;
auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(load_ret.IsOK());
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<FastGeluFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Tanh"] == 0);
ASSERT_TRUE(op_to_count["Mul"] == 0);
ASSERT_TRUE(op_to_count["FastGelu"] == 1);
}
TEST(GraphTransformationTests, FastGeluUseGraphInputFusionTest2) {
auto model_uri = MODEL_FOLDER "fusion/fast_gelu2_use_graph_input.onnx";
std::shared_ptr<Model> p_model;
auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(load_ret.IsOK());
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<FastGeluFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Tanh"] == 0);
ASSERT_TRUE(op_to_count["Mul"] == 0);
ASSERT_TRUE(op_to_count["FastGelu"] == 1);
}
TEST(GraphTransformationTests, FastGeluWithBiasFusionTest2) {
auto model_uri = MODEL_FOLDER "fusion/fast_gelu2_with_bias.onnx";
std::shared_ptr<Model> p_model;
auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(load_ret.IsOK());
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<FastGeluFusion>(), TransformerLevel::Level2);
graph_transformation_mgr.Register(onnxruntime::make_unique<BiasGelu>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Tanh"] == 0);
ASSERT_TRUE(op_to_count["Mul"] == 0);
ASSERT_TRUE(op_to_count["FastGelu"] == 1);
}
TEST(GraphTransformationTests, FastGeluWithBiasUseGraphInputFusionTest2) {
auto model_uri = MODEL_FOLDER "fusion/fast_gelu2_with_bias_use_graph_input.onnx";
std::shared_ptr<Model> p_model;
auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(load_ret.IsOK());
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<FastGeluFusion>(), TransformerLevel::Level2);
graph_transformation_mgr.Register(onnxruntime::make_unique<BiasGelu>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Tanh"] == 0);
ASSERT_TRUE(op_to_count["Mul"] == 0);
ASSERT_TRUE(op_to_count["FastGelu"] == 1);
}
TEST(GraphTransformationTests, LayerNormFusionTest) {
auto model_uri = MODEL_FOLDER "fusion/layer_norm.onnx";
std::shared_ptr<Model> p_model;

Binary file not shown.

View file

@ -0,0 +1,172 @@
import onnx
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto
from onnx import numpy_helper
import numpy as np
# Gelu formula: x * 0.5 * (1.0 + tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
has_bias = True # change it to True to generate fast_gelu_with_bias.onnx
gelu_use_graph_input = True # change it to False to let Gelu don't have graph inputs as inputs.
X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", 64])
Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "seqlen", 64])
bias_np_vals = (0.01 * np.arange(64)).astype(np.float32).reshape((64))
bias_initializer = numpy_helper.from_array(bias_np_vals, "input_bias")
a_weight_np_vals = np.asarray([0.044714998453855515]).astype(np.float32).reshape(())
a_weight_initializer = numpy_helper.from_array(a_weight_np_vals, "mul1_init")
b_weight_np_vals = np.asarray([0.7978845834732056]).astype(np.float32).reshape(())
b_weight_initializer = numpy_helper.from_array(b_weight_np_vals, "mul2_init")
c_weight_np_vals = np.asarray([0.5]).astype(np.float32).reshape(())
c_weight_initializer = numpy_helper.from_array(c_weight_np_vals, "mul3_init")
a_bias_np_vals = np.asarray([1.0]).astype(np.float32).reshape(())
a_bias_initializer = numpy_helper.from_array(a_bias_np_vals, "add1_init")
b_bias_np_vals = np.asarray([1.0]).astype(np.float32).reshape(())
b_bias_initializer = numpy_helper.from_array(b_bias_np_vals, "add2_init")
nodes = []
gelu_input = "input"
if not gelu_use_graph_input:
leading_identity = helper.make_node(
'Identity',
[gelu_input],
['identity_leading'],
name="identity_leading"
)
gelu_input = "identity_leading"
nodes.append(leading_identity)
mul_input_name = gelu_input
if has_bias:
add0 = helper.make_node(
'Add',
[gelu_input, bias_initializer.name],
['add0'],
name="add0"
)
mul_input_name = "add0"
nodes.append(add0)
mul1 = helper.make_node(
'Mul',
[mul_input_name, a_weight_initializer.name],
['mul1'],
name="mul1"
)
nodes.append(mul1)
mul2 = helper.make_node(
'Mul',
[mul_input_name, 'mul1'],
['mul2'],
name="mul2"
)
nodes.append(mul2)
add1 = helper.make_node(
'Add',
['mul2', a_bias_initializer.name],
['add1'],
name="add1"
)
nodes.append(add1)
mul3 = helper.make_node(
'Mul',
[mul_input_name, b_weight_initializer.name],
['mul3'],
name="mul3"
)
nodes.append(mul3)
mul4 = helper.make_node(
'Mul',
['mul3', 'add1'],
['mul4'],
name="mul4"
)
nodes.append(mul4)
tanh = helper.make_node(
'Tanh',
['mul4'],
['tanh'],
name="tanh"
)
nodes.append(tanh)
add2 = helper.make_node(
'Add',
['tanh', b_bias_initializer.name],
['add2'],
name="add2"
)
nodes.append(add2)
mul5 = helper.make_node(
'Mul',
[mul_input_name, c_weight_initializer.name],
['mul5'],
name="mul5"
)
nodes.append(mul5)
mul6 = helper.make_node(
'Mul',
['mul5', 'add2'],
['mul6'],
name="mul6"
)
ending_identity = helper.make_node(
'Identity',
['mul6'],
['output'],
name="identity_ending"
)
nodes.extend([mul6, ending_identity])
initializers = []
if has_bias:
initializers = [bias_initializer]
initializers.extend([a_weight_initializer, a_bias_initializer, b_weight_initializer, b_bias_initializer, c_weight_initializer])
# Create the graph (GraphProto)
graph_def = helper.make_graph(
nodes,
'test-model',
[X],
[Y],
initializers
)
opsets = []
onnxdomain = OperatorSetIdProto()
onnxdomain.version = 10
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
opsets.append(onnxdomain)
msdomain = OperatorSetIdProto()
msdomain.version = 1
msdomain.domain = "com.microsoft"
opsets.append(msdomain)
kwargs={}
kwargs["opset_imports"] = opsets
model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs)
file_name = "fast_gelu"
if has_bias:
file_name += "_with_bias"
if gelu_use_graph_input:
file_name += "_use_graph_input"
onnx.save(model_def, file_name + ".onnx")

Binary file not shown.

View file

@ -0,0 +1,163 @@
import onnx
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto
from onnx import numpy_helper
import numpy as np
# Gelu formula: x * 0.5 * (1.0 + tanh((sqrt(2 / pi) * (x + 0.044715 * pow(x, 3)))))
has_bias = False # change it to True to generate fast_gelu_openai_with_bias.onnx
gelu_use_graph_input = True # change it to False to let Gelu don't have graph inputs/outputs as inputs/outputs.
X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", 64])
Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "seqlen", 64])
bias_np_vals = (0.01 * np.arange(64)).astype(np.float32).reshape((64))
bias_initializer = numpy_helper.from_array(bias_np_vals, "input_bias")
pow_np_vals = np.asarray([3]).astype(np.float32).reshape(())
pow_initializer = numpy_helper.from_array(pow_np_vals, "pow_init")
a_weight_np_vals = np.asarray([0.044714998453855515]).astype(np.float32).reshape(())
a_weight_initializer = numpy_helper.from_array(a_weight_np_vals, "mul1_init")
b_weight_np_vals = np.asarray([0.7978845834732056]).astype(np.float32).reshape(())
b_weight_initializer = numpy_helper.from_array(b_weight_np_vals, "mul2_init")
c_weight_np_vals = np.asarray([0.5]).astype(np.float32).reshape(())
c_weight_initializer = numpy_helper.from_array(c_weight_np_vals, "mul3_init")
b_bias_np_vals = np.asarray([1.0]).astype(np.float32).reshape(())
b_bias_initializer = numpy_helper.from_array(b_bias_np_vals, "add2_init")
nodes = []
gelu_input = "input"
if not gelu_use_graph_input:
leading_identity = helper.make_node(
'Identity',
[gelu_input],
['identity_leading'],
name="identity_leading"
)
gelu_input = "identity_leading"
nodes.append(leading_identity)
mul_input_name = gelu_input
if has_bias:
add0 = helper.make_node(
'Add',
[gelu_input, bias_initializer.name],
['add0'],
name="add0"
)
mul_input_name = "add0"
nodes.append(add0)
pow1 = helper.make_node(
'Pow',
[mul_input_name, pow_initializer.name],
['pow1'],
name="pow1"
)
nodes.append(pow1)
mul1 = helper.make_node(
'Mul',
['pow1', a_weight_initializer.name],
['mul1'],
name="mul1"
)
nodes.append(mul1)
add1 = helper.make_node(
'Add',
[mul_input_name, "mul1"],
['add1'],
name="add1"
)
nodes.append(add1)
mul2 = helper.make_node(
'Mul',
['add1', b_weight_initializer.name],
['mul2'],
name="mul2"
)
nodes.append(mul2)
tanh = helper.make_node(
'Tanh',
['mul2'],
['tanh'],
name="tanh"
)
nodes.append(tanh)
add2 = helper.make_node(
'Add',
['tanh', b_bias_initializer.name],
['add2'],
name="add2"
)
nodes.append(add2)
mul5 = helper.make_node(
'Mul',
[mul_input_name, c_weight_initializer.name],
['mul5'],
name="mul5"
)
nodes.append(mul5)
mul6 = helper.make_node(
'Mul',
['mul5', 'add2'],
['mul6'],
name="mul6"
)
ending_identity = helper.make_node(
'Identity',
['mul6'],
['output'],
name="ending_identity"
)
nodes.extend([mul6, ending_identity])
initializers = []
if has_bias:
initializers = [bias_initializer]
initializers.extend([pow_initializer, a_weight_initializer, b_weight_initializer, b_bias_initializer, c_weight_initializer])
# Create the graph (GraphProto)
graph_def = helper.make_graph(
nodes,
'test-model',
[X],
[Y],
initializers
)
opsets = []
onnxdomain = OperatorSetIdProto()
onnxdomain.version = 10
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
opsets.append(onnxdomain)
msdomain = OperatorSetIdProto()
msdomain.version = 1
msdomain.domain = "com.microsoft"
opsets.append(msdomain)
kwargs={}
kwargs["opset_imports"] = opsets
model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs)
file_name = "fast_gelu2"
if has_bias:
file_name += "_with_bias"
if gelu_use_graph_input:
file_name += "_use_graph_input"
onnx.save(model_def, file_name + ".onnx")

Binary file not shown.