diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc index 319b61fa32..9ad5edf4f2 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc @@ -345,6 +345,65 @@ std::optional IsSupportedShrunkenGather(Graph& graph, Node& node, return SliceInfo(graph, &node, false /*is_slice_scalar*/, "axis", axis, true); } +/** + * @brief Check if the Slice node can be up-streamed to the previous node. + * + * If "Slice" node is operating on one single axis, then it is supported. + * @return std::optional + */ +std::optional IsSupportedSlice(Graph& graph, Node& node, + const InlinedHashSet& + compatible_execution_providers, + const logging::Logger& logger) { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {10, 11, 13}) || + !graph_utils::IsSupportedProvider(node, compatible_execution_providers)) { + return std::nullopt; + } + + const NodeArg* data_input = node.InputDefs()[0]; + const NodeArg* starts_input = node.InputDefs()[1]; + const NodeArg* ends_input = node.InputDefs()[2]; + const NodeArg* axes_input = node.InputDefs().size() > 3 ? node.InputDefs()[3] : nullptr; + + if (data_input->Shape() == nullptr || starts_input->Shape() == nullptr || ends_input->Shape() == nullptr || + (axes_input && axes_input->Shape() == nullptr)) { + LOG_DEBUG_INFO(logger, "Skip Slice node " + node.Name() + " due to undefined shape."); + return std::nullopt; + } + + // Make sure starts/ends/axes/steps are all 1D tensors, since we only support single-dimension slicing. + if (starts_input->Shape()->dim_size() != 1 || ends_input->Shape()->dim_size() != 1 || + (axes_input && axes_input->Shape()->dim_size() != 1)) { + LOG_DEBUG_INFO(logger, "Skip Slice node " + node.Name() + " due to unsupported dim size: " + + std::to_string(starts_input->Shape()->dim_size()) + ", " + + std::to_string(ends_input->Shape()->dim_size()) + ", " + + std::to_string(axes_input ? axes_input->Shape()->dim_size() : 0)); + return std::nullopt; + } + + // Try to parse the 'axes' value. + int axis = 0; + if (axes_input) { + InlinedVector axes_values; + if (!graph_utils::IsConstantInitializer(graph, axes_input->Name()) || + !optimizer_utils::AppendTensorFromInitializer(graph, *axes_input, axes_values, true) || + axes_values.size() != 1) { + return std::nullopt; + } + axis = static_cast(axes_values[0]); + } else { + // If 'axes' is not specified, then it is [0, .., r-1], so we force data rank to be 1. + if (data_input->Shape()->dim_size() != 1) { + return std::nullopt; + } + } + + if (axis < 0) + axis += data_input->Shape()->dim_size(); + + return SliceInfo(graph, &node, false /*is_slice_scalar*/, "axis", axis, true); +} + } // namespace std::optional UpStreamGatherGraphTransformer::IsSupportedForUpstream( @@ -358,6 +417,9 @@ std::optional UpStreamGatherGraphTransformer::IsSupportedForUpstream( if (!gather_info.has_value()) { gather_info = IsSupportedShrunkenGather(graph, node, GetCompatibleExecutionProviders(), logger); } + if (!gather_info.has_value()) { + gather_info = IsSupportedSlice(graph, node, GetCompatibleExecutionProviders(), logger); + } return gather_info; } diff --git a/onnxruntime/test/optimizer/compute_optimizer_test.cc b/onnxruntime/test/optimizer/compute_optimizer_test.cc index d374492057..55a7864820 100644 --- a/onnxruntime/test/optimizer/compute_optimizer_test.cc +++ b/onnxruntime/test/optimizer/compute_optimizer_test.cc @@ -1576,6 +1576,182 @@ TEST(ComputeOptimizerTests, ShrunkenGatherElementwiseOps_PropagationOnTwoBranche 1, pre_graph_checker, post_graph_checker)); } +/* +Test graph includes multiple equivalent subgraphs as below. + graph input [4, 32, 256] (float) graph input [4, 32, 256] (float) + | | + \_____________ ______________/ + \ / + Add starts:(0) ends: (-1) axes: (1) steps: (1) + \ \ | / / + \ \ | / / + \ \ | / / + \ \ | / / + \ \ | / / + Slice + | + Identity + | + graph output [4, 31, 256] (float) + +Add an Identity node because currently we don't allow Slice generates graph output. +*/ +TEST(ComputeOptimizerTests, SliceElementwiseOps_PropagationOnTwoBranches) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + InlinedVector starts_indices; + auto pre_graph_checker = [&starts_indices](Graph& graph) -> Status { + auto op_count_pre = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_pre.size() == 3U); + TEST_RETURN_IF_NOT(op_count_pre["Add"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Slice"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1); + + for (Node& node : graph.Nodes()) { + if (node.OpType() == "Slice") { + TEST_RETURN_IF_NOT(starts_indices.empty()); + constexpr bool require_constant = true; + NodeArg* initializer_node_arg = graph.GetNodeArg(node.InputDefs()[1]->Name()); + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, starts_indices, + require_constant)); + } + } + return Status::OK(); + }; + + auto post_graph_checker = [&starts_indices](Graph& graph) { + auto op_count_post = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_post.size() == 3U); + TEST_RETURN_IF_NOT(op_count_post["Add"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Slice"] == 2); + TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1); + + for (Node& node : graph.Nodes()) { + if (node.OpType() == "Add") { + const auto& input_defs = node.InputDefs(); + + { + auto producer_node = graph.GetProducerNode(input_defs[0]->Name()); + TEST_RETURN_IF_NOT(producer_node != nullptr); + TEST_RETURN_IF_NOT(producer_node->OpType() == "Slice"); + + InlinedVector values; + constexpr bool require_constant = true; + NodeArg* initializer_node_arg = graph.GetNodeArg(producer_node->InputDefs()[1]->Name()); + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, values, + require_constant)); + for (size_t i = 0; i < values.size(); i++) { + TEST_RETURN_IF_NOT(values[i] == starts_indices[i]); + } + } + + { + auto producer_node = graph.GetProducerNode(input_defs[1]->Name()); + TEST_RETURN_IF_NOT(producer_node != nullptr); + TEST_RETURN_IF_NOT(producer_node->OpType() == "Slice"); + + InlinedVector values; + constexpr bool require_constant = true; + NodeArg* initializer_node_arg = graph.GetNodeArg(producer_node->InputDefs()[1]->Name()); + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, values, require_constant)); + for (size_t i = 0; i < values.size(); i++) { + TEST_RETURN_IF_NOT(values[i] == starts_indices[i]); + } + } + } + } + return Status::OK(); + }; + + auto build_test_case = [](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{4, 32, 256}}); + auto* input2_arg = builder.MakeInput({{4, 32, 256}}); + auto* add_out = builder.MakeIntermediate(); + builder.AddNode("Add", {input1_arg, input2_arg}, {add_out}); + + auto* starts_initializer = builder.MakeInitializer({1}, {0}); + auto* ends_initializer = builder.MakeInitializer({1}, {-1}); + auto* axes_initializer = builder.MakeInitializer({1}, {1}); + auto* steps_initializer = builder.MakeInitializer({1}, {1}); + auto* slice_out = builder.MakeIntermediate(); + builder.AddNode("Slice", {add_out, starts_initializer, ends_initializer, axes_initializer, steps_initializer}, + {slice_out}); + + auto* identity_out = builder.MakeOutput(); + builder.AddNode("Identity", {slice_out}, {identity_out}); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger, std::move(transformer), + TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); +} + +/* +Test graph includes multiple equivalent subgraphs as below. + graph input [4, 32, 256] (float) graph input [4, 32, 256] (float) + | | + \_____________ ______________/ + \ / + Add starts:(0,0) ends: (-1,-1) axes: (0,1) steps: (1,1) + \ \ | / / + \ \ | / / + \ \ | / / + \ \ | / / + \ \ | / / + Slice + | + Identity + | + graph output [3, 31, 256] (float) + +Add an Identity node because currently we don't allow Slice generates graph output. +*/ +TEST(ComputeOptimizerTests, SliceElementwiseOps_NoPropagationForMutipleAxesSlice) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + auto pre_graph_checker = [](Graph& graph) -> Status { + auto op_count_pre = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_pre.size() == 3U); + TEST_RETURN_IF_NOT(op_count_pre["Add"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Slice"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1); + + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) { + auto op_count_post = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_post.size() == 3U); + TEST_RETURN_IF_NOT(op_count_post["Add"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Slice"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1); + + return Status::OK(); + }; + + auto build_test_case = [](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{4, 32, 256}}); + auto* input2_arg = builder.MakeInput({{4, 32, 256}}); + auto* add_out = builder.MakeIntermediate(); + builder.AddNode("Add", {input1_arg, input2_arg}, {add_out}); + + auto* starts_initializer = builder.MakeInitializer({2}, {0, 0}); + auto* ends_initializer = builder.MakeInitializer({2}, {-1, -1}); + auto* axes_initializer = builder.MakeInitializer({2}, {0, 1}); + auto* steps_initializer = builder.MakeInitializer({2}, {1, 1}); + auto* slice_out = builder.MakeIntermediate(); + builder.AddNode("Slice", {add_out, starts_initializer, ends_initializer, axes_initializer, steps_initializer}, + {slice_out}); + + auto* identity_out = builder.MakeOutput(); + builder.AddNode("Identity", {slice_out}, {identity_out}); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger, std::move(transformer), + TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); +} + /* Test graph include multiple equivalent subgraphs as below. graph input [4, 32, 256] (int64_t) graph input [4, 32, 256] (int64_t)