From 7a42ffd15fab88ed1337359358ff3703f4cdf82e Mon Sep 17 00:00:00 2001 From: Konstantinos Karanasos Date: Fri, 26 Apr 2019 21:04:29 -0700 Subject: [PATCH] Fix in Slice Elimination (issue #885) (#918) Slice elimination should not be triggered when starts or ends is negative; small fix in op set domain validation. Fixes issue #885. --- onnxruntime/core/graph/graph_utils.cc | 21 ++++++++++----- onnxruntime/core/graph/graph_utils.h | 26 ++++++++++++------- .../core/optimizer/slice_elimination.cc | 13 +++++++--- .../test/optimizer/graph_transform_test.cc | 2 +- 4 files changed, 40 insertions(+), 22 deletions(-) diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index f82962174d..269e981dd9 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -265,17 +265,24 @@ const std::string& GetNodeOutputName(const Node& node, int index) { return outputs[index]->Name(); } -// fusion is only done for ONNX domain ops bool IsSupportedOptypeVersionAndDomain(const Node& node, const std::string& op_type, ONNX_NAMESPACE::OperatorSetVersion version, const std::string& domain) { - if (node.OpType() != op_type || - node.Op()->Deprecated() || node.Op()->SinceVersion() != version || - (!node.Domain().empty() && node.Domain() != domain)) { - return false; - } - return true; + return (node.OpType() == op_type && !node.Op()->Deprecated() && + MatchesOpSinceVersion(node, version) && MatchesOpSetDomain(node, domain)); +} + +bool MatchesOpSinceVersion(const Node& node, ONNX_NAMESPACE::OperatorSetVersion version) { + return node.Op()->SinceVersion() == version; +} + +bool MatchesOpSetDomain(const Node& node, const std::string& domain) { + const auto& node_domain = node.Domain(); + // We do a special check for the ONNX domain, as it has two aliases. + return node_domain == domain || + ((node_domain == kOnnxDomain || node_domain == kOnnxDomainAlias) && + (domain == kOnnxDomain || domain == kOnnxDomainAlias)); } bool IsSupportedProvider(const Node& node, diff --git a/onnxruntime/core/graph/graph_utils.h b/onnxruntime/core/graph/graph_utils.h index 15d0e19465..1f3593b013 100644 --- a/onnxruntime/core/graph/graph_utils.h +++ b/onnxruntime/core/graph/graph_utils.h @@ -10,18 +10,24 @@ namespace onnxruntime { namespace graph_utils { +/** Checks if the operator's type, version, and domain of the given node match the given values. */ bool IsSupportedOptypeVersionAndDomain(const Node& node, const std::string& op_type, ONNX_NAMESPACE::OperatorSetVersion version, const std::string& domain = kOnnxDomainAlias); -/* Returns true if the execution provider assigned to current node is present in the compatible providers list - * or if the compatible_providers list is empty - */ +/** Checks if the node has the same operator since version as the given one. */ +bool MatchesOpSinceVersion(const Node& node, ONNX_NAMESPACE::OperatorSetVersion version); + +/** Checks if the node has the same op set domain as the given one. */ +bool MatchesOpSetDomain(const Node& node, const std::string& domain); + +/** Returns true if the execution provider assigned to current node is present in the compatible providers list + or if the compatible_providers list is empty. */ bool IsSupportedProvider(const Node& node, const std::unordered_set& compatible_providers); -/** Check whether the node has a single input and a single output. The single input can be either the output of +/** Checks whether the node has a single input and a single output. The single input can be either the output of another node or an initializer, but not an implicit input from a parent subgraph. The single output can be fed to multiple downstream operators, i.e., it can have multiple output edges. */ bool IsSingleInSingleOutNode(const Node& node); @@ -32,16 +38,16 @@ bool IsGraphInput(const Graph& graph, const NodeArg* input); /** Checks if the given node has only constant inputs (initializers). */ bool AllNodeInputsAreConstant(const Graph& graph, const Node& node); -/** Get the name of the incoming NodeArg with the specified index for the given node. */ +/** Gets the name of the incoming NodeArg with the specified index for the given node. */ const std::string& GetNodeInputName(const Node& node, int index); -/** Get the name of the outgoing NodeArg with the specified index for the given node. */ +/** Gets the name of the outgoing NodeArg with the specified index for the given node. */ const std::string& GetNodeOutputName(const Node& node, int index); -/** Return the attribute of a Node with a given name. */ +/** Returns the attribute of a Node with a given name. */ const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name); -/** Retrieve the values for a repeated attribute of a node and place them to the values vector. */ +/** Retrieves the values for a repeated attribute of a node and place them to the values vector. */ template bool GetRepeatedNodeAttributeValues(const Node& node, const std::string& attr_name, @@ -58,12 +64,12 @@ bool GetRepeatedNodeAttributeValues(const Node& node, Status ForAllMutableSubgraphs(Graph& main_graph, std::function func); Status ForAllSubgraphs(const Graph& main_graph, std::function func); -/** Remove the given single-input Node from the Graph. The single input might be either +/** Removes the given single-input Node from the Graph. The single input might be either another node or an initializer, but not an implicit input. The node should have a single output but can have multiple output edges. */ bool RemoveSingleInputNode(Graph& graph, Node& node); -/** Remove all output edges from the given Node of the Graph. +/** Removes all output edges from the given Node of the Graph. This should probably be elevated to the Graph API eventually. */ size_t RemoveNodeOutputEdges(Graph& graph, Node& node); diff --git a/onnxruntime/core/optimizer/slice_elimination.cc b/onnxruntime/core/optimizer/slice_elimination.cc index b7a195b571..f3ab618618 100644 --- a/onnxruntime/core/optimizer/slice_elimination.cc +++ b/onnxruntime/core/optimizer/slice_elimination.cc @@ -17,6 +17,12 @@ Status EliminateSlice::Apply(Graph& graph, Node& node, bool& modified, bool& rem } bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node) { + // We currently support elimination for Slice operator v1. + // TODO Extend to support Slice operator v10, which includes "steps" and all attributes are now given as inputs. + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", 1)) { + return false; + } + if (!graph_utils::IsSingleInSingleOutNode(node) || graph.IsNodeOutputsInGraphOutputs(node)) { return false; @@ -34,15 +40,14 @@ bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node) { for (int i = 0; (size_t)i < starts.size(); ++i) { axes.push_back(i); } - } else if (axes.size() != starts.size() || axes.size() != ends.size()) { + } else if (axes.size() != starts.size()) { return false; } - // For now eliminate slice operators if starts=0 and ends=MAX_INT or -1. + // For now eliminate slice operators if starts=0 and ends=MAX_INT. // TODO: Take into account the input's shape to get a tighter bound for the ends. for (size_t i = 0; i < axes.size(); ++i) { - if (starts[i] > 0 || starts[i] < 0 || - (ends[i] > 0 && ends[i] < INT64_MAX)) { + if (starts[i] != 0 || ends[i] < INT64_MAX) { return false; } } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index b3edbf106b..039a4bac94 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -77,7 +77,7 @@ TEST(GraphTransformationTests, SliceElimination) { ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Slice"] == 3); + ASSERT_TRUE(op_to_count["Slice"] == 4); } TEST(GraphTransformationTests, ConstantFolding1) {