mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
GPT2 Gelu Fusion & Test (#3009)
* GPT2 Gelu Fusion & Test * change header path * Refine code & add missing test onnx file * Fix builds & refine float/double/fp16 compare. * Fix builds * Add Bias Check and UTs * Fix build and uts * Fuse with second formula & test * minor change * disable FastGelu to see whether the builds can pass * Verify where is wrong * disable for debugging * Revert "disable for debugging" This reverts commit 535c0817fb36fb95a75773a7f00c8b969dd5362c. * Revert "Verify where is wrong" This reverts commit ffc43ec1d136636ba2cee30df49f563a75e84676. * disable the transformer for inference currently * Enable FastGeluFusion and fix segement fault when run bertsquad10.onnx test * Add more Unit tests convering Gelu subgraph use graph input/output (cherry picked from commit 0739ab985240c6d9acdb8f0afd40c5fb316166af) * Mode Bias Fusion in BiasGelu.cc Co-authored-by: Changming Sun <chasun@microsoft.com>
This commit is contained in:
parent
932ecaea34
commit
92b8a7a2be
17 changed files with 837 additions and 9 deletions
|
|
@ -64,20 +64,28 @@ Status BiasGelu::ApplyImpl(Graph& graph, bool& modified, int graph_level, const
|
|||
}
|
||||
|
||||
const Node& next_node = (*next_node_itr);
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {1}, kMSDomain) ||
|
||||
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {1}, kMSDomain) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "FastGelu", {1}, kMSDomain)) ||
|
||||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool is_fast_gelu = next_node.OpType().compare("FastGelu") == 0;
|
||||
if (is_fast_gelu && next_node.InputDefs().size() > 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Node& add_node = node;
|
||||
Node& gelu_node = const_cast<Node&>(next_node);
|
||||
std::string op_type = "BiasGelu";
|
||||
if (is_fast_gelu) op_type = "FastGelu";
|
||||
|
||||
Node& gelu_add_fusion_node = graph.AddNode(graph.GenerateNodeName("BiasGelu"),
|
||||
"BiasGelu",
|
||||
Node& gelu_add_fusion_node = graph.AddNode(graph.GenerateNodeName(op_type),
|
||||
op_type,
|
||||
"fused Add and Gelu",
|
||||
gelu_input,
|
||||
{},
|
||||
|
|
|
|||
244
onnxruntime/core/optimizer/fast_gelu_fusion.cc
Normal file
244
onnxruntime/core/optimizer/fast_gelu_fusion.cc
Normal file
|
|
@ -0,0 +1,244 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/optimizer/fast_gelu_fusion.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "float.h"
|
||||
#include <deque>
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
// FastGelu supports limited data types.
|
||||
static std::vector<std::string> supported_data_types{"tensor(float16)", "tensor(float)"};
|
||||
|
||||
static bool CheckNode(const Node& node, const std::string& op_name, int32_t opset_version, ProviderType provider,
|
||||
bool require_single_output=false){
|
||||
return graph_utils::IsSupportedOptypeVersionAndDomain(node, op_name, {opset_version}) &&
|
||||
node.GetExecutionProviderType() == provider &&
|
||||
optimizer_utils::IsSupportedDataType(node, supported_data_types) &&
|
||||
(!require_single_output || node.GetOutputEdgesCount() == 1);
|
||||
}
|
||||
|
||||
MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node,
|
||||
std::vector<std::reference_wrapper<Node>>& nodes_to_fuse) const {
|
||||
MatchResult matchResult{false, nullptr, nullptr};
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul1_node, "Mul", {7}) ||
|
||||
!graph_utils::IsSupportedProvider(mul1_node, GetCompatibleExecutionProviders()) ||
|
||||
mul1_node.GetOutputEdgesCount() != 1 ||
|
||||
!optimizer_utils::IsSupportedDataType(mul1_node, supported_data_types)) {
|
||||
return matchResult;
|
||||
}
|
||||
|
||||
int32_t input_index = -1;
|
||||
const float mul_val = 0.044715f;
|
||||
for (auto i = 0; i < 2; i++) {
|
||||
if (optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul1_node.InputDefs()[i]), mul_val, true)){
|
||||
input_index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (input_index == -1) return matchResult;
|
||||
|
||||
NodeArg* gelu_without_bias_input_arg = mul1_node.MutableInputDefs()[(input_index + 1) % 2];
|
||||
nodes_to_fuse.push_back(mul1_node);
|
||||
|
||||
|
||||
Node& mul2_node = *graph.GetNode(mul1_node.OutputNodesBegin()->Index());
|
||||
input_index = optimizer_utils::IndexOfNodeInput(mul2_node, *mul1_node.MutableOutputDefs()[0]);
|
||||
if (!CheckNode(mul2_node, "Mul", 7, mul1_node.GetExecutionProviderType(), true) ||
|
||||
mul2_node.MutableInputDefs()[(input_index + 1) % 2]->Name() != gelu_without_bias_input_arg->Name()) {
|
||||
return matchResult;;
|
||||
}
|
||||
nodes_to_fuse.push_back(mul2_node);
|
||||
|
||||
|
||||
Node& add1_node = *graph.GetNode(mul2_node.OutputNodesBegin()->Index());
|
||||
input_index = optimizer_utils::IndexOfNodeInput(add1_node, *mul2_node.MutableOutputDefs()[0]);
|
||||
if (!CheckNode(add1_node, "Add", 7, mul1_node.GetExecutionProviderType(), true) ||
|
||||
!optimizer_utils::IsInitializerWithExpectedValue(graph, *(add1_node.InputDefs()[(input_index + 1) % 2]), 1.0f, true)) {
|
||||
return matchResult;
|
||||
}
|
||||
nodes_to_fuse.push_back(add1_node);
|
||||
|
||||
|
||||
Node& mul3_node = *graph.GetNode(add1_node.OutputNodesBegin()->Index());
|
||||
if (!CheckNode(mul3_node, "Mul", 7, mul1_node.GetExecutionProviderType(), true)) {
|
||||
return matchResult;
|
||||
}
|
||||
nodes_to_fuse.push_back(mul3_node);
|
||||
|
||||
input_index = optimizer_utils::IndexOfNodeInput(mul3_node, *add1_node.MutableOutputDefs()[0]);
|
||||
const Node* p_mul3_input_node = graph_utils::GetInputNode(mul3_node, (input_index + 1) % 2);
|
||||
if (p_mul3_input_node == nullptr) return matchResult;
|
||||
Node& mul4_node = const_cast<Node&>(*p_mul3_input_node);
|
||||
if (!CheckNode(mul4_node, "Mul", 7, mul1_node.GetExecutionProviderType(), true)) {
|
||||
return matchResult;
|
||||
}
|
||||
|
||||
input_index = -1;
|
||||
const float mul4_val = 0.7978845834732056f;
|
||||
for (auto i = 0; i < 2; i++) {
|
||||
if (optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul4_node.InputDefs()[i]), mul4_val, true)){
|
||||
input_index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (input_index == -1 || mul4_node.InputDefs()[(input_index + 1) % 2]->Name() != gelu_without_bias_input_arg->Name())
|
||||
return matchResult;
|
||||
nodes_to_fuse.push_back(mul4_node);
|
||||
|
||||
matchResult.matched = true;
|
||||
matchResult.gelu_without_bias_input_arg = gelu_without_bias_input_arg;
|
||||
matchResult.tanh_input_node = &mul3_node;
|
||||
return matchResult;
|
||||
}
|
||||
|
||||
MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node,
|
||||
std::vector<std::reference_wrapper<Node>>& nodes_to_fuse) const {
|
||||
MatchResult matchResult{false, nullptr, nullptr};
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow1_node, "Pow", {7}) ||
|
||||
!graph_utils::IsSupportedProvider(pow1_node, GetCompatibleExecutionProviders()) ||
|
||||
pow1_node.GetOutputEdgesCount() != 1 ||
|
||||
!optimizer_utils::IsSupportedDataType(pow1_node, supported_data_types)) {
|
||||
return matchResult;
|
||||
}
|
||||
|
||||
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(pow1_node.InputDefs()[1]), 3.0f, true)){
|
||||
return matchResult;
|
||||
}
|
||||
|
||||
NodeArg* pow_input_arg = pow1_node.MutableInputDefs()[0];
|
||||
nodes_to_fuse.push_back(pow1_node);
|
||||
|
||||
Node& mul1_node = *graph.GetNode(pow1_node.OutputNodesBegin()->Index());
|
||||
auto input_index = optimizer_utils::IndexOfNodeInput(mul1_node, *pow1_node.MutableOutputDefs()[0]);
|
||||
if (!CheckNode(mul1_node, "Mul", 7, pow1_node.GetExecutionProviderType(), true) ||
|
||||
!optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul1_node.InputDefs()[(input_index + 1) % 2]),
|
||||
0.044714998453855515f, true)) {
|
||||
return matchResult;
|
||||
}
|
||||
nodes_to_fuse.push_back(mul1_node);
|
||||
|
||||
|
||||
Node& add1_node = *graph.GetNode(mul1_node.OutputNodesBegin()->Index());
|
||||
input_index = optimizer_utils::IndexOfNodeInput(add1_node, *mul1_node.MutableOutputDefs()[0]);
|
||||
if (!CheckNode(add1_node, "Add", 7, pow1_node.GetExecutionProviderType(), true) ||
|
||||
add1_node.MutableInputDefs()[(input_index + 1) % 2]->Name() != pow_input_arg->Name()) {
|
||||
return matchResult;
|
||||
}
|
||||
nodes_to_fuse.push_back(add1_node);
|
||||
|
||||
|
||||
Node& mul2_node = *graph.GetNode(add1_node.OutputNodesBegin()->Index());
|
||||
input_index = optimizer_utils::IndexOfNodeInput(mul2_node, *add1_node.MutableOutputDefs()[0]);
|
||||
if (!CheckNode(mul2_node, "Mul", 7, pow1_node.GetExecutionProviderType(), true) ||
|
||||
!optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul2_node.InputDefs()[(input_index + 1) % 2]),
|
||||
0.7978845834732056f, true)) {
|
||||
return matchResult;
|
||||
}
|
||||
nodes_to_fuse.push_back(mul2_node);
|
||||
|
||||
matchResult.matched = true;
|
||||
matchResult.gelu_without_bias_input_arg = pow_input_arg;
|
||||
matchResult.tanh_input_node = &mul2_node;
|
||||
return matchResult;
|
||||
}
|
||||
|
||||
Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
for (auto node_index : node_topology_list) {
|
||||
auto* p_node = graph.GetNode(node_index);
|
||||
if (p_node == nullptr)
|
||||
continue;
|
||||
|
||||
Node& node = *p_node;
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
|
||||
std::vector<std::reference_wrapper<Node>> nodes_to_fuse;
|
||||
MatchResult matchRet = CheckFirstFormula(graph, node, nodes_to_fuse);
|
||||
if (!matchRet.matched) {
|
||||
nodes_to_fuse.clear();
|
||||
matchRet = CheckSecondFormula(graph, node, nodes_to_fuse);
|
||||
|
||||
if(!matchRet.matched) continue;
|
||||
};
|
||||
|
||||
Node& tanh_node = *graph.GetNode(matchRet.tanh_input_node->OutputNodesBegin()->Index());
|
||||
if (!CheckNode(tanh_node, "Tanh", 6, node.GetExecutionProviderType(), true)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
Node& add2_node = *graph.GetNode(tanh_node.OutputNodesBegin()->Index());
|
||||
if (!CheckNode(add2_node, "Add", 7, node.GetExecutionProviderType(), true)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto input_index = optimizer_utils::IndexOfNodeInput(add2_node, *tanh_node.MutableOutputDefs()[0]);
|
||||
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(add2_node.InputDefs()[(input_index + 1) % 2]), 1.0f, true)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Node& mul5_node = *graph.GetNode(add2_node.OutputNodesBegin()->Index());
|
||||
// This is the output of the Gelu subgraph, we don't need check it has single edge.
|
||||
if (!CheckNode(mul5_node, "Mul", 7, node.GetExecutionProviderType(), false)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// ingnore the transformer if Gelu's output is the graph's output.
|
||||
if (!graph.GetNodeOutputsInGraphOutputs(mul5_node).empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
input_index = optimizer_utils::IndexOfNodeInput(mul5_node, *add2_node.MutableOutputDefs()[0]);
|
||||
const Node* p_mul5_input_node = graph_utils::GetInputNode(mul5_node, (input_index + 1) % 2);
|
||||
if (p_mul5_input_node == nullptr) continue;
|
||||
Node& mul6_node = const_cast<Node&>(*p_mul5_input_node);
|
||||
if (!CheckNode(mul6_node, "Mul", 7, node.GetExecutionProviderType(), false)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
input_index = -1;
|
||||
for (auto i = 0; i < 2; i++) {
|
||||
if (optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul6_node.InputDefs()[i]), 0.5f, true)){
|
||||
input_index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (input_index == -1 || mul6_node.InputDefs()[(input_index + 1) % 2]->Name() != matchRet.gelu_without_bias_input_arg->Name())
|
||||
continue;
|
||||
|
||||
std::vector<NodeArg*> gelu_input_defs{matchRet.gelu_without_bias_input_arg};
|
||||
nodes_to_fuse.insert(nodes_to_fuse.end(), {tanh_node, add2_node, mul6_node, mul5_node});
|
||||
|
||||
auto type_info = *node.MutableOutputDefs()[0]->TypeAsProto();
|
||||
auto& shape_output = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("fast_gelu_output"), &type_info);
|
||||
Node& fast_gelu_node = graph.AddNode(graph.GenerateNodeName("GPT2Gelu"),
|
||||
"FastGelu",
|
||||
"fused GPT2Gelu subgraphs ",
|
||||
gelu_input_defs,
|
||||
{&shape_output}, {}, kMSDomain);
|
||||
|
||||
// assign provider to this new node, provider should be same as the provider for old node.
|
||||
fast_gelu_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
|
||||
// move input edges to node (first in list) across to the fast_gelu_node.
|
||||
// move output definitions and output edges from mul5_node (last in list) to fast_gelu_node.
|
||||
// remove all nodes.
|
||||
graph_utils::FinalizeNodeFusion(graph, nodes_to_fuse, fast_gelu_node);
|
||||
|
||||
modified = true;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
39
onnxruntime/core/optimizer/fast_gelu_fusion.h
Normal file
39
onnxruntime/core/optimizer/fast_gelu_fusion.h
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
struct MatchResult {
|
||||
public:
|
||||
bool matched;
|
||||
NodeArg* gelu_without_bias_input_arg; // The Gelu input arg if not considering bias node.
|
||||
Node* tanh_input_node;
|
||||
};
|
||||
|
||||
/**
|
||||
@Class FastGeluFusion
|
||||
|
||||
Rewrite graph fusing Gelu activation subgraph to a single Gelu node.
|
||||
|
||||
The formula corresponding to Gelu activation subgraph:
|
||||
x * 0.5 * (1.0 + tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) or
|
||||
x * 0.5 * (1.0 + tanh((sqrt(2 / pi) * (x + 0.044715 * pow(x, 3))))), where x is the input.
|
||||
|
||||
*/
|
||||
class FastGeluFusion : public GraphTransformer {
|
||||
public:
|
||||
FastGeluFusion(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
|
||||
: GraphTransformer("FastGeluFusion", compatible_execution_providers) {}
|
||||
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
|
||||
MatchResult CheckFirstFormula(Graph& graph, Node& node, std::vector<std::reference_wrapper<Node>>& nodes_to_fuse) const;
|
||||
|
||||
MatchResult CheckSecondFormula(Graph& graph, Node& nodes, std::vector<std::reference_wrapper<Node>>& nodes_to_fuse) const;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -21,6 +21,7 @@
|
|||
#include "core/optimizer/bias_gelu_fusion.h"
|
||||
#include "core/optimizer/gelu_fusion.h"
|
||||
#include "core/optimizer/gelu_approximation.h"
|
||||
#include "core/optimizer/fast_gelu_fusion.h"
|
||||
#include "core/optimizer/layer_norm_fusion.h"
|
||||
#include "core/optimizer/skip_layer_norm_fusion.h"
|
||||
#include "core/optimizer/embed_layer_norm_fusion.h"
|
||||
|
|
@ -135,6 +136,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL
|
|||
|
||||
std::unordered_set<std::string> cuda_execution_providers = {onnxruntime::kCudaExecutionProvider};
|
||||
transformers.emplace_back(onnxruntime::make_unique<GeluApproximation>(cuda_execution_providers));
|
||||
transformers.emplace_back(onnxruntime::make_unique<FastGeluFusion>(cuda_execution_providers));
|
||||
#endif
|
||||
} break;
|
||||
|
||||
|
|
|
|||
|
|
@ -36,6 +36,8 @@ bool IsInitializerWithExpectedValue(const Graph& graph, const NodeArg& input_arg
|
|||
return false;
|
||||
}
|
||||
|
||||
const float atol = 1e-8f;
|
||||
const float rtol = 1e-5f;
|
||||
const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr;
|
||||
if (is_constant) {
|
||||
tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name());
|
||||
|
|
@ -51,20 +53,28 @@ bool IsInitializerWithExpectedValue(const Graph& graph, const NodeArg& input_arg
|
|||
const auto data_type = tensor_proto->data_type();
|
||||
if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
const float* val = init_const->data<float>();
|
||||
float diff = std::abs(val[0] - static_cast<float>(expected_value));
|
||||
if (diff > FLT_EPSILON) {
|
||||
if (std::isnan(val[0]) || std::isinf(val[0])) return false;
|
||||
|
||||
float diff = std::abs(val[0] - expected_value);
|
||||
if (diff > (atol + rtol * std::abs(expected_value))) {
|
||||
return false;
|
||||
}
|
||||
} else if (data_type == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE) {
|
||||
const double* val = init_const->data<double>();
|
||||
double diff = std::abs(val[0] - static_cast<double>(expected_value));
|
||||
if (diff > DBL_EPSILON) {
|
||||
if (std::isnan(val[0]) || std::isinf(val[0])) return false;
|
||||
|
||||
const double expected_val = static_cast<double>(expected_value);
|
||||
double diff = std::abs(val[0] - expected_val);
|
||||
if (diff > (atol + rtol * std::abs(expected_value))) {
|
||||
return false;
|
||||
}
|
||||
} else if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
|
||||
const MLFloat16* val = init_const->data<MLFloat16>();
|
||||
float diff = std::abs(math::halfToFloat(val[0].val) - math::halfToFloat(math::floatToHalf(expected_value)));
|
||||
if (diff > FLT_EPSILON) {
|
||||
const float flt_val = math::halfToFloat(val[0].val);
|
||||
if (std::isnan(flt_val) || std::isinf(flt_val)) return false;
|
||||
const float expected_val = math::halfToFloat(math::floatToHalf(expected_value));
|
||||
float diff = std::abs(flt_val - expected_val);
|
||||
if (diff > (atol + rtol * std::abs(expected_value))) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
|
|
@ -176,5 +186,27 @@ bool IsShapeKnownOnAllDims(const NodeArg& node_arg, int expected_dim_size) {
|
|||
return true;
|
||||
}
|
||||
|
||||
int32_t IndexOfNodeInput(const Node& node, const NodeArg& node_arg) {
|
||||
int32_t index = 0;
|
||||
for (auto& input_arg : node.InputDefs()) {
|
||||
if (input_arg->Name().compare(node_arg.Name()) == 0) {
|
||||
return index;
|
||||
}
|
||||
index++;
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
bool IsSupportedDataType(const Node& node, const std::vector<std::string>& supported_data_types) {
|
||||
for (const auto& input_arg : node.InputDefs()) {
|
||||
if (std::find(supported_data_types.begin(), supported_data_types.end(),
|
||||
*(input_arg->Type())) == supported_data_types.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace optimizer_utils
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -52,5 +52,15 @@ bool ValidateShape(const NodeArg& node_arg, const std::initializer_list<int64_t>
|
|||
*/
|
||||
bool IsShapeKnownOnAllDims(const NodeArg& node_arg, int expected_dim_size);
|
||||
|
||||
/** Get the index of node_arg among the node's all inputs.
|
||||
@remarks -1 when node_arg is not in node's inputs..
|
||||
*/
|
||||
int32_t IndexOfNodeInput(const Node& node, const NodeArg& node_arg);
|
||||
|
||||
/** Check whether node's input data types are in supported data type list.
|
||||
@param supported_data_types specify the supported data types.
|
||||
*/
|
||||
bool IsSupportedDataType(const Node& node, const std::vector<std::string>& supported_data_types);
|
||||
|
||||
} // namespace optimizer_utils
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@
|
|||
#include "core/optimizer/unsqueeze_elimination.h"
|
||||
#include "core/optimizer/reshape_fusion.h"
|
||||
#include "core/optimizer/attention_fusion.h"
|
||||
#include "core/optimizer/fast_gelu_fusion.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "core/platform/env.h"
|
||||
#include "core/util/math.h"
|
||||
|
|
@ -1201,6 +1202,163 @@ TEST(GraphTransformationTests, GeluApproximation_Gelu_Add_MatMul) {
|
|||
EXPECT_EQ(op_to_count["FastGelu"], 1);
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, FastGeluFusionTest) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/fast_gelu.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(load_ret.IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<FastGeluFusion>(), TransformerLevel::Level2);
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Identity"] == 2);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Tanh"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Mul"] == 0);
|
||||
ASSERT_TRUE(op_to_count["FastGelu"] == 1);
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, FastGeluUseGraphInputFusionTest) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/fast_gelu_use_graph_input.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(load_ret.IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<FastGeluFusion>(), TransformerLevel::Level2);
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Tanh"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Mul"] == 0);
|
||||
ASSERT_TRUE(op_to_count["FastGelu"] == 1);
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, FastGeluWithBiasFusionTest) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/fast_gelu_with_bias.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(load_ret.IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<FastGeluFusion>(), TransformerLevel::Level2);
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<BiasGelu>(), TransformerLevel::Level2);
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Tanh"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Mul"] == 0);
|
||||
ASSERT_TRUE(op_to_count["FastGelu"] == 1);
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, FastGeluWithBiasUseGraphInputFusionTest) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/fast_gelu_with_bias_use_graph_input.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(load_ret.IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<FastGeluFusion>(), TransformerLevel::Level2);
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<BiasGelu>(), TransformerLevel::Level2);
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Tanh"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Mul"] == 0);
|
||||
ASSERT_TRUE(op_to_count["FastGelu"] == 1);
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, FastGeluFusionTest2) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/fast_gelu2.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(load_ret.IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<FastGeluFusion>(), TransformerLevel::Level2);
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Tanh"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Mul"] == 0);
|
||||
ASSERT_TRUE(op_to_count["FastGelu"] == 1);
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, FastGeluUseGraphInputFusionTest2) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/fast_gelu2_use_graph_input.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(load_ret.IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<FastGeluFusion>(), TransformerLevel::Level2);
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Tanh"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Mul"] == 0);
|
||||
ASSERT_TRUE(op_to_count["FastGelu"] == 1);
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, FastGeluWithBiasFusionTest2) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/fast_gelu2_with_bias.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(load_ret.IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<FastGeluFusion>(), TransformerLevel::Level2);
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<BiasGelu>(), TransformerLevel::Level2);
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Tanh"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Mul"] == 0);
|
||||
ASSERT_TRUE(op_to_count["FastGelu"] == 1);
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, FastGeluWithBiasUseGraphInputFusionTest2) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/fast_gelu2_with_bias_use_graph_input.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(load_ret.IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<FastGeluFusion>(), TransformerLevel::Level2);
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<BiasGelu>(), TransformerLevel::Level2);
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Tanh"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Mul"] == 0);
|
||||
ASSERT_TRUE(op_to_count["FastGelu"] == 1);
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, LayerNormFusionTest) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/layer_norm.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/fast_gelu.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/fast_gelu.onnx
vendored
Normal file
Binary file not shown.
172
onnxruntime/test/testdata/transform/fusion/fast_gelu.py
vendored
Normal file
172
onnxruntime/test/testdata/transform/fusion/fast_gelu.py
vendored
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto
|
||||
from onnx import numpy_helper
|
||||
import numpy as np
|
||||
|
||||
# Gelu formula: x * 0.5 * (1.0 + tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
|
||||
|
||||
has_bias = True # change it to True to generate fast_gelu_with_bias.onnx
|
||||
gelu_use_graph_input = True # change it to False to let Gelu don't have graph inputs as inputs.
|
||||
|
||||
X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", 64])
|
||||
Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "seqlen", 64])
|
||||
|
||||
bias_np_vals = (0.01 * np.arange(64)).astype(np.float32).reshape((64))
|
||||
bias_initializer = numpy_helper.from_array(bias_np_vals, "input_bias")
|
||||
|
||||
a_weight_np_vals = np.asarray([0.044714998453855515]).astype(np.float32).reshape(())
|
||||
a_weight_initializer = numpy_helper.from_array(a_weight_np_vals, "mul1_init")
|
||||
|
||||
b_weight_np_vals = np.asarray([0.7978845834732056]).astype(np.float32).reshape(())
|
||||
b_weight_initializer = numpy_helper.from_array(b_weight_np_vals, "mul2_init")
|
||||
|
||||
c_weight_np_vals = np.asarray([0.5]).astype(np.float32).reshape(())
|
||||
c_weight_initializer = numpy_helper.from_array(c_weight_np_vals, "mul3_init")
|
||||
|
||||
a_bias_np_vals = np.asarray([1.0]).astype(np.float32).reshape(())
|
||||
a_bias_initializer = numpy_helper.from_array(a_bias_np_vals, "add1_init")
|
||||
|
||||
b_bias_np_vals = np.asarray([1.0]).astype(np.float32).reshape(())
|
||||
b_bias_initializer = numpy_helper.from_array(b_bias_np_vals, "add2_init")
|
||||
|
||||
nodes = []
|
||||
gelu_input = "input"
|
||||
if not gelu_use_graph_input:
|
||||
leading_identity = helper.make_node(
|
||||
'Identity',
|
||||
[gelu_input],
|
||||
['identity_leading'],
|
||||
name="identity_leading"
|
||||
)
|
||||
gelu_input = "identity_leading"
|
||||
nodes.append(leading_identity)
|
||||
|
||||
mul_input_name = gelu_input
|
||||
if has_bias:
|
||||
add0 = helper.make_node(
|
||||
'Add',
|
||||
[gelu_input, bias_initializer.name],
|
||||
['add0'],
|
||||
name="add0"
|
||||
)
|
||||
mul_input_name = "add0"
|
||||
nodes.append(add0)
|
||||
|
||||
|
||||
mul1 = helper.make_node(
|
||||
'Mul',
|
||||
[mul_input_name, a_weight_initializer.name],
|
||||
['mul1'],
|
||||
name="mul1"
|
||||
)
|
||||
nodes.append(mul1)
|
||||
|
||||
mul2 = helper.make_node(
|
||||
'Mul',
|
||||
[mul_input_name, 'mul1'],
|
||||
['mul2'],
|
||||
name="mul2"
|
||||
)
|
||||
nodes.append(mul2)
|
||||
|
||||
add1 = helper.make_node(
|
||||
'Add',
|
||||
['mul2', a_bias_initializer.name],
|
||||
['add1'],
|
||||
name="add1"
|
||||
)
|
||||
nodes.append(add1)
|
||||
|
||||
mul3 = helper.make_node(
|
||||
'Mul',
|
||||
[mul_input_name, b_weight_initializer.name],
|
||||
['mul3'],
|
||||
name="mul3"
|
||||
)
|
||||
nodes.append(mul3)
|
||||
|
||||
mul4 = helper.make_node(
|
||||
'Mul',
|
||||
['mul3', 'add1'],
|
||||
['mul4'],
|
||||
name="mul4"
|
||||
)
|
||||
nodes.append(mul4)
|
||||
|
||||
tanh = helper.make_node(
|
||||
'Tanh',
|
||||
['mul4'],
|
||||
['tanh'],
|
||||
name="tanh"
|
||||
)
|
||||
nodes.append(tanh)
|
||||
|
||||
add2 = helper.make_node(
|
||||
'Add',
|
||||
['tanh', b_bias_initializer.name],
|
||||
['add2'],
|
||||
name="add2"
|
||||
)
|
||||
nodes.append(add2)
|
||||
|
||||
mul5 = helper.make_node(
|
||||
'Mul',
|
||||
[mul_input_name, c_weight_initializer.name],
|
||||
['mul5'],
|
||||
name="mul5"
|
||||
)
|
||||
nodes.append(mul5)
|
||||
|
||||
mul6 = helper.make_node(
|
||||
'Mul',
|
||||
['mul5', 'add2'],
|
||||
['mul6'],
|
||||
name="mul6"
|
||||
)
|
||||
ending_identity = helper.make_node(
|
||||
'Identity',
|
||||
['mul6'],
|
||||
['output'],
|
||||
name="identity_ending"
|
||||
)
|
||||
nodes.extend([mul6, ending_identity])
|
||||
|
||||
initializers = []
|
||||
if has_bias:
|
||||
initializers = [bias_initializer]
|
||||
|
||||
initializers.extend([a_weight_initializer, a_bias_initializer, b_weight_initializer, b_bias_initializer, c_weight_initializer])
|
||||
# Create the graph (GraphProto)
|
||||
graph_def = helper.make_graph(
|
||||
nodes,
|
||||
'test-model',
|
||||
[X],
|
||||
[Y],
|
||||
initializers
|
||||
)
|
||||
|
||||
opsets = []
|
||||
onnxdomain = OperatorSetIdProto()
|
||||
onnxdomain.version = 10
|
||||
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
|
||||
opsets.append(onnxdomain)
|
||||
|
||||
msdomain = OperatorSetIdProto()
|
||||
msdomain.version = 1
|
||||
msdomain.domain = "com.microsoft"
|
||||
|
||||
opsets.append(msdomain)
|
||||
kwargs={}
|
||||
kwargs["opset_imports"] = opsets
|
||||
|
||||
model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs)
|
||||
|
||||
file_name = "fast_gelu"
|
||||
if has_bias:
|
||||
file_name += "_with_bias"
|
||||
|
||||
if gelu_use_graph_input:
|
||||
file_name += "_use_graph_input"
|
||||
onnx.save(model_def, file_name + ".onnx")
|
||||
|
||||
BIN
onnxruntime/test/testdata/transform/fusion/fast_gelu2.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/fast_gelu2.onnx
vendored
Normal file
Binary file not shown.
163
onnxruntime/test/testdata/transform/fusion/fast_gelu2.py
vendored
Normal file
163
onnxruntime/test/testdata/transform/fusion/fast_gelu2.py
vendored
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto
|
||||
from onnx import numpy_helper
|
||||
import numpy as np
|
||||
|
||||
# Gelu formula: x * 0.5 * (1.0 + tanh((sqrt(2 / pi) * (x + 0.044715 * pow(x, 3)))))
|
||||
has_bias = False # change it to True to generate fast_gelu_openai_with_bias.onnx
|
||||
gelu_use_graph_input = True # change it to False to let Gelu don't have graph inputs/outputs as inputs/outputs.
|
||||
|
||||
X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", 64])
|
||||
Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "seqlen", 64])
|
||||
|
||||
bias_np_vals = (0.01 * np.arange(64)).astype(np.float32).reshape((64))
|
||||
bias_initializer = numpy_helper.from_array(bias_np_vals, "input_bias")
|
||||
|
||||
pow_np_vals = np.asarray([3]).astype(np.float32).reshape(())
|
||||
pow_initializer = numpy_helper.from_array(pow_np_vals, "pow_init")
|
||||
|
||||
a_weight_np_vals = np.asarray([0.044714998453855515]).astype(np.float32).reshape(())
|
||||
a_weight_initializer = numpy_helper.from_array(a_weight_np_vals, "mul1_init")
|
||||
|
||||
b_weight_np_vals = np.asarray([0.7978845834732056]).astype(np.float32).reshape(())
|
||||
b_weight_initializer = numpy_helper.from_array(b_weight_np_vals, "mul2_init")
|
||||
|
||||
c_weight_np_vals = np.asarray([0.5]).astype(np.float32).reshape(())
|
||||
c_weight_initializer = numpy_helper.from_array(c_weight_np_vals, "mul3_init")
|
||||
|
||||
b_bias_np_vals = np.asarray([1.0]).astype(np.float32).reshape(())
|
||||
b_bias_initializer = numpy_helper.from_array(b_bias_np_vals, "add2_init")
|
||||
|
||||
nodes = []
|
||||
gelu_input = "input"
|
||||
if not gelu_use_graph_input:
|
||||
leading_identity = helper.make_node(
|
||||
'Identity',
|
||||
[gelu_input],
|
||||
['identity_leading'],
|
||||
name="identity_leading"
|
||||
)
|
||||
gelu_input = "identity_leading"
|
||||
nodes.append(leading_identity)
|
||||
|
||||
mul_input_name = gelu_input
|
||||
if has_bias:
|
||||
add0 = helper.make_node(
|
||||
'Add',
|
||||
[gelu_input, bias_initializer.name],
|
||||
['add0'],
|
||||
name="add0"
|
||||
)
|
||||
mul_input_name = "add0"
|
||||
nodes.append(add0)
|
||||
|
||||
|
||||
pow1 = helper.make_node(
|
||||
'Pow',
|
||||
[mul_input_name, pow_initializer.name],
|
||||
['pow1'],
|
||||
name="pow1"
|
||||
)
|
||||
nodes.append(pow1)
|
||||
|
||||
mul1 = helper.make_node(
|
||||
'Mul',
|
||||
['pow1', a_weight_initializer.name],
|
||||
['mul1'],
|
||||
name="mul1"
|
||||
)
|
||||
nodes.append(mul1)
|
||||
|
||||
add1 = helper.make_node(
|
||||
'Add',
|
||||
[mul_input_name, "mul1"],
|
||||
['add1'],
|
||||
name="add1"
|
||||
)
|
||||
nodes.append(add1)
|
||||
|
||||
mul2 = helper.make_node(
|
||||
'Mul',
|
||||
['add1', b_weight_initializer.name],
|
||||
['mul2'],
|
||||
name="mul2"
|
||||
)
|
||||
nodes.append(mul2)
|
||||
|
||||
tanh = helper.make_node(
|
||||
'Tanh',
|
||||
['mul2'],
|
||||
['tanh'],
|
||||
name="tanh"
|
||||
)
|
||||
nodes.append(tanh)
|
||||
|
||||
add2 = helper.make_node(
|
||||
'Add',
|
||||
['tanh', b_bias_initializer.name],
|
||||
['add2'],
|
||||
name="add2"
|
||||
)
|
||||
nodes.append(add2)
|
||||
|
||||
mul5 = helper.make_node(
|
||||
'Mul',
|
||||
[mul_input_name, c_weight_initializer.name],
|
||||
['mul5'],
|
||||
name="mul5"
|
||||
)
|
||||
nodes.append(mul5)
|
||||
|
||||
mul6 = helper.make_node(
|
||||
'Mul',
|
||||
['mul5', 'add2'],
|
||||
['mul6'],
|
||||
name="mul6"
|
||||
)
|
||||
ending_identity = helper.make_node(
|
||||
'Identity',
|
||||
['mul6'],
|
||||
['output'],
|
||||
name="ending_identity"
|
||||
)
|
||||
nodes.extend([mul6, ending_identity])
|
||||
|
||||
initializers = []
|
||||
if has_bias:
|
||||
initializers = [bias_initializer]
|
||||
|
||||
initializers.extend([pow_initializer, a_weight_initializer, b_weight_initializer, b_bias_initializer, c_weight_initializer])
|
||||
# Create the graph (GraphProto)
|
||||
graph_def = helper.make_graph(
|
||||
nodes,
|
||||
'test-model',
|
||||
[X],
|
||||
[Y],
|
||||
initializers
|
||||
)
|
||||
|
||||
opsets = []
|
||||
onnxdomain = OperatorSetIdProto()
|
||||
onnxdomain.version = 10
|
||||
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
|
||||
opsets.append(onnxdomain)
|
||||
|
||||
msdomain = OperatorSetIdProto()
|
||||
msdomain.version = 1
|
||||
msdomain.domain = "com.microsoft"
|
||||
|
||||
opsets.append(msdomain)
|
||||
kwargs={}
|
||||
kwargs["opset_imports"] = opsets
|
||||
|
||||
model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs)
|
||||
|
||||
file_name = "fast_gelu2"
|
||||
if has_bias:
|
||||
file_name += "_with_bias"
|
||||
|
||||
if gelu_use_graph_input:
|
||||
file_name += "_use_graph_input"
|
||||
onnx.save(model_def, file_name + ".onnx")
|
||||
|
||||
BIN
onnxruntime/test/testdata/transform/fusion/fast_gelu2_use_graph_input.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/fast_gelu2_use_graph_input.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/fast_gelu2_with_bias.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/fast_gelu2_with_bias.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/fast_gelu2_with_bias_use_graph_input.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/fast_gelu2_with_bias_use_graph_input.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/fast_gelu_use_graph_input.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/fast_gelu_use_graph_input.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/fast_gelu_with_bias.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/fast_gelu_with_bias.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/fast_gelu_with_bias_use_graph_input.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/fast_gelu_with_bias_use_graph_input.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue