support opset13 in embednorm (#6866)

This commit is contained in:
Ye Wang 2021-03-02 12:33:40 -08:00 committed by GitHub
parent 0d0eb2c85c
commit 9073f7a5c3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 117 additions and 46 deletions

View file

@ -112,10 +112,10 @@ static bool MatchInputToConcatSubgraph(
const logging::Logger& logger,
const NodeIndex expected_gather_node_1_index) {
std::vector<graph_utils::EdgeEndToMatch> expand_parent_path1{
{0, index, "Concat", {4, 11}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Gather", {1, 11}, kOnnxDomain},
{0, 0, "Shape", {1}, kOnnxDomain},
{0, index, "Concat", {4, 11, 13}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
{0, 0, "Gather", {1, 11, 13}, kOnnxDomain},
{0, 0, "Shape", {1, 13}, kOnnxDomain},
};
std::vector<const Node::EdgeEnd*> edges;
@ -145,9 +145,9 @@ static bool MatchInputToConcatSubgraph(
}
std::vector<graph_utils::EdgeEndToMatch> concat_parent_path{
{0, 1, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Gather", {1, 11}, kOnnxDomain},
{0, 0, "Shape", {1}, kOnnxDomain}};
{0, 1, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
{0, 0, "Gather", {1, 11, 13}, kOnnxDomain},
{0, 0, "Shape", {1, 13}, kOnnxDomain}};
if (!graph_utils::FindPath(concat_node, true, concat_parent_path, edges, logger)) {
DEBUG_LOG("Failed to find path 2 of position shape.");
@ -316,7 +316,7 @@ static bool MatchPositionEmbeddingSubgraphsFromGather(
// Match Shape --> Expand path.
std::vector<const Node::EdgeEnd*> pg_edges_2;
if (!graph_utils::FindPath(expand_node, true, {{0, 1, "Shape", {1}, kOnnxDomain}}, pg_edges_2, logger)) {
if (!graph_utils::FindPath(expand_node, true, {{0, 1, "Shape", {1, 13}, kOnnxDomain}}, pg_edges_2, logger)) {
DEBUG_LOG("Failed to match Shape node. ");
return false;
}

View file

@ -2897,15 +2897,47 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat2) {
ASSERT_TRUE(op_to_count["com.microsoft.EmbedLayerNormalization"] == 1);
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3) {
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format3.onnx";
static void EmbedLayerNormFusionFormat3(const std::basic_string<ORTCHAR_T>& file_path, logging::Logger* logger) {
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK());
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<EmbedLayerNormFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger);
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
EXPECT_EQ(op_to_count["Shape"], 0);
EXPECT_EQ(op_to_count["Expand"], 0);
EXPECT_EQ(op_to_count["Gather"], 0);
EXPECT_EQ(op_to_count["Unsqueeze"], 0);
EXPECT_EQ(op_to_count["LayerNormalization"], 0);
EXPECT_EQ(op_to_count["com.microsoft.SkipLayerNormalization"], 0);
EXPECT_EQ(op_to_count["ReduceSum"], 1);
EXPECT_EQ(op_to_count["MatMul"], 1);
EXPECT_EQ(op_to_count["Add"], 2);
EXPECT_EQ(op_to_count["Cast"], 3);
EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1);
EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1);
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3) {
EmbedLayerNormFusionFormat3(MODEL_FOLDER "fusion/embed_layer_norm_format3.onnx", logger_.get());
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3_OpSet13) {
EmbedLayerNormFusionFormat3(MODEL_FOLDER "fusion/embed_layer_norm_format3_opset13.onnx", logger_.get());
}
static void EmbedLayerNormFusionFormat3NoCast(const std::basic_string<ORTCHAR_T>& file_path, logging::Logger* logger) {
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK());
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<EmbedLayerNormFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger);
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@ -2924,29 +2956,11 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3) {
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3NoCast) {
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format3_no_cast.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
EmbedLayerNormFusionFormat3NoCast(MODEL_FOLDER "fusion/embed_layer_norm_format3_no_cast.onnx", logger_.get());
}
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<EmbedLayerNormFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
EXPECT_EQ(op_to_count["Shape"], 0);
EXPECT_EQ(op_to_count["Expand"], 0);
EXPECT_EQ(op_to_count["Gather"], 0);
EXPECT_EQ(op_to_count["Unsqueeze"], 0);
EXPECT_EQ(op_to_count["LayerNormalization"], 0);
EXPECT_EQ(op_to_count["com.microsoft.SkipLayerNormalization"], 0);
EXPECT_EQ(op_to_count["ReduceSum"], 1);
EXPECT_EQ(op_to_count["MatMul"], 1);
EXPECT_EQ(op_to_count["Add"], 2);
EXPECT_EQ(op_to_count["Cast"], 3);
EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1);
EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1);
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3NoCast_OpSet13) {
EmbedLayerNormFusionFormat3NoCast(MODEL_FOLDER "fusion/embed_layer_norm_format3_no_cast_opset13.onnx", logger_.get());
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat4) {
@ -2977,15 +2991,14 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat4) {
ASSERT_TRUE(op_to_count["com.microsoft.EmbedLayerNormalization"] == 1);
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat5) {
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format5.onnx";
static void EmbedLayerNormFusionFormat5(const std::basic_string<ORTCHAR_T>& file_path, logging::Logger* logger) {
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK());
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<EmbedLayerNormFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger);
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@ -3019,15 +3032,22 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat5) {
}
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6) {
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format6.onnx";
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat5) {
EmbedLayerNormFusionFormat5(MODEL_FOLDER "fusion/embed_layer_norm_format5.onnx", logger_.get());
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat5_OpSet13) {
EmbedLayerNormFusionFormat5(MODEL_FOLDER "fusion/embed_layer_norm_format5_opset13.onnx", logger_.get());
}
static void EmbedLayerNormFusionFormat6(const std::basic_string<ORTCHAR_T>& file_path, logging::Logger* logger) {
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK());
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<EmbedLayerNormFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger);
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@ -3048,6 +3068,14 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6) {
EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1);
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6) {
EmbedLayerNormFusionFormat6(MODEL_FOLDER "fusion/embed_layer_norm_format6.onnx", logger_.get());
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6_OpSet13) {
EmbedLayerNormFusionFormat6(MODEL_FOLDER "fusion/embed_layer_norm_format6_opset13.onnx", logger_.get());
}
static void TestEmbedLayerNormFusionDistilBert(const std::basic_string<ORTCHAR_T>& model_uri,
std::map<std::string, int>& op_to_count,
logging::Logger* logger) {
@ -3076,6 +3104,18 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7) {
EXPECT_EQ(op_to_count["ReduceSum"], 1);
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7_OpSet13) {
std::map<std::string, int> op_to_count;
TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format7_opset13.onnx", op_to_count, logger_.get());
EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1);
EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1);
EXPECT_EQ(op_to_count["Cast"], 2);
EXPECT_EQ(op_to_count["Shape"], 0);
EXPECT_EQ(op_to_count["Gather"], 0);
EXPECT_EQ(op_to_count["Unsqueeze"], 0);
EXPECT_EQ(op_to_count["ReduceSum"], 1);
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat8) {
std::map<std::string, int> op_to_count;
TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format8.onnx", op_to_count, logger_.get());
@ -3088,6 +3128,18 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat8) {
EXPECT_EQ(op_to_count["ReduceSum"], 1);
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat8_OpSet13) {
std::map<std::string, int> op_to_count;
TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format8_opset13.onnx", op_to_count, logger_.get());
EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1);
EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1);
EXPECT_EQ(op_to_count["Cast"], 2);
EXPECT_EQ(op_to_count["Shape"], 0);
EXPECT_EQ(op_to_count["Gather"], 0);
EXPECT_EQ(op_to_count["Unsqueeze"], 0);
EXPECT_EQ(op_to_count["ReduceSum"], 1);
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat9) {
std::map<std::string, int> op_to_count;
TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format9.onnx", op_to_count, logger_.get());
@ -3100,15 +3152,26 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat9) {
EXPECT_EQ(op_to_count["ReduceSum"], 1);
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionMultiple) {
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_multiple.onnx";
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat9_OpSet13) {
std::map<std::string, int> op_to_count;
TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format9_opset13.onnx", op_to_count, logger_.get());
EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1);
EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1);
EXPECT_EQ(op_to_count["Cast"], 2);
EXPECT_EQ(op_to_count["Shape"], 1);
EXPECT_EQ(op_to_count["Gather"], 2);
EXPECT_EQ(op_to_count["Unsqueeze"], 2);
EXPECT_EQ(op_to_count["ReduceSum"], 1);
}
static void EmbedLayerNormFusionFormatMultiple(const std::basic_string<ORTCHAR_T>& file_path, logging::Logger* logger) {
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK());
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<EmbedLayerNormFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger);
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@ -3126,6 +3189,14 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionMultiple) {
EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 2);
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionMultiple) {
EmbedLayerNormFusionFormatMultiple(MODEL_FOLDER "fusion/embed_layer_norm_multiple.onnx", logger_.get());
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionMultiple_OpSet13) {
EmbedLayerNormFusionFormatMultiple(MODEL_FOLDER "fusion/embed_layer_norm_multiple_opset13.onnx", logger_.get());
}
TEST_F(GraphTransformationTests, DynamicQuantizeMatMulTest) {
auto model_uri = MODEL_FOLDER "fusion/dynamic_quantize_matmul.onnx";
std::shared_ptr<Model> p_model;