diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.h b/onnxruntime/test/optimizer/graph_transform_test_builder.h index 3e42bc456b..43c81d1c52 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.h +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.h @@ -107,6 +107,29 @@ class ModelTestBuilder { return &graph_.GetOrCreateNodeArg(name, &type_proto); } + template + NodeArg* MakeSymbolicInput(const std::vector>& shape) { + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(utils::ToTensorProtoElementType()); + type_proto.mutable_tensor_type()->mutable_shape(); + for (auto& d : shape) { + auto dim = type_proto.mutable_tensor_type()->mutable_shape()->add_dim(); + std::visit([&dim](auto&& arg) -> void { + using V = std::decay_t; + if constexpr (std::is_same_v) { + ORT_ENFORCE(arg >= 0, "Negative dimension is not allowed in symbolic shape"); + dim->set_dim_value(arg); + } else { + dim->set_dim_param(arg); + } + }, + d); + } + + std::string name = graph_.GenerateNodeArgName("symbolic_input"); + return &graph_.GetOrCreateNodeArg(name, &type_proto); + } + NodeArg* MakeOutput() { std::string name = graph_.GenerateNodeArgName("output"); output_names_.push_back(name); diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 16103f0059..6cf850c57e 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -59,6 +59,8 @@ #include "orttraining/core/optimizer/lstm_replacement.h" #include "orttraining/core/optimizer/transformer_layer_recompute.h" #include "orttraining/core/optimizer/qdq_fusion.h" +#include "orttraining/core/optimizer/shape_optimizer.h" +#include "orttraining/core/optimizer/transformer_layer_recompute.h" // Only enabled in full training build. Not in on device training builds #ifdef ENABLE_TRAINING @@ -145,6 +147,10 @@ std::vector> GeneratePreTrainingTransformers( transformers.emplace_back(std::make_unique( execution_provider, false /*skip_dequantize_linear*/, compatible_eps, excluded_initializers)); transformers.emplace_back(std::make_unique(compatible_eps)); + // Put fine-grained optimizer (e.g. ShapeOptimizer) after ReshapeFusion to avoid it breaks the strong patterns + // it defines. ReshapeFusion depends on subgraph pattern matching and do replacement accordingly, ShapeOptimizer + // potentially will optimize out some nodes defined in the subgraph patterns. So we put it after ReshapeFusion. + transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); if (config.gelu_recompute) { diff --git a/orttraining/orttraining/core/optimizer/shape_optimizer.cc b/orttraining/orttraining/core/optimizer/shape_optimizer.cc new file mode 100644 index 0000000000..2af0ef86a6 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/shape_optimizer.cc @@ -0,0 +1,369 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/core/optimizer/shape_optimizer.h" + +#include + +#include "core/common/inlined_containers.h" +#include "core/framework/tensorprotoutils.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/utils.h" + +using namespace onnxruntime::common; + +namespace onnxruntime { + +// TODO(pengwa): better way (instead of defining MACROs locally) to enable detailed debug logs +// for some specific graph transformers. +// Uncomment to log debug info for SO(Shape Optimizer). +// #define NEED_SO_LOG_DEBUG_INFO 1 + +#ifndef SO_LOG_DEBUG_INFO +#ifdef NEED_SO_LOG_DEBUG_INFO +#define SO_LOG_DEBUG_INFO(logger, message) LOGS(logger, WARNING) << message +#else +#define SO_LOG_DEBUG_INFO(logger, message) \ + ORT_UNUSED_PARAMETER(logger); \ + do { \ + } while (0) +#endif +#endif + +// Put utilities into an anonymous namespace. +namespace { + +constexpr int64_t NormalizeIndex(int64_t initial_index_value, int64_t rank) { + // Negative handling + int64_t non_negative_index = initial_index_value < 0 ? initial_index_value + rank : initial_index_value; + + // Clamp to [0, rank]. + if (non_negative_index < 0) { + non_negative_index = 0; + } else if (non_negative_index > rank) { + non_negative_index = rank; + } + + return non_negative_index; +} + +bool IsSingleValue1DShape(const ONNX_NAMESPACE::TensorShapeProto* input_shape) { + if (input_shape == nullptr) { + return false; + } + + size_t dim_size = static_cast(input_shape->dim_size()); + if (dim_size == 1 && utils::HasDimValue(input_shape->dim(0)) && input_shape->dim(0).dim_value() == 1) { + return true; + } + + return false; +} + +bool CanShapeNodeBeReplacedWithConstant(const Node& shape_node, const TensorShapeVector& dim_values, + TensorShapeVector& fold_values) { + int64_t data_rank = static_cast(dim_values.size()); + int64_t start = 0; + int64_t end = data_rank; // end is exclusive + if (graph_utils::IsSupportedOptypeVersionAndDomain(shape_node, "Shape", {15})) { + // Opset-15 Shape supports slicing using a 'start' and 'end' attribute + const auto& shape_attributes = shape_node.GetAttributes(); + for (const auto& attr : shape_attributes) { + if (attr.first == "start") { + start = attr.second.i(); + } else if (attr.first == "end") { + end = attr.second.i(); + } + } + } + + int64_t start_index_normalized = NormalizeIndex(start, data_rank); + int64_t end_index_normalized = NormalizeIndex(end, data_rank); + + int64_t slice_length = end_index_normalized - start_index_normalized; + slice_length = slice_length < 0 ? 0 : slice_length; + + fold_values.clear(); + fold_values.reserve(slice_length); + for (int64_t i = start_index_normalized; i < end_index_normalized; ++i) { + if (dim_values[i] == -1) { + // Return false if it contains symbolic dim values. + return false; + } else { + fold_values.push_back(dim_values[i]); + } + } + + return true; +} + +bool CanSliceNodeBeReplacedWithConstant(const Graph& graph, const Node& slice_node, + const TensorShapeVector& dim_values, + TensorShapeVector& fold_values) { + const NodeArg* starts_input = slice_node.InputDefs()[1]; + const NodeArg* ends_input = slice_node.InputDefs()[2]; + const NodeArg* axes_input = slice_node.InputDefs().size() > 3 ? slice_node.InputDefs()[3] : nullptr; + const NodeArg* steps_input = slice_node.InputDefs().size() > 4 ? slice_node.InputDefs()[4] : nullptr; + + // TODO: We support with some constraints currently, can be extended further to support other cases. + // Support cases: + // 1. starts/ends/axes/steps are all single-value 1D tensors, axes=[0] and steps=[1]. + // 2. starts/ends are single-value 1D tensors, axes/steps are not provided, (default value: axes=[0] and steps=[1]). + if (!IsSingleValue1DShape(starts_input->Shape()) || + !IsSingleValue1DShape(ends_input->Shape()) || + (axes_input && !IsSingleValue1DShape(axes_input->Shape())) || + (steps_input && !IsSingleValue1DShape(steps_input->Shape()))) { + return false; + } + + // Try to parse the value and double-check. + InlinedVector starts_values, ends_values, axes_values, steps_values; + if (!(optimizer_utils::AppendTensorFromInitializer(graph, *starts_input, starts_values, true) && + starts_values.size() == 1)) { + return false; + } + if (!(optimizer_utils::AppendTensorFromInitializer(graph, *ends_input, ends_values, true) && + ends_values.size() == 1)) { + return false; + } + if (axes_input && !(optimizer_utils::AppendTensorFromInitializer(graph, *axes_input, axes_values, true) && + axes_values.size() == 1 && axes_values[0] == 0)) { + return false; + } + if (steps_input && !(optimizer_utils::AppendTensorFromInitializer(graph, *steps_input, steps_values, true) && + steps_values.size() == 1 && steps_values[0] == 1)) { + return false; + } + + int64_t start = starts_values[0]; + int64_t end = ends_values[0]; + + int64_t data_rank = static_cast(dim_values.size()); + int64_t start_index_normalized = NormalizeIndex(start, data_rank); + int64_t end_index_normalized = NormalizeIndex(end, data_rank); + + int64_t slice_length = end_index_normalized - start_index_normalized; + slice_length = slice_length < 0 ? 0 : slice_length; + fold_values.clear(); + fold_values.reserve(slice_length); + for (int64_t i = start_index_normalized; i < end_index_normalized; ++i) { + if (dim_values[i] == -1) { + // Return false if it contains symbolic dim values. + return false; + } else { + fold_values.push_back(dim_values[i]); + } + } + + return true; +} + +bool CanGatherNodeBeReplacedWithConstant(const Graph& graph, const Node& gather_node, + const TensorShapeVector& dim_values, + TensorShapeVector& fold_values, int& gather_output_rank) { + const NodeArg* data_input = gather_node.InputDefs()[0]; + + // TODO: We support with some constraints currently, can be extended further to support other cases. + // Support cases: + // 1. data is 1D tensor, indices is a scalar, axis=0. + // 2. data is 1D tensor, indices is a scalar, axis=0 or axis is not provided (default value: axis=0). + // 3. data is 1D tensor, indices is 1D tensor with single element, axis=0. + // 4. data is 1D tensor, indices is 1D tensor with single element, axis is not provided (default value: axis=0). + + // Gather's input MUST be 1D tensor. + if (!data_input->Shape() || data_input->Shape()->dim_size() != 1) { + return false; + } + + const NodeArg* indices_input = gather_node.InputDefs()[1]; + auto indices_shape = indices_input->Shape(); + // Indices can be 1D tensor or scalar. + if (!indices_shape || !(indices_shape->dim_size() == 0 || IsSingleValue1DShape(indices_shape))) { + // If the indices did not contain one single element, then skip it. + return false; + } + + // Try to parse int64 type constant initializers. + InlinedVector indices_values; + if (!(optimizer_utils::AppendTensorFromInitializer(graph, *indices_input, indices_values, true) && + indices_values.size() == 1)) { + return false; + } + + const ONNX_NAMESPACE::AttributeProto* axis_attr = graph_utils::GetNodeAttribute(gather_node, "axis"); + if (axis_attr && static_cast(axis_attr->i()) != 0) { + return false; + } + + int64_t start = indices_values[0]; + int64_t data_rank = static_cast(dim_values.size()); + int64_t start_index_normalized = NormalizeIndex(start, data_rank); + + if (dim_values[static_cast(start_index_normalized)] == -1) { + // Return false if it contains symbolic dim values. + return false; + } else { + fold_values.push_back(dim_values[static_cast(start_index_normalized)]); + } + + gather_output_rank = data_input->Shape()->dim_size() + indices_shape->dim_size() - 1; + return true; +} + +void UpdateNodeArgToConstant(Graph& graph, NodeArg* arg_to_update, const TensorShapeVector& values, + bool create_scalar_for_single_value = false) { + size_t length = values.size(); + bool is_scalar = length == 1 && create_scalar_for_single_value; + + // Create new TensorProto. + ONNX_NAMESPACE::TensorProto constant_tensor_proto; + constant_tensor_proto.set_name(arg_to_update->Name()); + constant_tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + if (!is_scalar) { + constant_tensor_proto.add_dims(length); + } + constant_tensor_proto.set_raw_data(values.data(), length * sizeof(int64_t)); + + // Add initializer into Graph. + graph.AddInitializedTensor(constant_tensor_proto); + + // Update the output arg shape. + ONNX_NAMESPACE::TensorShapeProto new_shape; + if (!is_scalar) { + new_shape.add_dim()->set_dim_value(length); + } + arg_to_update->SetShape(new_shape); +} + +} // namespace + +Status ShapeOptimizer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) + const { + GraphViewer graph_viewer(graph); + auto& order = graph_viewer.GetNodesInTopologicalOrder(); + + for (NodeIndex i : order) { + auto* node = graph.GetNode(i); + if (!node) { + continue; + } + + if (!graph_utils::IsSupportedOptypeVersionAndDomain(*node, "Shape", {1, 13, 15})) { + continue; + } + + ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger)); + + auto data_shape = node->MutableInputDefs()[0]->Shape(); + if (data_shape == nullptr) { + SO_LOG_DEBUG_INFO(logger, "Shape node's data input shape is missing." + node->Name()); + continue; + } + + // Parse data input shape, fill -1 for symbolic dimensions. + TensorShapeVector dim_values; + dim_values.reserve(data_shape->dim_size()); + bool has_concrete_dim = false; + for (int dim_index = 0; dim_index < data_shape->dim_size(); dim_index++) { + auto dim = data_shape->dim(dim_index); + if (utils::HasDimValue(dim)) { + dim_values.push_back(dim.dim_value()); + has_concrete_dim = true; + } else { + // Fill with -1 for symbolic dimension. + dim_values.push_back(-1); + } + } + + if (!has_concrete_dim) { + SO_LOG_DEBUG_INFO(logger, "No concrete dimension found, don't need try further." + node->Name()); + continue; + } + + InlinedVector nodes_to_remove; + TensorShapeVector fold_values; + // Short path - check if the shape node can be constant folded. + if (CanShapeNodeBeReplacedWithConstant(*node, dim_values, fold_values)) { + SO_LOG_DEBUG_INFO(logger, "Shape node can be constant folded." + node->Name()); + UpdateNodeArgToConstant(graph, node->MutableOutputDefs()[0], fold_values); + nodes_to_remove.push_back(node); + } else { + // Check consumers of Shape node, try best effort to constant fold them if possible. + // Currently support Gather and Slice in some cases. + auto p_ip_node = node->OutputNodesBegin(); + const auto p_ip_node_end = node->OutputNodesEnd(); + InlinedHashSet visited_nodes; + while (p_ip_node != p_ip_node_end) { + if (visited_nodes.find(&(*p_ip_node)) != visited_nodes.end()) { + // Already handled, skip the node. + ++p_ip_node; + continue; + } + + auto& output_node = *graph.GetNode(p_ip_node->Index()); + visited_nodes.insert(&output_node); + ++p_ip_node; + + NodeArg* data_input = output_node.MutableInputDefs()[0]; + // Skip when shape is not used as sliced data. + if (data_input != node->MutableOutputDefs()[0]) { + continue; + } + + TensorShapeVector slice_fold_values; + if (graph_utils::IsSupportedOptypeVersionAndDomain(output_node, "Slice", {10, 11, 13}) && + CanSliceNodeBeReplacedWithConstant(graph, output_node, dim_values, slice_fold_values)) { + SO_LOG_DEBUG_INFO(logger, "Slice node can be constant folded." + output_node.Name()); + UpdateNodeArgToConstant(graph, output_node.MutableOutputDefs()[0], slice_fold_values); + nodes_to_remove.push_back(&output_node); + continue; + } + + int gather_output_rank = 0; + TensorShapeVector gather_fold_values; + if (graph_utils::IsSupportedOptypeVersionAndDomain(output_node, "Gather", {1, 11, 13}) && + CanGatherNodeBeReplacedWithConstant(graph, output_node, dim_values, gather_fold_values, + gather_output_rank)) { + SO_LOG_DEBUG_INFO(logger, "Gather node can be constant folded." + output_node.Name()); + UpdateNodeArgToConstant(graph, output_node.MutableOutputDefs()[0], gather_fold_values, gather_output_rank == 0); + nodes_to_remove.push_back(&output_node); + continue; + } + } + } + + for (Node* node_to_remove : nodes_to_remove) { + // Remove single-output node chain for inputs of the node + auto p_ip_node = node_to_remove->InputNodesBegin(); + const auto p_ip_node_end = node_to_remove->InputNodesEnd(); + while (p_ip_node != p_ip_node_end) { + const auto& input_node = *p_ip_node; + // Update the node iterator before removing the corresponding node because removing + // the node will invalidate the node iterator + ++p_ip_node; + + // Remove the node only when there is a single output edge or the node does not produce graph output. + if (input_node.GetOutputEdgesCount() > 1 || graph.NodeProducesGraphOutput(input_node)) { + SO_LOG_DEBUG_INFO(logger, "Skip removing node: " + input_node.Name() + "(" + input_node.OpType() + ")"); + continue; + } + SO_LOG_DEBUG_INFO(logger, "Removing node: " + input_node.Name() + "(" + input_node.OpType() + ")"); + graph_utils::RemoveNodesWithOneOutputBottomUp(graph, input_node); + } + + // Remove the output edges of the constant node and then remove the node itself. + graph_utils::RemoveNodeOutputEdges(graph, *node_to_remove); + + SO_LOG_DEBUG_INFO(logger, "Removing trigger node: " + node_to_remove->Name()); + graph.RemoveNode(node_to_remove->Index()); + modified = true; + } + } + + return Status::OK(); +} + +#undef NEED_SO_LOG_DEBUG_INFO +#undef SO_LOG_DEBUG_INFO + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/shape_optimizer.h b/orttraining/orttraining/core/optimizer/shape_optimizer.h new file mode 100644 index 0000000000..aae7cd7b4d --- /dev/null +++ b/orttraining/orttraining/core/optimizer/shape_optimizer.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** +@class ShapeOptimizer + +Transformer that traverses the graph top-down and performs shape optimizations. + +Try best to constant fold the output of the Shape node: + 1. Shape generates 1D tensor [12, 128, 512] (all dimensions have concrete dim value), which can be constant folded + to an initializer including 1D tensor values [12, 128, 512]. (Some logic of ConstantFolding also does the same thing.) + + 2. Shape generates 1D tensor [batch_size, 128, 512] -> Slice(start=1,end=3), we can constant fold the Shape->Slice to + an initializer including 1D tensor values [128, 512]. + + 3. Shape generates 1D tensor [batch_size, 128, 512] -> Gather(axes=[0], index=[2]), we can constant fold the + Shape->Gather to an initializer including 1D tensor values [512]. + + 4. Shape since OPSET 15 takes input of shape [batch_size, 128, 512], slicing from 1 to 2(exclusive), + we can constant fold the Shape(start=1,end=2) to an initializer including 1D tensor values [128]. + +This would help clean up the graph, and combined with ConstantFolding, the graph would be much more simplified. + +*/ +class ShapeOptimizer : public GraphTransformer { + public: + ShapeOptimizer( + const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("ShapeOptimizer", compatible_execution_providers) { + } + + private: + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/orttraining/orttraining/test/optimizer/shape_optimizer_test.cc b/orttraining/orttraining/test/optimizer/shape_optimizer_test.cc new file mode 100644 index 0000000000..ea05b29c86 --- /dev/null +++ b/orttraining/orttraining/test/optimizer/shape_optimizer_test.cc @@ -0,0 +1,1043 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/model.h" +#include "test/framework/test_utils.h" +#include "test/test_environment.h" + +#include "gtest/gtest.h" +#include "core/optimizer/utils.h" +#include "test/optimizer/graph_transform_test_builder.h" +#include "test/optimizer/graph_transform_test_fixture.h" +#include "test/util/include/asserts.h" +#include "orttraining/core/optimizer/shape_optimizer.h" + +using namespace std; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace test { + +#ifndef DISABLE_CONTRIB_OPS + +TEST(ShapeOptimizerTests, Shape15CannotFold) { + /* + [attention_mask1_dim0,512,1536] + | + Identity + | + [attention_mask1_dim0,512,1536] + | + Shape15 + | + [2]: (attention_mask1_dim0,512) + | + Identity + | + [2] + */ + + std::string identity_output_name; + + auto pre_graph_checker = [&](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Identity"] == 2); + TEST_RETURN_IF_NOT(op_to_count["Shape"] == 1); + + identity_output_name = ""; + for (auto& node : graph.Nodes()) { + if (node.OpType().compare("Identity") == 0) + identity_output_name = node.MutableOutputDefs()[0]->Name(); + } + + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Identity"] == 2); + TEST_RETURN_IF_NOT(op_to_count["Shape"] == 1); + + TEST_RETURN_IF_NOT(!identity_output_name.empty()); + auto input_arg = graph.GetNodeArg(identity_output_name); + // Try to parse int64 type constant initializers. + InlinedVector shape_values; + TEST_RETURN_IF_NOT(!optimizer_utils::AppendTensorFromInitializer(graph, *input_arg, shape_values, true)); + + return Status::OK(); + }; + + std::vector opset_candidates{15}; + for (auto opset : opset_candidates) { + auto build_test_case = [&](ModelTestBuilder& builder) { + std::vector> identity_input_shape; + identity_input_shape.reserve(3); + identity_input_shape.push_back("attention_mask1_dim0"); + identity_input_shape.push_back(512); + identity_input_shape.push_back(1536); + + auto* identity_input_arg = builder.MakeSymbolicInput(identity_input_shape); + auto* identity_out_arg = builder.MakeIntermediate(); + builder.AddNode("Identity", {identity_input_arg}, {identity_out_arg}); + + auto* shape_out_arg = builder.MakeIntermediate(); + Node& shape_node = builder.AddNode("Shape", {identity_out_arg}, {shape_out_arg}); + shape_node.AddAttribute("start", static_cast(0)); + shape_node.AddAttribute("end", static_cast(2)); + + auto* identity_out_arg_1 = builder.MakeOutput(); + builder.AddNode("Identity", {shape_out_arg}, {identity_out_arg_1}); + }; + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + InlinedHashSet compatible_eps; + std::unique_ptr transformer = std::make_unique(compatible_eps); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } +} + +TEST(ShapeOptimizerTests, Shape15) { + /* + [attention_mask1_dim0,512,1536] + | + Identity + | + [attention_mask1_dim0,512,1536] + | + Shape15 + | + [2]: (512,1536) + | + Identity + | + [2] + */ + + std::string shape_output_name; + + auto pre_graph_checker = [&](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Identity"] == 2); + TEST_RETURN_IF_NOT(op_to_count["Shape"] == 1); + + shape_output_name = ""; + for (auto& node : graph.Nodes()) { + if (node.OpType().compare("Shape") == 0) + shape_output_name = node.MutableOutputDefs()[0]->Name(); + } + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Identity"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Shape"] == 0); + + TEST_RETURN_IF_NOT(!shape_output_name.empty()); + auto input_arg = graph.GetNodeArg(shape_output_name); + // Try to parse int64 type constant initializers. + InlinedVector shape_values; + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *input_arg, shape_values, true)); + TEST_RETURN_IF_NOT(shape_values.size() == 2U); + TEST_RETURN_IF_NOT(shape_values[0] == 512); + TEST_RETURN_IF_NOT(shape_values[1] == 1536); + return Status::OK(); + }; + + std::vector opset_candidates{15}; + for (auto opset : opset_candidates) { + auto build_test_case = [&](ModelTestBuilder& builder) { + std::vector> identity_input_shape; + identity_input_shape.reserve(3); + identity_input_shape.push_back("attention_mask1_dim0"); + identity_input_shape.push_back(512); + identity_input_shape.push_back(1536); + + auto* identity_input_arg = builder.MakeSymbolicInput(identity_input_shape); + auto* identity_out_arg = builder.MakeIntermediate(); + builder.AddNode("Identity", {identity_input_arg}, {identity_out_arg}); + + auto* shape_out_arg = builder.MakeIntermediate(); + builder.AddNode("Shape", {identity_out_arg}, {shape_out_arg}) + .AddAttribute("start", static_cast(-2)); + + auto* identity_out_arg_1 = builder.MakeOutput(); + builder.AddNode("Identity", {shape_out_arg}, {identity_out_arg_1}); + }; + + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + InlinedHashSet compatible_eps; + std::unique_ptr transformer = std::make_unique(compatible_eps); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } +} + +TEST(ShapeOptimizerTests, Shape15TakesGraphInput) { + /* + [attention_mask1_dim0,512,1536] + | + Shape15 + | + [2]: (512,1536) + | + Identity + | + [2] + */ + + std::string shape_output_name; + auto pre_graph_checker = [&](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Identity"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Shape"] == 1); + + shape_output_name = ""; + for (auto& node : graph.Nodes()) { + if (node.OpType().compare("Shape") == 0) + shape_output_name = node.MutableOutputDefs()[0]->Name(); + } + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Identity"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Shape"] == 0); + + TEST_RETURN_IF_NOT(!shape_output_name.empty()); + auto input_arg = graph.GetNodeArg(shape_output_name); + // Try to parse int64 type constant initializers. + InlinedVector shape_values; + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *input_arg, shape_values, true)); + TEST_RETURN_IF_NOT(shape_values.size() == 2U); + TEST_RETURN_IF_NOT(shape_values[0] == 512); + TEST_RETURN_IF_NOT(shape_values[1] == 1536); + return Status::OK(); + }; + + std::vector opset_candidates{15}; + for (auto opset : opset_candidates) { + auto build_test_case = [&](ModelTestBuilder& builder) { + std::vector> shape_input_shape; + shape_input_shape.reserve(3); + shape_input_shape.push_back("attention_mask1_dim0"); + shape_input_shape.push_back(512); + shape_input_shape.push_back(1536); + + auto* shape_input_arg = builder.MakeSymbolicInput(shape_input_shape); + auto* shape_out_arg = builder.MakeIntermediate(); + builder.AddNode("Shape", {shape_input_arg}, {shape_out_arg}) + .AddAttribute("start", static_cast(-2)); + + auto* identity_out_arg_1 = builder.MakeOutput(); + builder.AddNode("Identity", {shape_out_arg}, {identity_out_arg_1}); + }; + + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + InlinedHashSet compatible_eps; + std::unique_ptr transformer = std::make_unique(compatible_eps); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } +} + +TEST(ShapeOptimizerTests, Shape15GeneratesGraphOutput) { + /* + [attention_mask1_dim0,512,1536] + | + Identity + | + [attention_mask1_dim0,512,1536] + | + Shape15 + | + [2]: (512,1536) + | + [2] + */ + std::string shape_output_name; + auto pre_graph_checker = [&](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Identity"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Shape"] == 1); + + shape_output_name = ""; + for (auto& node : graph.Nodes()) { + if (node.OpType().compare("Shape") == 0) + shape_output_name = node.MutableOutputDefs()[0]->Name(); + } + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Identity"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Shape"] == 0); + + TEST_RETURN_IF_NOT(!shape_output_name.empty()); + auto input_arg = graph.GetNodeArg(shape_output_name); + // Try to parse int64 type constant initializers. + InlinedVector shape_values; + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *input_arg, shape_values, true)); + TEST_RETURN_IF_NOT(shape_values.size() == 2U); + TEST_RETURN_IF_NOT(shape_values[0] == 512); + TEST_RETURN_IF_NOT(shape_values[1] == 1536); + return Status::OK(); + }; + + std::vector opset_candidates{15}; + for (auto opset : opset_candidates) { + auto build_test_case = [&](ModelTestBuilder& builder) { + std::vector> identity_input_shape; + identity_input_shape.reserve(3); + identity_input_shape.push_back("attention_mask1_dim0"); + identity_input_shape.push_back(512); + identity_input_shape.push_back(1536); + + auto* identity_input_arg = builder.MakeSymbolicInput(identity_input_shape); + auto* identity_out_arg = builder.MakeIntermediate(); + builder.AddNode("Identity", {identity_input_arg}, {identity_out_arg}); + + auto* shape_out_arg = builder.MakeOutput(); + builder.AddNode("Shape", {identity_out_arg}, {shape_out_arg}) + .AddAttribute("start", static_cast(-2)); + }; + + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + InlinedHashSet compatible_eps; + std::unique_ptr transformer = std::make_unique(compatible_eps); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } +} + +TEST(ShapeOptimizerTests, Slice) { + /* + [attention_mask1_dim0,512,1536] + | + Shape + | + [4]: (attention_mask1_dim0,512,1536) + | + Slice + | + [2]: (512, 1536) + | + Identity + | + [2] + */ + + std::string slice_output_name; + + auto pre_graph_checker = [&](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Identity"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Shape"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Slice"] == 1); + + slice_output_name = ""; + for (auto& node : graph.Nodes()) { + if (node.OpType().compare("Slice") == 0) + slice_output_name = node.MutableOutputDefs()[0]->Name(); + } + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Identity"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Shape"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Slice"] == 0); + + TEST_RETURN_IF_NOT(!slice_output_name.empty()); + auto input_arg = graph.GetNodeArg(slice_output_name); + // Try to parse int64 type constant initializers. + InlinedVector shape_values; + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *input_arg, shape_values, true)); + TEST_RETURN_IF_NOT(shape_values.size() == 2U); + TEST_RETURN_IF_NOT(shape_values[0] == 512); + TEST_RETURN_IF_NOT(shape_values[1] == 1536); + return Status::OK(); + }; + + std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + for (auto opset : opset_candidates) { + auto build_test_case = [&](ModelTestBuilder& builder) { + std::vector> shape_input_shape; + shape_input_shape.reserve(3); + shape_input_shape.push_back("attention_mask1_dim0"); + shape_input_shape.push_back(512); + shape_input_shape.push_back(1536); + + auto* shape_input_arg = builder.MakeSymbolicInput(shape_input_shape); + auto* shape_out_arg = builder.MakeIntermediate(); + builder.AddNode("Shape", {shape_input_arg}, {shape_out_arg}); + + // Slice after opset 1 have such schema. + auto* slice_out_arg = builder.MakeIntermediate(); + auto* starts_input_arg = builder.MakeInitializer({1}, {-2}); + auto* ends_input_arg = builder.MakeInitializer({1}, {3}); + auto* axes_input_arg = builder.MakeInitializer({1}, {0}); + builder.AddNode("Slice", {shape_out_arg, starts_input_arg, ends_input_arg, axes_input_arg}, {slice_out_arg}); + + auto* identity_out_arg_1 = builder.MakeOutput(); + builder.AddNode("Identity", {slice_out_arg}, {identity_out_arg_1}); + }; + + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + InlinedHashSet compatible_eps; + std::unique_ptr transformer = std::make_unique(compatible_eps); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } +} + +TEST(ShapeOptimizerTests, SliceGeneratesGraphOutput) { + /* + [attention_mask1_dim0,512,1536] + | + Shape + | + [4]: (attention_mask1_dim0,512,1536) + | + Slice + | + [2]: (512, 1536) + | + [2] + This test also test when axes and step input are missing. + */ + + std::string slice_output_name; + auto pre_graph_checker = [&](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Shape"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Slice"] == 1); + + slice_output_name = ""; + for (auto& node : graph.Nodes()) { + if (node.OpType().compare("Slice") == 0) + slice_output_name = node.MutableOutputDefs()[0]->Name(); + } + + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Shape"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Slice"] == 0); + + TEST_RETURN_IF_NOT(!slice_output_name.empty()); + auto input_arg = graph.GetNodeArg(slice_output_name); + // Try to parse int64 type constant initializers. + InlinedVector shape_values; + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *input_arg, shape_values, true)); + TEST_RETURN_IF_NOT(shape_values.size() == 2U); + TEST_RETURN_IF_NOT(shape_values[0] == 512); + TEST_RETURN_IF_NOT(shape_values[1] == 1536); + return Status::OK(); + }; + + std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + for (auto opset : opset_candidates) { + auto build_test_case = [&](ModelTestBuilder& builder) { + std::vector> shape_input_shape; + shape_input_shape.reserve(3); + shape_input_shape.push_back("attention_mask1_dim0"); + shape_input_shape.push_back(512); + shape_input_shape.push_back(1536); + + auto* shape_input_arg = builder.MakeSymbolicInput(shape_input_shape); + auto* shape_out_arg = builder.MakeIntermediate(); + builder.AddNode("Shape", {shape_input_arg}, {shape_out_arg}); + + // Slice after opset 1 have such schema. + auto* slice_out_arg = builder.MakeOutput(); + auto* starts_input_arg = builder.MakeInitializer({1}, {-2}); + auto* ends_input_arg = builder.MakeInitializer({1}, {3}); + builder.AddNode("Slice", {shape_out_arg, starts_input_arg, ends_input_arg}, {slice_out_arg}); + }; + + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + InlinedHashSet compatible_eps; + std::unique_ptr transformer = std::make_unique(compatible_eps); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } +} + +TEST(ShapeOptimizerTests, Gather) { + /* + [attention_mask1_dim0,512,24,64] + | + Shape + | + [4] + / \ + Gather Gather + | | + []: (attention_mask1_dim0,) [1]: (24,) + | | + [] means a shape for scalar. + */ + + std::vector gather_output_names; + gather_output_names.reserve(2); + auto pre_graph_checker = [&](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Shape"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Gather"] == 2); + + gather_output_names.clear(); + for (auto& node : graph.Nodes()) { + if (node.OpType().compare("Gather") == 0) { + gather_output_names.push_back(node.MutableOutputDefs()[0]->Name()); + } + } + + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Shape"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Gather"] == 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gather") { + for (auto& gather_output_name : gather_output_names) { + if (gather_output_name.compare(node.MutableOutputDefs()[0]->Name()) != 0) { + // Try to parse int64 type constant initializers. + InlinedVector shape_values; + auto input_arg = graph.GetNodeArg(gather_output_name); + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *input_arg, shape_values, true)); + TEST_RETURN_IF_NOT(shape_values.size() == 1U); + TEST_RETURN_IF_NOT(shape_values[0] == 24); + } + } + } + } + + return Status::OK(); + }; + + std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + for (auto opset : opset_candidates) { + auto build_test_case = [&](ModelTestBuilder& builder) { + std::vector> shape_input_shape; + shape_input_shape.reserve(4); + shape_input_shape.push_back("attention_mask1_dim0"); + shape_input_shape.push_back(512); + shape_input_shape.push_back(24); + shape_input_shape.push_back(64); + + auto* shape_input_arg = builder.MakeSymbolicInput(shape_input_shape); + auto* shape_out_arg = builder.MakeIntermediate(); + // Shape before opset 15 have such schema, the test schema did not cover opset 15. + builder.AddNode("Shape", {shape_input_arg}, {shape_out_arg}); + + auto* indices_input_arg = builder.MakeScalarInitializer(0); + auto* gather_out_arg = builder.MakeOutput(); + builder.AddNode("Gather", {shape_out_arg, indices_input_arg}, {gather_out_arg}) + .AddAttribute("axis", static_cast(0)); + + auto* indices_input_arg_1 = builder.MakeInitializer({1}, {2}); + auto* gather_out_arg_1 = builder.MakeOutput(); + builder.AddNode("Gather", {shape_out_arg, indices_input_arg_1}, {gather_out_arg_1}) + .AddAttribute("axis", static_cast(0)); + }; + + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + InlinedHashSet compatible_eps; + std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); + std::unique_ptr transformer = std::make_unique(compatible_eps); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } +} + +TEST(ShapeOptimizerTests, ConcreteDimUsedBySlice) { + /* + [attention_mask1_dim0,24,512,512] + | + Dropout + / \ + [attention_mask1_dim0,24,512,512] [attention_mask1_dim0,24,512,512] + | | + Shape | + | | + [4] | + / \ | + Slice Slice | + | | | + [1]: (512,) [1]: (512,) | + | | | + Squeeze Squeeze | + | | | + [1]: -1 Unsqueeze Unsqueeze / + \ \ / / + ConcatTraining / + | / + | / + [3]: (-1, 512, 512) / + \ / + Reshape + | + */ + + std::vector slice_output_names; + auto pre_graph_checker = [&slice_output_names](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Dropout"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Slice"] == 2); + TEST_RETURN_IF_NOT(op_to_count["Squeeze"] == 2); + TEST_RETURN_IF_NOT(op_to_count["Unsqueeze"] == 2); + TEST_RETURN_IF_NOT(op_to_count["com.microsoft.ConcatTraining"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Shape"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Reshape"] == 1); + + slice_output_names.clear(); + for (auto& node : graph.Nodes()) { + if (node.OpType().compare("Slice") == 0) { + slice_output_names.push_back(node.OutputDefs()[0]->Name()); + } + } + + return Status::OK(); + }; + + auto post_graph_checker = [&slice_output_names](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Dropout"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Slice"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Squeeze"] == 2); + TEST_RETURN_IF_NOT(op_to_count["Unsqueeze"] == 2); + TEST_RETURN_IF_NOT(op_to_count["com.microsoft.ConcatTraining"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Shape"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Reshape"] == 1); + + for (auto slice_output_name : slice_output_names) { + auto input_arg = graph.GetNodeArg(slice_output_name); + TEST_RETURN_IF_NOT(input_arg != nullptr); + InlinedVector shape_values; + // Try to parse int64 type constant initializers. + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *input_arg, shape_values, true)); + TEST_RETURN_IF_NOT(shape_values.size() == 1U); + TEST_RETURN_IF_NOT(shape_values[0] == 512); + } + + return Status::OK(); + }; + + std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + for (auto opset : opset_candidates) { + auto build_test_case = [&](ModelTestBuilder& builder) { + std::vector> dropout_input_shape; + dropout_input_shape.reserve(4); + dropout_input_shape.push_back("attention_mask1_dim0"); + dropout_input_shape.push_back(24); + dropout_input_shape.push_back(512); + dropout_input_shape.push_back(512); + + auto* dropout_input_arg = builder.MakeSymbolicInput(dropout_input_shape); + auto* dropout_out_arg = builder.MakeIntermediate(); + auto* mask_out_arg = builder.MakeIntermediate(); + constexpr float ratio = 0.10000000149011612f; + if (opset < 12) { + builder.AddNode("Dropout", {dropout_input_arg}, {dropout_out_arg, mask_out_arg}) + .AddAttribute("ratio", ratio); + } else { + auto* ratio_input_arg = builder.MakeScalarInitializer(ratio); + auto* mode_input_arg = builder.MakeInitializerBool({}, std::vector{true}); + builder.AddNode("Dropout", {dropout_input_arg, ratio_input_arg, mode_input_arg}, + {dropout_out_arg, mask_out_arg}); + } + + auto* shape_out_arg = builder.MakeIntermediate(); + // Shape before opset 15 have such schema, the test schema did not cover opset 15. + builder.AddNode("Shape", {dropout_out_arg}, {shape_out_arg}); + + // Slice after opset 1 have such schema. + auto* slice_out_arg = builder.MakeIntermediate(); + auto* starts_input_arg = builder.MakeInitializer({1}, {-2}); + auto* ends_input_arg = builder.MakeInitializer({1}, {-1}); + auto* axes_input_arg = builder.MakeInitializer({1}, {0}); + builder.AddNode("Slice", {shape_out_arg, starts_input_arg, ends_input_arg, axes_input_arg}, {slice_out_arg}); + + auto* starts_input_arg_1 = builder.MakeInitializer({1}, {-1}); + auto* ends_input_arg_1 = builder.MakeInitializer({1}, {9223372036854775807}); + auto* axes_input_arg_1 = builder.MakeInitializer({1}, {0}); + auto* slice_out_arg_1 = builder.MakeIntermediate(); + builder.AddNode("Slice", {shape_out_arg, starts_input_arg_1, ends_input_arg_1, axes_input_arg_1}, + {slice_out_arg_1}); + + auto* squeeze_out_arg = builder.MakeIntermediate(); + auto* squeeze_out_arg_1 = builder.MakeIntermediate(); + const std::vector squeeze_axes{0}; + if (opset < 13) { + builder.AddNode("Squeeze", {slice_out_arg}, {squeeze_out_arg}).AddAttribute("axes", squeeze_axes); + builder.AddNode("Squeeze", {slice_out_arg_1}, {squeeze_out_arg_1}).AddAttribute("axes", squeeze_axes); + } else { + auto* squeeze_axes_input_arg = builder.MakeInitializer({1}, squeeze_axes); + builder.AddNode("Squeeze", {slice_out_arg, squeeze_axes_input_arg}, {squeeze_out_arg}); + auto* squeeze_axes_input_arg_1 = builder.MakeInitializer({1}, squeeze_axes); + builder.AddNode("Squeeze", {slice_out_arg_1, squeeze_axes_input_arg_1}, {squeeze_out_arg_1}); + } + + auto* unsqueeze_out_arg = builder.MakeIntermediate(); + auto* unsqueeze_out_arg_1 = builder.MakeIntermediate(); + const std::vector unsqueeze_axes{0}; + if (opset < 13) { + builder.AddNode("Unsqueeze", {squeeze_out_arg}, {unsqueeze_out_arg}).AddAttribute("axes", unsqueeze_axes); + builder.AddNode("Unsqueeze", {squeeze_out_arg_1}, {unsqueeze_out_arg_1}).AddAttribute("axes", unsqueeze_axes); + } else { + auto* unsqueeze_axes_input_arg = builder.MakeInitializer({1}, unsqueeze_axes); + builder.AddNode("Unsqueeze", {squeeze_out_arg, unsqueeze_axes_input_arg}, {unsqueeze_out_arg}); + auto* unsqueeze_axes_input_arg_1 = builder.MakeInitializer({1}, unsqueeze_axes); + builder.AddNode("Unsqueeze", {squeeze_out_arg_1, unsqueeze_axes_input_arg_1}, {unsqueeze_out_arg_1}); + } + + auto* concat_training_out_arg = builder.MakeIntermediate(); + auto* concat_input_arg = builder.MakeInitializer({1}, {-1}); + builder.AddNode("ConcatTraining", {concat_input_arg, unsqueeze_out_arg, unsqueeze_out_arg_1}, + {concat_training_out_arg}, kMSDomain) + .AddAttribute("axis", static_cast(0)); + + auto* reshape_out_arg = builder.MakeOutput(); + builder.AddNode("Reshape", {dropout_out_arg, concat_training_out_arg}, {reshape_out_arg}); + }; + + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + InlinedHashSet compatible_eps; + std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); + std::unique_ptr transformer = std::make_unique(compatible_eps); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } +} + +TEST(ShapeOptimizerTests, ConcreteDimUsedByGatherSlice) { + /* + [attention_mask1_dim0,512,1536] [4]: (0, 0, 24, -1) + \ / + Reshape + / + [attention_mask1_dim0,512,24,64] + | \ + Shape Transpose + | | + [4] [attention_mask1_dim0,24,512,64] + / \ \ + Gather Slice \ + | | \ + []: (512,) [1]: (64,) | + | | | + | Squeeze | + | | | + [1]: -1 Unsqueeze Unsqueeze / + \ \ / / + ConcatTraining / + | / + | / + [3]: (-1, 512, 64) / + \ / + Reshape + | + [] means a shape for scalar. + */ + + std::string gather_output_name, slice_output_name; + auto pre_graph_checker = [&](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Shape"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Transpose"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Gather"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Slice"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Squeeze"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Unsqueeze"] == 2); + TEST_RETURN_IF_NOT(op_to_count["com.microsoft.ConcatTraining"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Reshape"] == 2); + + for (auto& node : graph.Nodes()) { + if (node.OpType().compare("Gather") == 0) { + gather_output_name = node.OutputDefs()[0]->Name(); + } else if (node.OpType().compare("Slice") == 0) { + slice_output_name = node.OutputDefs()[0]->Name(); + } + } + + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Shape"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Transpose"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Gather"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Slice"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Squeeze"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Unsqueeze"] == 2); + TEST_RETURN_IF_NOT(op_to_count["com.microsoft.ConcatTraining"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Reshape"] == 2); + + auto gather_output_arg = graph.GetNodeArg(gather_output_name); + TEST_RETURN_IF_NOT(gather_output_arg != nullptr); + // Try to parse int64 type constant initializers. + InlinedVector gather_out_values; + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *gather_output_arg, gather_out_values, true)); + TEST_RETURN_IF_NOT(gather_out_values.size() == 1U); + TEST_RETURN_IF_NOT(gather_out_values[0] == 512); + + auto slice_out_arg = graph.GetNodeArg(slice_output_name); + TEST_RETURN_IF_NOT(slice_out_arg != nullptr); + // Try to parse int64 type constant initializers. + InlinedVector slice_out_values; + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *slice_out_arg, slice_out_values, true)); + TEST_RETURN_IF_NOT(slice_out_values.size() == 1U); + TEST_RETURN_IF_NOT(slice_out_values[0] == 64); + + return Status::OK(); + }; + + std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + for (auto opset : opset_candidates) { + auto build_test_case = [&](ModelTestBuilder& builder) { + std::vector> reshape_input_shape; + reshape_input_shape.reserve(3); + reshape_input_shape.push_back("attention_mask1_dim0"); + reshape_input_shape.push_back(512); + reshape_input_shape.push_back(1536); + + auto* reshape_input_arg = builder.MakeSymbolicInput(reshape_input_shape); + auto* target_shape_input_arg = builder.MakeInitializer({4}, {0, 0, 24, -1}); + auto* reshape_out_arg = builder.MakeIntermediate(); + builder.AddNode("Reshape", {reshape_input_arg, target_shape_input_arg}, {reshape_out_arg}); + + auto* shape_out_arg = builder.MakeIntermediate(); + // Shape before opset 15 have such schema, the test schema did not cover opset 15. + builder.AddNode("Shape", {reshape_out_arg}, {shape_out_arg}); + auto* transpose_out_arg = builder.MakeIntermediate(); + builder.AddNode("Transpose", {reshape_out_arg}, {transpose_out_arg}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + + auto* indices_input_arg = builder.MakeScalarInitializer(1); + auto* gather_out_arg = builder.MakeIntermediate(); + builder.AddNode("Gather", {shape_out_arg, indices_input_arg}, {gather_out_arg}) + .AddAttribute("axis", static_cast(0)); + + auto* starts_input_arg_1 = builder.MakeInitializer({1}, {-1}); + auto* ends_input_arg_1 = builder.MakeInitializer({1}, {9223372036854775807}); + auto* axes_input_arg_1 = builder.MakeInitializer({1}, {0}); + auto* slice_out_arg_1 = builder.MakeIntermediate(); + builder.AddNode("Slice", {shape_out_arg, starts_input_arg_1, ends_input_arg_1, axes_input_arg_1}, + {slice_out_arg_1}); + + auto* squeeze_out_arg_1 = builder.MakeIntermediate(); + const std::vector squeeze_axes{0}; + if (opset < 13) { + builder.AddNode("Squeeze", {slice_out_arg_1}, {squeeze_out_arg_1}).AddAttribute("axes", squeeze_axes); + } else { + auto* squeeze_axes_input_arg_1 = builder.MakeInitializer({1}, squeeze_axes); + builder.AddNode("Squeeze", {slice_out_arg_1, squeeze_axes_input_arg_1}, {squeeze_out_arg_1}); + } + + auto* unsqueeze_out_arg = builder.MakeIntermediate(); + auto* unsqueeze_out_arg_1 = builder.MakeIntermediate(); + const std::vector unsqueeze_axes{0}; + if (opset < 13) { + builder.AddNode("Unsqueeze", {gather_out_arg}, {unsqueeze_out_arg}).AddAttribute("axes", unsqueeze_axes); + builder.AddNode("Unsqueeze", {squeeze_out_arg_1}, {unsqueeze_out_arg_1}).AddAttribute("axes", unsqueeze_axes); + } else { + auto* unsqueeze_axes_input_arg = builder.MakeInitializer({1}, unsqueeze_axes); + builder.AddNode("Unsqueeze", {gather_out_arg, unsqueeze_axes_input_arg}, {unsqueeze_out_arg}); + auto* unsqueeze_axes_input_arg_1 = builder.MakeInitializer({1}, unsqueeze_axes); + builder.AddNode("Unsqueeze", {squeeze_out_arg_1, unsqueeze_axes_input_arg_1}, {unsqueeze_out_arg_1}); + } + + auto* concat_training_out_arg = builder.MakeIntermediate(); + auto* concat_input_arg = builder.MakeInitializer({1}, {-1}); + builder.AddNode("ConcatTraining", {concat_input_arg, unsqueeze_out_arg, unsqueeze_out_arg_1}, + {concat_training_out_arg}, kMSDomain) + .AddAttribute("axis", static_cast(0)); + + auto* reshape_out_arg_1 = builder.MakeOutput(); + builder.AddNode("Reshape", {transpose_out_arg, concat_training_out_arg}, {reshape_out_arg_1}); + }; + + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + InlinedHashSet compatible_eps; + std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); + std::unique_ptr transformer = std::make_unique(compatible_eps); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } +} + +TEST(ShapeOptimizerTests, SymbolicDimUsedByGather_ConcreteDimUsedByGather) { + /* + [attention_mask1_dim0,512,1536] [4]: (0, 0, 24, -1) + \ / + Reshape + / + [attention_mask1_dim0,512,24,64] + | \ + Shape Transpose + | | + [4] [attention_mask1_dim0,24,512,64] + / | | + Gather Gather | + | | | + []: (attention_mask1_dim0,) [1]: (24,) | + | | | + | | | + | | / + Unsqueeze | [1]: -1 / + \ | / / + ConcatTraining / + | / + | / + [3]: (attention_mask1_dim0, 24, -1) / + \ / + Reshape + | + [] means a shape for scalar. + */ + auto pre_graph_checker = [&](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Shape"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Transpose"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Gather"] == 2); + TEST_RETURN_IF_NOT(op_to_count["Unsqueeze"] == 1); + TEST_RETURN_IF_NOT(op_to_count["com.microsoft.ConcatTraining"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Reshape"] == 2); + + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Shape"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Transpose"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Gather"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Unsqueeze"] == 1); + TEST_RETURN_IF_NOT(op_to_count["com.microsoft.ConcatTraining"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Reshape"] == 2); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Reshape") { + NodeArg* shape_input = node.MutableInputDefs()[1]; + auto p_output_node = node.OutputNodesBegin(); + const auto p_output_node_end = node.OutputNodesEnd(); + bool find_transpose = false; + while (p_output_node != p_output_node_end) { + const auto& output_node = *p_output_node; + if (output_node.OpType().compare("Transpose") == 0) { + find_transpose = true; + break; + } + ++p_output_node; + } + + if (find_transpose) { + // Ignore the first Reshape node. + continue; + } + + // Try to parse int64 type constant initializers. + InlinedVector shape_values; + TEST_RETURN_IF_NOT(!optimizer_utils::AppendTensorFromInitializer(graph, *shape_input, shape_values, true)); + TEST_RETURN_IF_NOT(graph.GetProducerNode( + node.MutableInputDefs()[1]->Name()) + ->OpType() + .compare("ConcatTraining") == 0); + } else if (node.OpType() == "ConcatTraining") { + NodeArg* shape_input = node.MutableInputDefs()[1]; + + // Try to parse int64 type constant initializers. + InlinedVector shape_values; + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *shape_input, shape_values, true)); + TEST_RETURN_IF_NOT(shape_values.size() == 1U); + TEST_RETURN_IF_NOT(shape_values[0] == 24); + } + } + + return Status::OK(); + }; + + std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + for (auto opset : opset_candidates) { + auto build_test_case = [&](ModelTestBuilder& builder) { + std::vector> reshape_input_shape; + reshape_input_shape.reserve(3); + reshape_input_shape.push_back("attention_mask1_dim0"); + reshape_input_shape.push_back(512); + reshape_input_shape.push_back(1536); + + auto* reshape_input_arg = builder.MakeSymbolicInput(reshape_input_shape); + auto* target_shape_input_arg = builder.MakeInitializer({4}, {0, 0, 24, -1}); + auto* reshape_out_arg = builder.MakeIntermediate(); + builder.AddNode("Reshape", {reshape_input_arg, target_shape_input_arg}, {reshape_out_arg}); + + auto* shape_out_arg = builder.MakeIntermediate(); + // Shape before opset 15 have such schema, the test schema did not cover opset 15. + builder.AddNode("Shape", {reshape_out_arg}, {shape_out_arg}); + auto* transpose_out_arg = builder.MakeIntermediate(); + builder.AddNode("Transpose", {reshape_out_arg}, {transpose_out_arg}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + + auto* indices_input_arg = builder.MakeScalarInitializer(0); + auto* gather_out_arg = builder.MakeIntermediate(); + builder.AddNode("Gather", {shape_out_arg, indices_input_arg}, {gather_out_arg}) + .AddAttribute("axis", static_cast(0)); + + auto* indices_input_arg_1 = builder.MakeInitializer({1}, {2}); + auto* gather_out_arg_1 = builder.MakeIntermediate(); + builder.AddNode("Gather", {shape_out_arg, indices_input_arg_1}, {gather_out_arg_1}) + .AddAttribute("axis", static_cast(0)); + + auto* unsqueeze_out_arg = builder.MakeIntermediate(); + const std::vector unsqueeze_axes{0}; + if (opset < 13) { + builder.AddNode("Unsqueeze", {gather_out_arg}, {unsqueeze_out_arg}).AddAttribute("axes", unsqueeze_axes); + } else { + auto* unsqueeze_axes_input_arg = builder.MakeInitializer({1}, unsqueeze_axes); + builder.AddNode("Unsqueeze", {gather_out_arg, unsqueeze_axes_input_arg}, {unsqueeze_out_arg}); + } + + auto* concat_training_out_arg = builder.MakeIntermediate(); + auto* concat_input_arg = builder.MakeInitializer({1}, {-1}); + builder.AddNode("ConcatTraining", {unsqueeze_out_arg, gather_out_arg_1, concat_input_arg}, + {concat_training_out_arg}, kMSDomain) + .AddAttribute("axis", static_cast(0)); + + auto* reshape_out_arg_1 = builder.MakeOutput(); + builder.AddNode("Reshape", {transpose_out_arg, concat_training_out_arg}, {reshape_out_arg_1}); + }; + + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + InlinedHashSet compatible_eps; + std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); + std::unique_ptr transformer = std::make_unique(compatible_eps); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } +} + +// end of DISABLE_CONTRIB_OPS +#endif + +} // namespace test +} // namespace onnxruntime