Optimize computation orders (#13672)

### Optimize computation orders

In `Roberta/Electra`, when `ClassificationHead` is used, there is
slicing operation on features on sequence_length dimensions, then loss
calculations only depend on this sliced data. This is a slicing at axis
1. Before slicing the shape is [batch, sequence_length, hidden], after
slicing, it becomes [batch , hidden_stage]

We had opportunities to bring this slicing earlier as much as possible,
by passing through simple elementwise ops (like Add/Div), or
Layernorm/Softmax(if their reduce axis is after the slicing axis), or
even MatMul's the left operand (if only it did not affect the last
dims).

For operators like Reshape/Transpose, it is special since they have
either data specified (after slicing we need update), or they have perm
specified, which requires the input rank remain unchanged. So for those
kinds of operators, we can remain the original rank, but just leave the
sliced dim to be 1, after the compute completed, we do a Squeeze.

```
class RobertaClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, features, **kwargs):
        x = features[:, 0, :]  # take <s> token (equiv. to [CLS])
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x
```

src\transformers\models\roberta\modeling_roberta.py
src\transformers\models\electra\modeling_electra.py

#### Benchmark

A simple benchmark shows Robeta training latency dropped from 208ms ~
199ms. 4.5+% reduction.
More comprehensive tests are on the way.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
pengwa 2022-12-22 15:12:52 +08:00 committed by GitHub
parent 7ed8bd4f95
commit 2f5bf75e51
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
51 changed files with 3802 additions and 547 deletions

View file

@ -57,6 +57,8 @@ else()
"${ONNXRUNTIME_INCLUDE_DIR}/core/optimizer/*.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/*.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/*.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/compute_optimizer/*.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/compute_optimizer/*.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/*.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/*.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/*.h"

View file

@ -124,6 +124,13 @@ Before full qualified name can be got from exporter, this environment variables
export ORTMODULE_SKIPPED_AUTOGRAD_FUNCTIONS="megatron.fp16.fp16.fused_kernels.GELUFunction"
```
#### ORTMODULE_ENABLE_COMPUTE_OPTIMIZER
- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is enabled then some computation can be saved. This env var can be used for disabling
the optimization to guarantee exactly same compute with baseline (for example PyTorch, when doing convergence parity
debugging).
### 2.2 Memory Optimization
Q: *Want to run a bigger batch size?*

View file

@ -1,305 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/common/safeint.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/utils.h"
#include "core/optimizer/computation_reduction.h"
using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
namespace onnxruntime {
typedef std::function<Status(Graph&, Node&, Node&)> Handler;
constexpr int GATHERND_BATCH_DIM = 1;
static bool IsLeadingDimsEqual(const TensorShapeProto* input_shape, const TensorShapeProto* output_shape,
const int num_dim_to_check) {
ORT_ENFORCE(output_shape->dim_size() >= num_dim_to_check && input_shape->dim_size() >= num_dim_to_check);
for (int i = 0; i < num_dim_to_check; ++i) {
auto& output_dim = output_shape->dim(i);
auto& input_dim = input_shape->dim(i);
if (output_dim.has_dim_value() && input_dim.has_dim_value()) {
if (output_dim.dim_value() != input_dim.dim_value()) {
return false;
}
} else if (output_dim.has_dim_param() && input_dim.has_dim_param()) {
if (output_dim.dim_param() != input_dim.dim_param()) {
return false;
}
} else {
return false;
}
}
return true;
}
static int GetValidInputForGatherND(const Node& target_node) {
// target_node is the producer of GatherND's input.
// If target_node's some input tensors have exactly same shape with
// target_node output tensor shape, then it is safe to gather using
// original slice ranges.
int candidate_input_index = -1;
auto output_shape = target_node.OutputDefs()[0]->Shape();
const int output_rank = output_shape->dim_size();
for (size_t i = 0; i < target_node.InputDefs().size(); ++i) {
auto input_shape = target_node.InputDefs()[i]->Shape();
const int input_rank = input_shape->dim_size();
if (input_rank != output_rank) {
continue;
}
if (IsLeadingDimsEqual(input_shape, output_shape, GATHERND_BATCH_DIM + 1)) {
candidate_input_index = SafeInt<int>(i);
break;
}
}
return candidate_input_index;
}
static TensorShapeProto ReplaceSymbolicDimValue(const TensorShapeProto* shape, const int replacement_axis,
const std::string& replacement_dim_value) {
ORT_ENFORCE(replacement_axis >= 0 && replacement_axis < shape->dim_size());
TensorShapeProto output_shape;
for (int i = 0; i < shape->dim_size(); ++i) {
auto& dim = shape->dim(i);
if (i == replacement_axis) {
output_shape.add_dim()->set_dim_param(replacement_dim_value);
continue;
}
if (dim.has_dim_value()) {
output_shape.add_dim()->set_dim_value(dim.dim_value());
} else if (dim.has_dim_param()) {
output_shape.add_dim()->set_dim_param(dim.dim_param());
} else {
ORT_THROW("Invalid dim found in ReplaceSymbolicDimValue");
}
}
return output_shape;
}
static Status SwapGatherNDWithTargetNode(Graph& graph, Node& gathernd_node, Node& target_node,
const int target_node_input_index = 0) {
auto new_input_arg_for_gathernd = target_node.MutableInputDefs()[target_node_input_index];
auto target_node_out_arg = target_node.MutableOutputDefs()[0];
auto gathernd_out_arg = gathernd_node.MutableOutputDefs()[0];
auto gathernd_old_consumers = graph.GetConsumerNodes(gathernd_out_arg->Name());
const auto& graph_outputs = graph.GetOutputs();
bool need_update_graph_output = false;
if (std::find(graph_outputs.begin(), graph_outputs.end(), gathernd_out_arg) != graph_outputs.end()) {
need_update_graph_output = true;
}
const std::string& gathered_dim_param = gathernd_out_arg->Shape()->dim(GATHERND_BATCH_DIM).dim_param();
TensorShapeProto new_output_shape_for_gathernd =
ReplaceSymbolicDimValue(new_input_arg_for_gathernd->Shape(), GATHERND_BATCH_DIM, gathered_dim_param);
TensorShapeProto new_output_shape_for_target_node =
ReplaceSymbolicDimValue(target_node_out_arg->Shape(), GATHERND_BATCH_DIM, gathered_dim_param);
// update input/output definitions.
int output_index = optimizer_utils::IndexOfNodeOutput(target_node, *gathernd_node.MutableInputDefs()[0]);
graph.RemoveEdge(target_node.Index(), gathernd_node.Index(), output_index, 0);
const Node* target_node_input_node = graph.GetProducerNode(new_input_arg_for_gathernd->Name());
if (target_node_input_node != nullptr) {
output_index = optimizer_utils::IndexOfNodeOutput(*target_node_input_node, *new_input_arg_for_gathernd);
graph.AddEdge(target_node_input_node->Index(), gathernd_node.Index(), output_index, 0);
} else {
// new_input_arg_for_gathernd is graph input
graph_utils::ReplaceNodeInput(gathernd_node, 0, *new_input_arg_for_gathernd);
}
graph_utils::ReplaceDownstreamNodeInput(graph, gathernd_node, 0 /*output_idx*/,
target_node, 0 /*replacement_output_idx*/);
if (target_node_input_node != nullptr) {
graph.RemoveEdge(target_node_input_node->Index(), target_node.Index(), output_index, target_node_input_index);
}
graph.AddEdge(gathernd_node.Index(), target_node.Index(), 0, target_node_input_index);
// update consumer relation ship
if (!gathernd_old_consumers.empty()) {
graph.UpdateConsumerNodes(target_node_out_arg->Name(), {const_cast<Node*>(gathernd_old_consumers[0])});
}
graph.UpdateConsumerNodes(gathernd_out_arg->Name(), {&target_node});
// update shapes
gathernd_out_arg->SetShape(new_output_shape_for_gathernd);
target_node_out_arg->SetShape(new_output_shape_for_target_node);
if (need_update_graph_output) {
InlinedVector<const NodeArg*> graph_new_outputs;
graph_new_outputs.reserve(graph_outputs.size());
for (auto out_arg : graph_outputs) {
if (out_arg->Name().compare(gathernd_out_arg->Name()) == 0) {
graph_new_outputs.push_back(target_node_out_arg);
} else {
graph_new_outputs.push_back(out_arg);
}
}
graph.SetOutputs(graph_new_outputs);
graph.SetGraphResolveNeeded();
graph.SetGraphProtoSyncNeeded();
}
return Status::OK();
}
static Status SimpleHandler(Graph& graph, Node& gathernd_node, Node& target_node) {
return SwapGatherNDWithTargetNode(graph, gathernd_node, target_node, 0);
}
/*
This handler change the graphs this way:
Before:
input_1[b,s,h] weight_2[h]
\ /
Add[b,s,h] indices[b,p_s,1]
| /
GatherND[b,p_s,h]
|
<subsquent graphs>
After :
input_1[b,s,h] indices[b,p_s,1]
| /
GatherND[b,p_s,h] weight_2[h]
\ /
Add[b,p_s,h]
|
<subsquent graphs>
Note: b: batch, s: sequence_length, h: hidden_size, p_s: dynamic_prediction_count
*/
static Status BinaryElementwiseHandler(Graph& graph, Node& gathernd_node, Node& target_node) {
int target_node_input_index = GetValidInputForGatherND(target_node);
ORT_RETURN_IF(target_node_input_index == -1, "Invalid target node index");
return SwapGatherNDWithTargetNode(graph, gathernd_node, target_node, target_node_input_index);
}
/*
This handler change the graphs this way:
Before:
input_1[b,s,h] weight_2[h, 2h]
\ /
MatMul[b,s,2h] indices[b,p_s,1]
| /
GatherND[b,p_s,2h]
|
<subsquent graphs>
After :
input_1[b,s,h] indices[b,p_s,1]
| /
GatherND[b,p_s,h] weight_2[h,2h]
\ /
MatMul[b,p_s,2h]
|
<subsquent graphs>
Note: b: batch, s: sequence_length, h: hidden_size, p_s: dynamic_prediction_count
*/
static Status MatMulHandler(Graph& graph, Node& gathernd_node, Node& target_node) {
int target_node_input_index = GetValidInputForGatherND(target_node);
ORT_RETURN_IF_NOT(target_node_input_index == 0, "target_node_input_index != 0");
return SwapGatherNDWithTargetNode(graph, gathernd_node, target_node, target_node_input_index);
}
static std::unordered_map<std::string, Handler> handlers = {
{"Add", BinaryElementwiseHandler},
{"Div", BinaryElementwiseHandler},
{"Gelu", SimpleHandler},
{"LayerNormalization", SimpleHandler},
{"MatMul", MatMulHandler}};
static Status Delegate(Graph& graph, Node& gathernd_node, Node& target_node) {
const std::string& op_type = target_node.OpType();
if (handlers.count(op_type)) {
return handlers[op_type](graph, gathernd_node, target_node);
} else {
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, op_type + " handler is not implemented");
}
}
Status ComputationReductionTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
for (auto index : order) {
auto* node_ptr = graph.GetNode(index);
if (!node_ptr)
// node was removed, this should not happen since we are not removing nodes.
continue;
auto& node = *node_ptr;
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
// Same ideas might apply for Gather, GatherElements, Slice, Split, etc.
// TODO: let's review the real cases to make the logic more generic.
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "GatherND", {1, 12, 13}, kOnnxDomain) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
node.GetOutputEdgesCount() > 1) { // allow GatherND have no out edges in case it is graph output.
continue;
}
auto batch_dims = static_cast<int64_t>(node.GetAttributes().at("batch_dims").i());
if (batch_dims != GATHERND_BATCH_DIM) {
continue;
}
auto indices_shape = node.MutableInputDefs()[1]->Shape();
if (indices_shape == nullptr) {
continue;
}
const auto indices_rank = indices_shape->dim_size();
auto& indices_last_dim = indices_shape->dim(indices_rank - 1);
// Since GatherND is assumed to have batch_dims=1, if the input data's shape is [batch, sequence, ..., ... ],
// limiting indices_rank=3 will make sure produced output is in shape [batch, sliced_sequence, ..., ...]
// and the rank did not change.
if (!(indices_last_dim.has_dim_value() && indices_last_dim.dim_value() == 1 && indices_rank == 3)) {
continue;
}
// Todo: check whether we want to move GatherND up, for example, if GatherND's outputs are larger
// than inputs, we should NOT probably bring it ahead.
bool stop = false;
while (!stop) {
const Node* gathernd_data_producer = graph.GetProducerNode(node.MutableInputDefs()[0]->Name());
if (gathernd_data_producer == nullptr) {
break;
}
Node* input_node = const_cast<Node*>(gathernd_data_producer);
if (graph.GetConsumerNodes(input_node->MutableOutputDefs()[0]->Name()).size() > 1) {
LOGS_DEFAULT(WARNING) << "node " << node.Name() << " stopped at node "
<< input_node->Name();
break;
}
auto ret = Delegate(graph, node, *input_node);
if (ret.IsOK()) {
LOGS_DEFAULT(WARNING) << "node " << node.Name() << " up across node "
<< input_node->Name() << std::endl;
modified = true;
} else if (ret.Code() == common::NOT_IMPLEMENTED) {
LOGS_DEFAULT(WARNING) << "node " << node.Name() << " stopped at node "
<< input_node->Name();
break;
} else {
LOGS_DEFAULT(WARNING) << " terminate due to unexpected error, node names:" << node.Name()
<< ", " << input_node->Name() << ", error " << ret.ErrorMessage() << std::endl;
stop = true;
}
}
}
if (modified) {
graph.SetGraphResolveNeeded();
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -1,18 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/graph_transformer.h"
namespace onnxruntime {
class ComputationReductionTransformer : public GraphTransformer {
public:
ComputationReductionTransformer(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("ComputationReductionTransformer", compatible_execution_providers) {}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -0,0 +1,541 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef ENABLE_TRAINING
#include <onnx/defs/attr_proto_util.h>
#include "core/common/safeint.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/utils.h"
#include "core/optimizer/compute_optimizer/passthrough_actors.h"
#include "core/optimizer/compute_optimizer/compute_optimizer.h"
using SliceInfo = onnxruntime::optimizer::compute_optimizer::SliceInfo;
using namespace onnxruntime::optimizer::compute_optimizer;
namespace onnxruntime {
namespace {
bool EnforceNodeAllInputOutputHaveShapes(const Node& node) {
for (const auto* input_def : node.InputDefs()) {
if (!input_def->Shape()) {
return false;
}
}
for (const auto* output_def : node.OutputDefs()) {
if (!output_def->Shape()) {
return false;
}
}
return true;
}
using OPSET_VERSION_LIST = std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion>;
const OPSET_VERSION_LIST opset_1{1};
const OPSET_VERSION_LIST opset_13_1{13, 1};
const OPSET_VERSION_LIST opset_13_9_1{13, 9, 1};
const OPSET_VERSION_LIST opset_13_11_1{13, 11, 1};
const OPSET_VERSION_LIST opset_13_9_6_1{13, 9, 6, 1};
const OPSET_VERSION_LIST opset_14_13_5_1{14, 13, 5, 1};
const OPSET_VERSION_LIST opset_14_13_7_6_1{14, 13, 7, 6, 1};
const OPSET_VERSION_LIST opset_13_12_10_7_6_1{13, 12, 10, 7, 6, 1};
/**
* @brief Functor to trigger the optimization search for a given slicing node
* (for example Gather/GatherND node).
*/
struct SliceOperationReorderHandle {
/**
* @brief Pass through configuration for specific operator.
*
* For each operator:
* > `input_indices` can be used to explicitly specify the input indices that Slicing op can be passed through.
* This could be helpful if some inputs are not applicable for pass through. If not specified, all inputs
* are considered (but there will be checks to ignore those inputs that are not affected by the slicing axis).
* > `actor` will be used to perform the actual pass through, including both pre-check stage and post process
* stage.
*/
struct OpPassThroughConfig {
OpPassThroughConfig(const std::vector<int>& input_indices,
std::shared_ptr<OperatorPassThroughActorBase> actor,
const OPSET_VERSION_LIST& opset_list)
: input_indices(input_indices), actor(actor), opsets(opset_list) {
}
std::vector<int> input_indices;
std::shared_ptr<OperatorPassThroughActorBase> actor;
const OPSET_VERSION_LIST& opsets;
};
static std::string GetFullQualifiedOpName(const std::string& op_type, const std::string& domain) {
return domain + "::" + op_type;
}
static std::unordered_map<std::string, OpPassThroughConfig>& GetOpPassThroughConfigMap() {
static std::unordered_map<std::string, OpPassThroughConfig> allowed_passthrough_ops;
static std::once_flag allowed_ops_init;
std::call_once(allowed_ops_init, []() {
allowed_passthrough_ops.insert({
// Things to consider when more operators are added here:
// 1. Whether the operator is safe to pass through in term of compute equivalence.
// If optype is not enough to guarantee the equivalence, we need to add a customized pre-check function
// (as LayerNormalization did).
// 2. Whether the outputs have the same dim changes if Gather node moves before that operator.
// 3. Should all inputs be allowed when track back further (bottom-up);
// if not, add the input index restriction as MatMul did.
{GetFullQualifiedOpName("Add", kOnnxDomain),
OpPassThroughConfig({}, std::make_shared<SimplePassThroughActor>(), opset_14_13_7_6_1)},
{GetFullQualifiedOpName("BiasGelu", kMSDomain),
OpPassThroughConfig({}, std::make_shared<SimplePassThroughActor>(), opset_1)},
{GetFullQualifiedOpName("BitmaskBiasDropout", kMSDomain),
OpPassThroughConfig({}, std::make_shared<SimplePassThroughActor>(), opset_1)},
{GetFullQualifiedOpName("Cast", kOnnxDomain),
OpPassThroughConfig({}, std::make_shared<SimplePassThroughActor>(), opset_13_9_6_1)},
{GetFullQualifiedOpName("Div", kOnnxDomain),
OpPassThroughConfig({}, std::make_shared<SimplePassThroughActor>(), opset_14_13_7_6_1)},
{GetFullQualifiedOpName("Dropout", kOnnxDomain),
OpPassThroughConfig({}, std::make_shared<SimplePassThroughActor>(), opset_13_12_10_7_6_1)},
{GetFullQualifiedOpName("Gelu", kMSDomain),
OpPassThroughConfig({}, std::make_shared<SimplePassThroughActor>(), opset_1)},
{// Be noted, this is our own implementation of ONNX domain op.
GetFullQualifiedOpName("LayerNormalization", kOnnxDomain),
OpPassThroughConfig({0}, std::make_shared<ReductionOpPassThroughActor>(), opset_1)},
{GetFullQualifiedOpName("MatMul", kOnnxDomain),
OpPassThroughConfig({}, std::make_shared<MatMulPassThroughActor>(), opset_13_9_1)},
{GetFullQualifiedOpName("Reshape", kOnnxDomain),
OpPassThroughConfig({0}, std::make_shared<ReshapePassThroughActor>(), opset_14_13_5_1)},
{GetFullQualifiedOpName("Softmax", kOnnxDomain),
OpPassThroughConfig({0}, std::make_shared<ReductionOpPassThroughActor>(), opset_13_11_1)},
{GetFullQualifiedOpName("Transpose", kOnnxDomain),
OpPassThroughConfig({}, std::make_shared<TransposePassThroughActor>(), opset_13_1)},
});
});
return allowed_passthrough_ops;
}
SliceOperationReorderHandle(const std::string& node_name) : entry_node_name_(node_name) {
}
bool operator()(Graph& graph, Node& current_node, SliceInfo& info, const logging::Logger& logger,
std::deque<SliceInfo>& queue);
private:
/**
* @brief Pass through Slicing op from current_node's output to its specific input.
*
* Propagate the slicing operation into current_node's current_input_index-th input, e.g. a slicing op is inserted
* between current_node's current_input_index-th input and current_node. For example, if current_node is Add,
* and slice_node is a Gather(axis=1, indices=[1]):
*
* input_0 [M, N, K] input_1 [M, N, K]
* \ /
* Add [M, N, K]
* |
* Gather0(axis=1, indices=[1])
* |
* output [M, 1, K]
*
* After the pass through, the graph will be:
*
* input_0 [M, N, K] input_1 [M, N, K]
* \ /
* Gather1(axis=1, indices=[1]) Gather2(axis=1, indices=[1])
* \ /
* \ /
* \ /
* Add [M, N, K]
* |
* Gather0(axis=1, indices=[1])
* |
* output [M, 1, K]
*
* Be noted: Gather1 and Gather2 are inserted on Add's two inputs.
* Gather0's removal and Add's output shape update is done in RemoveOriginSlicingOp.
*
*
* @param graph Graph to iterate.
* @param slice_node Slicing op node the takes current_node's output as input.
* @param current_node Current node.
* @param current_node_input_index The current_node_input_index-th input to propagate the Slice op pass through.
* @param info slice_node's SliceInfo.
* @param logger Logger.
* @param new_axis The new axis (for the new Slice op) upon current_node's original current_node_input_index-th input.
* @return SliceInfo for new created slicing op.
*/
SliceInfo PropagateSlicingForInput(Graph& graph, Node& slice_node, Node& current_node, int current_node_input_index,
SliceInfo& info, int new_axis, const logging::Logger& logger);
/**
* @brief Remove the origin slicing op (for example Gather/GatherND) and update shapes.
*
* In the above example, the graph will be cleaned up to:
* input_0 [M, N, K] input_1 [M, N, K]
* \ /
* Gather1(axis=1, indices=[1]) Gather2(axis=1, indices=[1])
* \ /
* \ /
* \ /
* Add [M, 1, K]
* |
* |
* output [M, 1, K]
*
* Be noted: Gather0 is removed, Add's output shape is updated.
*
* @param graph Graph to iterate.
* @param slice_node Slicing op node the takes current_node's output as input.
* @param current_node Current node.
* @param logger Logger.
* @param info slice_node's SliceInfo.
* @return
*/
Status RemoveOriginSlicingOp(Graph& graph, Node& slice_node, Node& current_node,
const logging::Logger& logger, SliceInfo& info);
std::string entry_node_name_;
};
bool SliceOperationReorderHandle::operator()(Graph& graph, Node& current_node,
SliceInfo& info,
const logging::Logger& logger,
std::deque<SliceInfo>& queue) {
Node& slice_node = *info.node_ptr;
const std::string& op_type = GetFullQualifiedOpName(current_node.OpType(), current_node.Domain());
if (GetOpPassThroughConfigMap().count(op_type)) {
auto& pass_through_config = GetOpPassThroughConfigMap().at(op_type);
LOG_DEBUG_INFO(logger, "Enter reorder handle for node " + current_node.Name() + "(" + op_type + ")");
if (!graph_utils::IsSupportedOptypeVersionAndDomain(current_node, current_node.OpType(),
pass_through_config.opsets, current_node.Domain())) {
LOG_DEBUG_INFO(logger, "Unsupported opset for " + current_node.Name() + "(" + op_type + ") since version: " +
std::to_string(current_node.SinceVersion()));
return false;
}
if (!EnforceNodeAllInputOutputHaveShapes(current_node)) {
LOG_DEBUG_INFO(logger, "Some inputs/outputs' shape not found for node " + current_node.Name() + "(" +
op_type + ")");
return false;
}
std::unordered_map<int, int> candidate_input_indices;
bool input_has_dim_1_for_axis = false;
if (!pass_through_config.actor->PreCheck(graph, current_node, info, pass_through_config.input_indices, logger,
candidate_input_indices, input_has_dim_1_for_axis)) {
LOG_DEBUG_INFO(logger, "Pre-check failed for " + current_node.Name() + "(" + op_type + ")");
return false;
}
if (candidate_input_indices.empty()) {
LOG_DEBUG_INFO(logger, "Skip handling current node " + current_node.Name() + "(" + op_type +
") because the requirement is not met.");
return false;
}
// Be noted, once we reach this point after PreCheck, graph modification started, any failure after this should
// be reported as ERROR.
std::vector<SliceInfo> populated_slicing_infos; // Slicing infos that are populated into current_node's inputs.
populated_slicing_infos.reserve(candidate_input_indices.size());
std::unordered_map<int, SliceInfo> new_gather_infos;
for (auto pair : candidate_input_indices) {
auto input_index = pair.first; // input index of current_node
int new_axis = pair.second; // new axis of current_node's input to be sliced
SliceInfo gather_info = PropagateSlicingForInput(graph, slice_node, current_node, input_index, info, new_axis,
logger);
ORT_ENFORCE(gather_info.node_ptr, "New added gather node should not be null.");
populated_slicing_infos.push_back(gather_info);
new_gather_infos.insert({{input_index, gather_info}});
}
int index_of_output =
optimizer_utils::IndexOfNodeOutput(current_node, *slice_node.InputDefs()[info.GetDataInputIndex()]);
ORT_ENFORCE(RemoveOriginSlicingOp(graph, slice_node, current_node, logger, info).IsOK());
if (!pass_through_config.actor->PostProcess(graph, current_node, index_of_output, info.non_negative_axis,
info.is_scalar_slice, input_has_dim_1_for_axis,
info.output_dim_on_axis,
entry_node_name_, new_gather_infos,
logger)) {
ORT_THROW("Post-process failed for " + current_node.Name() + "(" + op_type + ")");
}
queue.insert(queue.end(), populated_slicing_infos.begin(), populated_slicing_infos.end());
return true;
} else {
LOG_DEBUG_INFO(logger, "op_type not supported for " + current_node.Name() + "(" + op_type + ")");
return false;
}
}
SliceInfo SliceOperationReorderHandle::PropagateSlicingForInput(Graph& graph,
Node& slice_node,
Node& current_node,
int current_node_input_index,
SliceInfo& info,
int new_axis,
const logging::Logger& logger) {
LOG_DEBUG_INFO(logger, "PropagateSlicingForInput for Node " + slice_node.Name() + "(" + slice_node.OpType() +
") with input index " + std::to_string(current_node_input_index) + ", keep_dim = " +
std::to_string(!info.is_scalar_slice));
InlinedVector<NodeArg*> input_args;
input_args.reserve(slice_node.InputDefs().size());
// The first slice op's data input should be current_node's current_node_input_index-th input.
// For some cases when rank changes, slice op's slice input should also be adapted.
input_args.push_back(current_node.MutableInputDefs()[current_node_input_index]);
for (size_t i = 1; i < slice_node.InputDefs().size(); ++i) {
input_args.push_back(slice_node.MutableInputDefs()[i]);
}
// Update the axis attribute if new_axis is not same with the original slicing axis (which happens when data
// layout got changed by Transpose or Reshape ops)
onnxruntime::NodeAttributes attributes = slice_node.GetAttributes();
if (info.non_negative_axis != new_axis) {
attributes[info.axis_attr_name] =
ONNX_NAMESPACE::MakeAttribute(info.axis_attr_name, static_cast<int64_t>(new_axis));
}
InlinedVector<NodeArg*> output_args;
output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(info.entry_slice_arg_name),
current_node.MutableInputDefs()[current_node_input_index]->TypeAsProto()));
/* new node input index to connect to current_node's input node*/
int new_slice_input_index_to_connect = info.GetDataInputIndex();
/* new node output index to connect to current_node*/
int new_slice_output_index_to_connect = info.GetOutputIndex();
Node* new_slice_node = InsertIntermediateNodeOnDestInput(graph, current_node,
current_node_input_index,
new_slice_input_index_to_connect,
new_slice_output_index_to_connect,
graph.GenerateNodeName(info.entry_slice_arg_name),
slice_node.OpType(),
"Duplicated Gather node",
input_args,
output_args,
attributes,
slice_node.Domain(),
logger);
new_slice_node->SetExecutionProviderType(slice_node.GetExecutionProviderType());
// Set correct shape for new created node.
auto new_slice_out_arg = new_slice_node->MutableOutputDefs()[new_slice_output_index_to_connect];
int reversed_axis = new_axis - new_slice_out_arg->Shape()->dim_size();
UpdateSliceOutputShape(*new_slice_out_arg, reversed_axis, info.output_dim_on_axis);
auto new_slice_info = SliceInfo(new_slice_node, info.is_scalar_slice, info.axis_attr_name, new_axis);
new_slice_info.entry_slice_arg_name = info.entry_slice_arg_name;
return new_slice_info;
}
Status SliceOperationReorderHandle::RemoveOriginSlicingOp(Graph& graph, Node& slice_node, Node& current_node,
const logging::Logger& logger, SliceInfo& info) {
LOG_DEBUG_INFO(logger, "RemoveOriginSlicingOp target_node " + current_node.Name() + "(" + current_node.OpType() +
") slice_node " + slice_node.Name() + "(" + slice_node.OpType() + "), keep_dim = " +
std::to_string(!(info.is_scalar_slice)));
auto slice_input_arg = slice_node.MutableInputDefs()[info.GetDataInputIndex()];
int slice_input_rank = slice_input_arg->Shape()->dim_size();
int output_index = optimizer_utils::IndexOfNodeOutput(current_node, *slice_input_arg);
auto slice_op_output_arg = slice_node.MutableOutputDefs()[info.GetOutputIndex()];
// Loop all outputs of target node, update the shape accordingly.
// For elementwise ops like (LayerNorm/Dropout/Add), we should handle all outputs.
// If some output rank is lower than sliced axis, we should just ignore it (the correctness is guaranteed by devs
// who adds more operator coverage in the pass through).
for (size_t i = 0; i < current_node.MutableOutputDefs().size(); ++i) {
UpdateSliceOutputShape(*current_node.MutableOutputDefs()[i], info.non_negative_axis - slice_input_rank,
info.output_dim_on_axis);
}
LOG_DEBUG_INFO(logger, "RemoveOriginSlicingOp Replace all usage of output " + slice_op_output_arg->Name() + ":0" +
" with " + current_node.MutableOutputDefs()[output_index]->Name() + ":" +
std::to_string(output_index));
graph_utils::ReplaceDownstreamNodeInput(graph, slice_node, info.GetOutputIndex() /*output_idx*/, current_node,
output_index /*replacement_output_idx*/);
auto gather_origin_consumer_nodes = graph.GetConsumerNodes(slice_op_output_arg->Name());
std::vector<Node*> slice_op_consumers;
slice_op_consumers.reserve(gather_origin_consumer_nodes.size());
for (auto& consumer_node : gather_origin_consumer_nodes) {
slice_op_consumers.push_back(graph.GetNode(consumer_node->Index()));
LOG_DEBUG_INFO(logger, "RemoveOriginSlicingOp Gather's consumer node " + consumer_node->Name() + "(" +
consumer_node->OpType() + ")");
}
graph.UpdateConsumerNodes(current_node.OutputDefs()[output_index]->Name(), slice_op_consumers);
graph.UpdateConsumerNodes(slice_op_output_arg->Name(), {});
graph.RemoveNode(slice_node.Index());
return Status::OK();
}
} // namespace
std::optional<SliceInfo> ComputeOptimizer::IsSupportedGatherND(Graph& /*graph*/, Node& node,
const logging::Logger& logger) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "GatherND", {1, 12, 13}, kOnnxDomain) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
return std::nullopt;
}
auto data_shape = node.MutableInputDefs()[0]->Shape();
auto indices_shape = node.MutableInputDefs()[1]->Shape();
auto gather_out_shape = node.MutableOutputDefs()[0]->Shape();
if (data_shape == nullptr || indices_shape == nullptr || gather_out_shape == nullptr) {
LOG_DEBUG_INFO(logger, "Skip GatherND node " + node.Name() + " due to undefined shape.");
return std::nullopt;
}
const auto data_rank = data_shape->dim_size();
const auto indices_rank = indices_shape->dim_size();
// batch_dims is an integer indicating the number of batch dimensions,
// i.e the leading b number of dimensions of data tensor and indices are representing the batches,
// and the gather starts from the b+1 dimension.
auto batch_dims = static_cast<int64_t>(node.GetAttributes().at("batch_dims").i());
ORT_ENFORCE(batch_dims >= 0 && batch_dims < indices_rank && batch_dims < data_rank,
"batch_dims must be in the range [0, min(indices_rank, data_rank)):" + std::to_string(batch_dims) +
" indices_rank:" + std::to_string(indices_rank) + " data_rank:" + std::to_string(data_rank));
// Since GatherND is assumed to have batch_dims=1, if the input data's shape is [batch, sequence, ..., ... ],
// limiting indices_rank=3 will make sure produced output is in shape [batch, sliced_sequence, ..., ...]
// and the rank did not change.
// TODO: release the constraint here.
if (data_rank != 3 || indices_rank != 3 || batch_dims != 1) {
return std::nullopt;
}
auto& indices_last_dim = indices_shape->dim(indices_rank - 1);
if (!(indices_last_dim.has_dim_value() && indices_last_dim.dim_value() == 1)) {
return std::nullopt;
}
return SliceInfo(&node, false, "batch_dims", static_cast<int>(batch_dims), true);
}
std::optional<SliceInfo> ComputeOptimizer::IsSupportedGather(Graph& /*graph*/, Node& node,
const logging::Logger& logger) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}, kOnnxDomain) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
return std::nullopt;
}
auto data_shape = node.MutableInputDefs()[0]->Shape();
auto indices_shape = node.MutableInputDefs()[1]->Shape();
auto gather_out_shape = node.MutableOutputDefs()[0]->Shape();
if (data_shape == nullptr || indices_shape == nullptr || gather_out_shape == nullptr) {
LOG_DEBUG_INFO(logger, "Skip Gather node " + node.Name() + " due to undefined shape.");
return std::nullopt;
}
const auto data_rank = data_shape->dim_size();
if (data_rank <= 1) {
LOG_DEBUG_INFO(logger, "Skip Gather node " + node.Name() + " due to rank <= 1.");
return std::nullopt;
}
auto axis = static_cast<int>(node.GetAttributes().at("axis").i());
axis = axis < 0 ? axis + data_rank : axis;
size_t dim_size = static_cast<size_t>(indices_shape->dim_size());
bool is_single_value_1d_tensor = dim_size != 0 && (dim_size == 1 && utils::HasDimValue(indices_shape->dim(0)) &&
indices_shape->dim(0).dim_value() == 1);
if (dim_size != 0 && !is_single_value_1d_tensor) {
if (dim_size == 1 && utils::HasDimValue(data_shape->dim(axis)) &&
data_shape->dim(axis).dim_value() > indices_shape->dim(0).dim_value()) {
// Can support.
} else {
LOG_DEBUG_INFO(logger, "Skip Gather node " + node.Name() + " due to unsupported dim size: " +
std::to_string(dim_size));
return std::nullopt;
}
}
return SliceInfo(&node, dim_size == 0, "axis", axis, true);
}
Status ComputeOptimizer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger)
const {
LOG_DEBUG_INFO(logger, "Enter ComputeOptimizer");
bool reordered = false;
GraphViewer graph_viewer(graph);
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
const auto& graph_outputs = graph.GetOutputs();
size_t reordered_node_count = 0; // For summary
for (auto index : order) {
auto* node_ptr = graph.GetNode(index);
if (!node_ptr)
// node was removed.
continue;
auto& node = *node_ptr;
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
std::optional<SliceInfo> gather_info;
// Same ideas might apply for GatherElements, Slice, Split, etc..
gather_info = IsSupportedGatherND(graph, node, logger);
if (!gather_info.has_value()) {
gather_info = IsSupportedGather(graph, node, logger);
}
if (!gather_info.has_value()) {
continue;
}
auto& output_arg = node.MutableOutputDefs()[0];
if (std::find(graph_outputs.begin(), graph_outputs.end(), output_arg) != graph_outputs.end()) {
continue;
}
std::string node_name = node.Name();
std::string node_type = node.OpType();
std::deque<SliceInfo> gather_queue;
gather_queue.push_back(gather_info.value());
std::string log_prefix = "Entry node " + node_name + " (" + node_type + ") with axis " +
std::to_string(gather_info.value().non_negative_axis);
LOG_DEBUG_INFO(logger, log_prefix + " starts re-ordering check");
SliceOperationReorderHandle handle(node_name);
// DON'T operate on `node` once this loop starts, as it may be removed from the graph.
while (!gather_queue.empty()) {
SliceInfo info = gather_queue.front();
Node* gather_node = info.node_ptr;
gather_queue.pop_front();
Node* slice_input_data_producer =
graph.GetMutableProducerNode(gather_node->MutableInputDefs()[0]->Name());
if (slice_input_data_producer == nullptr) {
break;
}
Node* input_node = slice_input_data_producer;
if (graph.GetConsumerNodes(input_node->MutableOutputDefs()[0]->Name()).size() > 1) {
LOG_DEBUG_INFO(logger, log_prefix + " stops at node " + input_node->Name() + " since multiple consumer found");
continue;
}
auto ret = handle(graph, *input_node, info, logger, gather_queue);
if (ret) {
LOG_DEBUG_INFO(logger, log_prefix + " moves up across node " + input_node->Name());
modified = true;
reordered = true;
} else {
LOG_DEBUG_INFO(logger, log_prefix + " stops when handling " + input_node->Name());
}
}
if (reordered) {
++reordered_node_count;
}
}
LOGS(logger, INFO) << "Exit ComputeOptimizer with summary - reorderd_node_count:" << reordered_node_count
<< " nodes.";
return Status::OK();
}
} // namespace onnxruntime
#endif

View file

@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef ENABLE_TRAINING
#pragma once
#include "core/optimizer/compute_optimizer/passthrough_actors.h"
#include "core/optimizer/graph_transformer.h"
#include "core/optimizer/utils.h"
namespace onnxruntime {
/**
* @brief Graph transformer that helps reduce compute FLOP while maintaining mathematically equivalent result.
*
* This graph transformation tries to identify opportunities to reduce unnecessary computations on the graph level.
* Currently, the major optimization is to bring some slice operators ahead as much as possible, to leave more ops
* operate on sliced input data. Gather and GatherND are the entry operators that trigger the optimization search.
*
* In terms of file dependency, compute_optimizer.h/cc reference structs and utilities defined in
* passthrough_actors.h/cc.
*/
class ComputeOptimizer : public GraphTransformer {
public:
using SliceInfo = onnxruntime::optimizer::compute_optimizer::SliceInfo;
ComputeOptimizer(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("ComputeOptimizer", compatible_execution_providers) {}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
private:
std::optional<SliceInfo> IsSupportedGatherND(Graph& graph, Node& node, const logging::Logger& logger) const;
std::optional<SliceInfo> IsSupportedGather(Graph& graph, Node& node, const logging::Logger& logger) const;
};
} // namespace onnxruntime
#endif

View file

@ -0,0 +1,844 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef ENABLE_TRAINING
#include <onnx/defs/attr_proto_util.h>
#include "core/common/safeint.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/utils.h"
#include "core/optimizer/compute_optimizer/passthrough_actors.h"
#include "core/optimizer/compute_optimizer/compute_optimizer.h"
using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
namespace onnxruntime::optimizer::compute_optimizer {
enum class DimCompareRet {
ExactEqual = 0,
BroadcastableEqual = 1,
RankTooLow = 2,
NotEqual = 3,
DimCompareRetMax = 4,
};
/**
* @brief Check dimensions are equal or broadcastable before axis.
*
* @param full_broadcasted_shape Full broadcasted shape as a baseline to compare.
* @param axis The axis (inclusive, of full_broadcasted_shape) where we end the comparison.
* @param target_shape Shape to compare, can have dim value be 1 for broadcastable dimension.
* @return A pair of bool, bool. The first bool is true if the dimensions are exactly same before and include axis.
* The second bool is true if the dimension of target_shape has dim value be 1 on axis.
*/
std::pair<DimCompareRet, bool> AreDimsCompatibleBeforeAxisInternal(
const TensorShapeProto* full_broadcasted_shape, const int axis,
const TensorShapeProto* target_shape) {
int full_rank = full_broadcasted_shape->dim_size();
int target_rank = target_shape->dim_size();
ORT_ENFORCE(full_rank >= axis && target_rank <= full_rank, "full_rank should bigger than axis and target_rank ",
axis, " full_rank: ", full_rank, " target_rank: ", target_rank);
int minimum_rank_to_handle = full_rank - axis;
if (target_rank < minimum_rank_to_handle) {
// Skip if target node's input rank is less than minimum rank to handle.
// Essentially this means the input did not affect the Gather axis.
return std::make_pair(DimCompareRet::RankTooLow, false);
}
bool exact_equal = true;
bool broadcastable_equal = true;
bool dim_be_1_on_axis = false;
int axis_iter = axis;
int negative_axis = axis < 0 ? axis : axis - full_rank;
int target_axis_iter = target_rank + negative_axis;
for (; axis_iter >= 0 && target_axis_iter >= 0; --axis_iter, --target_axis_iter) {
auto& dim = full_broadcasted_shape->dim(axis_iter);
auto& target_dim = target_shape->dim(target_axis_iter);
if (dim.has_dim_value() && target_dim.has_dim_value()) {
if (dim.dim_value() != target_dim.dim_value()) {
exact_equal = false;
if (target_dim.dim_value() == 1) {
if (axis_iter == axis) dim_be_1_on_axis = true;
} else {
broadcastable_equal = false;
}
}
} else if (dim.has_dim_param() && target_dim.has_dim_param()) {
if (dim.dim_param() != target_dim.dim_param()) {
exact_equal = false;
}
} else {
exact_equal = false;
if (target_dim.has_dim_value() && target_dim.dim_value() == 1) {
if (axis_iter == axis) dim_be_1_on_axis = true;
} else {
broadcastable_equal = false;
}
}
}
if (exact_equal) {
return std::make_pair(DimCompareRet::ExactEqual, dim_be_1_on_axis);
} else if (broadcastable_equal) {
return std::make_pair(DimCompareRet::BroadcastableEqual, dim_be_1_on_axis);
} else {
return std::make_pair(DimCompareRet::NotEqual, dim_be_1_on_axis);
}
}
/**
* @brief Check input meet pass through requirement.
*
* @param current_node_output_arg_to_gather The output arg of current node that consumed by slice node.
* @param arg_to_compare The input/output arg to check.
* @param info Slice info.
* @param logger The logger.
* @param fatal_error_found Used as return value. If fatal error found, set to true. Fatal error means,
* we cannot pass through this input arg.
* @param dim_1_for_axis_found Used as return value. If dim value is 1 for axis, set to true.
* @return a int represent the new slice axis for the input arg, if pass through needed to be done for
* this input arg, otherwise, return nullptr.
*
* For each input of current_node, using this function to check if the input can be passed through.
* If the input has dim on negative_axis and
* 1). either the dimension (if exists) including and before negative_axis is same as target node's output shape.
* 2). or the dimension (if exists) including and before negative_axis is 1.
* Otherwise, we will skip the optimization.
*
* Example 1: [Can be passed through]
* input_0 [M, N, K] input_1 [K]
* \ /
* Add [M, N, K] (current_node)
* |
* Gather0(axis=1, indices=[1])
* |
* output [M, 1, K]
* In this case, we can propagate Gather to input_0 branch, input_1 is skipped because it did not has dim on
* slicing axis.
*
* Example 2: [Can be passed through]
* input_0 [M, N, K] input_1 [N, K]
* \ /
* Add [M, N, K] (current_node)
* |
* Gather0(axis=1, indices=[1])
* |
* output [M, 1, K]
* In this case, we can propagate Gather to input_0 and input-1 branch, because including and before slicing axis 1,
* all dims are equal.
*
* Example 3: [Can be passed through]
* input_0 [M, N, K] input_1 [1, K]
* \ /
* Add [M, N, K] (current_node)
* |
* Gather0(axis=1, indices=[1])
* |
* output [M, 1, K]
* In this case, we can propagate Gather to input_0 branch, input_1 branch is skipped because it has dim 1 on slicing
* axis.
*
* Example 4: [Can be passed through]
* input_0 [M, N, K] input_1 [1, N, K]
* \ /
* Add [M, N, K] (current_node)
* |
* Gather0(axis=1, indices=[1])
* |
* output [M, 1, K]
* In this case, we can propagate Gather to input_0 and input_1 branch.
*
* Example 5: [Can be passed through]
* input_0 [M, N, K] input_1 [M, 1, K]
* \ /
* Add [M, N, K] (current_node)
* |
* Gather0(axis=1, indices=[1])
* |
* output [M, 1, K]
* In this case, we can propagate Gather to input_0 branch, input_1 branch is skipped because it has dim 1 on slicing.
*
* Example 6: [CANNOT be passed through]
* input_0 [M, N, K] input_1 [L, N, K]
* \ /
* Add [M, N, K] (current_node)
* |
* Gather0(axis=1, indices=[1])
* |
* output [M, 1, K]
*
*/
std::optional<int> CheckInputForPassThrough(const NodeArg* current_node_output_arg_to_gather,
const NodeArg* arg_to_compare,
const SliceInfo& info,
const logging::Logger& logger,
bool& fatal_error_found,
bool& dim_1_for_axis_found) {
fatal_error_found = false;
auto ret_pair = AreDimsCompatibleBeforeAxisInternal(current_node_output_arg_to_gather->Shape(),
info.non_negative_axis,
arg_to_compare->Shape());
if (ret_pair.first == DimCompareRet::ExactEqual) {
return info.non_negative_axis;
} else if (ret_pair.first == DimCompareRet::RankTooLow) {
LOG_DEBUG_INFO(logger, "Skip " + arg_to_compare->Name() + " because its rank is too low.");
return std::nullopt;
} else if (ret_pair.first == DimCompareRet::NotEqual) {
fatal_error_found = true;
return std::nullopt;
} else if (ret_pair.first == DimCompareRet::BroadcastableEqual) {
if (ret_pair.second) {
LOG_DEBUG_INFO(logger, "Skip " + arg_to_compare->Name() +
", whose dim on axis is 1, no need to Gather from.");
dim_1_for_axis_found = true;
return std::nullopt;
}
return info.non_negative_axis;
}
ORT_THROW("Unexpected return value from CheckInputForPassThrough.");
}
/**
* @brief From given TensorShape, update specified dimension with given value.
* If no new_dim is provided, the dimension will be removed.
*
* @param shape TensorShape used as base shape to modify.
* @param axis The dimension to be replaced/removed.
* @param new_dim The new dimension value. If not provided, the dimension will be removed.
* @return TensorShapeProto A copy of "shape" after modification.
*/
TensorShapeProto CreateNewShapeWithUpdatedDim(const TensorShapeProto* shape, const int axis,
const TensorShapeProto_Dimension& new_dim) {
ORT_ENFORCE(axis >= 0 && axis < shape->dim_size());
TensorShapeProto output_shape;
for (int i = 0; i < shape->dim_size(); ++i) {
auto& dim = shape->dim(i);
if (i == axis) {
if (new_dim.has_dim_value()) {
output_shape.add_dim()->set_dim_value(new_dim.dim_value());
} else if (new_dim.has_dim_param()) {
output_shape.add_dim()->set_dim_param(new_dim.dim_param());
} else {
// do nothing, unassigned dim will be removed.
}
continue;
}
if (dim.has_dim_value()) {
output_shape.add_dim()->set_dim_value(dim.dim_value());
} else if (dim.has_dim_param()) {
output_shape.add_dim()->set_dim_param(dim.dim_param());
} else {
ORT_THROW("Invalid dim found in CreateNewShapeWithUpdatedDim");
}
}
return output_shape;
}
bool UpdateSliceOutputShape(NodeArg& arg_to_update, int reverse_axis, const TensorShapeProto_Dimension& output_dim_on_axis) {
ORT_ENFORCE(reverse_axis < 0, " reverse_axis should be negative, representing the index from right to left.");
const TensorShapeProto* shape = arg_to_update.Shape();
int rank = shape->dim_size();
if (rank < -reverse_axis) {
return false;
}
int axis_to_update = rank + reverse_axis;
TensorShapeProto new_output_shape = CreateNewShapeWithUpdatedDim(shape, axis_to_update, output_dim_on_axis);
arg_to_update.SetShape(new_output_shape);
return true;
}
Node* InsertIntermediateNodeOnDestInput(Graph& graph,
Node& dest_node, int dest_in_index,
int new_node_input_index,
int new_node_output_index,
const std::string& name, const std::string& op_type,
const std::string& description,
const InlinedVector<NodeArg*>& input_args,
const InlinedVector<NodeArg*>& output_args,
const onnxruntime::NodeAttributes& attributes,
const std::string& domain,
const logging::Logger& logger) {
LOG_DEBUG_INFO(logger, "Inserting " + op_type + " node on " + dest_node.Name() + " 's " +
std::to_string(dest_in_index) + "th input " +
dest_node.InputDefs()[dest_in_index]->Name() + ", and connect inserted node's " +
std::to_string(new_node_output_index) + "th output to " + dest_node.Name() + " 's " +
std::to_string(dest_in_index) + "th input.");
ORT_ENFORCE(dest_in_index < static_cast<int>(dest_node.InputDefs().size()));
ORT_ENFORCE(new_node_input_index < static_cast<int>(input_args.size()), "new_node_input_index is out of range.");
ORT_ENFORCE(new_node_output_index < static_cast<int>(output_args.size()), "new_node_output_index is out of range.");
ORT_ENFORCE(dest_node.MutableInputDefs()[dest_in_index] == input_args[new_node_input_index],
"input_args[new_node_input_index] is not the same as dest_node.MutableInputDefs()[dest_in_index].",
dest_node.MutableInputDefs()[dest_in_index]->Name(), " vs ", input_args[new_node_input_index]->Name());
// Prepare Input and Outputs for the duplicated Gather/GatherND node.
NodeArg* src_node_arg = dest_node.MutableInputDefs()[dest_in_index];
// Create the duplicated Gather/GatherND node.
Node& new_node = graph.AddNode(name, op_type, description, input_args, output_args, &attributes, domain);
ORT_ENFORCE(graph.SetOpSchemaFromRegistryForNode(new_node), "Failed to set op schema for " + new_node.Name());
// Connect dest_node's input node to duplicated node.
// Update new node producer and consumer map.
for (size_t j = 0; j < new_node.MutableOutputDefs().size(); ++j) {
graph.UpdateProducerNode(new_node.MutableOutputDefs()[j]->Name(), new_node.Index());
}
graph.AddConsumerNode(src_node_arg->Name(), &new_node);
const Node* src_node = graph.GetProducerNode(src_node_arg->Name());
if (src_node) {
int src_out_index = optimizer_utils::IndexOfNodeOutput(*src_node, *src_node_arg);
graph.AddEdge(src_node->Index(), new_node.Index(), src_out_index, new_node_input_index);
}
// Remove edge between dest_node and src_node.
// Be noted, this will remove dest_node's input edges to src_node
// (and also the src_node's output edges to dest_node).
std::vector<graph_utils::GraphEdge> input_edge_to_remove;
input_edge_to_remove.reserve(1);
for (auto it = dest_node.InputEdgesBegin(), end = dest_node.InputEdgesEnd(); it != end; ++it) {
LOG_DEBUG_INFO(logger, "dest_node " + dest_node.Name() + " input edge: " + it->GetNode().Name() +
" output index: " + std::to_string(it->GetSrcArgIndex()) + " input index: " +
std::to_string(it->GetDstArgIndex()));
if (it->GetDstArgIndex() == dest_in_index) {
input_edge_to_remove.push_back(graph_utils::GraphEdge::CreateGraphEdge(dest_node, *it, true));
break;
}
}
// If the input is graph input or initializer, no edge will be removed.
if (input_edge_to_remove.size() > 0) {
graph_utils::GraphEdge::RemoveGraphEdges(graph, input_edge_to_remove);
// Remove target node from target input arg's consumer list.
const std::string& src_node_arg_name = src_node_arg->Name();
int input_use_count_by_dest_node = 0;
for (size_t i = 0; i < dest_node.InputDefs().size(); ++i) {
if (dest_node.InputDefs()[i]->Name().compare(src_node_arg_name) == 0) {
++input_use_count_by_dest_node;
}
}
if (input_use_count_by_dest_node == 1) {
graph.RemoveConsumerNode(src_node_arg_name, &dest_node);
}
}
// Connect duplicated gather node to target node's input.
dest_node.MutableInputDefs()[dest_in_index] = new_node.MutableOutputDefs()[new_node_output_index];
// Add new edge connecting the duplicated gather with the target node directly.
// This also updates the destination node's input node args
graph.AddEdge(new_node.Index(), dest_node.Index(), new_node_output_index, dest_in_index);
graph.AddConsumerNode(new_node.MutableOutputDefs()[new_node_output_index]->Name(), &dest_node);
LOG_DEBUG_INFO(logger, "Inserted " + op_type + " node on " + dest_node.Name() + " 's " +
std::to_string(dest_in_index) + "th input " +
dest_node.InputDefs()[dest_in_index]->Name());
return &new_node;
}
TensorShapeProto CreateTensorShapeInsertDimAtAxis(const TensorShapeProto* src_shape, int axis, int64_t dim_value) {
ORT_ENFORCE(axis <= src_shape->dim_size(), "axis is out of range.", axis, " vs ", src_shape->dim_size());
TensorShapeProto updated_shape;
int j = 0;
for (j = 0; j < axis; ++j) {
auto dim = src_shape->dim(j);
if (dim.has_dim_value()) {
updated_shape.add_dim()->set_dim_value(dim.dim_value());
} else if (dim.has_dim_param()) {
updated_shape.add_dim()->set_dim_param(dim.dim_param());
} else {
ORT_THROW("Invalid dim found in CreateTensorShapeInsertDimAtAxis");
}
}
updated_shape.add_dim()->set_dim_value(dim_value);
for (; j < src_shape->dim_size(); ++j) {
auto dim = src_shape->dim(j);
if (dim.has_dim_value()) {
updated_shape.add_dim()->set_dim_value(dim.dim_value());
} else if (dim.has_dim_param()) {
updated_shape.add_dim()->set_dim_param(dim.dim_param());
} else {
ORT_THROW("Invalid dim found in CreateTensorShapeInsertDimAtAxis");
}
}
return updated_shape;
}
NodeArg* CreateUnsqueezeAxesInitializer(Graph& graph, const std::vector<int64_t>& values) {
ONNX_NAMESPACE::TensorProto axes_const_tensor;
axes_const_tensor.set_name(graph.GenerateNodeArgName("axes"));
axes_const_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
axes_const_tensor.add_dims(values.size());
axes_const_tensor.set_raw_data(values.data(), values.size() * sizeof(int64_t));
return &graph_utils::AddInitializer(graph, axes_const_tensor);
}
int GetONNXOpSetVersion(const Graph& graph) {
int onnx_opset = -1;
auto onnx_domain_it = graph.DomainToVersionMap().find(kOnnxDomain);
if (onnx_domain_it != graph.DomainToVersionMap().end()) {
onnx_opset = onnx_domain_it->second;
} else {
auto onnx_domain_alias_it = graph.DomainToVersionMap().find(kOnnxDomainAlias);
if (onnx_domain_alias_it != graph.DomainToVersionMap().end())
onnx_opset = onnx_domain_alias_it->second;
else
ORT_THROW("ONNX domain not found in this model");
}
return onnx_opset;
}
void AdaptInputAndOutputForScalarSlice(Graph& graph, Node& current_node, int current_node_output_index,
int slice_axis, const std::string& entry_node_name,
const std::unordered_map<int, SliceInfo>& new_gather_infos,
const logging::Logger& logger) {
LOG_DEBUG_INFO(logger, "AdaptInputAndOutputForScalarSlice for Node " + current_node.Name() + "(" +
current_node.OpType() + ")");
// For each handled inputs, insert Unsqueeze node to get the removed dim back at slice_axis.
for (auto pair : new_gather_infos) {
int input_index = pair.first;
Node* new_node = nullptr;
// Be noted, the Unsqueeze should happens on the axis of new slice node.
if (GetONNXOpSetVersion(graph) < 13) {
onnxruntime::NodeAttributes attributes;
attributes["axes"] = ONNX_NAMESPACE::MakeAttribute("axes", std::vector<int64_t>{pair.second.non_negative_axis});
new_node =
InsertIntermediateNodeOnDestInput(
graph,
current_node, input_index,
0 /* new node input index to connect to current_node's input node*/,
0 /* new node output index to connect to current_node*/,
graph.GenerateNodeName(entry_node_name + "_adapt_input"),
"Unsqueeze",
"Unsqueeze node",
{current_node.MutableInputDefs()[input_index]},
{&graph.GetOrCreateNodeArg(
graph.GenerateNodeArgName("unsqueeze_adaptor"),
current_node.MutableInputDefs()[input_index]->TypeAsProto())},
attributes, kOnnxDomain,
logger);
} else {
new_node =
InsertIntermediateNodeOnDestInput(
graph,
current_node, input_index,
0 /* new node input index to connect to current_node's input node*/,
0 /* new node output index to connect to current_node*/,
graph.GenerateNodeName(entry_node_name + "_adapt_input"),
"Unsqueeze",
"Unsqueeze node",
{current_node.MutableInputDefs()[input_index],
CreateUnsqueezeAxesInitializer(graph, {pair.second.non_negative_axis})},
{&graph.GetOrCreateNodeArg(
graph.GenerateNodeArgName("unsqueeze_adaptor"),
current_node.MutableInputDefs()[input_index]->TypeAsProto())},
{}, kOnnxDomain,
logger);
}
new_node->SetExecutionProviderType(current_node.GetExecutionProviderType());
// Set correct shape for Unsqueeze node
const TensorShapeProto* unsqueeze_input_shape = new_node->MutableInputDefs()[0]->Shape();
new_node->MutableOutputDefs()[0]->SetShape(
CreateTensorShapeInsertDimAtAxis(unsqueeze_input_shape, pair.second.non_negative_axis, 1));
}
// Find the consumer node of MatMul, and the input index of that node connect to MatMul.
std::vector<const Node*> consumers =
graph.GetConsumerNodes(current_node.MutableOutputDefs()[current_node_output_index]->Name());
ORT_ENFORCE(consumers.size() >= 1, "MatMul should have at least one consumer at this point. " +
std::to_string(consumers.size()) + " consumers found.");
Node& consumer = *graph.GetNode(consumers[0]->Index());
int index = -1;
for (size_t i = 0; i < consumer.InputDefs().size(); ++i) {
auto input_arg = consumer.InputDefs()[i];
if (input_arg->Name().compare(current_node.MutableOutputDefs()[current_node_output_index]->Name()) == 0) {
index = static_cast<int>(i);
break;
}
}
// Create Squeeze node connecting MatMul output to consumer node.
Node* matmul_out_adaptor_node = nullptr;
if (GetONNXOpSetVersion(graph) < 13) {
onnxruntime::NodeAttributes attributes;
attributes["axes"] = ONNX_NAMESPACE::MakeAttribute("axes", std::vector<int64_t>{slice_axis});
matmul_out_adaptor_node =
InsertIntermediateNodeOnDestInput(
graph, consumer, index,
0,
0 /* new node output index*/,
graph.GenerateNodeName(current_node.OpType() + "_output"),
"Squeeze",
"Squeeze node",
{consumer.MutableInputDefs()[index]},
{&graph.GetOrCreateNodeArg(
graph.GenerateNodeArgName("squeeze_adaptor"),
consumer.MutableInputDefs()[index]->TypeAsProto())},
attributes, kOnnxDomain, logger);
} else {
matmul_out_adaptor_node =
InsertIntermediateNodeOnDestInput(
graph, consumer, index,
0,
0 /* new node output index*/,
graph.GenerateNodeName(current_node.OpType() + "_output"),
"Squeeze",
"Squeeze node",
{consumer.MutableInputDefs()[index],
CreateUnsqueezeAxesInitializer(graph, {slice_axis})},
{&graph.GetOrCreateNodeArg(
graph.GenerateNodeArgName("squeeze_adaptor"),
consumer.MutableInputDefs()[index]->TypeAsProto())},
{}, kOnnxDomain, logger);
}
matmul_out_adaptor_node->SetExecutionProviderType(current_node.GetExecutionProviderType());
// Don't need set shape for Squeeze because original MatMul output is used as its output type.
// Set correct shape for MatMul node
const TensorShapeProto* matmul_out_shape = matmul_out_adaptor_node->MutableOutputDefs()[0]->Shape();
current_node.MutableOutputDefs()[0]->SetShape(CreateTensorShapeInsertDimAtAxis(matmul_out_shape, slice_axis, 1));
}
bool DefaultOperatorPassThroughActorBase::PostProcess(
Graph& graph, Node& current_node, int current_node_output_index,
int slice_axis, bool is_slice_scalar, bool input_has_dim_1_for_axis,
const ONNX_NAMESPACE::TensorShapeProto_Dimension& /*output_dim_on_axis*/,
const std::string& entry_node_name,
const std::unordered_map<int, SliceInfo>& new_gather_infos,
const logging::Logger& logger) {
LOG_DEBUG_INFO(logger, "Enter DefaultOperatorPassThroughActorBase::PostProcess for Node " + current_node.Name() +
"(" + current_node.OpType() + ")");
if (is_slice_scalar && input_has_dim_1_for_axis) {
AdaptInputAndOutputForScalarSlice(graph, current_node, current_node_output_index, slice_axis,
entry_node_name, new_gather_infos, logger);
}
return true;
}
bool SimplePassThroughActor::PreCheck(const Graph& /*graph*/, const Node& current_node, const SliceInfo& info,
const std::vector<int>& allowed_input_indices,
const logging::Logger& logger,
std::unordered_map<int, int>& propagate_input_config,
bool& input_has_dim_1_for_axis) {
LOG_DEBUG_INFO(logger, "Enter SimplePassThroughActor::PreCheck for node " + current_node.Name());
Node* slice_node = info.node_ptr;
int current_node_output_index = optimizer_utils::IndexOfNodeOutput(current_node, *slice_node->InputDefs()[0]);
const NodeArg* gather_data_input_arg = current_node.OutputDefs()[current_node_output_index];
propagate_input_config.clear();
input_has_dim_1_for_axis = false;
for (size_t i = 0; i < current_node.InputDefs().size(); ++i) {
if (allowed_input_indices.size() > 0 &&
std::find(allowed_input_indices.begin(), allowed_input_indices.end(), i) == allowed_input_indices.end()) {
continue;
}
bool fatal_error_found = false;
auto ret = CheckInputForPassThrough(gather_data_input_arg, current_node.InputDefs()[i], info, logger,
fatal_error_found, input_has_dim_1_for_axis);
if (fatal_error_found) {
LOG_DEBUG_INFO(logger, "Skip for node " + current_node.Name() + " due to input check failure at index " +
std::to_string(i));
return false;
} else if (ret.has_value()) {
propagate_input_config[static_cast<int>(i)] = ret.value();
}
}
// Make sure once Gather is moved before target node, all its outputs can be correctly be sliced.
std::unordered_map<int, int> output_indices;
for (size_t i = 0; i < current_node.OutputDefs().size(); ++i) {
if (static_cast<int>(i) == current_node_output_index) {
continue;
}
bool fatal_error_found = false;
bool dim_1_for_axis_found = false;
auto ret = CheckInputForPassThrough(gather_data_input_arg, current_node.OutputDefs()[i], info, logger,
fatal_error_found, dim_1_for_axis_found);
if (fatal_error_found) {
LOG_DEBUG_INFO(logger, "Skip for node " + current_node.Name() + " due to output check failure at index " +
std::to_string(i));
return false;
} else if (ret.has_value()) {
output_indices[static_cast<int>(i)] = ret.value();
}
}
bool output_check_success = output_indices.size() == current_node.OutputDefs().size() - 1;
return output_check_success;
}
bool ReductionOpPassThroughActor::PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info,
const std::vector<int>& allowed_input_indices,
const logging::Logger& logger,
std::unordered_map<int, int>& propagate_input_config,
bool& input_has_dim_1_for_axis) {
auto axis = static_cast<int64_t>(current_node.GetAttributes().at("axis").i());
axis = axis < 0 ? axis + current_node.InputDefs()[0]->Shape()->dim_size() : axis;
// Make sure layernorm/softmax's reduction happens after the axis we want to slice.
if (axis <= info.non_negative_axis) {
return false;
}
return SimplePassThroughActor::PreCheck(graph, current_node, info, allowed_input_indices, logger,
propagate_input_config, input_has_dim_1_for_axis);
}
bool ReshapePassThroughActor::PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info,
const std::vector<int>& /*allowed_input_indices*/,
const logging::Logger& logger,
std::unordered_map<int, int>& propagate_input_config,
bool& /*input_has_dim_1_for_axis*/) {
auto data_input_shape = current_node.InputDefs()[0]->Shape();
auto shape_input_shape = current_node.InputDefs()[1]->Shape();
auto output_shape = current_node.OutputDefs()[0]->Shape();
if (data_input_shape == nullptr || shape_input_shape == nullptr || shape_input_shape->dim_size() != 1 ||
output_shape == nullptr) {
LOG_DEBUG_INFO(logger, "Reshape input/output node arg shape is not valid.");
return false;
}
if (!graph_utils::IsConstantInitializer(graph, current_node.InputDefs()[1]->Name())) {
LOG_DEBUG_INFO(logger, "Skip handle the Reshape, because the new shape is not constant.");
return false;
}
propagate_input_config.clear();
InlinedVector<int64_t> new_shape_const_values;
optimizer_utils::AppendTensorFromInitializer(graph, *current_node.InputDefs()[1], new_shape_const_values, true);
// Only below two cases are supported for easier updating shape data after propagate slice ops.
// 1). If the shape data on slicing axis is zero (e.g. remain the same after slicing), we support it.
// 2). Or if the sliced dim value is a constant, we also support it, and can update the shape data directly.
// For other cases, it is feasible to support but we don't support for now.
if (new_shape_const_values[info.non_negative_axis] == 0 || info.output_dim_on_axis.has_dim_value()) {
auto in_dims = data_input_shape->dim();
auto out_dims = output_shape->dim();
int in_rank = in_dims.size();
int out_rank = out_dims.size();
int reshape_input_axis = -1;
// Match from left to right.
for (int i = 0; i < std::min(in_rank, out_rank); ++i) {
bool dim_value_eq = in_dims[i].has_dim_value() && out_dims[i].has_dim_value() &&
in_dims[i].dim_value() == out_dims[i].dim_value();
bool dim_param_eq = in_dims[i].has_dim_param() && out_dims[i].has_dim_param() &&
in_dims[i].dim_param() == out_dims[i].dim_param();
if (dim_value_eq || dim_param_eq) {
if (i == info.non_negative_axis) {
reshape_input_axis = i;
break;
}
continue;
}
}
if (reshape_input_axis == -1) {
// Match from right to left.
for (int i = 0; i < std::min(in_rank, out_rank); ++i) {
int in_index = in_rank - 1 - i;
int out_index = out_rank - 1 - i;
bool dim_value_eq = in_dims[in_index].has_dim_value() && out_dims[out_index].has_dim_value() &&
in_dims[in_index].dim_value() == out_dims[out_index].dim_value();
bool dim_param_eq = in_dims[in_index].has_dim_param() && out_dims[out_index].has_dim_param() &&
in_dims[in_index].dim_param() == out_dims[out_index].dim_param();
if (dim_value_eq || dim_param_eq) {
if (out_index == info.non_negative_axis) {
reshape_input_axis = in_index;
break;
}
continue;
}
}
}
if (reshape_input_axis == -1) {
LOG_DEBUG_INFO(logger, "Cannot find Reshape's input axis for Gather.");
return false;
}
propagate_input_config[0] = reshape_input_axis;
return true;
}
return false;
}
bool ReshapePassThroughActor::PostProcess(Graph& graph, Node& current_node, int /*current_node_output_index*/,
int slice_axis, bool is_slice_scalar, bool /*input_has_dim_1_for_axis*/,
const ONNX_NAMESPACE::TensorShapeProto_Dimension& output_dim_on_axis,
const std::string& /*entry_node_name*/,
const std::unordered_map<int, SliceInfo>& /*new_gather_infos*/,
const logging::Logger& logger) {
LOG_DEBUG_INFO(logger, "ReshapePostProcess for Node " + current_node.Name() + "(" + current_node.OpType() + ")");
InlinedVector<int64_t> new_shape_const_values;
optimizer_utils::AppendTensorFromInitializer(graph, *current_node.InputDefs()[1], new_shape_const_values, true);
auto create_new_initializer_from_vector = [&graph](NodeArg* arg_to_be_replaced,
const InlinedVector<int64_t>& new_values) -> NodeArg* {
// Create new TensorProto.
ONNX_NAMESPACE::TensorProto constant_tensor_proto;
constant_tensor_proto.set_name(graph.GenerateNodeArgName(arg_to_be_replaced->Name()));
constant_tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
auto length = new_values.size();
constant_tensor_proto.add_dims(length);
constant_tensor_proto.set_raw_data(new_values.data(), length * sizeof(int64_t));
// Add initializer into Graph.
NodeArg* new_shape_arg = &graph_utils::AddInitializer(graph, constant_tensor_proto);
// Update the output arg shape.
ONNX_NAMESPACE::TensorShapeProto new_shape;
new_shape.add_dim()->set_dim_value(length);
new_shape_arg->SetShape(new_shape);
return new_shape_arg;
};
// If the shape constant on slice_axis is 0, then it keeps the original dim of input.
// If it is scalar slice, then we just remove that dim. Otherwise, we don't need to update the dim value.
if (new_shape_const_values[slice_axis] == 0) {
if (is_slice_scalar) {
LOG_DEBUG_INFO(logger, "Removing axis " + std::to_string(slice_axis) + " from shape tensor.");
NodeArg* arg_to_be_replaced = current_node.MutableInputDefs()[1];
InlinedVector<int64_t> new_values;
for (int i = 0; i < static_cast<int>(new_shape_const_values.size()); ++i) {
if (i != slice_axis) {
new_values.push_back(new_shape_const_values[i]);
}
}
auto new_shape_arg = create_new_initializer_from_vector(arg_to_be_replaced, new_values);
graph_utils::ReplaceNodeInput(current_node, 1, *new_shape_arg);
} else {
LOG_DEBUG_INFO(logger, "Reshape's shape has 0 specified for aixs: " + std::to_string(slice_axis) +
", not need update.");
}
return true;
}
// If it selected shape is dim value, we can update the shape tensor directory.
if (output_dim_on_axis.has_dim_value()) {
new_shape_const_values[slice_axis] = output_dim_on_axis.dim_value();
auto new_shape_arg = create_new_initializer_from_vector(current_node.MutableInputDefs()[1], new_shape_const_values);
graph_utils::ReplaceNodeInput(current_node, 1, *new_shape_arg);
return true;
}
ORT_THROW("Fail to update shape data in ReshapePassThroughActor::PostProcess, but this should not be called.");
}
bool TransposePassThroughActor::PreCheck(const Graph& /*graph*/, const Node& current_node, const SliceInfo& info,
const std::vector<int>& /*allowed_input_indices*/,
const logging::Logger& logger,
std::unordered_map<int, int>& propagate_input_config,
bool& input_has_dim_1_for_axis) {
InlinedVector<int64_t> perm;
if (!graph_utils::GetRepeatedNodeAttributeValues(current_node, "perm", perm)) {
LOG_DEBUG_INFO(logger, "perm attribute is not set for node " + current_node.Name());
return false;
}
propagate_input_config.clear();
propagate_input_config[0] = static_cast<int>(perm[info.non_negative_axis]);
input_has_dim_1_for_axis = false;
return true;
}
bool TransposePassThroughActor::PostProcess(Graph& graph, Node& current_node, int current_node_output_index,
int slice_axis, bool is_slice_scalar, bool /*input_has_dim_1_for_axis*/,
const ONNX_NAMESPACE::TensorShapeProto_Dimension& /*output_dim_on_axis*/,
const std::string& entry_node_name,
const std::unordered_map<int, SliceInfo>& new_gather_infos,
const logging::Logger& logger) {
LOG_DEBUG_INFO(logger, "Enter TransposePassThroughActor::PostProcess for Node " + current_node.Name() + "(" +
current_node.OpType() + ")");
// We need keep the original dimension to align with original perm.
if (is_slice_scalar) {
AdaptInputAndOutputForScalarSlice(graph, current_node, current_node_output_index, slice_axis,
entry_node_name, new_gather_infos, logger);
}
return true;
}
bool MatMulPassThroughActor::PreCheck(const Graph& /*graph*/, const Node& current_node, const SliceInfo& info,
const std::vector<int>& allowed_input_indices,
const logging::Logger& logger,
std::unordered_map<int, int>& propagate_input_config,
bool& input_has_dim_1_for_axis) {
LOG_DEBUG_INFO(logger, "Enter MatMulPassThroughActor::PreCheck for node " + current_node.Name());
auto lhs_rank = current_node.InputDefs()[0]->Shape()->dim_size();
auto rhs_rank = current_node.InputDefs()[1]->Shape()->dim_size();
if (!(lhs_rank >= 2 && rhs_rank >= 2)) {
LOG_DEBUG_INFO(logger, "MatMul input rank lower than 2, skip.");
return false;
}
propagate_input_config.clear();
if (info.non_negative_axis == info.input_rank - 1) {
propagate_input_config[1] = rhs_rank - 1;
return true;
} else if (info.non_negative_axis == info.input_rank - 2) {
propagate_input_config[0] = lhs_rank - 2;
return true;
}
int target_node_output_index = optimizer_utils::IndexOfNodeOutput(current_node, *info.node_ptr->InputDefs()[0]);
const NodeArg* gather_data_input_arg = current_node.OutputDefs()[target_node_output_index];
input_has_dim_1_for_axis = false;
for (size_t i = 0; i < current_node.InputDefs().size(); ++i) {
if (allowed_input_indices.size() > 0 &&
std::find(allowed_input_indices.begin(), allowed_input_indices.end(), i) == allowed_input_indices.end()) {
continue;
}
bool fatal_error_found = false;
auto ret = CheckInputForPassThrough(gather_data_input_arg, current_node.InputDefs()[i], info, logger,
fatal_error_found, input_has_dim_1_for_axis);
if (fatal_error_found) {
LOG_DEBUG_INFO(logger, "Skip for node " + current_node.Name() + " due to input check failure at index " +
std::to_string(i));
return false;
} else if (ret.has_value()) {
LOG_DEBUG_INFO(logger, "Add new input candidate for node " + current_node.Name() + " at index " +
std::to_string(i) + " with axis " + std::to_string(ret.value()));
propagate_input_config[static_cast<int>(i)] = ret.value();
}
}
return propagate_input_config.size() > 0;
}
bool MatMulPassThroughActor::PostProcess(Graph& graph, Node& current_node, int current_node_output_index,
int slice_axis, bool is_slice_scalar, bool /*input_has_dim_1_for_axis*/,
const ONNX_NAMESPACE::TensorShapeProto_Dimension& /*output_dim_on_axis*/,
const std::string& entry_node_name,
const std::unordered_map<int, SliceInfo>& new_gather_infos,
const logging::Logger& logger) {
LOG_DEBUG_INFO(logger, "Enter MatMulPassThroughActor::PostProcess for Node " + current_node.Name() + "(" +
current_node.OpType() + ")");
// We need keep the original dimension to avoid the matmul inputs cannot be compatible to compute.
if (is_slice_scalar) {
AdaptInputAndOutputForScalarSlice(graph, current_node, current_node_output_index, slice_axis,
entry_node_name, new_gather_infos, logger);
}
return true;
}
} // namespace onnxruntime::optimizer::compute_optimizer
#endif

View file

@ -0,0 +1,333 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef ENABLE_TRAINING
#pragma once
// Uncomment for debugging
// #define NEED_LOG_DEBUG_INFO 1
#ifdef NEED_LOG_DEBUG_INFO
#define LOG_DEBUG_INFO(logger, message) LOGS(logger, WARNING) << message
#else
#define LOG_DEBUG_INFO(logger, message) \
ORT_UNUSED_PARAMETER(logger); \
do { \
} while (0)
#endif
namespace onnxruntime::optimizer::compute_optimizer {
/**
* @brief Struct to hold the information of the slicing operations.
*
* Initially, an instance of this class for entry node is created, as the slice op propagates to entry node's inputs,
* more instances of this class are created. The propagation stops when the all inputs are not supported to be sliced.
*/
struct SliceInfo {
static constexpr int kSliceDataInputIndex = 0;
static constexpr int kSliceOutputIndex = 0;
SliceInfo(Node* slice_node,
bool is_slice_scalar,
const std::string& slice_axis_attr_name,
int slice_axis,
bool is_entry_node_ptr = false)
: node_ptr(slice_node), is_scalar_slice(is_slice_scalar) {
axis_attr_name = slice_axis_attr_name;
const NodeArg* input = node_ptr->InputDefs()[kSliceDataInputIndex];
const NodeArg* output = node_ptr->OutputDefs()[kSliceOutputIndex];
input_rank = input->Shape()->dim_size();
non_negative_axis = slice_axis < 0 ? input_rank + slice_axis : slice_axis;
if (!is_scalar_slice) {
output_dim_on_axis = output->Shape()->dim(non_negative_axis);
}
if (is_entry_node_ptr) {
entry_slice_arg_name = node_ptr->OutputDefs()[kSliceOutputIndex]->Name();
}
}
int GetDataInputIndex() const {
return kSliceDataInputIndex;
}
int GetOutputIndex() const {
return kSliceOutputIndex;
}
Node* node_ptr; // The Gather/GatherND node that triggers the optimization search.
bool is_scalar_slice; // whether the slice is a scalar, if it is, after Gather, rank will be reduced by 1.
std::string axis_attr_name;
int non_negative_axis; // The axis to slice on
std::string entry_slice_arg_name;
int input_rank; // rank of the Gather data input tensor
// The dimension of the output tensor on the slicing axis
// Be noted: if it is a scalar slicing, this dim will not be set, which means, afterward when use it to update
// shapes, that dim at axis will be removed.
ONNX_NAMESPACE::TensorShapeProto_Dimension output_dim_on_axis;
};
/**
* @brief Base class for all pass through actors.
*
* Each actors defines rules to determine whether a node can be passed through, and how to do the pass through.
* PreCheck is the interface to check whether a node can be passed through.
* The pass through is done transparently, without any interface required to implemented.
* PostProcess is the interface to do some adaptor work after the pass through.
*/
class OperatorPassThroughActorBase {
public:
OperatorPassThroughActorBase() = default;
virtual ~OperatorPassThroughActorBase() = default;
/**
* @brief Check whether a node can be passed through.
* At this point, graph modification is not started, once we see any clues that this node cannot be passed through,
* We should return false immediately.
*
* @param graph The graph that the node belongs to.
* @param current_node The node to be checked.
* @param info The slicing info of the Gather/GatherND node.
* @param allowed_input_indices The input indices explicitly specified of the current_node that are allowed to do pass
* through.
* @param propagate_input_config: Used as a return value - a map of input index to new slice axis.
* The key is an integer, which is the index of the input of the current_node.
* The value is an integer, which is the new axis index after the pass through on the corresponding input.
* For example:
* > if the current_node is a Add node, and the slice axe is 1, then the corresponding input should
* also have axis 1 when we move the slice to the input.
* > if the current_node is a Transpose (perm=[1, 0, 2]) node, and the slice
* axis is 1, then the new axis for the input should be 0.
* @param input_has_dim_1_for_axis: Used as a return value - a bool indicates whether any of current_node' input
* has dim 1 on the slice axis.
*/
virtual bool PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info,
const std::vector<int>& allowed_input_indices,
const logging::Logger& logger,
std::unordered_map<int, int>& propagate_input_config,
bool& input_has_dim_1_for_axis) = 0;
/**
* @brief After slice op pass through all inputs, do some post process work.
*
* Be noted: at this point, slice op is already removed, so we cannot access SliceInfo any more, instead,
* we pass important infos including slice_axis, input_rank, is_scalar_slice, etc as parameters of this function.
*
* @param graph The graph that the node belongs to.
* @param current_node The node that has been passed through.
* @param current_node_output_index The output index of the current_node connecting to slice op.
* @param slice_axis slice axis of the slice op.
* @param is_slice_scalar whether the slice is a scalar.
* @param input_has_dim_1_for_axis whether any of current_node's inputs has dim 1 on the slice axis.
* @param output_dim_on_axis dimension of the slice op's output tensor on the slice axis.
* @param entry_node_name name of entry node that trigger the pass through search, for naming only.
* @param new_gather_infos new gather infos that are generated during the pass through for current_node's inputs.
* @param logger
* @return
*/
virtual bool PostProcess(Graph& graph, Node& current_node, int current_node_output_index,
int slice_axis, bool is_slice_scalar, bool input_has_dim_1_for_axis,
const ONNX_NAMESPACE::TensorShapeProto_Dimension& output_dim_on_axis,
const std::string& entry_node_name,
const std::unordered_map<int, SliceInfo>& new_gather_infos,
const logging::Logger& logger) = 0;
};
class DefaultOperatorPassThroughActorBase : public OperatorPassThroughActorBase {
public:
DefaultOperatorPassThroughActorBase() = default;
~DefaultOperatorPassThroughActorBase() = default;
bool PreCheck(const Graph&, const Node&, const SliceInfo&, const std::vector<int>&, const logging::Logger&,
std::unordered_map<int, int>&, bool&) override {
return true;
};
bool PostProcess(Graph& graph, Node& current_node, int current_node_output_index,
int slice_axis, bool is_slice_scalar, bool input_has_dim_1_for_axis,
const ONNX_NAMESPACE::TensorShapeProto_Dimension& output_dim_on_axis,
const std::string& entry_node_name,
const std::unordered_map<int, SliceInfo>& new_gather_infos,
const logging::Logger& logger) override;
};
class SimplePassThroughActor : public DefaultOperatorPassThroughActorBase {
public:
SimplePassThroughActor() = default;
~SimplePassThroughActor() = default;
bool PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info,
const std::vector<int>& allowed_input_indices,
const logging::Logger& logger,
std::unordered_map<int, int>& propagate_input_config,
bool& input_has_dim_1_for_axis) override;
};
class ReductionOpPassThroughActor : public SimplePassThroughActor {
public:
ReductionOpPassThroughActor() = default;
~ReductionOpPassThroughActor() = default;
bool PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info,
const std::vector<int>& allowed_input_indices,
const logging::Logger& logger,
std::unordered_map<int, int>& propagate_input_config,
bool& input_has_dim_1_for_axis) override;
};
class ReshapePassThroughActor : public DefaultOperatorPassThroughActorBase {
public:
ReshapePassThroughActor() = default;
~ReshapePassThroughActor() = default;
bool PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info,
const std::vector<int>& allowed_input_indices,
const logging::Logger& logger,
std::unordered_map<int, int>& propagate_input_config,
bool& input_has_dim_1_for_axis) override;
// Once slice node is passed through, we need to update the shape accordingly.
bool PostProcess(Graph& graph, Node& current_node, int current_node_output_index,
int slice_axis, bool is_slice_scalar, bool input_has_dim_1_for_axis,
const ONNX_NAMESPACE::TensorShapeProto_Dimension& output_dim_on_axis,
const std::string& entry_node_name,
const std::unordered_map<int, SliceInfo>& new_gather_infos,
const logging::Logger& logger) override;
};
class TransposePassThroughActor : public DefaultOperatorPassThroughActorBase {
public:
TransposePassThroughActor() = default;
~TransposePassThroughActor() = default;
bool PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info,
const std::vector<int>& allowed_input_indices,
const logging::Logger& logger,
std::unordered_map<int, int>& propagate_input_config,
bool& input_has_dim_1_for_axis) override;
// If scalar slice happens, we need adapt the input, otherwise the perm cannot be matched.
bool PostProcess(Graph& graph, Node& current_node, int current_node_output_index,
int slice_axis, bool is_slice_scalar, bool input_has_dim_1_for_axis,
const ONNX_NAMESPACE::TensorShapeProto_Dimension& output_dim_on_axis,
const std::string& entry_node_name,
const std::unordered_map<int, SliceInfo>& new_gather_infos,
const logging::Logger& logger) override;
};
class MatMulPassThroughActor : public DefaultOperatorPassThroughActorBase {
public:
MatMulPassThroughActor() = default;
~MatMulPassThroughActor() = default;
// Check which inputs can be propagated according to the slice axis.
bool PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info,
const std::vector<int>& allowed_input_indices,
const logging::Logger& logger,
std::unordered_map<int, int>& propagate_input_config,
bool& input_has_dim_1_for_axis) override;
// If scalar slice happens in the second last dimension, we need to adapt the input.
bool PostProcess(Graph& graph, Node& current_node, int current_node_output_index,
int slice_axis, bool is_slice_scalar, bool input_has_dim_1_for_axis,
const ONNX_NAMESPACE::TensorShapeProto_Dimension& output_dim_on_axis,
const std::string& entry_node_name,
const std::unordered_map<int, SliceInfo>& new_gather_infos,
const logging::Logger& logger) override;
};
/**
* @brief Update the dim value using given new dim value at specified axis.
*
* @param arg_to_update The NodeArg to be updated.
* @param reverse_axis A negative axis MUST be given here. This is to make sure if arg_to_update has less rank
* than expected value, the update will be ignored.
* @param output_dim_on_axis New dim value to be updated.
* @return true if the update is done.
*/
bool UpdateSliceOutputShape(NodeArg& arg_to_update, int reverse_axis,
const ONNX_NAMESPACE::TensorShapeProto_Dimension& new_dim_value);
/**
* @brief Insert a new node to the graph,
* 1. taking dest_node.input[dest_input_index] as the input of the new node.
* 2. remove connection of dest_node and it's dest_input_index-th input producer node.
* 3. connect the new node and dest_node.
*
* Original graph:
* Node A
* / \
* A-output-0 A-output-1
* \ B-input-1
* \ /
* Node B
* |
*
* dest_node = Node B
* dest_input_index = 0
* op_type = C
*
* After inserting the new node:
* Node A
* / \
* A-output-0 A-output-1
* \
* Node C
* / \
* c-output-0 C-output-1 B-input-1
* \ /
* Node B
* |
* @param graph Graph to insert the new node.
* @param dest_node The node to insert the new node before.
* @param dest_in_index The input index of the dest_node to insert the new node before.
* @param new_node_output_index The output index of the new node to connect to the dest_node.
* @param name The name of the new node.
* @param op_type The op_type of the new node.
* @param description The description of the new node.
* @param input_args The input args of the new node. At least one of the input args should be the
* dest_node's dest_in_index-th input arg.
* @param attributes The attributes of the new node.
* @param domain The domain of the new node.
* @param logger The logger.
* @return
*/
Node* InsertIntermediateNodeOnDestInput(Graph& graph,
Node& dest_node, int dest_in_index,
int new_node_input_index,
int new_node_output_index,
const std::string& name, const std::string& op_type,
const std::string& description,
const InlinedVector<NodeArg*>& input_args,
const InlinedVector<NodeArg*>& output_args,
const onnxruntime::NodeAttributes& attributes,
const std::string& domain,
const logging::Logger& logger);
/**
* @brief Insert adaptor nodes for the inputs and output, to make sure they remain the same rank, when scalar slicing
* is done.
*
* Be noted: at this point, slice node already been removed.
*
* @param graph Graph to insert the adaptor nodes.
* @param current_node For whom to insert the adaptor nodes.
* @param slice_axis The axis of the slice node.
* @param entry_node_name Then name of the entry slice node, used for naming only.
* @param new_gather_infos Populated slicing infos for current_node's inputs.
* @param target_node_output_index output_index of current_node's output, connecting to the slice node.
* @param logger Logger.
*/
void AdaptInputAndOutputForScalarSlice(Graph& graph, Node& current_node, int current_node_output_index,
int slice_axis, const std::string& entry_node_name,
const std::unordered_map<int, SliceInfo>& new_gather_infos,
const logging::Logger& logger);
} // namespace onnxruntime::optimizer::compute_optimizer
#endif

View file

@ -307,7 +307,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
#ifdef ENABLE_TRAINING
// Put memory optimization transformer at last (which is done after most of fusions are done) by intention.
// Known issue: after mmeory optimization is completed, if some fusion happens, it is possible that the
// Known issue: after memory optimization is completed, if some fusion happens, it is possible that the
// node priority got changed. This may disorder the execution order of nodes to recompute.
// TODO(pengwa): need to fix this issue.
const std::string enable_memory_optimizer =
@ -329,7 +329,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<NhwcTransformer>(std::move(cpu_allocator)));
// NCHWCtransformer should have a higher priority versus this. Because NCHWCtransformer also do the similar things
// of fusion patterns and target on CPU. However, NCHWCtransformer will reorder the layout to nchwc which is only available for
// x86-64 cpu, not edge cpu like arm. But This tranformer could be used by opencl-ep/cpu-ep. So
// x86-64 cpu, not edge cpu like arm. But This transformer could be used by opencl-ep/cpu-ep. So
// we will prefer NhwcTransformer once ort runs on x86-64 CPU, otherwise ConvAddActivationFusion is enabled.
// PR #6351 implemented similar fusion-pattern for CUDA only, and can only fuse conv-add-relu,
// while we can fuse more activation.

File diff suppressed because it is too large Load diff

View file

@ -24,7 +24,6 @@
#include "core/optimizer/bias_softmax_fusion.h"
#include "core/optimizer/cast_elimination.h"
#include "core/optimizer/common_subexpression_elimination.h"
#include "core/optimizer/computation_reduction.h"
#include "core/optimizer/concat_slice_elimination.h"
#include "core/optimizer/constant_folding.h"
#include "core/optimizer/constant_sharing.h"
@ -5069,199 +5068,6 @@ TEST_F(GraphTransformationTests, MatMulIntegerToFloatTest) {
#endif
// LayerNormalization implementation is in contrib namespace (OnnxDomain 1), so
// Without contib_ops enabled, we cannot parse the graph correctly.
#ifndef DISABLE_CONTRIB_OPS
// We used Opset 12 for testing to make sure we are not using GatherND OnnxDomain Opset 1.
static void GatherNDComputationReductionTest(const std::string op_type, logging::Logger& logger) {
std::string op_type_lower = op_type;
std::transform(op_type_lower.begin(), op_type_lower.end(), op_type_lower.begin(), [](unsigned char c) { return std::tolower(c); });
std::string file_path = std::string("testdata/transform/computation_reduction/gathernd_") + op_type_lower + std::string(".onnx");
std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(ToPathString(file_path), model, nullptr, logger));
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
onnxruntime::GraphTransformerManager graph_transformation_mgr{1};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<ComputationReductionTransformer>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, logger));
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
Node* gathernd_node = nullptr;
for (auto node_index : node_topology_list) {
Node* p_node = graph.GetNode(node_index);
ASSERT_FALSE(p_node == nullptr);
if (p_node->OpType().compare("GatherND") == 0) {
gathernd_node = p_node;
EXPECT_EQ(gathernd_node->MutableInputDefs()[0]->Name(), "input");
const auto& consumers = graph.GetConsumerNodes(gathernd_node->MutableOutputDefs()[0]->Name());
EXPECT_EQ(consumers[0]->OpType(), op_type);
}
}
ASSERT_FALSE(gathernd_node == nullptr);
}
TEST_F(GraphTransformationTests, ComputationReductionTransformer_GatherND_Gelu) {
GatherNDComputationReductionTest("Gelu", *logger_);
}
TEST_F(GraphTransformationTests, ComputationReductionTransformer_GatherND_Add) {
GatherNDComputationReductionTest("Add", *logger_);
}
TEST_F(GraphTransformationTests, ComputationReductionTransformer_GatherND_LayerNormalization) {
GatherNDComputationReductionTest("LayerNormalization", *logger_);
}
TEST_F(GraphTransformationTests, ComputationReductionTransformer_GatherND_MatMul) {
GatherNDComputationReductionTest("MatMul", *logger_);
}
static void RunGatherNDE2EGraph(std::vector<OrtValue>& run_results, const PathString& model_uri,
const std::string session_log_id, const std::string& provider_type,
const std::vector<int64_t>& dims_input,
const std::vector<float>& input_values,
const std::vector<int64_t>& dims_unsqueezed_masked_lm_positions,
const std::vector<int64_t>& values_unsqueezed_masked_lm_positions) {
SessionOptions so;
// we don't want any transformation here.
so.graph_optimization_level = TransformerLevel::Default;
so.session_logid = session_log_id;
InferenceSession session_object{so, GetEnvironment()};
std::unique_ptr<IExecutionProvider> execution_provider;
if (provider_type == onnxruntime::kCpuExecutionProvider)
execution_provider = DefaultCpuExecutionProvider();
else if (provider_type == onnxruntime::kCudaExecutionProvider)
execution_provider = DefaultCudaExecutionProvider();
else if (provider_type == onnxruntime::kRocmExecutionProvider)
execution_provider = DefaultRocmExecutionProvider();
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
Status st;
ASSERT_TRUE((st = session_object.Load(model_uri)).IsOK()) << st;
ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st;
OrtValue input1;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_input, input_values, &input1);
OrtValue input2;
CreateMLValue<int64_t>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_unsqueezed_masked_lm_positions,
values_unsqueezed_masked_lm_positions, &input2);
NameMLValMap feeds;
feeds.insert(std::make_pair("input", input1));
feeds.insert(std::make_pair("unsqueezed_masked_lm_positions", input2));
// prepare outputs
std::vector<std::string> output_names;
output_names.push_back("output");
output_names.push_back("gather_output");
// Now run
RunOptions run_options;
st = session_object.Run(run_options, feeds, output_names, &run_results);
EXPECT_TRUE(st.IsOK());
}
TEST_F(GraphTransformationTests, ComputationReductionTransformer_GatherND_E2E) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "computation_reduction/e2e.onnx";
std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<ComputationReductionTransformer>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
// check the expected node orders.
{
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
Node* gathernd_node = nullptr;
for (auto node_index : node_topology_list) {
Node* p_node = graph.GetNode(node_index);
ASSERT_FALSE(p_node == nullptr);
if (p_node->OpType().compare("GatherND") == 0) {
gathernd_node = p_node;
const Node* layer_norm_node = graph.GetProducerNode(gathernd_node->MutableInputDefs()[0]->Name());
EXPECT_EQ(layer_norm_node->OpType(), "LayerNormalization");
EXPECT_EQ(layer_norm_node->Name(), "layer_norm_1");
const auto& consumers = graph.GetConsumerNodes(gathernd_node->MutableOutputDefs()[0]->Name());
EXPECT_EQ(consumers[0]->OpType(), "MatMul");
EXPECT_EQ(consumers[0]->Name(), "matmul_1");
break;
}
}
ASSERT_FALSE(gathernd_node == nullptr);
}
// check result diff after the re-order
auto new_model_uri = "computation_reduction_transformer_after.onnx";
ASSERT_STATUS_OK(Model::Save(*model, new_model_uri));
float scale = 1.f;
float mean = 0.f;
float seed = 123.f;
std::default_random_engine generator_float{gsl::narrow_cast<uint32_t>(seed)};
std::normal_distribution<float> distribution_float{mean, scale};
int batch_size = 8;
int sequence = 128;
int hidden_size = 128;
int dynamic_predict_count = 20;
const std::vector<int64_t> dims_input = {batch_size, sequence, hidden_size};
std::vector<float> input_values(TensorShape(dims_input).Size());
std::for_each(input_values.begin(), input_values.end(),
[&generator_float, &distribution_float](float& value) { value = distribution_float(generator_float); });
const std::vector<int64_t> dims_unsqueezed_masked_lm_positions = {batch_size, dynamic_predict_count, 1};
std::vector<int64_t> values_unsqueezed_masked_lm_positions(TensorShape(dims_unsqueezed_masked_lm_positions).Size());
std::random_device rd; // obtain a random number from hardware
std::mt19937 eng(rd()); // seed the generator
std::uniform_int_distribution<> distr(0, sequence - 1); // define the range
std::for_each(values_unsqueezed_masked_lm_positions.begin(), values_unsqueezed_masked_lm_positions.end(),
[&distr, &eng](int64_t& value) { value = distr(eng); });
static const std::string all_provider_types[] = {
onnxruntime::kCpuExecutionProvider,
#ifdef USE_CUDA
onnxruntime::kCudaExecutionProvider,
#elif USE_ROCM
onnxruntime::kRocmExecutionProvider,
#endif
};
for (auto& provider_type : all_provider_types) {
std::vector<OrtValue> expected_ort_values;
RunGatherNDE2EGraph(expected_ort_values, model_uri, std::string("RawGraphRun"), provider_type,
dims_input, input_values, dims_unsqueezed_masked_lm_positions,
values_unsqueezed_masked_lm_positions);
std::vector<OrtValue> actual_ort_values;
RunGatherNDE2EGraph(actual_ort_values, ToPathString(new_model_uri), std::string("OptimizedGraphRun"), provider_type,
dims_input, input_values, dims_unsqueezed_masked_lm_positions,
values_unsqueezed_masked_lm_positions);
ASSERT_TRUE(expected_ort_values.size() == actual_ort_values.size());
constexpr double per_sample_tolerance = 1e-4;
constexpr double relative_per_sample_tolerance = 1e-4;
for (size_t i = 0; i < expected_ort_values.size(); i++) {
auto ret = CompareOrtValue(actual_ort_values[i], expected_ort_values[i],
per_sample_tolerance, relative_per_sample_tolerance, false);
EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second;
}
}
}
#endif
#ifndef DISABLE_CONTRIB_OPS
template <typename GraphTransformationCheckFn, typename GraphPreprocessFn>
static void TestMatMulScaleFusion(

View file

@ -0,0 +1,65 @@
import onnx
from onnx import OperatorSetIdProto, TensorProto, helper
def _create_model_proto(output_shapes, axis_to_gather, slice_dims, slices_values, model_name):
# inputs and outputs
hidden = 1024
inputs = [
helper.make_tensor_value_info("input1", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden]),
helper.make_tensor_value_info("input2", TensorProto.FLOAT, ["batch_size", hidden, "sequence_length"]),
]
outputs = [
helper.make_tensor_value_info("final_output", TensorProto.FLOAT, output_shapes),
]
# initializers
initializers = [
helper.make_tensor("slices", TensorProto.INT64, slice_dims, slices_values),
]
# nodes
nodes = [
helper.make_node("MatMul", ["input1", "input2"], ["m1_out"], "m1"),
helper.make_node("Gather", ["m1_out", "slices"], ["gather_out"], "gather", axis=axis_to_gather),
helper.make_node("Identity", ["gather_out"], ["final_output"], "identity1"),
]
# Create the graph (GraphProto)
graph_def = helper.make_graph(
nodes,
"test-model",
inputs,
outputs,
initializers,
"doc string",
)
opsets = []
onnxdomain = OperatorSetIdProto()
onnxdomain.version = 14
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator
# set that is defined as part of the ONNX specification.
opsets.append(onnxdomain)
msdomain = OperatorSetIdProto()
msdomain.version = 1
msdomain.domain = "com.microsoft"
opsets.append(msdomain)
kwargs = {}
kwargs["opset_imports"] = opsets
model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs)
final_model = onnx.shape_inference.infer_shapes(model_def)
onnx.save(final_model, model_name + ".onnx")
_create_model_proto(["sequence_length", "sequence_length"], 0, [], [1], "gather_matmul_scalar_batch_dim")
_create_model_proto([1, "sequence_length", "sequence_length"], 0, [1], [1], "gather_matmul_batch_dim")
_create_model_proto(["batch_size", "sequence_length"], 1, [], [1], "gather_matmul_scalar_second_last_dim")
_create_model_proto(["batch_size", 1, "sequence_length"], 1, [1], [1], "gather_matmul_second_last_dim")
_create_model_proto(["batch_size", "sequence_length"], 2, [], [1], "gather_matmul_scalar_last_dim")
_create_model_proto(["batch_size", "sequence_length", 1], 2, [1], [1], "gather_matmul_last_dim")

View file

@ -0,0 +1,89 @@
import onnx
from onnx import OperatorSetIdProto, TensorProto, helper
hidden = 1024
head = 16
def _create_model_proto(
input_shapes, output_shapes, axis_to_gather, slice_dims, slices_values, shape_dims, shape_values, model_name
):
# inputs and outputs
inputs = [
helper.make_tensor_value_info("input1", TensorProto.FLOAT, input_shapes),
]
outputs = [
helper.make_tensor_value_info("final_output", TensorProto.FLOAT, output_shapes),
]
# initializers
initializers = [
helper.make_tensor("shape", TensorProto.INT64, shape_dims, shape_values),
helper.make_tensor("slices", TensorProto.INT64, slice_dims, slices_values),
]
# nodes
nodes = [
helper.make_node("Reshape", ["input1", "shape"], ["reshape_out"], "reshape1"),
helper.make_node("Gather", ["reshape_out", "slices"], ["gather_out"], "gather", axis=axis_to_gather),
helper.make_node("Identity", ["gather_out"], ["final_output"], "identity1"),
]
# Create the graph (GraphProto)
graph_def = helper.make_graph(
nodes,
"test-model",
inputs,
outputs,
initializers,
"doc string",
)
opsets = []
onnxdomain = OperatorSetIdProto()
onnxdomain.version = 14
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator
# set that is defined as part of the ONNX specification.
opsets.append(onnxdomain)
msdomain = OperatorSetIdProto()
msdomain.version = 1
msdomain.domain = "com.microsoft"
opsets.append(msdomain)
kwargs = {}
kwargs["opset_imports"] = opsets
model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs)
final_model = onnx.shape_inference.infer_shapes(model_def)
onnx.save(final_model, model_name + ".onnx")
input_shapes1 = ["batch_size", "sequence_length", hidden]
_create_model_proto(
input_shapes1, ["sequence_length", 16, 64], 0, [], [1], [4], [0, 0, 16, 64], "gather_reshape_scalar_batch_dim"
)
_create_model_proto(
input_shapes1, [1, "sequence_length", 16, 64], 0, [1], [1], [4], [0, 0, 16, 64], "gather_reshape_batch_dim"
)
_create_model_proto(
input_shapes1, ["batch_size", 16, 64], 1, [], [1], [4], [0, 0, 16, 64], "gather_reshape_scalar_seqlen_dim"
)
_create_model_proto(
input_shapes1, ["batch_size", 1, 16, 64], 1, [1], [1], [4], [0, 0, 16, 64], "gather_reshape_seqlen_dim"
)
input_shapes2 = ["batch_size", 128, hidden]
_create_model_proto(
input_shapes2,
["batch_size", 31, 16, 64],
1,
[31],
[i for i in range(31)],
[4],
[0, 128, 16, 64],
"gather_reshape_seqlen_dim2",
)

View file

@ -0,0 +1,211 @@
import onnx
from onnx import OperatorSetIdProto, TensorProto, helper
# inputs and outputs
hidden = 1024
head = 16
inputs = [
helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden]),
helper.make_tensor_value_info("attention_mask", TensorProto.INT64, ["batch_size", "sequence_length"]),
helper.make_tensor_value_info("matmul1.weight", TensorProto.FLOAT16, [hidden, 1024]),
helper.make_tensor_value_info("add1.bias", TensorProto.FLOAT16, [hidden]),
helper.make_tensor_value_info("matmul2.weight", TensorProto.FLOAT16, [hidden, 1024]),
helper.make_tensor_value_info("add2.bias", TensorProto.FLOAT16, [hidden]),
helper.make_tensor_value_info("matmul3.weight", TensorProto.FLOAT16, [hidden, 1024]),
helper.make_tensor_value_info("add3.bias", TensorProto.FLOAT16, [hidden]),
helper.make_tensor_value_info("matmul4.weight", TensorProto.FLOAT16, [hidden, 1024]),
helper.make_tensor_value_info("add4.bias", TensorProto.FLOAT16, [hidden]),
helper.make_tensor_value_info("layer_norm1.weight", TensorProto.FLOAT, [hidden]),
helper.make_tensor_value_info("layer_norm1.bias", TensorProto.FLOAT, [hidden]),
helper.make_tensor_value_info("matmul7.weight", TensorProto.FLOAT16, [hidden, hidden * 4]),
helper.make_tensor_value_info("add7.bias", TensorProto.FLOAT16, [hidden * 4]),
helper.make_tensor_value_info("matmul8.weight", TensorProto.FLOAT16, [hidden * 4, hidden]),
helper.make_tensor_value_info("add8.bias", TensorProto.FLOAT16, [hidden]),
helper.make_tensor_value_info("layer_norm2.weight", TensorProto.FLOAT, [hidden]),
helper.make_tensor_value_info("layer_norm2.bias", TensorProto.FLOAT, [hidden]),
]
outputs = [
helper.make_tensor_value_info("final_output", TensorProto.FLOAT, ["batch_size", hidden]),
]
# initializers
initializers = [
helper.make_tensor("scalar_float_0.1", TensorProto.FLOAT, [], [0.1]),
helper.make_tensor("scalar_float_0", TensorProto.FLOAT, [], [0.0]),
helper.make_tensor("scalar_float16_8", TensorProto.FLOAT16, [], [8]),
helper.make_tensor("scalar_bool_true", TensorProto.BOOL, [], [1]),
helper.make_tensor("scalar_float_1", TensorProto.FLOAT, [], [1]),
helper.make_tensor("scalar_float_big_num", TensorProto.FLOAT, [], [-3.4028234663852886e38]),
helper.make_tensor("scalar_int_0", TensorProto.INT64, [], [0]),
helper.make_tensor("scalar_int_1", TensorProto.INT64, [], [1]),
helper.make_tensor("single_value_1d_int_0", TensorProto.INT64, [1], [0]),
helper.make_tensor("single_value_1d_int_1", TensorProto.INT64, [1], [1]),
helper.make_tensor("single_value_1d_int_2", TensorProto.INT64, [1], [2]),
helper.make_tensor("single_value_1d_int_16", TensorProto.INT64, [1], [head]),
helper.make_tensor("single_value_1d_int_64", TensorProto.INT64, [1], [hidden // head]),
helper.make_tensor("single_value_1d_int_1024", TensorProto.INT64, [1], [hidden]),
helper.make_tensor("shape1", TensorProto.INT64, [4], [0, 0, 16, 64]),
helper.make_tensor("shape2", TensorProto.INT64, [4], [0, 0, 16, 64]),
helper.make_tensor("shape3", TensorProto.INT64, [4], [0, 0, 16, 64]),
helper.make_tensor("shape4", TensorProto.INT64, [3], [0, 0, 1024]),
]
# nodes
nodes = [
helper.make_node("Dropout", ["input", "scalar_float_0", "scalar_bool_true"], ["d1_out", "d1_mask"], "d1"),
helper.make_node("Cast", ["d1_out"], ["c1_out"], name="c1", to=10),
# attention
## left branch
helper.make_node("MatMul", ["c1_out", "matmul1.weight"], ["m1_out"], "m1"),
helper.make_node("Add", ["add1.bias", "m1_out"], ["a1_out"], "a1"),
helper.make_node("Reshape", ["a1_out", "shape1"], ["reshape1_out"], "reshape1"),
helper.make_node("Transpose", ["reshape1_out"], ["transpose1_out"], name="transpose1", perm=[0, 2, 1, 3]),
## middle branch
helper.make_node("MatMul", ["c1_out", "matmul2.weight"], ["m2_out"], "m2"),
helper.make_node("Add", ["add2.bias", "m2_out"], ["a2_out"], "a2"),
helper.make_node("Reshape", ["a2_out", "shape2"], ["reshape2_out"], "reshape2"),
helper.make_node("Transpose", ["reshape2_out"], ["transpose2_out"], name="transpose2", perm=[0, 2, 1, 3]),
## right banch
helper.make_node("MatMul", ["c1_out", "matmul3.weight"], ["m3_out"], "m3"),
helper.make_node("Add", ["add3.bias", "m3_out"], ["a3_out"], "a3"),
helper.make_node("Reshape", ["a3_out", "shape3"], ["reshape3_out"], "reshape3"),
helper.make_node("Transpose", ["reshape3_out"], ["transpose3_out"], name="transpose3", perm=[0, 2, 3, 1]),
## middle branch result computes with right branch result
helper.make_node("MatMul", ["transpose2_out", "transpose3_out"], ["m4_out"], "m4"),
helper.make_node("Div", ["m4_out", "scalar_float16_8"], ["div1_out"], "div1"),
helper.make_node("Cast", ["div1_out"], ["c2_out"], name="c2", to=1),
helper.make_node("Unsqueeze", ["attention_mask", "single_value_1d_int_1"], ["unsqueeze7_out"], "unsqueeze7"),
helper.make_node("Unsqueeze", ["unsqueeze7_out", "single_value_1d_int_2"], ["unsqueeze8_out"], "unsqueeze8"),
helper.make_node("Cast", ["unsqueeze8_out"], ["c3_out"], name="c3", to=1),
helper.make_node("Sub", ["scalar_float_1", "c3_out"], ["sub1_out"], "sub1"),
helper.make_node("Mul", ["sub1_out", "scalar_float_big_num"], ["mul1_out"], "mul1"),
helper.make_node("Add", ["mul1_out", "c2_out"], ["a4_out"], "a4"),
helper.make_node("Softmax", ["a4_out"], ["softmax1_out"], "softmax1", axis=-1),
helper.make_node("Dropout", ["softmax1_out", "scalar_float_0", "scalar_bool_true"], ["d2_out", "d2_mask"], "d2"),
helper.make_node("Cast", ["d2_out"], ["c4_out"], name="c4", to=10),
## left branch result computes with result of `middle branch result computes with right branch result``
helper.make_node("MatMul", ["c4_out", "transpose1_out"], ["m5_out"], "m5"),
helper.make_node("Transpose", ["m5_out"], ["tranpose4_out"], name="tranpose4", perm=[0, 2, 1, 3]),
helper.make_node("Reshape", ["tranpose4_out", "shape4"], ["reshape4_out"], "reshape4"),
## attention output
helper.make_node("MatMul", ["reshape4_out", "matmul4.weight"], ["m6_out"], "m6"),
helper.make_node("Add", ["add4.bias", "m6_out"], ["a5_out"], "a5"),
helper.make_node("Dropout", ["a5_out", "scalar_float_0", "scalar_bool_true"], ["d4_out", 'd4_mask"'], "d4"),
helper.make_node("Cast", ["d4_out"], ["c5_out"], name="c5", to=1),
helper.make_node("Add", ["d1_out", "c5_out"], ["a6_out"], "a6"),
# MLP
helper.make_node(
"LayerNormalization",
["a6_out", "layer_norm1.weight", "layer_norm1.bias"],
["layernorm1_out", "layernorm1_mean", "layernorm1_var"],
"layernorm1",
axis=-1,
epsion=0.000009999999747378752,
),
helper.make_node("Cast", ["layernorm1_out"], ["c6_out"], name="c6", to=10),
helper.make_node("MatMul", ["c6_out", "matmul7.weight"], ["m7_out"], "m7"),
helper.make_node("BiasGelu", ["m7_out", "add7.bias"], ["biasgelu1_out"], "biasgelu1", domain="com.microsoft"),
helper.make_node("MatMul", ["biasgelu1_out", "matmul8.weight"], ["m8_out"], "m8"),
helper.make_node("Add", ["add8.bias", "m8_out"], ["a7_out"], "a7"),
helper.make_node("Dropout", ["a7_out", "scalar_float_0", "scalar_bool_true"], ["d5_out", "d5_mask"], "d5"),
helper.make_node("Cast", ["d5_out"], ["c7_out"], name="c7", to=1),
helper.make_node("Add", ["layernorm1_out", "c7_out"], ["a8_out"], "a8"),
helper.make_node(
"LayerNormalization",
["a8_out", "layer_norm2.weight", "layer_norm2.bias"],
["layernorm2_out", "layernorm2_mean", "layernorm2_var"],
"layernorm2",
axis=-1,
epsion=0.000009999999747378752,
),
helper.make_node("Gather", ["layernorm2_out", "scalar_int_0"], ["final_gather_out"], "final_gather", axis=1),
helper.make_node(
"Dropout", ["final_gather_out", "scalar_float_0", "scalar_bool_true"], ["final_output", "d6_mask"], "d6"
),
]
# Shapes that cannot be inferred by onnx shape inference
value_infos = [
helper.make_value_info(
name="reshape1_out",
type_proto=helper.make_tensor_type_proto(
elem_type=TensorProto.FLOAT16, shape=["batch_size", "sequence_length", head, hidden // head]
),
),
helper.make_value_info(
name="reshape2_out",
type_proto=helper.make_tensor_type_proto(
elem_type=TensorProto.FLOAT16, shape=["batch_size", "sequence_length", head, hidden // head]
),
),
helper.make_value_info(
name="reshape3_out",
type_proto=helper.make_tensor_type_proto(
elem_type=TensorProto.FLOAT16, shape=["batch_size", "sequence_length", head, hidden // head]
),
),
helper.make_value_info(
name="reshape4_out",
type_proto=helper.make_tensor_type_proto(
elem_type=TensorProto.FLOAT16, shape=["batch_size", "sequence_length", hidden]
),
),
helper.make_value_info(
name="layernorm1_out",
type_proto=helper.make_tensor_type_proto(
elem_type=TensorProto.FLOAT, shape=["batch_size", "sequence_length", hidden]
),
),
helper.make_value_info(
name="layernorm2_out",
type_proto=helper.make_tensor_type_proto(
elem_type=TensorProto.FLOAT, shape=["batch_size", "sequence_length", hidden]
),
),
helper.make_value_info(
name="concattraining4_out",
type_proto=helper.make_tensor_type_proto(elem_type=TensorProto.INT64, shape=[3]),
),
helper.make_value_info(
name="biasgelu1_out",
type_proto=helper.make_tensor_type_proto(
elem_type=TensorProto.FLOAT16, shape=["batch_size", "sequence_length", hidden * 4]
),
),
]
# Create the graph (GraphProto)
graph_def = helper.make_graph(
nodes,
"test-model",
inputs,
outputs,
initializers,
"doc string",
value_infos,
)
opsets = []
onnxdomain = OperatorSetIdProto()
onnxdomain.version = 14
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator
# set that is defined as part of the ONNX specification.
opsets.append(onnxdomain)
msdomain = OperatorSetIdProto()
msdomain.version = 1
msdomain.domain = "com.microsoft"
opsets.append(msdomain)
kwargs = {}
kwargs["opset_imports"] = opsets
model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs)
final_model = onnx.shape_inference.infer_shapes(model_def)
onnx.save(final_model, "gather_roberta_e2e.onnx")

View file

@ -1,8 +1,8 @@
import numpy as np
import onnx
from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper
from onnx import OperatorSetIdProto, TensorProto, helper, numpy_helper
vocab_size = 256 # 30258
vocab_size = 256
X = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", "seqlen", 128])
unsqueezed_masked_lm_positions = helper.make_tensor_value_info(
@ -127,8 +127,9 @@ graph_def = helper.make_graph(
opsets = []
onnxdomain = OperatorSetIdProto()
onnxdomain.version = 12
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
onnxdomain.version = 14
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator
# set that is defined as part of the ONNX specification.
opsets.append(onnxdomain)
msdomain = OperatorSetIdProto()
@ -141,4 +142,18 @@ kwargs["opset_imports"] = opsets
model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs)
onnx.save(model_def, "e2e.onnx")
ln1_value_info = model_def.graph.value_info.add()
ln1_value_info.CopyFrom(X)
ln1_value_info.name = "layer_norm1"
ln2_value_info = model_def.graph.value_info.add()
ln2_value_info.CopyFrom(X)
ln2_value_info.name = "layer_norm2"
gelu1_value_info = model_def.graph.value_info.add()
gelu1_value_info.CopyFrom(X)
gelu1_value_info.name = "gelu1"
final_model = onnx.shape_inference.infer_shapes(model_def)
onnx.save(final_model, "e2e.onnx")

View file

@ -36,12 +36,15 @@ nodes.append(add2)
gathernd2 = helper.make_node(
"GatherND",
["add_2", "unsqueezed_masked_lm_positions"],
["output2"],
["gathernd_out"],
name="gathernd_2",
batch_dims=1,
)
nodes.append(gathernd2)
identity = helper.make_node("Identity", ["gathernd_out"], ["output2"], name="identity")
nodes.append(identity)
graph_def = helper.make_graph(
nodes,
"test-model",

View file

@ -36,12 +36,15 @@ nodes.append(div2)
gathernd2 = helper.make_node(
"GatherND",
["div_2", "unsqueezed_masked_lm_positions"],
["output2"],
["gathernd_out"],
name="gathernd_2",
batch_dims=1,
)
nodes.append(gathernd2)
identity = helper.make_node("Identity", ["gathernd_out"], ["output2"], name="identity")
nodes.append(identity)
graph_def = helper.make_graph(
nodes,
"test-model",

View file

@ -18,12 +18,15 @@ nodes.append(gelu1)
gathernd1 = helper.make_node(
"GatherND",
["gelu_1", "unsqueezed_masked_lm_positions"],
["output"],
["gathernd_out"],
name="gathernd_1",
batch_dims=1,
)
nodes.append(gathernd1)
identity = helper.make_node("Identity", ["gathernd_out"], ["output"], name="identity")
nodes.append(identity)
graph_def = helper.make_graph(nodes, "test-model", [X, unsqueezed_masked_lm_positions], [Y])
opsets = []

View file

@ -34,12 +34,15 @@ nodes.append(layer_norm1)
gathernd1 = helper.make_node(
"GatherND",
["layer_norm1", "unsqueezed_masked_lm_positions"],
["output"],
["gathernd_out"],
name="gathernd_1",
batch_dims=1,
)
nodes.append(gathernd1)
identity = helper.make_node("Identity", ["gathernd_out"], ["output"], name="identity")
nodes.append(identity)
graph_def = helper.make_graph(
nodes,
"test-model",

View file

@ -20,12 +20,15 @@ nodes.append(matmul1)
gathernd1 = helper.make_node(
"GatherND",
["matmul1", "unsqueezed_masked_lm_positions"],
["output"],
["gathernd_out"],
name="gathernd_1",
batch_dims=1,
)
nodes.append(gathernd1)
identity = helper.make_node("Identity", ["gathernd_out"], ["output"], name="identity")
nodes.append(identity)
initializers = [matmul1_initializer]
graph_def = helper.make_graph(nodes, "test-model", [X, unsqueezed_masked_lm_positions], [Y], initializers)

View file

@ -21,6 +21,9 @@ struct TrainingGraphTransformerConfiguration : public GraphTransformerConfigurat
bool transformer_layer_recompute{false};
// Number of layers to apply recompute
int number_recompute_layers{0};
// Enable compute optimizer.
bool enable_compute_optimizer{false};
};
} // namespace training

View file

@ -10,7 +10,7 @@
#include "core/optimizer/bias_softmax_fusion.h"
#include "core/optimizer/cast_elimination.h"
#include "core/optimizer/common_subexpression_elimination.h"
#include "core/optimizer/computation_reduction.h"
#include "core/optimizer/compute_optimizer/compute_optimizer.h"
#include "core/optimizer/concat_slice_elimination.h"
#include "core/optimizer/constant_folding.h"
#include "core/optimizer/conv_activation_fusion.h"
@ -94,7 +94,11 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
ORT_THROW_IF_ERROR(rule_transformer->Register(std::make_unique<InsertSoftmaxCrossEntropyLossOutput>()));
// Remove duplicate nodes. Must be applied before any recompute transformations.
transformers.emplace_back(std::make_unique<CommonSubexpressionEliminationApplyOnce>(compatible_eps));
if (config.gelu_recompute || config.attn_dropout_recompute || config.transformer_layer_recompute) {
transformers.emplace_back(std::make_unique<CommonSubexpressionEliminationApplyOnce>(compatible_eps));
} else {
transformers.emplace_back(std::make_unique<CommonSubexpressionElimination>(compatible_eps));
}
transformers.emplace_back(std::make_unique<GeluFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<LayerNormFusion>(compatible_eps));
@ -120,9 +124,10 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
execution_provider, false /*skip_dequantize_linear*/, compatible_eps, excluded_initializers));
transformers.emplace_back(std::make_unique<ReshapeFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<ConcatSliceElimination>(compatible_eps));
#if defined(USE_CUDA) || defined(USE_ROCM)
transformers.emplace_back(std::make_unique<ComputationReductionTransformer>(compatible_eps));
#endif
if (config.enable_compute_optimizer) {
transformers.emplace_back(std::make_unique<ComputeOptimizer>(compatible_eps));
}
if (config.gelu_recompute) {
transformers.emplace_back(std::make_unique<GeluRecompute>());
}
@ -195,11 +200,11 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
case TransformerLevel::Level1: {
InlinedHashSet<std::string_view> l1_execution_providers = {};
InlinedHashSet<std::string_view> cuda_rocm_execution_providers = {onnxruntime::kCudaExecutionProvider,
onnxruntime::kRocmExecutionProvider};
onnxruntime::kRocmExecutionProvider};
// TODO hack - constant folding currently doesn't work after mixed precision transformation so it's disabled for now
// ORT uses CPU kernels to evaluate constant values but some of them don't support fp16
//transformers.emplace_back(std::make_unique<ConstantFolding>(l1_execution_providers));
// transformers.emplace_back(std::make_unique<ConstantFolding>(l1_execution_providers));
transformers.emplace_back(std::make_unique<MatMulAddFusion>(l1_execution_providers));
transformers.emplace_back(std::make_unique<FreeDimensionOverrideTransformer>(free_dimension_overrides));
transformers.emplace_back(std::make_unique<MatmulTransposeFusion>(cuda_rocm_execution_providers));

View file

@ -192,6 +192,7 @@ Status TrainingRunner::Initialize() {
gt_config.gelu_recompute = params_.gelu_recompute;
gt_config.transformer_layer_recompute = params_.transformer_layer_recompute;
gt_config.number_recompute_layers = params_.number_recompute_layers;
gt_config.enable_compute_optimizer = true;
config.graph_transformer_config = gt_config;
}
@ -580,7 +581,7 @@ void TrainingRunner::RunWithUpdate(VectorString& feed_names,
ORT_THROW_IF_ERROR(status);
} catch (std::exception&) {
// If exception happens during worker execution, propogate the exception to main thread.
// If exception happens during worker execution, propagate the exception to main thread.
pipeline_worker_pool_.worker_states[worker_id].execution_exception = std::current_exception();
}
},

View file

@ -711,6 +711,7 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
.def_readwrite("gelu_recompute", &TrainingGraphTransformerConfiguration::gelu_recompute)
.def_readwrite("transformer_layer_recompute", &TrainingGraphTransformerConfiguration::transformer_layer_recompute)
.def_readwrite("number_recompute_layers", &TrainingGraphTransformerConfiguration::number_recompute_layers)
.def_readwrite("enable_compute_optimizer", &TrainingGraphTransformerConfiguration::enable_compute_optimizer)
.def_readwrite("propagate_cast_ops_config", &TrainingGraphTransformerConfiguration::GraphTransformerConfiguration::propagate_cast_ops_config);
py::class_<OrtModuleGraphBuilderConfiguration> module_graph_builder_config(

View file

@ -177,6 +177,12 @@ class GraphExecutionManager(GraphExecutionInterface):
# Memory aware gradient builder.
self._use_memory_efficient_gradient = False
# Enable compute optimizer by default. Allowed to be disabled via environment variable for
# convergence parity investigation.
self._enable_compute_optimizer = (
ortmodule._defined_from_envvar("ORTMODULE_ENABLE_COMPUTE_OPTIMIZER", 1, warn=True) == 1
)
# Flag to re-export the model due to attribute change on original module.
# Re-export will be avoided if _skip_check is enabled.
self._original_model_has_changed = False
@ -445,6 +451,7 @@ class GraphExecutionManager(GraphExecutionInterface):
graph_transformer_config.propagate_cast_ops_config.level = self._propagate_cast_ops_level
graph_transformer_config.propagate_cast_ops_config.allow = self._propagate_cast_ops_allow
graph_transformer_config.propagate_cast_ops_config.strategy = self._propagate_cast_ops_strategy
graph_transformer_config.enable_compute_optimizer = self._enable_compute_optimizer
return graph_transformer_config
def _initialize_graph_builder(self):

View file

@ -48,12 +48,12 @@ namespace {
* Then load it into ORT, compare with the initial parameter values.
*/
TEST(CheckpointApiTest, SaveOnnxModelAsCheckpoint_ThenLoad_CPU) {
/// Phase 1 - Test Preparison
/// Phase 1 - Test Preparation
/// Prepare the data and dest folder for saving checkpoint.
/// Also cooked the data for test result comparison.
// Model path and trainable parameter name definitions.
auto model_uri = MODEL_FOLDER "transform/computation_reduction/e2e.onnx";
auto model_uri = MODEL_FOLDER "transform/computation_reduction/gathernd/e2e.onnx";
std::vector<std::string> expected_trainable_param_names{
"bert.encoder.layer.2.output.LayerNorm.weight",
"bert.encoder.layer.2.output.LayerNorm.bias",
@ -88,7 +88,7 @@ TEST(CheckpointApiTest, SaveOnnxModelAsCheckpoint_ThenLoad_CPU) {
ORT_ENFORCE(CreateOrtValuesFromTensorProtos(trainable_param_values, expected_trainable_param_name_to_ort_value)
.IsOK());
// Remove the tempoprary directory if it already exists.
// Remove the temporary directory if it already exists.
auto ckpt_test_root_dir = ORT_TSTR("checkpointing_api_test_dir");
if (Env::Default().FolderExists(ckpt_test_root_dir)) {
ORT_ENFORCE(Env::Default().DeleteFolder(ckpt_test_root_dir).IsOK());
@ -120,7 +120,7 @@ TEST(CheckpointApiTest, SaveOnnxModelAsCheckpoint_ThenLoad_CPU) {
ASSERT_EQ(expected_file_names, valid_file_names);
/// Phase 3 - Run load checkpoint APIs.
/// And check the result comparible with initial parameter values.
/// And check the result comparable with initial parameter values.
// Call Load APIs
CheckpointState checkpoint_state_to_load;
@ -199,7 +199,7 @@ TEST(CheckpointApiTest, LoadCheckpointToModel) {
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad_CUDA) {
/// Phase 1 - Test Preparison
/// Phase 1 - Test Preparation
/// Prepare the data and dest folder for saving checkpoint.
/// Also cooked the data for test result comparison.
auto model_uri = MODEL_FOLDER "training_api/training_model.onnx";
@ -252,7 +252,7 @@ TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad_CUDA) {
CheckpointState checkpoint_state;
ORT_ENFORCE(optimizer->GetStateDict(checkpoint_state.optimizer_checkpoint_state).IsOK());
// Remove the tempoprary directory if it already exists.
// Remove the temporary directory if it already exists.
auto ckpt_test_root_dir = ORT_TSTR("checkpointing_api_test_dir");
if (Env::Default().FolderExists(ckpt_test_root_dir)) {
ORT_ENFORCE(Env::Default().DeleteFolder(ckpt_test_root_dir).IsOK());
@ -287,7 +287,7 @@ TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad_CUDA) {
ASSERT_EQ(expected_file_names, valid_file_names);
/// Phase 3 - Run load checkpoint APIs.
/// And check the result comparible with initial optimizer state values.
/// And check the result comparable with initial optimizer state values.
// Call Load APIs
CheckpointState checkpoint_state_to_load;
@ -334,7 +334,7 @@ TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad_CUDA) {
* Then load it into ORT, compare with the initial properties' values.
*/
TEST(CheckpointApiTest, SaveCustomPropertyAsCheckpoint_ThenLoad_CPU) {
/// Phase 1 - Test Preparison
/// Phase 1 - Test Preparation
/// Prepare the data and dest folder for saving checkpoint.
CheckpointState checkpoint_state;
@ -352,7 +352,7 @@ TEST(CheckpointApiTest, SaveCustomPropertyAsCheckpoint_ThenLoad_CPU) {
std::string s_property_name("train_data_path");
property_bag.AddProperty(s_property_name, s_data);
// Remove the tempoprary directory if it already exists.
// Remove the temporary directory if it already exists.
auto ckpt_test_root_dir = ORT_TSTR("checkpointing_api_test_dir");
if (Env::Default().FolderExists(ckpt_test_root_dir)) {
ORT_ENFORCE(Env::Default().DeleteFolder(ckpt_test_root_dir).IsOK());

View file

@ -187,7 +187,7 @@ class NcclService final : public INcclService {
// Search the next unfinished communication group to work on.
int FindNextCommunicationTime() const;
// Mutex to gurantee thread-safe access to this class.
// Mutex to guarantee thread-safe access to this class.
std::mutex mutex_;
// Conditional variable used to wait for the mutex.
std::condition_variable cv_;