mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
ORTModule - FastGeluFusion/fp16 fix and minor LayerNormFusion cleanup (#6734)
* fastgelu fix * assert cast outputs Co-authored-by: Ethan Tao <ettao@OrtTrainingDev4.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
parent
fb3f1f5cc1
commit
39d182f7fc
6 changed files with 194 additions and 5 deletions
|
|
@ -141,6 +141,25 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node,
|
|||
}
|
||||
nodes_to_fuse.push_back(add1_node);
|
||||
|
||||
// check if pow node has Cast parent, expect Add has same Cast parent as well
|
||||
const Node* p_cast1_node = graph_utils::FirstParentByType(pow1_node, "Cast");
|
||||
if (p_cast1_node != nullptr) {
|
||||
Node& cast1_node = *graph.GetNode(p_cast1_node->Index());
|
||||
// this is fused Cast node, so expect 2 output edges
|
||||
if (!CheckNode(graph, cast1_node, "Cast", {9, 13}, pow1_node.GetExecutionProviderType(), false) ||
|
||||
cast1_node.GetOutputEdgesCount() != 2){
|
||||
return matchResult;
|
||||
}
|
||||
const Node* p_pow_node = graph_utils::FirstChildByType(cast1_node, "Pow");
|
||||
if (p_pow_node == nullptr || p_pow_node->Index() != pow1_node.Index()) {
|
||||
return matchResult;
|
||||
}
|
||||
const Node* p_add_node = graph_utils::FirstChildByType(cast1_node, "Add");
|
||||
if (p_add_node == nullptr || p_add_node->Index() != add1_node.Index()) {
|
||||
return matchResult;
|
||||
}
|
||||
}
|
||||
|
||||
Node& mul2_node = *graph.GetNode(add1_node.OutputNodesBegin()->Index());
|
||||
input_index = optimizer_utils::IndexOfNodeInput(mul2_node, *add1_node.MutableOutputDefs()[0]);
|
||||
if (!CheckNode(graph, mul2_node, "Mul", {7, 13}, pow1_node.GetExecutionProviderType(), true) ||
|
||||
|
|
@ -156,6 +175,22 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node,
|
|||
return matchResult;
|
||||
}
|
||||
|
||||
/**
|
||||
In case of ORTModule, there are extra Cast nodes exported for fp16. They should be fused into two nodes:
|
||||
|
||||
x --> Cast --> FastGelu
|
||||
|
||||
The first Cast should have been fused in CommonSubexpressionElimination transformer, thus it has 2 output edges.
|
||||
|
||||
+--------------------------------------------> Mul ---> Cast ----+
|
||||
| |
|
||||
| v
|
||||
X --> Cast --> Pow --> Mul --> Add --> Mul --> Tanh --> Add --> Mul
|
||||
| ^
|
||||
| |
|
||||
+------------------------+
|
||||
|
||||
*/
|
||||
Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
|
@ -169,12 +204,14 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
|
||||
std::vector<std::reference_wrapper<Node>> nodes_to_fuse;
|
||||
bool second_formula = false;
|
||||
MatchResult matchRet = CheckFirstFormula(graph, node, nodes_to_fuse);
|
||||
if (!matchRet.matched) {
|
||||
nodes_to_fuse.clear();
|
||||
matchRet = CheckSecondFormula(graph, node, nodes_to_fuse);
|
||||
|
||||
if (!matchRet.matched) continue;
|
||||
second_formula = true;
|
||||
};
|
||||
|
||||
Node& tanh_node = *graph.GetNode(matchRet.tanh_input_node->OutputNodesBegin()->Index());
|
||||
|
|
@ -201,6 +238,30 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
input_index = optimizer_utils::IndexOfNodeInput(mul5_node, *add2_node.MutableOutputDefs()[0]);
|
||||
const Node* p_mul5_input_node = graph_utils::GetInputNode(mul5_node, (input_index + 1) % 2);
|
||||
if (p_mul5_input_node == nullptr) continue;
|
||||
|
||||
// if this is second formula and if pow node has Cast parent, expect mul5_node has Cast parent as well
|
||||
NodeArg* cast_input_arg = nullptr;
|
||||
if (second_formula) {
|
||||
const Node* p_cast1_node = graph_utils::FirstParentByType(node, "Cast");
|
||||
if (p_cast1_node != nullptr) {
|
||||
// we've done the node check in second formula for pow node
|
||||
Node& cast1_node = *graph.GetNode(p_cast1_node->Index());
|
||||
cast_input_arg = cast1_node.MutableInputDefs()[0];
|
||||
|
||||
const Node* p_cast3_node = graph_utils::FirstParentByType(mul5_node, "Cast");
|
||||
if (p_cast3_node == nullptr) continue;
|
||||
|
||||
Node& cast3_node = *graph.GetNode(p_cast3_node->Index());
|
||||
if (!CheckNode(graph, cast3_node, "Cast", {9, 13}, node.GetExecutionProviderType(), true)) {
|
||||
continue;
|
||||
}
|
||||
// overwrite and continue as usual
|
||||
p_mul5_input_node = graph_utils::FirstParentByType(cast3_node, "Mul");
|
||||
nodes_to_fuse.push_back(cast3_node);
|
||||
// keep cast1_node for reuse, its output edges will be adjusted in FinalizeNodeFusion()
|
||||
}
|
||||
}
|
||||
|
||||
Node& mul6_node = const_cast<Node&>(*p_mul5_input_node);
|
||||
if (!CheckNode(graph, mul6_node, "Mul", {7, 13}, node.GetExecutionProviderType(), false)) {
|
||||
continue;
|
||||
|
|
@ -214,8 +275,15 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
}
|
||||
}
|
||||
|
||||
if (input_index == -1 || mul6_node.InputDefs()[(input_index + 1) % 2]->Name() != matchRet.gelu_without_bias_input_arg->Name())
|
||||
continue;
|
||||
if (input_index == -1) continue;
|
||||
// check same parent for both mul6 and pow, with or without cast
|
||||
if (cast_input_arg != nullptr) {
|
||||
if (mul6_node.InputDefs()[(input_index + 1) % 2]->Name() != cast_input_arg->Name())
|
||||
continue;
|
||||
} else {
|
||||
if (mul6_node.InputDefs()[(input_index + 1) % 2]->Name() != matchRet.gelu_without_bias_input_arg->Name())
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<NodeArg*> gelu_input_defs{matchRet.gelu_without_bias_input_arg};
|
||||
nodes_to_fuse.insert(nodes_to_fuse.end(), {tanh_node, add2_node, mul6_node, mul5_node});
|
||||
|
|
|
|||
|
|
@ -206,7 +206,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
if (p_cast_node != nullptr) {
|
||||
Node& cast_node = *graph.GetNode(p_cast_node->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9, 13}) ||
|
||||
cast_node.GetExecutionProviderType() != cast_node.GetExecutionProviderType() ||
|
||||
cast_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, cast_node, 1)) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
#include "core/optimizer/bias_softmax_fusion.h"
|
||||
#include "core/optimizer/computation_reduction.h"
|
||||
#include "core/optimizer/cast_elimination.h"
|
||||
#include "core/optimizer/common_subexpression_elimination.h"
|
||||
#include "core/optimizer/concat_slice_elimination.h"
|
||||
#include "core/optimizer/constant_folding.h"
|
||||
#include "core/optimizer/conv_activation_fusion.h"
|
||||
|
|
@ -2479,6 +2480,34 @@ TEST_F(GraphTransformationTests, FastGeluWithBiasUseGraphInputFusionTest2) {
|
|||
ASSERT_TRUE(op_to_count["com.microsoft.FastGelu"] == 1);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, FastGeluFusionWithCastsTest3) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/fast_gelu3_with_casts.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
auto load_ret = Model::Load(model_uri, p_model, nullptr, *logger_);
|
||||
ASSERT_TRUE(load_ret.IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
// ORTModule for gpt2 model has two casts fused into one before FastGeluFusion
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<CommonSubexpressionElimination>(), TransformerLevel::Level1);
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Cast"] == 2);
|
||||
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<FastGeluFusion>(), TransformerLevel::Level2);
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Tanh"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Mul"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Cast"] == 1);
|
||||
ASSERT_TRUE(op_to_count["com.microsoft.FastGelu"] == 1);
|
||||
}
|
||||
|
||||
|
||||
struct BiasSoftmaxFusionTester {
|
||||
std::shared_ptr<Model> p_model_;
|
||||
Status model_load_;
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/fast_gelu3_with_casts.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/fast_gelu3_with_casts.onnx
vendored
Normal file
Binary file not shown.
92
onnxruntime/test/testdata/transform/fusion/fast_gelu3_with_casts.py
vendored
Normal file
92
onnxruntime/test/testdata/transform/fusion/fast_gelu3_with_casts.py
vendored
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto
|
||||
from onnx import numpy_helper
|
||||
import numpy as np
|
||||
|
||||
# Gelu formula: x * 0.5 * (1.0 + tanh((sqrt(2 / pi) * (x + 0.044715 * pow(x, 3)))))
|
||||
|
||||
X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", 64])
|
||||
Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "seqlen", 64])
|
||||
|
||||
pow_np_vals = np.asarray([3]).astype(np.float32).reshape(())
|
||||
pow_initializer = numpy_helper.from_array(pow_np_vals, "pow_init")
|
||||
|
||||
a_weight_np_vals = np.asarray([0.044714998453855515]).astype(np.float32).reshape(())
|
||||
a_weight_initializer = numpy_helper.from_array(a_weight_np_vals, "mul1_init")
|
||||
|
||||
b_weight_np_vals = np.asarray([0.7978845834732056]).astype(np.float32).reshape(())
|
||||
b_weight_initializer = numpy_helper.from_array(b_weight_np_vals, "mul2_init")
|
||||
|
||||
c_weight_np_vals = np.asarray([0.5]).astype(np.float32).reshape(())
|
||||
c_weight_initializer = numpy_helper.from_array(c_weight_np_vals, "mul3_init")
|
||||
|
||||
b_bias_np_vals = np.asarray([1.0]).astype(np.float32).reshape(())
|
||||
b_bias_initializer = numpy_helper.from_array(b_bias_np_vals, "add2_init")
|
||||
|
||||
nodes = []
|
||||
gelu_input = "input"
|
||||
leading_identity = helper.make_node('Identity', [gelu_input], ['identity_leading'], name="identity_leading")
|
||||
gelu_input = "identity_leading"
|
||||
nodes.append(leading_identity)
|
||||
|
||||
mul_input_name = gelu_input
|
||||
|
||||
cast1 = helper.make_node('Cast', [mul_input_name], ['cast1'], name='cast1', to=1)
|
||||
nodes.append(cast1)
|
||||
|
||||
pow1 = helper.make_node('Pow', ['cast1', pow_initializer.name], ['pow1'], name="pow1")
|
||||
nodes.append(pow1)
|
||||
|
||||
mul1 = helper.make_node('Mul', ['pow1', a_weight_initializer.name], ['mul1'], name="mul1")
|
||||
nodes.append(mul1)
|
||||
|
||||
cast2 = helper.make_node('Cast', [mul_input_name], ['cast2'], name='cast2', to=1)
|
||||
nodes.append(cast2)
|
||||
|
||||
add1 = helper.make_node('Add', ['mul1', 'cast2'], ['add1'], name="add1")
|
||||
nodes.append(add1)
|
||||
|
||||
mul2 = helper.make_node('Mul', ['add1', b_weight_initializer.name], ['mul2'], name="mul2")
|
||||
nodes.append(mul2)
|
||||
|
||||
tanh = helper.make_node('Tanh', ['mul2'], ['tanh'], name="tanh")
|
||||
nodes.append(tanh)
|
||||
|
||||
add2 = helper.make_node('Add', ['tanh', b_bias_initializer.name], ['add2'], name="add2")
|
||||
nodes.append(add2)
|
||||
|
||||
mul5 = helper.make_node('Mul', [mul_input_name, c_weight_initializer.name], ['mul5'], name="mul5")
|
||||
nodes.append(mul5)
|
||||
|
||||
cast3 = helper.make_node('Cast', ['mul5'], ['cast3'], name='cast3', to=1)
|
||||
nodes.append(cast3)
|
||||
|
||||
mul6 = helper.make_node('Mul', ['cast3', 'add2'], ['mul6'], name="mul6")
|
||||
ending_identity = helper.make_node('Identity', ['mul6'], ['output'], name="ending_identity")
|
||||
nodes.extend([mul6, ending_identity])
|
||||
|
||||
initializers = []
|
||||
|
||||
initializers.extend(
|
||||
[pow_initializer, a_weight_initializer, b_weight_initializer, b_bias_initializer, c_weight_initializer])
|
||||
# Create the graph (GraphProto)
|
||||
graph_def = helper.make_graph(nodes, 'test-model', [X], [Y], initializers)
|
||||
|
||||
opsets = []
|
||||
onnxdomain = OperatorSetIdProto()
|
||||
onnxdomain.version = 13
|
||||
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
|
||||
opsets.append(onnxdomain)
|
||||
|
||||
msdomain = OperatorSetIdProto()
|
||||
msdomain.version = 1
|
||||
msdomain.domain = "com.microsoft"
|
||||
|
||||
opsets.append(msdomain)
|
||||
kwargs = {}
|
||||
kwargs["opset_imports"] = opsets
|
||||
|
||||
model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs)
|
||||
|
||||
onnx.save(model_def, "fast_gelu3_with_casts.onnx")
|
||||
|
|
@ -7,7 +7,7 @@ from enum import Enum
|
|||
|
||||
|
||||
def GenerateModel(model_name):
|
||||
nodes = [ # SimplifiedLayerNorm subgraph
|
||||
nodes = [ # LayerNormWithCast2 subgraph
|
||||
helper.make_node("ReduceMean", ["A"], ["rd1_out"], "reduce", axes=[-1]),
|
||||
helper.make_node("Sub", ["A", "rd1_out"], ["sub1_out"], "sub"),
|
||||
helper.make_node("Cast", ["pow_in_2"], ["cast_out"], "cast", to=10),
|
||||
|
|
@ -29,7 +29,7 @@ def GenerateModel(model_name):
|
|||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
"SimplifiedLayerNorm", #name
|
||||
"LayerNormWithCast2", #name
|
||||
[ # inputs
|
||||
helper.make_tensor_value_info('A', TensorProto.FLOAT16, [16, 32, 4]),
|
||||
],
|
||||
|
|
|
|||
Loading…
Reference in a new issue