mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Fix GraphTransformationTests tests.
This commit is contained in:
parent
87fad09c7b
commit
d50c3e7a71
3 changed files with 50 additions and 32 deletions
|
|
@ -45,6 +45,7 @@
|
|||
#include "core/util/math.h"
|
||||
#include "test/capturing_sink.h"
|
||||
#include "test/framework/test_utils.h"
|
||||
#include "test/optimizer/graph_transform_test_fixture.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/test_environment.h"
|
||||
#include "asserts.h"
|
||||
|
|
@ -58,15 +59,6 @@ namespace test {
|
|||
|
||||
#define MODEL_FOLDER ORT_TSTR("testdata/transform/")
|
||||
|
||||
class GraphTransformationTests : public ::testing::Test {
|
||||
protected:
|
||||
GraphTransformationTests() {
|
||||
logger_ = DefaultLoggingManager().CreateLogger("GraphTransformationTests");
|
||||
}
|
||||
|
||||
std::unique_ptr<logging::Logger> logger_;
|
||||
};
|
||||
|
||||
TEST_F(GraphTransformationTests, IdentityElimination) {
|
||||
auto model_uri = MODEL_FOLDER "abs-id-max.onnx";
|
||||
std::shared_ptr<Model> model;
|
||||
|
|
|
|||
25
onnxruntime/test/optimizer/graph_transform_test_fixture.h
Normal file
25
onnxruntime/test/optimizer/graph_transform_test_fixture.h
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "test/test_environment.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
class GraphTransformationTests : public ::testing::Test {
|
||||
protected:
|
||||
GraphTransformationTests() {
|
||||
logger_ = DefaultLoggingManager().CreateLogger("GraphTransformationTests");
|
||||
}
|
||||
|
||||
std::unique_ptr<logging::Logger> logger_;
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -13,6 +13,7 @@
|
|||
#include "core/optimizer/utils.h"
|
||||
#include "orttraining/core/optimizer/gist_encode_decode.h"
|
||||
#include "orttraining/core/optimizer/megatron_transformer.h"
|
||||
#include "test/optimizer/graph_transform_test_fixture.h"
|
||||
#include "test/util/include/default_providers.h"
|
||||
#include "orttraining/test/optimizer/horizontal_parallel_test_utils.h"
|
||||
|
||||
|
|
@ -26,10 +27,10 @@ namespace test {
|
|||
|
||||
#define MODEL_FOLDER ORT_TSTR("testdata/transform/")
|
||||
|
||||
TEST(GraphTransformationTests, GistEncodeDecode) {
|
||||
TEST_F(GraphTransformationTests, GistEncodeDecode) {
|
||||
auto model_uri = MODEL_FOLDER "../test_training_model.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
|
||||
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
auto rule_transformer_L1 = onnxruntime::make_unique<RuleBasedGraphTransformer>("RuleGistTransformer1");
|
||||
|
|
@ -37,7 +38,7 @@ TEST(GraphTransformationTests, GistEncodeDecode) {
|
|||
onnxruntime::GraphTransformerManager graph_transformation_mgr{1};
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
|
||||
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger());
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
|
|
@ -61,15 +62,15 @@ Node* GetNodeByName(Graph& graph, std::string node_name) {
|
|||
|
||||
// MegatronF/G is defined only for training, and in msdomain.
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
TEST(GraphTransformationTests, MegatronMLPPartitionRank0) {
|
||||
TEST_F(GraphTransformationTests, MegatronMLPPartitionRank0) {
|
||||
auto model_uri = MODEL_FOLDER "model_parallel/mlp_megatron_basic_test.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
auto ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
|
||||
auto ret = Model::Load(model_uri, p_model, nullptr, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(0, 2), TransformerLevel::Level1);
|
||||
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger());
|
||||
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
auto model_uri2 = "mlp_megatron_basic_test_partition_rank0.onnx";
|
||||
|
|
@ -128,16 +129,16 @@ TEST(GraphTransformationTests, MegatronMLPPartitionRank0) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, MegatronMLPPartitionRank1) {
|
||||
TEST_F(GraphTransformationTests, MegatronMLPPartitionRank1) {
|
||||
auto model_uri = MODEL_FOLDER "model_parallel/mlp_megatron_basic_test.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
auto ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
|
||||
auto ret = Model::Load(model_uri, p_model, nullptr, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(1, 2), TransformerLevel::Level1);
|
||||
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger());
|
||||
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
auto model_uri2 = "mlp_megatron_basic_test_partition_rank1.onnx";
|
||||
|
|
@ -196,15 +197,15 @@ TEST(GraphTransformationTests, MegatronMLPPartitionRank1) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, MegatronSelfAttentionPartitionRank0) {
|
||||
TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionRank0) {
|
||||
auto model_uri = MODEL_FOLDER "model_parallel/self_attention_megatron_basic_test.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
auto ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
|
||||
auto ret = Model::Load(model_uri, p_model, nullptr, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(0, 2), TransformerLevel::Level1);
|
||||
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger());
|
||||
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
auto model_uri2 = "self_attention_megatron_basic_test_partition_rank0.onnx";
|
||||
|
|
@ -260,16 +261,16 @@ TEST(GraphTransformationTests, MegatronSelfAttentionPartitionRank0) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, MegatronSelfAttentionPartitionRank1) {
|
||||
TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionRank1) {
|
||||
auto model_uri = MODEL_FOLDER "model_parallel/self_attention_megatron_basic_test.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
auto ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger());
|
||||
auto ret = Model::Load(model_uri, p_model, nullptr, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(1, 2), TransformerLevel::Level1);
|
||||
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger());
|
||||
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
auto model_uri2 = "self_attention_megatron_basic_test_partition_rank1.onnx";
|
||||
|
|
@ -327,23 +328,23 @@ TEST(GraphTransformationTests, MegatronSelfAttentionPartitionRank1) {
|
|||
|
||||
// We only tested on CUDA run.
|
||||
#if defined(USE_CUDA)
|
||||
TEST(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {
|
||||
TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {
|
||||
auto model_uri = MODEL_FOLDER "model_parallel/mlp_megatron_basic_test.onnx";
|
||||
const int total_rank = 4;
|
||||
std::vector<Graph*> graphs;
|
||||
std::vector<std::shared_ptr<Model>> p_models(total_rank);
|
||||
for (auto i = 0; i < total_rank; i++) {
|
||||
auto ret = Model::Load(model_uri, p_models[i], nullptr, DefaultLoggingManager().DefaultLogger());
|
||||
auto ret = Model::Load(model_uri, p_models[i], nullptr, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
Graph& graph = p_models[i]->MainGraph();
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(i, total_rank), TransformerLevel::Level1);
|
||||
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger());
|
||||
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
graphs.push_back(&graph);
|
||||
}
|
||||
|
||||
onnxruntime::Model combine_model("combine_graph", false, DefaultLoggingManager().DefaultLogger());
|
||||
onnxruntime::Model combine_model("combine_graph", false, *logger_);
|
||||
auto& combine_graph = combine_model.MainGraph();
|
||||
auto ret = horizontal_parallel_test_utils::MergeGraphsOnAllWorkers(graphs, combine_graph);
|
||||
ORT_ENFORCE(ret.IsOK());
|
||||
|
|
@ -430,18 +431,18 @@ TEST(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, MegatronSelfAttentionPartitionCorrectnessTest) {
|
||||
TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionCorrectnessTest) {
|
||||
auto model_uri = MODEL_FOLDER "model_parallel/self_attention_megatron_basic_test.onnx";
|
||||
const int total_rank = 2; // The test graph is too small to partition to 4, so use 2 instead here.
|
||||
std::vector<Graph*> graphs;
|
||||
std::vector<std::shared_ptr<Model>> p_models(total_rank);
|
||||
for (auto i = 0; i < total_rank; i++) {
|
||||
auto ret = Model::Load(model_uri, p_models[i], nullptr, DefaultLoggingManager().DefaultLogger());
|
||||
auto ret = Model::Load(model_uri, p_models[i], nullptr, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
Graph& graph = p_models[i]->MainGraph();
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(i, total_rank), TransformerLevel::Level1);
|
||||
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger());
|
||||
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
graphs.push_back(&graph);
|
||||
}
|
||||
|
|
@ -460,7 +461,7 @@ TEST(GraphTransformationTests, MegatronSelfAttentionPartitionCorrectnessTest) {
|
|||
ORT_ENFORCE(attr != nullptr && attr->has_i() && attr->i() == dropout2_rank0_seed);
|
||||
}
|
||||
|
||||
onnxruntime::Model combine_model("combine_graph", false, DefaultLoggingManager().DefaultLogger());
|
||||
onnxruntime::Model combine_model("combine_graph", false, *logger_);
|
||||
auto& combine_graph = combine_model.MainGraph();
|
||||
auto ret = horizontal_parallel_test_utils::MergeGraphsOnAllWorkers(graphs, combine_graph);
|
||||
ORT_ENFORCE(ret.IsOK());
|
||||
|
|
|
|||
Loading…
Reference in a new issue