diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index b2d0b9d9d0..b912ed30cf 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -45,6 +45,7 @@ #include "core/util/math.h" #include "test/capturing_sink.h" #include "test/framework/test_utils.h" +#include "test/optimizer/graph_transform_test_fixture.h" #include "test/providers/provider_test_utils.h" #include "test/test_environment.h" #include "asserts.h" @@ -58,15 +59,6 @@ namespace test { #define MODEL_FOLDER ORT_TSTR("testdata/transform/") -class GraphTransformationTests : public ::testing::Test { - protected: - GraphTransformationTests() { - logger_ = DefaultLoggingManager().CreateLogger("GraphTransformationTests"); - } - - std::unique_ptr logger_; -}; - TEST_F(GraphTransformationTests, IdentityElimination) { auto model_uri = MODEL_FOLDER "abs-id-max.onnx"; std::shared_ptr model; diff --git a/onnxruntime/test/optimizer/graph_transform_test_fixture.h b/onnxruntime/test/optimizer/graph_transform_test_fixture.h new file mode 100644 index 0000000000..5ec89028a6 --- /dev/null +++ b/onnxruntime/test/optimizer/graph_transform_test_fixture.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "gtest/gtest.h" + +#include "test/test_environment.h" + +namespace onnxruntime { +namespace test { + +class GraphTransformationTests : public ::testing::Test { + protected: + GraphTransformationTests() { + logger_ = DefaultLoggingManager().CreateLogger("GraphTransformationTests"); + } + + std::unique_ptr logger_; +}; + +} // namespace test +} // namespace onnxruntime diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index 2c07a65f8f..f92a526697 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -13,6 +13,7 @@ #include "core/optimizer/utils.h" #include "orttraining/core/optimizer/gist_encode_decode.h" #include "orttraining/core/optimizer/megatron_transformer.h" +#include "test/optimizer/graph_transform_test_fixture.h" #include "test/util/include/default_providers.h" #include "orttraining/test/optimizer/horizontal_parallel_test_utils.h" @@ -26,10 +27,10 @@ namespace test { #define MODEL_FOLDER ORT_TSTR("testdata/transform/") -TEST(GraphTransformationTests, GistEncodeDecode) { +TEST_F(GraphTransformationTests, GistEncodeDecode) { auto model_uri = MODEL_FOLDER "../test_training_model.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK()); Graph& graph = p_model->MainGraph(); auto rule_transformer_L1 = onnxruntime::make_unique("RuleGistTransformer1"); @@ -37,7 +38,7 @@ TEST(GraphTransformationTests, GistEncodeDecode) { onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -61,15 +62,15 @@ Node* GetNodeByName(Graph& graph, std::string node_name) { // MegatronF/G is defined only for training, and in msdomain. #ifndef DISABLE_CONTRIB_OPS -TEST(GraphTransformationTests, MegatronMLPPartitionRank0) { +TEST_F(GraphTransformationTests, MegatronMLPPartitionRank0) { auto model_uri = MODEL_FOLDER "model_parallel/mlp_megatron_basic_test.onnx"; std::shared_ptr p_model; - auto ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); + auto ret = Model::Load(model_uri, p_model, nullptr, *logger_); ASSERT_TRUE(ret.IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(0, 2), TransformerLevel::Level1); - ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()); + ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); ASSERT_TRUE(ret.IsOK()); auto model_uri2 = "mlp_megatron_basic_test_partition_rank0.onnx"; @@ -128,16 +129,16 @@ TEST(GraphTransformationTests, MegatronMLPPartitionRank0) { } } -TEST(GraphTransformationTests, MegatronMLPPartitionRank1) { +TEST_F(GraphTransformationTests, MegatronMLPPartitionRank1) { auto model_uri = MODEL_FOLDER "model_parallel/mlp_megatron_basic_test.onnx"; std::shared_ptr p_model; - auto ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); + auto ret = Model::Load(model_uri, p_model, nullptr, *logger_); ASSERT_TRUE(ret.IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(1, 2), TransformerLevel::Level1); - ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()); + ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); ASSERT_TRUE(ret.IsOK()); auto model_uri2 = "mlp_megatron_basic_test_partition_rank1.onnx"; @@ -196,15 +197,15 @@ TEST(GraphTransformationTests, MegatronMLPPartitionRank1) { } } -TEST(GraphTransformationTests, MegatronSelfAttentionPartitionRank0) { +TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionRank0) { auto model_uri = MODEL_FOLDER "model_parallel/self_attention_megatron_basic_test.onnx"; std::shared_ptr p_model; - auto ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); + auto ret = Model::Load(model_uri, p_model, nullptr, *logger_); ASSERT_TRUE(ret.IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(0, 2), TransformerLevel::Level1); - ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()); + ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); ASSERT_TRUE(ret.IsOK()); auto model_uri2 = "self_attention_megatron_basic_test_partition_rank0.onnx"; @@ -260,16 +261,16 @@ TEST(GraphTransformationTests, MegatronSelfAttentionPartitionRank0) { } } -TEST(GraphTransformationTests, MegatronSelfAttentionPartitionRank1) { +TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionRank1) { auto model_uri = MODEL_FOLDER "model_parallel/self_attention_megatron_basic_test.onnx"; std::shared_ptr p_model; - auto ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); + auto ret = Model::Load(model_uri, p_model, nullptr, *logger_); ASSERT_TRUE(ret.IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(1, 2), TransformerLevel::Level1); - ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()); + ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); ASSERT_TRUE(ret.IsOK()); auto model_uri2 = "self_attention_megatron_basic_test_partition_rank1.onnx"; @@ -327,23 +328,23 @@ TEST(GraphTransformationTests, MegatronSelfAttentionPartitionRank1) { // We only tested on CUDA run. #if defined(USE_CUDA) -TEST(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) { +TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) { auto model_uri = MODEL_FOLDER "model_parallel/mlp_megatron_basic_test.onnx"; const int total_rank = 4; std::vector graphs; std::vector> p_models(total_rank); for (auto i = 0; i < total_rank; i++) { - auto ret = Model::Load(model_uri, p_models[i], nullptr, DefaultLoggingManager().DefaultLogger()); + auto ret = Model::Load(model_uri, p_models[i], nullptr, *logger_); ASSERT_TRUE(ret.IsOK()); Graph& graph = p_models[i]->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(i, total_rank), TransformerLevel::Level1); - ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()); + ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); ASSERT_TRUE(ret.IsOK()); graphs.push_back(&graph); } - onnxruntime::Model combine_model("combine_graph", false, DefaultLoggingManager().DefaultLogger()); + onnxruntime::Model combine_model("combine_graph", false, *logger_); auto& combine_graph = combine_model.MainGraph(); auto ret = horizontal_parallel_test_utils::MergeGraphsOnAllWorkers(graphs, combine_graph); ORT_ENFORCE(ret.IsOK()); @@ -430,18 +431,18 @@ TEST(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) { } } -TEST(GraphTransformationTests, MegatronSelfAttentionPartitionCorrectnessTest) { +TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionCorrectnessTest) { auto model_uri = MODEL_FOLDER "model_parallel/self_attention_megatron_basic_test.onnx"; const int total_rank = 2; // The test graph is too small to partition to 4, so use 2 instead here. std::vector graphs; std::vector> p_models(total_rank); for (auto i = 0; i < total_rank; i++) { - auto ret = Model::Load(model_uri, p_models[i], nullptr, DefaultLoggingManager().DefaultLogger()); + auto ret = Model::Load(model_uri, p_models[i], nullptr, *logger_); ASSERT_TRUE(ret.IsOK()); Graph& graph = p_models[i]->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(i, total_rank), TransformerLevel::Level1); - ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()); + ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); ASSERT_TRUE(ret.IsOK()); graphs.push_back(&graph); } @@ -460,7 +461,7 @@ TEST(GraphTransformationTests, MegatronSelfAttentionPartitionCorrectnessTest) { ORT_ENFORCE(attr != nullptr && attr->has_i() && attr->i() == dropout2_rank0_seed); } - onnxruntime::Model combine_model("combine_graph", false, DefaultLoggingManager().DefaultLogger()); + onnxruntime::Model combine_model("combine_graph", false, *logger_); auto& combine_graph = combine_model.MainGraph(); auto ret = horizontal_parallel_test_utils::MergeGraphsOnAllWorkers(graphs, combine_graph); ORT_ENFORCE(ret.IsOK());