Shape independent gradient builder for Concat (#4675)

* Add gradient for ConcatTraining

* Graph rewriter changes for concat

* Add generated onnx graph, minor fixes

* Revert unintended change

* Fix for MaxPoolGradTest

* Fix UT

* Review comments, windows tests

* Review comments
This commit is contained in:
ashbhandare 2020-08-06 14:39:33 -07:00 committed by GitHub
parent 8507bc1f48
commit fc2f36c608
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 331 additions and 65 deletions

View file

@ -0,0 +1,55 @@
import onnx
from onnx import helper
from onnx import TensorProto
import numpy as np
def GenerateModel(model_name):
nodes = [
helper.make_node("Gather", ["embed_weights","input_1"], ["gather_out"], "gather"),
helper.make_node("Add", ["gather_out", "add_q_weight"], ["add_q_out"], "add_q"),
helper.make_node("Add", ["gather_out", "add_k_weight"], ["add_k_out"], "add_k"),
helper.make_node("Add", ["gather_out", "add_v_weight"], ["add_v_out"], "add_v"),
helper.make_node("Concat", ["add_q_out", "add_k_out", "add_v_out"],
["concat_out"], "concat", axis=0),
helper.make_node("Add", ["add_qkv_weight", "concat_out"], ["add_out"], "add"),
helper.make_node("ReduceSum",["add_out"],["predictions"],"reduce_sum_1", axes=[0], keepdims=1),
]
embed_weights = np.random.uniform(-1,1,8000).tolist()
add_q_weight = [-0.23681640625, -0.16552734375, 0.2191162109375, -0.1756591796875,
-0.03460693359375, -0.05316162109375, -0.336181640625, -0.253662109375]
add_k_weight = [0.0246734619140625, 0.011993408203125, 0.0178375244140625, 0.00998687744140625,
0.0255126953125, 0.076416015625, -0.040771484375, 0.0107879638671875]
add_v_weight = [-0.005893707275390625, -0.00916290283203125, 0.04541015625, 0.0159454345703125,
-0.0029163360595703125, -0.03472900390625, 0.0535888671875, 0.0091094970703125]
initializers = [ # initializers
helper.make_tensor('embed_weights', TensorProto.FLOAT, [1000, 8], embed_weights),
helper.make_tensor('add_q_weight', TensorProto.FLOAT, [8], add_q_weight),
helper.make_tensor('add_k_weight', TensorProto.FLOAT, [8], add_k_weight),
helper.make_tensor('add_v_weight', TensorProto.FLOAT, [8], add_v_weight),
helper.make_tensor('add_qkv_weight', TensorProto.FLOAT, [1], [1.0]),
]
graph = helper.make_graph(
nodes,
"ConcatThreeInputs", #name
[ # inputs
helper.make_tensor_value_info('input_1', TensorProto.INT64, ['batch', 'seq_len'])
],
[ # outputs
helper.make_tensor_value_info('predictions', TensorProto.FLOAT, [1,1,8]),
],
initializers)
model = helper.make_model(graph)
onnx.save(model, model_name)
GenerateModel('concat_trainable.onnx')

Binary file not shown.

View file

@ -9,7 +9,9 @@
#include "onnx/defs/attr_proto_util.h"
#include "onnx/defs/tensor_proto_util.h"
#include "core/framework/tensorprotoutils.h"
#include "core/providers/common.h"
#include "orttraining/core/framework/distributed_run_context.h"
#include "orttraining/core/graph/gradient_builder_registry.h"
#include "orttraining/core/graph/graph_augmenter.h"
@ -534,6 +536,49 @@ IMPLEMENT_GRADIENT_BUILDER(GetConcatGradient) {
new_attributes)};
}
IMPLEMENT_GRADIENT_BUILDER(GetConcatTrainingGradient) {
auto attributes = SrcNodeAttributes();
ORT_ENFORCE(utils::HasInt(attributes.at("axis")));
auto axis = attributes.at("axis").i();
std::vector<int64_t> split_attribute(GetSrcNodeInputSize());
std::vector<ArgDef> outputs;
bool known_shapes = true;
for (int i = 0; i < GetSrcNodeInputSize(); ++i) {
std::vector<Dimension> data_shape;
if (GetShape(I(i), data_shape).IsOK()) {
int64_t rank = static_cast<int64_t>(data_shape.size());
int64_t axis_index = HandleNegativeAxis(axis, rank);
if (data_shape[axis_index].has_dim_value()) {
split_attribute[i] = data_shape[axis_index].dim_value();
} else {
known_shapes = false;
}
} else {
known_shapes = false;
}
outputs.push_back(GI(i));
}
std::vector<AttributeProto> new_attributes;
new_attributes.push_back(MakeAttribute("axis", axis));
if (known_shapes) {
new_attributes.push_back(MakeAttribute("split", split_attribute));
return std::vector<NodeDef>{
NodeDef("Split",
{GO(0)},
outputs,
new_attributes)};
} else {
return std::vector<NodeDef>{
NodeDef(OpDef{"SplitTraining", kMSDomain, 1},
{GO(0), O(1)},
outputs,
new_attributes)};
}
}
IMPLEMENT_GRADIENT_BUILDER(GetGatherNDGradient) {
auto attributes = SrcNodeAttributes();
ORT_ENFORCE(attributes.at("batch_dims").has_i());

View file

@ -28,6 +28,7 @@ DECLARE_GRADIENT_BUILDER(GetReduceSumGradient)
DECLARE_GRADIENT_BUILDER(GetReduceLogSumExpGradient)
DECLARE_GRADIENT_BUILDER(GetPowGradient)
DECLARE_GRADIENT_BUILDER(GetConcatGradient)
DECLARE_GRADIENT_BUILDER(GetConcatTrainingGradient)
DECLARE_GRADIENT_BUILDER(GetReshapeGradient)
DECLARE_GRADIENT_BUILDER(GetTransposeGradient)
DECLARE_GRADIENT_BUILDER(GetPoolGradient)

View file

@ -60,6 +60,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Mul", GetMulGradient);
REGISTER_GRADIENT_BUILDER("Div", GetDivGradient);
REGISTER_GRADIENT_BUILDER("Concat", GetConcatGradient);
REGISTER_GRADIENT_BUILDER("ConcatTraining", GetConcatTrainingGradient);
REGISTER_GRADIENT_BUILDER("Reshape", GetReshapeGradient);
REGISTER_GRADIENT_BUILDER("Transpose", GetTransposeGradient);
REGISTER_GRADIENT_BUILDER("Gemm", GetGemmGradient);

View file

@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "orttraining/core/optimizer/concat_replacement.h"
#include "core/common/logging/logging.h"
#include "core/optimizer/rewrite_rule.h"
#include "core/optimizer/utils.h"
#include "core/graph/graph.h"
#include "core/graph/graph_utils.h"
namespace onnxruntime {
Status ConcatReplacement::Apply(Graph& graph, Node& concat_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
const auto& concat_inputs = concat_node.MutableInputDefs();
auto& concat_outputs = concat_node.MutableOutputDefs();
ONNX_NAMESPACE::TypeProto t;
t.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
t.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(concat_inputs.size());
NodeArg& ip_shape_op = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("per_input_length"), &t);
concat_outputs.push_back(&ip_shape_op);
Node& concat_training_node = graph.AddNode(graph.GenerateNodeName("ConcatTraining"),
"ConcatTraining",
"Concat with extra output",
concat_inputs,
concat_outputs,
&concat_node.GetAttributes(),
kMSDomain);
// Assign provider to this new node. Provider should be same as the provider for old node.
concat_training_node.SetExecutionProviderType(concat_node.GetExecutionProviderType());
graph_utils::FinalizeNodeFusion(graph, concat_training_node, concat_node);
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
return Status::OK();
}
bool ConcatReplacement::SatisfyCondition(const Graph&, const Node&, const logging::Logger&) const {
return true;
}
} // namespace onnxruntime

View file

@ -0,0 +1,32 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/rewrite_rule.h"
namespace onnxruntime {
/**
@Class ConcatReplacement
Rewrite rule that replaces Concat with ConcatTraining, that has an additional output
used in building the gradient for Concat node.
It is attempted to be triggered only on nodes with op type "Concat".
*/
class ConcatReplacement : public RewriteRule {
public:
ConcatReplacement() noexcept : RewriteRule("ConcatReplacement") {}
std::vector<std::string> TargetOpTypes() const noexcept override {
return {"Concat"};
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -37,6 +37,7 @@
#include "core/session/inference_session.h"
#include "orttraining/core/framework/distributed_run_context.h"
#include "orttraining/core/optimizer/bias_dropout_fusion.h"
#include "orttraining/core/optimizer/concat_replacement.h"
#include "orttraining/core/optimizer/insert_output_rewriter.h"
#include "orttraining/core/optimizer/localized_recompute.h"
#include "orttraining/core/optimizer/megatron_transformer.h"
@ -101,6 +102,10 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
case TransformerLevel::Level2: {
// Put ReshapeFusion as level-2 optimization after all level-1 graph rewriters are run.
transformers.emplace_back(onnxruntime::make_unique<ReshapeFusion>(compatible_eps));
rule_transformer =
onnxruntime::make_unique<RuleBasedGraphTransformer>(optimizer_utils::GenerateRuleBasedTransformerName(level),
compatible_eps);
rule_transformer->Register(onnxruntime::make_unique<ConcatReplacement>());
} break;
case TransformerLevel::Level3: {

View file

@ -248,7 +248,17 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::InitOpTesterWithGraph(
for (size_t data_index = 0; data_index < y_infos.size(); data_index++) {
std::string name = "output" + std::to_string(data_index);
op_session.AddOutput<Y_T>(name.c_str(), y_infos[data_index].shape.GetDims(), (*y_datas)[data_index]);
const std::vector<Y_T>& data = (*y_datas)[data_index];
if (y_infos[data_index].data_type == DataTypeImpl::GetTensorType<int64_t>()) {
std::vector<int64_t> int64_data(data.size());
std::transform(data.begin(), data.end(), int64_data.begin(), [](Y_T x) { return static_cast<int64_t>(x); });
op_session.AddOutput<int64_t>(name.c_str(),
y_infos[data_index].shape.GetDims(),
int64_data);
} else {
op_session.AddOutput<Y_T>(name.c_str(), y_infos[data_index].shape.GetDims(), data);
}
}
// Currently only allows setting int attributes to zero. TODO: Expand this
for (auto attr : attributes) {
@ -568,7 +578,7 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::ComputeGradientError(
}
// Compute gradient error.
return ComputeGradientErrorInternal(op_def, x_infos, y_infos, &x_datas, &y_datas, max_error,
return ComputeGradientErrorInternal(op_def, x_infos, y_infos, &x_datas, &y_datas, max_error,
attributes, check_not_have_gradient, check_not_have_shape_inferencing);
}

View file

@ -39,55 +39,52 @@ static bool IsErrorWithinTolerance(float error, float tolerance) {
EXPECT_IS_TINIER_THAN(max_error, 1.5e-2f)
static void RunReductionTests(const OpDef& op_def) {
TestDataVector test_data(
// Input X
{
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
},
// Input Y
{
{{1, 1, 1}},
{{}},
{{1, 3, 1}},
{{2}},
{{4, 1, 2}},
{{4, 3}},
{{4, 1, 2}},
{{4}}
},
// Attributes
{
// default
{},
// axes = [0, 1, 2], keepdims = 0
{MakeAttribute("axes", std::vector<int64_t>{0, 1, 2}),
MakeAttribute("keepdims", int64_t(0))},
// axes = [0, 2], keepdims = 1
{MakeAttribute("axes", std::vector<int64_t>{0, 2})},
// axes = [0, 1], keepdims = 0
{MakeAttribute("axes", std::vector<int64_t>{0, 1}),
MakeAttribute("keepdims", int64_t(0))},
// axes = [1], keepdims = 1
{MakeAttribute("axes", std::vector<int64_t>{1}),
MakeAttribute("keepdims", int64_t(1))},
// axes = [2], keepdims = 0
{MakeAttribute("axes", std::vector<int64_t>{2}),
MakeAttribute("keepdims", int64_t(0))},
// axes = [-2], keepdims = 1
{MakeAttribute("axes", std::vector<int64_t>{-2}),
MakeAttribute("keepdims", int64_t(1))},
// axes = [-2, -1], keepdims = 0
{MakeAttribute("axes", std::vector<int64_t>{-2, -1}),
MakeAttribute("keepdims", int64_t(0))}
});
// Input X
{
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
},
// Input Y
{
{{1, 1, 1}},
{{}},
{{1, 3, 1}},
{{2}},
{{4, 1, 2}},
{{4, 3}},
{{4, 1, 2}},
{{4}}},
// Attributes
{
// default
{},
// axes = [0, 1, 2], keepdims = 0
{MakeAttribute("axes", std::vector<int64_t>{0, 1, 2}),
MakeAttribute("keepdims", int64_t(0))},
// axes = [0, 2], keepdims = 1
{MakeAttribute("axes", std::vector<int64_t>{0, 2})},
// axes = [0, 1], keepdims = 0
{MakeAttribute("axes", std::vector<int64_t>{0, 1}),
MakeAttribute("keepdims", int64_t(0))},
// axes = [1], keepdims = 1
{MakeAttribute("axes", std::vector<int64_t>{1}),
MakeAttribute("keepdims", int64_t(1))},
// axes = [2], keepdims = 0
{MakeAttribute("axes", std::vector<int64_t>{2}),
MakeAttribute("keepdims", int64_t(0))},
// axes = [-2], keepdims = 1
{MakeAttribute("axes", std::vector<int64_t>{-2}),
MakeAttribute("keepdims", int64_t(1))},
// axes = [-2, -1], keepdims = 0
{MakeAttribute("axes", std::vector<int64_t>{-2, -1}),
MakeAttribute("keepdims", int64_t(0))}});
GradientChecker<float, float, float> gradient_checker;
@ -670,17 +667,25 @@ TEST(GradientCheckerTest, ConvGrad) {
}
}
TEST(GradientCheckerTest, ConcatGrad) {
static void TestConcatOpGrad(const std::string& op_type,
const std::string& domain = kOnnxDomain,
int opset_version = 9,
bool check_not_have_shape_inferencing = false) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"Concat"};
const bool extra_input = op_type == "ConcatTraining";
OpDef op_def{op_type, domain, opset_version};
//concat_1d
{
TensorShape x_shape({2});
TensorShape y_shape({6});
gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape}, {y_shape}, &max_error,
{MakeAttribute("axis", int64_t(0))});
std::vector<TensorInfo> output = {y_shape};
if (extra_input) output.push_back(TensorInfo({3}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>()));
gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape},
output, &max_error,
{MakeAttribute("axis", int64_t(0))}, true,
check_not_have_shape_inferencing);
EXPECT_IS_TINY(max_error);
}
@ -688,8 +693,12 @@ TEST(GradientCheckerTest, ConcatGrad) {
{
TensorShape x_shape({2, 2});
TensorShape y_shape({2, 6});
gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape}, {y_shape}, &max_error,
{MakeAttribute("axis", int64_t(1))});
std::vector<TensorInfo> output = {y_shape};
if (extra_input) output.push_back(TensorInfo({3}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>()));
gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape},
output, &max_error,
{MakeAttribute("axis", int64_t(1))}, true,
check_not_have_shape_inferencing);
EXPECT_IS_TINY(max_error);
}
@ -697,8 +706,12 @@ TEST(GradientCheckerTest, ConcatGrad) {
{
TensorShape x_shape({1, 2, 3});
TensorShape y_shape({1, 2, 9});
gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape}, {y_shape}, &max_error,
{MakeAttribute("axis", int64_t(2))});
std::vector<TensorInfo> output = {y_shape};
if (extra_input) output.push_back(TensorInfo({3}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>()));
gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape},
output, &max_error,
{MakeAttribute("axis", int64_t(2))}, true,
check_not_have_shape_inferencing);
EXPECT_IS_TINY(max_error);
}
@ -707,8 +720,12 @@ TEST(GradientCheckerTest, ConcatGrad) {
TensorShape x1_shape({2, 2});
TensorShape x2_shape({2, 4});
TensorShape y_shape({2, 6});
gradient_checker.ComputeGradientError(op_def, {x1_shape, x2_shape}, {y_shape}, &max_error,
{MakeAttribute("axis", int64_t(1))});
std::vector<TensorInfo> output = {y_shape};
if (extra_input) output.push_back(TensorInfo({2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>()));
gradient_checker.ComputeGradientError(op_def, {x1_shape, x2_shape},
output, &max_error,
{MakeAttribute("axis", int64_t(1))}, true,
check_not_have_shape_inferencing);
EXPECT_IS_TINY(max_error);
}
@ -717,12 +734,24 @@ TEST(GradientCheckerTest, ConcatGrad) {
TensorShape x1_shape({2, 2});
TensorShape x2_shape({2, 4});
TensorShape y_shape({2, 6});
gradient_checker.ComputeGradientError(op_def, {x1_shape, x2_shape}, {y_shape}, &max_error,
{MakeAttribute("axis", int64_t(-1))});
std::vector<TensorInfo> output = {y_shape};
if (extra_input) output.push_back(TensorInfo({2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>()));
gradient_checker.ComputeGradientError(op_def, {x1_shape, x2_shape},
output, &max_error,
{MakeAttribute("axis", int64_t(-1))}, true,
check_not_have_shape_inferencing);
EXPECT_IS_TINY(max_error);
}
}
TEST(GradientCheckerTest, ConcatGrad) {
TestConcatOpGrad("Concat");
}
TEST(GradientCheckerTest, ConcatTrainingGrad) { /*also test w/o shape inferencing */
TestConcatOpGrad("ConcatTraining", kMSDomain, 1, true);
}
TEST(GradientCheckerTest, AveragePoolGrad) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
@ -1909,4 +1938,3 @@ TEST(GradientCheckerTest, ExpandGrad) {
} // namespace onnxruntime
#endif // NDEBUG

View file

@ -5,6 +5,7 @@
#include "gtest/gtest.h"
#include "orttraining/core/optimizer/gist_encode_decode.h"
#include "test/providers/provider_test_utils.h"
#include "test/framework/test_utils.h"
#include "core/common/path_utils.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#include "core/session/environment.h"
@ -28,6 +29,7 @@ namespace test {
namespace {
constexpr auto ORIGINAL_MODEL_PATH = ORT_TSTR("testdata/test_training_model.onnx");
constexpr auto BACKWARD_MODEL_PATH = ORT_TSTR("testdata/temp_backward_model.onnx");
constexpr auto CONCAT_MODEL_PATH = ORT_TSTR("testdata/transform/concat_trainable.onnx");
std::unordered_set<std::string> GetModelOutputNames(const InferenceSession& session) {
const auto outputs_result = session.GetModelOutputs();
@ -167,6 +169,27 @@ TEST(GradientGraphBuilderTest, BuildGradientGraphTest) {
}
}
TEST(GradientGraphBuilderTest, BuildConcatGradientGraphTest) {
const auto config = MakeBasicTrainingConfig();
PathString backprop_model_file;
ASSERT_STATUS_OK(BuildBackPropGraph(CONCAT_MODEL_PATH, config, backprop_model_file));
std::shared_ptr<Model> pModel;
ASSERT_STATUS_OK(Model::Load(backprop_model_file, pModel, nullptr, DefaultLoggingManager().DefaultLogger()));
Graph& graph = pModel->MainGraph();
EXPECT_FALSE(graph.GraphResolveNeeded());
EXPECT_TRUE(graph.NumberOfNodes() > 0);
EXPECT_TRUE(graph.MaxNodeIndex() > 0);
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Concat"], 0);
ASSERT_EQ(op_to_count["Split"], 0);
ASSERT_EQ(op_to_count["ConcatTraining"], 1);
ASSERT_EQ(op_to_count["SplitTraining"], 1);
}
TEST(GradientGraphBuilderTest, TrainingSession_Basic) {
const auto config = MakeBasicTrainingConfig();
PathString backprop_model_file;

View file

@ -14,6 +14,7 @@
#include "orttraining/core/optimizer/gist_encode_decode.h"
#include "orttraining/core/optimizer/nonzero_shape_setter.h"
#include "orttraining/core/optimizer/megatron_transformer.h"
#include "orttraining/core/optimizer/concat_replacement.h"
#include "test/optimizer/graph_transform_test_fixture.h"
#include "test/util/include/default_providers.h"
#include "test/util/include/asserts.h"
@ -108,8 +109,28 @@ TEST_F(GraphTransformationTests, NonZeroShapeSetter) {
ASSERT_TRUE(nonzero_shape->dim(1).dim_param() == "nonzero_nonzero_count");
}
// MegatronF/G is defined only for training, and in msdomain.
// MegatronF/G and ConcatTraining is defined only for training, and in msdomain.
#ifndef DISABLE_CONTRIB_OPS
TEST_F(GraphTransformationTests, ConcatReplacement) {
auto model_uri = MODEL_FOLDER "concat_trainable.onnx";
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK());
Graph& graph = p_model->MainGraph();
auto rule_transformer_L1 = onnxruntime::make_unique<RuleBasedGraphTransformer>("ConcatReplacement");
rule_transformer_L1->Register(onnxruntime::make_unique<ConcatReplacement>());
onnxruntime::GraphTransformerManager graph_transformation_mgr{1};
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Concat"], 0);
ASSERT_EQ(op_to_count["ConcatTraining"], 1);
}
TEST_F(GraphTransformationTests, MegatronMLPPartitionRank0) {
auto model_uri = MODEL_FOLDER "model_parallel/mlp_megatron_basic_test.onnx";
std::shared_ptr<Model> p_model;
@ -483,7 +504,7 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {
TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionCorrectnessTest) {
auto model_uri = MODEL_FOLDER "model_parallel/self_attention_megatron_basic_test.onnx";
const int total_rank = 2; // The test graph is too small to partition to 4, so use 2 instead here.
const int total_rank = 2; // The test graph is too small to partition to 4, so use 2 instead here.
std::vector<Graph*> graphs;
std::vector<std::shared_ptr<Model>> p_models(total_rank);
for (auto i = 0; i < total_rank; i++) {