Fix in Slice Elimination (issue #885) (#918)

Slice elimination should not be triggered when starts or ends is negative; small fix in op set domain validation. Fixes issue #885.
This commit is contained in:
Konstantinos Karanasos 2019-04-26 21:04:29 -07:00 committed by GitHub
parent 90544ed766
commit 7a42ffd15f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 40 additions and 22 deletions

View file

@ -265,17 +265,24 @@ const std::string& GetNodeOutputName(const Node& node, int index) {
return outputs[index]->Name();
}
// fusion is only done for ONNX domain ops
bool IsSupportedOptypeVersionAndDomain(const Node& node,
const std::string& op_type,
ONNX_NAMESPACE::OperatorSetVersion version,
const std::string& domain) {
if (node.OpType() != op_type ||
node.Op()->Deprecated() || node.Op()->SinceVersion() != version ||
(!node.Domain().empty() && node.Domain() != domain)) {
return false;
}
return true;
return (node.OpType() == op_type && !node.Op()->Deprecated() &&
MatchesOpSinceVersion(node, version) && MatchesOpSetDomain(node, domain));
}
bool MatchesOpSinceVersion(const Node& node, ONNX_NAMESPACE::OperatorSetVersion version) {
return node.Op()->SinceVersion() == version;
}
bool MatchesOpSetDomain(const Node& node, const std::string& domain) {
const auto& node_domain = node.Domain();
// We do a special check for the ONNX domain, as it has two aliases.
return node_domain == domain ||
((node_domain == kOnnxDomain || node_domain == kOnnxDomainAlias) &&
(domain == kOnnxDomain || domain == kOnnxDomainAlias));
}
bool IsSupportedProvider(const Node& node,

View file

@ -10,18 +10,24 @@ namespace onnxruntime {
namespace graph_utils {
/** Checks if the operator's type, version, and domain of the given node match the given values. */
bool IsSupportedOptypeVersionAndDomain(const Node& node,
const std::string& op_type,
ONNX_NAMESPACE::OperatorSetVersion version,
const std::string& domain = kOnnxDomainAlias);
/* Returns true if the execution provider assigned to current node is present in the compatible providers list
* or if the compatible_providers list is empty
*/
/** Checks if the node has the same operator since version as the given one. */
bool MatchesOpSinceVersion(const Node& node, ONNX_NAMESPACE::OperatorSetVersion version);
/** Checks if the node has the same op set domain as the given one. */
bool MatchesOpSetDomain(const Node& node, const std::string& domain);
/** Returns true if the execution provider assigned to current node is present in the compatible providers list
or if the compatible_providers list is empty. */
bool IsSupportedProvider(const Node& node,
const std::unordered_set<std::string>& compatible_providers);
/** Check whether the node has a single input and a single output. The single input can be either the output of
/** Checks whether the node has a single input and a single output. The single input can be either the output of
another node or an initializer, but not an implicit input from a parent subgraph. The single output can be
fed to multiple downstream operators, i.e., it can have multiple output edges. */
bool IsSingleInSingleOutNode(const Node& node);
@ -32,16 +38,16 @@ bool IsGraphInput(const Graph& graph, const NodeArg* input);
/** Checks if the given node has only constant inputs (initializers). */
bool AllNodeInputsAreConstant(const Graph& graph, const Node& node);
/** Get the name of the incoming NodeArg with the specified index for the given node. */
/** Gets the name of the incoming NodeArg with the specified index for the given node. */
const std::string& GetNodeInputName(const Node& node, int index);
/** Get the name of the outgoing NodeArg with the specified index for the given node. */
/** Gets the name of the outgoing NodeArg with the specified index for the given node. */
const std::string& GetNodeOutputName(const Node& node, int index);
/** Return the attribute of a Node with a given name. */
/** Returns the attribute of a Node with a given name. */
const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name);
/** Retrieve the values for a repeated attribute of a node and place them to the values vector. */
/** Retrieves the values for a repeated attribute of a node and place them to the values vector. */
template <typename T>
bool GetRepeatedNodeAttributeValues(const Node& node,
const std::string& attr_name,
@ -58,12 +64,12 @@ bool GetRepeatedNodeAttributeValues(const Node& node,
Status ForAllMutableSubgraphs(Graph& main_graph, std::function<Status(Graph&)> func);
Status ForAllSubgraphs(const Graph& main_graph, std::function<Status(const Graph&)> func);
/** Remove the given single-input Node from the Graph. The single input might be either
/** Removes the given single-input Node from the Graph. The single input might be either
another node or an initializer, but not an implicit input. The node should have a single
output but can have multiple output edges. */
bool RemoveSingleInputNode(Graph& graph, Node& node);
/** Remove all output edges from the given Node of the Graph.
/** Removes all output edges from the given Node of the Graph.
This should probably be elevated to the Graph API eventually. */
size_t RemoveNodeOutputEdges(Graph& graph, Node& node);

View file

@ -17,6 +17,12 @@ Status EliminateSlice::Apply(Graph& graph, Node& node, bool& modified, bool& rem
}
bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node) {
// We currently support elimination for Slice operator v1.
// TODO Extend to support Slice operator v10, which includes "steps" and all attributes are now given as inputs.
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", 1)) {
return false;
}
if (!graph_utils::IsSingleInSingleOutNode(node) ||
graph.IsNodeOutputsInGraphOutputs(node)) {
return false;
@ -34,15 +40,14 @@ bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node) {
for (int i = 0; (size_t)i < starts.size(); ++i) {
axes.push_back(i);
}
} else if (axes.size() != starts.size() || axes.size() != ends.size()) {
} else if (axes.size() != starts.size()) {
return false;
}
// For now eliminate slice operators if starts=0 and ends=MAX_INT or -1.
// For now eliminate slice operators if starts=0 and ends=MAX_INT.
// TODO: Take into account the input's shape to get a tighter bound for the ends.
for (size_t i = 0; i < axes.size(); ++i) {
if (starts[i] > 0 || starts[i] < 0 ||
(ends[i] > 0 && ends[i] < INT64_MAX)) {
if (starts[i] != 0 || ends[i] < INT64_MAX) {
return false;
}
}

View file

@ -77,7 +77,7 @@ TEST(GraphTransformationTests, SliceElimination) {
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK());
op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Slice"] == 3);
ASSERT_TRUE(op_to_count["Slice"] == 4);
}
TEST(GraphTransformationTests, ConstantFolding1) {