Add FlattenAndUnpad Op (#17845)

### Description
Add an op named `FlattenAndUnpad`.
This op implements functions:
1. Flatten the first two dims of input tensor.
2. Gather valid value from input tensor with index tensor,.


### Motivation and Context
The grad op of `PadAndUnflatten` was `GatherGrad` which is inefficient
in performance.
I implement this `FlattenAndUnpad` just to replace the `GatherGrad` as
grad of `PadAndUnflatten`.
With this op, we also can simplify the "Reshape + ShrunkenGather"
pattern to `PadAndUnflatten` in padding elimination optimizer, which
will also improve performance.
This commit is contained in:
guyang3532 2023-11-09 09:52:48 +08:00 committed by GitHub
parent 885bf3561d
commit 4dc63692f8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 448 additions and 118 deletions

View file

@ -791,13 +791,16 @@ IMPLEMENT_GRADIENT_BUILDER(GetGatherGradient) {
IMPLEMENT_GRADIENT_BUILDER(GetPadAndUnflattenGradient) {
return std::vector<NodeDef>{
NodeDef(OpDef("Reshape"),
{GO(0), O(1)},
{IA("GO_reshaped")}),
NodeDef(OpDef{"Gather", kOnnxDomain, 1},
{IA("GO_reshaped"), I(1)},
{GI(0)},
SrcNodeAttributes())};
NodeDef(OpDef{"FlattenAndUnpad", kMSDomain, 1},
{GO(0), I(1)},
{GI(0), IA("Unflatten_dims")})};
}
IMPLEMENT_GRADIENT_BUILDER(GetFlattenAndUnpadGradient) {
return std::vector<NodeDef>{
NodeDef(OpDef{"PadAndUnflatten", kMSDomain, 1},
{GO(0), I(1), O(1)},
{GI(0)})};
}
IMPLEMENT_GRADIENT_BUILDER(GetShrunkenGatherGradient) {

View file

@ -40,6 +40,7 @@ DECLARE_GRADIENT_BUILDER(GetAveragePoolGradient)
DECLARE_GRADIENT_BUILDER(GetMaxPoolGradient)
DECLARE_GRADIENT_BUILDER(GetGatherGradient)
DECLARE_GRADIENT_BUILDER(GetPadAndUnflattenGradient)
DECLARE_GRADIENT_BUILDER(GetFlattenAndUnpadGradient)
DECLARE_GRADIENT_BUILDER(GetShrunkenGatherGradient)
DECLARE_GRADIENT_BUILDER(GetConvGradient)
DECLARE_GRADIENT_BUILDER(GetUnsqueezeGradient)

View file

@ -72,6 +72,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("MaxPool", GetMaxPoolGradient);
REGISTER_GRADIENT_BUILDER("Gather", GetGatherGradient);
REGISTER_GRADIENT_BUILDER("PadAndUnflatten", GetPadAndUnflattenGradient);
REGISTER_GRADIENT_BUILDER("FlattenAndUnpad", GetFlattenAndUnpadGradient);
REGISTER_GRADIENT_BUILDER("ShrunkenGather", GetShrunkenGatherGradient);
REGISTER_GRADIENT_BUILDER("Conv", GetConvGradient);
REGISTER_GRADIENT_BUILDER("Squeeze", GetSqueezeGradient);

View file

@ -4741,7 +4741,7 @@ Return true if all elements are true and false otherwise.
"For other indices, the corresponding value in output will be padded to zero."
"The indices don't allow duplicated index values, otherwise, though there is no runtime check"
"(in case of performance concern), the behaviour of output is undefined."
"(in case of performance concern), the behavior of output is undefined."
"An example:"
" input: [[1, 2, 3, 4], [5, 6, 7, 8]], shape is [2, 4]"
@ -4749,14 +4749,12 @@ Return true if all elements are true and false otherwise.
" unflatten_dims: [2, 3], shape is [2]"
" output: [[[1, 2, 3, 4], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [5, 6, 7, 8]]],"
" shape is [2, 3, 4]"
" flatten_output_shape: [6, 4], shape is [2]")
" shape is [2, 3, 4]")
.Input(0, "input", "input data of rank N, shape is [d1, d2, ..., dN]", "T")
.Input(1, "indices", "1D Tensor of int32/int64 indices, shape is [d1], each element's value ranges in [0, M1*M2).",
"T_INDEX")
.Input(2, "unflatten_dims", "1D tensor with two values, [M1, M2].", "T_INT")
.Output(0, "output", "output data of rank N+1, [M1, M2, d2, ..., dN]", "T")
.Output(1, "flatten_output_shape", "1D tensor with output shape, [M1*M2, d2, ..., dN]", "T_INT")
.TypeConstraint(
"T_INT",
{"tensor(int32)", "tensor(int64)"},
@ -4770,6 +4768,26 @@ Return true if all elements are true and false otherwise.
{"tensor(int32)", "tensor(int64)"},
"Constrain indices to integer types");
ONNX_CONTRIB_OPERATOR_SCHEMA(FlattenAndUnpad)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc(
"FlattenAndUnpad operator flattens the first two dims of input tensor, and unpad according to given indices."
"This is used by padding elimination graph transformer.")
.Input(0, "input", "input data of rank N + 1, shape is [M1, M2, d2, ..., dN]", "T")
.Input(1, "indices", "1D Tensor of int32/int64 indices, shape is [d1], each element's value ranges in [0, M1*M2).",
"T_INT")
.Output(0, "output", "output data of rank N, [d1, d2, ..., dN]", "T")
.Output(1, "unflatten_dims", "1D tensor with two values, [M1, M2].", "T_INT")
.TypeConstraint(
"T_INT",
{"tensor(int32)", "tensor(int64)"},
"Constrain indices and shape to integer tensors.")
.TypeConstraint(
"T",
{"tensor(int32)", "tensor(int64)", "tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.");
ONNX_CONTRIB_OPERATOR_SCHEMA(GRUTraining)
.SetDomain(kMSDomain)
.SinceVersion(1)

View file

@ -129,91 +129,43 @@ NodeArg* InsertExpandForNodeInput(Graph& graph,
return new_expand_node->MutableOutputDefs()[0];
}
// Insert Reshape + ShrunkenGather to flatten the in_index-th input of node.
// Insert FlattenAndUnpad to flatten and unpad the in_index-th input of node.
// The gather_index_arg is the indices of the elements that are not padding.
NodeArg* InsertFlattenPatternForInput(Graph& graph,
Node& node,
uint32_t in_index,
NodeArg* gather_index_arg,
const logging::Logger& logger) {
InlinedVector<NodeArg*> reshape_input_args;
reshape_input_args.reserve(2);
reshape_input_args.push_back(node.MutableInputDefs()[in_index]);
std::vector<int64_t> new_shape;
new_shape.push_back(-1); // only support flatten 0 and 1 dims
auto input_shape = node.InputDefs()[in_index]->Shape();
ORT_ENFORCE(input_shape->dim_size() >= 2);
ONNX_NAMESPACE::TensorShapeProto flattened_shape;
if (input_shape->dim(0).has_dim_value() && input_shape->dim(1).has_dim_value()) {
flattened_shape.add_dim()->set_dim_value(input_shape->dim(0).dim_value() * input_shape->dim(1).dim_value());
} else {
std::string token_dim_name = MakeString("total_token_count_", utils::GetRandomSeed());
flattened_shape.add_dim()->set_dim_param(token_dim_name);
}
for (int k = 2; k < input_shape->dim_size(); k++) {
ORT_ENFORCE(input_shape->dim(k).has_dim_value());
new_shape.push_back(input_shape->dim(k).dim_value());
flattened_shape.add_dim()->set_dim_value(input_shape->dim(k).dim_value());
}
ONNX_NAMESPACE::TensorProto new_shape_const_tensor;
new_shape_const_tensor.set_name(graph.GenerateNodeArgName("new_shape"));
new_shape_const_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
new_shape_const_tensor.add_dims(new_shape.size());
new_shape_const_tensor.set_raw_data(new_shape.data(), new_shape.size() * sizeof(int64_t));
NodeArg* new_shape_arg = &graph_utils::AddInitializer(graph, new_shape_const_tensor);
reshape_input_args.push_back(new_shape_arg);
InlinedVector<NodeArg*> unpad_input_args;
unpad_input_args.reserve(2);
unpad_input_args.push_back(node.MutableInputDefs()[in_index]);
unpad_input_args.push_back(gather_index_arg);
InlinedVector<NodeArg*> reshape_output_args;
reshape_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("inputs_reshape_result"),
node.MutableInputDefs()[in_index]->TypeAsProto()));
Node* new_reshape_node = InsertIntermediateNodeOnDestInput(
graph, node,
in_index,
0,
0,
graph.GenerateNodeName("Reshape"),
"Reshape",
"Reshape node to filter invalid tokens.",
reshape_input_args,
reshape_output_args,
{},
"",
logger);
new_reshape_node->SetExecutionProviderType(node.GetExecutionProviderType());
auto reshape_out_arg = new_reshape_node->MutableOutputDefs()[0];
reshape_out_arg->SetShape(flattened_shape);
InlinedVector<NodeArg*> gather_input_args;
gather_input_args.reserve(2);
gather_input_args.push_back(reshape_output_args[0]);
gather_input_args.push_back(gather_index_arg);
InlinedVector<NodeArg*> gather_output_args;
gather_output_args.push_back(
InlinedVector<NodeArg*> unpad_output_args;
unpad_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padding_filter_result"),
reshape_out_arg->TypeAsProto()));
nullptr));
unpad_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("d1_d2_shape"),
nullptr));
Node* new_gather_node = InsertIntermediateNodeOnDestInput(
Node* unpad_node = InsertIntermediateNodeOnDestInput(
graph, node,
in_index,
0,
0,
graph.GenerateNodeName("PaddingFilter"),
"ShrunkenGather",
"ShrunkenGather node to filter invalid tokens.",
gather_input_args,
gather_output_args,
"FlattenAndUnpad",
"FlattenAndUnpad node to filter invalid tokens.",
unpad_input_args,
unpad_output_args,
{},
kMSDomain,
logger);
new_gather_node->SetExecutionProviderType(node.GetExecutionProviderType());
auto gather_out_arg = new_gather_node->MutableOutputDefs()[0];
return gather_out_arg;
unpad_node->SetExecutionProviderType(node.GetExecutionProviderType());
auto unpad_out_arg = unpad_node->MutableOutputDefs()[0];
return unpad_out_arg;
}
// Insert PadAndUnflatten to unflatten the shape of the in_index-th input of node.
@ -236,10 +188,6 @@ NodeArg* InsertNodesForOutput(Graph& graph,
pad_node_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_result"),
nullptr));
pad_node_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_d1xd2_shape"),
nullptr));
Node* new_gathergrad_node = InsertIntermediateNodeOnDestInput(
graph, node,
in_index,

View file

@ -3011,7 +3011,6 @@ TEST(GradientCheckerTest, PadAndUnflattenGrad) {
std::vector<std::vector<float>> x_datas = {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 5, 0, 1}, {5, 2}};
TensorInfo padded_out_info({5, 2, 3}, true);
TensorInfo out_shape_info({2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
@ -3021,7 +3020,7 @@ TEST(GradientCheckerTest, PadAndUnflattenGrad) {
#endif
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info, indices_info, shape_info},
{padded_out_info, out_shape_info}, &max_error,
{padded_out_info}, &max_error,
x_datas, {}, true, false, &execution_providers));
EXPECT_IS_TINY(max_error);
}

View file

@ -5786,14 +5786,14 @@ def test_ops_for_padding_elimination(test_cases):
# the test_op should be included in padding elimination subgraph and the PadAndUnflatten should be
# added to output of test_op.
# in case 2, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, 1, hidden_size],
# the test_op should be included in padding elimination subgraph and a 'Expand + Reshape + ShrunkenGather'
# the test_op should be included in padding elimination subgraph and a 'Expand + FlattenAndUnpad'
# pattern should be insert to the arg of [batch_size, 1, hidden_size].
# in case 3, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [1, hidden_size],
# the test_op should be included in padding elimination subgraph and a 'Expand + Reshape + ShrunkenGather'
# the test_op should be included in padding elimination subgraph and a 'Expand + FlattenAndUnpad'
# pattern should be insert to the arg of [batch_size, 1, hidden_size].
# in case 4, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, seqlen, hidden_size],
# the test_op should be included in padding elimination subgraph and the PadAndUnflatten should be added to
# output of test_op. Besides, the other input of Add should be added 'Reshape + ShrunkenGather' to
# output of test_op. Besides, the other input of Add should be added 'FlattenAndUnpad' to
# flatten and elimination padding.
def test_elementwise(self, input_ids):
input_shape = input_ids.size()
@ -5905,9 +5905,9 @@ def test_ops_for_padding_elimination(test_cases):
assert len([node.op_type for node in training_model.graph.node if node.op_type == "Squeeze"]) == 1
assert len([node.op_type for node in training_model.graph.node if node.op_type == "PadAndUnflatten"]) == 1
if case >= 2:
assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 2
assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 3
else:
assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 1
assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 2
gathergrad_node = next(node for node in training_model.graph.node if node.op_type == "PadAndUnflatten")
def find_input_node_type(model, arg):
@ -6071,7 +6071,7 @@ def test_e2e_padding_elimination():
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-3, rtol=1e-4)
training_model = ort_model._torch_module._execution_manager(True)._onnx_models.optimized_model
assert "ShrunkenGather" in [node.op_type for node in training_model.graph.node]
assert "FlattenAndUnpad" in [node.op_type for node in training_model.graph.node]
assert "PadAndUnflatten" in [node.op_type for node in training_model.graph.node]
del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"]

View file

@ -0,0 +1,157 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "test/common/tensor_op_test_utils.h"
#include "test/providers/provider_test_utils.h"
namespace onnxruntime {
namespace test {
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(FlattenAndUnpadTest, Int32Type2D) {
std::vector<int32_t> input = {1, 1, 3, 2, 0, 3, 0, 4,
0, 5, 0, 6, 0, 0, 0};
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11};
std::vector<int32_t> output = {1, 2, 3, 4, 5, 6};
std::vector<int64_t> unflatten_dims = {5, 3};
OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
test.AddInput<int32_t>("input", {5, 3}, input);
test.AddInput<int64_t>("indices", {6}, indices);
test.AddOutput<int32_t>("output", {6}, output);
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.Run();
}
TEST(FlattenAndUnpadTest, Int32Type3D) {
std::vector<int32_t> input = {0, 0, 0, 1, 2, 3, 0, 0, 0,
4, 5, 6, 7, 8, 9, 0, 0, 0};
std::vector<int64_t> indices = {1, 3, 4};
std::vector<int32_t> output = {1, 2, 3, 4, 5, 6, 7, 8, 9};
std::vector<int64_t> unflatten_dims = {2, 3};
OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
test.AddInput<int32_t>("input", {2, 3, 3}, input);
test.AddInput<int64_t>("indices", {3}, indices);
test.AddOutput<int32_t>("output", {3, 3}, output);
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.Run();
}
TEST(FlattenAndUnpadTest, Int64Type2D) {
std::vector<int64_t> input = {1, 1, 3, 2, 0, 3, 0, 4,
0, 5, 0, 6, 0, 0, 0};
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11};
std::vector<int64_t> output = {1, 2, 3, 4, 5, 6};
std::vector<int64_t> unflatten_dims = {5, 3};
OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
test.AddInput<int64_t>("input", {5, 3}, input);
test.AddInput<int64_t>("indices", {6}, indices);
test.AddOutput<int64_t>("output", {6}, output);
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.Run();
}
TEST(FlattenAndUnpadTest, Int64Type3D) {
std::vector<int64_t> input = {0, 0, 0, 1, 2, 3, 0, 0, 0,
4, 5, 6, 7, 8, 9, 0, 0, 0};
std::vector<int64_t> indices = {1, 3, 4};
std::vector<int64_t> output = {1, 2, 3, 4, 5, 6, 7, 8, 9};
std::vector<int64_t> unflatten_dims = {2, 3};
OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
test.AddInput<int64_t>("input", {2, 3, 3}, input);
test.AddInput<int64_t>("indices", {3}, indices);
test.AddOutput<int64_t>("output", {3, 3}, output);
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.Run();
}
TEST(FlattenAndUnpadTest, FloatType2D) {
std::vector<float> input = {1.0f, 1.0f, 3.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f,
0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11};
std::vector<float> output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f};
std::vector<int64_t> unflatten_dims = {5, 3};
OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
test.AddInput<float>("input", {5, 3}, input);
test.AddInput<int64_t>("indices", {6}, indices);
test.AddOutput<float>("output", {6}, output);
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.Run();
}
TEST(FlattenAndUnpadTest, FloatType3D) {
std::vector<float> input = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f,
4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> indices = {1, 3, 4};
std::vector<float> output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f, 7.f, 8.f, 9.f};
std::vector<int64_t> unflatten_dims = {2, 3};
OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
test.AddInput<float>("input", {2, 3, 3}, input);
test.AddInput<int64_t>("indices", {3}, indices);
test.AddOutput<float>("output", {3, 3}, output);
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.Run();
}
TEST(FlattenAndUnpadTest, MLFloat16Type2D) {
std::vector<float> input = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f,
0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11};
std::vector<float> output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f};
std::vector<int64_t> unflatten_dims = {5, 3};
std::vector<MLFloat16> input_half;
input_half.resize(input.size());
ConvertFloatToMLFloat16(input.data(), input_half.data(), static_cast<int>(input.size()));
std::vector<MLFloat16> output_half;
output_half.resize(output.size());
ConvertFloatToMLFloat16(output.data(), output_half.data(), static_cast<int>(output.size()));
OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
test.AddInput<MLFloat16>("input", {5, 3}, input_half);
test.AddInput<int64_t>("indices", {6}, indices);
test.AddOutput<MLFloat16>("output", {6}, output_half);
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.Run();
}
TEST(FlattenAndUnpadTest, MLFloat16Type3D) {
std::vector<float> input = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f,
4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> indices = {1, 3, 4};
std::vector<float> output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f, 7.f, 8.f, 9.f};
std::vector<int64_t> unflatten_dims = {2, 3};
std::vector<MLFloat16> input_half;
input_half.resize(input.size());
ConvertFloatToMLFloat16(input.data(), input_half.data(), static_cast<int>(input.size()));
std::vector<MLFloat16> output_half;
output_half.resize(output.size());
ConvertFloatToMLFloat16(output.data(), output_half.data(), static_cast<int>(output.size()));
OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
test.AddInput<MLFloat16>("input", {2, 3, 3}, input_half);
test.AddInput<int64_t>("indices", {3}, indices);
test.AddOutput<MLFloat16>("output", {3, 3}, output_half);
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.Run();
}
#endif
} // namespace test
} // namespace onnxruntime

View file

@ -17,14 +17,11 @@ TEST(PadAndUnflattenTest, FloatType1D) {
std::vector<float> output = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f,
0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> full_flatten_dims = {15};
OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain);
test.AddInput<float>("input", {6}, input);
test.AddInput<int64_t>("indices", {6}, indices);
test.AddInput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.AddOutput<float>("output", {5, 3}, output);
test.AddOutput<int64_t>("full_flatten_dims", {1}, full_flatten_dims);
test.Run();
}
@ -36,14 +33,11 @@ TEST(PadAndUnflattenTest, FloatType2D) {
std::vector<float> output = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f,
4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> full_flatten_dims = {6, 3};
OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain);
test.AddInput<float>("input", {3, 3}, input);
test.AddInput<int64_t>("indices", {3}, indices);
test.AddInput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.AddOutput<float>("output", {2, 3, 3}, output);
test.AddOutput<int64_t>("full_flatten_dims", {2}, full_flatten_dims);
test.Run();
}
@ -55,8 +49,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type1D) {
std::vector<float> output = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f,
0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> full_flatten_dims = {15};
std::vector<MLFloat16> input_half;
input_half.resize(input.size());
ConvertFloatToMLFloat16(input.data(), input_half.data(), int(input.size()));
@ -69,7 +61,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type1D) {
test.AddInput<int64_t>("indices", {6}, indices);
test.AddInput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.AddOutput<MLFloat16>("output", {5, 3}, output_half);
test.AddOutput<int64_t>("full_flatten_dims", {1}, full_flatten_dims);
test.Run();
}
@ -81,8 +72,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type2D) {
std::vector<float> output = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f,
4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> full_flatten_dims = {6, 3};
std::vector<MLFloat16> input_half;
input_half.resize(input.size());
ConvertFloatToMLFloat16(input.data(), input_half.data(), int(input.size()));
@ -95,7 +84,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type2D) {
test.AddInput<int64_t>("indices", {3}, indices);
test.AddInput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.AddOutput<MLFloat16>("output", {2, 3, 3}, output_half);
test.AddOutput<int64_t>("full_flatten_dims", {2}, full_flatten_dims);
test.Run();
}

View file

@ -207,6 +207,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuantGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BatchScale);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PadAndUnflatten);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, FlattenAndUnpad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ScaledSum);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ResizeGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ResizeGrad);
@ -462,6 +463,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuantGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BatchScale)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PadAndUnflatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, FlattenAndUnpad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ScaledSum)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ResizeGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ResizeGrad)>,

View file

@ -0,0 +1,91 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "orttraining/training_ops/cuda/tensor/flatten_and_unpad.h"
#include "orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h"
#include "core/providers/cuda/shared_inc/cuda_utils.h"
namespace onnxruntime {
namespace cuda {
ONNX_OPERATOR_KERNEL_EX(
FlattenAndUnpad,
kMSDomain,
1,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", BuildKernelDefConstraints<int32_t, int64_t, MLFloat16, float, double, BFloat16>())
.TypeConstraint("T_INT", DataTypeImpl::GetTensorType<int64_t>())
.OutputMemoryType(OrtMemTypeCPUOutput, 1),
FlattenAndUnpad);
// Put implementation in the anonymous namespace to avoid name collision in the global namespace.
namespace {
template <typename T>
struct FlattenAndUnpadFunctor {
void operator()(cudaStream_t stream,
const int64_t output_element_count,
const fast_divmod output_element_stride_fdm,
const int64_t index_value_upper_bound,
const Tensor& input_tensor,
const Tensor& indices_tensor,
Tensor& output_tensor) const {
typedef typename ToCudaType<T>::MappedType CudaT;
const CudaT* input_data = reinterpret_cast<const CudaT*>(input_tensor.Data<T>());
FlattenAndUnpadImpl<CudaT>(stream, output_element_count, output_element_stride_fdm, index_value_upper_bound,
input_data, indices_tensor.Data<int64_t>(),
reinterpret_cast<CudaT*>(output_tensor.MutableData<T>()));
}
};
} // namespace
Status FlattenAndUnpad::ComputeInternal(OpKernelContext* context) const {
const Tensor* input_tensor = context->Input<Tensor>(0);
const Tensor* indices_tensor = context->Input<Tensor>(1);
ORT_ENFORCE(input_tensor->Shape().NumDimensions() >= 2,
"input_tensor tensor must have at least 2 dimensions.", input_tensor->Shape().NumDimensions());
ORT_ENFORCE(indices_tensor->Shape().NumDimensions() == 1,
"indices_tensor tensor must be 1-D.", indices_tensor->Shape().NumDimensions());
const auto& input_shape = input_tensor->Shape();
std::vector<int64_t> output_shape_vec;
output_shape_vec.reserve(input_shape.NumDimensions() - 1);
output_shape_vec.push_back(indices_tensor->Shape()[0]);
int64_t element_stride = 1;
for (size_t i = 2; i < input_shape.NumDimensions(); ++i) {
output_shape_vec.push_back(input_shape[i]);
element_stride *= input_shape[i];
}
fast_divmod output_element_stride_fdm(static_cast<int>(element_stride));
auto output_shape = TensorShape(output_shape_vec);
Tensor* output_tensor = context->Output(0, output_shape);
std::vector<int64_t> unflatten_dims_vec;
unflatten_dims_vec.reserve(2);
unflatten_dims_vec.push_back(input_shape[0]);
unflatten_dims_vec.push_back(input_shape[1]);
const int64_t index_value_upper_bound = input_shape[0] * input_shape[1];
utils::MLTypeCallDispatcher<int32_t, int64_t, float, MLFloat16, double, BFloat16>
t_disp(input_tensor->GetElementType());
t_disp.Invoke<FlattenAndUnpadFunctor>(Stream(context),
output_shape.Size(),
output_element_stride_fdm,
index_value_upper_bound,
*input_tensor,
*indices_tensor,
*output_tensor);
size_t rank = unflatten_dims_vec.size();
Tensor* unflatten_dims_tensor = context->Output(1, {static_cast<int>(rank)});
TensorShape(unflatten_dims_vec).CopyDims(unflatten_dims_tensor->MutableData<int64_t>(), rank);
return Status::OK();
}
} // namespace cuda
} // namespace onnxruntime

View file

@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/cuda/cuda_kernel.h"
#include "core/providers/common.h"
namespace onnxruntime {
namespace cuda {
class FlattenAndUnpad final : public CudaKernel {
public:
FlattenAndUnpad(const OpKernelInfo& info) : CudaKernel(info) {
}
Status ComputeInternal(OpKernelContext* context) const override;
};
} // namespace cuda
} // namespace onnxruntime

View file

@ -0,0 +1,83 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h"
#include "core/providers/cuda/cu_inc/common.cuh"
namespace onnxruntime {
namespace cuda {
constexpr int kBlockSize = 256;
constexpr int kNumUnroll = 4;
template <typename T>
__global__ void ExtractIputWithIndexKernel(const CUDA_LONG N,
const fast_divmod output_element_stride_fdm,
const int64_t index_value_upper_bound,
const T* input_data,
const int64_t* indices_data,
T* output_data) {
CUDA_LONG idx = blockDim.x * blockIdx.x + threadIdx.x;
CUDA_LONG id = idx * kNumUnroll;
T input[kNumUnroll];
if (id < N) {
#pragma unroll
for (int i = 0; i < kNumUnroll; ++i) {
CUDA_LONG li = id + i;
if (li < N) {
int row_index, col_index;
output_element_stride_fdm.divmod(li, row_index, col_index);
assert(indices_data[row_index] < index_value_upper_bound);
input[i] = input_data[indices_data[row_index] * output_element_stride_fdm.d_ + col_index];
}
}
}
#pragma unroll
for (int i = 0; i < kNumUnroll; ++i) {
CUDA_LONG li = id + i;
if (li < N) {
output_data[li] = input[i];
}
}
}
template <typename T>
void FlattenAndUnpadImpl(cudaStream_t stream,
const int64_t total_element_count,
const fast_divmod output_element_stride_fdm,
const int64_t index_value_upper_bound,
const T* input_data,
const int64_t* indices_data,
T* output_data) {
const int blocksPerGrid = static_cast<int>(CeilDiv(total_element_count, kBlockSize * kNumUnroll));
ExtractIputWithIndexKernel<T><<<blocksPerGrid, kBlockSize, 0, stream>>>(
static_cast<CUDA_LONG>(total_element_count),
output_element_stride_fdm,
index_value_upper_bound,
input_data,
indices_data,
output_data);
}
#define FLATTEN_AND_UNPAD_IMPL(T) \
template void FlattenAndUnpadImpl<T>(cudaStream_t stream, \
const int64_t total_element_count, \
const fast_divmod output_element_stride_fdm, \
const int64_t index_value_upper_bound, \
const T* input_data, \
const int64_t* indices_data, \
T* output_data);
FLATTEN_AND_UNPAD_IMPL(float)
FLATTEN_AND_UNPAD_IMPL(double)
FLATTEN_AND_UNPAD_IMPL(half)
FLATTEN_AND_UNPAD_IMPL(BFloat16)
FLATTEN_AND_UNPAD_IMPL(int32_t)
FLATTEN_AND_UNPAD_IMPL(int64_t)
#undef FLATTEN_AND_UNPAD_FROM_MASK_IMPL
} // namespace cuda
} // namespace onnxruntime

View file

@ -0,0 +1,25 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#ifdef USE_ROCM
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#else
#include "core/providers/cuda/shared_inc/cuda_utils.h"
#endif
namespace onnxruntime {
namespace cuda {
template <typename T>
void FlattenAndUnpadImpl(cudaStream_t stream,
const int64_t total_element_count,
const fast_divmod output_element_stride_fdm,
const int64_t index_value_upper_bound,
const T* input_data,
const int64_t* indices_data,
T* output_data);
} // namespace cuda
} // namespace onnxruntime

View file

@ -17,8 +17,7 @@ ONNX_OPERATOR_KERNEL_EX(
.TypeConstraint("T", BuildKernelDefConstraints<MLFloat16, float, double, BFloat16>())
.TypeConstraint("T_INT", DataTypeImpl::GetTensorType<int64_t>())
.TypeConstraint("T_INDEX", DataTypeImpl::GetTensorType<int64_t>())
.InputMemoryType(OrtMemTypeCPUInput, 2)
.OutputMemoryType(OrtMemTypeCPUOutput, 1),
.InputMemoryType(OrtMemTypeCPUInput, 2),
PadAndUnflatten);
// Put implementation in the anonymous namespace to avoid name collision in the global namespace.
@ -63,14 +62,11 @@ Status PadAndUnflatten::ComputeInternal(OpKernelContext* context) const {
output_shape_vec.push_back(dims_ptr[0]);
output_shape_vec.push_back(dims_ptr[1]);
std::vector<int64_t> full_size_flatten_shape_vec;
const int64_t flatten_dim_factor = dims_ptr[0] * dims_ptr[1];
full_size_flatten_shape_vec.push_back(flatten_dim_factor);
int64_t element_stride = 1;
for (size_t i = 1; i < input_shape.NumDimensions(); ++i) {
output_shape_vec.push_back(input_shape[i]);
full_size_flatten_shape_vec.push_back(input_shape[i]);
element_stride *= input_shape[i];
}
@ -87,11 +83,6 @@ Status PadAndUnflatten::ComputeInternal(OpKernelContext* context) const {
*indices_tensor,
*output_tensor);
// Set input shape output tensor.
size_t rank = full_size_flatten_shape_vec.size();
Tensor* input_shape_tensor = context->Output(1, {static_cast<int>(rank)});
TensorShape(full_size_flatten_shape_vec).CopyDims(input_shape_tensor->MutableData<int64_t>(), rank);
return Status::OK();
}

View file

@ -61,7 +61,7 @@ void PadAndUnflattenImpl(cudaStream_t stream,
output_data);
}
#define SPECIALIZED_RESTORE_FROM_MASK_IMPL(T) \
#define PAD_AND_UNFLATTEN_IMPL(T) \
template void PadAndUnflattenImpl<T>(cudaStream_t stream, \
const int64_t total_element_count, \
const fast_divmod output_element_stride_fdm, \
@ -70,12 +70,12 @@ void PadAndUnflattenImpl(cudaStream_t stream,
const int64_t* indices_data, \
T* output_data);
SPECIALIZED_RESTORE_FROM_MASK_IMPL(float)
SPECIALIZED_RESTORE_FROM_MASK_IMPL(double)
SPECIALIZED_RESTORE_FROM_MASK_IMPL(half)
SPECIALIZED_RESTORE_FROM_MASK_IMPL(BFloat16)
PAD_AND_UNFLATTEN_IMPL(float)
PAD_AND_UNFLATTEN_IMPL(double)
PAD_AND_UNFLATTEN_IMPL(half)
PAD_AND_UNFLATTEN_IMPL(BFloat16)
#undef SPECIALIZED_RESTORE_FROM_MASK_IMPL
#undef PAD_AND_UNFLATTEN_FROM_MASK_IMPL
} // namespace cuda
} // namespace onnxruntime

View file

@ -187,6 +187,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_BFloat16, ReduceAllL2);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, PadAndUnflatten);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, FlattenAndUnpad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ResizeGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ResizeGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ResizeGrad);
@ -390,6 +391,7 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_BFloat16, ReduceAllL2)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, PadAndUnflatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, FlattenAndUnpad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ResizeGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ResizeGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ResizeGrad)>,