diff --git a/onnxruntime/test/testdata/transform/concat_graph_gen.py b/onnxruntime/test/testdata/transform/concat_graph_gen.py new file mode 100644 index 0000000000..599eeec221 --- /dev/null +++ b/onnxruntime/test/testdata/transform/concat_graph_gen.py @@ -0,0 +1,55 @@ +import onnx +from onnx import helper +from onnx import TensorProto +import numpy as np + +def GenerateModel(model_name): + nodes = [ + helper.make_node("Gather", ["embed_weights","input_1"], ["gather_out"], "gather"), + + helper.make_node("Add", ["gather_out", "add_q_weight"], ["add_q_out"], "add_q"), + helper.make_node("Add", ["gather_out", "add_k_weight"], ["add_k_out"], "add_k"), + helper.make_node("Add", ["gather_out", "add_v_weight"], ["add_v_out"], "add_v"), + + helper.make_node("Concat", ["add_q_out", "add_k_out", "add_v_out"], + ["concat_out"], "concat", axis=0), + + helper.make_node("Add", ["add_qkv_weight", "concat_out"], ["add_out"], "add"), + helper.make_node("ReduceSum",["add_out"],["predictions"],"reduce_sum_1", axes=[0], keepdims=1), + ] + + embed_weights = np.random.uniform(-1,1,8000).tolist() + + add_q_weight = [-0.23681640625, -0.16552734375, 0.2191162109375, -0.1756591796875, + -0.03460693359375, -0.05316162109375, -0.336181640625, -0.253662109375] + + add_k_weight = [0.0246734619140625, 0.011993408203125, 0.0178375244140625, 0.00998687744140625, + 0.0255126953125, 0.076416015625, -0.040771484375, 0.0107879638671875] + + add_v_weight = [-0.005893707275390625, -0.00916290283203125, 0.04541015625, 0.0159454345703125, + -0.0029163360595703125, -0.03472900390625, 0.0535888671875, 0.0091094970703125] + + initializers = [ # initializers + helper.make_tensor('embed_weights', TensorProto.FLOAT, [1000, 8], embed_weights), + helper.make_tensor('add_q_weight', TensorProto.FLOAT, [8], add_q_weight), + helper.make_tensor('add_k_weight', TensorProto.FLOAT, [8], add_k_weight), + helper.make_tensor('add_v_weight', TensorProto.FLOAT, [8], add_v_weight), + helper.make_tensor('add_qkv_weight', TensorProto.FLOAT, [1], [1.0]), + ] + + graph = helper.make_graph( + nodes, + "ConcatThreeInputs", #name + [ # inputs + helper.make_tensor_value_info('input_1', TensorProto.INT64, ['batch', 'seq_len']) + ], + [ # outputs + helper.make_tensor_value_info('predictions', TensorProto.FLOAT, [1,1,8]), + ], + initializers) + + model = helper.make_model(graph) + onnx.save(model, model_name) + +GenerateModel('concat_trainable.onnx') + diff --git a/onnxruntime/test/testdata/transform/concat_trainable.onnx b/onnxruntime/test/testdata/transform/concat_trainable.onnx new file mode 100644 index 0000000000..1f63c242be Binary files /dev/null and b/onnxruntime/test/testdata/transform/concat_trainable.onnx differ diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 5f83592553..438024873d 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -9,7 +9,9 @@ #include "onnx/defs/attr_proto_util.h" #include "onnx/defs/tensor_proto_util.h" + #include "core/framework/tensorprotoutils.h" +#include "core/providers/common.h" #include "orttraining/core/framework/distributed_run_context.h" #include "orttraining/core/graph/gradient_builder_registry.h" #include "orttraining/core/graph/graph_augmenter.h" @@ -534,6 +536,49 @@ IMPLEMENT_GRADIENT_BUILDER(GetConcatGradient) { new_attributes)}; } +IMPLEMENT_GRADIENT_BUILDER(GetConcatTrainingGradient) { + auto attributes = SrcNodeAttributes(); + ORT_ENFORCE(utils::HasInt(attributes.at("axis"))); + auto axis = attributes.at("axis").i(); + + std::vector split_attribute(GetSrcNodeInputSize()); + std::vector outputs; + bool known_shapes = true; + for (int i = 0; i < GetSrcNodeInputSize(); ++i) { + std::vector data_shape; + if (GetShape(I(i), data_shape).IsOK()) { + int64_t rank = static_cast(data_shape.size()); + int64_t axis_index = HandleNegativeAxis(axis, rank); + if (data_shape[axis_index].has_dim_value()) { + split_attribute[i] = data_shape[axis_index].dim_value(); + } else { + known_shapes = false; + } + } else { + known_shapes = false; + } + + outputs.push_back(GI(i)); + } + + std::vector new_attributes; + new_attributes.push_back(MakeAttribute("axis", axis)); + if (known_shapes) { + new_attributes.push_back(MakeAttribute("split", split_attribute)); + return std::vector{ + NodeDef("Split", + {GO(0)}, + outputs, + new_attributes)}; + } else { + return std::vector{ + NodeDef(OpDef{"SplitTraining", kMSDomain, 1}, + {GO(0), O(1)}, + outputs, + new_attributes)}; + } +} + IMPLEMENT_GRADIENT_BUILDER(GetGatherNDGradient) { auto attributes = SrcNodeAttributes(); ORT_ENFORCE(attributes.at("batch_dims").has_i()); diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index 50ee1c27e9..87f007b90c 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -28,6 +28,7 @@ DECLARE_GRADIENT_BUILDER(GetReduceSumGradient) DECLARE_GRADIENT_BUILDER(GetReduceLogSumExpGradient) DECLARE_GRADIENT_BUILDER(GetPowGradient) DECLARE_GRADIENT_BUILDER(GetConcatGradient) +DECLARE_GRADIENT_BUILDER(GetConcatTrainingGradient) DECLARE_GRADIENT_BUILDER(GetReshapeGradient) DECLARE_GRADIENT_BUILDER(GetTransposeGradient) DECLARE_GRADIENT_BUILDER(GetPoolGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 7eb39e97d0..5a4146e92c 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -60,6 +60,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("Mul", GetMulGradient); REGISTER_GRADIENT_BUILDER("Div", GetDivGradient); REGISTER_GRADIENT_BUILDER("Concat", GetConcatGradient); + REGISTER_GRADIENT_BUILDER("ConcatTraining", GetConcatTrainingGradient); REGISTER_GRADIENT_BUILDER("Reshape", GetReshapeGradient); REGISTER_GRADIENT_BUILDER("Transpose", GetTransposeGradient); REGISTER_GRADIENT_BUILDER("Gemm", GetGemmGradient); diff --git a/orttraining/orttraining/core/optimizer/concat_replacement.cc b/orttraining/orttraining/core/optimizer/concat_replacement.cc new file mode 100644 index 0000000000..37d302765c --- /dev/null +++ b/orttraining/orttraining/core/optimizer/concat_replacement.cc @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/core/optimizer/concat_replacement.h" + +#include "core/common/logging/logging.h" +#include "core/optimizer/rewrite_rule.h" +#include "core/optimizer/utils.h" +#include "core/graph/graph.h" +#include "core/graph/graph_utils.h" + +namespace onnxruntime { + +Status ConcatReplacement::Apply(Graph& graph, Node& concat_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { + const auto& concat_inputs = concat_node.MutableInputDefs(); + auto& concat_outputs = concat_node.MutableOutputDefs(); + + ONNX_NAMESPACE::TypeProto t; + t.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + t.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(concat_inputs.size()); + + NodeArg& ip_shape_op = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("per_input_length"), &t); + + concat_outputs.push_back(&ip_shape_op); + + Node& concat_training_node = graph.AddNode(graph.GenerateNodeName("ConcatTraining"), + "ConcatTraining", + "Concat with extra output", + concat_inputs, + concat_outputs, + &concat_node.GetAttributes(), + kMSDomain); + + // Assign provider to this new node. Provider should be same as the provider for old node. + concat_training_node.SetExecutionProviderType(concat_node.GetExecutionProviderType()); + graph_utils::FinalizeNodeFusion(graph, concat_training_node, concat_node); + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + return Status::OK(); +} + +bool ConcatReplacement::SatisfyCondition(const Graph&, const Node&, const logging::Logger&) const { + return true; +} + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/concat_replacement.h b/orttraining/orttraining/core/optimizer/concat_replacement.h new file mode 100644 index 0000000000..faa8877c56 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/concat_replacement.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { + +/** +@Class ConcatReplacement + +Rewrite rule that replaces Concat with ConcatTraining, that has an additional output +used in building the gradient for Concat node. + +It is attempted to be triggered only on nodes with op type "Concat". +*/ +class ConcatReplacement : public RewriteRule { + public: + ConcatReplacement() noexcept : RewriteRule("ConcatReplacement") {} + + std::vector TargetOpTypes() const noexcept override { + return {"Concat"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index b49e144203..7dc99e7711 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -37,6 +37,7 @@ #include "core/session/inference_session.h" #include "orttraining/core/framework/distributed_run_context.h" #include "orttraining/core/optimizer/bias_dropout_fusion.h" +#include "orttraining/core/optimizer/concat_replacement.h" #include "orttraining/core/optimizer/insert_output_rewriter.h" #include "orttraining/core/optimizer/localized_recompute.h" #include "orttraining/core/optimizer/megatron_transformer.h" @@ -101,6 +102,10 @@ std::vector> GeneratePreTrainingTransformers( case TransformerLevel::Level2: { // Put ReshapeFusion as level-2 optimization after all level-1 graph rewriters are run. transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); + rule_transformer = + onnxruntime::make_unique(optimizer_utils::GenerateRuleBasedTransformerName(level), + compatible_eps); + rule_transformer->Register(onnxruntime::make_unique()); } break; case TransformerLevel::Level3: { diff --git a/orttraining/orttraining/test/gradient/gradient_checker.cc b/orttraining/orttraining/test/gradient/gradient_checker.cc index 464ff515a7..f12fcfc802 100644 --- a/orttraining/orttraining/test/gradient/gradient_checker.cc +++ b/orttraining/orttraining/test/gradient/gradient_checker.cc @@ -248,7 +248,17 @@ inline Status GradientChecker::InitOpTesterWithGraph( for (size_t data_index = 0; data_index < y_infos.size(); data_index++) { std::string name = "output" + std::to_string(data_index); - op_session.AddOutput(name.c_str(), y_infos[data_index].shape.GetDims(), (*y_datas)[data_index]); + const std::vector& data = (*y_datas)[data_index]; + + if (y_infos[data_index].data_type == DataTypeImpl::GetTensorType()) { + std::vector int64_data(data.size()); + std::transform(data.begin(), data.end(), int64_data.begin(), [](Y_T x) { return static_cast(x); }); + op_session.AddOutput(name.c_str(), + y_infos[data_index].shape.GetDims(), + int64_data); + } else { + op_session.AddOutput(name.c_str(), y_infos[data_index].shape.GetDims(), data); + } } // Currently only allows setting int attributes to zero. TODO: Expand this for (auto attr : attributes) { @@ -568,7 +578,7 @@ inline Status GradientChecker::ComputeGradientError( } // Compute gradient error. - return ComputeGradientErrorInternal(op_def, x_infos, y_infos, &x_datas, &y_datas, max_error, + return ComputeGradientErrorInternal(op_def, x_infos, y_infos, &x_datas, &y_datas, max_error, attributes, check_not_have_gradient, check_not_have_shape_inferencing); } diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 773b4e1a8b..4c5abcfe51 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -39,55 +39,52 @@ static bool IsErrorWithinTolerance(float error, float tolerance) { EXPECT_IS_TINIER_THAN(max_error, 1.5e-2f) static void RunReductionTests(const OpDef& op_def) { - TestDataVector test_data( - // Input X - { - {{4, 3, 2}}, - {{4, 3, 2}}, - {{4, 3, 2}}, - {{4, 3, 2}}, - {{4, 3, 2}}, - {{4, 3, 2}}, - {{4, 3, 2}}, - {{4, 3, 2}}, - }, - // Input Y - { - {{1, 1, 1}}, - {{}}, - {{1, 3, 1}}, - {{2}}, - {{4, 1, 2}}, - {{4, 3}}, - {{4, 1, 2}}, - {{4}} - }, - // Attributes - { - // default - {}, - // axes = [0, 1, 2], keepdims = 0 - {MakeAttribute("axes", std::vector{0, 1, 2}), - MakeAttribute("keepdims", int64_t(0))}, - // axes = [0, 2], keepdims = 1 - {MakeAttribute("axes", std::vector{0, 2})}, - // axes = [0, 1], keepdims = 0 - {MakeAttribute("axes", std::vector{0, 1}), - MakeAttribute("keepdims", int64_t(0))}, - // axes = [1], keepdims = 1 - {MakeAttribute("axes", std::vector{1}), - MakeAttribute("keepdims", int64_t(1))}, - // axes = [2], keepdims = 0 - {MakeAttribute("axes", std::vector{2}), - MakeAttribute("keepdims", int64_t(0))}, - // axes = [-2], keepdims = 1 - {MakeAttribute("axes", std::vector{-2}), - MakeAttribute("keepdims", int64_t(1))}, - // axes = [-2, -1], keepdims = 0 - {MakeAttribute("axes", std::vector{-2, -1}), - MakeAttribute("keepdims", int64_t(0))} - }); + // Input X + { + {{4, 3, 2}}, + {{4, 3, 2}}, + {{4, 3, 2}}, + {{4, 3, 2}}, + {{4, 3, 2}}, + {{4, 3, 2}}, + {{4, 3, 2}}, + {{4, 3, 2}}, + }, + // Input Y + { + {{1, 1, 1}}, + {{}}, + {{1, 3, 1}}, + {{2}}, + {{4, 1, 2}}, + {{4, 3}}, + {{4, 1, 2}}, + {{4}}}, + // Attributes + { + // default + {}, + // axes = [0, 1, 2], keepdims = 0 + {MakeAttribute("axes", std::vector{0, 1, 2}), + MakeAttribute("keepdims", int64_t(0))}, + // axes = [0, 2], keepdims = 1 + {MakeAttribute("axes", std::vector{0, 2})}, + // axes = [0, 1], keepdims = 0 + {MakeAttribute("axes", std::vector{0, 1}), + MakeAttribute("keepdims", int64_t(0))}, + // axes = [1], keepdims = 1 + {MakeAttribute("axes", std::vector{1}), + MakeAttribute("keepdims", int64_t(1))}, + // axes = [2], keepdims = 0 + {MakeAttribute("axes", std::vector{2}), + MakeAttribute("keepdims", int64_t(0))}, + // axes = [-2], keepdims = 1 + {MakeAttribute("axes", std::vector{-2}), + MakeAttribute("keepdims", int64_t(1))}, + // axes = [-2, -1], keepdims = 0 + {MakeAttribute("axes", std::vector{-2, -1}), + MakeAttribute("keepdims", int64_t(0))}}); GradientChecker gradient_checker; @@ -670,17 +667,25 @@ TEST(GradientCheckerTest, ConvGrad) { } } -TEST(GradientCheckerTest, ConcatGrad) { +static void TestConcatOpGrad(const std::string& op_type, + const std::string& domain = kOnnxDomain, + int opset_version = 9, + bool check_not_have_shape_inferencing = false) { float max_error; GradientChecker gradient_checker; - OpDef op_def{"Concat"}; + const bool extra_input = op_type == "ConcatTraining"; + OpDef op_def{op_type, domain, opset_version}; //concat_1d { TensorShape x_shape({2}); TensorShape y_shape({6}); - gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape}, {y_shape}, &max_error, - {MakeAttribute("axis", int64_t(0))}); + std::vector output = {y_shape}; + if (extra_input) output.push_back(TensorInfo({3}, false, nullptr, DataTypeImpl::GetTensorType())); + gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape}, + output, &max_error, + {MakeAttribute("axis", int64_t(0))}, true, + check_not_have_shape_inferencing); EXPECT_IS_TINY(max_error); } @@ -688,8 +693,12 @@ TEST(GradientCheckerTest, ConcatGrad) { { TensorShape x_shape({2, 2}); TensorShape y_shape({2, 6}); - gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape}, {y_shape}, &max_error, - {MakeAttribute("axis", int64_t(1))}); + std::vector output = {y_shape}; + if (extra_input) output.push_back(TensorInfo({3}, false, nullptr, DataTypeImpl::GetTensorType())); + gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape}, + output, &max_error, + {MakeAttribute("axis", int64_t(1))}, true, + check_not_have_shape_inferencing); EXPECT_IS_TINY(max_error); } @@ -697,8 +706,12 @@ TEST(GradientCheckerTest, ConcatGrad) { { TensorShape x_shape({1, 2, 3}); TensorShape y_shape({1, 2, 9}); - gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape}, {y_shape}, &max_error, - {MakeAttribute("axis", int64_t(2))}); + std::vector output = {y_shape}; + if (extra_input) output.push_back(TensorInfo({3}, false, nullptr, DataTypeImpl::GetTensorType())); + gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape}, + output, &max_error, + {MakeAttribute("axis", int64_t(2))}, true, + check_not_have_shape_inferencing); EXPECT_IS_TINY(max_error); } @@ -707,8 +720,12 @@ TEST(GradientCheckerTest, ConcatGrad) { TensorShape x1_shape({2, 2}); TensorShape x2_shape({2, 4}); TensorShape y_shape({2, 6}); - gradient_checker.ComputeGradientError(op_def, {x1_shape, x2_shape}, {y_shape}, &max_error, - {MakeAttribute("axis", int64_t(1))}); + std::vector output = {y_shape}; + if (extra_input) output.push_back(TensorInfo({2}, false, nullptr, DataTypeImpl::GetTensorType())); + gradient_checker.ComputeGradientError(op_def, {x1_shape, x2_shape}, + output, &max_error, + {MakeAttribute("axis", int64_t(1))}, true, + check_not_have_shape_inferencing); EXPECT_IS_TINY(max_error); } @@ -717,12 +734,24 @@ TEST(GradientCheckerTest, ConcatGrad) { TensorShape x1_shape({2, 2}); TensorShape x2_shape({2, 4}); TensorShape y_shape({2, 6}); - gradient_checker.ComputeGradientError(op_def, {x1_shape, x2_shape}, {y_shape}, &max_error, - {MakeAttribute("axis", int64_t(-1))}); + std::vector output = {y_shape}; + if (extra_input) output.push_back(TensorInfo({2}, false, nullptr, DataTypeImpl::GetTensorType())); + gradient_checker.ComputeGradientError(op_def, {x1_shape, x2_shape}, + output, &max_error, + {MakeAttribute("axis", int64_t(-1))}, true, + check_not_have_shape_inferencing); EXPECT_IS_TINY(max_error); } } +TEST(GradientCheckerTest, ConcatGrad) { + TestConcatOpGrad("Concat"); +} + +TEST(GradientCheckerTest, ConcatTrainingGrad) { /*also test w/o shape inferencing */ + TestConcatOpGrad("ConcatTraining", kMSDomain, 1, true); +} + TEST(GradientCheckerTest, AveragePoolGrad) { float max_error; GradientChecker gradient_checker; @@ -1909,4 +1938,3 @@ TEST(GradientCheckerTest, ExpandGrad) { } // namespace onnxruntime #endif // NDEBUG - diff --git a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc index 2e7ae8ffe9..bea6fc97e6 100644 --- a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc +++ b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc @@ -5,6 +5,7 @@ #include "gtest/gtest.h" #include "orttraining/core/optimizer/gist_encode_decode.h" #include "test/providers/provider_test_utils.h" +#include "test/framework/test_utils.h" #include "core/common/path_utils.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/session/environment.h" @@ -28,6 +29,7 @@ namespace test { namespace { constexpr auto ORIGINAL_MODEL_PATH = ORT_TSTR("testdata/test_training_model.onnx"); constexpr auto BACKWARD_MODEL_PATH = ORT_TSTR("testdata/temp_backward_model.onnx"); +constexpr auto CONCAT_MODEL_PATH = ORT_TSTR("testdata/transform/concat_trainable.onnx"); std::unordered_set GetModelOutputNames(const InferenceSession& session) { const auto outputs_result = session.GetModelOutputs(); @@ -167,6 +169,27 @@ TEST(GradientGraphBuilderTest, BuildGradientGraphTest) { } } +TEST(GradientGraphBuilderTest, BuildConcatGradientGraphTest) { + const auto config = MakeBasicTrainingConfig(); + PathString backprop_model_file; + ASSERT_STATUS_OK(BuildBackPropGraph(CONCAT_MODEL_PATH, config, backprop_model_file)); + + std::shared_ptr pModel; + ASSERT_STATUS_OK(Model::Load(backprop_model_file, pModel, nullptr, DefaultLoggingManager().DefaultLogger())); + + Graph& graph = pModel->MainGraph(); + EXPECT_FALSE(graph.GraphResolveNeeded()); + EXPECT_TRUE(graph.NumberOfNodes() > 0); + EXPECT_TRUE(graph.MaxNodeIndex() > 0); + + std::map op_to_count = CountOpsInGraph(graph); + + ASSERT_EQ(op_to_count["Concat"], 0); + ASSERT_EQ(op_to_count["Split"], 0); + ASSERT_EQ(op_to_count["ConcatTraining"], 1); + ASSERT_EQ(op_to_count["SplitTraining"], 1); +} + TEST(GradientGraphBuilderTest, TrainingSession_Basic) { const auto config = MakeBasicTrainingConfig(); PathString backprop_model_file; diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index 873c7ade7e..484a2f2468 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -14,6 +14,7 @@ #include "orttraining/core/optimizer/gist_encode_decode.h" #include "orttraining/core/optimizer/nonzero_shape_setter.h" #include "orttraining/core/optimizer/megatron_transformer.h" +#include "orttraining/core/optimizer/concat_replacement.h" #include "test/optimizer/graph_transform_test_fixture.h" #include "test/util/include/default_providers.h" #include "test/util/include/asserts.h" @@ -108,8 +109,28 @@ TEST_F(GraphTransformationTests, NonZeroShapeSetter) { ASSERT_TRUE(nonzero_shape->dim(1).dim_param() == "nonzero_nonzero_count"); } -// MegatronF/G is defined only for training, and in msdomain. +// MegatronF/G and ConcatTraining is defined only for training, and in msdomain. #ifndef DISABLE_CONTRIB_OPS +TEST_F(GraphTransformationTests, ConcatReplacement) { + auto model_uri = MODEL_FOLDER "concat_trainable.onnx"; + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK()); + Graph& graph = p_model->MainGraph(); + + auto rule_transformer_L1 = onnxruntime::make_unique("ConcatReplacement"); + rule_transformer_L1->Register(onnxruntime::make_unique()); + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + + ASSERT_EQ(op_to_count["Concat"], 0); + ASSERT_EQ(op_to_count["ConcatTraining"], 1); +} + TEST_F(GraphTransformationTests, MegatronMLPPartitionRank0) { auto model_uri = MODEL_FOLDER "model_parallel/mlp_megatron_basic_test.onnx"; std::shared_ptr p_model; @@ -483,7 +504,7 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) { TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionCorrectnessTest) { auto model_uri = MODEL_FOLDER "model_parallel/self_attention_megatron_basic_test.onnx"; - const int total_rank = 2; // The test graph is too small to partition to 4, so use 2 instead here. + const int total_rank = 2; // The test graph is too small to partition to 4, so use 2 instead here. std::vector graphs; std::vector> p_models(total_rank); for (auto i = 0; i < total_rank; i++) {