diff --git a/cmake/onnxruntime_optimizer.cmake b/cmake/onnxruntime_optimizer.cmake index 306c6a9b55..0cb8103630 100644 --- a/cmake/onnxruntime_optimizer.cmake +++ b/cmake/onnxruntime_optimizer.cmake @@ -57,6 +57,8 @@ else() "${ONNXRUNTIME_INCLUDE_DIR}/core/optimizer/*.h" "${ONNXRUNTIME_ROOT}/core/optimizer/*.h" "${ONNXRUNTIME_ROOT}/core/optimizer/*.cc" + "${ONNXRUNTIME_ROOT}/core/optimizer/compute_optimizer/*.h" + "${ONNXRUNTIME_ROOT}/core/optimizer/compute_optimizer/*.cc" "${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/*.h" "${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/*.cc" "${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/*.h" diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index 49375f5891..42dd4bad10 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -124,6 +124,13 @@ Before full qualified name can be got from exporter, this environment variables export ORTMODULE_SKIPPED_AUTOGRAD_FUNCTIONS="megatron.fp16.fp16.fused_kernels.GELUFunction" ``` +#### ORTMODULE_ENABLE_COMPUTE_OPTIMIZER + +- **Feature Area**: *ORTMODULE/Optimizations* +- **Description**: By default, this is enabled then some computation can be saved. This env var can be used for disabling +the optimization to guarantee exactly same compute with baseline (for example PyTorch, when doing convergence parity +debugging). + ### 2.2 Memory Optimization Q: *Want to run a bigger batch size?* diff --git a/onnxruntime/core/optimizer/computation_reduction.cc b/onnxruntime/core/optimizer/computation_reduction.cc deleted file mode 100644 index f1ccc9f877..0000000000 --- a/onnxruntime/core/optimizer/computation_reduction.cc +++ /dev/null @@ -1,305 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/common/safeint.h" -#include "core/graph/graph_utils.h" -#include "core/optimizer/initializer.h" -#include "core/optimizer/utils.h" -#include "core/optimizer/computation_reduction.h" - -using namespace ONNX_NAMESPACE; -using namespace ::onnxruntime::common; -namespace onnxruntime { -typedef std::function Handler; - -constexpr int GATHERND_BATCH_DIM = 1; - -static bool IsLeadingDimsEqual(const TensorShapeProto* input_shape, const TensorShapeProto* output_shape, - const int num_dim_to_check) { - ORT_ENFORCE(output_shape->dim_size() >= num_dim_to_check && input_shape->dim_size() >= num_dim_to_check); - - for (int i = 0; i < num_dim_to_check; ++i) { - auto& output_dim = output_shape->dim(i); - auto& input_dim = input_shape->dim(i); - if (output_dim.has_dim_value() && input_dim.has_dim_value()) { - if (output_dim.dim_value() != input_dim.dim_value()) { - return false; - } - } else if (output_dim.has_dim_param() && input_dim.has_dim_param()) { - if (output_dim.dim_param() != input_dim.dim_param()) { - return false; - } - } else { - return false; - } - } - - return true; -} - -static int GetValidInputForGatherND(const Node& target_node) { - // target_node is the producer of GatherND's input. - // If target_node's some input tensors have exactly same shape with - // target_node output tensor shape, then it is safe to gather using - // original slice ranges. - int candidate_input_index = -1; - auto output_shape = target_node.OutputDefs()[0]->Shape(); - const int output_rank = output_shape->dim_size(); - for (size_t i = 0; i < target_node.InputDefs().size(); ++i) { - auto input_shape = target_node.InputDefs()[i]->Shape(); - const int input_rank = input_shape->dim_size(); - if (input_rank != output_rank) { - continue; - } - - if (IsLeadingDimsEqual(input_shape, output_shape, GATHERND_BATCH_DIM + 1)) { - candidate_input_index = SafeInt(i); - break; - } - } - - return candidate_input_index; -} - -static TensorShapeProto ReplaceSymbolicDimValue(const TensorShapeProto* shape, const int replacement_axis, - const std::string& replacement_dim_value) { - ORT_ENFORCE(replacement_axis >= 0 && replacement_axis < shape->dim_size()); - TensorShapeProto output_shape; - for (int i = 0; i < shape->dim_size(); ++i) { - auto& dim = shape->dim(i); - if (i == replacement_axis) { - output_shape.add_dim()->set_dim_param(replacement_dim_value); - continue; - } - - if (dim.has_dim_value()) { - output_shape.add_dim()->set_dim_value(dim.dim_value()); - } else if (dim.has_dim_param()) { - output_shape.add_dim()->set_dim_param(dim.dim_param()); - } else { - ORT_THROW("Invalid dim found in ReplaceSymbolicDimValue"); - } - } - - return output_shape; -} - -static Status SwapGatherNDWithTargetNode(Graph& graph, Node& gathernd_node, Node& target_node, - const int target_node_input_index = 0) { - auto new_input_arg_for_gathernd = target_node.MutableInputDefs()[target_node_input_index]; - auto target_node_out_arg = target_node.MutableOutputDefs()[0]; - auto gathernd_out_arg = gathernd_node.MutableOutputDefs()[0]; - auto gathernd_old_consumers = graph.GetConsumerNodes(gathernd_out_arg->Name()); - const auto& graph_outputs = graph.GetOutputs(); - bool need_update_graph_output = false; - if (std::find(graph_outputs.begin(), graph_outputs.end(), gathernd_out_arg) != graph_outputs.end()) { - need_update_graph_output = true; - } - - const std::string& gathered_dim_param = gathernd_out_arg->Shape()->dim(GATHERND_BATCH_DIM).dim_param(); - TensorShapeProto new_output_shape_for_gathernd = - ReplaceSymbolicDimValue(new_input_arg_for_gathernd->Shape(), GATHERND_BATCH_DIM, gathered_dim_param); - - TensorShapeProto new_output_shape_for_target_node = - ReplaceSymbolicDimValue(target_node_out_arg->Shape(), GATHERND_BATCH_DIM, gathered_dim_param); - - // update input/output definitions. - int output_index = optimizer_utils::IndexOfNodeOutput(target_node, *gathernd_node.MutableInputDefs()[0]); - graph.RemoveEdge(target_node.Index(), gathernd_node.Index(), output_index, 0); - const Node* target_node_input_node = graph.GetProducerNode(new_input_arg_for_gathernd->Name()); - if (target_node_input_node != nullptr) { - output_index = optimizer_utils::IndexOfNodeOutput(*target_node_input_node, *new_input_arg_for_gathernd); - graph.AddEdge(target_node_input_node->Index(), gathernd_node.Index(), output_index, 0); - } else { - // new_input_arg_for_gathernd is graph input - graph_utils::ReplaceNodeInput(gathernd_node, 0, *new_input_arg_for_gathernd); - } - - graph_utils::ReplaceDownstreamNodeInput(graph, gathernd_node, 0 /*output_idx*/, - target_node, 0 /*replacement_output_idx*/); - - if (target_node_input_node != nullptr) { - graph.RemoveEdge(target_node_input_node->Index(), target_node.Index(), output_index, target_node_input_index); - } - graph.AddEdge(gathernd_node.Index(), target_node.Index(), 0, target_node_input_index); - - // update consumer relation ship - if (!gathernd_old_consumers.empty()) { - graph.UpdateConsumerNodes(target_node_out_arg->Name(), {const_cast(gathernd_old_consumers[0])}); - } - graph.UpdateConsumerNodes(gathernd_out_arg->Name(), {&target_node}); - - // update shapes - gathernd_out_arg->SetShape(new_output_shape_for_gathernd); - target_node_out_arg->SetShape(new_output_shape_for_target_node); - - if (need_update_graph_output) { - InlinedVector graph_new_outputs; - graph_new_outputs.reserve(graph_outputs.size()); - for (auto out_arg : graph_outputs) { - if (out_arg->Name().compare(gathernd_out_arg->Name()) == 0) { - graph_new_outputs.push_back(target_node_out_arg); - } else { - graph_new_outputs.push_back(out_arg); - } - } - graph.SetOutputs(graph_new_outputs); - graph.SetGraphResolveNeeded(); - graph.SetGraphProtoSyncNeeded(); - } - - return Status::OK(); -} - -static Status SimpleHandler(Graph& graph, Node& gathernd_node, Node& target_node) { - return SwapGatherNDWithTargetNode(graph, gathernd_node, target_node, 0); -} - -/* - This handler change the graphs this way: - Before: - input_1[b,s,h] weight_2[h] - \ / - Add[b,s,h] indices[b,p_s,1] - | / - GatherND[b,p_s,h] - | - - After : - input_1[b,s,h] indices[b,p_s,1] - | / - GatherND[b,p_s,h] weight_2[h] - \ / - Add[b,p_s,h] - | - - Note: b: batch, s: sequence_length, h: hidden_size, p_s: dynamic_prediction_count -*/ -static Status BinaryElementwiseHandler(Graph& graph, Node& gathernd_node, Node& target_node) { - int target_node_input_index = GetValidInputForGatherND(target_node); - ORT_RETURN_IF(target_node_input_index == -1, "Invalid target node index"); - return SwapGatherNDWithTargetNode(graph, gathernd_node, target_node, target_node_input_index); -} - -/* - This handler change the graphs this way: - Before: - input_1[b,s,h] weight_2[h, 2h] - \ / - MatMul[b,s,2h] indices[b,p_s,1] - | / - GatherND[b,p_s,2h] - | - - After : - input_1[b,s,h] indices[b,p_s,1] - | / - GatherND[b,p_s,h] weight_2[h,2h] - \ / - MatMul[b,p_s,2h] - | - - Note: b: batch, s: sequence_length, h: hidden_size, p_s: dynamic_prediction_count -*/ -static Status MatMulHandler(Graph& graph, Node& gathernd_node, Node& target_node) { - int target_node_input_index = GetValidInputForGatherND(target_node); - ORT_RETURN_IF_NOT(target_node_input_index == 0, "target_node_input_index != 0"); - return SwapGatherNDWithTargetNode(graph, gathernd_node, target_node, target_node_input_index); -} - -static std::unordered_map handlers = { - {"Add", BinaryElementwiseHandler}, - {"Div", BinaryElementwiseHandler}, - {"Gelu", SimpleHandler}, - {"LayerNormalization", SimpleHandler}, - {"MatMul", MatMulHandler}}; - -static Status Delegate(Graph& graph, Node& gathernd_node, Node& target_node) { - const std::string& op_type = target_node.OpType(); - if (handlers.count(op_type)) { - return handlers[op_type](graph, gathernd_node, target_node); - } else { - return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, op_type + " handler is not implemented"); - } -} - -Status ComputationReductionTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, - const logging::Logger& logger) const { - GraphViewer graph_viewer(graph); - const auto& order = graph_viewer.GetNodesInTopologicalOrder(); - - for (auto index : order) { - auto* node_ptr = graph.GetNode(index); - if (!node_ptr) - // node was removed, this should not happen since we are not removing nodes. - continue; - - auto& node = *node_ptr; - ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); - - // Same ideas might apply for Gather, GatherElements, Slice, Split, etc. - // TODO: let's review the real cases to make the logic more generic. - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "GatherND", {1, 12, 13}, kOnnxDomain) || - !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || - node.GetOutputEdgesCount() > 1) { // allow GatherND have no out edges in case it is graph output. - continue; - } - - auto batch_dims = static_cast(node.GetAttributes().at("batch_dims").i()); - if (batch_dims != GATHERND_BATCH_DIM) { - continue; - } - - auto indices_shape = node.MutableInputDefs()[1]->Shape(); - if (indices_shape == nullptr) { - continue; - } - - const auto indices_rank = indices_shape->dim_size(); - auto& indices_last_dim = indices_shape->dim(indices_rank - 1); - // Since GatherND is assumed to have batch_dims=1, if the input data's shape is [batch, sequence, ..., ... ], - // limiting indices_rank=3 will make sure produced output is in shape [batch, sliced_sequence, ..., ...] - // and the rank did not change. - if (!(indices_last_dim.has_dim_value() && indices_last_dim.dim_value() == 1 && indices_rank == 3)) { - continue; - } - - // Todo: check whether we want to move GatherND up, for example, if GatherND's outputs are larger - // than inputs, we should NOT probably bring it ahead. - bool stop = false; - while (!stop) { - const Node* gathernd_data_producer = graph.GetProducerNode(node.MutableInputDefs()[0]->Name()); - if (gathernd_data_producer == nullptr) { - break; - } - Node* input_node = const_cast(gathernd_data_producer); - if (graph.GetConsumerNodes(input_node->MutableOutputDefs()[0]->Name()).size() > 1) { - LOGS_DEFAULT(WARNING) << "node " << node.Name() << " stopped at node " - << input_node->Name(); - break; - } - - auto ret = Delegate(graph, node, *input_node); - if (ret.IsOK()) { - LOGS_DEFAULT(WARNING) << "node " << node.Name() << " up across node " - << input_node->Name() << std::endl; - modified = true; - } else if (ret.Code() == common::NOT_IMPLEMENTED) { - LOGS_DEFAULT(WARNING) << "node " << node.Name() << " stopped at node " - << input_node->Name(); - break; - } else { - LOGS_DEFAULT(WARNING) << " terminate due to unexpected error, node names:" << node.Name() - << ", " << input_node->Name() << ", error " << ret.ErrorMessage() << std::endl; - stop = true; - } - } - } - - if (modified) { - graph.SetGraphResolveNeeded(); - } - return Status::OK(); -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/computation_reduction.h b/onnxruntime/core/optimizer/computation_reduction.h deleted file mode 100644 index 1f44b7ee4a..0000000000 --- a/onnxruntime/core/optimizer/computation_reduction.h +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/optimizer/graph_transformer.h" - -namespace onnxruntime { - -class ComputationReductionTransformer : public GraphTransformer { - public: - ComputationReductionTransformer(const InlinedHashSet& compatible_execution_providers = {}) noexcept - : GraphTransformer("ComputationReductionTransformer", compatible_execution_providers) {} - - Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; -}; - -} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/compute_optimizer/compute_optimizer.cc b/onnxruntime/core/optimizer/compute_optimizer/compute_optimizer.cc new file mode 100644 index 0000000000..2bf3aa2cc7 --- /dev/null +++ b/onnxruntime/core/optimizer/compute_optimizer/compute_optimizer.cc @@ -0,0 +1,541 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef ENABLE_TRAINING +#include + +#include "core/common/safeint.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" +#include "core/optimizer/compute_optimizer/passthrough_actors.h" +#include "core/optimizer/compute_optimizer/compute_optimizer.h" + +using SliceInfo = onnxruntime::optimizer::compute_optimizer::SliceInfo; +using namespace onnxruntime::optimizer::compute_optimizer; +namespace onnxruntime { + +namespace { + +bool EnforceNodeAllInputOutputHaveShapes(const Node& node) { + for (const auto* input_def : node.InputDefs()) { + if (!input_def->Shape()) { + return false; + } + } + + for (const auto* output_def : node.OutputDefs()) { + if (!output_def->Shape()) { + return false; + } + } + return true; +} + +using OPSET_VERSION_LIST = std::initializer_list; +const OPSET_VERSION_LIST opset_1{1}; +const OPSET_VERSION_LIST opset_13_1{13, 1}; +const OPSET_VERSION_LIST opset_13_9_1{13, 9, 1}; +const OPSET_VERSION_LIST opset_13_11_1{13, 11, 1}; +const OPSET_VERSION_LIST opset_13_9_6_1{13, 9, 6, 1}; +const OPSET_VERSION_LIST opset_14_13_5_1{14, 13, 5, 1}; +const OPSET_VERSION_LIST opset_14_13_7_6_1{14, 13, 7, 6, 1}; +const OPSET_VERSION_LIST opset_13_12_10_7_6_1{13, 12, 10, 7, 6, 1}; + +/** + * @brief Functor to trigger the optimization search for a given slicing node + * (for example Gather/GatherND node). + */ +struct SliceOperationReorderHandle { + /** + * @brief Pass through configuration for specific operator. + * + * For each operator: + * > `input_indices` can be used to explicitly specify the input indices that Slicing op can be passed through. + * This could be helpful if some inputs are not applicable for pass through. If not specified, all inputs + * are considered (but there will be checks to ignore those inputs that are not affected by the slicing axis). + * > `actor` will be used to perform the actual pass through, including both pre-check stage and post process + * stage. + */ + struct OpPassThroughConfig { + OpPassThroughConfig(const std::vector& input_indices, + std::shared_ptr actor, + const OPSET_VERSION_LIST& opset_list) + : input_indices(input_indices), actor(actor), opsets(opset_list) { + } + + std::vector input_indices; + std::shared_ptr actor; + const OPSET_VERSION_LIST& opsets; + }; + + static std::string GetFullQualifiedOpName(const std::string& op_type, const std::string& domain) { + return domain + "::" + op_type; + } + + static std::unordered_map& GetOpPassThroughConfigMap() { + static std::unordered_map allowed_passthrough_ops; + static std::once_flag allowed_ops_init; + std::call_once(allowed_ops_init, []() { + allowed_passthrough_ops.insert({ + // Things to consider when more operators are added here: + // 1. Whether the operator is safe to pass through in term of compute equivalence. + // If optype is not enough to guarantee the equivalence, we need to add a customized pre-check function + // (as LayerNormalization did). + // 2. Whether the outputs have the same dim changes if Gather node moves before that operator. + // 3. Should all inputs be allowed when track back further (bottom-up); + // if not, add the input index restriction as MatMul did. + {GetFullQualifiedOpName("Add", kOnnxDomain), + OpPassThroughConfig({}, std::make_shared(), opset_14_13_7_6_1)}, + {GetFullQualifiedOpName("BiasGelu", kMSDomain), + OpPassThroughConfig({}, std::make_shared(), opset_1)}, + {GetFullQualifiedOpName("BitmaskBiasDropout", kMSDomain), + OpPassThroughConfig({}, std::make_shared(), opset_1)}, + {GetFullQualifiedOpName("Cast", kOnnxDomain), + OpPassThroughConfig({}, std::make_shared(), opset_13_9_6_1)}, + {GetFullQualifiedOpName("Div", kOnnxDomain), + OpPassThroughConfig({}, std::make_shared(), opset_14_13_7_6_1)}, + {GetFullQualifiedOpName("Dropout", kOnnxDomain), + OpPassThroughConfig({}, std::make_shared(), opset_13_12_10_7_6_1)}, + {GetFullQualifiedOpName("Gelu", kMSDomain), + OpPassThroughConfig({}, std::make_shared(), opset_1)}, + {// Be noted, this is our own implementation of ONNX domain op. + GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), + OpPassThroughConfig({0}, std::make_shared(), opset_1)}, + {GetFullQualifiedOpName("MatMul", kOnnxDomain), + OpPassThroughConfig({}, std::make_shared(), opset_13_9_1)}, + {GetFullQualifiedOpName("Reshape", kOnnxDomain), + OpPassThroughConfig({0}, std::make_shared(), opset_14_13_5_1)}, + {GetFullQualifiedOpName("Softmax", kOnnxDomain), + OpPassThroughConfig({0}, std::make_shared(), opset_13_11_1)}, + {GetFullQualifiedOpName("Transpose", kOnnxDomain), + OpPassThroughConfig({}, std::make_shared(), opset_13_1)}, + }); + }); + + return allowed_passthrough_ops; + } + + SliceOperationReorderHandle(const std::string& node_name) : entry_node_name_(node_name) { + } + + bool operator()(Graph& graph, Node& current_node, SliceInfo& info, const logging::Logger& logger, + std::deque& queue); + + private: + /** + * @brief Pass through Slicing op from current_node's output to its specific input. + * + * Propagate the slicing operation into current_node's current_input_index-th input, e.g. a slicing op is inserted + * between current_node's current_input_index-th input and current_node. For example, if current_node is Add, + * and slice_node is a Gather(axis=1, indices=[1]): + * + * input_0 [M, N, K] input_1 [M, N, K] + * \ / + * Add [M, N, K] + * | + * Gather0(axis=1, indices=[1]) + * | + * output [M, 1, K] + * + * After the pass through, the graph will be: + * + * input_0 [M, N, K] input_1 [M, N, K] + * \ / + * Gather1(axis=1, indices=[1]) Gather2(axis=1, indices=[1]) + * \ / + * \ / + * \ / + * Add [M, N, K] + * | + * Gather0(axis=1, indices=[1]) + * | + * output [M, 1, K] + * + * Be noted: Gather1 and Gather2 are inserted on Add's two inputs. + * Gather0's removal and Add's output shape update is done in RemoveOriginSlicingOp. + * + * + * @param graph Graph to iterate. + * @param slice_node Slicing op node the takes current_node's output as input. + * @param current_node Current node. + * @param current_node_input_index The current_node_input_index-th input to propagate the Slice op pass through. + * @param info slice_node's SliceInfo. + * @param logger Logger. + * @param new_axis The new axis (for the new Slice op) upon current_node's original current_node_input_index-th input. + * @return SliceInfo for new created slicing op. + */ + SliceInfo PropagateSlicingForInput(Graph& graph, Node& slice_node, Node& current_node, int current_node_input_index, + SliceInfo& info, int new_axis, const logging::Logger& logger); + + /** + * @brief Remove the origin slicing op (for example Gather/GatherND) and update shapes. + * + * In the above example, the graph will be cleaned up to: + * input_0 [M, N, K] input_1 [M, N, K] + * \ / + * Gather1(axis=1, indices=[1]) Gather2(axis=1, indices=[1]) + * \ / + * \ / + * \ / + * Add [M, 1, K] + * | + * | + * output [M, 1, K] + * + * Be noted: Gather0 is removed, Add's output shape is updated. + * + * @param graph Graph to iterate. + * @param slice_node Slicing op node the takes current_node's output as input. + * @param current_node Current node. + * @param logger Logger. + * @param info slice_node's SliceInfo. + * @return + */ + Status RemoveOriginSlicingOp(Graph& graph, Node& slice_node, Node& current_node, + const logging::Logger& logger, SliceInfo& info); + + std::string entry_node_name_; +}; + +bool SliceOperationReorderHandle::operator()(Graph& graph, Node& current_node, + SliceInfo& info, + const logging::Logger& logger, + std::deque& queue) { + Node& slice_node = *info.node_ptr; + const std::string& op_type = GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); + if (GetOpPassThroughConfigMap().count(op_type)) { + auto& pass_through_config = GetOpPassThroughConfigMap().at(op_type); + LOG_DEBUG_INFO(logger, "Enter reorder handle for node " + current_node.Name() + "(" + op_type + ")"); + + if (!graph_utils::IsSupportedOptypeVersionAndDomain(current_node, current_node.OpType(), + pass_through_config.opsets, current_node.Domain())) { + LOG_DEBUG_INFO(logger, "Unsupported opset for " + current_node.Name() + "(" + op_type + ") since version: " + + std::to_string(current_node.SinceVersion())); + return false; + } + + if (!EnforceNodeAllInputOutputHaveShapes(current_node)) { + LOG_DEBUG_INFO(logger, "Some inputs/outputs' shape not found for node " + current_node.Name() + "(" + + op_type + ")"); + return false; + } + + std::unordered_map candidate_input_indices; + bool input_has_dim_1_for_axis = false; + if (!pass_through_config.actor->PreCheck(graph, current_node, info, pass_through_config.input_indices, logger, + candidate_input_indices, input_has_dim_1_for_axis)) { + LOG_DEBUG_INFO(logger, "Pre-check failed for " + current_node.Name() + "(" + op_type + ")"); + return false; + } + + if (candidate_input_indices.empty()) { + LOG_DEBUG_INFO(logger, "Skip handling current node " + current_node.Name() + "(" + op_type + + ") because the requirement is not met."); + return false; + } + + // Be noted, once we reach this point after PreCheck, graph modification started, any failure after this should + // be reported as ERROR. + std::vector populated_slicing_infos; // Slicing infos that are populated into current_node's inputs. + populated_slicing_infos.reserve(candidate_input_indices.size()); + std::unordered_map new_gather_infos; + for (auto pair : candidate_input_indices) { + auto input_index = pair.first; // input index of current_node + int new_axis = pair.second; // new axis of current_node's input to be sliced + SliceInfo gather_info = PropagateSlicingForInput(graph, slice_node, current_node, input_index, info, new_axis, + logger); + + ORT_ENFORCE(gather_info.node_ptr, "New added gather node should not be null."); + populated_slicing_infos.push_back(gather_info); + new_gather_infos.insert({{input_index, gather_info}}); + } + + int index_of_output = + optimizer_utils::IndexOfNodeOutput(current_node, *slice_node.InputDefs()[info.GetDataInputIndex()]); + ORT_ENFORCE(RemoveOriginSlicingOp(graph, slice_node, current_node, logger, info).IsOK()); + if (!pass_through_config.actor->PostProcess(graph, current_node, index_of_output, info.non_negative_axis, + info.is_scalar_slice, input_has_dim_1_for_axis, + info.output_dim_on_axis, + entry_node_name_, new_gather_infos, + logger)) { + ORT_THROW("Post-process failed for " + current_node.Name() + "(" + op_type + ")"); + } + + queue.insert(queue.end(), populated_slicing_infos.begin(), populated_slicing_infos.end()); + return true; + } else { + LOG_DEBUG_INFO(logger, "op_type not supported for " + current_node.Name() + "(" + op_type + ")"); + return false; + } +} + +SliceInfo SliceOperationReorderHandle::PropagateSlicingForInput(Graph& graph, + Node& slice_node, + Node& current_node, + int current_node_input_index, + SliceInfo& info, + int new_axis, + const logging::Logger& logger) { + LOG_DEBUG_INFO(logger, "PropagateSlicingForInput for Node " + slice_node.Name() + "(" + slice_node.OpType() + + ") with input index " + std::to_string(current_node_input_index) + ", keep_dim = " + + std::to_string(!info.is_scalar_slice)); + + InlinedVector input_args; + input_args.reserve(slice_node.InputDefs().size()); + // The first slice op's data input should be current_node's current_node_input_index-th input. + // For some cases when rank changes, slice op's slice input should also be adapted. + input_args.push_back(current_node.MutableInputDefs()[current_node_input_index]); + for (size_t i = 1; i < slice_node.InputDefs().size(); ++i) { + input_args.push_back(slice_node.MutableInputDefs()[i]); + } + + // Update the axis attribute if new_axis is not same with the original slicing axis (which happens when data + // layout got changed by Transpose or Reshape ops) + onnxruntime::NodeAttributes attributes = slice_node.GetAttributes(); + if (info.non_negative_axis != new_axis) { + attributes[info.axis_attr_name] = + ONNX_NAMESPACE::MakeAttribute(info.axis_attr_name, static_cast(new_axis)); + } + + InlinedVector output_args; + output_args.push_back( + &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(info.entry_slice_arg_name), + current_node.MutableInputDefs()[current_node_input_index]->TypeAsProto())); + + /* new node input index to connect to current_node's input node*/ + int new_slice_input_index_to_connect = info.GetDataInputIndex(); + /* new node output index to connect to current_node*/ + int new_slice_output_index_to_connect = info.GetOutputIndex(); + Node* new_slice_node = InsertIntermediateNodeOnDestInput(graph, current_node, + current_node_input_index, + new_slice_input_index_to_connect, + new_slice_output_index_to_connect, + graph.GenerateNodeName(info.entry_slice_arg_name), + slice_node.OpType(), + "Duplicated Gather node", + input_args, + output_args, + attributes, + slice_node.Domain(), + logger); + + new_slice_node->SetExecutionProviderType(slice_node.GetExecutionProviderType()); + + // Set correct shape for new created node. + auto new_slice_out_arg = new_slice_node->MutableOutputDefs()[new_slice_output_index_to_connect]; + int reversed_axis = new_axis - new_slice_out_arg->Shape()->dim_size(); + UpdateSliceOutputShape(*new_slice_out_arg, reversed_axis, info.output_dim_on_axis); + auto new_slice_info = SliceInfo(new_slice_node, info.is_scalar_slice, info.axis_attr_name, new_axis); + new_slice_info.entry_slice_arg_name = info.entry_slice_arg_name; + return new_slice_info; +} + +Status SliceOperationReorderHandle::RemoveOriginSlicingOp(Graph& graph, Node& slice_node, Node& current_node, + const logging::Logger& logger, SliceInfo& info) { + LOG_DEBUG_INFO(logger, "RemoveOriginSlicingOp target_node " + current_node.Name() + "(" + current_node.OpType() + + ") slice_node " + slice_node.Name() + "(" + slice_node.OpType() + "), keep_dim = " + + std::to_string(!(info.is_scalar_slice))); + + auto slice_input_arg = slice_node.MutableInputDefs()[info.GetDataInputIndex()]; + int slice_input_rank = slice_input_arg->Shape()->dim_size(); + int output_index = optimizer_utils::IndexOfNodeOutput(current_node, *slice_input_arg); + auto slice_op_output_arg = slice_node.MutableOutputDefs()[info.GetOutputIndex()]; + + // Loop all outputs of target node, update the shape accordingly. + // For elementwise ops like (LayerNorm/Dropout/Add), we should handle all outputs. + // If some output rank is lower than sliced axis, we should just ignore it (the correctness is guaranteed by devs + // who adds more operator coverage in the pass through). + for (size_t i = 0; i < current_node.MutableOutputDefs().size(); ++i) { + UpdateSliceOutputShape(*current_node.MutableOutputDefs()[i], info.non_negative_axis - slice_input_rank, + info.output_dim_on_axis); + } + LOG_DEBUG_INFO(logger, "RemoveOriginSlicingOp Replace all usage of output " + slice_op_output_arg->Name() + ":0" + + " with " + current_node.MutableOutputDefs()[output_index]->Name() + ":" + + std::to_string(output_index)); + + graph_utils::ReplaceDownstreamNodeInput(graph, slice_node, info.GetOutputIndex() /*output_idx*/, current_node, + output_index /*replacement_output_idx*/); + auto gather_origin_consumer_nodes = graph.GetConsumerNodes(slice_op_output_arg->Name()); + std::vector slice_op_consumers; + slice_op_consumers.reserve(gather_origin_consumer_nodes.size()); + for (auto& consumer_node : gather_origin_consumer_nodes) { + slice_op_consumers.push_back(graph.GetNode(consumer_node->Index())); + LOG_DEBUG_INFO(logger, "RemoveOriginSlicingOp Gather's consumer node " + consumer_node->Name() + "(" + + consumer_node->OpType() + ")"); + } + graph.UpdateConsumerNodes(current_node.OutputDefs()[output_index]->Name(), slice_op_consumers); + + graph.UpdateConsumerNodes(slice_op_output_arg->Name(), {}); + graph.RemoveNode(slice_node.Index()); + + return Status::OK(); +} + +} // namespace + +std::optional ComputeOptimizer::IsSupportedGatherND(Graph& /*graph*/, Node& node, + const logging::Logger& logger) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "GatherND", {1, 12, 13}, kOnnxDomain) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + return std::nullopt; + } + + auto data_shape = node.MutableInputDefs()[0]->Shape(); + auto indices_shape = node.MutableInputDefs()[1]->Shape(); + auto gather_out_shape = node.MutableOutputDefs()[0]->Shape(); + if (data_shape == nullptr || indices_shape == nullptr || gather_out_shape == nullptr) { + LOG_DEBUG_INFO(logger, "Skip GatherND node " + node.Name() + " due to undefined shape."); + return std::nullopt; + } + + const auto data_rank = data_shape->dim_size(); + const auto indices_rank = indices_shape->dim_size(); + + // batch_dims is an integer indicating the number of batch dimensions, + // i.e the leading b number of dimensions of data tensor and indices are representing the batches, + // and the gather starts from the b+1 dimension. + auto batch_dims = static_cast(node.GetAttributes().at("batch_dims").i()); + ORT_ENFORCE(batch_dims >= 0 && batch_dims < indices_rank && batch_dims < data_rank, + "batch_dims must be in the range [0, min(indices_rank, data_rank)):" + std::to_string(batch_dims) + + " indices_rank:" + std::to_string(indices_rank) + " data_rank:" + std::to_string(data_rank)); + + // Since GatherND is assumed to have batch_dims=1, if the input data's shape is [batch, sequence, ..., ... ], + // limiting indices_rank=3 will make sure produced output is in shape [batch, sliced_sequence, ..., ...] + // and the rank did not change. + // TODO: release the constraint here. + if (data_rank != 3 || indices_rank != 3 || batch_dims != 1) { + return std::nullopt; + } + + auto& indices_last_dim = indices_shape->dim(indices_rank - 1); + if (!(indices_last_dim.has_dim_value() && indices_last_dim.dim_value() == 1)) { + return std::nullopt; + } + + return SliceInfo(&node, false, "batch_dims", static_cast(batch_dims), true); +} + +std::optional ComputeOptimizer::IsSupportedGather(Graph& /*graph*/, Node& node, + const logging::Logger& logger) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}, kOnnxDomain) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + return std::nullopt; + } + + auto data_shape = node.MutableInputDefs()[0]->Shape(); + auto indices_shape = node.MutableInputDefs()[1]->Shape(); + auto gather_out_shape = node.MutableOutputDefs()[0]->Shape(); + if (data_shape == nullptr || indices_shape == nullptr || gather_out_shape == nullptr) { + LOG_DEBUG_INFO(logger, "Skip Gather node " + node.Name() + " due to undefined shape."); + return std::nullopt; + } + + const auto data_rank = data_shape->dim_size(); + if (data_rank <= 1) { + LOG_DEBUG_INFO(logger, "Skip Gather node " + node.Name() + " due to rank <= 1."); + return std::nullopt; + } + + auto axis = static_cast(node.GetAttributes().at("axis").i()); + axis = axis < 0 ? axis + data_rank : axis; + size_t dim_size = static_cast(indices_shape->dim_size()); + bool is_single_value_1d_tensor = dim_size != 0 && (dim_size == 1 && utils::HasDimValue(indices_shape->dim(0)) && + indices_shape->dim(0).dim_value() == 1); + if (dim_size != 0 && !is_single_value_1d_tensor) { + if (dim_size == 1 && utils::HasDimValue(data_shape->dim(axis)) && + data_shape->dim(axis).dim_value() > indices_shape->dim(0).dim_value()) { + // Can support. + } else { + LOG_DEBUG_INFO(logger, "Skip Gather node " + node.Name() + " due to unsupported dim size: " + + std::to_string(dim_size)); + return std::nullopt; + } + } + + return SliceInfo(&node, dim_size == 0, "axis", axis, true); +} + +Status ComputeOptimizer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) + const { + LOG_DEBUG_INFO(logger, "Enter ComputeOptimizer"); + bool reordered = false; + GraphViewer graph_viewer(graph); + const auto& order = graph_viewer.GetNodesInTopologicalOrder(); + const auto& graph_outputs = graph.GetOutputs(); + size_t reordered_node_count = 0; // For summary + for (auto index : order) { + auto* node_ptr = graph.GetNode(index); + if (!node_ptr) + // node was removed. + continue; + + auto& node = *node_ptr; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + std::optional gather_info; + // Same ideas might apply for GatherElements, Slice, Split, etc.. + gather_info = IsSupportedGatherND(graph, node, logger); + if (!gather_info.has_value()) { + gather_info = IsSupportedGather(graph, node, logger); + } + + if (!gather_info.has_value()) { + continue; + } + + auto& output_arg = node.MutableOutputDefs()[0]; + if (std::find(graph_outputs.begin(), graph_outputs.end(), output_arg) != graph_outputs.end()) { + continue; + } + + std::string node_name = node.Name(); + std::string node_type = node.OpType(); + std::deque gather_queue; + gather_queue.push_back(gather_info.value()); + + std::string log_prefix = "Entry node " + node_name + " (" + node_type + ") with axis " + + std::to_string(gather_info.value().non_negative_axis); + LOG_DEBUG_INFO(logger, log_prefix + " starts re-ordering check"); + + SliceOperationReorderHandle handle(node_name); + + // DON'T operate on `node` once this loop starts, as it may be removed from the graph. + while (!gather_queue.empty()) { + SliceInfo info = gather_queue.front(); + Node* gather_node = info.node_ptr; + gather_queue.pop_front(); + Node* slice_input_data_producer = + graph.GetMutableProducerNode(gather_node->MutableInputDefs()[0]->Name()); + if (slice_input_data_producer == nullptr) { + break; + } + Node* input_node = slice_input_data_producer; + if (graph.GetConsumerNodes(input_node->MutableOutputDefs()[0]->Name()).size() > 1) { + LOG_DEBUG_INFO(logger, log_prefix + " stops at node " + input_node->Name() + " since multiple consumer found"); + continue; + } + + auto ret = handle(graph, *input_node, info, logger, gather_queue); + if (ret) { + LOG_DEBUG_INFO(logger, log_prefix + " moves up across node " + input_node->Name()); + modified = true; + reordered = true; + } else { + LOG_DEBUG_INFO(logger, log_prefix + " stops when handling " + input_node->Name()); + } + } + + if (reordered) { + ++reordered_node_count; + } + } + + LOGS(logger, INFO) << "Exit ComputeOptimizer with summary - reorderd_node_count:" << reordered_node_count + << " nodes."; + return Status::OK(); +} + +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/core/optimizer/compute_optimizer/compute_optimizer.h b/onnxruntime/core/optimizer/compute_optimizer/compute_optimizer.h new file mode 100644 index 0000000000..934d3ea4bb --- /dev/null +++ b/onnxruntime/core/optimizer/compute_optimizer/compute_optimizer.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef ENABLE_TRAINING +#pragma once + +#include "core/optimizer/compute_optimizer/passthrough_actors.h" +#include "core/optimizer/graph_transformer.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { + +/** + * @brief Graph transformer that helps reduce compute FLOP while maintaining mathematically equivalent result. + * + * This graph transformation tries to identify opportunities to reduce unnecessary computations on the graph level. + * Currently, the major optimization is to bring some slice operators ahead as much as possible, to leave more ops + * operate on sliced input data. Gather and GatherND are the entry operators that trigger the optimization search. + * + * In terms of file dependency, compute_optimizer.h/cc reference structs and utilities defined in + * passthrough_actors.h/cc. + */ +class ComputeOptimizer : public GraphTransformer { + public: + using SliceInfo = onnxruntime::optimizer::compute_optimizer::SliceInfo; + ComputeOptimizer(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("ComputeOptimizer", compatible_execution_providers) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + + private: + std::optional IsSupportedGatherND(Graph& graph, Node& node, const logging::Logger& logger) const; + std::optional IsSupportedGather(Graph& graph, Node& node, const logging::Logger& logger) const; +}; + +} // namespace onnxruntime +#endif diff --git a/onnxruntime/core/optimizer/compute_optimizer/passthrough_actors.cc b/onnxruntime/core/optimizer/compute_optimizer/passthrough_actors.cc new file mode 100644 index 0000000000..82c5084b03 --- /dev/null +++ b/onnxruntime/core/optimizer/compute_optimizer/passthrough_actors.cc @@ -0,0 +1,844 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef ENABLE_TRAINING +#include + +#include "core/common/safeint.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" +#include "core/optimizer/compute_optimizer/passthrough_actors.h" +#include "core/optimizer/compute_optimizer/compute_optimizer.h" + +using namespace ONNX_NAMESPACE; +using namespace ::onnxruntime::common; +namespace onnxruntime::optimizer::compute_optimizer { + +enum class DimCompareRet { + ExactEqual = 0, + BroadcastableEqual = 1, + RankTooLow = 2, + NotEqual = 3, + DimCompareRetMax = 4, +}; + +/** + * @brief Check dimensions are equal or broadcastable before axis. + * + * @param full_broadcasted_shape Full broadcasted shape as a baseline to compare. + * @param axis The axis (inclusive, of full_broadcasted_shape) where we end the comparison. + * @param target_shape Shape to compare, can have dim value be 1 for broadcastable dimension. + * @return A pair of bool, bool. The first bool is true if the dimensions are exactly same before and include axis. + * The second bool is true if the dimension of target_shape has dim value be 1 on axis. + */ +std::pair AreDimsCompatibleBeforeAxisInternal( + const TensorShapeProto* full_broadcasted_shape, const int axis, + const TensorShapeProto* target_shape) { + int full_rank = full_broadcasted_shape->dim_size(); + int target_rank = target_shape->dim_size(); + + ORT_ENFORCE(full_rank >= axis && target_rank <= full_rank, "full_rank should bigger than axis and target_rank ", + axis, " full_rank: ", full_rank, " target_rank: ", target_rank); + + int minimum_rank_to_handle = full_rank - axis; + if (target_rank < minimum_rank_to_handle) { + // Skip if target node's input rank is less than minimum rank to handle. + // Essentially this means the input did not affect the Gather axis. + return std::make_pair(DimCompareRet::RankTooLow, false); + } + + bool exact_equal = true; + bool broadcastable_equal = true; + bool dim_be_1_on_axis = false; + + int axis_iter = axis; + int negative_axis = axis < 0 ? axis : axis - full_rank; + int target_axis_iter = target_rank + negative_axis; + + for (; axis_iter >= 0 && target_axis_iter >= 0; --axis_iter, --target_axis_iter) { + auto& dim = full_broadcasted_shape->dim(axis_iter); + auto& target_dim = target_shape->dim(target_axis_iter); + if (dim.has_dim_value() && target_dim.has_dim_value()) { + if (dim.dim_value() != target_dim.dim_value()) { + exact_equal = false; + if (target_dim.dim_value() == 1) { + if (axis_iter == axis) dim_be_1_on_axis = true; + } else { + broadcastable_equal = false; + } + } + } else if (dim.has_dim_param() && target_dim.has_dim_param()) { + if (dim.dim_param() != target_dim.dim_param()) { + exact_equal = false; + } + } else { + exact_equal = false; + if (target_dim.has_dim_value() && target_dim.dim_value() == 1) { + if (axis_iter == axis) dim_be_1_on_axis = true; + } else { + broadcastable_equal = false; + } + } + } + + if (exact_equal) { + return std::make_pair(DimCompareRet::ExactEqual, dim_be_1_on_axis); + } else if (broadcastable_equal) { + return std::make_pair(DimCompareRet::BroadcastableEqual, dim_be_1_on_axis); + } else { + return std::make_pair(DimCompareRet::NotEqual, dim_be_1_on_axis); + } +} + +/** + * @brief Check input meet pass through requirement. + * + * @param current_node_output_arg_to_gather The output arg of current node that consumed by slice node. + * @param arg_to_compare The input/output arg to check. + * @param info Slice info. + * @param logger The logger. + * @param fatal_error_found Used as return value. If fatal error found, set to true. Fatal error means, + * we cannot pass through this input arg. + * @param dim_1_for_axis_found Used as return value. If dim value is 1 for axis, set to true. + * @return a int represent the new slice axis for the input arg, if pass through needed to be done for + * this input arg, otherwise, return nullptr. + * + * For each input of current_node, using this function to check if the input can be passed through. + * If the input has dim on negative_axis and + * 1). either the dimension (if exists) including and before negative_axis is same as target node's output shape. + * 2). or the dimension (if exists) including and before negative_axis is 1. + * Otherwise, we will skip the optimization. + * + * Example 1: [Can be passed through] + * input_0 [M, N, K] input_1 [K] + * \ / + * Add [M, N, K] (current_node) + * | + * Gather0(axis=1, indices=[1]) + * | + * output [M, 1, K] + * In this case, we can propagate Gather to input_0 branch, input_1 is skipped because it did not has dim on + * slicing axis. + * + * Example 2: [Can be passed through] + * input_0 [M, N, K] input_1 [N, K] + * \ / + * Add [M, N, K] (current_node) + * | + * Gather0(axis=1, indices=[1]) + * | + * output [M, 1, K] + * In this case, we can propagate Gather to input_0 and input-1 branch, because including and before slicing axis 1, + * all dims are equal. + * + * Example 3: [Can be passed through] + * input_0 [M, N, K] input_1 [1, K] + * \ / + * Add [M, N, K] (current_node) + * | + * Gather0(axis=1, indices=[1]) + * | + * output [M, 1, K] + * In this case, we can propagate Gather to input_0 branch, input_1 branch is skipped because it has dim 1 on slicing + * axis. + * + * Example 4: [Can be passed through] + * input_0 [M, N, K] input_1 [1, N, K] + * \ / + * Add [M, N, K] (current_node) + * | + * Gather0(axis=1, indices=[1]) + * | + * output [M, 1, K] + * In this case, we can propagate Gather to input_0 and input_1 branch. + * + * Example 5: [Can be passed through] + * input_0 [M, N, K] input_1 [M, 1, K] + * \ / + * Add [M, N, K] (current_node) + * | + * Gather0(axis=1, indices=[1]) + * | + * output [M, 1, K] + * In this case, we can propagate Gather to input_0 branch, input_1 branch is skipped because it has dim 1 on slicing. + * + * Example 6: [CANNOT be passed through] + * input_0 [M, N, K] input_1 [L, N, K] + * \ / + * Add [M, N, K] (current_node) + * | + * Gather0(axis=1, indices=[1]) + * | + * output [M, 1, K] + * + */ +std::optional CheckInputForPassThrough(const NodeArg* current_node_output_arg_to_gather, + const NodeArg* arg_to_compare, + const SliceInfo& info, + const logging::Logger& logger, + bool& fatal_error_found, + bool& dim_1_for_axis_found) { + fatal_error_found = false; + auto ret_pair = AreDimsCompatibleBeforeAxisInternal(current_node_output_arg_to_gather->Shape(), + info.non_negative_axis, + arg_to_compare->Shape()); + if (ret_pair.first == DimCompareRet::ExactEqual) { + return info.non_negative_axis; + } else if (ret_pair.first == DimCompareRet::RankTooLow) { + LOG_DEBUG_INFO(logger, "Skip " + arg_to_compare->Name() + " because its rank is too low."); + return std::nullopt; + } else if (ret_pair.first == DimCompareRet::NotEqual) { + fatal_error_found = true; + return std::nullopt; + } else if (ret_pair.first == DimCompareRet::BroadcastableEqual) { + if (ret_pair.second) { + LOG_DEBUG_INFO(logger, "Skip " + arg_to_compare->Name() + + ", whose dim on axis is 1, no need to Gather from."); + dim_1_for_axis_found = true; + return std::nullopt; + } + return info.non_negative_axis; + } + + ORT_THROW("Unexpected return value from CheckInputForPassThrough."); +} + +/** + * @brief From given TensorShape, update specified dimension with given value. + * If no new_dim is provided, the dimension will be removed. + * + * @param shape TensorShape used as base shape to modify. + * @param axis The dimension to be replaced/removed. + * @param new_dim The new dimension value. If not provided, the dimension will be removed. + * @return TensorShapeProto A copy of "shape" after modification. + */ +TensorShapeProto CreateNewShapeWithUpdatedDim(const TensorShapeProto* shape, const int axis, + const TensorShapeProto_Dimension& new_dim) { + ORT_ENFORCE(axis >= 0 && axis < shape->dim_size()); + TensorShapeProto output_shape; + for (int i = 0; i < shape->dim_size(); ++i) { + auto& dim = shape->dim(i); + if (i == axis) { + if (new_dim.has_dim_value()) { + output_shape.add_dim()->set_dim_value(new_dim.dim_value()); + } else if (new_dim.has_dim_param()) { + output_shape.add_dim()->set_dim_param(new_dim.dim_param()); + } else { + // do nothing, unassigned dim will be removed. + } + + continue; + } + + if (dim.has_dim_value()) { + output_shape.add_dim()->set_dim_value(dim.dim_value()); + } else if (dim.has_dim_param()) { + output_shape.add_dim()->set_dim_param(dim.dim_param()); + } else { + ORT_THROW("Invalid dim found in CreateNewShapeWithUpdatedDim"); + } + } + + return output_shape; +} + +bool UpdateSliceOutputShape(NodeArg& arg_to_update, int reverse_axis, const TensorShapeProto_Dimension& output_dim_on_axis) { + ORT_ENFORCE(reverse_axis < 0, " reverse_axis should be negative, representing the index from right to left."); + const TensorShapeProto* shape = arg_to_update.Shape(); + int rank = shape->dim_size(); + if (rank < -reverse_axis) { + return false; + } + + int axis_to_update = rank + reverse_axis; + TensorShapeProto new_output_shape = CreateNewShapeWithUpdatedDim(shape, axis_to_update, output_dim_on_axis); + arg_to_update.SetShape(new_output_shape); + return true; +} + +Node* InsertIntermediateNodeOnDestInput(Graph& graph, + Node& dest_node, int dest_in_index, + int new_node_input_index, + int new_node_output_index, + const std::string& name, const std::string& op_type, + const std::string& description, + const InlinedVector& input_args, + const InlinedVector& output_args, + const onnxruntime::NodeAttributes& attributes, + const std::string& domain, + const logging::Logger& logger) { + LOG_DEBUG_INFO(logger, "Inserting " + op_type + " node on " + dest_node.Name() + " 's " + + std::to_string(dest_in_index) + "th input " + + dest_node.InputDefs()[dest_in_index]->Name() + ", and connect inserted node's " + + std::to_string(new_node_output_index) + "th output to " + dest_node.Name() + " 's " + + std::to_string(dest_in_index) + "th input."); + + ORT_ENFORCE(dest_in_index < static_cast(dest_node.InputDefs().size())); + ORT_ENFORCE(new_node_input_index < static_cast(input_args.size()), "new_node_input_index is out of range."); + ORT_ENFORCE(new_node_output_index < static_cast(output_args.size()), "new_node_output_index is out of range."); + ORT_ENFORCE(dest_node.MutableInputDefs()[dest_in_index] == input_args[new_node_input_index], + "input_args[new_node_input_index] is not the same as dest_node.MutableInputDefs()[dest_in_index].", + dest_node.MutableInputDefs()[dest_in_index]->Name(), " vs ", input_args[new_node_input_index]->Name()); + + // Prepare Input and Outputs for the duplicated Gather/GatherND node. + NodeArg* src_node_arg = dest_node.MutableInputDefs()[dest_in_index]; + + // Create the duplicated Gather/GatherND node. + Node& new_node = graph.AddNode(name, op_type, description, input_args, output_args, &attributes, domain); + ORT_ENFORCE(graph.SetOpSchemaFromRegistryForNode(new_node), "Failed to set op schema for " + new_node.Name()); + + // Connect dest_node's input node to duplicated node. + // Update new node producer and consumer map. + for (size_t j = 0; j < new_node.MutableOutputDefs().size(); ++j) { + graph.UpdateProducerNode(new_node.MutableOutputDefs()[j]->Name(), new_node.Index()); + } + graph.AddConsumerNode(src_node_arg->Name(), &new_node); + const Node* src_node = graph.GetProducerNode(src_node_arg->Name()); + if (src_node) { + int src_out_index = optimizer_utils::IndexOfNodeOutput(*src_node, *src_node_arg); + graph.AddEdge(src_node->Index(), new_node.Index(), src_out_index, new_node_input_index); + } + + // Remove edge between dest_node and src_node. + // Be noted, this will remove dest_node's input edges to src_node + // (and also the src_node's output edges to dest_node). + std::vector input_edge_to_remove; + input_edge_to_remove.reserve(1); + for (auto it = dest_node.InputEdgesBegin(), end = dest_node.InputEdgesEnd(); it != end; ++it) { + LOG_DEBUG_INFO(logger, "dest_node " + dest_node.Name() + " input edge: " + it->GetNode().Name() + + " output index: " + std::to_string(it->GetSrcArgIndex()) + " input index: " + + std::to_string(it->GetDstArgIndex())); + if (it->GetDstArgIndex() == dest_in_index) { + input_edge_to_remove.push_back(graph_utils::GraphEdge::CreateGraphEdge(dest_node, *it, true)); + break; + } + } + + // If the input is graph input or initializer, no edge will be removed. + if (input_edge_to_remove.size() > 0) { + graph_utils::GraphEdge::RemoveGraphEdges(graph, input_edge_to_remove); + + // Remove target node from target input arg's consumer list. + const std::string& src_node_arg_name = src_node_arg->Name(); + int input_use_count_by_dest_node = 0; + for (size_t i = 0; i < dest_node.InputDefs().size(); ++i) { + if (dest_node.InputDefs()[i]->Name().compare(src_node_arg_name) == 0) { + ++input_use_count_by_dest_node; + } + } + + if (input_use_count_by_dest_node == 1) { + graph.RemoveConsumerNode(src_node_arg_name, &dest_node); + } + } + + // Connect duplicated gather node to target node's input. + dest_node.MutableInputDefs()[dest_in_index] = new_node.MutableOutputDefs()[new_node_output_index]; + // Add new edge connecting the duplicated gather with the target node directly. + // This also updates the destination node's input node args + graph.AddEdge(new_node.Index(), dest_node.Index(), new_node_output_index, dest_in_index); + graph.AddConsumerNode(new_node.MutableOutputDefs()[new_node_output_index]->Name(), &dest_node); + LOG_DEBUG_INFO(logger, "Inserted " + op_type + " node on " + dest_node.Name() + " 's " + + std::to_string(dest_in_index) + "th input " + + dest_node.InputDefs()[dest_in_index]->Name()); + return &new_node; +} + +TensorShapeProto CreateTensorShapeInsertDimAtAxis(const TensorShapeProto* src_shape, int axis, int64_t dim_value) { + ORT_ENFORCE(axis <= src_shape->dim_size(), "axis is out of range.", axis, " vs ", src_shape->dim_size()); + TensorShapeProto updated_shape; + int j = 0; + for (j = 0; j < axis; ++j) { + auto dim = src_shape->dim(j); + if (dim.has_dim_value()) { + updated_shape.add_dim()->set_dim_value(dim.dim_value()); + } else if (dim.has_dim_param()) { + updated_shape.add_dim()->set_dim_param(dim.dim_param()); + } else { + ORT_THROW("Invalid dim found in CreateTensorShapeInsertDimAtAxis"); + } + } + updated_shape.add_dim()->set_dim_value(dim_value); + for (; j < src_shape->dim_size(); ++j) { + auto dim = src_shape->dim(j); + if (dim.has_dim_value()) { + updated_shape.add_dim()->set_dim_value(dim.dim_value()); + } else if (dim.has_dim_param()) { + updated_shape.add_dim()->set_dim_param(dim.dim_param()); + } else { + ORT_THROW("Invalid dim found in CreateTensorShapeInsertDimAtAxis"); + } + } + return updated_shape; +} + +NodeArg* CreateUnsqueezeAxesInitializer(Graph& graph, const std::vector& values) { + ONNX_NAMESPACE::TensorProto axes_const_tensor; + axes_const_tensor.set_name(graph.GenerateNodeArgName("axes")); + axes_const_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + axes_const_tensor.add_dims(values.size()); + axes_const_tensor.set_raw_data(values.data(), values.size() * sizeof(int64_t)); + return &graph_utils::AddInitializer(graph, axes_const_tensor); +} + +int GetONNXOpSetVersion(const Graph& graph) { + int onnx_opset = -1; + auto onnx_domain_it = graph.DomainToVersionMap().find(kOnnxDomain); + if (onnx_domain_it != graph.DomainToVersionMap().end()) { + onnx_opset = onnx_domain_it->second; + } else { + auto onnx_domain_alias_it = graph.DomainToVersionMap().find(kOnnxDomainAlias); + if (onnx_domain_alias_it != graph.DomainToVersionMap().end()) + onnx_opset = onnx_domain_alias_it->second; + else + ORT_THROW("ONNX domain not found in this model"); + } + return onnx_opset; +} + +void AdaptInputAndOutputForScalarSlice(Graph& graph, Node& current_node, int current_node_output_index, + int slice_axis, const std::string& entry_node_name, + const std::unordered_map& new_gather_infos, + const logging::Logger& logger) { + LOG_DEBUG_INFO(logger, "AdaptInputAndOutputForScalarSlice for Node " + current_node.Name() + "(" + + current_node.OpType() + ")"); + + // For each handled inputs, insert Unsqueeze node to get the removed dim back at slice_axis. + for (auto pair : new_gather_infos) { + int input_index = pair.first; + Node* new_node = nullptr; + // Be noted, the Unsqueeze should happens on the axis of new slice node. + if (GetONNXOpSetVersion(graph) < 13) { + onnxruntime::NodeAttributes attributes; + attributes["axes"] = ONNX_NAMESPACE::MakeAttribute("axes", std::vector{pair.second.non_negative_axis}); + + new_node = + InsertIntermediateNodeOnDestInput( + graph, + current_node, input_index, + 0 /* new node input index to connect to current_node's input node*/, + 0 /* new node output index to connect to current_node*/, + graph.GenerateNodeName(entry_node_name + "_adapt_input"), + "Unsqueeze", + "Unsqueeze node", + {current_node.MutableInputDefs()[input_index]}, + {&graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("unsqueeze_adaptor"), + current_node.MutableInputDefs()[input_index]->TypeAsProto())}, + attributes, kOnnxDomain, + logger); + } else { + new_node = + InsertIntermediateNodeOnDestInput( + graph, + current_node, input_index, + 0 /* new node input index to connect to current_node's input node*/, + 0 /* new node output index to connect to current_node*/, + graph.GenerateNodeName(entry_node_name + "_adapt_input"), + "Unsqueeze", + "Unsqueeze node", + {current_node.MutableInputDefs()[input_index], + CreateUnsqueezeAxesInitializer(graph, {pair.second.non_negative_axis})}, + {&graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("unsqueeze_adaptor"), + current_node.MutableInputDefs()[input_index]->TypeAsProto())}, + {}, kOnnxDomain, + logger); + } + new_node->SetExecutionProviderType(current_node.GetExecutionProviderType()); + // Set correct shape for Unsqueeze node + const TensorShapeProto* unsqueeze_input_shape = new_node->MutableInputDefs()[0]->Shape(); + new_node->MutableOutputDefs()[0]->SetShape( + CreateTensorShapeInsertDimAtAxis(unsqueeze_input_shape, pair.second.non_negative_axis, 1)); + } + + // Find the consumer node of MatMul, and the input index of that node connect to MatMul. + std::vector consumers = + graph.GetConsumerNodes(current_node.MutableOutputDefs()[current_node_output_index]->Name()); + ORT_ENFORCE(consumers.size() >= 1, "MatMul should have at least one consumer at this point. " + + std::to_string(consumers.size()) + " consumers found."); + Node& consumer = *graph.GetNode(consumers[0]->Index()); + int index = -1; + for (size_t i = 0; i < consumer.InputDefs().size(); ++i) { + auto input_arg = consumer.InputDefs()[i]; + if (input_arg->Name().compare(current_node.MutableOutputDefs()[current_node_output_index]->Name()) == 0) { + index = static_cast(i); + break; + } + } + + // Create Squeeze node connecting MatMul output to consumer node. + Node* matmul_out_adaptor_node = nullptr; + if (GetONNXOpSetVersion(graph) < 13) { + onnxruntime::NodeAttributes attributes; + attributes["axes"] = ONNX_NAMESPACE::MakeAttribute("axes", std::vector{slice_axis}); + matmul_out_adaptor_node = + InsertIntermediateNodeOnDestInput( + graph, consumer, index, + 0, + 0 /* new node output index*/, + graph.GenerateNodeName(current_node.OpType() + "_output"), + "Squeeze", + "Squeeze node", + {consumer.MutableInputDefs()[index]}, + {&graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("squeeze_adaptor"), + consumer.MutableInputDefs()[index]->TypeAsProto())}, + attributes, kOnnxDomain, logger); + } else { + matmul_out_adaptor_node = + InsertIntermediateNodeOnDestInput( + graph, consumer, index, + 0, + 0 /* new node output index*/, + graph.GenerateNodeName(current_node.OpType() + "_output"), + "Squeeze", + "Squeeze node", + {consumer.MutableInputDefs()[index], + CreateUnsqueezeAxesInitializer(graph, {slice_axis})}, + {&graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("squeeze_adaptor"), + consumer.MutableInputDefs()[index]->TypeAsProto())}, + {}, kOnnxDomain, logger); + } + + matmul_out_adaptor_node->SetExecutionProviderType(current_node.GetExecutionProviderType()); + + // Don't need set shape for Squeeze because original MatMul output is used as its output type. + // Set correct shape for MatMul node + const TensorShapeProto* matmul_out_shape = matmul_out_adaptor_node->MutableOutputDefs()[0]->Shape(); + current_node.MutableOutputDefs()[0]->SetShape(CreateTensorShapeInsertDimAtAxis(matmul_out_shape, slice_axis, 1)); +} + +bool DefaultOperatorPassThroughActorBase::PostProcess( + Graph& graph, Node& current_node, int current_node_output_index, + int slice_axis, bool is_slice_scalar, bool input_has_dim_1_for_axis, + const ONNX_NAMESPACE::TensorShapeProto_Dimension& /*output_dim_on_axis*/, + const std::string& entry_node_name, + const std::unordered_map& new_gather_infos, + const logging::Logger& logger) { + LOG_DEBUG_INFO(logger, "Enter DefaultOperatorPassThroughActorBase::PostProcess for Node " + current_node.Name() + + "(" + current_node.OpType() + ")"); + if (is_slice_scalar && input_has_dim_1_for_axis) { + AdaptInputAndOutputForScalarSlice(graph, current_node, current_node_output_index, slice_axis, + entry_node_name, new_gather_infos, logger); + } + + return true; +} + +bool SimplePassThroughActor::PreCheck(const Graph& /*graph*/, const Node& current_node, const SliceInfo& info, + const std::vector& allowed_input_indices, + const logging::Logger& logger, + std::unordered_map& propagate_input_config, + bool& input_has_dim_1_for_axis) { + LOG_DEBUG_INFO(logger, "Enter SimplePassThroughActor::PreCheck for node " + current_node.Name()); + + Node* slice_node = info.node_ptr; + int current_node_output_index = optimizer_utils::IndexOfNodeOutput(current_node, *slice_node->InputDefs()[0]); + const NodeArg* gather_data_input_arg = current_node.OutputDefs()[current_node_output_index]; + + propagate_input_config.clear(); + input_has_dim_1_for_axis = false; + for (size_t i = 0; i < current_node.InputDefs().size(); ++i) { + if (allowed_input_indices.size() > 0 && + std::find(allowed_input_indices.begin(), allowed_input_indices.end(), i) == allowed_input_indices.end()) { + continue; + } + bool fatal_error_found = false; + auto ret = CheckInputForPassThrough(gather_data_input_arg, current_node.InputDefs()[i], info, logger, + fatal_error_found, input_has_dim_1_for_axis); + if (fatal_error_found) { + LOG_DEBUG_INFO(logger, "Skip for node " + current_node.Name() + " due to input check failure at index " + + std::to_string(i)); + return false; + } else if (ret.has_value()) { + propagate_input_config[static_cast(i)] = ret.value(); + } + } + + // Make sure once Gather is moved before target node, all its outputs can be correctly be sliced. + std::unordered_map output_indices; + for (size_t i = 0; i < current_node.OutputDefs().size(); ++i) { + if (static_cast(i) == current_node_output_index) { + continue; + } + + bool fatal_error_found = false; + bool dim_1_for_axis_found = false; + auto ret = CheckInputForPassThrough(gather_data_input_arg, current_node.OutputDefs()[i], info, logger, + fatal_error_found, dim_1_for_axis_found); + if (fatal_error_found) { + LOG_DEBUG_INFO(logger, "Skip for node " + current_node.Name() + " due to output check failure at index " + + std::to_string(i)); + return false; + } else if (ret.has_value()) { + output_indices[static_cast(i)] = ret.value(); + } + } + bool output_check_success = output_indices.size() == current_node.OutputDefs().size() - 1; + + return output_check_success; +} + +bool ReductionOpPassThroughActor::PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info, + const std::vector& allowed_input_indices, + const logging::Logger& logger, + std::unordered_map& propagate_input_config, + bool& input_has_dim_1_for_axis) { + auto axis = static_cast(current_node.GetAttributes().at("axis").i()); + axis = axis < 0 ? axis + current_node.InputDefs()[0]->Shape()->dim_size() : axis; + + // Make sure layernorm/softmax's reduction happens after the axis we want to slice. + if (axis <= info.non_negative_axis) { + return false; + } + + return SimplePassThroughActor::PreCheck(graph, current_node, info, allowed_input_indices, logger, + propagate_input_config, input_has_dim_1_for_axis); +} +bool ReshapePassThroughActor::PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info, + const std::vector& /*allowed_input_indices*/, + const logging::Logger& logger, + std::unordered_map& propagate_input_config, + bool& /*input_has_dim_1_for_axis*/) { + auto data_input_shape = current_node.InputDefs()[0]->Shape(); + auto shape_input_shape = current_node.InputDefs()[1]->Shape(); + auto output_shape = current_node.OutputDefs()[0]->Shape(); + if (data_input_shape == nullptr || shape_input_shape == nullptr || shape_input_shape->dim_size() != 1 || + output_shape == nullptr) { + LOG_DEBUG_INFO(logger, "Reshape input/output node arg shape is not valid."); + return false; + } + + if (!graph_utils::IsConstantInitializer(graph, current_node.InputDefs()[1]->Name())) { + LOG_DEBUG_INFO(logger, "Skip handle the Reshape, because the new shape is not constant."); + return false; + } + + propagate_input_config.clear(); + + InlinedVector new_shape_const_values; + optimizer_utils::AppendTensorFromInitializer(graph, *current_node.InputDefs()[1], new_shape_const_values, true); + // Only below two cases are supported for easier updating shape data after propagate slice ops. + // 1). If the shape data on slicing axis is zero (e.g. remain the same after slicing), we support it. + // 2). Or if the sliced dim value is a constant, we also support it, and can update the shape data directly. + // For other cases, it is feasible to support but we don't support for now. + if (new_shape_const_values[info.non_negative_axis] == 0 || info.output_dim_on_axis.has_dim_value()) { + auto in_dims = data_input_shape->dim(); + auto out_dims = output_shape->dim(); + int in_rank = in_dims.size(); + int out_rank = out_dims.size(); + + int reshape_input_axis = -1; + // Match from left to right. + for (int i = 0; i < std::min(in_rank, out_rank); ++i) { + bool dim_value_eq = in_dims[i].has_dim_value() && out_dims[i].has_dim_value() && + in_dims[i].dim_value() == out_dims[i].dim_value(); + bool dim_param_eq = in_dims[i].has_dim_param() && out_dims[i].has_dim_param() && + in_dims[i].dim_param() == out_dims[i].dim_param(); + if (dim_value_eq || dim_param_eq) { + if (i == info.non_negative_axis) { + reshape_input_axis = i; + break; + } + continue; + } + } + + if (reshape_input_axis == -1) { + // Match from right to left. + for (int i = 0; i < std::min(in_rank, out_rank); ++i) { + int in_index = in_rank - 1 - i; + int out_index = out_rank - 1 - i; + bool dim_value_eq = in_dims[in_index].has_dim_value() && out_dims[out_index].has_dim_value() && + in_dims[in_index].dim_value() == out_dims[out_index].dim_value(); + bool dim_param_eq = in_dims[in_index].has_dim_param() && out_dims[out_index].has_dim_param() && + in_dims[in_index].dim_param() == out_dims[out_index].dim_param(); + if (dim_value_eq || dim_param_eq) { + if (out_index == info.non_negative_axis) { + reshape_input_axis = in_index; + break; + } + continue; + } + } + } + + if (reshape_input_axis == -1) { + LOG_DEBUG_INFO(logger, "Cannot find Reshape's input axis for Gather."); + return false; + } + + propagate_input_config[0] = reshape_input_axis; + return true; + } + + return false; +} + +bool ReshapePassThroughActor::PostProcess(Graph& graph, Node& current_node, int /*current_node_output_index*/, + int slice_axis, bool is_slice_scalar, bool /*input_has_dim_1_for_axis*/, + const ONNX_NAMESPACE::TensorShapeProto_Dimension& output_dim_on_axis, + const std::string& /*entry_node_name*/, + const std::unordered_map& /*new_gather_infos*/, + const logging::Logger& logger) { + LOG_DEBUG_INFO(logger, "ReshapePostProcess for Node " + current_node.Name() + "(" + current_node.OpType() + ")"); + InlinedVector new_shape_const_values; + optimizer_utils::AppendTensorFromInitializer(graph, *current_node.InputDefs()[1], new_shape_const_values, true); + + auto create_new_initializer_from_vector = [&graph](NodeArg* arg_to_be_replaced, + const InlinedVector& new_values) -> NodeArg* { + // Create new TensorProto. + ONNX_NAMESPACE::TensorProto constant_tensor_proto; + constant_tensor_proto.set_name(graph.GenerateNodeArgName(arg_to_be_replaced->Name())); + constant_tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + auto length = new_values.size(); + constant_tensor_proto.add_dims(length); + constant_tensor_proto.set_raw_data(new_values.data(), length * sizeof(int64_t)); + + // Add initializer into Graph. + NodeArg* new_shape_arg = &graph_utils::AddInitializer(graph, constant_tensor_proto); + // Update the output arg shape. + ONNX_NAMESPACE::TensorShapeProto new_shape; + new_shape.add_dim()->set_dim_value(length); + new_shape_arg->SetShape(new_shape); + + return new_shape_arg; + }; + + // If the shape constant on slice_axis is 0, then it keeps the original dim of input. + // If it is scalar slice, then we just remove that dim. Otherwise, we don't need to update the dim value. + if (new_shape_const_values[slice_axis] == 0) { + if (is_slice_scalar) { + LOG_DEBUG_INFO(logger, "Removing axis " + std::to_string(slice_axis) + " from shape tensor."); + NodeArg* arg_to_be_replaced = current_node.MutableInputDefs()[1]; + InlinedVector new_values; + for (int i = 0; i < static_cast(new_shape_const_values.size()); ++i) { + if (i != slice_axis) { + new_values.push_back(new_shape_const_values[i]); + } + } + auto new_shape_arg = create_new_initializer_from_vector(arg_to_be_replaced, new_values); + graph_utils::ReplaceNodeInput(current_node, 1, *new_shape_arg); + } else { + LOG_DEBUG_INFO(logger, "Reshape's shape has 0 specified for aixs: " + std::to_string(slice_axis) + + ", not need update."); + } + return true; + } + + // If it selected shape is dim value, we can update the shape tensor directory. + if (output_dim_on_axis.has_dim_value()) { + new_shape_const_values[slice_axis] = output_dim_on_axis.dim_value(); + auto new_shape_arg = create_new_initializer_from_vector(current_node.MutableInputDefs()[1], new_shape_const_values); + graph_utils::ReplaceNodeInput(current_node, 1, *new_shape_arg); + return true; + } + + ORT_THROW("Fail to update shape data in ReshapePassThroughActor::PostProcess, but this should not be called."); +} + +bool TransposePassThroughActor::PreCheck(const Graph& /*graph*/, const Node& current_node, const SliceInfo& info, + const std::vector& /*allowed_input_indices*/, + const logging::Logger& logger, + std::unordered_map& propagate_input_config, + bool& input_has_dim_1_for_axis) { + InlinedVector perm; + if (!graph_utils::GetRepeatedNodeAttributeValues(current_node, "perm", perm)) { + LOG_DEBUG_INFO(logger, "perm attribute is not set for node " + current_node.Name()); + return false; + } + propagate_input_config.clear(); + propagate_input_config[0] = static_cast(perm[info.non_negative_axis]); + input_has_dim_1_for_axis = false; + return true; +} + +bool TransposePassThroughActor::PostProcess(Graph& graph, Node& current_node, int current_node_output_index, + int slice_axis, bool is_slice_scalar, bool /*input_has_dim_1_for_axis*/, + const ONNX_NAMESPACE::TensorShapeProto_Dimension& /*output_dim_on_axis*/, + const std::string& entry_node_name, + const std::unordered_map& new_gather_infos, + const logging::Logger& logger) { + LOG_DEBUG_INFO(logger, "Enter TransposePassThroughActor::PostProcess for Node " + current_node.Name() + "(" + + current_node.OpType() + ")"); + + // We need keep the original dimension to align with original perm. + if (is_slice_scalar) { + AdaptInputAndOutputForScalarSlice(graph, current_node, current_node_output_index, slice_axis, + entry_node_name, new_gather_infos, logger); + } + return true; +} + +bool MatMulPassThroughActor::PreCheck(const Graph& /*graph*/, const Node& current_node, const SliceInfo& info, + const std::vector& allowed_input_indices, + const logging::Logger& logger, + std::unordered_map& propagate_input_config, + bool& input_has_dim_1_for_axis) { + LOG_DEBUG_INFO(logger, "Enter MatMulPassThroughActor::PreCheck for node " + current_node.Name()); + auto lhs_rank = current_node.InputDefs()[0]->Shape()->dim_size(); + auto rhs_rank = current_node.InputDefs()[1]->Shape()->dim_size(); + + if (!(lhs_rank >= 2 && rhs_rank >= 2)) { + LOG_DEBUG_INFO(logger, "MatMul input rank lower than 2, skip."); + return false; + } + + propagate_input_config.clear(); + if (info.non_negative_axis == info.input_rank - 1) { + propagate_input_config[1] = rhs_rank - 1; + return true; + } else if (info.non_negative_axis == info.input_rank - 2) { + propagate_input_config[0] = lhs_rank - 2; + return true; + } + + int target_node_output_index = optimizer_utils::IndexOfNodeOutput(current_node, *info.node_ptr->InputDefs()[0]); + const NodeArg* gather_data_input_arg = current_node.OutputDefs()[target_node_output_index]; + + input_has_dim_1_for_axis = false; + for (size_t i = 0; i < current_node.InputDefs().size(); ++i) { + if (allowed_input_indices.size() > 0 && + std::find(allowed_input_indices.begin(), allowed_input_indices.end(), i) == allowed_input_indices.end()) { + continue; + } + bool fatal_error_found = false; + auto ret = CheckInputForPassThrough(gather_data_input_arg, current_node.InputDefs()[i], info, logger, + fatal_error_found, input_has_dim_1_for_axis); + if (fatal_error_found) { + LOG_DEBUG_INFO(logger, "Skip for node " + current_node.Name() + " due to input check failure at index " + + std::to_string(i)); + return false; + } else if (ret.has_value()) { + LOG_DEBUG_INFO(logger, "Add new input candidate for node " + current_node.Name() + " at index " + + std::to_string(i) + " with axis " + std::to_string(ret.value())); + propagate_input_config[static_cast(i)] = ret.value(); + } + } + + return propagate_input_config.size() > 0; +} + +bool MatMulPassThroughActor::PostProcess(Graph& graph, Node& current_node, int current_node_output_index, + int slice_axis, bool is_slice_scalar, bool /*input_has_dim_1_for_axis*/, + const ONNX_NAMESPACE::TensorShapeProto_Dimension& /*output_dim_on_axis*/, + const std::string& entry_node_name, + const std::unordered_map& new_gather_infos, + const logging::Logger& logger) { + LOG_DEBUG_INFO(logger, "Enter MatMulPassThroughActor::PostProcess for Node " + current_node.Name() + "(" + + current_node.OpType() + ")"); + + // We need keep the original dimension to avoid the matmul inputs cannot be compatible to compute. + if (is_slice_scalar) { + AdaptInputAndOutputForScalarSlice(graph, current_node, current_node_output_index, slice_axis, + entry_node_name, new_gather_infos, logger); + } + return true; +} + +} // namespace onnxruntime::optimizer::compute_optimizer + +#endif diff --git a/onnxruntime/core/optimizer/compute_optimizer/passthrough_actors.h b/onnxruntime/core/optimizer/compute_optimizer/passthrough_actors.h new file mode 100644 index 0000000000..a2a5719179 --- /dev/null +++ b/onnxruntime/core/optimizer/compute_optimizer/passthrough_actors.h @@ -0,0 +1,333 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef ENABLE_TRAINING +#pragma once + +// Uncomment for debugging +// #define NEED_LOG_DEBUG_INFO 1 + +#ifdef NEED_LOG_DEBUG_INFO +#define LOG_DEBUG_INFO(logger, message) LOGS(logger, WARNING) << message +#else +#define LOG_DEBUG_INFO(logger, message) \ + ORT_UNUSED_PARAMETER(logger); \ + do { \ + } while (0) +#endif + +namespace onnxruntime::optimizer::compute_optimizer { + +/** + * @brief Struct to hold the information of the slicing operations. + * + * Initially, an instance of this class for entry node is created, as the slice op propagates to entry node's inputs, + * more instances of this class are created. The propagation stops when the all inputs are not supported to be sliced. + */ +struct SliceInfo { + static constexpr int kSliceDataInputIndex = 0; + static constexpr int kSliceOutputIndex = 0; + + SliceInfo(Node* slice_node, + bool is_slice_scalar, + const std::string& slice_axis_attr_name, + int slice_axis, + bool is_entry_node_ptr = false) + : node_ptr(slice_node), is_scalar_slice(is_slice_scalar) { + axis_attr_name = slice_axis_attr_name; + + const NodeArg* input = node_ptr->InputDefs()[kSliceDataInputIndex]; + const NodeArg* output = node_ptr->OutputDefs()[kSliceOutputIndex]; + input_rank = input->Shape()->dim_size(); + non_negative_axis = slice_axis < 0 ? input_rank + slice_axis : slice_axis; + + if (!is_scalar_slice) { + output_dim_on_axis = output->Shape()->dim(non_negative_axis); + } + + if (is_entry_node_ptr) { + entry_slice_arg_name = node_ptr->OutputDefs()[kSliceOutputIndex]->Name(); + } + } + + int GetDataInputIndex() const { + return kSliceDataInputIndex; + } + + int GetOutputIndex() const { + return kSliceOutputIndex; + } + + Node* node_ptr; // The Gather/GatherND node that triggers the optimization search. + bool is_scalar_slice; // whether the slice is a scalar, if it is, after Gather, rank will be reduced by 1. + std::string axis_attr_name; + int non_negative_axis; // The axis to slice on + std::string entry_slice_arg_name; + + int input_rank; // rank of the Gather data input tensor + + // The dimension of the output tensor on the slicing axis + // Be noted: if it is a scalar slicing, this dim will not be set, which means, afterward when use it to update + // shapes, that dim at axis will be removed. + ONNX_NAMESPACE::TensorShapeProto_Dimension output_dim_on_axis; +}; + +/** + * @brief Base class for all pass through actors. + * + * Each actors defines rules to determine whether a node can be passed through, and how to do the pass through. + * PreCheck is the interface to check whether a node can be passed through. + * The pass through is done transparently, without any interface required to implemented. + * PostProcess is the interface to do some adaptor work after the pass through. + */ +class OperatorPassThroughActorBase { + public: + OperatorPassThroughActorBase() = default; + virtual ~OperatorPassThroughActorBase() = default; + + /** + * @brief Check whether a node can be passed through. + * At this point, graph modification is not started, once we see any clues that this node cannot be passed through, + * We should return false immediately. + * + * @param graph The graph that the node belongs to. + * @param current_node The node to be checked. + * @param info The slicing info of the Gather/GatherND node. + * @param allowed_input_indices The input indices explicitly specified of the current_node that are allowed to do pass + * through. + * @param propagate_input_config: Used as a return value - a map of input index to new slice axis. + * The key is an integer, which is the index of the input of the current_node. + * The value is an integer, which is the new axis index after the pass through on the corresponding input. + * For example: + * > if the current_node is a Add node, and the slice axe is 1, then the corresponding input should + * also have axis 1 when we move the slice to the input. + * > if the current_node is a Transpose (perm=[1, 0, 2]) node, and the slice + * axis is 1, then the new axis for the input should be 0. + * @param input_has_dim_1_for_axis: Used as a return value - a bool indicates whether any of current_node' input + * has dim 1 on the slice axis. + */ + virtual bool PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info, + const std::vector& allowed_input_indices, + const logging::Logger& logger, + std::unordered_map& propagate_input_config, + bool& input_has_dim_1_for_axis) = 0; + + /** + * @brief After slice op pass through all inputs, do some post process work. + * + * Be noted: at this point, slice op is already removed, so we cannot access SliceInfo any more, instead, + * we pass important infos including slice_axis, input_rank, is_scalar_slice, etc as parameters of this function. + * + * @param graph The graph that the node belongs to. + * @param current_node The node that has been passed through. + * @param current_node_output_index The output index of the current_node connecting to slice op. + * @param slice_axis slice axis of the slice op. + * @param is_slice_scalar whether the slice is a scalar. + * @param input_has_dim_1_for_axis whether any of current_node's inputs has dim 1 on the slice axis. + * @param output_dim_on_axis dimension of the slice op's output tensor on the slice axis. + * @param entry_node_name name of entry node that trigger the pass through search, for naming only. + * @param new_gather_infos new gather infos that are generated during the pass through for current_node's inputs. + * @param logger + * @return + */ + virtual bool PostProcess(Graph& graph, Node& current_node, int current_node_output_index, + int slice_axis, bool is_slice_scalar, bool input_has_dim_1_for_axis, + const ONNX_NAMESPACE::TensorShapeProto_Dimension& output_dim_on_axis, + const std::string& entry_node_name, + const std::unordered_map& new_gather_infos, + const logging::Logger& logger) = 0; +}; + +class DefaultOperatorPassThroughActorBase : public OperatorPassThroughActorBase { + public: + DefaultOperatorPassThroughActorBase() = default; + ~DefaultOperatorPassThroughActorBase() = default; + + bool PreCheck(const Graph&, const Node&, const SliceInfo&, const std::vector&, const logging::Logger&, + std::unordered_map&, bool&) override { + return true; + }; + + bool PostProcess(Graph& graph, Node& current_node, int current_node_output_index, + int slice_axis, bool is_slice_scalar, bool input_has_dim_1_for_axis, + const ONNX_NAMESPACE::TensorShapeProto_Dimension& output_dim_on_axis, + const std::string& entry_node_name, + const std::unordered_map& new_gather_infos, + const logging::Logger& logger) override; +}; + +class SimplePassThroughActor : public DefaultOperatorPassThroughActorBase { + public: + SimplePassThroughActor() = default; + ~SimplePassThroughActor() = default; + + bool PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info, + const std::vector& allowed_input_indices, + const logging::Logger& logger, + std::unordered_map& propagate_input_config, + bool& input_has_dim_1_for_axis) override; +}; + +class ReductionOpPassThroughActor : public SimplePassThroughActor { + public: + ReductionOpPassThroughActor() = default; + ~ReductionOpPassThroughActor() = default; + + bool PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info, + const std::vector& allowed_input_indices, + const logging::Logger& logger, + std::unordered_map& propagate_input_config, + bool& input_has_dim_1_for_axis) override; +}; + +class ReshapePassThroughActor : public DefaultOperatorPassThroughActorBase { + public: + ReshapePassThroughActor() = default; + ~ReshapePassThroughActor() = default; + + bool PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info, + const std::vector& allowed_input_indices, + const logging::Logger& logger, + std::unordered_map& propagate_input_config, + bool& input_has_dim_1_for_axis) override; + + // Once slice node is passed through, we need to update the shape accordingly. + bool PostProcess(Graph& graph, Node& current_node, int current_node_output_index, + int slice_axis, bool is_slice_scalar, bool input_has_dim_1_for_axis, + const ONNX_NAMESPACE::TensorShapeProto_Dimension& output_dim_on_axis, + const std::string& entry_node_name, + const std::unordered_map& new_gather_infos, + const logging::Logger& logger) override; +}; + +class TransposePassThroughActor : public DefaultOperatorPassThroughActorBase { + public: + TransposePassThroughActor() = default; + ~TransposePassThroughActor() = default; + + bool PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info, + const std::vector& allowed_input_indices, + const logging::Logger& logger, + std::unordered_map& propagate_input_config, + bool& input_has_dim_1_for_axis) override; + + // If scalar slice happens, we need adapt the input, otherwise the perm cannot be matched. + bool PostProcess(Graph& graph, Node& current_node, int current_node_output_index, + int slice_axis, bool is_slice_scalar, bool input_has_dim_1_for_axis, + const ONNX_NAMESPACE::TensorShapeProto_Dimension& output_dim_on_axis, + const std::string& entry_node_name, + const std::unordered_map& new_gather_infos, + const logging::Logger& logger) override; +}; + +class MatMulPassThroughActor : public DefaultOperatorPassThroughActorBase { + public: + MatMulPassThroughActor() = default; + ~MatMulPassThroughActor() = default; + + // Check which inputs can be propagated according to the slice axis. + bool PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info, + const std::vector& allowed_input_indices, + const logging::Logger& logger, + std::unordered_map& propagate_input_config, + bool& input_has_dim_1_for_axis) override; + + // If scalar slice happens in the second last dimension, we need to adapt the input. + bool PostProcess(Graph& graph, Node& current_node, int current_node_output_index, + int slice_axis, bool is_slice_scalar, bool input_has_dim_1_for_axis, + const ONNX_NAMESPACE::TensorShapeProto_Dimension& output_dim_on_axis, + const std::string& entry_node_name, + const std::unordered_map& new_gather_infos, + const logging::Logger& logger) override; +}; + +/** + * @brief Update the dim value using given new dim value at specified axis. + * + * @param arg_to_update The NodeArg to be updated. + * @param reverse_axis A negative axis MUST be given here. This is to make sure if arg_to_update has less rank + * than expected value, the update will be ignored. + * @param output_dim_on_axis New dim value to be updated. + * @return true if the update is done. + */ +bool UpdateSliceOutputShape(NodeArg& arg_to_update, int reverse_axis, + const ONNX_NAMESPACE::TensorShapeProto_Dimension& new_dim_value); + +/** + * @brief Insert a new node to the graph, + * 1. taking dest_node.input[dest_input_index] as the input of the new node. + * 2. remove connection of dest_node and it's dest_input_index-th input producer node. + * 3. connect the new node and dest_node. + * + * Original graph: + * Node A + * / \ + * A-output-0 A-output-1 + * \ B-input-1 + * \ / + * Node B + * | + * + * dest_node = Node B + * dest_input_index = 0 + * op_type = C + * + * After inserting the new node: + * Node A + * / \ + * A-output-0 A-output-1 + * \ + * Node C + * / \ + * c-output-0 C-output-1 B-input-1 + * \ / + * Node B + * | + * @param graph Graph to insert the new node. + * @param dest_node The node to insert the new node before. + * @param dest_in_index The input index of the dest_node to insert the new node before. + * @param new_node_output_index The output index of the new node to connect to the dest_node. + * @param name The name of the new node. + * @param op_type The op_type of the new node. + * @param description The description of the new node. + * @param input_args The input args of the new node. At least one of the input args should be the + * dest_node's dest_in_index-th input arg. + * @param attributes The attributes of the new node. + * @param domain The domain of the new node. + * @param logger The logger. + * @return + */ +Node* InsertIntermediateNodeOnDestInput(Graph& graph, + Node& dest_node, int dest_in_index, + int new_node_input_index, + int new_node_output_index, + const std::string& name, const std::string& op_type, + const std::string& description, + const InlinedVector& input_args, + const InlinedVector& output_args, + const onnxruntime::NodeAttributes& attributes, + const std::string& domain, + const logging::Logger& logger); + +/** + * @brief Insert adaptor nodes for the inputs and output, to make sure they remain the same rank, when scalar slicing + * is done. + * + * Be noted: at this point, slice node already been removed. + * + * @param graph Graph to insert the adaptor nodes. + * @param current_node For whom to insert the adaptor nodes. + * @param slice_axis The axis of the slice node. + * @param entry_node_name Then name of the entry slice node, used for naming only. + * @param new_gather_infos Populated slicing infos for current_node's inputs. + * @param target_node_output_index output_index of current_node's output, connecting to the slice node. + * @param logger Logger. + */ +void AdaptInputAndOutputForScalarSlice(Graph& graph, Node& current_node, int current_node_output_index, + int slice_axis, const std::string& entry_node_name, + const std::unordered_map& new_gather_infos, + const logging::Logger& logger); + +} // namespace onnxruntime::optimizer::compute_optimizer + +#endif diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 2231eac0bd..8d76a160b7 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -307,7 +307,7 @@ InlinedVector> GenerateTransformers( #ifdef ENABLE_TRAINING // Put memory optimization transformer at last (which is done after most of fusions are done) by intention. - // Known issue: after mmeory optimization is completed, if some fusion happens, it is possible that the + // Known issue: after memory optimization is completed, if some fusion happens, it is possible that the // node priority got changed. This may disorder the execution order of nodes to recompute. // TODO(pengwa): need to fix this issue. const std::string enable_memory_optimizer = @@ -329,7 +329,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(std::move(cpu_allocator))); // NCHWCtransformer should have a higher priority versus this. Because NCHWCtransformer also do the similar things // of fusion patterns and target on CPU. However, NCHWCtransformer will reorder the layout to nchwc which is only available for - // x86-64 cpu, not edge cpu like arm. But This tranformer could be used by opencl-ep/cpu-ep. So + // x86-64 cpu, not edge cpu like arm. But This transformer could be used by opencl-ep/cpu-ep. So // we will prefer NhwcTransformer once ort runs on x86-64 CPU, otherwise ConvAddActivationFusion is enabled. // PR #6351 implemented similar fusion-pattern for CUDA only, and can only fuse conv-add-relu, // while we can fuse more activation. diff --git a/onnxruntime/test/optimizer/compute_optimizer_test.cc b/onnxruntime/test/optimizer/compute_optimizer_test.cc new file mode 100644 index 0000000000..a9f1fd705d --- /dev/null +++ b/onnxruntime/test/optimizer/compute_optimizer_test.cc @@ -0,0 +1,1596 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef ENABLE_TRAINING + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4244) +#endif + +#include +#include "core/graph/onnx_protobuf.h" + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +#include "asserts.h" +#include "core/common/span_utils.h" +#include "core/framework/data_types.h" +#include "core/framework/ort_value.h" +#include "core/graph/graph_utils.h" +#include "core/graph/graph_viewer.h" +#include "core/graph/model.h" + +#include "core/optimizer/common_subexpression_elimination.h" +#include "core/optimizer/compute_optimizer/compute_optimizer.h" +#include "core/optimizer/utils.h" +#include "core/platform/env.h" +#include "core/session/inference_session.h" +#include "core/util/math.h" + +#include "test/compare_ortvalue.h" +#include "test/framework/test_utils.h" +#include "test/providers/provider_test_utils.h" +#include "test/test_environment.h" +#include "test/util/include/temp_dir.h" +#include "test/util/include/asserts.h" +#include "test/util/include/default_providers.h" + +namespace onnxruntime { +namespace test { + +#define MODEL_FOLDER ORT_TSTR("testdata/transform/") + +// LayerNormalization/Gelu implementation are in contrib namespace (OnnxDomain 1), so +// Without contib_ops enabled, we cannot parse the graph correctly. +#ifndef DISABLE_CONTRIB_OPS +static void GatherNDComputationReductionTest(const std::string& op_type, + const logging::Logger& logger, + std::function validation_func) { + std::string op_type_lower = op_type; + std::transform(op_type_lower.begin(), op_type_lower.end(), op_type_lower.begin(), + [](unsigned char c) { return std::tolower(c); }); + std::string file_path = std::string("testdata/transform/computation_reduction/gathernd/gathernd_") + op_type_lower + + std::string(".onnx"); + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(ToPathString(file_path), model, nullptr, logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, logger)); + + validation_func(graph, op_type); +} + +void SingleOpDefaultValidationFunc(Graph& graph, std::string op_type) { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + Node* gathernd_node = nullptr; + for (auto node_index : node_topology_list) { + Node* p_node = graph.GetNode(node_index); + ASSERT_FALSE(p_node == nullptr); + if (p_node->OpType().compare("GatherND") == 0) { + gathernd_node = p_node; + EXPECT_EQ(gathernd_node->MutableInputDefs()[0]->Name(), "input"); + const auto& consumers = graph.GetConsumerNodes(gathernd_node->MutableOutputDefs()[0]->Name()); + EXPECT_EQ(consumers[0]->OpType(), op_type); + } + } + + ASSERT_FALSE(gathernd_node == nullptr); +} + +TEST(ComputeOptimizerTests, GatherND_Gelu) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + GatherNDComputationReductionTest("Gelu", *logger, SingleOpDefaultValidationFunc); +} + +TEST(ComputeOptimizerTests, GatherND_Add) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + GatherNDComputationReductionTest("Add", *logger, [](Graph& graph, std::string op_type) -> void { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + Node* gathernd_node = nullptr; + bool found_gathernd_around_graph_output = false; + for (auto node_index : node_topology_list) { + Node* p_node = graph.GetNode(node_index); + ASSERT_FALSE(p_node == nullptr); + if (p_node->OpType().compare("GatherND") == 0) { + if (p_node->OutputDefs()[0]->Name().compare("output") != 0) { + gathernd_node = p_node; + EXPECT_EQ(gathernd_node->MutableInputDefs()[0]->Name(), "input"); + const auto& consumers = graph.GetConsumerNodes(gathernd_node->MutableOutputDefs()[0]->Name()); + EXPECT_EQ(consumers[0]->OpType(), op_type); + } else { + found_gathernd_around_graph_output = true; + } + } + } + ASSERT_FALSE(gathernd_node == nullptr); + EXPECT_TRUE(found_gathernd_around_graph_output); }); +} + +TEST(ComputeOptimizerTests, GatherND_LayerNormalization) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + GatherNDComputationReductionTest("LayerNormalization", *logger, SingleOpDefaultValidationFunc); +} + +TEST(ComputeOptimizerTests, GatherND_MatMul) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + GatherNDComputationReductionTest("MatMul", *logger, SingleOpDefaultValidationFunc); +} + +/** + * @brief Class represent a input data (dimensions, data type and value). + */ +struct TestInputData { + template + TestInputData(const std::string& name, const TensorShapeVector& dims, const std::vector& values) + : name_(name), dims_(dims), values_(values) {} + + OrtValue ToOrtValue() { + OrtValue ortvalue; + std::vector dims; + dims.reserve(dims_.size()); + dims.insert(dims.end(), dims_.begin(), dims_.end()); + std::visit([&ortvalue, &dims](auto&& arg) { + using T = std::decay_t; + if constexpr (std::is_same_v> || + std::is_same_v> || + std::is_same_v>) + CreateMLValue( + TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims, arg, &ortvalue); + else + static_assert("Unspported types!"); + }, + values_); + + return ortvalue; + } + + std::string GetName() const { + return name_; + } + + private: + std::string name_; + TensorShapeVector dims_; + std::variant, std::vector, std::vector> values_; +}; + +void RandomFillFloatVector(const TensorShapeVector& shape, std::vector& data) { + float scale = 1.f; + float mean = 0.f; + float seed = 123.f; + data.resize(TensorShape(shape).Size()); + std::default_random_engine generator_float{gsl::narrow_cast(seed)}; + std::normal_distribution distribution_float{mean, scale}; + + std::for_each(data.begin(), data.end(), + [&generator_float, &distribution_float](float& value) { + value = distribution_float(generator_float); + }); +} + +void RandomFillHalfVector(const TensorShapeVector& shape, std::vector& data) { + std::vector data_float(TensorShape(shape).Size()); + std::transform(data_float.begin(), data_float.end(), data.begin(), + [](float value) { return MLFloat16(math::floatToHalf(value)); }); +} + +struct InputContainer { + InputContainer() = default; + + template + TestInputData& AddInput(const std::string& name, const TensorShapeVector dims, const std::vector& values) { + inputs_.emplace_back(TestInputData(name, dims, values)); + return inputs_.back(); + } + + template + TestInputData& AddInput(const std::string& name, TensorShapeVector dims, + std::function< + void(const TensorShapeVector& shape, std::vector& data)> + func = nullptr) { + std::vector values(TensorShape(dims).Size()); + if (func) { + func(dims, values); + } + + inputs_.emplace_back(TestInputData(name, dims, values)); + return inputs_.back(); + } + + void ToInputMap(NameMLValMap& feeds) const { + for (auto input : inputs_) { + feeds.insert({input.GetName(), input.ToOrtValue()}); + } + } + + private: + std::vector inputs_; +}; + +static void RunModelWithData(const PathString& model_uri, const std::string session_log_id, + const std::string& provider_type, const InputContainer& input_container, + const std::vector& output_names, + std::vector& run_results) { + SessionOptions so; + // we don't want any transformation here. + so.graph_optimization_level = TransformerLevel::Default; + so.session_logid = session_log_id; + + InferenceSession session_object{so, GetEnvironment()}; + std::unique_ptr execution_provider; + if (provider_type == onnxruntime::kCpuExecutionProvider) + execution_provider = DefaultCpuExecutionProvider(); + else if (provider_type == onnxruntime::kCudaExecutionProvider) + execution_provider = DefaultCudaExecutionProvider(); + else if (provider_type == onnxruntime::kRocmExecutionProvider) + execution_provider = DefaultRocmExecutionProvider(); + EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); + + Status st; + ASSERT_TRUE((st = session_object.Load(model_uri)).IsOK()) << st.ErrorMessage(); + ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st.ErrorMessage(); + + NameMLValMap feeds; + input_container.ToInputMap(feeds); + + // Now run + RunOptions run_options; + st = session_object.Run(run_options, feeds, output_names, &run_results); + + ASSERT_TRUE(st.IsOK()) << "RunModelWithData run graph failed with error: " << st.ErrorMessage(); +} + +TEST(ComputeOptimizerTests, GatherND_E2E) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + auto model_uri = MODEL_FOLDER "computation_reduction/gathernd/e2e.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); + + // check the expected node orders. + { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + Node* gathernd_node = nullptr; + for (auto node_index : node_topology_list) { + Node* p_node = graph.GetNode(node_index); + ASSERT_FALSE(p_node == nullptr); + if (p_node->OpType().compare("GatherND") == 0) { + gathernd_node = p_node; + const Node* layer_norm_node = graph.GetProducerNode(gathernd_node->MutableInputDefs()[0]->Name()); + EXPECT_EQ(layer_norm_node->OpType(), "LayerNormalization"); + EXPECT_EQ(layer_norm_node->Name(), "layer_norm_1"); + const auto& consumers = graph.GetConsumerNodes(gathernd_node->MutableOutputDefs()[0]->Name()); + EXPECT_EQ(consumers[0]->OpType(), "MatMul"); + EXPECT_EQ(consumers[0]->Name(), "matmul_1"); + break; + } + } + + ASSERT_FALSE(gathernd_node == nullptr); + } + + // check result diff after the re-order + onnxruntime::test::TemporaryDirectory tmp_dir{ORT_TSTR("compute_optimizer_test_tmp_dir")}; + PathString new_model_uri{ConcatPathComponent(tmp_dir.Path(), + ORT_TSTR("computation_reduction_transformer_after.onnx"))}; + ASSERT_STATUS_OK(Model::Save(*model, new_model_uri)); + + InputContainer input_container; + + int batch_size = 8; + int sequence = 128; + int hidden_size = 128; + int dynamic_predict_count = 20; + input_container.AddInput("input", {batch_size, sequence, hidden_size}, RandomFillFloatVector); + + const TensorShapeVector dims_unsqueezed_masked_lm_positions{batch_size, dynamic_predict_count, 1}; + std::vector values_unsqueezed_masked_lm_positions(TensorShape(dims_unsqueezed_masked_lm_positions).Size()); + + std::random_device rd; // obtain a random number from hardware + std::mt19937 eng(rd()); // seed the generator + std::uniform_int_distribution<> distr(0, sequence - 1); // define the range + std::for_each(values_unsqueezed_masked_lm_positions.begin(), values_unsqueezed_masked_lm_positions.end(), + [&distr, &eng](int64_t& value) { value = distr(eng); }); + + input_container.AddInput("unsqueezed_masked_lm_positions", + dims_unsqueezed_masked_lm_positions, + values_unsqueezed_masked_lm_positions); + + static const std::string all_provider_types[] = { + onnxruntime::kCpuExecutionProvider, +#ifdef USE_CUDA + onnxruntime::kCudaExecutionProvider, +#elif USE_ROCM + onnxruntime::kRocmExecutionProvider, +#endif + }; + const std::vector output_names{"output", "gather_output"}; + + for (auto& provider_type : all_provider_types) { + std::vector expected_ort_values; + RunModelWithData(model_uri, std::string("RawGraphRun"), provider_type, + input_container, output_names, expected_ort_values); + + std::vector actual_ort_values; + RunModelWithData(ToPathString(new_model_uri), std::string("OptimizedGraphRun"), provider_type, + input_container, output_names, actual_ort_values); + + ASSERT_TRUE(expected_ort_values.size() == actual_ort_values.size()); + constexpr double per_sample_tolerance = 1e-4; + constexpr double relative_per_sample_tolerance = 1e-4; + for (size_t i = 0; i < expected_ort_values.size(); i++) { + auto ret = CompareOrtValue(actual_ort_values[i], expected_ort_values[i], + per_sample_tolerance, relative_per_sample_tolerance, false); + EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; + } + } +} + +TEST(ComputeOptimizerTests, GatherMatMul_ScalarSlicingOnBatchDim) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + auto model_uri = MODEL_FOLDER "computation_reduction/gather/gather_matmul_scalar_batch_dim.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); + + GraphViewer graph_viewer(graph); + // Check the first Gather. + { + const std::vector& consumers = graph.GetConsumerNodes("input1"); + ASSERT_EQ(consumers.size(), 1U); + const Node* gather_node = consumers[0]; + ASSERT_EQ(gather_node->OpType(), "Gather"); + + auto& attrs = gather_node->GetAttributes(); + ASSERT_TRUE(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + ASSERT_EQ(axis_value, 0); + } + + // Check the second Gather. + { + const std::vector& consumers = graph.GetConsumerNodes("input2"); + ASSERT_EQ(consumers.size(), 1U); + const Node* gather_node = consumers[0]; + ASSERT_EQ(gather_node->OpType(), "Gather"); + + auto& attrs = gather_node->GetAttributes(); + ASSERT_TRUE(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + ASSERT_EQ(axis_value, 0); + } + + // Check MatMul's input and output. + { + const Node* m5 = graph.GetProducerNode("m1_out"); + ASSERT_FALSE(m5 == nullptr); + EXPECT_EQ(m5->OpType(), "MatMul"); + EXPECT_EQ(m5->Name(), "m1"); + + const Node* lhs_input = graph.GetProducerNode(m5->InputDefs()[0]->Name()); + const Node* rhs_input = graph.GetProducerNode(m5->InputDefs()[1]->Name()); + + ASSERT_FALSE(lhs_input == nullptr); + EXPECT_EQ(lhs_input->OpType(), "Unsqueeze"); + + ASSERT_FALSE(rhs_input == nullptr); + EXPECT_EQ(rhs_input->OpType(), "Unsqueeze"); + } + + // Check result diff after the re-order + onnxruntime::test::TemporaryDirectory tmp_dir{ORT_TSTR("compute_optimizer_test_tmp_dir")}; + PathString new_model_uri{ConcatPathComponent(tmp_dir.Path(), + ORT_TSTR("gather_matmul_scalar_batch_dim_optimized.onnx"))}; + ASSERT_STATUS_OK(Model::Save(*model, new_model_uri)); + + int64_t batch_size = 8; + int64_t sequence_length = 16; + int64_t hidden_size = 1024; + + InputContainer input_container; + + input_container.AddInput("input1", {batch_size, sequence_length, hidden_size}, RandomFillFloatVector); + input_container.AddInput("input2", {batch_size, hidden_size, sequence_length}, RandomFillFloatVector); + + static const std::string all_provider_types[] = { + onnxruntime::kCpuExecutionProvider, +#ifdef USE_CUDA + onnxruntime::kCudaExecutionProvider, +#elif USE_ROCM + onnxruntime::kRocmExecutionProvider, +#endif + }; + + const std::vector output_names = {"final_output"}; + + for (auto& provider_type : all_provider_types) { + std::vector expected_ort_values; + RunModelWithData(model_uri, std::string("RawGraphRun"), provider_type, + input_container, output_names, expected_ort_values); + + std::vector actual_ort_values; + RunModelWithData(ToPathString(new_model_uri), std::string("OptimizedGraphRun"), + provider_type, input_container, output_names, actual_ort_values); + + ASSERT_TRUE(expected_ort_values.size() == actual_ort_values.size()); + constexpr double per_sample_tolerance = 1e-4; + constexpr double relative_per_sample_tolerance = 1e-4; + for (size_t i = 0; i < expected_ort_values.size(); i++) { + auto ret = CompareOrtValue(actual_ort_values[i], expected_ort_values[i], + per_sample_tolerance, relative_per_sample_tolerance, false); + EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; + } + } +} + +TEST(ComputeOptimizerTests, GatherMatMul_SlicingOnBatchDim) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + auto model_uri = MODEL_FOLDER "computation_reduction/gather/gather_matmul_batch_dim.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); + + GraphViewer graph_viewer(graph); + // Check the first Gather. + { + const std::vector& consumers = graph.GetConsumerNodes("input1"); + ASSERT_EQ(consumers.size(), 1U); + const Node* gather_node = consumers[0]; + ASSERT_EQ(gather_node->OpType(), "Gather"); + + auto& attrs = gather_node->GetAttributes(); + ASSERT_TRUE(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + ASSERT_EQ(axis_value, 0); + } + + // Check the second Gather. + { + const std::vector& consumers = graph.GetConsumerNodes("input2"); + ASSERT_EQ(consumers.size(), 1U); + const Node* gather_node = consumers[0]; + ASSERT_EQ(gather_node->OpType(), "Gather"); + + auto& attrs = gather_node->GetAttributes(); + ASSERT_TRUE(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + ASSERT_EQ(axis_value, 0); + } + + // Check MatMul's input and output. + { + const Node* m5 = graph.GetProducerNode("m1_out"); + ASSERT_FALSE(m5 == nullptr); + EXPECT_EQ(m5->OpType(), "MatMul"); + EXPECT_EQ(m5->Name(), "m1"); + + const Node* lhs_input = graph.GetProducerNode(m5->InputDefs()[0]->Name()); + const Node* rhs_input = graph.GetProducerNode(m5->InputDefs()[1]->Name()); + + ASSERT_FALSE(lhs_input == nullptr); + EXPECT_EQ(lhs_input->OpType(), "Gather"); + + ASSERT_FALSE(rhs_input == nullptr); + EXPECT_EQ(rhs_input->OpType(), "Gather"); + } + + // Check result diff after the re-order + onnxruntime::test::TemporaryDirectory tmp_dir{ORT_TSTR("compute_optimizer_test_tmp_dir")}; + PathString new_model_uri{ConcatPathComponent(tmp_dir.Path(), + ORT_TSTR("gather_matmul_batch_dim_optimized.onnx"))}; + ASSERT_STATUS_OK(Model::Save(*model, new_model_uri)); + + int64_t batch_size = 8; + int64_t sequence_length = 16; + int64_t hidden_size = 1024; + + InputContainer input_container; + + input_container.AddInput("input1", {batch_size, sequence_length, hidden_size}, RandomFillFloatVector); + input_container.AddInput("input2", {batch_size, hidden_size, sequence_length}, RandomFillFloatVector); + + static const std::string all_provider_types[] = { + onnxruntime::kCpuExecutionProvider, +#ifdef USE_CUDA + onnxruntime::kCudaExecutionProvider, +#elif USE_ROCM + onnxruntime::kRocmExecutionProvider, +#endif + }; + + const std::vector output_names = {"final_output"}; + + for (auto& provider_type : all_provider_types) { + std::vector expected_ort_values; + RunModelWithData(model_uri, std::string("RawGraphRun"), provider_type, + input_container, output_names, expected_ort_values); + + std::vector actual_ort_values; + RunModelWithData(ToPathString(new_model_uri), std::string("OptimizedGraphRun"), + provider_type, input_container, output_names, actual_ort_values); + + ASSERT_TRUE(expected_ort_values.size() == actual_ort_values.size()); + constexpr double per_sample_tolerance = 1e-4; + constexpr double relative_per_sample_tolerance = 1e-4; + for (size_t i = 0; i < expected_ort_values.size(); i++) { + auto ret = CompareOrtValue(actual_ort_values[i], expected_ort_values[i], + per_sample_tolerance, relative_per_sample_tolerance, false); + EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; + } + } +} + +TEST(ComputeOptimizerTests, GatherMatMul_ScalarSlicingOnLastDim) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + auto model_uri = MODEL_FOLDER "computation_reduction/gather/gather_matmul_scalar_last_dim.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); + + GraphViewer graph_viewer(graph); + // Check the first branch. + { + const std::vector& consumers = graph.GetConsumerNodes("input1"); + ASSERT_EQ(consumers.size(), 1U); + const Node* gather_node = consumers[0]; + ASSERT_EQ(gather_node->OpType(), "MatMul"); + } + + // Check the second Gather. + { + const std::vector& consumers = graph.GetConsumerNodes("input2"); + ASSERT_EQ(consumers.size(), 1U); + const Node* gather_node = consumers[0]; + ASSERT_EQ(gather_node->OpType(), "Gather"); + + auto& attrs = gather_node->GetAttributes(); + ASSERT_TRUE(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + ASSERT_EQ(axis_value, 2); + } + + // Check MatMul's input and output. + { + const Node* m5 = graph.GetProducerNode("m1_out"); + ASSERT_FALSE(m5 == nullptr); + EXPECT_EQ(m5->OpType(), "MatMul"); + EXPECT_EQ(m5->Name(), "m1"); + + const Node* lhs_input = graph.GetProducerNode(m5->InputDefs()[0]->Name()); + const Node* rhs_input = graph.GetProducerNode(m5->InputDefs()[1]->Name()); + + ASSERT_TRUE(lhs_input == nullptr); + + ASSERT_FALSE(rhs_input == nullptr); + EXPECT_EQ(rhs_input->OpType(), "Unsqueeze"); + } + + // Check result diff after the re-order + onnxruntime::test::TemporaryDirectory tmp_dir{ORT_TSTR("compute_optimizer_test_tmp_dir")}; + PathString new_model_uri{ConcatPathComponent(tmp_dir.Path(), + ORT_TSTR("gather_matmul_scalar_last_dim_optimized.onnx"))}; + ASSERT_STATUS_OK(Model::Save(*model, new_model_uri)); + + int64_t batch_size = 8; + int64_t sequence_length = 16; + int64_t hidden_size = 1024; + + InputContainer input_container; + + input_container.AddInput("input1", {batch_size, sequence_length, hidden_size}, RandomFillFloatVector); + input_container.AddInput("input2", {batch_size, hidden_size, sequence_length}, RandomFillFloatVector); + + static const std::string all_provider_types[] = { + onnxruntime::kCpuExecutionProvider, +#ifdef USE_CUDA + onnxruntime::kCudaExecutionProvider, +#elif USE_ROCM + onnxruntime::kRocmExecutionProvider, +#endif + }; + + const std::vector output_names = {"final_output"}; + + for (auto& provider_type : all_provider_types) { + std::vector expected_ort_values; + RunModelWithData(model_uri, std::string("RawGraphRun"), provider_type, + input_container, output_names, expected_ort_values); + + std::vector actual_ort_values; + RunModelWithData(ToPathString(new_model_uri), std::string("OptimizedGraphRun"), + provider_type, input_container, output_names, actual_ort_values); + + ASSERT_TRUE(expected_ort_values.size() == actual_ort_values.size()); + constexpr double per_sample_tolerance = 1e-4; + constexpr double relative_per_sample_tolerance = 1e-4; + for (size_t i = 0; i < expected_ort_values.size(); i++) { + auto ret = CompareOrtValue(actual_ort_values[i], expected_ort_values[i], + per_sample_tolerance, relative_per_sample_tolerance, false); + EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; + } + } +} + +TEST(ComputeOptimizerTests, GatherMatMul_SlicingOnLastDim) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + auto model_uri = MODEL_FOLDER "computation_reduction/gather/gather_matmul_last_dim.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); + + GraphViewer graph_viewer(graph); + // Check the first branch. + { + const std::vector& consumers = graph.GetConsumerNodes("input1"); + ASSERT_EQ(consumers.size(), 1U); + const Node* gather_node = consumers[0]; + ASSERT_EQ(gather_node->OpType(), "MatMul"); + } + + // Check the second Gather. + { + const std::vector& consumers = graph.GetConsumerNodes("input2"); + ASSERT_EQ(consumers.size(), 1U); + const Node* gather_node = consumers[0]; + ASSERT_EQ(gather_node->OpType(), "Gather"); + + auto& attrs = gather_node->GetAttributes(); + ASSERT_TRUE(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + ASSERT_EQ(axis_value, 2); + } + + // Check MatMul's input and output. + { + const Node* m5 = graph.GetProducerNode("m1_out"); + ASSERT_FALSE(m5 == nullptr); + EXPECT_EQ(m5->OpType(), "MatMul"); + EXPECT_EQ(m5->Name(), "m1"); + + const Node* lhs_input = graph.GetProducerNode(m5->InputDefs()[0]->Name()); + const Node* rhs_input = graph.GetProducerNode(m5->InputDefs()[1]->Name()); + + ASSERT_TRUE(lhs_input == nullptr); + + ASSERT_FALSE(rhs_input == nullptr); + EXPECT_EQ(rhs_input->OpType(), "Gather"); + } + + // Check result diff after the re-order + onnxruntime::test::TemporaryDirectory tmp_dir{ORT_TSTR("compute_optimizer_test_tmp_dir")}; + PathString new_model_uri{ConcatPathComponent(tmp_dir.Path(), + ORT_TSTR("gather_matmul_last_dim_optimized.onnx"))}; + ASSERT_STATUS_OK(Model::Save(*model, new_model_uri)); + + int64_t batch_size = 8; + int64_t sequence_length = 16; + int64_t hidden_size = 1024; + + InputContainer input_container; + + input_container.AddInput("input1", {batch_size, sequence_length, hidden_size}, RandomFillFloatVector); + input_container.AddInput("input2", {batch_size, hidden_size, sequence_length}, RandomFillFloatVector); + + static const std::string all_provider_types[] = { + onnxruntime::kCpuExecutionProvider, +#ifdef USE_CUDA + onnxruntime::kCudaExecutionProvider, +#elif USE_ROCM + onnxruntime::kRocmExecutionProvider, +#endif + }; + + const std::vector output_names = {"final_output"}; + + for (auto& provider_type : all_provider_types) { + std::vector expected_ort_values; + RunModelWithData(model_uri, std::string("RawGraphRun"), provider_type, + input_container, output_names, expected_ort_values); + + std::vector actual_ort_values; + RunModelWithData(ToPathString(new_model_uri), std::string("OptimizedGraphRun"), + provider_type, input_container, output_names, actual_ort_values); + + ASSERT_TRUE(expected_ort_values.size() == actual_ort_values.size()); + constexpr double per_sample_tolerance = 1e-4; + constexpr double relative_per_sample_tolerance = 1e-4; + for (size_t i = 0; i < expected_ort_values.size(); i++) { + auto ret = CompareOrtValue(actual_ort_values[i], expected_ort_values[i], + per_sample_tolerance, relative_per_sample_tolerance, false); + EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; + } + } +} + +TEST(ComputeOptimizerTests, GatherMatMul_ScalarSlicingOnSecondLastDim) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + auto model_uri = MODEL_FOLDER "computation_reduction/gather/gather_matmul_scalar_second_last_dim.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); + + GraphViewer graph_viewer(graph); + // Check the first Gather. + { + const std::vector& consumers = graph.GetConsumerNodes("input1"); + ASSERT_EQ(consumers.size(), 1U); + const Node* gather_node = consumers[0]; + ASSERT_EQ(gather_node->OpType(), "Gather"); + + auto& attrs = gather_node->GetAttributes(); + ASSERT_TRUE(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + ASSERT_EQ(axis_value, 1); + } + + // Check the second branch. + { + const std::vector& consumers = graph.GetConsumerNodes("input2"); + ASSERT_EQ(consumers.size(), 1U); + const Node* gather_node = consumers[0]; + ASSERT_EQ(gather_node->OpType(), "MatMul"); + } + + // Check MatMul(who gathers on the second last dim)'s input and output. + { + const Node* m5 = graph.GetProducerNode("m1_out"); + ASSERT_FALSE(m5 == nullptr); + EXPECT_EQ(m5->OpType(), "MatMul"); + EXPECT_EQ(m5->Name(), "m1"); + + const Node* lhs_input = graph.GetProducerNode(m5->InputDefs()[0]->Name()); + const Node* rhs_input = graph.GetProducerNode(m5->InputDefs()[1]->Name()); + + ASSERT_FALSE(lhs_input == nullptr); + EXPECT_EQ(lhs_input->OpType(), "Unsqueeze"); + + ASSERT_TRUE(rhs_input == nullptr); + } + + // Check result diff after the re-order + onnxruntime::test::TemporaryDirectory tmp_dir{ORT_TSTR("compute_optimizer_test_tmp_dir")}; + PathString new_model_uri{ConcatPathComponent( + tmp_dir.Path(), + ORT_TSTR("gather_matmul_scalar_second_last_dim_optimized.onnx"))}; + ASSERT_STATUS_OK(Model::Save(*model, new_model_uri)); + + int64_t batch_size = 8; + int64_t sequence_length = 16; + int64_t hidden_size = 1024; + + InputContainer input_container; + + input_container.AddInput("input1", {batch_size, sequence_length, hidden_size}, RandomFillFloatVector); + input_container.AddInput("input2", {batch_size, hidden_size, sequence_length}, RandomFillFloatVector); + + static const std::string all_provider_types[] = { + onnxruntime::kCpuExecutionProvider, +#ifdef USE_CUDA + onnxruntime::kCudaExecutionProvider, +#elif USE_ROCM + onnxruntime::kRocmExecutionProvider, +#endif + }; + + const std::vector output_names = {"final_output"}; + + for (auto& provider_type : all_provider_types) { + std::vector expected_ort_values; + RunModelWithData(model_uri, std::string("RawGraphRun"), provider_type, + input_container, output_names, expected_ort_values); + + std::vector actual_ort_values; + RunModelWithData(ToPathString(new_model_uri), std::string("OptimizedGraphRun"), + provider_type, input_container, output_names, actual_ort_values); + + ASSERT_TRUE(expected_ort_values.size() == actual_ort_values.size()); + constexpr double per_sample_tolerance = 1e-4; + constexpr double relative_per_sample_tolerance = 1e-4; + for (size_t i = 0; i < expected_ort_values.size(); i++) { + auto ret = CompareOrtValue(actual_ort_values[i], expected_ort_values[i], + per_sample_tolerance, relative_per_sample_tolerance, false); + EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; + } + } +} + +TEST(ComputeOptimizerTests, GatherMatMul_SlicingOnSecondLastDim) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + auto model_uri = MODEL_FOLDER "computation_reduction/gather/gather_matmul_second_last_dim.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); + + GraphViewer graph_viewer(graph); + // Check the first Gather. + { + const std::vector& consumers = graph.GetConsumerNodes("input1"); + ASSERT_EQ(consumers.size(), 1U); + const Node* gather_node = consumers[0]; + ASSERT_EQ(gather_node->OpType(), "Gather"); + + auto& attrs = gather_node->GetAttributes(); + ASSERT_TRUE(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + ASSERT_EQ(axis_value, 1); + } + + // Check the second branch. + { + const std::vector& consumers = graph.GetConsumerNodes("input2"); + ASSERT_EQ(consumers.size(), 1U); + const Node* gather_node = consumers[0]; + ASSERT_EQ(gather_node->OpType(), "MatMul"); + } + + // Check MatMul's input and output. + { + const Node* m5 = graph.GetProducerNode("m1_out"); + ASSERT_FALSE(m5 == nullptr); + EXPECT_EQ(m5->OpType(), "MatMul"); + EXPECT_EQ(m5->Name(), "m1"); + + const Node* lhs_input = graph.GetProducerNode(m5->InputDefs()[0]->Name()); + const Node* rhs_input = graph.GetProducerNode(m5->InputDefs()[1]->Name()); + + ASSERT_FALSE(lhs_input == nullptr); + EXPECT_EQ(lhs_input->OpType(), "Gather"); + + ASSERT_TRUE(rhs_input == nullptr); + } + + // Check result diff after the re-order + onnxruntime::test::TemporaryDirectory tmp_dir{ORT_TSTR("compute_optimizer_test_tmp_dir")}; + PathString new_model_uri{ConcatPathComponent(tmp_dir.Path(), + ORT_TSTR("gather_matmul_second_last_dim_optimized.onnx"))}; + ASSERT_STATUS_OK(Model::Save(*model, new_model_uri)); + + int64_t batch_size = 8; + int64_t sequence_length = 16; + int64_t hidden_size = 1024; + + InputContainer input_container; + + input_container.AddInput("input1", {batch_size, sequence_length, hidden_size}, RandomFillFloatVector); + input_container.AddInput("input2", {batch_size, hidden_size, sequence_length}, RandomFillFloatVector); + + static const std::string all_provider_types[] = { + onnxruntime::kCpuExecutionProvider, +#ifdef USE_CUDA + onnxruntime::kCudaExecutionProvider, +#elif USE_ROCM + onnxruntime::kRocmExecutionProvider, +#endif + }; + + const std::vector output_names = {"final_output"}; + + for (auto& provider_type : all_provider_types) { + std::vector expected_ort_values; + RunModelWithData(model_uri, std::string("RawGraphRun"), provider_type, + input_container, output_names, expected_ort_values); + + std::vector actual_ort_values; + RunModelWithData(ToPathString(new_model_uri), std::string("OptimizedGraphRun"), + provider_type, input_container, output_names, actual_ort_values); + + ASSERT_TRUE(expected_ort_values.size() == actual_ort_values.size()); + constexpr double per_sample_tolerance = 1e-4; + constexpr double relative_per_sample_tolerance = 1e-4; + for (size_t i = 0; i < expected_ort_values.size(); i++) { + auto ret = CompareOrtValue(actual_ort_values[i], expected_ort_values[i], + per_sample_tolerance, relative_per_sample_tolerance, false); + EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; + } + } +} + +TEST(ComputeOptimizerTests, GatherReshape_ScalarSlicingOnBatchDim) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + auto model_uri = MODEL_FOLDER "computation_reduction/gather/gather_reshape_scalar_batch_dim.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); + + GraphViewer graph_viewer(graph); + // Check the first Gather. + { + const std::vector& consumers = graph.GetConsumerNodes("input1"); + ASSERT_EQ(consumers.size(), 1U); + const Node* gather_node = consumers[0]; + ASSERT_EQ(gather_node->OpType(), "Gather"); + + auto& attrs = gather_node->GetAttributes(); + ASSERT_TRUE(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + ASSERT_EQ(axis_value, 0); + } + + { + const Node* m5 = graph.GetProducerNode("reshape_out"); + ASSERT_FALSE(m5 == nullptr); + EXPECT_EQ(m5->OpType(), "Reshape"); + + const Node* lhs_input = graph.GetProducerNode(m5->InputDefs()[0]->Name()); + const Node* rhs_input = graph.GetProducerNode(m5->InputDefs()[1]->Name()); + + ASSERT_FALSE(lhs_input == nullptr); + EXPECT_EQ(lhs_input->OpType(), "Gather"); + + ASSERT_TRUE(rhs_input == nullptr); + InlinedVector new_shape_const_values; + optimizer_utils::AppendTensorFromInitializer(graph, *m5->InputDefs()[1], new_shape_const_values, true); + ASSERT_EQ(new_shape_const_values.size(), 3U); + ASSERT_EQ(new_shape_const_values[0], 0); + ASSERT_EQ(new_shape_const_values[1], 16); + ASSERT_EQ(new_shape_const_values[2], 64); + } + + // Check result diff after the re-order + onnxruntime::test::TemporaryDirectory tmp_dir{ORT_TSTR("compute_optimizer_test_tmp_dir")}; + PathString new_model_uri{ConcatPathComponent(tmp_dir.Path(), + ORT_TSTR("gather_reshape_scalar_batch_dim_optimized.onnx"))}; + ASSERT_STATUS_OK(Model::Save(*model, new_model_uri)); + + int64_t batch_size = 8; + int64_t sequence_length = 16; + int64_t hidden_size = 1024; + + InputContainer input_container; + + input_container.AddInput("input1", {batch_size, sequence_length, hidden_size}, RandomFillFloatVector); + + static const std::string all_provider_types[] = { + onnxruntime::kCpuExecutionProvider, +#ifdef USE_CUDA + onnxruntime::kCudaExecutionProvider, +#elif USE_ROCM + onnxruntime::kRocmExecutionProvider, +#endif + }; + + const std::vector output_names = {"final_output"}; + + for (auto& provider_type : all_provider_types) { + std::vector expected_ort_values; + RunModelWithData(model_uri, std::string("RawGraphRun"), provider_type, + input_container, output_names, expected_ort_values); + + std::vector actual_ort_values; + RunModelWithData(ToPathString(new_model_uri), std::string("OptimizedGraphRun"), + provider_type, input_container, output_names, actual_ort_values); + + ASSERT_TRUE(expected_ort_values.size() == actual_ort_values.size()); + constexpr double per_sample_tolerance = 1e-4; + constexpr double relative_per_sample_tolerance = 1e-4; + for (size_t i = 0; i < expected_ort_values.size(); i++) { + auto ret = CompareOrtValue(actual_ort_values[i], expected_ort_values[i], + per_sample_tolerance, relative_per_sample_tolerance, false); + EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; + } + } +} + +TEST(ComputeOptimizerTests, GatherReshape_SlicingOnBatchDim) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + auto model_uri = MODEL_FOLDER "computation_reduction/gather/gather_reshape_batch_dim.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); + + GraphViewer graph_viewer(graph); + // Check the first Gather. + { + const std::vector& consumers = graph.GetConsumerNodes("input1"); + ASSERT_EQ(consumers.size(), 1U); + const Node* gather_node = consumers[0]; + ASSERT_EQ(gather_node->OpType(), "Gather"); + + auto& attrs = gather_node->GetAttributes(); + ASSERT_TRUE(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + ASSERT_EQ(axis_value, 0); + } + + { + const Node* m5 = graph.GetProducerNode("reshape_out"); + ASSERT_FALSE(m5 == nullptr); + EXPECT_EQ(m5->OpType(), "Reshape"); + + const Node* lhs_input = graph.GetProducerNode(m5->InputDefs()[0]->Name()); + const Node* rhs_input = graph.GetProducerNode(m5->InputDefs()[1]->Name()); + + ASSERT_FALSE(lhs_input == nullptr); + EXPECT_EQ(lhs_input->OpType(), "Gather"); + + ASSERT_TRUE(rhs_input == nullptr); + InlinedVector new_shape_const_values; + optimizer_utils::AppendTensorFromInitializer(graph, *m5->InputDefs()[1], new_shape_const_values, true); + ASSERT_EQ(new_shape_const_values.size(), 4U); + ASSERT_EQ(new_shape_const_values[0], 0); + ASSERT_EQ(new_shape_const_values[1], 0); + ASSERT_EQ(new_shape_const_values[2], 16); + ASSERT_EQ(new_shape_const_values[3], 64); + } + + // Check result diff after the re-order + onnxruntime::test::TemporaryDirectory tmp_dir{ORT_TSTR("compute_optimizer_test_tmp_dir")}; + PathString new_model_uri{ConcatPathComponent(tmp_dir.Path(), + ORT_TSTR("gather_reshape_batch_dim_optimized.onnx"))}; + ASSERT_STATUS_OK(Model::Save(*model, new_model_uri)); + + int64_t batch_size = 8; + int64_t sequence_length = 16; + int64_t hidden_size = 1024; + + InputContainer input_container; + + input_container.AddInput("input1", {batch_size, sequence_length, hidden_size}, RandomFillFloatVector); + + static const std::string all_provider_types[] = { + onnxruntime::kCpuExecutionProvider, +#ifdef USE_CUDA + onnxruntime::kCudaExecutionProvider, +#elif USE_ROCM + onnxruntime::kRocmExecutionProvider, +#endif + }; + + const std::vector output_names = {"final_output"}; + + for (auto& provider_type : all_provider_types) { + std::vector expected_ort_values; + RunModelWithData(model_uri, std::string("RawGraphRun"), provider_type, + input_container, output_names, expected_ort_values); + + std::vector actual_ort_values; + RunModelWithData(ToPathString(new_model_uri), std::string("OptimizedGraphRun"), + provider_type, input_container, output_names, actual_ort_values); + + ASSERT_TRUE(expected_ort_values.size() == actual_ort_values.size()); + constexpr double per_sample_tolerance = 1e-4; + constexpr double relative_per_sample_tolerance = 1e-4; + for (size_t i = 0; i < expected_ort_values.size(); i++) { + auto ret = CompareOrtValue(actual_ort_values[i], expected_ort_values[i], + per_sample_tolerance, relative_per_sample_tolerance, false); + EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; + } + } +} + +TEST(ComputeOptimizerTests, GatherReshape_ScalarSlicingOnSeqlenDim) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + auto model_uri = MODEL_FOLDER "computation_reduction/gather/gather_reshape_scalar_seqlen_dim.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); + + GraphViewer graph_viewer(graph); + // Check the first Gather. + { + const std::vector& consumers = graph.GetConsumerNodes("input1"); + ASSERT_EQ(consumers.size(), 1U); + const Node* gather_node = consumers[0]; + ASSERT_EQ(gather_node->OpType(), "Gather"); + + auto& attrs = gather_node->GetAttributes(); + ASSERT_TRUE(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + ASSERT_EQ(axis_value, 1); + } + + { + const Node* m5 = graph.GetProducerNode("reshape_out"); + ASSERT_FALSE(m5 == nullptr); + EXPECT_EQ(m5->OpType(), "Reshape"); + + const Node* lhs_input = graph.GetProducerNode(m5->InputDefs()[0]->Name()); + const Node* rhs_input = graph.GetProducerNode(m5->InputDefs()[1]->Name()); + + ASSERT_FALSE(lhs_input == nullptr); + EXPECT_EQ(lhs_input->OpType(), "Gather"); + + ASSERT_TRUE(rhs_input == nullptr); + InlinedVector new_shape_const_values; + optimizer_utils::AppendTensorFromInitializer(graph, *m5->InputDefs()[1], new_shape_const_values, true); + ASSERT_EQ(new_shape_const_values.size(), 3U); + ASSERT_EQ(new_shape_const_values[0], 0); + ASSERT_EQ(new_shape_const_values[1], 16); + ASSERT_EQ(new_shape_const_values[2], 64); + } + + // Check result diff after the re-order + onnxruntime::test::TemporaryDirectory tmp_dir{ORT_TSTR("compute_optimizer_test_tmp_dir")}; + PathString new_model_uri{ConcatPathComponent(tmp_dir.Path(), + ORT_TSTR("gather_reshape_scalar_seqlen_dim_optimized.onnx"))}; + ASSERT_STATUS_OK(Model::Save(*model, new_model_uri)); + + int64_t batch_size = 8; + int64_t sequence_length = 16; + int64_t hidden_size = 1024; + + InputContainer input_container; + + input_container.AddInput("input1", {batch_size, sequence_length, hidden_size}, RandomFillFloatVector); + + static const std::string all_provider_types[] = { + onnxruntime::kCpuExecutionProvider, +#ifdef USE_CUDA + onnxruntime::kCudaExecutionProvider, +#elif USE_ROCM + onnxruntime::kRocmExecutionProvider, +#endif + }; + + const std::vector output_names = {"final_output"}; + + for (auto& provider_type : all_provider_types) { + std::vector expected_ort_values; + RunModelWithData(model_uri, std::string("RawGraphRun"), provider_type, + input_container, output_names, expected_ort_values); + + std::vector actual_ort_values; + RunModelWithData(ToPathString(new_model_uri), std::string("OptimizedGraphRun"), + provider_type, input_container, output_names, actual_ort_values); + + ASSERT_TRUE(expected_ort_values.size() == actual_ort_values.size()); + constexpr double per_sample_tolerance = 1e-4; + constexpr double relative_per_sample_tolerance = 1e-4; + for (size_t i = 0; i < expected_ort_values.size(); i++) { + auto ret = CompareOrtValue(actual_ort_values[i], expected_ort_values[i], + per_sample_tolerance, relative_per_sample_tolerance, false); + EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; + } + } +} + +TEST(ComputeOptimizerTests, GatherReshape_SlicingOnSeqlenDim) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + auto model_uri = MODEL_FOLDER "computation_reduction/gather/gather_reshape_seqlen_dim.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); + + GraphViewer graph_viewer(graph); + // Check the first Gather. + { + const std::vector& consumers = graph.GetConsumerNodes("input1"); + ASSERT_EQ(consumers.size(), 1U); + const Node* gather_node = consumers[0]; + ASSERT_EQ(gather_node->OpType(), "Gather"); + + auto& attrs = gather_node->GetAttributes(); + ASSERT_TRUE(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + ASSERT_EQ(axis_value, 1); + } + + { + const Node* m5 = graph.GetProducerNode("reshape_out"); + ASSERT_FALSE(m5 == nullptr); + EXPECT_EQ(m5->OpType(), "Reshape"); + + const Node* lhs_input = graph.GetProducerNode(m5->InputDefs()[0]->Name()); + const Node* rhs_input = graph.GetProducerNode(m5->InputDefs()[1]->Name()); + + ASSERT_FALSE(lhs_input == nullptr); + EXPECT_EQ(lhs_input->OpType(), "Gather"); + + ASSERT_TRUE(rhs_input == nullptr); + InlinedVector new_shape_const_values; + optimizer_utils::AppendTensorFromInitializer(graph, *m5->InputDefs()[1], new_shape_const_values, true); + ASSERT_EQ(new_shape_const_values.size(), 4U); + ASSERT_EQ(new_shape_const_values[0], 0); + ASSERT_EQ(new_shape_const_values[1], 0); + ASSERT_EQ(new_shape_const_values[2], 16); + ASSERT_EQ(new_shape_const_values[3], 64); + } + + // Check result diff after the re-order + onnxruntime::test::TemporaryDirectory tmp_dir{ORT_TSTR("compute_optimizer_test_tmp_dir")}; + PathString new_model_uri{ConcatPathComponent(tmp_dir.Path(), + ORT_TSTR("gather_reshape_seqlen_dim_optimized.onnx"))}; + ASSERT_STATUS_OK(Model::Save(*model, new_model_uri)); + + int64_t batch_size = 8; + int64_t sequence_length = 16; + int64_t hidden_size = 1024; + + InputContainer input_container; + + input_container.AddInput("input1", {batch_size, sequence_length, hidden_size}, RandomFillFloatVector); + + static const std::string all_provider_types[] = { + onnxruntime::kCpuExecutionProvider, +#ifdef USE_CUDA + onnxruntime::kCudaExecutionProvider, +#elif USE_ROCM + onnxruntime::kRocmExecutionProvider, +#endif + }; + + const std::vector output_names = {"final_output"}; + + for (auto& provider_type : all_provider_types) { + std::vector expected_ort_values; + RunModelWithData(model_uri, std::string("RawGraphRun"), provider_type, + input_container, output_names, expected_ort_values); + + std::vector actual_ort_values; + RunModelWithData(ToPathString(new_model_uri), std::string("OptimizedGraphRun"), + provider_type, input_container, output_names, actual_ort_values); + + ASSERT_TRUE(expected_ort_values.size() == actual_ort_values.size()); + constexpr double per_sample_tolerance = 1e-4; + constexpr double relative_per_sample_tolerance = 1e-4; + for (size_t i = 0; i < expected_ort_values.size(); i++) { + auto ret = CompareOrtValue(actual_ort_values[i], expected_ort_values[i], + per_sample_tolerance, relative_per_sample_tolerance, false); + EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; + } + } +} + +TEST(ComputeOptimizerTests, GatherReshape_SlicingOnSeqlenDim2) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + auto model_uri = MODEL_FOLDER "computation_reduction/gather/gather_reshape_seqlen_dim2.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); + + GraphViewer graph_viewer(graph); + // Check the first Gather. + { + const std::vector& consumers = graph.GetConsumerNodes("input1"); + ASSERT_EQ(consumers.size(), 1U); + const Node* gather_node = consumers[0]; + ASSERT_EQ(gather_node->OpType(), "Gather"); + + auto& attrs = gather_node->GetAttributes(); + ASSERT_TRUE(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + ASSERT_EQ(axis_value, 1); + } + + { + const Node* m5 = graph.GetProducerNode("reshape_out"); + ASSERT_FALSE(m5 == nullptr); + EXPECT_EQ(m5->OpType(), "Reshape"); + + const Node* lhs_input = graph.GetProducerNode(m5->InputDefs()[0]->Name()); + const Node* rhs_input = graph.GetProducerNode(m5->InputDefs()[1]->Name()); + + ASSERT_FALSE(lhs_input == nullptr); + EXPECT_EQ(lhs_input->OpType(), "Gather"); + + ASSERT_TRUE(rhs_input == nullptr); + InlinedVector new_shape_const_values; + optimizer_utils::AppendTensorFromInitializer(graph, *m5->InputDefs()[1], new_shape_const_values, true); + ASSERT_EQ(new_shape_const_values.size(), 4U); + ASSERT_EQ(new_shape_const_values[0], 0); + ASSERT_EQ(new_shape_const_values[1], 31); + ASSERT_EQ(new_shape_const_values[2], 16); + ASSERT_EQ(new_shape_const_values[3], 64); + } + + // Check result diff after the re-order + onnxruntime::test::TemporaryDirectory tmp_dir{ORT_TSTR("compute_optimizer_test_tmp_dir")}; + PathString new_model_uri{ConcatPathComponent(tmp_dir.Path(), + ORT_TSTR("gather_reshape_seqlen_dim2_optimized.onnx"))}; + ASSERT_STATUS_OK(Model::Save(*model, new_model_uri)); + + int64_t batch_size = 8; + int64_t sequence_length = 128; + int64_t hidden_size = 1024; + + InputContainer input_container; + + input_container.AddInput("input1", {batch_size, sequence_length, hidden_size}, RandomFillFloatVector); + + static const std::string all_provider_types[] = { + onnxruntime::kCpuExecutionProvider, +#ifdef USE_CUDA + onnxruntime::kCudaExecutionProvider, +#elif USE_ROCM + onnxruntime::kRocmExecutionProvider, +#endif + }; + + const std::vector output_names = {"final_output"}; + + for (auto& provider_type : all_provider_types) { + std::vector expected_ort_values; + RunModelWithData(model_uri, std::string("RawGraphRun"), provider_type, + input_container, output_names, expected_ort_values); + + std::vector actual_ort_values; + RunModelWithData(ToPathString(new_model_uri), std::string("OptimizedGraphRun"), + provider_type, input_container, output_names, actual_ort_values); + + ASSERT_TRUE(expected_ort_values.size() == actual_ort_values.size()); + constexpr double per_sample_tolerance = 1e-4; + constexpr double relative_per_sample_tolerance = 1e-4; + for (size_t i = 0; i < expected_ort_values.size(); i++) { + auto ret = CompareOrtValue(actual_ort_values[i], expected_ort_values[i], + per_sample_tolerance, relative_per_sample_tolerance, false); + EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; + } + } +} + +TEST(ComputeOptimizerTests, GatherRobertaE2E) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + // Be noted, all dropout have ratio be 0.0, to make it easier to compare when running with session. + // This did not affect the transformer tests, because we did not remove the Dropout of ratio 0. in the middle. + auto model_uri = MODEL_FOLDER "computation_reduction/gather/gather_roberta_e2e.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{3}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); + + GraphViewer graph_viewer(graph); + // Check the first Gather. + { + const std::vector& consumers = graph.GetConsumerNodes("c1_out"); + const Node* gather_node = nullptr; + for (auto p_node : consumers) { + ASSERT_FALSE(p_node == nullptr); + if (p_node->OpType().compare("Gather") == 0) { + gather_node = p_node; + const Node* cast_node = graph.GetProducerNode(gather_node->InputDefs()[0]->Name()); + EXPECT_EQ(cast_node->OpType(), "Cast"); + EXPECT_EQ(cast_node->Name(), "c1"); + const auto& gather_consumers = graph.GetConsumerNodes(gather_node->OutputDefs()[0]->Name()); + EXPECT_EQ(gather_consumers[0]->OpType(), "Unsqueeze"); + break; + } + } + + ASSERT_FALSE(gather_node == nullptr); + } + + // Check the second Gather. + { + const std::vector& consumers = graph.GetConsumerNodes("d1_out"); + const Node* gather_node = nullptr; + for (auto p_node : consumers) { + ASSERT_FALSE(p_node == nullptr); + if (p_node->OpType().compare("Gather") == 0) { + gather_node = p_node; + const Node* dropout_node = graph.GetProducerNode(gather_node->InputDefs()[0]->Name()); + EXPECT_EQ(dropout_node->OpType(), "Dropout"); + EXPECT_EQ(dropout_node->Name(), "d1"); + const auto& gather_consumers = graph.GetConsumerNodes(gather_node->OutputDefs()[0]->Name()); + EXPECT_EQ(gather_consumers[0]->OpType(), "Add"); + EXPECT_EQ(gather_consumers[0]->Name(), "a6"); + break; + } + } + + ASSERT_FALSE(gather_node == nullptr); + } + + // Check the input/output of the original Gather node. + { + const std::vector& consumers = graph.GetConsumerNodes("layernorm2_out"); + ASSERT_TRUE(consumers.size() == 1); + ASSERT_FALSE(consumers[0] == nullptr); + EXPECT_EQ(consumers[0]->OpType(), "Dropout"); + EXPECT_EQ(consumers[0]->Name(), "d6"); + } + + // Check MatMul(who gathers on the second last dim)'s input and output. + { + const Node* m5 = graph.GetProducerNode("m5_out"); + ASSERT_FALSE(m5 == nullptr); + EXPECT_EQ(m5->OpType(), "MatMul"); + EXPECT_EQ(m5->Name(), "m5"); + + const Node* lhs_input = graph.GetProducerNode(m5->InputDefs()[0]->Name()); + const Node* rhs_input = graph.GetProducerNode(m5->InputDefs()[1]->Name()); + + ASSERT_FALSE(lhs_input == nullptr); + EXPECT_EQ(lhs_input->OpType(), "Unsqueeze"); + + ASSERT_FALSE(rhs_input == nullptr); + EXPECT_EQ(rhs_input->OpType(), "Transpose"); + EXPECT_EQ(rhs_input->Name(), "transpose1"); + } + + // Check Add(who has broadcastable dim on gather axis)'s input and output. + { + const Node* a4 = graph.GetProducerNode("a4_out"); + ASSERT_FALSE(a4 == nullptr); + EXPECT_EQ(a4->OpType(), "Add"); + EXPECT_EQ(a4->Name(), "a4"); + + const std::vector& consumers = graph.GetConsumerNodes("a4_out"); + ASSERT_TRUE(consumers.size() == 1); + ASSERT_FALSE(consumers[0] == nullptr); + EXPECT_EQ(consumers[0]->OpType(), "Squeeze"); + } + + // Check result diff after the re-order + onnxruntime::test::TemporaryDirectory tmp_dir{ORT_TSTR("compute_optimizer_test_tmp_dir")}; + PathString new_model_uri{ConcatPathComponent(tmp_dir.Path(), + ORT_TSTR("gather_roberta_e2e_optimized.onnx"))}; + ASSERT_STATUS_OK(Model::Save(*model, new_model_uri)); + + int64_t batch_size = 8; + int64_t sequence_length = 16; + int64_t hidden_size = 1024; + + InputContainer input_container; + + input_container.AddInput("input", {batch_size, sequence_length, hidden_size}, RandomFillFloatVector); + + const TensorShapeVector dims_mask = {batch_size, sequence_length}; + std::vector attention_mask(TensorShape(dims_mask).Size(), 1); + input_container.AddInput("attention_mask", dims_mask, attention_mask); + + input_container.AddInput("matmul1.weight", {hidden_size, 1024}, RandomFillHalfVector); + input_container.AddInput("add1.bias", {1024}, RandomFillHalfVector); + + input_container.AddInput("matmul2.weight", {hidden_size, 1024}, RandomFillHalfVector); + input_container.AddInput("add2.bias", {1024}, RandomFillHalfVector); + + input_container.AddInput("matmul3.weight", {hidden_size, 1024}, RandomFillHalfVector); + input_container.AddInput("add3.bias", {1024}, RandomFillHalfVector); + + input_container.AddInput("matmul4.weight", {hidden_size, 1024}, RandomFillHalfVector); + input_container.AddInput("add4.bias", {1024}, RandomFillHalfVector); + + input_container.AddInput("layer_norm1.weight", {hidden_size}, RandomFillFloatVector); + input_container.AddInput("layer_norm1.bias", {hidden_size}, RandomFillFloatVector); + + input_container.AddInput("matmul7.weight", {hidden_size, hidden_size * 4}, RandomFillHalfVector); + input_container.AddInput("add7.bias", {hidden_size * 4}, RandomFillHalfVector); + + input_container.AddInput("matmul8.weight", {hidden_size * 4, hidden_size}, RandomFillHalfVector); + input_container.AddInput("add8.bias", {hidden_size}, RandomFillHalfVector); + + input_container.AddInput("layer_norm2.weight", {hidden_size}, RandomFillFloatVector); + input_container.AddInput("layer_norm2.bias", {hidden_size}, RandomFillFloatVector); + + static const std::string all_provider_types[] = { + onnxruntime::kCpuExecutionProvider, +#ifdef USE_CUDA + onnxruntime::kCudaExecutionProvider, +#elif USE_ROCM + onnxruntime::kRocmExecutionProvider, +#endif + }; + + const std::vector output_names = {"final_output"}; + + for (auto& provider_type : all_provider_types) { + std::vector expected_ort_values; + RunModelWithData(model_uri, std::string("RawGraphRun"), provider_type, + input_container, output_names, expected_ort_values); + + std::vector actual_ort_values; + RunModelWithData(ToPathString(new_model_uri), std::string("OptimizedGraphRun"), + provider_type, input_container, output_names, actual_ort_values); + + ASSERT_TRUE(expected_ort_values.size() == actual_ort_values.size()); + + // "expected 0.793675 (3f4b2e44), got 0.79232 (3f4ad584), diff: 0.00135422, tol=0.000179367 idx=4276. + // 1713 of 8192 differ" + // Loose the atol a bit because we see the MatMuls results differs once we move Gather before it. + constexpr double per_sample_tolerance = 2e-3; + constexpr double relative_per_sample_tolerance = 2e-3; + for (size_t i = 0; i < expected_ort_values.size(); i++) { + auto ret = CompareOrtValue(actual_ort_values[i], expected_ort_values[i], + per_sample_tolerance, relative_per_sample_tolerance, false); + EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; + } + } +} +#endif + +} // namespace test +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index f84058845c..eb0d2c9ebd 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -24,7 +24,6 @@ #include "core/optimizer/bias_softmax_fusion.h" #include "core/optimizer/cast_elimination.h" #include "core/optimizer/common_subexpression_elimination.h" -#include "core/optimizer/computation_reduction.h" #include "core/optimizer/concat_slice_elimination.h" #include "core/optimizer/constant_folding.h" #include "core/optimizer/constant_sharing.h" @@ -5069,199 +5068,6 @@ TEST_F(GraphTransformationTests, MatMulIntegerToFloatTest) { #endif -// LayerNormalization implementation is in contrib namespace (OnnxDomain 1), so -// Without contib_ops enabled, we cannot parse the graph correctly. -#ifndef DISABLE_CONTRIB_OPS -// We used Opset 12 for testing to make sure we are not using GatherND OnnxDomain Opset 1. -static void GatherNDComputationReductionTest(const std::string op_type, logging::Logger& logger) { - std::string op_type_lower = op_type; - std::transform(op_type_lower.begin(), op_type_lower.end(), op_type_lower.begin(), [](unsigned char c) { return std::tolower(c); }); - std::string file_path = std::string("testdata/transform/computation_reduction/gathernd_") + op_type_lower + std::string(".onnx"); - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(ToPathString(file_path), model, nullptr, logger)); - Graph& graph = model->MainGraph(); - std::map op_to_count = CountOpsInGraph(graph); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, logger)); - - GraphViewer graph_viewer(graph); - const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); - - Node* gathernd_node = nullptr; - for (auto node_index : node_topology_list) { - Node* p_node = graph.GetNode(node_index); - ASSERT_FALSE(p_node == nullptr); - if (p_node->OpType().compare("GatherND") == 0) { - gathernd_node = p_node; - EXPECT_EQ(gathernd_node->MutableInputDefs()[0]->Name(), "input"); - const auto& consumers = graph.GetConsumerNodes(gathernd_node->MutableOutputDefs()[0]->Name()); - EXPECT_EQ(consumers[0]->OpType(), op_type); - } - } - - ASSERT_FALSE(gathernd_node == nullptr); -} - -TEST_F(GraphTransformationTests, ComputationReductionTransformer_GatherND_Gelu) { - GatherNDComputationReductionTest("Gelu", *logger_); -} - -TEST_F(GraphTransformationTests, ComputationReductionTransformer_GatherND_Add) { - GatherNDComputationReductionTest("Add", *logger_); -} - -TEST_F(GraphTransformationTests, ComputationReductionTransformer_GatherND_LayerNormalization) { - GatherNDComputationReductionTest("LayerNormalization", *logger_); -} - -TEST_F(GraphTransformationTests, ComputationReductionTransformer_GatherND_MatMul) { - GatherNDComputationReductionTest("MatMul", *logger_); -} - -static void RunGatherNDE2EGraph(std::vector& run_results, const PathString& model_uri, - const std::string session_log_id, const std::string& provider_type, - const std::vector& dims_input, - const std::vector& input_values, - const std::vector& dims_unsqueezed_masked_lm_positions, - const std::vector& values_unsqueezed_masked_lm_positions) { - SessionOptions so; - // we don't want any transformation here. - so.graph_optimization_level = TransformerLevel::Default; - so.session_logid = session_log_id; - - InferenceSession session_object{so, GetEnvironment()}; - std::unique_ptr execution_provider; - if (provider_type == onnxruntime::kCpuExecutionProvider) - execution_provider = DefaultCpuExecutionProvider(); - else if (provider_type == onnxruntime::kCudaExecutionProvider) - execution_provider = DefaultCudaExecutionProvider(); - else if (provider_type == onnxruntime::kRocmExecutionProvider) - execution_provider = DefaultRocmExecutionProvider(); - EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); - - Status st; - ASSERT_TRUE((st = session_object.Load(model_uri)).IsOK()) << st; - ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st; - - OrtValue input1; - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_input, input_values, &input1); - OrtValue input2; - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_unsqueezed_masked_lm_positions, - values_unsqueezed_masked_lm_positions, &input2); - - NameMLValMap feeds; - feeds.insert(std::make_pair("input", input1)); - feeds.insert(std::make_pair("unsqueezed_masked_lm_positions", input2)); - - // prepare outputs - std::vector output_names; - output_names.push_back("output"); - output_names.push_back("gather_output"); - - // Now run - RunOptions run_options; - st = session_object.Run(run_options, feeds, output_names, &run_results); - - EXPECT_TRUE(st.IsOK()); -} - -TEST_F(GraphTransformationTests, ComputationReductionTransformer_GatherND_E2E) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "computation_reduction/e2e.onnx"; - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); - Graph& graph = model->MainGraph(); - std::map op_to_count = CountOpsInGraph(graph); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); - - // check the expected node orders. - { - GraphViewer graph_viewer(graph); - const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); - - Node* gathernd_node = nullptr; - for (auto node_index : node_topology_list) { - Node* p_node = graph.GetNode(node_index); - ASSERT_FALSE(p_node == nullptr); - if (p_node->OpType().compare("GatherND") == 0) { - gathernd_node = p_node; - const Node* layer_norm_node = graph.GetProducerNode(gathernd_node->MutableInputDefs()[0]->Name()); - EXPECT_EQ(layer_norm_node->OpType(), "LayerNormalization"); - EXPECT_EQ(layer_norm_node->Name(), "layer_norm_1"); - const auto& consumers = graph.GetConsumerNodes(gathernd_node->MutableOutputDefs()[0]->Name()); - EXPECT_EQ(consumers[0]->OpType(), "MatMul"); - EXPECT_EQ(consumers[0]->Name(), "matmul_1"); - break; - } - } - - ASSERT_FALSE(gathernd_node == nullptr); - } - - // check result diff after the re-order - auto new_model_uri = "computation_reduction_transformer_after.onnx"; - ASSERT_STATUS_OK(Model::Save(*model, new_model_uri)); - - float scale = 1.f; - float mean = 0.f; - float seed = 123.f; - std::default_random_engine generator_float{gsl::narrow_cast(seed)}; - std::normal_distribution distribution_float{mean, scale}; - - int batch_size = 8; - int sequence = 128; - int hidden_size = 128; - int dynamic_predict_count = 20; - const std::vector dims_input = {batch_size, sequence, hidden_size}; - std::vector input_values(TensorShape(dims_input).Size()); - std::for_each(input_values.begin(), input_values.end(), - [&generator_float, &distribution_float](float& value) { value = distribution_float(generator_float); }); - - const std::vector dims_unsqueezed_masked_lm_positions = {batch_size, dynamic_predict_count, 1}; - std::vector values_unsqueezed_masked_lm_positions(TensorShape(dims_unsqueezed_masked_lm_positions).Size()); - - std::random_device rd; // obtain a random number from hardware - std::mt19937 eng(rd()); // seed the generator - std::uniform_int_distribution<> distr(0, sequence - 1); // define the range - std::for_each(values_unsqueezed_masked_lm_positions.begin(), values_unsqueezed_masked_lm_positions.end(), - [&distr, &eng](int64_t& value) { value = distr(eng); }); - - static const std::string all_provider_types[] = { - onnxruntime::kCpuExecutionProvider, -#ifdef USE_CUDA - onnxruntime::kCudaExecutionProvider, -#elif USE_ROCM - onnxruntime::kRocmExecutionProvider, -#endif - }; - - for (auto& provider_type : all_provider_types) { - std::vector expected_ort_values; - RunGatherNDE2EGraph(expected_ort_values, model_uri, std::string("RawGraphRun"), provider_type, - dims_input, input_values, dims_unsqueezed_masked_lm_positions, - values_unsqueezed_masked_lm_positions); - - std::vector actual_ort_values; - RunGatherNDE2EGraph(actual_ort_values, ToPathString(new_model_uri), std::string("OptimizedGraphRun"), provider_type, - dims_input, input_values, dims_unsqueezed_masked_lm_positions, - values_unsqueezed_masked_lm_positions); - - ASSERT_TRUE(expected_ort_values.size() == actual_ort_values.size()); - constexpr double per_sample_tolerance = 1e-4; - constexpr double relative_per_sample_tolerance = 1e-4; - for (size_t i = 0; i < expected_ort_values.size(); i++) { - auto ret = CompareOrtValue(actual_ort_values[i], expected_ort_values[i], - per_sample_tolerance, relative_per_sample_tolerance, false); - EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; - } - } -} -#endif - #ifndef DISABLE_CONTRIB_OPS template static void TestMatMulScaleFusion( diff --git a/onnxruntime/test/testdata/transform/computation_reduction/e2e.onnx b/onnxruntime/test/testdata/transform/computation_reduction/e2e.onnx deleted file mode 100755 index b3961b84af..0000000000 Binary files a/onnxruntime/test/testdata/transform/computation_reduction/e2e.onnx and /dev/null differ diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_matmul.py b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_matmul.py new file mode 100755 index 0000000000..e70cd45970 --- /dev/null +++ b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_matmul.py @@ -0,0 +1,65 @@ +import onnx +from onnx import OperatorSetIdProto, TensorProto, helper + + +def _create_model_proto(output_shapes, axis_to_gather, slice_dims, slices_values, model_name): + # inputs and outputs + hidden = 1024 + inputs = [ + helper.make_tensor_value_info("input1", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden]), + helper.make_tensor_value_info("input2", TensorProto.FLOAT, ["batch_size", hidden, "sequence_length"]), + ] + + outputs = [ + helper.make_tensor_value_info("final_output", TensorProto.FLOAT, output_shapes), + ] + + # initializers + + initializers = [ + helper.make_tensor("slices", TensorProto.INT64, slice_dims, slices_values), + ] + + # nodes + nodes = [ + helper.make_node("MatMul", ["input1", "input2"], ["m1_out"], "m1"), + helper.make_node("Gather", ["m1_out", "slices"], ["gather_out"], "gather", axis=axis_to_gather), + helper.make_node("Identity", ["gather_out"], ["final_output"], "identity1"), + ] + + # Create the graph (GraphProto) + graph_def = helper.make_graph( + nodes, + "test-model", + inputs, + outputs, + initializers, + "doc string", + ) + + opsets = [] + onnxdomain = OperatorSetIdProto() + onnxdomain.version = 14 + onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator + # set that is defined as part of the ONNX specification. + opsets.append(onnxdomain) + + msdomain = OperatorSetIdProto() + msdomain.version = 1 + msdomain.domain = "com.microsoft" + + opsets.append(msdomain) + kwargs = {} + kwargs["opset_imports"] = opsets + + model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) + final_model = onnx.shape_inference.infer_shapes(model_def) + onnx.save(final_model, model_name + ".onnx") + + +_create_model_proto(["sequence_length", "sequence_length"], 0, [], [1], "gather_matmul_scalar_batch_dim") +_create_model_proto([1, "sequence_length", "sequence_length"], 0, [1], [1], "gather_matmul_batch_dim") +_create_model_proto(["batch_size", "sequence_length"], 1, [], [1], "gather_matmul_scalar_second_last_dim") +_create_model_proto(["batch_size", 1, "sequence_length"], 1, [1], [1], "gather_matmul_second_last_dim") +_create_model_proto(["batch_size", "sequence_length"], 2, [], [1], "gather_matmul_scalar_last_dim") +_create_model_proto(["batch_size", "sequence_length", 1], 2, [1], [1], "gather_matmul_last_dim") diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_matmul_batch_dim.onnx b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_matmul_batch_dim.onnx new file mode 100644 index 0000000000..d99711f69e Binary files /dev/null and b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_matmul_batch_dim.onnx differ diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_matmul_last_dim.onnx b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_matmul_last_dim.onnx new file mode 100644 index 0000000000..abd8b59040 Binary files /dev/null and b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_matmul_last_dim.onnx differ diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_matmul_scalar_batch_dim.onnx b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_matmul_scalar_batch_dim.onnx new file mode 100644 index 0000000000..32dda1e228 Binary files /dev/null and b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_matmul_scalar_batch_dim.onnx differ diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_matmul_scalar_last_dim.onnx b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_matmul_scalar_last_dim.onnx new file mode 100644 index 0000000000..4e6d1a3fab Binary files /dev/null and b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_matmul_scalar_last_dim.onnx differ diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_matmul_scalar_second_last_dim.onnx b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_matmul_scalar_second_last_dim.onnx new file mode 100644 index 0000000000..c1c72ada15 Binary files /dev/null and b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_matmul_scalar_second_last_dim.onnx differ diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_matmul_second_last_dim.onnx b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_matmul_second_last_dim.onnx new file mode 100644 index 0000000000..b8296a4dc7 Binary files /dev/null and b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_matmul_second_last_dim.onnx differ diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_reshape.py b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_reshape.py new file mode 100755 index 0000000000..5b3d841e3f --- /dev/null +++ b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_reshape.py @@ -0,0 +1,89 @@ +import onnx +from onnx import OperatorSetIdProto, TensorProto, helper + +hidden = 1024 +head = 16 + + +def _create_model_proto( + input_shapes, output_shapes, axis_to_gather, slice_dims, slices_values, shape_dims, shape_values, model_name +): + # inputs and outputs + inputs = [ + helper.make_tensor_value_info("input1", TensorProto.FLOAT, input_shapes), + ] + + outputs = [ + helper.make_tensor_value_info("final_output", TensorProto.FLOAT, output_shapes), + ] + + # initializers + + initializers = [ + helper.make_tensor("shape", TensorProto.INT64, shape_dims, shape_values), + helper.make_tensor("slices", TensorProto.INT64, slice_dims, slices_values), + ] + + # nodes + nodes = [ + helper.make_node("Reshape", ["input1", "shape"], ["reshape_out"], "reshape1"), + helper.make_node("Gather", ["reshape_out", "slices"], ["gather_out"], "gather", axis=axis_to_gather), + helper.make_node("Identity", ["gather_out"], ["final_output"], "identity1"), + ] + + # Create the graph (GraphProto) + graph_def = helper.make_graph( + nodes, + "test-model", + inputs, + outputs, + initializers, + "doc string", + ) + + opsets = [] + onnxdomain = OperatorSetIdProto() + onnxdomain.version = 14 + onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator + # set that is defined as part of the ONNX specification. + opsets.append(onnxdomain) + + msdomain = OperatorSetIdProto() + msdomain.version = 1 + msdomain.domain = "com.microsoft" + + opsets.append(msdomain) + kwargs = {} + kwargs["opset_imports"] = opsets + + model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) + final_model = onnx.shape_inference.infer_shapes(model_def) + onnx.save(final_model, model_name + ".onnx") + + +input_shapes1 = ["batch_size", "sequence_length", hidden] +_create_model_proto( + input_shapes1, ["sequence_length", 16, 64], 0, [], [1], [4], [0, 0, 16, 64], "gather_reshape_scalar_batch_dim" +) +_create_model_proto( + input_shapes1, [1, "sequence_length", 16, 64], 0, [1], [1], [4], [0, 0, 16, 64], "gather_reshape_batch_dim" +) +_create_model_proto( + input_shapes1, ["batch_size", 16, 64], 1, [], [1], [4], [0, 0, 16, 64], "gather_reshape_scalar_seqlen_dim" +) +_create_model_proto( + input_shapes1, ["batch_size", 1, 16, 64], 1, [1], [1], [4], [0, 0, 16, 64], "gather_reshape_seqlen_dim" +) + + +input_shapes2 = ["batch_size", 128, hidden] +_create_model_proto( + input_shapes2, + ["batch_size", 31, 16, 64], + 1, + [31], + [i for i in range(31)], + [4], + [0, 128, 16, 64], + "gather_reshape_seqlen_dim2", +) diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_reshape_batch_dim.onnx b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_reshape_batch_dim.onnx new file mode 100644 index 0000000000..cc999b9ce9 Binary files /dev/null and b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_reshape_batch_dim.onnx differ diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_reshape_scalar_batch_dim.onnx b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_reshape_scalar_batch_dim.onnx new file mode 100644 index 0000000000..b16a9501e6 Binary files /dev/null and b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_reshape_scalar_batch_dim.onnx differ diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_reshape_scalar_seqlen_dim.onnx b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_reshape_scalar_seqlen_dim.onnx new file mode 100644 index 0000000000..eafa117c87 Binary files /dev/null and b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_reshape_scalar_seqlen_dim.onnx differ diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_reshape_scalar_seqlen_dim2.onnx b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_reshape_scalar_seqlen_dim2.onnx new file mode 100644 index 0000000000..65a33bbbf1 Binary files /dev/null and b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_reshape_scalar_seqlen_dim2.onnx differ diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_reshape_seqlen_dim.onnx b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_reshape_seqlen_dim.onnx new file mode 100644 index 0000000000..db9d9a829b Binary files /dev/null and b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_reshape_seqlen_dim.onnx differ diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_reshape_seqlen_dim2.onnx b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_reshape_seqlen_dim2.onnx new file mode 100644 index 0000000000..5291a07c51 Binary files /dev/null and b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_reshape_seqlen_dim2.onnx differ diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_roberta_e2e.onnx b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_roberta_e2e.onnx new file mode 100644 index 0000000000..32a9e8c100 Binary files /dev/null and b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_roberta_e2e.onnx differ diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_roberta_e2e.py b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_roberta_e2e.py new file mode 100755 index 0000000000..bf07a8045d --- /dev/null +++ b/onnxruntime/test/testdata/transform/computation_reduction/gather/gather_roberta_e2e.py @@ -0,0 +1,211 @@ +import onnx +from onnx import OperatorSetIdProto, TensorProto, helper + +# inputs and outputs +hidden = 1024 +head = 16 +inputs = [ + helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden]), + helper.make_tensor_value_info("attention_mask", TensorProto.INT64, ["batch_size", "sequence_length"]), + helper.make_tensor_value_info("matmul1.weight", TensorProto.FLOAT16, [hidden, 1024]), + helper.make_tensor_value_info("add1.bias", TensorProto.FLOAT16, [hidden]), + helper.make_tensor_value_info("matmul2.weight", TensorProto.FLOAT16, [hidden, 1024]), + helper.make_tensor_value_info("add2.bias", TensorProto.FLOAT16, [hidden]), + helper.make_tensor_value_info("matmul3.weight", TensorProto.FLOAT16, [hidden, 1024]), + helper.make_tensor_value_info("add3.bias", TensorProto.FLOAT16, [hidden]), + helper.make_tensor_value_info("matmul4.weight", TensorProto.FLOAT16, [hidden, 1024]), + helper.make_tensor_value_info("add4.bias", TensorProto.FLOAT16, [hidden]), + helper.make_tensor_value_info("layer_norm1.weight", TensorProto.FLOAT, [hidden]), + helper.make_tensor_value_info("layer_norm1.bias", TensorProto.FLOAT, [hidden]), + helper.make_tensor_value_info("matmul7.weight", TensorProto.FLOAT16, [hidden, hidden * 4]), + helper.make_tensor_value_info("add7.bias", TensorProto.FLOAT16, [hidden * 4]), + helper.make_tensor_value_info("matmul8.weight", TensorProto.FLOAT16, [hidden * 4, hidden]), + helper.make_tensor_value_info("add8.bias", TensorProto.FLOAT16, [hidden]), + helper.make_tensor_value_info("layer_norm2.weight", TensorProto.FLOAT, [hidden]), + helper.make_tensor_value_info("layer_norm2.bias", TensorProto.FLOAT, [hidden]), +] + +outputs = [ + helper.make_tensor_value_info("final_output", TensorProto.FLOAT, ["batch_size", hidden]), +] + +# initializers + +initializers = [ + helper.make_tensor("scalar_float_0.1", TensorProto.FLOAT, [], [0.1]), + helper.make_tensor("scalar_float_0", TensorProto.FLOAT, [], [0.0]), + helper.make_tensor("scalar_float16_8", TensorProto.FLOAT16, [], [8]), + helper.make_tensor("scalar_bool_true", TensorProto.BOOL, [], [1]), + helper.make_tensor("scalar_float_1", TensorProto.FLOAT, [], [1]), + helper.make_tensor("scalar_float_big_num", TensorProto.FLOAT, [], [-3.4028234663852886e38]), + helper.make_tensor("scalar_int_0", TensorProto.INT64, [], [0]), + helper.make_tensor("scalar_int_1", TensorProto.INT64, [], [1]), + helper.make_tensor("single_value_1d_int_0", TensorProto.INT64, [1], [0]), + helper.make_tensor("single_value_1d_int_1", TensorProto.INT64, [1], [1]), + helper.make_tensor("single_value_1d_int_2", TensorProto.INT64, [1], [2]), + helper.make_tensor("single_value_1d_int_16", TensorProto.INT64, [1], [head]), + helper.make_tensor("single_value_1d_int_64", TensorProto.INT64, [1], [hidden // head]), + helper.make_tensor("single_value_1d_int_1024", TensorProto.INT64, [1], [hidden]), + helper.make_tensor("shape1", TensorProto.INT64, [4], [0, 0, 16, 64]), + helper.make_tensor("shape2", TensorProto.INT64, [4], [0, 0, 16, 64]), + helper.make_tensor("shape3", TensorProto.INT64, [4], [0, 0, 16, 64]), + helper.make_tensor("shape4", TensorProto.INT64, [3], [0, 0, 1024]), +] + +# nodes + +nodes = [ + helper.make_node("Dropout", ["input", "scalar_float_0", "scalar_bool_true"], ["d1_out", "d1_mask"], "d1"), + helper.make_node("Cast", ["d1_out"], ["c1_out"], name="c1", to=10), + # attention + ## left branch + helper.make_node("MatMul", ["c1_out", "matmul1.weight"], ["m1_out"], "m1"), + helper.make_node("Add", ["add1.bias", "m1_out"], ["a1_out"], "a1"), + helper.make_node("Reshape", ["a1_out", "shape1"], ["reshape1_out"], "reshape1"), + helper.make_node("Transpose", ["reshape1_out"], ["transpose1_out"], name="transpose1", perm=[0, 2, 1, 3]), + ## middle branch + helper.make_node("MatMul", ["c1_out", "matmul2.weight"], ["m2_out"], "m2"), + helper.make_node("Add", ["add2.bias", "m2_out"], ["a2_out"], "a2"), + helper.make_node("Reshape", ["a2_out", "shape2"], ["reshape2_out"], "reshape2"), + helper.make_node("Transpose", ["reshape2_out"], ["transpose2_out"], name="transpose2", perm=[0, 2, 1, 3]), + ## right banch + helper.make_node("MatMul", ["c1_out", "matmul3.weight"], ["m3_out"], "m3"), + helper.make_node("Add", ["add3.bias", "m3_out"], ["a3_out"], "a3"), + helper.make_node("Reshape", ["a3_out", "shape3"], ["reshape3_out"], "reshape3"), + helper.make_node("Transpose", ["reshape3_out"], ["transpose3_out"], name="transpose3", perm=[0, 2, 3, 1]), + ## middle branch result computes with right branch result + helper.make_node("MatMul", ["transpose2_out", "transpose3_out"], ["m4_out"], "m4"), + helper.make_node("Div", ["m4_out", "scalar_float16_8"], ["div1_out"], "div1"), + helper.make_node("Cast", ["div1_out"], ["c2_out"], name="c2", to=1), + helper.make_node("Unsqueeze", ["attention_mask", "single_value_1d_int_1"], ["unsqueeze7_out"], "unsqueeze7"), + helper.make_node("Unsqueeze", ["unsqueeze7_out", "single_value_1d_int_2"], ["unsqueeze8_out"], "unsqueeze8"), + helper.make_node("Cast", ["unsqueeze8_out"], ["c3_out"], name="c3", to=1), + helper.make_node("Sub", ["scalar_float_1", "c3_out"], ["sub1_out"], "sub1"), + helper.make_node("Mul", ["sub1_out", "scalar_float_big_num"], ["mul1_out"], "mul1"), + helper.make_node("Add", ["mul1_out", "c2_out"], ["a4_out"], "a4"), + helper.make_node("Softmax", ["a4_out"], ["softmax1_out"], "softmax1", axis=-1), + helper.make_node("Dropout", ["softmax1_out", "scalar_float_0", "scalar_bool_true"], ["d2_out", "d2_mask"], "d2"), + helper.make_node("Cast", ["d2_out"], ["c4_out"], name="c4", to=10), + ## left branch result computes with result of `middle branch result computes with right branch result`` + helper.make_node("MatMul", ["c4_out", "transpose1_out"], ["m5_out"], "m5"), + helper.make_node("Transpose", ["m5_out"], ["tranpose4_out"], name="tranpose4", perm=[0, 2, 1, 3]), + helper.make_node("Reshape", ["tranpose4_out", "shape4"], ["reshape4_out"], "reshape4"), + ## attention output + helper.make_node("MatMul", ["reshape4_out", "matmul4.weight"], ["m6_out"], "m6"), + helper.make_node("Add", ["add4.bias", "m6_out"], ["a5_out"], "a5"), + helper.make_node("Dropout", ["a5_out", "scalar_float_0", "scalar_bool_true"], ["d4_out", 'd4_mask"'], "d4"), + helper.make_node("Cast", ["d4_out"], ["c5_out"], name="c5", to=1), + helper.make_node("Add", ["d1_out", "c5_out"], ["a6_out"], "a6"), + # MLP + helper.make_node( + "LayerNormalization", + ["a6_out", "layer_norm1.weight", "layer_norm1.bias"], + ["layernorm1_out", "layernorm1_mean", "layernorm1_var"], + "layernorm1", + axis=-1, + epsion=0.000009999999747378752, + ), + helper.make_node("Cast", ["layernorm1_out"], ["c6_out"], name="c6", to=10), + helper.make_node("MatMul", ["c6_out", "matmul7.weight"], ["m7_out"], "m7"), + helper.make_node("BiasGelu", ["m7_out", "add7.bias"], ["biasgelu1_out"], "biasgelu1", domain="com.microsoft"), + helper.make_node("MatMul", ["biasgelu1_out", "matmul8.weight"], ["m8_out"], "m8"), + helper.make_node("Add", ["add8.bias", "m8_out"], ["a7_out"], "a7"), + helper.make_node("Dropout", ["a7_out", "scalar_float_0", "scalar_bool_true"], ["d5_out", "d5_mask"], "d5"), + helper.make_node("Cast", ["d5_out"], ["c7_out"], name="c7", to=1), + helper.make_node("Add", ["layernorm1_out", "c7_out"], ["a8_out"], "a8"), + helper.make_node( + "LayerNormalization", + ["a8_out", "layer_norm2.weight", "layer_norm2.bias"], + ["layernorm2_out", "layernorm2_mean", "layernorm2_var"], + "layernorm2", + axis=-1, + epsion=0.000009999999747378752, + ), + helper.make_node("Gather", ["layernorm2_out", "scalar_int_0"], ["final_gather_out"], "final_gather", axis=1), + helper.make_node( + "Dropout", ["final_gather_out", "scalar_float_0", "scalar_bool_true"], ["final_output", "d6_mask"], "d6" + ), +] + + +# Shapes that cannot be inferred by onnx shape inference +value_infos = [ + helper.make_value_info( + name="reshape1_out", + type_proto=helper.make_tensor_type_proto( + elem_type=TensorProto.FLOAT16, shape=["batch_size", "sequence_length", head, hidden // head] + ), + ), + helper.make_value_info( + name="reshape2_out", + type_proto=helper.make_tensor_type_proto( + elem_type=TensorProto.FLOAT16, shape=["batch_size", "sequence_length", head, hidden // head] + ), + ), + helper.make_value_info( + name="reshape3_out", + type_proto=helper.make_tensor_type_proto( + elem_type=TensorProto.FLOAT16, shape=["batch_size", "sequence_length", head, hidden // head] + ), + ), + helper.make_value_info( + name="reshape4_out", + type_proto=helper.make_tensor_type_proto( + elem_type=TensorProto.FLOAT16, shape=["batch_size", "sequence_length", hidden] + ), + ), + helper.make_value_info( + name="layernorm1_out", + type_proto=helper.make_tensor_type_proto( + elem_type=TensorProto.FLOAT, shape=["batch_size", "sequence_length", hidden] + ), + ), + helper.make_value_info( + name="layernorm2_out", + type_proto=helper.make_tensor_type_proto( + elem_type=TensorProto.FLOAT, shape=["batch_size", "sequence_length", hidden] + ), + ), + helper.make_value_info( + name="concattraining4_out", + type_proto=helper.make_tensor_type_proto(elem_type=TensorProto.INT64, shape=[3]), + ), + helper.make_value_info( + name="biasgelu1_out", + type_proto=helper.make_tensor_type_proto( + elem_type=TensorProto.FLOAT16, shape=["batch_size", "sequence_length", hidden * 4] + ), + ), +] + +# Create the graph (GraphProto) +graph_def = helper.make_graph( + nodes, + "test-model", + inputs, + outputs, + initializers, + "doc string", + value_infos, +) + + +opsets = [] +onnxdomain = OperatorSetIdProto() +onnxdomain.version = 14 +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator +# set that is defined as part of the ONNX specification. +opsets.append(onnxdomain) + +msdomain = OperatorSetIdProto() +msdomain.version = 1 +msdomain.domain = "com.microsoft" + +opsets.append(msdomain) +kwargs = {} +kwargs["opset_imports"] = opsets + + +model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) +final_model = onnx.shape_inference.infer_shapes(model_def) +onnx.save(final_model, "gather_roberta_e2e.onnx") diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd/e2e.onnx b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/e2e.onnx new file mode 100755 index 0000000000..420967be64 Binary files /dev/null and b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/e2e.onnx differ diff --git a/onnxruntime/test/testdata/transform/computation_reduction/e2e.py b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/e2e.py similarity index 88% rename from onnxruntime/test/testdata/transform/computation_reduction/e2e.py rename to onnxruntime/test/testdata/transform/computation_reduction/gathernd/e2e.py index 9f7450bb6a..c0e42e8a37 100755 --- a/onnxruntime/test/testdata/transform/computation_reduction/e2e.py +++ b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/e2e.py @@ -1,8 +1,8 @@ import numpy as np import onnx -from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper +from onnx import OperatorSetIdProto, TensorProto, helper, numpy_helper -vocab_size = 256 # 30258 +vocab_size = 256 X = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", "seqlen", 128]) unsqueezed_masked_lm_positions = helper.make_tensor_value_info( @@ -127,8 +127,9 @@ graph_def = helper.make_graph( opsets = [] onnxdomain = OperatorSetIdProto() -onnxdomain.version = 12 -onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +onnxdomain.version = 14 +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator +# set that is defined as part of the ONNX specification. opsets.append(onnxdomain) msdomain = OperatorSetIdProto() @@ -141,4 +142,18 @@ kwargs["opset_imports"] = opsets model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) -onnx.save(model_def, "e2e.onnx") + +ln1_value_info = model_def.graph.value_info.add() +ln1_value_info.CopyFrom(X) +ln1_value_info.name = "layer_norm1" + +ln2_value_info = model_def.graph.value_info.add() +ln2_value_info.CopyFrom(X) +ln2_value_info.name = "layer_norm2" + +gelu1_value_info = model_def.graph.value_info.add() +gelu1_value_info.CopyFrom(X) +gelu1_value_info.name = "gelu1" + +final_model = onnx.shape_inference.infer_shapes(model_def) +onnx.save(final_model, "e2e.onnx") diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_add.onnx b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_add.onnx new file mode 100644 index 0000000000..5845896edf Binary files /dev/null and b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_add.onnx differ diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_add.py b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_add.py similarity index 94% rename from onnxruntime/test/testdata/transform/computation_reduction/gathernd_add.py rename to onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_add.py index 0d32f081a9..ec0fdc888b 100755 --- a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_add.py +++ b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_add.py @@ -36,12 +36,15 @@ nodes.append(add2) gathernd2 = helper.make_node( "GatherND", ["add_2", "unsqueezed_masked_lm_positions"], - ["output2"], + ["gathernd_out"], name="gathernd_2", batch_dims=1, ) nodes.append(gathernd2) +identity = helper.make_node("Identity", ["gathernd_out"], ["output2"], name="identity") +nodes.append(identity) + graph_def = helper.make_graph( nodes, "test-model", diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_div.onnx b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_div.onnx new file mode 100644 index 0000000000..5bd65f9e3f Binary files /dev/null and b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_div.onnx differ diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_div.py b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_div.py similarity index 94% rename from onnxruntime/test/testdata/transform/computation_reduction/gathernd_div.py rename to onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_div.py index c5814dd4d8..d14f8a71ad 100755 --- a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_div.py +++ b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_div.py @@ -36,12 +36,15 @@ nodes.append(div2) gathernd2 = helper.make_node( "GatherND", ["div_2", "unsqueezed_masked_lm_positions"], - ["output2"], + ["gathernd_out"], name="gathernd_2", batch_dims=1, ) nodes.append(gathernd2) +identity = helper.make_node("Identity", ["gathernd_out"], ["output2"], name="identity") +nodes.append(identity) + graph_def = helper.make_graph( nodes, "test-model", diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_gelu.onnx b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_gelu.onnx similarity index 69% rename from onnxruntime/test/testdata/transform/computation_reduction/gathernd_gelu.onnx rename to onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_gelu.onnx index fcbab3838d..94e9bb142b 100644 Binary files a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_gelu.onnx and b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_gelu.onnx differ diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_gelu.py b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_gelu.py similarity index 91% rename from onnxruntime/test/testdata/transform/computation_reduction/gathernd_gelu.py rename to onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_gelu.py index ea97a62886..eade1b868b 100755 --- a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_gelu.py +++ b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_gelu.py @@ -18,12 +18,15 @@ nodes.append(gelu1) gathernd1 = helper.make_node( "GatherND", ["gelu_1", "unsqueezed_masked_lm_positions"], - ["output"], + ["gathernd_out"], name="gathernd_1", batch_dims=1, ) nodes.append(gathernd1) +identity = helper.make_node("Identity", ["gathernd_out"], ["output"], name="identity") +nodes.append(identity) + graph_def = helper.make_graph(nodes, "test-model", [X, unsqueezed_masked_lm_positions], [Y]) opsets = [] diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_layernormalization.onnx b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_layernormalization.onnx new file mode 100644 index 0000000000..b817a0e296 Binary files /dev/null and b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_layernormalization.onnx differ diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_layernormalization.py b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_layernormalization.py similarity index 94% rename from onnxruntime/test/testdata/transform/computation_reduction/gathernd_layernormalization.py rename to onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_layernormalization.py index eb63c76902..9473d05010 100755 --- a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_layernormalization.py +++ b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_layernormalization.py @@ -34,12 +34,15 @@ nodes.append(layer_norm1) gathernd1 = helper.make_node( "GatherND", ["layer_norm1", "unsqueezed_masked_lm_positions"], - ["output"], + ["gathernd_out"], name="gathernd_1", batch_dims=1, ) nodes.append(gathernd1) +identity = helper.make_node("Identity", ["gathernd_out"], ["output"], name="identity") +nodes.append(identity) + graph_def = helper.make_graph( nodes, "test-model", diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_matmul.onnx b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_matmul.onnx new file mode 100644 index 0000000000..f683cb73cc Binary files /dev/null and b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_matmul.onnx differ diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_matmul.py b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_matmul.py similarity index 92% rename from onnxruntime/test/testdata/transform/computation_reduction/gathernd_matmul.py rename to onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_matmul.py index 2b7ea7127d..50167bbd0a 100755 --- a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_matmul.py +++ b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_matmul.py @@ -20,12 +20,15 @@ nodes.append(matmul1) gathernd1 = helper.make_node( "GatherND", ["matmul1", "unsqueezed_masked_lm_positions"], - ["output"], + ["gathernd_out"], name="gathernd_1", batch_dims=1, ) nodes.append(gathernd1) +identity = helper.make_node("Identity", ["gathernd_out"], ["output"], name="identity") +nodes.append(identity) + initializers = [matmul1_initializer] graph_def = helper.make_graph(nodes, "test-model", [X, unsqueezed_masked_lm_positions], [Y], initializers) diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_add.onnx b/onnxruntime/test/testdata/transform/computation_reduction/gathernd_add.onnx deleted file mode 100644 index a54466163b..0000000000 Binary files a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_add.onnx and /dev/null differ diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_div.onnx b/onnxruntime/test/testdata/transform/computation_reduction/gathernd_div.onnx deleted file mode 100644 index 21ee3ad27e..0000000000 Binary files a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_div.onnx and /dev/null differ diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_layernormalization.onnx b/onnxruntime/test/testdata/transform/computation_reduction/gathernd_layernormalization.onnx deleted file mode 100644 index 4098cdd8a0..0000000000 Binary files a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_layernormalization.onnx and /dev/null differ diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_matmul.onnx b/onnxruntime/test/testdata/transform/computation_reduction/gathernd_matmul.onnx deleted file mode 100644 index 0295c60154..0000000000 Binary files a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_matmul.onnx and /dev/null differ diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_config.h b/orttraining/orttraining/core/optimizer/graph_transformer_config.h index 3b1bf22ae2..755d878029 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_config.h +++ b/orttraining/orttraining/core/optimizer/graph_transformer_config.h @@ -21,6 +21,9 @@ struct TrainingGraphTransformerConfiguration : public GraphTransformerConfigurat bool transformer_layer_recompute{false}; // Number of layers to apply recompute int number_recompute_layers{0}; + + // Enable compute optimizer. + bool enable_compute_optimizer{false}; }; } // namespace training diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 9745a1b36d..61b3fd1453 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -10,7 +10,7 @@ #include "core/optimizer/bias_softmax_fusion.h" #include "core/optimizer/cast_elimination.h" #include "core/optimizer/common_subexpression_elimination.h" -#include "core/optimizer/computation_reduction.h" +#include "core/optimizer/compute_optimizer/compute_optimizer.h" #include "core/optimizer/concat_slice_elimination.h" #include "core/optimizer/constant_folding.h" #include "core/optimizer/conv_activation_fusion.h" @@ -94,7 +94,11 @@ std::vector> GeneratePreTrainingTransformers( ORT_THROW_IF_ERROR(rule_transformer->Register(std::make_unique())); // Remove duplicate nodes. Must be applied before any recompute transformations. - transformers.emplace_back(std::make_unique(compatible_eps)); + if (config.gelu_recompute || config.attn_dropout_recompute || config.transformer_layer_recompute) { + transformers.emplace_back(std::make_unique(compatible_eps)); + } else { + transformers.emplace_back(std::make_unique(compatible_eps)); + } transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); @@ -120,9 +124,10 @@ std::vector> GeneratePreTrainingTransformers( execution_provider, false /*skip_dequantize_linear*/, compatible_eps, excluded_initializers)); transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); -#if defined(USE_CUDA) || defined(USE_ROCM) - transformers.emplace_back(std::make_unique(compatible_eps)); -#endif + + if (config.enable_compute_optimizer) { + transformers.emplace_back(std::make_unique(compatible_eps)); + } if (config.gelu_recompute) { transformers.emplace_back(std::make_unique()); } @@ -195,11 +200,11 @@ InlinedVector> GenerateTransformers( case TransformerLevel::Level1: { InlinedHashSet l1_execution_providers = {}; InlinedHashSet cuda_rocm_execution_providers = {onnxruntime::kCudaExecutionProvider, - onnxruntime::kRocmExecutionProvider}; + onnxruntime::kRocmExecutionProvider}; // TODO hack - constant folding currently doesn't work after mixed precision transformation so it's disabled for now // ORT uses CPU kernels to evaluate constant values but some of them don't support fp16 - //transformers.emplace_back(std::make_unique(l1_execution_providers)); + // transformers.emplace_back(std::make_unique(l1_execution_providers)); transformers.emplace_back(std::make_unique(l1_execution_providers)); transformers.emplace_back(std::make_unique(free_dimension_overrides)); transformers.emplace_back(std::make_unique(cuda_rocm_execution_providers)); diff --git a/orttraining/orttraining/models/runner/training_runner.cc b/orttraining/orttraining/models/runner/training_runner.cc index ef875a5f1c..4eada66cd1 100644 --- a/orttraining/orttraining/models/runner/training_runner.cc +++ b/orttraining/orttraining/models/runner/training_runner.cc @@ -192,6 +192,7 @@ Status TrainingRunner::Initialize() { gt_config.gelu_recompute = params_.gelu_recompute; gt_config.transformer_layer_recompute = params_.transformer_layer_recompute; gt_config.number_recompute_layers = params_.number_recompute_layers; + gt_config.enable_compute_optimizer = true; config.graph_transformer_config = gt_config; } @@ -580,7 +581,7 @@ void TrainingRunner::RunWithUpdate(VectorString& feed_names, ORT_THROW_IF_ERROR(status); } catch (std::exception&) { - // If exception happens during worker execution, propogate the exception to main thread. + // If exception happens during worker execution, propagate the exception to main thread. pipeline_worker_pool_.worker_states[worker_id].execution_exception = std::current_exception(); } }, diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index fe245bdd5d..1dbeb6e6ad 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -711,6 +711,7 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn .def_readwrite("gelu_recompute", &TrainingGraphTransformerConfiguration::gelu_recompute) .def_readwrite("transformer_layer_recompute", &TrainingGraphTransformerConfiguration::transformer_layer_recompute) .def_readwrite("number_recompute_layers", &TrainingGraphTransformerConfiguration::number_recompute_layers) + .def_readwrite("enable_compute_optimizer", &TrainingGraphTransformerConfiguration::enable_compute_optimizer) .def_readwrite("propagate_cast_ops_config", &TrainingGraphTransformerConfiguration::GraphTransformerConfiguration::propagate_cast_ops_config); py::class_ module_graph_builder_config( diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 82c00bc47d..453bf64efe 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -177,6 +177,12 @@ class GraphExecutionManager(GraphExecutionInterface): # Memory aware gradient builder. self._use_memory_efficient_gradient = False + # Enable compute optimizer by default. Allowed to be disabled via environment variable for + # convergence parity investigation. + self._enable_compute_optimizer = ( + ortmodule._defined_from_envvar("ORTMODULE_ENABLE_COMPUTE_OPTIMIZER", 1, warn=True) == 1 + ) + # Flag to re-export the model due to attribute change on original module. # Re-export will be avoided if _skip_check is enabled. self._original_model_has_changed = False @@ -445,6 +451,7 @@ class GraphExecutionManager(GraphExecutionInterface): graph_transformer_config.propagate_cast_ops_config.level = self._propagate_cast_ops_level graph_transformer_config.propagate_cast_ops_config.allow = self._propagate_cast_ops_allow graph_transformer_config.propagate_cast_ops_config.strategy = self._propagate_cast_ops_strategy + graph_transformer_config.enable_compute_optimizer = self._enable_compute_optimizer return graph_transformer_config def _initialize_graph_builder(self): diff --git a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc index 51c6392ff5..a9af9a8e8b 100644 --- a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc +++ b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc @@ -48,12 +48,12 @@ namespace { * Then load it into ORT, compare with the initial parameter values. */ TEST(CheckpointApiTest, SaveOnnxModelAsCheckpoint_ThenLoad_CPU) { - /// Phase 1 - Test Preparison + /// Phase 1 - Test Preparation /// Prepare the data and dest folder for saving checkpoint. /// Also cooked the data for test result comparison. // Model path and trainable parameter name definitions. - auto model_uri = MODEL_FOLDER "transform/computation_reduction/e2e.onnx"; + auto model_uri = MODEL_FOLDER "transform/computation_reduction/gathernd/e2e.onnx"; std::vector expected_trainable_param_names{ "bert.encoder.layer.2.output.LayerNorm.weight", "bert.encoder.layer.2.output.LayerNorm.bias", @@ -88,7 +88,7 @@ TEST(CheckpointApiTest, SaveOnnxModelAsCheckpoint_ThenLoad_CPU) { ORT_ENFORCE(CreateOrtValuesFromTensorProtos(trainable_param_values, expected_trainable_param_name_to_ort_value) .IsOK()); - // Remove the tempoprary directory if it already exists. + // Remove the temporary directory if it already exists. auto ckpt_test_root_dir = ORT_TSTR("checkpointing_api_test_dir"); if (Env::Default().FolderExists(ckpt_test_root_dir)) { ORT_ENFORCE(Env::Default().DeleteFolder(ckpt_test_root_dir).IsOK()); @@ -120,7 +120,7 @@ TEST(CheckpointApiTest, SaveOnnxModelAsCheckpoint_ThenLoad_CPU) { ASSERT_EQ(expected_file_names, valid_file_names); /// Phase 3 - Run load checkpoint APIs. - /// And check the result comparible with initial parameter values. + /// And check the result comparable with initial parameter values. // Call Load APIs CheckpointState checkpoint_state_to_load; @@ -199,7 +199,7 @@ TEST(CheckpointApiTest, LoadCheckpointToModel) { #if defined(USE_CUDA) || defined(USE_ROCM) TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad_CUDA) { - /// Phase 1 - Test Preparison + /// Phase 1 - Test Preparation /// Prepare the data and dest folder for saving checkpoint. /// Also cooked the data for test result comparison. auto model_uri = MODEL_FOLDER "training_api/training_model.onnx"; @@ -252,7 +252,7 @@ TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad_CUDA) { CheckpointState checkpoint_state; ORT_ENFORCE(optimizer->GetStateDict(checkpoint_state.optimizer_checkpoint_state).IsOK()); - // Remove the tempoprary directory if it already exists. + // Remove the temporary directory if it already exists. auto ckpt_test_root_dir = ORT_TSTR("checkpointing_api_test_dir"); if (Env::Default().FolderExists(ckpt_test_root_dir)) { ORT_ENFORCE(Env::Default().DeleteFolder(ckpt_test_root_dir).IsOK()); @@ -287,7 +287,7 @@ TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad_CUDA) { ASSERT_EQ(expected_file_names, valid_file_names); /// Phase 3 - Run load checkpoint APIs. - /// And check the result comparible with initial optimizer state values. + /// And check the result comparable with initial optimizer state values. // Call Load APIs CheckpointState checkpoint_state_to_load; @@ -334,7 +334,7 @@ TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad_CUDA) { * Then load it into ORT, compare with the initial properties' values. */ TEST(CheckpointApiTest, SaveCustomPropertyAsCheckpoint_ThenLoad_CPU) { - /// Phase 1 - Test Preparison + /// Phase 1 - Test Preparation /// Prepare the data and dest folder for saving checkpoint. CheckpointState checkpoint_state; @@ -352,7 +352,7 @@ TEST(CheckpointApiTest, SaveCustomPropertyAsCheckpoint_ThenLoad_CPU) { std::string s_property_name("train_data_path"); property_bag.AddProperty(s_property_name, s_data); - // Remove the tempoprary directory if it already exists. + // Remove the temporary directory if it already exists. auto ckpt_test_root_dir = ORT_TSTR("checkpointing_api_test_dir"); if (Env::Default().FolderExists(ckpt_test_root_dir)) { ORT_ENFORCE(Env::Default().DeleteFolder(ckpt_test_root_dir).IsOK()); diff --git a/orttraining/orttraining/training_ops/cuda/communication/nccl_service.h b/orttraining/orttraining/training_ops/cuda/communication/nccl_service.h index f9f2d0a5f9..601d8ccb31 100644 --- a/orttraining/orttraining/training_ops/cuda/communication/nccl_service.h +++ b/orttraining/orttraining/training_ops/cuda/communication/nccl_service.h @@ -187,7 +187,7 @@ class NcclService final : public INcclService { // Search the next unfinished communication group to work on. int FindNextCommunicationTime() const; - // Mutex to gurantee thread-safe access to this class. + // Mutex to guarantee thread-safe access to this class. std::mutex mutex_; // Conditional variable used to wait for the mutex. std::condition_variable cv_;