Use PadAndUnflatten to replace GatherGrad for restore (#16429)

### Use PadAndUnflatten to replace GatherGrad for restore




### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
pengwa 2023-06-27 15:07:20 +08:00 committed by GitHub
parent ae6da03438
commit 403bebfb51
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 490 additions and 198 deletions

View file

@ -752,14 +752,14 @@ IMPLEMENT_GRADIENT_BUILDER(GetGatherGradient) {
SrcNodeAttributes())};
}
IMPLEMENT_GRADIENT_BUILDER(GetGatherGradGradient) {
// TODO: Strictly speaking, GatherGrad's gradient is not alway Gather when the indices have repeated values.
// Since GatherGrad in foward path is only used by embed sparsity feature in which case the indices are unique,
// we can safely use Gather here. But we will adress this issue as soon as possible.
IMPLEMENT_GRADIENT_BUILDER(GetPadAndUnflattenGradient) {
return std::vector<NodeDef>{
NodeDef(OpDef("Reshape"),
{GO(0), O(1)},
{IA("GO_reshaped")}),
NodeDef(OpDef{"Gather", kOnnxDomain, 1},
{GO(0), I(1)},
{GI(2)},
{IA("GO_reshaped"), I(1)},
{GI(0)},
SrcNodeAttributes())};
}

View file

@ -39,7 +39,7 @@ DECLARE_GRADIENT_BUILDER(GetPoolGradient)
DECLARE_GRADIENT_BUILDER(GetAveragePoolGradient)
DECLARE_GRADIENT_BUILDER(GetMaxPoolGradient)
DECLARE_GRADIENT_BUILDER(GetGatherGradient)
DECLARE_GRADIENT_BUILDER(GetGatherGradGradient)
DECLARE_GRADIENT_BUILDER(GetPadAndUnflattenGradient)
DECLARE_GRADIENT_BUILDER(GetShrunkenGatherGradient)
DECLARE_GRADIENT_BUILDER(GetConvGradient)
DECLARE_GRADIENT_BUILDER(GetUnsqueezeGradient)

View file

@ -70,7 +70,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Gemm", GetGemmGradient);
REGISTER_GRADIENT_BUILDER("MaxPool", GetMaxPoolGradient);
REGISTER_GRADIENT_BUILDER("Gather", GetGatherGradient);
REGISTER_GRADIENT_BUILDER("GatherGrad", GetGatherGradGradient);
REGISTER_GRADIENT_BUILDER("PadAndUnflatten", GetPadAndUnflattenGradient);
REGISTER_GRADIENT_BUILDER("ShrunkenGather", GetShrunkenGatherGradient);
REGISTER_GRADIENT_BUILDER("Conv", GetConvGradient);
REGISTER_GRADIENT_BUILDER("Squeeze", GetSqueezeGradient);

View file

@ -4571,6 +4571,45 @@ Return true if all elements are true and false otherwise.
updateOutputShape(ctx, 6, {num_directions, three * hidden_size});
}
});
ONNX_CONTRIB_OPERATOR_SCHEMA(PadAndUnflatten)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc(
"PadAndUnflatten operator pads zero on the first axis, and unflatten the axis into two axes according"
"to given unflatten_dims. This is used by padding elimination graph transformers."
"For each index in indices, the corresponding value in output comes from input."
"For other indices, the corresponding value in output will be padded to zero."
"The indices don't allow duplicated index values, otherwise, though there is no runtime check"
"(in case of performance concern), the behaviour of output is undefined."
"An example:"
" input: [[1, 2, 3, 4], [5, 6, 7, 8]], shape is [2, 4]"
" indices: [0, 5], shape is [2]"
" unflatten_dims: [2, 3], shape is [2]"
" output: [[[1, 2, 3, 4], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [5, 6, 7, 8]]],"
" shape is [2, 3, 4]"
" flatten_output_shape: [6, 4], shape is [2]")
.Input(0, "input", "input data of rank N, shape is [d1, d2, ..., dN]", "T")
.Input(1, "indices", "1D Tensor of int32/int64 indices, shape is [d1], each element's value ranges in [0, M1*M2).",
"T_INDEX")
.Input(2, "unflatten_dims", "1D tensor with two values, [M1, M2].", "T_INT")
.Output(0, "output", "output data of rank N+1, [M1, M2, d2, ..., dN]", "T")
.Output(1, "flatten_output_shape", "1D tensor with output shape, [M1*M2, d2, ..., dN]", "T_INT")
.TypeConstraint(
"T_INT",
{"tensor(int32)", "tensor(int64)"},
"Constrain shape to integer tensors.")
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.")
.TypeConstraint(
"T_INDEX",
{"tensor(int32)", "tensor(int64)"},
"Constrain indices to integer types");
}
} // namespace training

View file

@ -61,34 +61,6 @@ NodeArg* GetDimsValue(Graph& graph, NodeArg* input, NodeArg* indices_arg, Node&
return gather_out_args[0];
}
// Insert Shape + ScatterElements to get an updated shape of input with index of indices_arg updated to
// the value of update_value.
// Such as, if the indices_arg is a initializer of [0] and the original shape of input is [valid_token_count, a, b, c],
// this function will return a shape of [update_value, a, b, c]
NodeArg* UpdateShape(Graph& graph, NodeArg* input, NodeArg* update_value, NodeArg* indices_arg, Node& node) {
InlinedVector<NodeArg*> shape_output_args{&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("shape_result"),
nullptr)};
Node& shape_node = graph.AddNode(graph.GenerateNodeName("shape"), "Shape", "", {input},
shape_output_args, nullptr, kOnnxDomain);
ORT_ENFORCE(graph.SetOpSchemaFromRegistryForNode(shape_node), "Failed to get shape for " + shape_node.Name());
shape_node.SetExecutionProviderType(node.GetExecutionProviderType());
InlinedVector<NodeArg*> scatter_input_args;
scatter_input_args.push_back(shape_output_args[0]);
scatter_input_args.push_back(indices_arg);
scatter_input_args.push_back(update_value);
InlinedVector<NodeArg*> scatter_out_args{&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("scatter_result"),
nullptr)};
Node& scatter_node = graph.AddNode(graph.GenerateNodeName("update_dim"), "ScatterElements", "", scatter_input_args,
scatter_out_args, nullptr, kOnnxDomain);
ORT_ENFORCE(graph.SetOpSchemaFromRegistryForNode(scatter_node), "Failed to update shape for " + scatter_node.Name());
scatter_node.SetExecutionProviderType(node.GetExecutionProviderType());
return scatter_out_args[0];
}
// Insert Reshape + ShrunkenGather to flatten the in_index-th input of node.
// The gather_index_arg is the indices of the elements that are not padding.
NodeArg* InsertNodesForInput(Graph& graph,
@ -125,7 +97,8 @@ NodeArg* InsertNodesForInput(Graph& graph,
InlinedVector<NodeArg*> reshape_output_args;
reshape_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("inputs_reshape_result"), node.MutableInputDefs()[in_index]->TypeAsProto()));
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("inputs_reshape_result"),
node.MutableInputDefs()[in_index]->TypeAsProto()));
Node* new_reshape_node = InsertIntermediateNodeOnDestInput(
graph, node,
@ -175,7 +148,7 @@ NodeArg* InsertNodesForInput(Graph& graph,
return gather_out_arg;
}
// Insert GatherGrad + Reshape to unflatten the shape of the in_index-th input of node.
// Insert PadAndUnflatten to unflatten the shape of the in_index-th input of node.
// The gathergrad_index_arg is the indices of the elements that are not padding.
// The new_shape_arg is the shape of [batch_size * seqlen, ...]
// gathergrad_index_arg and new_shape_arg are the arguments needed by GatherGrad.
@ -183,100 +156,42 @@ NodeArg* InsertNodesForOutput(Graph& graph,
Node& node,
uint32_t in_index,
NodeArg* gathergrad_index_arg,
NodeArg* new_shape_arg,
NodeArg* first_two_dims_arg,
const logging::Logger& logger) {
std::vector<int64_t> other_indices;
auto input_shape = node.InputDefs()[in_index]->Shape();
for (int k = 2; k < input_shape->dim_size(); k++) {
// When executing, Shape of node here has been flattened, so the indices should be k-1.
other_indices.push_back(int64_t(k) - 1);
}
InlinedVector<NodeArg*> pad_node_input_args;
pad_node_input_args.reserve(3);
pad_node_input_args.push_back(node.MutableInputDefs()[in_index]);
pad_node_input_args.push_back(gathergrad_index_arg);
pad_node_input_args.push_back(first_two_dims_arg);
// Construct the unflattened_shape_arg of [batch_size, seqlen, ...]
NodeArg* unflattened_shape_arg = nullptr;
if (other_indices.empty()) {
unflattened_shape_arg = first_two_dims_arg;
} else {
// If the shape size of the in_index-th input of node is larger than 2 dims, we need to concat the first two dims
// of [batch_size, seqlen] and the other dims together.
ONNX_NAMESPACE::TensorProto other_indices_const_tensor;
other_indices_const_tensor.set_name(graph.GenerateNodeArgName("other_shape"));
other_indices_const_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
other_indices_const_tensor.add_dims(other_indices.size());
other_indices_const_tensor.set_raw_data(other_indices.data(), other_indices.size() * sizeof(int64_t));
NodeArg* other_indices_arg = &graph_utils::AddInitializer(graph, other_indices_const_tensor);
NodeArg* other_dims_arg = GetDimsValue(graph, node.MutableInputDefs()[in_index], other_indices_arg, node);
InlinedVector<NodeArg*> concat_input_args;
concat_input_args.push_back(first_two_dims_arg);
concat_input_args.push_back(other_dims_arg);
InlinedVector<NodeArg*> concat_output_args{&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("concat_shape_result"),
nullptr)};
onnxruntime::NodeAttributes attributes;
attributes["axis"] = ONNX_NAMESPACE::MakeAttribute("axis", int64_t(0));
Node& concat_node = graph.AddNode(graph.GenerateNodeName("concat_shape"), "Concat", "", concat_input_args,
concat_output_args, &attributes, kOnnxDomain);
ORT_ENFORCE(graph.SetOpSchemaFromRegistryForNode(concat_node), "Failed to concat shape for " + concat_node.Name());
concat_node.SetExecutionProviderType(node.GetExecutionProviderType());
unflattened_shape_arg = concat_output_args[0];
}
InlinedVector<NodeArg*> gathergrad_input_args;
gathergrad_input_args.reserve(3);
gathergrad_input_args.push_back(new_shape_arg);
gathergrad_input_args.push_back(gathergrad_index_arg);
gathergrad_input_args.push_back(node.MutableInputDefs()[in_index]);
InlinedVector<NodeArg*> gathergrad_output_args;
gathergrad_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padding_recover_result"),
InlinedVector<NodeArg*> pad_node_output_args;
pad_node_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_result"),
nullptr));
pad_node_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_d1xd2_shape"),
nullptr));
Node* new_gathergrad_node = InsertIntermediateNodeOnDestInput(
graph, node,
in_index,
2,
0,
0 /* new_node_input_index*/,
0 /* new_node_output_index*/,
graph.GenerateNodeName("PaddingRecover"),
"GatherGrad",
"GatherGrad node to recover invalid tokens.",
gathergrad_input_args,
gathergrad_output_args,
"PadAndUnflatten",
"PadAndUnflatten node to recover invalid tokens.",
pad_node_input_args,
pad_node_output_args,
{},
kMSDomain,
logger);
new_gathergrad_node->SetExecutionProviderType(node.GetExecutionProviderType());
auto gathergrad_out_arg = new_gathergrad_node->MutableOutputDefs()[0];
InlinedVector<NodeArg*> reshape_input_args;
reshape_input_args.push_back(gathergrad_out_arg);
reshape_input_args.push_back(unflattened_shape_arg);
InlinedVector<NodeArg*> reshape_output_args{&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("reshape_result"),
nullptr)};
Node* new_reshape_node = InsertIntermediateNodeOnDestInput(
graph, node,
in_index,
0,
0,
graph.GenerateNodeName("RecoverShape"),
"Reshape",
"Reshape node to recover invalid tokens.",
reshape_input_args,
reshape_output_args,
{},
kOnnxDomain,
logger);
new_reshape_node->SetExecutionProviderType(node.GetExecutionProviderType());
return new_reshape_node->MutableOutputDefs()[0];
return new_gathergrad_node->MutableOutputDefs()[0];
}
// Iterate the subgraph beginning from the start_node, and put all node args into 'subgraph'
// Also put all candidate input nodes and cantidate output nodes of the subgraph into candidate_inputs and
// Also put all candidate input nodes and candidate output nodes of the subgraph into candidate_inputs and
// candidate_outputs respectively.
void IterateSubgraphFromNode(Graph& graph,
Node* start_node,
@ -368,7 +283,7 @@ void IterateSubgraphFromNode(Graph& graph,
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "MatMul", {1, 9, 13})) {
if (subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end()) {
// If shape of [batch_size, seqlen, ...] is propagated from the first argument of MatMul.
// The dim size of the first argument must larger than 2 to propagete the first two dims to the output.
// The dim size of the first argument must be larger than 2 to propagate the first two dims to the output.
// Or else the first two dims of the output will not be [batch_size, seqlen] and this MatMul will be added
// to candidate_outputs as the output of the subgraph.
if (cur->InputDefs()[0]->Shape()->dim_size() > 2) {
@ -376,17 +291,17 @@ void IterateSubgraphFromNode(Graph& graph,
PushAllOutputNode(graph, to_visit, cur, visited);
} else {
LOG_DEBUG_INFO(logger,
"PaddingElimination::dim size of left input of matmul smaller than 3 and \
this matmul would be output of subgraph.");
"PaddingElimination::dim size of left input of MatMul smaller than 3 and \
this MatMul would be the output of the subgraph.");
candidate_outputs.insert(cur);
continue;
}
} else if (subgraph.find(cur->MutableInputDefs()[1]) != subgraph.end()) {
LOG_DEBUG_INFO(logger, "PaddingElimination::right edge of matmul would not included.");
LOG_DEBUG_INFO(logger, "PaddingElimination::right edge of MatMul would not included.");
candidate_outputs.insert(cur);
continue;
} else {
ORT_THROW("PaddingElimination::found matmul node without input in subgraph.");
ORT_THROW("PaddingElimination::found MatMul node without input in subgraph.");
}
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "PythonOp", {1}, kMSDomain)) {
if (subgraph.find(cur->MutableInputDefs()[0]) == subgraph.end()) {
@ -451,7 +366,8 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
") due to embedding input is not in the sparse embedding input list.");
continue;
}
const ONNX_NAMESPACE::TensorProto* padding_initializer = graph_utils::GetConstantInitializer(graph, node.InputDefs()[2]->Name());
const ONNX_NAMESPACE::TensorProto* padding_initializer =
graph_utils::GetConstantInitializer(graph, node.InputDefs()[2]->Name());
if (padding_initializer != nullptr &&
padding_initializer->dims_size() == 0 &&
((padding_initializer->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT32) ||
@ -526,18 +442,6 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
}
}
std::vector<int64_t> first_indices;
first_indices.push_back(0);
ONNX_NAMESPACE::TensorProto first_indice_const_tensor;
first_indice_const_tensor.set_name(graph.GenerateNodeArgName("indices"));
first_indice_const_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
first_indice_const_tensor.add_dims(first_indices.size());
first_indice_const_tensor.set_raw_data(first_indices.data(), first_indices.size() * sizeof(int64_t));
NodeArg* first_index_arg = &graph_utils::AddInitializer(graph, first_indice_const_tensor);
// Get the first dim value of flattened input_ids which is batch_size * seq_len
NodeArg* first_dim = GetDimsValue(graph, reshape_output_args[0], first_index_arg, *embedding_node);
std::vector<int64_t> first_two_indices{0, 1};
ONNX_NAMESPACE::TensorProto first_two_indices_const_tensor;
first_two_indices_const_tensor.set_name(graph.GenerateNodeArgName("first_two_indices"));
@ -553,11 +457,7 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
for (const auto& node : candidate_outputs) {
for (uint32_t i = 0; i < node->InputDefs().size(); ++i) {
if (subgraph.find(node->MutableInputDefs()[i]) != subgraph.end()) {
// Get a shape of the i-th input of the node with first index updated to value of first_dim
// which is batch_size * seq_len. This shape arg will be used as the shape input of GatherGrad
NodeArg* shape_arg_for_gather_grad = UpdateShape(
graph, node->MutableInputDefs()[i], first_dim, first_index_arg, *node);
InsertNodesForOutput(graph, *node, i, squeeze_out_arg, shape_arg_for_gather_grad, first_two_dims_arg, logger);
InsertNodesForOutput(graph, *node, i, squeeze_out_arg, first_two_dims_arg, logger);
handled_output_count++;
}
}

View file

@ -16,15 +16,18 @@ namespace onnxruntime {
* @Class PaddingElimination
*
* @brief Graph transformer that eliminates unnecessary padding computation caused by embedding sparsity.
*
* In transformer trainings, input_ids are usually padded to the same length, which is the max sequence length,
* so its shape is [batch_size, sequence_length] or [sequence_length, batch_size]. This graph transformer
* tries to MERGE the leading two dimensions and REMOVE the padding on the merged
* dimension, i.e, [batch_size, sequence_length, ...] -> [batch_size * sequence_length, ...] ->
*
* This transformer is implemented in the following steps:
* 1. Iterate the graph and find the Embedding node that matches these requirements:
* (1) Its 2nd input is a graph input and its rank > 2 with the first two dimensions are dim_params which are
* actually batch_size and sequence_length.
* Note: Now only support the case of the first two dimensions to merged and remove the padding on the merged
* dimension, i.e, [batch_size, sequence_length, ...] -> [batch_size * sequence_length, ...] ->
* [valid_token, ... ]. In the future, we may support the case of any two consecutive dimensions to merged,
* such as [..., batch_size, sequence_length, ...].
* (2) Its 3nd input is a scalar constant initializer which is the padding idx that should >= 0.
* 1.1 The 2nd input is a graph input and its rank > 2, with the first two dimensions, are:
* [batch_size, sequence_length]. Both dimensions can be symbolic or concrete dim values.
* 1.2 The 3rd input(padding idx) is a scalar constant initializer, and should >= 0.
* 2. Append embedding node in node_to_scan_list.
* Iterate the node_to_scan_list, for each node,
* 2.1 Check if it is supported for pad elimination (from a pre-defined op list). If no, record this node as output
@ -42,6 +45,8 @@ namespace onnxruntime {
* This is needed to ensure not to affect subsequent computations
*
* For example, given the following graph:
* 1. `input_0` is a tensor that is an in-direct output of ATen embedding node.
* 2. `input_1` is a tensor that is NOT a direct or in-direct output of ATen embedding node.
*
* embed.weight input_ids [batch_size, seq_length] padding_idx [1] scale_grad_by_freq sparse
* \ \ / / /
@ -49,11 +54,14 @@ namespace onnxruntime {
* \ \ / / /
* \_________________\_________________________/________________/______________________/
* |
* Aten:embedding
* ATen:embedding
* |
* |
* input |
* \
* - - - - - - - - - - - -|
* | |
* input_0 | input_1
* \ | /
* \__________ | ___________/
* \ | /
* Subgraph
*
* |
@ -83,40 +91,36 @@ namespace onnxruntime {
* \______________________\________________________________/__________________/________________/
* |
* Aten:embedding
* _ _ _ _ _ __ _ _ _ __ _ _|
* / |
* input_node |
* \ [batch_size, seq_length] |
* \ |
* \ [-1] |
* \ / |
* Reshape (valid_token_index) |
* \ / |
* ShrunkenGather | shape:[valid_token, ...]
* \ |
* shape:[valid_token] \ |
* \ |
* candidate_input_node |
* \ |
* \ |
* - - - - - - - - - - - - - - - - - - - - |
* | |
* input_0 | input_1
* \ [batch_size, seq_length, ...] | |
* \ | [batch_size, seq_length, ...]
* \ [-1] | |
* \ / | |
* Reshape (valid_token_index) | Reshape (valid_token_index)
* \ / | \ /
* ShrunkenGather shape:[valid_token, ...] ShrunkenGather
* \ | /
* shape:[valid_token, ...] \ | /
* \ | /
* candidate_input_node | candidate_input_node
* \ | /
* \ | /
*
* Subgraph
*
* |
* | shape:[valid_token]
* | shape:[valid_token, ...]
* |
* | (valid_token_index)
* | / ________________ (unflatten_dims), shape:[2],
* | / / value:[batch_size, seq_length]
* | / /
* PadAndUnflatten
* |
* [batch_size*seq_length] (valid_token_index) |
* \ | /
* \ | /
* \ | /
*
* GatherGrad
* |
* Reshape
* |
* | [batch_size, valid_token]
* candidate_output_node
* | [batch_size, seq_length, ...]
* candidate_output_node
*
*
*

View file

@ -177,8 +177,12 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
transformers.emplace_back(std::make_unique<UpStreamReshapeGraphTransformer>(compatible_eps));
transformers.emplace_back(std::make_unique<InsertGatherBeforeSceLoss>(compatible_eps,
config.sparse_label_input_names));
#if defined(USE_CUDA) || defined(USE_ROCM)
// Put this under CUDA/ROCM guard as it depends on PadAndUnflatten CUDA/ROCM kernel.
// Once we have a CPU kernel for PadAndUnflatten, we can remove the guard.
transformers.emplace_back(std::make_unique<PaddingElimination>(compatible_eps,
config.sparse_embedding_input_names));
#endif
}
} break;

View file

@ -1901,23 +1901,6 @@ TEST(GradientCheckerTest, GatherGrad) {
}
}
TEST(GradientCheckerTest, GatherGradGrad) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"GatherGrad", kMSDomain, 1};
TensorInfo shape_info({2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
TensorInfo indices_info({2, 2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
TensorInfo x_info({2, 2, 3});
std::vector<std::vector<float>> x_datas = {{6, 3}, {3, 5, 0, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}};
TensorShape y_shape{6, 3};
int64_t axis = 0;
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {shape_info, indices_info, x_info}, {y_shape}, &max_error,
x_datas, {MakeAttribute("axis", axis)}));
EXPECT_IS_TINY(max_error);
}
void TestDropoutOp(float ratio, TensorShape& x_shape, bool default_ratio = true) {
OpTester test("Dropout", 12, kOnnxDomain, false);
if (default_ratio) ratio = 0.5f;
@ -3016,6 +2999,34 @@ TEST(GradientCheckerTest, TriluGrad) {
}
}
// TODO (enable once found why it fails on ROCM)
#if defined(USE_CUDA)
TEST(GradientCheckerTest, PadAndUnflattenGrad) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"PadAndUnflatten", kMSDomain, 1};
TensorInfo shape_info({2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
TensorInfo indices_info({4}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
TensorInfo x_info({4, 3});
std::vector<std::vector<float>> x_datas = {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 5, 0, 1}, {5, 2}};
TensorInfo padded_out_info({5, 2, 3}, true);
TensorInfo out_shape_info({2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
execution_providers.emplace_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
execution_providers.emplace_back(DefaultRocmExecutionProvider());
#endif
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info, indices_info, shape_info},
{padded_out_info, out_shape_info}, &max_error,
x_datas, {}, true, false, &execution_providers));
EXPECT_IS_TINY(max_error);
}
#endif
} // namespace test
} // namespace onnxruntime

View file

@ -5886,12 +5886,12 @@ def test_ops_for_padding_elimination(test_cases):
assert len([node.op_type for node in training_model.graph.node if node.op_type == "Sub"]) == 1
assert len([node.op_type for node in training_model.graph.node if node.op_type == "NonZero"]) == 1
assert len([node.op_type for node in training_model.graph.node if node.op_type == "Squeeze"]) == 1
assert len([node.op_type for node in training_model.graph.node if node.op_type == "GatherGrad"]) == 1
assert len([node.op_type for node in training_model.graph.node if node.op_type == "PadAndUnflatten"]) == 1
if case == 2:
assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 2
else:
assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 1
gathergrad_node = [node for node in training_model.graph.node if node.op_type == "GatherGrad"][0]
gathergrad_node = [node for node in training_model.graph.node if node.op_type == "PadAndUnflatten"][0]
def find_input_node_type(model, arg):
result = []
@ -6057,5 +6057,5 @@ def test_e2e_padding_elimination():
training_model = ort_model._torch_module._execution_manager(True)._onnx_models.optimized_model
assert "ShrunkenGather" in [node.op_type for node in training_model.graph.node]
assert "GatherGrad" in [node.op_type for node in training_model.graph.node]
assert "PadAndUnflatten" in [node.op_type for node in training_model.graph.node]
del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"]

View file

@ -0,0 +1,105 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "test/common/tensor_op_test_utils.h"
#include "test/providers/provider_test_utils.h"
namespace onnxruntime {
namespace test {
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(PadAndUnflattenTest, FloatType1D) {
std::vector<float> input = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f};
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11};
std::vector<int64_t> unflatten_dims = {5, 3};
std::vector<float> output = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f,
0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> full_flatten_dims = {15};
OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain);
test.AddInput<float>("input", {6}, input);
test.AddInput<int64_t>("indices", {6}, indices);
test.AddInput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.AddOutput<float>("output", {5, 3}, output);
test.AddOutput<int64_t>("full_flatten_dims", {1}, full_flatten_dims);
test.Run();
}
TEST(PadAndUnflattenTest, FloatType2D) {
std::vector<float> input = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f, 7.f, 8.f, 9.f};
std::vector<int64_t> indices = {1, 3, 4};
std::vector<int64_t> unflatten_dims = {2, 3};
std::vector<float> output = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f,
4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> full_flatten_dims = {6, 3};
OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain);
test.AddInput<float>("input", {3, 3}, input);
test.AddInput<int64_t>("indices", {3}, indices);
test.AddInput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.AddOutput<float>("output", {2, 3, 3}, output);
test.AddOutput<int64_t>("full_flatten_dims", {2}, full_flatten_dims);
test.Run();
}
TEST(PadAndUnflattenTest, MLFloat16Type1D) {
std::vector<float> input = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f};
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11};
std::vector<int64_t> unflatten_dims = {5, 3};
std::vector<float> output = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f,
0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> full_flatten_dims = {15};
std::vector<MLFloat16> input_half;
input_half.resize(input.size());
ConvertFloatToMLFloat16(input.data(), input_half.data(), int(input.size()));
std::vector<MLFloat16> output_half;
output_half.resize(output.size());
ConvertFloatToMLFloat16(output.data(), output_half.data(), int(output.size()));
OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain);
test.AddInput<MLFloat16>("input", {6}, input_half);
test.AddInput<int64_t>("indices", {6}, indices);
test.AddInput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.AddOutput<MLFloat16>("output", {5, 3}, output_half);
test.AddOutput<int64_t>("full_flatten_dims", {1}, full_flatten_dims);
test.Run();
}
TEST(PadAndUnflattenTest, MLFloat16Type2D) {
std::vector<float> input = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f, 7.f, 8.f, 9.f};
std::vector<int64_t> indices = {1, 3, 4};
std::vector<int64_t> unflatten_dims = {2, 3};
std::vector<float> output = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f,
4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> full_flatten_dims = {6, 3};
std::vector<MLFloat16> input_half;
input_half.resize(input.size());
ConvertFloatToMLFloat16(input.data(), input_half.data(), int(input.size()));
std::vector<MLFloat16> output_half;
output_half.resize(output.size());
ConvertFloatToMLFloat16(output.data(), output_half.data(), int(output.size()));
OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain);
test.AddInput<MLFloat16>("input", {3, 3}, input_half);
test.AddInput<int64_t>("indices", {3}, indices);
test.AddInput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.AddOutput<MLFloat16>("output", {2, 3, 3}, output_half);
test.AddOutput<int64_t>("full_flatten_dims", {2}, full_flatten_dims);
test.Run();
}
#endif
} // namespace test
} // namespace onnxruntime

View file

@ -197,6 +197,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inpl
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuant);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuantGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PadAndUnflatten);
// the kernels within the following ifdef are not included in a build with
// --enable_training_ops but without --enable_training
@ -437,7 +438,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
kCudaExecutionProvider, kMSDomain, 1, float, FakeQuant)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSDomain, 1, float, FakeQuantGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PadAndUnflatten)>,
// the kernels within the following ifdef are not included in a build with
// --enable_training_ops but without --enable_training
#ifdef ENABLE_TRAINING

View file

@ -0,0 +1,99 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "orttraining/training_ops/cuda/tensor/pad_and_unflatten.h"
#include "orttraining/training_ops/cuda/tensor/pad_and_unflatten_impl.h"
#include "core/providers/cuda/shared_inc/cuda_utils.h"
namespace onnxruntime {
namespace cuda {
ONNX_OPERATOR_KERNEL_EX(
PadAndUnflatten,
kMSDomain,
1,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", BuildKernelDefConstraints<MLFloat16, float, double, BFloat16>())
.TypeConstraint("T_INT", DataTypeImpl::GetTensorType<int64_t>())
.TypeConstraint("T_INDEX", DataTypeImpl::GetTensorType<int64_t>())
.InputMemoryType(OrtMemTypeCPUInput, 2)
.OutputMemoryType(OrtMemTypeCPUOutput, 1),
PadAndUnflatten);
// Put implementation in the anonymous namespace to avoid name collision in the global namespace.
namespace {
template <typename T>
struct PadAndUnflattenFunctor {
void operator()(cudaStream_t stream,
const int64_t input_element_count,
const fast_divmod output_element_stride_fdm,
const int64_t index_value_upper_bound,
const Tensor& input_tensor,
const Tensor& indices_tensor,
Tensor& output_tensor) const {
typedef typename ToCudaType<T>::MappedType CudaT;
const CudaT* input_data = reinterpret_cast<const CudaT*>(input_tensor.Data<T>());
CUDA_CALL_THROW(cudaMemset(output_tensor.MutableDataRaw(), 0, output_tensor.Shape().Size() * sizeof(CudaT)));
PadAndUnflattenImpl<CudaT>(stream, input_element_count, output_element_stride_fdm, index_value_upper_bound,
input_data, indices_tensor.Data<int64_t>(),
reinterpret_cast<CudaT*>(output_tensor.MutableData<T>()));
}
};
} // namespace
Status PadAndUnflatten::ComputeInternal(OpKernelContext* context) const {
const Tensor* input_tensor = context->Input<Tensor>(0);
const Tensor* indices_tensor = context->Input<Tensor>(1);
const Tensor* unflatten_dims_tensor = context->Input<Tensor>(2); // Parse the 1-D shape tensor.
ORT_ENFORCE(unflatten_dims_tensor->Shape().NumDimensions() == 1,
"unflatten_dims_tensor tensor must be 1-D.", unflatten_dims_tensor->Shape().NumDimensions());
ORT_ENFORCE(unflatten_dims_tensor->Shape().Size() == 2,
"unflatten_dims_tensor tensor must contain 2 values.", unflatten_dims_tensor->Shape().Size());
const int64_t* dims_ptr = unflatten_dims_tensor->Data<int64_t>();
const auto& input_shape = input_tensor->Shape();
ORT_ENFORCE(input_shape[0] == indices_tensor->Shape()[0],
"The first dimension of input and indices must be the same.");
std::vector<int64_t> output_shape_vec;
output_shape_vec.push_back(dims_ptr[0]);
output_shape_vec.push_back(dims_ptr[1]);
std::vector<int64_t> full_size_flatten_shape_vec;
const int64_t flatten_dim_factor = dims_ptr[0] * dims_ptr[1];
full_size_flatten_shape_vec.push_back(flatten_dim_factor);
int64_t element_stride = 1;
for (size_t i = 1; i < input_shape.NumDimensions(); ++i) {
output_shape_vec.push_back(input_shape[i]);
full_size_flatten_shape_vec.push_back(input_shape[i]);
element_stride *= input_shape[i];
}
fast_divmod output_element_stride_fdm(static_cast<int>(element_stride));
auto output_shape = TensorShape(output_shape_vec);
Tensor* output_tensor = context->Output(0, output_shape);
utils::MLTypeCallDispatcher<float, MLFloat16, double, BFloat16> t_disp(input_tensor->GetElementType());
t_disp.Invoke<PadAndUnflattenFunctor>(Stream(context),
input_shape.Size(),
output_element_stride_fdm,
flatten_dim_factor,
*input_tensor,
*indices_tensor,
*output_tensor);
// Set input shape output tensor.
size_t rank = full_size_flatten_shape_vec.size();
Tensor* input_shape_tensor = context->Output(1, {static_cast<int>(rank)});
TensorShape(full_size_flatten_shape_vec).CopyDims(input_shape_tensor->MutableData<int64_t>(), rank);
return Status::OK();
}
} // namespace cuda
} // namespace onnxruntime

View file

@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/cuda/cuda_kernel.h"
#include "core/providers/common.h"
namespace onnxruntime {
namespace cuda {
class PadAndUnflatten final : public CudaKernel {
public:
PadAndUnflatten(const OpKernelInfo& info) : CudaKernel(info) {
}
Status ComputeInternal(OpKernelContext* context) const override;
};
} // namespace cuda
} // namespace onnxruntime

View file

@ -0,0 +1,81 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "orttraining/training_ops/cuda/tensor/pad_and_unflatten_impl.h"
#include "core/providers/cuda/cu_inc/common.cuh"
namespace onnxruntime {
namespace cuda {
constexpr int kBlockSize = 256;
constexpr int kNumUnroll = 4;
template <typename T>
__global__ void FillOutputWithIndexKernel(const CUDA_LONG N,
const fast_divmod output_element_stride_fdm,
const int64_t index_value_upper_bound,
const T* input_data,
const int64_t* indices_data,
T* output_data) {
CUDA_LONG idx = blockDim.x * blockIdx.x + threadIdx.x;
CUDA_LONG id = idx * kNumUnroll;
T input[kNumUnroll];
if (id < N) {
#pragma unroll
for (int i = 0; i < kNumUnroll; ++i) {
CUDA_LONG li = id + i;
if (li < N) {
input[i] = input_data[li];
}
}
}
#pragma unroll
for (int i = 0; i < kNumUnroll; ++i) {
CUDA_LONG li = id + i;
if (li < N) {
int row_index, col_index;
output_element_stride_fdm.divmod(li, row_index, col_index);
assert(indices_data[row_index] < index_value_upper_bound);
output_data[indices_data[row_index] * output_element_stride_fdm.d_ + col_index] = input[i];
}
}
}
template <typename T>
void PadAndUnflattenImpl(cudaStream_t stream,
const int64_t total_element_count,
const fast_divmod output_element_stride_fdm,
const int64_t index_value_upper_bound,
const T* input_data,
const int64_t* indices_data,
T* output_data) {
const int blocksPerGrid = static_cast<int>(CeilDiv(total_element_count, kBlockSize * kNumUnroll));
FillOutputWithIndexKernel<T><<<blocksPerGrid, kBlockSize, 0, stream>>>(
static_cast<CUDA_LONG>(total_element_count),
output_element_stride_fdm,
index_value_upper_bound,
input_data,
indices_data,
output_data);
}
#define SPECIALIZED_RESTORE_FROM_MASK_IMPL(T) \
template void PadAndUnflattenImpl<T>(cudaStream_t stream, \
const int64_t total_element_count, \
const fast_divmod output_element_stride_fdm, \
const int64_t index_value_upper_bound, \
const T* input_data, \
const int64_t* indices_data, \
T* output_data);
SPECIALIZED_RESTORE_FROM_MASK_IMPL(float)
SPECIALIZED_RESTORE_FROM_MASK_IMPL(double)
SPECIALIZED_RESTORE_FROM_MASK_IMPL(half)
SPECIALIZED_RESTORE_FROM_MASK_IMPL(BFloat16)
#undef SPECIALIZED_RESTORE_FROM_MASK_IMPL
} // namespace cuda
} // namespace onnxruntime

View file

@ -0,0 +1,25 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#ifdef USE_ROCM
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#else
#include "core/providers/cuda/shared_inc/cuda_utils.h"
#endif
namespace onnxruntime {
namespace cuda {
template <typename T>
void PadAndUnflattenImpl(cudaStream_t stream,
const int64_t total_element_count,
const fast_divmod output_element_stride_fdm,
const int64_t index_value_upper_bound,
const T* input_data,
const int64_t* indices_data,
T* output_data);
} // namespace cuda
} // namespace onnxruntime

View file

@ -182,6 +182,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_float, ReduceAllL2);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_BFloat16, ReduceAllL2);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, PadAndUnflatten);
#if defined(ORT_USE_NCCL) || defined(USE_MPI)
// P2P communication operators.
@ -378,6 +379,7 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_float, ReduceAllL2)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_BFloat16, ReduceAllL2)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, PadAndUnflatten)>,
// P2P communication operators.
#if defined(ORT_USE_NCCL) || defined(USE_MPI)