mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Shape independent gradient builder for Concat (#4675)
* Add gradient for ConcatTraining * Graph rewriter changes for concat * Add generated onnx graph, minor fixes * Revert unintended change * Fix for MaxPoolGradTest * Fix UT * Review comments, windows tests * Review comments
This commit is contained in:
parent
8507bc1f48
commit
fc2f36c608
12 changed files with 331 additions and 65 deletions
55
onnxruntime/test/testdata/transform/concat_graph_gen.py
vendored
Normal file
55
onnxruntime/test/testdata/transform/concat_graph_gen.py
vendored
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import TensorProto
|
||||
import numpy as np
|
||||
|
||||
def GenerateModel(model_name):
|
||||
nodes = [
|
||||
helper.make_node("Gather", ["embed_weights","input_1"], ["gather_out"], "gather"),
|
||||
|
||||
helper.make_node("Add", ["gather_out", "add_q_weight"], ["add_q_out"], "add_q"),
|
||||
helper.make_node("Add", ["gather_out", "add_k_weight"], ["add_k_out"], "add_k"),
|
||||
helper.make_node("Add", ["gather_out", "add_v_weight"], ["add_v_out"], "add_v"),
|
||||
|
||||
helper.make_node("Concat", ["add_q_out", "add_k_out", "add_v_out"],
|
||||
["concat_out"], "concat", axis=0),
|
||||
|
||||
helper.make_node("Add", ["add_qkv_weight", "concat_out"], ["add_out"], "add"),
|
||||
helper.make_node("ReduceSum",["add_out"],["predictions"],"reduce_sum_1", axes=[0], keepdims=1),
|
||||
]
|
||||
|
||||
embed_weights = np.random.uniform(-1,1,8000).tolist()
|
||||
|
||||
add_q_weight = [-0.23681640625, -0.16552734375, 0.2191162109375, -0.1756591796875,
|
||||
-0.03460693359375, -0.05316162109375, -0.336181640625, -0.253662109375]
|
||||
|
||||
add_k_weight = [0.0246734619140625, 0.011993408203125, 0.0178375244140625, 0.00998687744140625,
|
||||
0.0255126953125, 0.076416015625, -0.040771484375, 0.0107879638671875]
|
||||
|
||||
add_v_weight = [-0.005893707275390625, -0.00916290283203125, 0.04541015625, 0.0159454345703125,
|
||||
-0.0029163360595703125, -0.03472900390625, 0.0535888671875, 0.0091094970703125]
|
||||
|
||||
initializers = [ # initializers
|
||||
helper.make_tensor('embed_weights', TensorProto.FLOAT, [1000, 8], embed_weights),
|
||||
helper.make_tensor('add_q_weight', TensorProto.FLOAT, [8], add_q_weight),
|
||||
helper.make_tensor('add_k_weight', TensorProto.FLOAT, [8], add_k_weight),
|
||||
helper.make_tensor('add_v_weight', TensorProto.FLOAT, [8], add_v_weight),
|
||||
helper.make_tensor('add_qkv_weight', TensorProto.FLOAT, [1], [1.0]),
|
||||
]
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
"ConcatThreeInputs", #name
|
||||
[ # inputs
|
||||
helper.make_tensor_value_info('input_1', TensorProto.INT64, ['batch', 'seq_len'])
|
||||
],
|
||||
[ # outputs
|
||||
helper.make_tensor_value_info('predictions', TensorProto.FLOAT, [1,1,8]),
|
||||
],
|
||||
initializers)
|
||||
|
||||
model = helper.make_model(graph)
|
||||
onnx.save(model, model_name)
|
||||
|
||||
GenerateModel('concat_trainable.onnx')
|
||||
|
||||
BIN
onnxruntime/test/testdata/transform/concat_trainable.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/concat_trainable.onnx
vendored
Normal file
Binary file not shown.
|
|
@ -9,7 +9,9 @@
|
|||
|
||||
#include "onnx/defs/attr_proto_util.h"
|
||||
#include "onnx/defs/tensor_proto_util.h"
|
||||
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "orttraining/core/framework/distributed_run_context.h"
|
||||
#include "orttraining/core/graph/gradient_builder_registry.h"
|
||||
#include "orttraining/core/graph/graph_augmenter.h"
|
||||
|
|
@ -534,6 +536,49 @@ IMPLEMENT_GRADIENT_BUILDER(GetConcatGradient) {
|
|||
new_attributes)};
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetConcatTrainingGradient) {
|
||||
auto attributes = SrcNodeAttributes();
|
||||
ORT_ENFORCE(utils::HasInt(attributes.at("axis")));
|
||||
auto axis = attributes.at("axis").i();
|
||||
|
||||
std::vector<int64_t> split_attribute(GetSrcNodeInputSize());
|
||||
std::vector<ArgDef> outputs;
|
||||
bool known_shapes = true;
|
||||
for (int i = 0; i < GetSrcNodeInputSize(); ++i) {
|
||||
std::vector<Dimension> data_shape;
|
||||
if (GetShape(I(i), data_shape).IsOK()) {
|
||||
int64_t rank = static_cast<int64_t>(data_shape.size());
|
||||
int64_t axis_index = HandleNegativeAxis(axis, rank);
|
||||
if (data_shape[axis_index].has_dim_value()) {
|
||||
split_attribute[i] = data_shape[axis_index].dim_value();
|
||||
} else {
|
||||
known_shapes = false;
|
||||
}
|
||||
} else {
|
||||
known_shapes = false;
|
||||
}
|
||||
|
||||
outputs.push_back(GI(i));
|
||||
}
|
||||
|
||||
std::vector<AttributeProto> new_attributes;
|
||||
new_attributes.push_back(MakeAttribute("axis", axis));
|
||||
if (known_shapes) {
|
||||
new_attributes.push_back(MakeAttribute("split", split_attribute));
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef("Split",
|
||||
{GO(0)},
|
||||
outputs,
|
||||
new_attributes)};
|
||||
} else {
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef(OpDef{"SplitTraining", kMSDomain, 1},
|
||||
{GO(0), O(1)},
|
||||
outputs,
|
||||
new_attributes)};
|
||||
}
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetGatherNDGradient) {
|
||||
auto attributes = SrcNodeAttributes();
|
||||
ORT_ENFORCE(attributes.at("batch_dims").has_i());
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ DECLARE_GRADIENT_BUILDER(GetReduceSumGradient)
|
|||
DECLARE_GRADIENT_BUILDER(GetReduceLogSumExpGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetPowGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetConcatGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetConcatTrainingGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetReshapeGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetTransposeGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetPoolGradient)
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
|
|||
REGISTER_GRADIENT_BUILDER("Mul", GetMulGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Div", GetDivGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Concat", GetConcatGradient);
|
||||
REGISTER_GRADIENT_BUILDER("ConcatTraining", GetConcatTrainingGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Reshape", GetReshapeGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Transpose", GetTransposeGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Gemm", GetGemmGradient);
|
||||
|
|
|
|||
45
orttraining/orttraining/core/optimizer/concat_replacement.cc
Normal file
45
orttraining/orttraining/core/optimizer/concat_replacement.cc
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/core/optimizer/concat_replacement.h"
|
||||
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/optimizer/rewrite_rule.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "core/graph/graph.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status ConcatReplacement::Apply(Graph& graph, Node& concat_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
|
||||
const auto& concat_inputs = concat_node.MutableInputDefs();
|
||||
auto& concat_outputs = concat_node.MutableOutputDefs();
|
||||
|
||||
ONNX_NAMESPACE::TypeProto t;
|
||||
t.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
t.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(concat_inputs.size());
|
||||
|
||||
NodeArg& ip_shape_op = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("per_input_length"), &t);
|
||||
|
||||
concat_outputs.push_back(&ip_shape_op);
|
||||
|
||||
Node& concat_training_node = graph.AddNode(graph.GenerateNodeName("ConcatTraining"),
|
||||
"ConcatTraining",
|
||||
"Concat with extra output",
|
||||
concat_inputs,
|
||||
concat_outputs,
|
||||
&concat_node.GetAttributes(),
|
||||
kMSDomain);
|
||||
|
||||
// Assign provider to this new node. Provider should be same as the provider for old node.
|
||||
concat_training_node.SetExecutionProviderType(concat_node.GetExecutionProviderType());
|
||||
graph_utils::FinalizeNodeFusion(graph, concat_training_node, concat_node);
|
||||
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool ConcatReplacement::SatisfyCondition(const Graph&, const Node&, const logging::Logger&) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
32
orttraining/orttraining/core/optimizer/concat_replacement.h
Normal file
32
orttraining/orttraining/core/optimizer/concat_replacement.h
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/rewrite_rule.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
@Class ConcatReplacement
|
||||
|
||||
Rewrite rule that replaces Concat with ConcatTraining, that has an additional output
|
||||
used in building the gradient for Concat node.
|
||||
|
||||
It is attempted to be triggered only on nodes with op type "Concat".
|
||||
*/
|
||||
class ConcatReplacement : public RewriteRule {
|
||||
public:
|
||||
ConcatReplacement() noexcept : RewriteRule("ConcatReplacement") {}
|
||||
|
||||
std::vector<std::string> TargetOpTypes() const noexcept override {
|
||||
return {"Concat"};
|
||||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -37,6 +37,7 @@
|
|||
#include "core/session/inference_session.h"
|
||||
#include "orttraining/core/framework/distributed_run_context.h"
|
||||
#include "orttraining/core/optimizer/bias_dropout_fusion.h"
|
||||
#include "orttraining/core/optimizer/concat_replacement.h"
|
||||
#include "orttraining/core/optimizer/insert_output_rewriter.h"
|
||||
#include "orttraining/core/optimizer/localized_recompute.h"
|
||||
#include "orttraining/core/optimizer/megatron_transformer.h"
|
||||
|
|
@ -101,6 +102,10 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
case TransformerLevel::Level2: {
|
||||
// Put ReshapeFusion as level-2 optimization after all level-1 graph rewriters are run.
|
||||
transformers.emplace_back(onnxruntime::make_unique<ReshapeFusion>(compatible_eps));
|
||||
rule_transformer =
|
||||
onnxruntime::make_unique<RuleBasedGraphTransformer>(optimizer_utils::GenerateRuleBasedTransformerName(level),
|
||||
compatible_eps);
|
||||
rule_transformer->Register(onnxruntime::make_unique<ConcatReplacement>());
|
||||
} break;
|
||||
|
||||
case TransformerLevel::Level3: {
|
||||
|
|
|
|||
|
|
@ -248,7 +248,17 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::InitOpTesterWithGraph(
|
|||
|
||||
for (size_t data_index = 0; data_index < y_infos.size(); data_index++) {
|
||||
std::string name = "output" + std::to_string(data_index);
|
||||
op_session.AddOutput<Y_T>(name.c_str(), y_infos[data_index].shape.GetDims(), (*y_datas)[data_index]);
|
||||
const std::vector<Y_T>& data = (*y_datas)[data_index];
|
||||
|
||||
if (y_infos[data_index].data_type == DataTypeImpl::GetTensorType<int64_t>()) {
|
||||
std::vector<int64_t> int64_data(data.size());
|
||||
std::transform(data.begin(), data.end(), int64_data.begin(), [](Y_T x) { return static_cast<int64_t>(x); });
|
||||
op_session.AddOutput<int64_t>(name.c_str(),
|
||||
y_infos[data_index].shape.GetDims(),
|
||||
int64_data);
|
||||
} else {
|
||||
op_session.AddOutput<Y_T>(name.c_str(), y_infos[data_index].shape.GetDims(), data);
|
||||
}
|
||||
}
|
||||
// Currently only allows setting int attributes to zero. TODO: Expand this
|
||||
for (auto attr : attributes) {
|
||||
|
|
@ -568,7 +578,7 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::ComputeGradientError(
|
|||
}
|
||||
|
||||
// Compute gradient error.
|
||||
return ComputeGradientErrorInternal(op_def, x_infos, y_infos, &x_datas, &y_datas, max_error,
|
||||
return ComputeGradientErrorInternal(op_def, x_infos, y_infos, &x_datas, &y_datas, max_error,
|
||||
attributes, check_not_have_gradient, check_not_have_shape_inferencing);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -39,55 +39,52 @@ static bool IsErrorWithinTolerance(float error, float tolerance) {
|
|||
EXPECT_IS_TINIER_THAN(max_error, 1.5e-2f)
|
||||
|
||||
static void RunReductionTests(const OpDef& op_def) {
|
||||
|
||||
TestDataVector test_data(
|
||||
// Input X
|
||||
{
|
||||
{{4, 3, 2}},
|
||||
{{4, 3, 2}},
|
||||
{{4, 3, 2}},
|
||||
{{4, 3, 2}},
|
||||
{{4, 3, 2}},
|
||||
{{4, 3, 2}},
|
||||
{{4, 3, 2}},
|
||||
{{4, 3, 2}},
|
||||
},
|
||||
// Input Y
|
||||
{
|
||||
{{1, 1, 1}},
|
||||
{{}},
|
||||
{{1, 3, 1}},
|
||||
{{2}},
|
||||
{{4, 1, 2}},
|
||||
{{4, 3}},
|
||||
{{4, 1, 2}},
|
||||
{{4}}
|
||||
},
|
||||
// Attributes
|
||||
{
|
||||
// default
|
||||
{},
|
||||
// axes = [0, 1, 2], keepdims = 0
|
||||
{MakeAttribute("axes", std::vector<int64_t>{0, 1, 2}),
|
||||
MakeAttribute("keepdims", int64_t(0))},
|
||||
// axes = [0, 2], keepdims = 1
|
||||
{MakeAttribute("axes", std::vector<int64_t>{0, 2})},
|
||||
// axes = [0, 1], keepdims = 0
|
||||
{MakeAttribute("axes", std::vector<int64_t>{0, 1}),
|
||||
MakeAttribute("keepdims", int64_t(0))},
|
||||
// axes = [1], keepdims = 1
|
||||
{MakeAttribute("axes", std::vector<int64_t>{1}),
|
||||
MakeAttribute("keepdims", int64_t(1))},
|
||||
// axes = [2], keepdims = 0
|
||||
{MakeAttribute("axes", std::vector<int64_t>{2}),
|
||||
MakeAttribute("keepdims", int64_t(0))},
|
||||
// axes = [-2], keepdims = 1
|
||||
{MakeAttribute("axes", std::vector<int64_t>{-2}),
|
||||
MakeAttribute("keepdims", int64_t(1))},
|
||||
// axes = [-2, -1], keepdims = 0
|
||||
{MakeAttribute("axes", std::vector<int64_t>{-2, -1}),
|
||||
MakeAttribute("keepdims", int64_t(0))}
|
||||
});
|
||||
// Input X
|
||||
{
|
||||
{{4, 3, 2}},
|
||||
{{4, 3, 2}},
|
||||
{{4, 3, 2}},
|
||||
{{4, 3, 2}},
|
||||
{{4, 3, 2}},
|
||||
{{4, 3, 2}},
|
||||
{{4, 3, 2}},
|
||||
{{4, 3, 2}},
|
||||
},
|
||||
// Input Y
|
||||
{
|
||||
{{1, 1, 1}},
|
||||
{{}},
|
||||
{{1, 3, 1}},
|
||||
{{2}},
|
||||
{{4, 1, 2}},
|
||||
{{4, 3}},
|
||||
{{4, 1, 2}},
|
||||
{{4}}},
|
||||
// Attributes
|
||||
{
|
||||
// default
|
||||
{},
|
||||
// axes = [0, 1, 2], keepdims = 0
|
||||
{MakeAttribute("axes", std::vector<int64_t>{0, 1, 2}),
|
||||
MakeAttribute("keepdims", int64_t(0))},
|
||||
// axes = [0, 2], keepdims = 1
|
||||
{MakeAttribute("axes", std::vector<int64_t>{0, 2})},
|
||||
// axes = [0, 1], keepdims = 0
|
||||
{MakeAttribute("axes", std::vector<int64_t>{0, 1}),
|
||||
MakeAttribute("keepdims", int64_t(0))},
|
||||
// axes = [1], keepdims = 1
|
||||
{MakeAttribute("axes", std::vector<int64_t>{1}),
|
||||
MakeAttribute("keepdims", int64_t(1))},
|
||||
// axes = [2], keepdims = 0
|
||||
{MakeAttribute("axes", std::vector<int64_t>{2}),
|
||||
MakeAttribute("keepdims", int64_t(0))},
|
||||
// axes = [-2], keepdims = 1
|
||||
{MakeAttribute("axes", std::vector<int64_t>{-2}),
|
||||
MakeAttribute("keepdims", int64_t(1))},
|
||||
// axes = [-2, -1], keepdims = 0
|
||||
{MakeAttribute("axes", std::vector<int64_t>{-2, -1}),
|
||||
MakeAttribute("keepdims", int64_t(0))}});
|
||||
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
|
||||
|
|
@ -670,17 +667,25 @@ TEST(GradientCheckerTest, ConvGrad) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, ConcatGrad) {
|
||||
static void TestConcatOpGrad(const std::string& op_type,
|
||||
const std::string& domain = kOnnxDomain,
|
||||
int opset_version = 9,
|
||||
bool check_not_have_shape_inferencing = false) {
|
||||
float max_error;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
OpDef op_def{"Concat"};
|
||||
const bool extra_input = op_type == "ConcatTraining";
|
||||
OpDef op_def{op_type, domain, opset_version};
|
||||
|
||||
//concat_1d
|
||||
{
|
||||
TensorShape x_shape({2});
|
||||
TensorShape y_shape({6});
|
||||
gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape}, {y_shape}, &max_error,
|
||||
{MakeAttribute("axis", int64_t(0))});
|
||||
std::vector<TensorInfo> output = {y_shape};
|
||||
if (extra_input) output.push_back(TensorInfo({3}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>()));
|
||||
gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape},
|
||||
output, &max_error,
|
||||
{MakeAttribute("axis", int64_t(0))}, true,
|
||||
check_not_have_shape_inferencing);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
|
|
@ -688,8 +693,12 @@ TEST(GradientCheckerTest, ConcatGrad) {
|
|||
{
|
||||
TensorShape x_shape({2, 2});
|
||||
TensorShape y_shape({2, 6});
|
||||
gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape}, {y_shape}, &max_error,
|
||||
{MakeAttribute("axis", int64_t(1))});
|
||||
std::vector<TensorInfo> output = {y_shape};
|
||||
if (extra_input) output.push_back(TensorInfo({3}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>()));
|
||||
gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape},
|
||||
output, &max_error,
|
||||
{MakeAttribute("axis", int64_t(1))}, true,
|
||||
check_not_have_shape_inferencing);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
|
|
@ -697,8 +706,12 @@ TEST(GradientCheckerTest, ConcatGrad) {
|
|||
{
|
||||
TensorShape x_shape({1, 2, 3});
|
||||
TensorShape y_shape({1, 2, 9});
|
||||
gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape}, {y_shape}, &max_error,
|
||||
{MakeAttribute("axis", int64_t(2))});
|
||||
std::vector<TensorInfo> output = {y_shape};
|
||||
if (extra_input) output.push_back(TensorInfo({3}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>()));
|
||||
gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape},
|
||||
output, &max_error,
|
||||
{MakeAttribute("axis", int64_t(2))}, true,
|
||||
check_not_have_shape_inferencing);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
|
|
@ -707,8 +720,12 @@ TEST(GradientCheckerTest, ConcatGrad) {
|
|||
TensorShape x1_shape({2, 2});
|
||||
TensorShape x2_shape({2, 4});
|
||||
TensorShape y_shape({2, 6});
|
||||
gradient_checker.ComputeGradientError(op_def, {x1_shape, x2_shape}, {y_shape}, &max_error,
|
||||
{MakeAttribute("axis", int64_t(1))});
|
||||
std::vector<TensorInfo> output = {y_shape};
|
||||
if (extra_input) output.push_back(TensorInfo({2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>()));
|
||||
gradient_checker.ComputeGradientError(op_def, {x1_shape, x2_shape},
|
||||
output, &max_error,
|
||||
{MakeAttribute("axis", int64_t(1))}, true,
|
||||
check_not_have_shape_inferencing);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
|
|
@ -717,12 +734,24 @@ TEST(GradientCheckerTest, ConcatGrad) {
|
|||
TensorShape x1_shape({2, 2});
|
||||
TensorShape x2_shape({2, 4});
|
||||
TensorShape y_shape({2, 6});
|
||||
gradient_checker.ComputeGradientError(op_def, {x1_shape, x2_shape}, {y_shape}, &max_error,
|
||||
{MakeAttribute("axis", int64_t(-1))});
|
||||
std::vector<TensorInfo> output = {y_shape};
|
||||
if (extra_input) output.push_back(TensorInfo({2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>()));
|
||||
gradient_checker.ComputeGradientError(op_def, {x1_shape, x2_shape},
|
||||
output, &max_error,
|
||||
{MakeAttribute("axis", int64_t(-1))}, true,
|
||||
check_not_have_shape_inferencing);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, ConcatGrad) {
|
||||
TestConcatOpGrad("Concat");
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, ConcatTrainingGrad) { /*also test w/o shape inferencing */
|
||||
TestConcatOpGrad("ConcatTraining", kMSDomain, 1, true);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, AveragePoolGrad) {
|
||||
float max_error;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
|
|
@ -1909,4 +1938,3 @@ TEST(GradientCheckerTest, ExpandGrad) {
|
|||
} // namespace onnxruntime
|
||||
|
||||
#endif // NDEBUG
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include "gtest/gtest.h"
|
||||
#include "orttraining/core/optimizer/gist_encode_decode.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/framework/test_utils.h"
|
||||
#include "core/common/path_utils.h"
|
||||
#include "core/providers/cpu/cpu_execution_provider.h"
|
||||
#include "core/session/environment.h"
|
||||
|
|
@ -28,6 +29,7 @@ namespace test {
|
|||
namespace {
|
||||
constexpr auto ORIGINAL_MODEL_PATH = ORT_TSTR("testdata/test_training_model.onnx");
|
||||
constexpr auto BACKWARD_MODEL_PATH = ORT_TSTR("testdata/temp_backward_model.onnx");
|
||||
constexpr auto CONCAT_MODEL_PATH = ORT_TSTR("testdata/transform/concat_trainable.onnx");
|
||||
|
||||
std::unordered_set<std::string> GetModelOutputNames(const InferenceSession& session) {
|
||||
const auto outputs_result = session.GetModelOutputs();
|
||||
|
|
@ -167,6 +169,27 @@ TEST(GradientGraphBuilderTest, BuildGradientGraphTest) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(GradientGraphBuilderTest, BuildConcatGradientGraphTest) {
|
||||
const auto config = MakeBasicTrainingConfig();
|
||||
PathString backprop_model_file;
|
||||
ASSERT_STATUS_OK(BuildBackPropGraph(CONCAT_MODEL_PATH, config, backprop_model_file));
|
||||
|
||||
std::shared_ptr<Model> pModel;
|
||||
ASSERT_STATUS_OK(Model::Load(backprop_model_file, pModel, nullptr, DefaultLoggingManager().DefaultLogger()));
|
||||
|
||||
Graph& graph = pModel->MainGraph();
|
||||
EXPECT_FALSE(graph.GraphResolveNeeded());
|
||||
EXPECT_TRUE(graph.NumberOfNodes() > 0);
|
||||
EXPECT_TRUE(graph.MaxNodeIndex() > 0);
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
|
||||
ASSERT_EQ(op_to_count["Concat"], 0);
|
||||
ASSERT_EQ(op_to_count["Split"], 0);
|
||||
ASSERT_EQ(op_to_count["ConcatTraining"], 1);
|
||||
ASSERT_EQ(op_to_count["SplitTraining"], 1);
|
||||
}
|
||||
|
||||
TEST(GradientGraphBuilderTest, TrainingSession_Basic) {
|
||||
const auto config = MakeBasicTrainingConfig();
|
||||
PathString backprop_model_file;
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
#include "orttraining/core/optimizer/gist_encode_decode.h"
|
||||
#include "orttraining/core/optimizer/nonzero_shape_setter.h"
|
||||
#include "orttraining/core/optimizer/megatron_transformer.h"
|
||||
#include "orttraining/core/optimizer/concat_replacement.h"
|
||||
#include "test/optimizer/graph_transform_test_fixture.h"
|
||||
#include "test/util/include/default_providers.h"
|
||||
#include "test/util/include/asserts.h"
|
||||
|
|
@ -108,8 +109,28 @@ TEST_F(GraphTransformationTests, NonZeroShapeSetter) {
|
|||
ASSERT_TRUE(nonzero_shape->dim(1).dim_param() == "nonzero_nonzero_count");
|
||||
}
|
||||
|
||||
// MegatronF/G is defined only for training, and in msdomain.
|
||||
// MegatronF/G and ConcatTraining is defined only for training, and in msdomain.
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
TEST_F(GraphTransformationTests, ConcatReplacement) {
|
||||
auto model_uri = MODEL_FOLDER "concat_trainable.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
auto rule_transformer_L1 = onnxruntime::make_unique<RuleBasedGraphTransformer>("ConcatReplacement");
|
||||
rule_transformer_L1->Register(onnxruntime::make_unique<ConcatReplacement>());
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{1};
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
|
||||
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
|
||||
ASSERT_EQ(op_to_count["Concat"], 0);
|
||||
ASSERT_EQ(op_to_count["ConcatTraining"], 1);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, MegatronMLPPartitionRank0) {
|
||||
auto model_uri = MODEL_FOLDER "model_parallel/mlp_megatron_basic_test.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
|
|
@ -483,7 +504,7 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {
|
|||
|
||||
TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionCorrectnessTest) {
|
||||
auto model_uri = MODEL_FOLDER "model_parallel/self_attention_megatron_basic_test.onnx";
|
||||
const int total_rank = 2; // The test graph is too small to partition to 4, so use 2 instead here.
|
||||
const int total_rank = 2; // The test graph is too small to partition to 4, so use 2 instead here.
|
||||
std::vector<Graph*> graphs;
|
||||
std::vector<std::shared_ptr<Model>> p_models(total_rank);
|
||||
for (auto i = 0; i < total_rank; i++) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue