mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
Slice elimination should not be triggered when starts or ends is negative; small fix in op set domain validation. Fixes issue #885.
This commit is contained in:
parent
90544ed766
commit
7a42ffd15f
4 changed files with 40 additions and 22 deletions
|
|
@ -265,17 +265,24 @@ const std::string& GetNodeOutputName(const Node& node, int index) {
|
|||
return outputs[index]->Name();
|
||||
}
|
||||
|
||||
// fusion is only done for ONNX domain ops
|
||||
bool IsSupportedOptypeVersionAndDomain(const Node& node,
|
||||
const std::string& op_type,
|
||||
ONNX_NAMESPACE::OperatorSetVersion version,
|
||||
const std::string& domain) {
|
||||
if (node.OpType() != op_type ||
|
||||
node.Op()->Deprecated() || node.Op()->SinceVersion() != version ||
|
||||
(!node.Domain().empty() && node.Domain() != domain)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
return (node.OpType() == op_type && !node.Op()->Deprecated() &&
|
||||
MatchesOpSinceVersion(node, version) && MatchesOpSetDomain(node, domain));
|
||||
}
|
||||
|
||||
bool MatchesOpSinceVersion(const Node& node, ONNX_NAMESPACE::OperatorSetVersion version) {
|
||||
return node.Op()->SinceVersion() == version;
|
||||
}
|
||||
|
||||
bool MatchesOpSetDomain(const Node& node, const std::string& domain) {
|
||||
const auto& node_domain = node.Domain();
|
||||
// We do a special check for the ONNX domain, as it has two aliases.
|
||||
return node_domain == domain ||
|
||||
((node_domain == kOnnxDomain || node_domain == kOnnxDomainAlias) &&
|
||||
(domain == kOnnxDomain || domain == kOnnxDomainAlias));
|
||||
}
|
||||
|
||||
bool IsSupportedProvider(const Node& node,
|
||||
|
|
|
|||
|
|
@ -10,18 +10,24 @@ namespace onnxruntime {
|
|||
|
||||
namespace graph_utils {
|
||||
|
||||
/** Checks if the operator's type, version, and domain of the given node match the given values. */
|
||||
bool IsSupportedOptypeVersionAndDomain(const Node& node,
|
||||
const std::string& op_type,
|
||||
ONNX_NAMESPACE::OperatorSetVersion version,
|
||||
const std::string& domain = kOnnxDomainAlias);
|
||||
|
||||
/* Returns true if the execution provider assigned to current node is present in the compatible providers list
|
||||
* or if the compatible_providers list is empty
|
||||
*/
|
||||
/** Checks if the node has the same operator since version as the given one. */
|
||||
bool MatchesOpSinceVersion(const Node& node, ONNX_NAMESPACE::OperatorSetVersion version);
|
||||
|
||||
/** Checks if the node has the same op set domain as the given one. */
|
||||
bool MatchesOpSetDomain(const Node& node, const std::string& domain);
|
||||
|
||||
/** Returns true if the execution provider assigned to current node is present in the compatible providers list
|
||||
or if the compatible_providers list is empty. */
|
||||
bool IsSupportedProvider(const Node& node,
|
||||
const std::unordered_set<std::string>& compatible_providers);
|
||||
|
||||
/** Check whether the node has a single input and a single output. The single input can be either the output of
|
||||
/** Checks whether the node has a single input and a single output. The single input can be either the output of
|
||||
another node or an initializer, but not an implicit input from a parent subgraph. The single output can be
|
||||
fed to multiple downstream operators, i.e., it can have multiple output edges. */
|
||||
bool IsSingleInSingleOutNode(const Node& node);
|
||||
|
|
@ -32,16 +38,16 @@ bool IsGraphInput(const Graph& graph, const NodeArg* input);
|
|||
/** Checks if the given node has only constant inputs (initializers). */
|
||||
bool AllNodeInputsAreConstant(const Graph& graph, const Node& node);
|
||||
|
||||
/** Get the name of the incoming NodeArg with the specified index for the given node. */
|
||||
/** Gets the name of the incoming NodeArg with the specified index for the given node. */
|
||||
const std::string& GetNodeInputName(const Node& node, int index);
|
||||
|
||||
/** Get the name of the outgoing NodeArg with the specified index for the given node. */
|
||||
/** Gets the name of the outgoing NodeArg with the specified index for the given node. */
|
||||
const std::string& GetNodeOutputName(const Node& node, int index);
|
||||
|
||||
/** Return the attribute of a Node with a given name. */
|
||||
/** Returns the attribute of a Node with a given name. */
|
||||
const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name);
|
||||
|
||||
/** Retrieve the values for a repeated attribute of a node and place them to the values vector. */
|
||||
/** Retrieves the values for a repeated attribute of a node and place them to the values vector. */
|
||||
template <typename T>
|
||||
bool GetRepeatedNodeAttributeValues(const Node& node,
|
||||
const std::string& attr_name,
|
||||
|
|
@ -58,12 +64,12 @@ bool GetRepeatedNodeAttributeValues(const Node& node,
|
|||
Status ForAllMutableSubgraphs(Graph& main_graph, std::function<Status(Graph&)> func);
|
||||
Status ForAllSubgraphs(const Graph& main_graph, std::function<Status(const Graph&)> func);
|
||||
|
||||
/** Remove the given single-input Node from the Graph. The single input might be either
|
||||
/** Removes the given single-input Node from the Graph. The single input might be either
|
||||
another node or an initializer, but not an implicit input. The node should have a single
|
||||
output but can have multiple output edges. */
|
||||
bool RemoveSingleInputNode(Graph& graph, Node& node);
|
||||
|
||||
/** Remove all output edges from the given Node of the Graph.
|
||||
/** Removes all output edges from the given Node of the Graph.
|
||||
This should probably be elevated to the Graph API eventually. */
|
||||
size_t RemoveNodeOutputEdges(Graph& graph, Node& node);
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,12 @@ Status EliminateSlice::Apply(Graph& graph, Node& node, bool& modified, bool& rem
|
|||
}
|
||||
|
||||
bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node) {
|
||||
// We currently support elimination for Slice operator v1.
|
||||
// TODO Extend to support Slice operator v10, which includes "steps" and all attributes are now given as inputs.
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", 1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!graph_utils::IsSingleInSingleOutNode(node) ||
|
||||
graph.IsNodeOutputsInGraphOutputs(node)) {
|
||||
return false;
|
||||
|
|
@ -34,15 +40,14 @@ bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node) {
|
|||
for (int i = 0; (size_t)i < starts.size(); ++i) {
|
||||
axes.push_back(i);
|
||||
}
|
||||
} else if (axes.size() != starts.size() || axes.size() != ends.size()) {
|
||||
} else if (axes.size() != starts.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// For now eliminate slice operators if starts=0 and ends=MAX_INT or -1.
|
||||
// For now eliminate slice operators if starts=0 and ends=MAX_INT.
|
||||
// TODO: Take into account the input's shape to get a tighter bound for the ends.
|
||||
for (size_t i = 0; i < axes.size(); ++i) {
|
||||
if (starts[i] > 0 || starts[i] < 0 ||
|
||||
(ends[i] > 0 && ends[i] < INT64_MAX)) {
|
||||
if (starts[i] != 0 || ends[i] < INT64_MAX) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ TEST(GraphTransformationTests, SliceElimination) {
|
|||
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK());
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Slice"] == 3);
|
||||
ASSERT_TRUE(op_to_count["Slice"] == 4);
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, ConstantFolding1) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue