mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
Add FlattenAndUnpad Op (#17845)
### Description Add an op named `FlattenAndUnpad`. This op implements functions: 1. Flatten the first two dims of input tensor. 2. Gather valid value from input tensor with index tensor,. ### Motivation and Context The grad op of `PadAndUnflatten` was `GatherGrad` which is inefficient in performance. I implement this `FlattenAndUnpad` just to replace the `GatherGrad` as grad of `PadAndUnflatten`. With this op, we also can simplify the "Reshape + ShrunkenGather" pattern to `PadAndUnflatten` in padding elimination optimizer, which will also improve performance.
This commit is contained in:
parent
885bf3561d
commit
4dc63692f8
17 changed files with 448 additions and 118 deletions
|
|
@ -791,13 +791,16 @@ IMPLEMENT_GRADIENT_BUILDER(GetGatherGradient) {
|
|||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetPadAndUnflattenGradient) {
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef(OpDef("Reshape"),
|
||||
{GO(0), O(1)},
|
||||
{IA("GO_reshaped")}),
|
||||
NodeDef(OpDef{"Gather", kOnnxDomain, 1},
|
||||
{IA("GO_reshaped"), I(1)},
|
||||
{GI(0)},
|
||||
SrcNodeAttributes())};
|
||||
NodeDef(OpDef{"FlattenAndUnpad", kMSDomain, 1},
|
||||
{GO(0), I(1)},
|
||||
{GI(0), IA("Unflatten_dims")})};
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetFlattenAndUnpadGradient) {
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef(OpDef{"PadAndUnflatten", kMSDomain, 1},
|
||||
{GO(0), I(1), O(1)},
|
||||
{GI(0)})};
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetShrunkenGatherGradient) {
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ DECLARE_GRADIENT_BUILDER(GetAveragePoolGradient)
|
|||
DECLARE_GRADIENT_BUILDER(GetMaxPoolGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetGatherGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetPadAndUnflattenGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetFlattenAndUnpadGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetShrunkenGatherGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetConvGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetUnsqueezeGradient)
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
|
|||
REGISTER_GRADIENT_BUILDER("MaxPool", GetMaxPoolGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Gather", GetGatherGradient);
|
||||
REGISTER_GRADIENT_BUILDER("PadAndUnflatten", GetPadAndUnflattenGradient);
|
||||
REGISTER_GRADIENT_BUILDER("FlattenAndUnpad", GetFlattenAndUnpadGradient);
|
||||
REGISTER_GRADIENT_BUILDER("ShrunkenGather", GetShrunkenGatherGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Conv", GetConvGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Squeeze", GetSqueezeGradient);
|
||||
|
|
|
|||
|
|
@ -4741,7 +4741,7 @@ Return true if all elements are true and false otherwise.
|
|||
"For other indices, the corresponding value in output will be padded to zero."
|
||||
|
||||
"The indices don't allow duplicated index values, otherwise, though there is no runtime check"
|
||||
"(in case of performance concern), the behaviour of output is undefined."
|
||||
"(in case of performance concern), the behavior of output is undefined."
|
||||
|
||||
"An example:"
|
||||
" input: [[1, 2, 3, 4], [5, 6, 7, 8]], shape is [2, 4]"
|
||||
|
|
@ -4749,14 +4749,12 @@ Return true if all elements are true and false otherwise.
|
|||
" unflatten_dims: [2, 3], shape is [2]"
|
||||
|
||||
" output: [[[1, 2, 3, 4], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [5, 6, 7, 8]]],"
|
||||
" shape is [2, 3, 4]"
|
||||
" flatten_output_shape: [6, 4], shape is [2]")
|
||||
" shape is [2, 3, 4]")
|
||||
.Input(0, "input", "input data of rank N, shape is [d1, d2, ..., dN]", "T")
|
||||
.Input(1, "indices", "1D Tensor of int32/int64 indices, shape is [d1], each element's value ranges in [0, M1*M2).",
|
||||
"T_INDEX")
|
||||
.Input(2, "unflatten_dims", "1D tensor with two values, [M1, M2].", "T_INT")
|
||||
.Output(0, "output", "output data of rank N+1, [M1, M2, d2, ..., dN]", "T")
|
||||
.Output(1, "flatten_output_shape", "1D tensor with output shape, [M1*M2, d2, ..., dN]", "T_INT")
|
||||
.TypeConstraint(
|
||||
"T_INT",
|
||||
{"tensor(int32)", "tensor(int64)"},
|
||||
|
|
@ -4770,6 +4768,26 @@ Return true if all elements are true and false otherwise.
|
|||
{"tensor(int32)", "tensor(int64)"},
|
||||
"Constrain indices to integer types");
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(FlattenAndUnpad)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.SetDoc(
|
||||
"FlattenAndUnpad operator flattens the first two dims of input tensor, and unpad according to given indices."
|
||||
"This is used by padding elimination graph transformer.")
|
||||
.Input(0, "input", "input data of rank N + 1, shape is [M1, M2, d2, ..., dN]", "T")
|
||||
.Input(1, "indices", "1D Tensor of int32/int64 indices, shape is [d1], each element's value ranges in [0, M1*M2).",
|
||||
"T_INT")
|
||||
.Output(0, "output", "output data of rank N, [d1, d2, ..., dN]", "T")
|
||||
.Output(1, "unflatten_dims", "1D tensor with two values, [M1, M2].", "T_INT")
|
||||
.TypeConstraint(
|
||||
"T_INT",
|
||||
{"tensor(int32)", "tensor(int64)"},
|
||||
"Constrain indices and shape to integer tensors.")
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
{"tensor(int32)", "tensor(int64)", "tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
|
||||
"Constrain input and output types to float tensors.");
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(GRUTraining)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
|
|
|
|||
|
|
@ -129,91 +129,43 @@ NodeArg* InsertExpandForNodeInput(Graph& graph,
|
|||
return new_expand_node->MutableOutputDefs()[0];
|
||||
}
|
||||
|
||||
// Insert Reshape + ShrunkenGather to flatten the in_index-th input of node.
|
||||
// Insert FlattenAndUnpad to flatten and unpad the in_index-th input of node.
|
||||
// The gather_index_arg is the indices of the elements that are not padding.
|
||||
NodeArg* InsertFlattenPatternForInput(Graph& graph,
|
||||
Node& node,
|
||||
uint32_t in_index,
|
||||
NodeArg* gather_index_arg,
|
||||
const logging::Logger& logger) {
|
||||
InlinedVector<NodeArg*> reshape_input_args;
|
||||
reshape_input_args.reserve(2);
|
||||
reshape_input_args.push_back(node.MutableInputDefs()[in_index]);
|
||||
std::vector<int64_t> new_shape;
|
||||
new_shape.push_back(-1); // only support flatten 0 and 1 dims
|
||||
auto input_shape = node.InputDefs()[in_index]->Shape();
|
||||
ORT_ENFORCE(input_shape->dim_size() >= 2);
|
||||
ONNX_NAMESPACE::TensorShapeProto flattened_shape;
|
||||
if (input_shape->dim(0).has_dim_value() && input_shape->dim(1).has_dim_value()) {
|
||||
flattened_shape.add_dim()->set_dim_value(input_shape->dim(0).dim_value() * input_shape->dim(1).dim_value());
|
||||
} else {
|
||||
std::string token_dim_name = MakeString("total_token_count_", utils::GetRandomSeed());
|
||||
flattened_shape.add_dim()->set_dim_param(token_dim_name);
|
||||
}
|
||||
for (int k = 2; k < input_shape->dim_size(); k++) {
|
||||
ORT_ENFORCE(input_shape->dim(k).has_dim_value());
|
||||
new_shape.push_back(input_shape->dim(k).dim_value());
|
||||
flattened_shape.add_dim()->set_dim_value(input_shape->dim(k).dim_value());
|
||||
}
|
||||
ONNX_NAMESPACE::TensorProto new_shape_const_tensor;
|
||||
new_shape_const_tensor.set_name(graph.GenerateNodeArgName("new_shape"));
|
||||
new_shape_const_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
new_shape_const_tensor.add_dims(new_shape.size());
|
||||
new_shape_const_tensor.set_raw_data(new_shape.data(), new_shape.size() * sizeof(int64_t));
|
||||
NodeArg* new_shape_arg = &graph_utils::AddInitializer(graph, new_shape_const_tensor);
|
||||
reshape_input_args.push_back(new_shape_arg);
|
||||
InlinedVector<NodeArg*> unpad_input_args;
|
||||
unpad_input_args.reserve(2);
|
||||
unpad_input_args.push_back(node.MutableInputDefs()[in_index]);
|
||||
unpad_input_args.push_back(gather_index_arg);
|
||||
|
||||
InlinedVector<NodeArg*> reshape_output_args;
|
||||
reshape_output_args.push_back(
|
||||
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("inputs_reshape_result"),
|
||||
node.MutableInputDefs()[in_index]->TypeAsProto()));
|
||||
|
||||
Node* new_reshape_node = InsertIntermediateNodeOnDestInput(
|
||||
graph, node,
|
||||
in_index,
|
||||
0,
|
||||
0,
|
||||
graph.GenerateNodeName("Reshape"),
|
||||
"Reshape",
|
||||
"Reshape node to filter invalid tokens.",
|
||||
reshape_input_args,
|
||||
reshape_output_args,
|
||||
{},
|
||||
"",
|
||||
logger);
|
||||
|
||||
new_reshape_node->SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
auto reshape_out_arg = new_reshape_node->MutableOutputDefs()[0];
|
||||
|
||||
reshape_out_arg->SetShape(flattened_shape);
|
||||
|
||||
InlinedVector<NodeArg*> gather_input_args;
|
||||
gather_input_args.reserve(2);
|
||||
gather_input_args.push_back(reshape_output_args[0]);
|
||||
gather_input_args.push_back(gather_index_arg);
|
||||
|
||||
InlinedVector<NodeArg*> gather_output_args;
|
||||
gather_output_args.push_back(
|
||||
InlinedVector<NodeArg*> unpad_output_args;
|
||||
unpad_output_args.push_back(
|
||||
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padding_filter_result"),
|
||||
reshape_out_arg->TypeAsProto()));
|
||||
nullptr));
|
||||
unpad_output_args.push_back(
|
||||
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("d1_d2_shape"),
|
||||
nullptr));
|
||||
|
||||
Node* new_gather_node = InsertIntermediateNodeOnDestInput(
|
||||
Node* unpad_node = InsertIntermediateNodeOnDestInput(
|
||||
graph, node,
|
||||
in_index,
|
||||
0,
|
||||
0,
|
||||
graph.GenerateNodeName("PaddingFilter"),
|
||||
"ShrunkenGather",
|
||||
"ShrunkenGather node to filter invalid tokens.",
|
||||
gather_input_args,
|
||||
gather_output_args,
|
||||
"FlattenAndUnpad",
|
||||
"FlattenAndUnpad node to filter invalid tokens.",
|
||||
unpad_input_args,
|
||||
unpad_output_args,
|
||||
{},
|
||||
kMSDomain,
|
||||
logger);
|
||||
|
||||
new_gather_node->SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
auto gather_out_arg = new_gather_node->MutableOutputDefs()[0];
|
||||
return gather_out_arg;
|
||||
unpad_node->SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
auto unpad_out_arg = unpad_node->MutableOutputDefs()[0];
|
||||
return unpad_out_arg;
|
||||
}
|
||||
|
||||
// Insert PadAndUnflatten to unflatten the shape of the in_index-th input of node.
|
||||
|
|
@ -236,10 +188,6 @@ NodeArg* InsertNodesForOutput(Graph& graph,
|
|||
pad_node_output_args.push_back(
|
||||
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_result"),
|
||||
nullptr));
|
||||
pad_node_output_args.push_back(
|
||||
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_d1xd2_shape"),
|
||||
nullptr));
|
||||
|
||||
Node* new_gathergrad_node = InsertIntermediateNodeOnDestInput(
|
||||
graph, node,
|
||||
in_index,
|
||||
|
|
|
|||
|
|
@ -3011,7 +3011,6 @@ TEST(GradientCheckerTest, PadAndUnflattenGrad) {
|
|||
std::vector<std::vector<float>> x_datas = {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 5, 0, 1}, {5, 2}};
|
||||
|
||||
TensorInfo padded_out_info({5, 2, 3}, true);
|
||||
TensorInfo out_shape_info({2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
|
||||
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
#ifdef USE_CUDA
|
||||
|
|
@ -3021,7 +3020,7 @@ TEST(GradientCheckerTest, PadAndUnflattenGrad) {
|
|||
#endif
|
||||
|
||||
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info, indices_info, shape_info},
|
||||
{padded_out_info, out_shape_info}, &max_error,
|
||||
{padded_out_info}, &max_error,
|
||||
x_datas, {}, true, false, &execution_providers));
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5786,14 +5786,14 @@ def test_ops_for_padding_elimination(test_cases):
|
|||
# the test_op should be included in padding elimination subgraph and the PadAndUnflatten should be
|
||||
# added to output of test_op.
|
||||
# in case 2, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, 1, hidden_size],
|
||||
# the test_op should be included in padding elimination subgraph and a 'Expand + Reshape + ShrunkenGather'
|
||||
# the test_op should be included in padding elimination subgraph and a 'Expand + FlattenAndUnpad'
|
||||
# pattern should be insert to the arg of [batch_size, 1, hidden_size].
|
||||
# in case 3, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [1, hidden_size],
|
||||
# the test_op should be included in padding elimination subgraph and a 'Expand + Reshape + ShrunkenGather'
|
||||
# the test_op should be included in padding elimination subgraph and a 'Expand + FlattenAndUnpad'
|
||||
# pattern should be insert to the arg of [batch_size, 1, hidden_size].
|
||||
# in case 4, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, seqlen, hidden_size],
|
||||
# the test_op should be included in padding elimination subgraph and the PadAndUnflatten should be added to
|
||||
# output of test_op. Besides, the other input of Add should be added 'Reshape + ShrunkenGather' to
|
||||
# output of test_op. Besides, the other input of Add should be added 'FlattenAndUnpad' to
|
||||
# flatten and elimination padding.
|
||||
def test_elementwise(self, input_ids):
|
||||
input_shape = input_ids.size()
|
||||
|
|
@ -5905,9 +5905,9 @@ def test_ops_for_padding_elimination(test_cases):
|
|||
assert len([node.op_type for node in training_model.graph.node if node.op_type == "Squeeze"]) == 1
|
||||
assert len([node.op_type for node in training_model.graph.node if node.op_type == "PadAndUnflatten"]) == 1
|
||||
if case >= 2:
|
||||
assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 2
|
||||
assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 3
|
||||
else:
|
||||
assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 1
|
||||
assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 2
|
||||
gathergrad_node = next(node for node in training_model.graph.node if node.op_type == "PadAndUnflatten")
|
||||
|
||||
def find_input_node_type(model, arg):
|
||||
|
|
@ -6071,7 +6071,7 @@ def test_e2e_padding_elimination():
|
|||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-3, rtol=1e-4)
|
||||
|
||||
training_model = ort_model._torch_module._execution_manager(True)._onnx_models.optimized_model
|
||||
assert "ShrunkenGather" in [node.op_type for node in training_model.graph.node]
|
||||
assert "FlattenAndUnpad" in [node.op_type for node in training_model.graph.node]
|
||||
assert "PadAndUnflatten" in [node.op_type for node in training_model.graph.node]
|
||||
del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"]
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,157 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "test/common/tensor_op_test_utils.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
|
||||
TEST(FlattenAndUnpadTest, Int32Type2D) {
|
||||
std::vector<int32_t> input = {1, 1, 3, 2, 0, 3, 0, 4,
|
||||
0, 5, 0, 6, 0, 0, 0};
|
||||
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11};
|
||||
|
||||
std::vector<int32_t> output = {1, 2, 3, 4, 5, 6};
|
||||
std::vector<int64_t> unflatten_dims = {5, 3};
|
||||
|
||||
OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<int32_t>("input", {5, 3}, input);
|
||||
test.AddInput<int64_t>("indices", {6}, indices);
|
||||
test.AddOutput<int32_t>("output", {6}, output);
|
||||
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(FlattenAndUnpadTest, Int32Type3D) {
|
||||
std::vector<int32_t> input = {0, 0, 0, 1, 2, 3, 0, 0, 0,
|
||||
4, 5, 6, 7, 8, 9, 0, 0, 0};
|
||||
std::vector<int64_t> indices = {1, 3, 4};
|
||||
|
||||
std::vector<int32_t> output = {1, 2, 3, 4, 5, 6, 7, 8, 9};
|
||||
std::vector<int64_t> unflatten_dims = {2, 3};
|
||||
|
||||
OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<int32_t>("input", {2, 3, 3}, input);
|
||||
test.AddInput<int64_t>("indices", {3}, indices);
|
||||
test.AddOutput<int32_t>("output", {3, 3}, output);
|
||||
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(FlattenAndUnpadTest, Int64Type2D) {
|
||||
std::vector<int64_t> input = {1, 1, 3, 2, 0, 3, 0, 4,
|
||||
0, 5, 0, 6, 0, 0, 0};
|
||||
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11};
|
||||
|
||||
std::vector<int64_t> output = {1, 2, 3, 4, 5, 6};
|
||||
std::vector<int64_t> unflatten_dims = {5, 3};
|
||||
|
||||
OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<int64_t>("input", {5, 3}, input);
|
||||
test.AddInput<int64_t>("indices", {6}, indices);
|
||||
test.AddOutput<int64_t>("output", {6}, output);
|
||||
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(FlattenAndUnpadTest, Int64Type3D) {
|
||||
std::vector<int64_t> input = {0, 0, 0, 1, 2, 3, 0, 0, 0,
|
||||
4, 5, 6, 7, 8, 9, 0, 0, 0};
|
||||
std::vector<int64_t> indices = {1, 3, 4};
|
||||
|
||||
std::vector<int64_t> output = {1, 2, 3, 4, 5, 6, 7, 8, 9};
|
||||
std::vector<int64_t> unflatten_dims = {2, 3};
|
||||
|
||||
OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<int64_t>("input", {2, 3, 3}, input);
|
||||
test.AddInput<int64_t>("indices", {3}, indices);
|
||||
test.AddOutput<int64_t>("output", {3, 3}, output);
|
||||
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(FlattenAndUnpadTest, FloatType2D) {
|
||||
std::vector<float> input = {1.0f, 1.0f, 3.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f,
|
||||
0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f};
|
||||
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11};
|
||||
|
||||
std::vector<float> output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f};
|
||||
std::vector<int64_t> unflatten_dims = {5, 3};
|
||||
|
||||
OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<float>("input", {5, 3}, input);
|
||||
test.AddInput<int64_t>("indices", {6}, indices);
|
||||
test.AddOutput<float>("output", {6}, output);
|
||||
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(FlattenAndUnpadTest, FloatType3D) {
|
||||
std::vector<float> input = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f,
|
||||
4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f};
|
||||
std::vector<int64_t> indices = {1, 3, 4};
|
||||
|
||||
std::vector<float> output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f, 7.f, 8.f, 9.f};
|
||||
std::vector<int64_t> unflatten_dims = {2, 3};
|
||||
|
||||
OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<float>("input", {2, 3, 3}, input);
|
||||
test.AddInput<int64_t>("indices", {3}, indices);
|
||||
test.AddOutput<float>("output", {3, 3}, output);
|
||||
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(FlattenAndUnpadTest, MLFloat16Type2D) {
|
||||
std::vector<float> input = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f,
|
||||
0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f};
|
||||
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11};
|
||||
|
||||
std::vector<float> output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f};
|
||||
std::vector<int64_t> unflatten_dims = {5, 3};
|
||||
|
||||
std::vector<MLFloat16> input_half;
|
||||
input_half.resize(input.size());
|
||||
ConvertFloatToMLFloat16(input.data(), input_half.data(), static_cast<int>(input.size()));
|
||||
std::vector<MLFloat16> output_half;
|
||||
output_half.resize(output.size());
|
||||
ConvertFloatToMLFloat16(output.data(), output_half.data(), static_cast<int>(output.size()));
|
||||
|
||||
OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<MLFloat16>("input", {5, 3}, input_half);
|
||||
test.AddInput<int64_t>("indices", {6}, indices);
|
||||
test.AddOutput<MLFloat16>("output", {6}, output_half);
|
||||
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(FlattenAndUnpadTest, MLFloat16Type3D) {
|
||||
std::vector<float> input = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f,
|
||||
4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f};
|
||||
std::vector<int64_t> indices = {1, 3, 4};
|
||||
|
||||
std::vector<float> output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f, 7.f, 8.f, 9.f};
|
||||
std::vector<int64_t> unflatten_dims = {2, 3};
|
||||
|
||||
std::vector<MLFloat16> input_half;
|
||||
input_half.resize(input.size());
|
||||
ConvertFloatToMLFloat16(input.data(), input_half.data(), static_cast<int>(input.size()));
|
||||
std::vector<MLFloat16> output_half;
|
||||
output_half.resize(output.size());
|
||||
ConvertFloatToMLFloat16(output.data(), output_half.data(), static_cast<int>(output.size()));
|
||||
|
||||
OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<MLFloat16>("input", {2, 3, 3}, input_half);
|
||||
test.AddInput<int64_t>("indices", {3}, indices);
|
||||
test.AddOutput<MLFloat16>("output", {3, 3}, output_half);
|
||||
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -17,14 +17,11 @@ TEST(PadAndUnflattenTest, FloatType1D) {
|
|||
std::vector<float> output = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f,
|
||||
0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f};
|
||||
|
||||
std::vector<int64_t> full_flatten_dims = {15};
|
||||
|
||||
OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<float>("input", {6}, input);
|
||||
test.AddInput<int64_t>("indices", {6}, indices);
|
||||
test.AddInput<int64_t>("unflatten_dims", {2}, unflatten_dims);
|
||||
test.AddOutput<float>("output", {5, 3}, output);
|
||||
test.AddOutput<int64_t>("full_flatten_dims", {1}, full_flatten_dims);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
|
|
@ -36,14 +33,11 @@ TEST(PadAndUnflattenTest, FloatType2D) {
|
|||
std::vector<float> output = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f,
|
||||
4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f};
|
||||
|
||||
std::vector<int64_t> full_flatten_dims = {6, 3};
|
||||
|
||||
OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<float>("input", {3, 3}, input);
|
||||
test.AddInput<int64_t>("indices", {3}, indices);
|
||||
test.AddInput<int64_t>("unflatten_dims", {2}, unflatten_dims);
|
||||
test.AddOutput<float>("output", {2, 3, 3}, output);
|
||||
test.AddOutput<int64_t>("full_flatten_dims", {2}, full_flatten_dims);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
|
|
@ -55,8 +49,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type1D) {
|
|||
std::vector<float> output = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f,
|
||||
0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f};
|
||||
|
||||
std::vector<int64_t> full_flatten_dims = {15};
|
||||
|
||||
std::vector<MLFloat16> input_half;
|
||||
input_half.resize(input.size());
|
||||
ConvertFloatToMLFloat16(input.data(), input_half.data(), int(input.size()));
|
||||
|
|
@ -69,7 +61,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type1D) {
|
|||
test.AddInput<int64_t>("indices", {6}, indices);
|
||||
test.AddInput<int64_t>("unflatten_dims", {2}, unflatten_dims);
|
||||
test.AddOutput<MLFloat16>("output", {5, 3}, output_half);
|
||||
test.AddOutput<int64_t>("full_flatten_dims", {1}, full_flatten_dims);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
|
|
@ -81,8 +72,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type2D) {
|
|||
std::vector<float> output = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f,
|
||||
4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f};
|
||||
|
||||
std::vector<int64_t> full_flatten_dims = {6, 3};
|
||||
|
||||
std::vector<MLFloat16> input_half;
|
||||
input_half.resize(input.size());
|
||||
ConvertFloatToMLFloat16(input.data(), input_half.data(), int(input.size()));
|
||||
|
|
@ -95,7 +84,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type2D) {
|
|||
test.AddInput<int64_t>("indices", {3}, indices);
|
||||
test.AddInput<int64_t>("unflatten_dims", {2}, unflatten_dims);
|
||||
test.AddOutput<MLFloat16>("output", {2, 3, 3}, output_half);
|
||||
test.AddOutput<int64_t>("full_flatten_dims", {2}, full_flatten_dims);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -207,6 +207,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuantGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BatchScale);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PadAndUnflatten);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, FlattenAndUnpad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ScaledSum);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ResizeGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ResizeGrad);
|
||||
|
|
@ -462,6 +463,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuantGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BatchScale)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PadAndUnflatten)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, FlattenAndUnpad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ScaledSum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ResizeGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ResizeGrad)>,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,91 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/training_ops/cuda/tensor/flatten_and_unpad.h"
|
||||
#include "orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h"
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
FlattenAndUnpad,
|
||||
kMSDomain,
|
||||
1,
|
||||
kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", BuildKernelDefConstraints<int32_t, int64_t, MLFloat16, float, double, BFloat16>())
|
||||
.TypeConstraint("T_INT", DataTypeImpl::GetTensorType<int64_t>())
|
||||
.OutputMemoryType(OrtMemTypeCPUOutput, 1),
|
||||
FlattenAndUnpad);
|
||||
|
||||
// Put implementation in the anonymous namespace to avoid name collision in the global namespace.
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
struct FlattenAndUnpadFunctor {
|
||||
void operator()(cudaStream_t stream,
|
||||
const int64_t output_element_count,
|
||||
const fast_divmod output_element_stride_fdm,
|
||||
const int64_t index_value_upper_bound,
|
||||
const Tensor& input_tensor,
|
||||
const Tensor& indices_tensor,
|
||||
Tensor& output_tensor) const {
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
const CudaT* input_data = reinterpret_cast<const CudaT*>(input_tensor.Data<T>());
|
||||
|
||||
FlattenAndUnpadImpl<CudaT>(stream, output_element_count, output_element_stride_fdm, index_value_upper_bound,
|
||||
input_data, indices_tensor.Data<int64_t>(),
|
||||
reinterpret_cast<CudaT*>(output_tensor.MutableData<T>()));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
Status FlattenAndUnpad::ComputeInternal(OpKernelContext* context) const {
|
||||
const Tensor* input_tensor = context->Input<Tensor>(0);
|
||||
const Tensor* indices_tensor = context->Input<Tensor>(1);
|
||||
ORT_ENFORCE(input_tensor->Shape().NumDimensions() >= 2,
|
||||
"input_tensor tensor must have at least 2 dimensions.", input_tensor->Shape().NumDimensions());
|
||||
ORT_ENFORCE(indices_tensor->Shape().NumDimensions() == 1,
|
||||
"indices_tensor tensor must be 1-D.", indices_tensor->Shape().NumDimensions());
|
||||
|
||||
const auto& input_shape = input_tensor->Shape();
|
||||
std::vector<int64_t> output_shape_vec;
|
||||
output_shape_vec.reserve(input_shape.NumDimensions() - 1);
|
||||
output_shape_vec.push_back(indices_tensor->Shape()[0]);
|
||||
int64_t element_stride = 1;
|
||||
for (size_t i = 2; i < input_shape.NumDimensions(); ++i) {
|
||||
output_shape_vec.push_back(input_shape[i]);
|
||||
element_stride *= input_shape[i];
|
||||
}
|
||||
|
||||
fast_divmod output_element_stride_fdm(static_cast<int>(element_stride));
|
||||
auto output_shape = TensorShape(output_shape_vec);
|
||||
Tensor* output_tensor = context->Output(0, output_shape);
|
||||
|
||||
std::vector<int64_t> unflatten_dims_vec;
|
||||
unflatten_dims_vec.reserve(2);
|
||||
unflatten_dims_vec.push_back(input_shape[0]);
|
||||
unflatten_dims_vec.push_back(input_shape[1]);
|
||||
const int64_t index_value_upper_bound = input_shape[0] * input_shape[1];
|
||||
|
||||
utils::MLTypeCallDispatcher<int32_t, int64_t, float, MLFloat16, double, BFloat16>
|
||||
t_disp(input_tensor->GetElementType());
|
||||
t_disp.Invoke<FlattenAndUnpadFunctor>(Stream(context),
|
||||
output_shape.Size(),
|
||||
output_element_stride_fdm,
|
||||
index_value_upper_bound,
|
||||
*input_tensor,
|
||||
*indices_tensor,
|
||||
*output_tensor);
|
||||
|
||||
size_t rank = unflatten_dims_vec.size();
|
||||
Tensor* unflatten_dims_tensor = context->Output(1, {static_cast<int>(rank)});
|
||||
TensorShape(unflatten_dims_vec).CopyDims(unflatten_dims_tensor->MutableData<int64_t>(), rank);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/cuda/cuda_kernel.h"
|
||||
#include "core/providers/common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
class FlattenAndUnpad final : public CudaKernel {
|
||||
public:
|
||||
FlattenAndUnpad(const OpKernelInfo& info) : CudaKernel(info) {
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h"
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
constexpr int kBlockSize = 256;
|
||||
constexpr int kNumUnroll = 4;
|
||||
|
||||
template <typename T>
|
||||
__global__ void ExtractIputWithIndexKernel(const CUDA_LONG N,
|
||||
const fast_divmod output_element_stride_fdm,
|
||||
const int64_t index_value_upper_bound,
|
||||
const T* input_data,
|
||||
const int64_t* indices_data,
|
||||
T* output_data) {
|
||||
CUDA_LONG idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
CUDA_LONG id = idx * kNumUnroll;
|
||||
|
||||
T input[kNumUnroll];
|
||||
if (id < N) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumUnroll; ++i) {
|
||||
CUDA_LONG li = id + i;
|
||||
if (li < N) {
|
||||
int row_index, col_index;
|
||||
output_element_stride_fdm.divmod(li, row_index, col_index);
|
||||
assert(indices_data[row_index] < index_value_upper_bound);
|
||||
input[i] = input_data[indices_data[row_index] * output_element_stride_fdm.d_ + col_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumUnroll; ++i) {
|
||||
CUDA_LONG li = id + i;
|
||||
if (li < N) {
|
||||
output_data[li] = input[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void FlattenAndUnpadImpl(cudaStream_t stream,
|
||||
const int64_t total_element_count,
|
||||
const fast_divmod output_element_stride_fdm,
|
||||
const int64_t index_value_upper_bound,
|
||||
const T* input_data,
|
||||
const int64_t* indices_data,
|
||||
T* output_data) {
|
||||
const int blocksPerGrid = static_cast<int>(CeilDiv(total_element_count, kBlockSize * kNumUnroll));
|
||||
ExtractIputWithIndexKernel<T><<<blocksPerGrid, kBlockSize, 0, stream>>>(
|
||||
static_cast<CUDA_LONG>(total_element_count),
|
||||
output_element_stride_fdm,
|
||||
index_value_upper_bound,
|
||||
input_data,
|
||||
indices_data,
|
||||
output_data);
|
||||
}
|
||||
|
||||
#define FLATTEN_AND_UNPAD_IMPL(T) \
|
||||
template void FlattenAndUnpadImpl<T>(cudaStream_t stream, \
|
||||
const int64_t total_element_count, \
|
||||
const fast_divmod output_element_stride_fdm, \
|
||||
const int64_t index_value_upper_bound, \
|
||||
const T* input_data, \
|
||||
const int64_t* indices_data, \
|
||||
T* output_data);
|
||||
|
||||
FLATTEN_AND_UNPAD_IMPL(float)
|
||||
FLATTEN_AND_UNPAD_IMPL(double)
|
||||
FLATTEN_AND_UNPAD_IMPL(half)
|
||||
FLATTEN_AND_UNPAD_IMPL(BFloat16)
|
||||
FLATTEN_AND_UNPAD_IMPL(int32_t)
|
||||
FLATTEN_AND_UNPAD_IMPL(int64_t)
|
||||
|
||||
#undef FLATTEN_AND_UNPAD_FROM_MASK_IMPL
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include "core/providers/rocm/shared_inc/rocm_utils.h"
|
||||
#else
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
void FlattenAndUnpadImpl(cudaStream_t stream,
|
||||
const int64_t total_element_count,
|
||||
const fast_divmod output_element_stride_fdm,
|
||||
const int64_t index_value_upper_bound,
|
||||
const T* input_data,
|
||||
const int64_t* indices_data,
|
||||
T* output_data);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -17,8 +17,7 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
.TypeConstraint("T", BuildKernelDefConstraints<MLFloat16, float, double, BFloat16>())
|
||||
.TypeConstraint("T_INT", DataTypeImpl::GetTensorType<int64_t>())
|
||||
.TypeConstraint("T_INDEX", DataTypeImpl::GetTensorType<int64_t>())
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 2)
|
||||
.OutputMemoryType(OrtMemTypeCPUOutput, 1),
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 2),
|
||||
PadAndUnflatten);
|
||||
|
||||
// Put implementation in the anonymous namespace to avoid name collision in the global namespace.
|
||||
|
|
@ -63,14 +62,11 @@ Status PadAndUnflatten::ComputeInternal(OpKernelContext* context) const {
|
|||
output_shape_vec.push_back(dims_ptr[0]);
|
||||
output_shape_vec.push_back(dims_ptr[1]);
|
||||
|
||||
std::vector<int64_t> full_size_flatten_shape_vec;
|
||||
const int64_t flatten_dim_factor = dims_ptr[0] * dims_ptr[1];
|
||||
full_size_flatten_shape_vec.push_back(flatten_dim_factor);
|
||||
|
||||
int64_t element_stride = 1;
|
||||
for (size_t i = 1; i < input_shape.NumDimensions(); ++i) {
|
||||
output_shape_vec.push_back(input_shape[i]);
|
||||
full_size_flatten_shape_vec.push_back(input_shape[i]);
|
||||
element_stride *= input_shape[i];
|
||||
}
|
||||
|
||||
|
|
@ -87,11 +83,6 @@ Status PadAndUnflatten::ComputeInternal(OpKernelContext* context) const {
|
|||
*indices_tensor,
|
||||
*output_tensor);
|
||||
|
||||
// Set input shape output tensor.
|
||||
size_t rank = full_size_flatten_shape_vec.size();
|
||||
Tensor* input_shape_tensor = context->Output(1, {static_cast<int>(rank)});
|
||||
TensorShape(full_size_flatten_shape_vec).CopyDims(input_shape_tensor->MutableData<int64_t>(), rank);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ void PadAndUnflattenImpl(cudaStream_t stream,
|
|||
output_data);
|
||||
}
|
||||
|
||||
#define SPECIALIZED_RESTORE_FROM_MASK_IMPL(T) \
|
||||
#define PAD_AND_UNFLATTEN_IMPL(T) \
|
||||
template void PadAndUnflattenImpl<T>(cudaStream_t stream, \
|
||||
const int64_t total_element_count, \
|
||||
const fast_divmod output_element_stride_fdm, \
|
||||
|
|
@ -70,12 +70,12 @@ void PadAndUnflattenImpl(cudaStream_t stream,
|
|||
const int64_t* indices_data, \
|
||||
T* output_data);
|
||||
|
||||
SPECIALIZED_RESTORE_FROM_MASK_IMPL(float)
|
||||
SPECIALIZED_RESTORE_FROM_MASK_IMPL(double)
|
||||
SPECIALIZED_RESTORE_FROM_MASK_IMPL(half)
|
||||
SPECIALIZED_RESTORE_FROM_MASK_IMPL(BFloat16)
|
||||
PAD_AND_UNFLATTEN_IMPL(float)
|
||||
PAD_AND_UNFLATTEN_IMPL(double)
|
||||
PAD_AND_UNFLATTEN_IMPL(half)
|
||||
PAD_AND_UNFLATTEN_IMPL(BFloat16)
|
||||
|
||||
#undef SPECIALIZED_RESTORE_FROM_MASK_IMPL
|
||||
#undef PAD_AND_UNFLATTEN_FROM_MASK_IMPL
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -187,6 +187,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_BFloat16, ReduceAllL2);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, PadAndUnflatten);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, FlattenAndUnpad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ResizeGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ResizeGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ResizeGrad);
|
||||
|
|
@ -390,6 +391,7 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_BFloat16, ReduceAllL2)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, PadAndUnflatten)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, FlattenAndUnpad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ResizeGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ResizeGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ResizeGrad)>,
|
||||
|
|
|
|||
Loading…
Reference in a new issue