mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Use PadAndUnflatten to replace GatherGrad for restore (#16429)
### Use PadAndUnflatten to replace GatherGrad for restore ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
ae6da03438
commit
403bebfb51
16 changed files with 490 additions and 198 deletions
|
|
@ -752,14 +752,14 @@ IMPLEMENT_GRADIENT_BUILDER(GetGatherGradient) {
|
|||
SrcNodeAttributes())};
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetGatherGradGradient) {
|
||||
// TODO: Strictly speaking, GatherGrad's gradient is not alway Gather when the indices have repeated values.
|
||||
// Since GatherGrad in foward path is only used by embed sparsity feature in which case the indices are unique,
|
||||
// we can safely use Gather here. But we will adress this issue as soon as possible.
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetPadAndUnflattenGradient) {
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef(OpDef("Reshape"),
|
||||
{GO(0), O(1)},
|
||||
{IA("GO_reshaped")}),
|
||||
NodeDef(OpDef{"Gather", kOnnxDomain, 1},
|
||||
{GO(0), I(1)},
|
||||
{GI(2)},
|
||||
{IA("GO_reshaped"), I(1)},
|
||||
{GI(0)},
|
||||
SrcNodeAttributes())};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ DECLARE_GRADIENT_BUILDER(GetPoolGradient)
|
|||
DECLARE_GRADIENT_BUILDER(GetAveragePoolGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetMaxPoolGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetGatherGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetGatherGradGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetPadAndUnflattenGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetShrunkenGatherGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetConvGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetUnsqueezeGradient)
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
|
|||
REGISTER_GRADIENT_BUILDER("Gemm", GetGemmGradient);
|
||||
REGISTER_GRADIENT_BUILDER("MaxPool", GetMaxPoolGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Gather", GetGatherGradient);
|
||||
REGISTER_GRADIENT_BUILDER("GatherGrad", GetGatherGradGradient);
|
||||
REGISTER_GRADIENT_BUILDER("PadAndUnflatten", GetPadAndUnflattenGradient);
|
||||
REGISTER_GRADIENT_BUILDER("ShrunkenGather", GetShrunkenGatherGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Conv", GetConvGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Squeeze", GetSqueezeGradient);
|
||||
|
|
|
|||
|
|
@ -4571,6 +4571,45 @@ Return true if all elements are true and false otherwise.
|
|||
updateOutputShape(ctx, 6, {num_directions, three * hidden_size});
|
||||
}
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(PadAndUnflatten)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.SetDoc(
|
||||
"PadAndUnflatten operator pads zero on the first axis, and unflatten the axis into two axes according"
|
||||
"to given unflatten_dims. This is used by padding elimination graph transformers."
|
||||
"For each index in indices, the corresponding value in output comes from input."
|
||||
"For other indices, the corresponding value in output will be padded to zero."
|
||||
|
||||
"The indices don't allow duplicated index values, otherwise, though there is no runtime check"
|
||||
"(in case of performance concern), the behaviour of output is undefined."
|
||||
|
||||
"An example:"
|
||||
" input: [[1, 2, 3, 4], [5, 6, 7, 8]], shape is [2, 4]"
|
||||
" indices: [0, 5], shape is [2]"
|
||||
" unflatten_dims: [2, 3], shape is [2]"
|
||||
|
||||
" output: [[[1, 2, 3, 4], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [5, 6, 7, 8]]],"
|
||||
" shape is [2, 3, 4]"
|
||||
" flatten_output_shape: [6, 4], shape is [2]")
|
||||
.Input(0, "input", "input data of rank N, shape is [d1, d2, ..., dN]", "T")
|
||||
.Input(1, "indices", "1D Tensor of int32/int64 indices, shape is [d1], each element's value ranges in [0, M1*M2).",
|
||||
"T_INDEX")
|
||||
.Input(2, "unflatten_dims", "1D tensor with two values, [M1, M2].", "T_INT")
|
||||
.Output(0, "output", "output data of rank N+1, [M1, M2, d2, ..., dN]", "T")
|
||||
.Output(1, "flatten_output_shape", "1D tensor with output shape, [M1*M2, d2, ..., dN]", "T_INT")
|
||||
.TypeConstraint(
|
||||
"T_INT",
|
||||
{"tensor(int32)", "tensor(int64)"},
|
||||
"Constrain shape to integer tensors.")
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
|
||||
"Constrain input and output types to float tensors.")
|
||||
.TypeConstraint(
|
||||
"T_INDEX",
|
||||
{"tensor(int32)", "tensor(int64)"},
|
||||
"Constrain indices to integer types");
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
|
|
|
|||
|
|
@ -61,34 +61,6 @@ NodeArg* GetDimsValue(Graph& graph, NodeArg* input, NodeArg* indices_arg, Node&
|
|||
return gather_out_args[0];
|
||||
}
|
||||
|
||||
// Insert Shape + ScatterElements to get an updated shape of input with index of indices_arg updated to
|
||||
// the value of update_value.
|
||||
// Such as, if the indices_arg is a initializer of [0] and the original shape of input is [valid_token_count, a, b, c],
|
||||
// this function will return a shape of [update_value, a, b, c]
|
||||
NodeArg* UpdateShape(Graph& graph, NodeArg* input, NodeArg* update_value, NodeArg* indices_arg, Node& node) {
|
||||
InlinedVector<NodeArg*> shape_output_args{&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("shape_result"),
|
||||
nullptr)};
|
||||
Node& shape_node = graph.AddNode(graph.GenerateNodeName("shape"), "Shape", "", {input},
|
||||
shape_output_args, nullptr, kOnnxDomain);
|
||||
ORT_ENFORCE(graph.SetOpSchemaFromRegistryForNode(shape_node), "Failed to get shape for " + shape_node.Name());
|
||||
shape_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
|
||||
InlinedVector<NodeArg*> scatter_input_args;
|
||||
scatter_input_args.push_back(shape_output_args[0]);
|
||||
scatter_input_args.push_back(indices_arg);
|
||||
scatter_input_args.push_back(update_value);
|
||||
|
||||
InlinedVector<NodeArg*> scatter_out_args{&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("scatter_result"),
|
||||
nullptr)};
|
||||
|
||||
Node& scatter_node = graph.AddNode(graph.GenerateNodeName("update_dim"), "ScatterElements", "", scatter_input_args,
|
||||
scatter_out_args, nullptr, kOnnxDomain);
|
||||
ORT_ENFORCE(graph.SetOpSchemaFromRegistryForNode(scatter_node), "Failed to update shape for " + scatter_node.Name());
|
||||
scatter_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
|
||||
return scatter_out_args[0];
|
||||
}
|
||||
|
||||
// Insert Reshape + ShrunkenGather to flatten the in_index-th input of node.
|
||||
// The gather_index_arg is the indices of the elements that are not padding.
|
||||
NodeArg* InsertNodesForInput(Graph& graph,
|
||||
|
|
@ -125,7 +97,8 @@ NodeArg* InsertNodesForInput(Graph& graph,
|
|||
|
||||
InlinedVector<NodeArg*> reshape_output_args;
|
||||
reshape_output_args.push_back(
|
||||
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("inputs_reshape_result"), node.MutableInputDefs()[in_index]->TypeAsProto()));
|
||||
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("inputs_reshape_result"),
|
||||
node.MutableInputDefs()[in_index]->TypeAsProto()));
|
||||
|
||||
Node* new_reshape_node = InsertIntermediateNodeOnDestInput(
|
||||
graph, node,
|
||||
|
|
@ -175,7 +148,7 @@ NodeArg* InsertNodesForInput(Graph& graph,
|
|||
return gather_out_arg;
|
||||
}
|
||||
|
||||
// Insert GatherGrad + Reshape to unflatten the shape of the in_index-th input of node.
|
||||
// Insert PadAndUnflatten to unflatten the shape of the in_index-th input of node.
|
||||
// The gathergrad_index_arg is the indices of the elements that are not padding.
|
||||
// The new_shape_arg is the shape of [batch_size * seqlen, ...]
|
||||
// gathergrad_index_arg and new_shape_arg are the arguments needed by GatherGrad.
|
||||
|
|
@ -183,100 +156,42 @@ NodeArg* InsertNodesForOutput(Graph& graph,
|
|||
Node& node,
|
||||
uint32_t in_index,
|
||||
NodeArg* gathergrad_index_arg,
|
||||
NodeArg* new_shape_arg,
|
||||
NodeArg* first_two_dims_arg,
|
||||
const logging::Logger& logger) {
|
||||
std::vector<int64_t> other_indices;
|
||||
auto input_shape = node.InputDefs()[in_index]->Shape();
|
||||
for (int k = 2; k < input_shape->dim_size(); k++) {
|
||||
// When executing, Shape of node here has been flattened, so the indices should be k-1.
|
||||
other_indices.push_back(int64_t(k) - 1);
|
||||
}
|
||||
InlinedVector<NodeArg*> pad_node_input_args;
|
||||
pad_node_input_args.reserve(3);
|
||||
pad_node_input_args.push_back(node.MutableInputDefs()[in_index]);
|
||||
pad_node_input_args.push_back(gathergrad_index_arg);
|
||||
pad_node_input_args.push_back(first_two_dims_arg);
|
||||
|
||||
// Construct the unflattened_shape_arg of [batch_size, seqlen, ...]
|
||||
NodeArg* unflattened_shape_arg = nullptr;
|
||||
if (other_indices.empty()) {
|
||||
unflattened_shape_arg = first_two_dims_arg;
|
||||
} else {
|
||||
// If the shape size of the in_index-th input of node is larger than 2 dims, we need to concat the first two dims
|
||||
// of [batch_size, seqlen] and the other dims together.
|
||||
ONNX_NAMESPACE::TensorProto other_indices_const_tensor;
|
||||
other_indices_const_tensor.set_name(graph.GenerateNodeArgName("other_shape"));
|
||||
other_indices_const_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
other_indices_const_tensor.add_dims(other_indices.size());
|
||||
other_indices_const_tensor.set_raw_data(other_indices.data(), other_indices.size() * sizeof(int64_t));
|
||||
NodeArg* other_indices_arg = &graph_utils::AddInitializer(graph, other_indices_const_tensor);
|
||||
NodeArg* other_dims_arg = GetDimsValue(graph, node.MutableInputDefs()[in_index], other_indices_arg, node);
|
||||
|
||||
InlinedVector<NodeArg*> concat_input_args;
|
||||
concat_input_args.push_back(first_two_dims_arg);
|
||||
concat_input_args.push_back(other_dims_arg);
|
||||
|
||||
InlinedVector<NodeArg*> concat_output_args{&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("concat_shape_result"),
|
||||
nullptr)};
|
||||
|
||||
onnxruntime::NodeAttributes attributes;
|
||||
attributes["axis"] = ONNX_NAMESPACE::MakeAttribute("axis", int64_t(0));
|
||||
|
||||
Node& concat_node = graph.AddNode(graph.GenerateNodeName("concat_shape"), "Concat", "", concat_input_args,
|
||||
concat_output_args, &attributes, kOnnxDomain);
|
||||
ORT_ENFORCE(graph.SetOpSchemaFromRegistryForNode(concat_node), "Failed to concat shape for " + concat_node.Name());
|
||||
concat_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
unflattened_shape_arg = concat_output_args[0];
|
||||
}
|
||||
|
||||
InlinedVector<NodeArg*> gathergrad_input_args;
|
||||
gathergrad_input_args.reserve(3);
|
||||
gathergrad_input_args.push_back(new_shape_arg);
|
||||
gathergrad_input_args.push_back(gathergrad_index_arg);
|
||||
gathergrad_input_args.push_back(node.MutableInputDefs()[in_index]);
|
||||
|
||||
InlinedVector<NodeArg*> gathergrad_output_args;
|
||||
gathergrad_output_args.push_back(
|
||||
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padding_recover_result"),
|
||||
InlinedVector<NodeArg*> pad_node_output_args;
|
||||
pad_node_output_args.push_back(
|
||||
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_result"),
|
||||
nullptr));
|
||||
pad_node_output_args.push_back(
|
||||
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_d1xd2_shape"),
|
||||
nullptr));
|
||||
|
||||
Node* new_gathergrad_node = InsertIntermediateNodeOnDestInput(
|
||||
graph, node,
|
||||
in_index,
|
||||
2,
|
||||
0,
|
||||
0 /* new_node_input_index*/,
|
||||
0 /* new_node_output_index*/,
|
||||
graph.GenerateNodeName("PaddingRecover"),
|
||||
"GatherGrad",
|
||||
"GatherGrad node to recover invalid tokens.",
|
||||
gathergrad_input_args,
|
||||
gathergrad_output_args,
|
||||
"PadAndUnflatten",
|
||||
"PadAndUnflatten node to recover invalid tokens.",
|
||||
pad_node_input_args,
|
||||
pad_node_output_args,
|
||||
{},
|
||||
kMSDomain,
|
||||
logger);
|
||||
|
||||
new_gathergrad_node->SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
auto gathergrad_out_arg = new_gathergrad_node->MutableOutputDefs()[0];
|
||||
|
||||
InlinedVector<NodeArg*> reshape_input_args;
|
||||
reshape_input_args.push_back(gathergrad_out_arg);
|
||||
reshape_input_args.push_back(unflattened_shape_arg);
|
||||
InlinedVector<NodeArg*> reshape_output_args{&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("reshape_result"),
|
||||
nullptr)};
|
||||
Node* new_reshape_node = InsertIntermediateNodeOnDestInput(
|
||||
graph, node,
|
||||
in_index,
|
||||
0,
|
||||
0,
|
||||
graph.GenerateNodeName("RecoverShape"),
|
||||
"Reshape",
|
||||
"Reshape node to recover invalid tokens.",
|
||||
reshape_input_args,
|
||||
reshape_output_args,
|
||||
{},
|
||||
kOnnxDomain,
|
||||
logger);
|
||||
new_reshape_node->SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
return new_reshape_node->MutableOutputDefs()[0];
|
||||
return new_gathergrad_node->MutableOutputDefs()[0];
|
||||
}
|
||||
|
||||
// Iterate the subgraph beginning from the start_node, and put all node args into 'subgraph'
|
||||
// Also put all candidate input nodes and cantidate output nodes of the subgraph into candidate_inputs and
|
||||
// Also put all candidate input nodes and candidate output nodes of the subgraph into candidate_inputs and
|
||||
// candidate_outputs respectively.
|
||||
void IterateSubgraphFromNode(Graph& graph,
|
||||
Node* start_node,
|
||||
|
|
@ -368,7 +283,7 @@ void IterateSubgraphFromNode(Graph& graph,
|
|||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "MatMul", {1, 9, 13})) {
|
||||
if (subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end()) {
|
||||
// If shape of [batch_size, seqlen, ...] is propagated from the first argument of MatMul.
|
||||
// The dim size of the first argument must larger than 2 to propagete the first two dims to the output.
|
||||
// The dim size of the first argument must be larger than 2 to propagate the first two dims to the output.
|
||||
// Or else the first two dims of the output will not be [batch_size, seqlen] and this MatMul will be added
|
||||
// to candidate_outputs as the output of the subgraph.
|
||||
if (cur->InputDefs()[0]->Shape()->dim_size() > 2) {
|
||||
|
|
@ -376,17 +291,17 @@ void IterateSubgraphFromNode(Graph& graph,
|
|||
PushAllOutputNode(graph, to_visit, cur, visited);
|
||||
} else {
|
||||
LOG_DEBUG_INFO(logger,
|
||||
"PaddingElimination::dim size of left input of matmul smaller than 3 and \
|
||||
this matmul would be output of subgraph.");
|
||||
"PaddingElimination::dim size of left input of MatMul smaller than 3 and \
|
||||
this MatMul would be the output of the subgraph.");
|
||||
candidate_outputs.insert(cur);
|
||||
continue;
|
||||
}
|
||||
} else if (subgraph.find(cur->MutableInputDefs()[1]) != subgraph.end()) {
|
||||
LOG_DEBUG_INFO(logger, "PaddingElimination::right edge of matmul would not included.");
|
||||
LOG_DEBUG_INFO(logger, "PaddingElimination::right edge of MatMul would not included.");
|
||||
candidate_outputs.insert(cur);
|
||||
continue;
|
||||
} else {
|
||||
ORT_THROW("PaddingElimination::found matmul node without input in subgraph.");
|
||||
ORT_THROW("PaddingElimination::found MatMul node without input in subgraph.");
|
||||
}
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "PythonOp", {1}, kMSDomain)) {
|
||||
if (subgraph.find(cur->MutableInputDefs()[0]) == subgraph.end()) {
|
||||
|
|
@ -451,7 +366,8 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
|
|||
") due to embedding input is not in the sparse embedding input list.");
|
||||
continue;
|
||||
}
|
||||
const ONNX_NAMESPACE::TensorProto* padding_initializer = graph_utils::GetConstantInitializer(graph, node.InputDefs()[2]->Name());
|
||||
const ONNX_NAMESPACE::TensorProto* padding_initializer =
|
||||
graph_utils::GetConstantInitializer(graph, node.InputDefs()[2]->Name());
|
||||
if (padding_initializer != nullptr &&
|
||||
padding_initializer->dims_size() == 0 &&
|
||||
((padding_initializer->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT32) ||
|
||||
|
|
@ -526,18 +442,6 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> first_indices;
|
||||
first_indices.push_back(0);
|
||||
ONNX_NAMESPACE::TensorProto first_indice_const_tensor;
|
||||
first_indice_const_tensor.set_name(graph.GenerateNodeArgName("indices"));
|
||||
first_indice_const_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
first_indice_const_tensor.add_dims(first_indices.size());
|
||||
first_indice_const_tensor.set_raw_data(first_indices.data(), first_indices.size() * sizeof(int64_t));
|
||||
NodeArg* first_index_arg = &graph_utils::AddInitializer(graph, first_indice_const_tensor);
|
||||
|
||||
// Get the first dim value of flattened input_ids which is batch_size * seq_len
|
||||
NodeArg* first_dim = GetDimsValue(graph, reshape_output_args[0], first_index_arg, *embedding_node);
|
||||
|
||||
std::vector<int64_t> first_two_indices{0, 1};
|
||||
ONNX_NAMESPACE::TensorProto first_two_indices_const_tensor;
|
||||
first_two_indices_const_tensor.set_name(graph.GenerateNodeArgName("first_two_indices"));
|
||||
|
|
@ -553,11 +457,7 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
|
|||
for (const auto& node : candidate_outputs) {
|
||||
for (uint32_t i = 0; i < node->InputDefs().size(); ++i) {
|
||||
if (subgraph.find(node->MutableInputDefs()[i]) != subgraph.end()) {
|
||||
// Get a shape of the i-th input of the node with first index updated to value of first_dim
|
||||
// which is batch_size * seq_len. This shape arg will be used as the shape input of GatherGrad
|
||||
NodeArg* shape_arg_for_gather_grad = UpdateShape(
|
||||
graph, node->MutableInputDefs()[i], first_dim, first_index_arg, *node);
|
||||
InsertNodesForOutput(graph, *node, i, squeeze_out_arg, shape_arg_for_gather_grad, first_two_dims_arg, logger);
|
||||
InsertNodesForOutput(graph, *node, i, squeeze_out_arg, first_two_dims_arg, logger);
|
||||
handled_output_count++;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,15 +16,18 @@ namespace onnxruntime {
|
|||
* @Class PaddingElimination
|
||||
*
|
||||
* @brief Graph transformer that eliminates unnecessary padding computation caused by embedding sparsity.
|
||||
*
|
||||
* In transformer trainings, input_ids are usually padded to the same length, which is the max sequence length,
|
||||
* so its shape is [batch_size, sequence_length] or [sequence_length, batch_size]. This graph transformer
|
||||
* tries to MERGE the leading two dimensions and REMOVE the padding on the merged
|
||||
* dimension, i.e, [batch_size, sequence_length, ...] -> [batch_size * sequence_length, ...] ->
|
||||
|
||||
*
|
||||
* This transformer is implemented in the following steps:
|
||||
* 1. Iterate the graph and find the Embedding node that matches these requirements:
|
||||
* (1) Its 2nd input is a graph input and its rank > 2 with the first two dimensions are dim_params which are
|
||||
* actually batch_size and sequence_length.
|
||||
* Note: Now only support the case of the first two dimensions to merged and remove the padding on the merged
|
||||
* dimension, i.e, [batch_size, sequence_length, ...] -> [batch_size * sequence_length, ...] ->
|
||||
* [valid_token, ... ]. In the future, we may support the case of any two consecutive dimensions to merged,
|
||||
* such as [..., batch_size, sequence_length, ...].
|
||||
* (2) Its 3nd input is a scalar constant initializer which is the padding idx that should >= 0.
|
||||
* 1.1 The 2nd input is a graph input and its rank > 2, with the first two dimensions, are:
|
||||
* [batch_size, sequence_length]. Both dimensions can be symbolic or concrete dim values.
|
||||
* 1.2 The 3rd input(padding idx) is a scalar constant initializer, and should >= 0.
|
||||
* 2. Append embedding node in node_to_scan_list.
|
||||
* Iterate the node_to_scan_list, for each node,
|
||||
* 2.1 Check if it is supported for pad elimination (from a pre-defined op list). If no, record this node as output
|
||||
|
|
@ -42,6 +45,8 @@ namespace onnxruntime {
|
|||
* This is needed to ensure not to affect subsequent computations
|
||||
*
|
||||
* For example, given the following graph:
|
||||
* 1. `input_0` is a tensor that is an in-direct output of ATen embedding node.
|
||||
* 2. `input_1` is a tensor that is NOT a direct or in-direct output of ATen embedding node.
|
||||
*
|
||||
* embed.weight input_ids [batch_size, seq_length] padding_idx [1] scale_grad_by_freq sparse
|
||||
* \ \ / / /
|
||||
|
|
@ -49,11 +54,14 @@ namespace onnxruntime {
|
|||
* \ \ / / /
|
||||
* \_________________\_________________________/________________/______________________/
|
||||
* |
|
||||
* Aten:embedding
|
||||
* ATen:embedding
|
||||
* |
|
||||
* |
|
||||
* input |
|
||||
* \
|
||||
* - - - - - - - - - - - -|
|
||||
* | |
|
||||
* input_0 | input_1
|
||||
* \ | /
|
||||
* \__________ | ___________/
|
||||
* \ | /
|
||||
* Subgraph
|
||||
*
|
||||
* |
|
||||
|
|
@ -83,40 +91,36 @@ namespace onnxruntime {
|
|||
* \______________________\________________________________/__________________/________________/
|
||||
* |
|
||||
* Aten:embedding
|
||||
* _ _ _ _ _ __ _ _ _ __ _ _|
|
||||
* / |
|
||||
* input_node |
|
||||
* \ [batch_size, seq_length] |
|
||||
* \ |
|
||||
* \ [-1] |
|
||||
* \ / |
|
||||
* Reshape (valid_token_index) |
|
||||
* \ / |
|
||||
* ShrunkenGather | shape:[valid_token, ...]
|
||||
* \ |
|
||||
* shape:[valid_token] \ |
|
||||
* \ |
|
||||
* candidate_input_node |
|
||||
* \ |
|
||||
* \ |
|
||||
* - - - - - - - - - - - - - - - - - - - - |
|
||||
* | |
|
||||
* input_0 | input_1
|
||||
* \ [batch_size, seq_length, ...] | |
|
||||
* \ | [batch_size, seq_length, ...]
|
||||
* \ [-1] | |
|
||||
* \ / | |
|
||||
* Reshape (valid_token_index) | Reshape (valid_token_index)
|
||||
* \ / | \ /
|
||||
* ShrunkenGather shape:[valid_token, ...] ShrunkenGather
|
||||
* \ | /
|
||||
* shape:[valid_token, ...] \ | /
|
||||
* \ | /
|
||||
* candidate_input_node | candidate_input_node
|
||||
* \ | /
|
||||
* \ | /
|
||||
*
|
||||
* Subgraph
|
||||
*
|
||||
* |
|
||||
* | shape:[valid_token]
|
||||
* | shape:[valid_token, ...]
|
||||
* |
|
||||
* | (valid_token_index)
|
||||
* | / ________________ (unflatten_dims), shape:[2],
|
||||
* | / / value:[batch_size, seq_length]
|
||||
* | / /
|
||||
* PadAndUnflatten
|
||||
* |
|
||||
* [batch_size*seq_length] (valid_token_index) |
|
||||
* \ | /
|
||||
* \ | /
|
||||
* \ | /
|
||||
*
|
||||
* GatherGrad
|
||||
* |
|
||||
* Reshape
|
||||
* |
|
||||
* | [batch_size, valid_token]
|
||||
* candidate_output_node
|
||||
* | [batch_size, seq_length, ...]
|
||||
* candidate_output_node
|
||||
*
|
||||
*
|
||||
*
|
||||
|
|
|
|||
|
|
@ -177,8 +177,12 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
transformers.emplace_back(std::make_unique<UpStreamReshapeGraphTransformer>(compatible_eps));
|
||||
transformers.emplace_back(std::make_unique<InsertGatherBeforeSceLoss>(compatible_eps,
|
||||
config.sparse_label_input_names));
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
// Put this under CUDA/ROCM guard as it depends on PadAndUnflatten CUDA/ROCM kernel.
|
||||
// Once we have a CPU kernel for PadAndUnflatten, we can remove the guard.
|
||||
transformers.emplace_back(std::make_unique<PaddingElimination>(compatible_eps,
|
||||
config.sparse_embedding_input_names));
|
||||
#endif
|
||||
}
|
||||
|
||||
} break;
|
||||
|
|
|
|||
|
|
@ -1901,23 +1901,6 @@ TEST(GradientCheckerTest, GatherGrad) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, GatherGradGrad) {
|
||||
float max_error;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
OpDef op_def{"GatherGrad", kMSDomain, 1};
|
||||
TensorInfo shape_info({2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
|
||||
TensorInfo indices_info({2, 2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
|
||||
TensorInfo x_info({2, 2, 3});
|
||||
std::vector<std::vector<float>> x_datas = {{6, 3}, {3, 5, 0, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}};
|
||||
|
||||
TensorShape y_shape{6, 3};
|
||||
int64_t axis = 0;
|
||||
|
||||
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {shape_info, indices_info, x_info}, {y_shape}, &max_error,
|
||||
x_datas, {MakeAttribute("axis", axis)}));
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
void TestDropoutOp(float ratio, TensorShape& x_shape, bool default_ratio = true) {
|
||||
OpTester test("Dropout", 12, kOnnxDomain, false);
|
||||
if (default_ratio) ratio = 0.5f;
|
||||
|
|
@ -3016,6 +2999,34 @@ TEST(GradientCheckerTest, TriluGrad) {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO (enable once found why it fails on ROCM)
|
||||
#if defined(USE_CUDA)
|
||||
TEST(GradientCheckerTest, PadAndUnflattenGrad) {
|
||||
float max_error;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
OpDef op_def{"PadAndUnflatten", kMSDomain, 1};
|
||||
TensorInfo shape_info({2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
|
||||
TensorInfo indices_info({4}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
|
||||
TensorInfo x_info({4, 3});
|
||||
std::vector<std::vector<float>> x_datas = {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 5, 0, 1}, {5, 2}};
|
||||
|
||||
TensorInfo padded_out_info({5, 2, 3}, true);
|
||||
TensorInfo out_shape_info({2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
|
||||
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
#ifdef USE_CUDA
|
||||
execution_providers.emplace_back(DefaultCudaExecutionProvider());
|
||||
#elif USE_ROCM
|
||||
execution_providers.emplace_back(DefaultRocmExecutionProvider());
|
||||
#endif
|
||||
|
||||
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info, indices_info, shape_info},
|
||||
{padded_out_info, out_shape_info}, &max_error,
|
||||
x_datas, {}, true, false, &execution_providers));
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
||||
|
|
|
|||
|
|
@ -5886,12 +5886,12 @@ def test_ops_for_padding_elimination(test_cases):
|
|||
assert len([node.op_type for node in training_model.graph.node if node.op_type == "Sub"]) == 1
|
||||
assert len([node.op_type for node in training_model.graph.node if node.op_type == "NonZero"]) == 1
|
||||
assert len([node.op_type for node in training_model.graph.node if node.op_type == "Squeeze"]) == 1
|
||||
assert len([node.op_type for node in training_model.graph.node if node.op_type == "GatherGrad"]) == 1
|
||||
assert len([node.op_type for node in training_model.graph.node if node.op_type == "PadAndUnflatten"]) == 1
|
||||
if case == 2:
|
||||
assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 2
|
||||
else:
|
||||
assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 1
|
||||
gathergrad_node = [node for node in training_model.graph.node if node.op_type == "GatherGrad"][0]
|
||||
gathergrad_node = [node for node in training_model.graph.node if node.op_type == "PadAndUnflatten"][0]
|
||||
|
||||
def find_input_node_type(model, arg):
|
||||
result = []
|
||||
|
|
@ -6057,5 +6057,5 @@ def test_e2e_padding_elimination():
|
|||
|
||||
training_model = ort_model._torch_module._execution_manager(True)._onnx_models.optimized_model
|
||||
assert "ShrunkenGather" in [node.op_type for node in training_model.graph.node]
|
||||
assert "GatherGrad" in [node.op_type for node in training_model.graph.node]
|
||||
assert "PadAndUnflatten" in [node.op_type for node in training_model.graph.node]
|
||||
del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,105 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "test/common/tensor_op_test_utils.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
|
||||
TEST(PadAndUnflattenTest, FloatType1D) {
|
||||
std::vector<float> input = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f};
|
||||
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11};
|
||||
std::vector<int64_t> unflatten_dims = {5, 3};
|
||||
|
||||
std::vector<float> output = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f,
|
||||
0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f};
|
||||
|
||||
std::vector<int64_t> full_flatten_dims = {15};
|
||||
|
||||
OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<float>("input", {6}, input);
|
||||
test.AddInput<int64_t>("indices", {6}, indices);
|
||||
test.AddInput<int64_t>("unflatten_dims", {2}, unflatten_dims);
|
||||
test.AddOutput<float>("output", {5, 3}, output);
|
||||
test.AddOutput<int64_t>("full_flatten_dims", {1}, full_flatten_dims);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(PadAndUnflattenTest, FloatType2D) {
|
||||
std::vector<float> input = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f, 7.f, 8.f, 9.f};
|
||||
std::vector<int64_t> indices = {1, 3, 4};
|
||||
std::vector<int64_t> unflatten_dims = {2, 3};
|
||||
|
||||
std::vector<float> output = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f,
|
||||
4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f};
|
||||
|
||||
std::vector<int64_t> full_flatten_dims = {6, 3};
|
||||
|
||||
OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<float>("input", {3, 3}, input);
|
||||
test.AddInput<int64_t>("indices", {3}, indices);
|
||||
test.AddInput<int64_t>("unflatten_dims", {2}, unflatten_dims);
|
||||
test.AddOutput<float>("output", {2, 3, 3}, output);
|
||||
test.AddOutput<int64_t>("full_flatten_dims", {2}, full_flatten_dims);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(PadAndUnflattenTest, MLFloat16Type1D) {
|
||||
std::vector<float> input = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f};
|
||||
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11};
|
||||
std::vector<int64_t> unflatten_dims = {5, 3};
|
||||
|
||||
std::vector<float> output = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f,
|
||||
0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f};
|
||||
|
||||
std::vector<int64_t> full_flatten_dims = {15};
|
||||
|
||||
std::vector<MLFloat16> input_half;
|
||||
input_half.resize(input.size());
|
||||
ConvertFloatToMLFloat16(input.data(), input_half.data(), int(input.size()));
|
||||
std::vector<MLFloat16> output_half;
|
||||
output_half.resize(output.size());
|
||||
ConvertFloatToMLFloat16(output.data(), output_half.data(), int(output.size()));
|
||||
|
||||
OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<MLFloat16>("input", {6}, input_half);
|
||||
test.AddInput<int64_t>("indices", {6}, indices);
|
||||
test.AddInput<int64_t>("unflatten_dims", {2}, unflatten_dims);
|
||||
test.AddOutput<MLFloat16>("output", {5, 3}, output_half);
|
||||
test.AddOutput<int64_t>("full_flatten_dims", {1}, full_flatten_dims);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(PadAndUnflattenTest, MLFloat16Type2D) {
|
||||
std::vector<float> input = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f, 7.f, 8.f, 9.f};
|
||||
std::vector<int64_t> indices = {1, 3, 4};
|
||||
std::vector<int64_t> unflatten_dims = {2, 3};
|
||||
|
||||
std::vector<float> output = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f,
|
||||
4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f};
|
||||
|
||||
std::vector<int64_t> full_flatten_dims = {6, 3};
|
||||
|
||||
std::vector<MLFloat16> input_half;
|
||||
input_half.resize(input.size());
|
||||
ConvertFloatToMLFloat16(input.data(), input_half.data(), int(input.size()));
|
||||
std::vector<MLFloat16> output_half;
|
||||
output_half.resize(output.size());
|
||||
ConvertFloatToMLFloat16(output.data(), output_half.data(), int(output.size()));
|
||||
|
||||
OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<MLFloat16>("input", {3, 3}, input_half);
|
||||
test.AddInput<int64_t>("indices", {3}, indices);
|
||||
test.AddInput<int64_t>("unflatten_dims", {2}, unflatten_dims);
|
||||
test.AddOutput<MLFloat16>("output", {2, 3, 3}, output_half);
|
||||
test.AddOutput<int64_t>("full_flatten_dims", {2}, full_flatten_dims);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -197,6 +197,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inpl
|
|||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuant);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuantGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PadAndUnflatten);
|
||||
|
||||
// the kernels within the following ifdef are not included in a build with
|
||||
// --enable_training_ops but without --enable_training
|
||||
|
|
@ -437,7 +438,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
kCudaExecutionProvider, kMSDomain, 1, float, FakeQuant)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSDomain, 1, float, FakeQuantGrad)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PadAndUnflatten)>,
|
||||
// the kernels within the following ifdef are not included in a build with
|
||||
// --enable_training_ops but without --enable_training
|
||||
#ifdef ENABLE_TRAINING
|
||||
|
|
|
|||
|
|
@ -0,0 +1,99 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/training_ops/cuda/tensor/pad_and_unflatten.h"
|
||||
#include "orttraining/training_ops/cuda/tensor/pad_and_unflatten_impl.h"
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
PadAndUnflatten,
|
||||
kMSDomain,
|
||||
1,
|
||||
kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", BuildKernelDefConstraints<MLFloat16, float, double, BFloat16>())
|
||||
.TypeConstraint("T_INT", DataTypeImpl::GetTensorType<int64_t>())
|
||||
.TypeConstraint("T_INDEX", DataTypeImpl::GetTensorType<int64_t>())
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 2)
|
||||
.OutputMemoryType(OrtMemTypeCPUOutput, 1),
|
||||
PadAndUnflatten);
|
||||
|
||||
// Put implementation in the anonymous namespace to avoid name collision in the global namespace.
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
struct PadAndUnflattenFunctor {
|
||||
void operator()(cudaStream_t stream,
|
||||
const int64_t input_element_count,
|
||||
const fast_divmod output_element_stride_fdm,
|
||||
const int64_t index_value_upper_bound,
|
||||
const Tensor& input_tensor,
|
||||
const Tensor& indices_tensor,
|
||||
Tensor& output_tensor) const {
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
const CudaT* input_data = reinterpret_cast<const CudaT*>(input_tensor.Data<T>());
|
||||
|
||||
CUDA_CALL_THROW(cudaMemset(output_tensor.MutableDataRaw(), 0, output_tensor.Shape().Size() * sizeof(CudaT)));
|
||||
PadAndUnflattenImpl<CudaT>(stream, input_element_count, output_element_stride_fdm, index_value_upper_bound,
|
||||
input_data, indices_tensor.Data<int64_t>(),
|
||||
reinterpret_cast<CudaT*>(output_tensor.MutableData<T>()));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
Status PadAndUnflatten::ComputeInternal(OpKernelContext* context) const {
|
||||
const Tensor* input_tensor = context->Input<Tensor>(0);
|
||||
const Tensor* indices_tensor = context->Input<Tensor>(1);
|
||||
const Tensor* unflatten_dims_tensor = context->Input<Tensor>(2); // Parse the 1-D shape tensor.
|
||||
ORT_ENFORCE(unflatten_dims_tensor->Shape().NumDimensions() == 1,
|
||||
"unflatten_dims_tensor tensor must be 1-D.", unflatten_dims_tensor->Shape().NumDimensions());
|
||||
ORT_ENFORCE(unflatten_dims_tensor->Shape().Size() == 2,
|
||||
"unflatten_dims_tensor tensor must contain 2 values.", unflatten_dims_tensor->Shape().Size());
|
||||
|
||||
const int64_t* dims_ptr = unflatten_dims_tensor->Data<int64_t>();
|
||||
const auto& input_shape = input_tensor->Shape();
|
||||
ORT_ENFORCE(input_shape[0] == indices_tensor->Shape()[0],
|
||||
"The first dimension of input and indices must be the same.");
|
||||
|
||||
std::vector<int64_t> output_shape_vec;
|
||||
output_shape_vec.push_back(dims_ptr[0]);
|
||||
output_shape_vec.push_back(dims_ptr[1]);
|
||||
|
||||
std::vector<int64_t> full_size_flatten_shape_vec;
|
||||
const int64_t flatten_dim_factor = dims_ptr[0] * dims_ptr[1];
|
||||
full_size_flatten_shape_vec.push_back(flatten_dim_factor);
|
||||
|
||||
int64_t element_stride = 1;
|
||||
for (size_t i = 1; i < input_shape.NumDimensions(); ++i) {
|
||||
output_shape_vec.push_back(input_shape[i]);
|
||||
full_size_flatten_shape_vec.push_back(input_shape[i]);
|
||||
element_stride *= input_shape[i];
|
||||
}
|
||||
|
||||
fast_divmod output_element_stride_fdm(static_cast<int>(element_stride));
|
||||
auto output_shape = TensorShape(output_shape_vec);
|
||||
Tensor* output_tensor = context->Output(0, output_shape);
|
||||
|
||||
utils::MLTypeCallDispatcher<float, MLFloat16, double, BFloat16> t_disp(input_tensor->GetElementType());
|
||||
t_disp.Invoke<PadAndUnflattenFunctor>(Stream(context),
|
||||
input_shape.Size(),
|
||||
output_element_stride_fdm,
|
||||
flatten_dim_factor,
|
||||
*input_tensor,
|
||||
*indices_tensor,
|
||||
*output_tensor);
|
||||
|
||||
// Set input shape output tensor.
|
||||
size_t rank = full_size_flatten_shape_vec.size();
|
||||
Tensor* input_shape_tensor = context->Output(1, {static_cast<int>(rank)});
|
||||
TensorShape(full_size_flatten_shape_vec).CopyDims(input_shape_tensor->MutableData<int64_t>(), rank);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/cuda/cuda_kernel.h"
|
||||
#include "core/providers/common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
class PadAndUnflatten final : public CudaKernel {
|
||||
public:
|
||||
PadAndUnflatten(const OpKernelInfo& info) : CudaKernel(info) {
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/training_ops/cuda/tensor/pad_and_unflatten_impl.h"
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
constexpr int kBlockSize = 256;
|
||||
constexpr int kNumUnroll = 4;
|
||||
|
||||
template <typename T>
|
||||
__global__ void FillOutputWithIndexKernel(const CUDA_LONG N,
|
||||
const fast_divmod output_element_stride_fdm,
|
||||
const int64_t index_value_upper_bound,
|
||||
const T* input_data,
|
||||
const int64_t* indices_data,
|
||||
T* output_data) {
|
||||
CUDA_LONG idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
CUDA_LONG id = idx * kNumUnroll;
|
||||
|
||||
T input[kNumUnroll];
|
||||
if (id < N) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumUnroll; ++i) {
|
||||
CUDA_LONG li = id + i;
|
||||
if (li < N) {
|
||||
input[i] = input_data[li];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumUnroll; ++i) {
|
||||
CUDA_LONG li = id + i;
|
||||
if (li < N) {
|
||||
int row_index, col_index;
|
||||
output_element_stride_fdm.divmod(li, row_index, col_index);
|
||||
assert(indices_data[row_index] < index_value_upper_bound);
|
||||
output_data[indices_data[row_index] * output_element_stride_fdm.d_ + col_index] = input[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void PadAndUnflattenImpl(cudaStream_t stream,
|
||||
const int64_t total_element_count,
|
||||
const fast_divmod output_element_stride_fdm,
|
||||
const int64_t index_value_upper_bound,
|
||||
const T* input_data,
|
||||
const int64_t* indices_data,
|
||||
T* output_data) {
|
||||
const int blocksPerGrid = static_cast<int>(CeilDiv(total_element_count, kBlockSize * kNumUnroll));
|
||||
FillOutputWithIndexKernel<T><<<blocksPerGrid, kBlockSize, 0, stream>>>(
|
||||
static_cast<CUDA_LONG>(total_element_count),
|
||||
output_element_stride_fdm,
|
||||
index_value_upper_bound,
|
||||
input_data,
|
||||
indices_data,
|
||||
output_data);
|
||||
}
|
||||
|
||||
#define SPECIALIZED_RESTORE_FROM_MASK_IMPL(T) \
|
||||
template void PadAndUnflattenImpl<T>(cudaStream_t stream, \
|
||||
const int64_t total_element_count, \
|
||||
const fast_divmod output_element_stride_fdm, \
|
||||
const int64_t index_value_upper_bound, \
|
||||
const T* input_data, \
|
||||
const int64_t* indices_data, \
|
||||
T* output_data);
|
||||
|
||||
SPECIALIZED_RESTORE_FROM_MASK_IMPL(float)
|
||||
SPECIALIZED_RESTORE_FROM_MASK_IMPL(double)
|
||||
SPECIALIZED_RESTORE_FROM_MASK_IMPL(half)
|
||||
SPECIALIZED_RESTORE_FROM_MASK_IMPL(BFloat16)
|
||||
|
||||
#undef SPECIALIZED_RESTORE_FROM_MASK_IMPL
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include "core/providers/rocm/shared_inc/rocm_utils.h"
|
||||
#else
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
void PadAndUnflattenImpl(cudaStream_t stream,
|
||||
const int64_t total_element_count,
|
||||
const fast_divmod output_element_stride_fdm,
|
||||
const int64_t index_value_upper_bound,
|
||||
const T* input_data,
|
||||
const int64_t* indices_data,
|
||||
T* output_data);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -182,6 +182,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_float, ReduceAllL2);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_BFloat16, ReduceAllL2);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, PadAndUnflatten);
|
||||
|
||||
#if defined(ORT_USE_NCCL) || defined(USE_MPI)
|
||||
// P2P communication operators.
|
||||
|
|
@ -378,6 +379,7 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_float, ReduceAllL2)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_BFloat16, ReduceAllL2)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, PadAndUnflatten)>,
|
||||
|
||||
// P2P communication operators.
|
||||
#if defined(ORT_USE_NCCL) || defined(USE_MPI)
|
||||
|
|
|
|||
Loading…
Reference in a new issue