mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-27 03:11:28 +00:00
Work on minimizing memory management calls by reducing number of allocations and copies. Replace std::unordered_set to InlinedHashSet and add usage of InlinedVector. Employ std::move() to minimize copying and memory allocations. Remove copying of the const shared data into each of the PropagateCast transformer instances. Move inlined_containers.h header to include/common Adjust AsSpan imlementation for C++ < 17
128 lines
5 KiB
C++
128 lines
5 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#include "core/optimizer/slice_elimination.h"
|
|
#include "core/graph/graph.h"
|
|
#include "core/graph/graph_utils.h"
|
|
#include "core/graph/op.h"
|
|
#include "core/optimizer/initializer.h"
|
|
|
|
namespace onnxruntime {
|
|
|
|
Status EliminateSlice::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
|
|
if (graph_utils::RemoveNode(graph, node)) {
|
|
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
|
|
// We currently support elimination for Slice operator v1.
|
|
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11, 13})) {
|
|
return false;
|
|
}
|
|
|
|
if (!graph_utils::CanRemoveNode(graph, node, logger)) {
|
|
return false;
|
|
}
|
|
|
|
InlinedVector<int64_t> starts;
|
|
InlinedVector<int64_t> ends;
|
|
|
|
if (graph_utils::MatchesOpSinceVersion(node, {1})) {
|
|
// If it is a Slice operator of opset version 1, starts/ends/axes are provided as node attributes.
|
|
if (!graph_utils::GetRepeatedNodeAttributeValues(node, "starts", starts) ||
|
|
!graph_utils::GetRepeatedNodeAttributeValues(node, "ends", ends) ||
|
|
starts.size() != ends.size()) {
|
|
return false;
|
|
}
|
|
InlinedVector<int64_t> axes;
|
|
// If there is an axes attribute, it has to be the same size as the starts and ends.
|
|
if (graph_utils::GetRepeatedNodeAttributeValues(node, "axes", axes) && (axes.size() != starts.size())) {
|
|
return false;
|
|
}
|
|
} else if (graph_utils::MatchesOpSinceVersion(node, {10, 11, 13})) {
|
|
// If it is a Slice operator of opset version >= 10, starts/ends/axes/steps are provided as node inputs.
|
|
|
|
// Returns a pointer to the corresponding NodeArg if input of the node at this index exists; otherwise, a nullptr.
|
|
auto get_input_if_exists = [&node](size_t input_idx) -> const NodeArg* {
|
|
const auto& input_defs = node.InputDefs();
|
|
const NodeArg* input = (input_defs.size() > input_idx) ? input_defs[input_idx] : nullptr;
|
|
return (input == nullptr || !input->Exists()) ? nullptr : input;
|
|
};
|
|
|
|
// Returns a pointer to the initializer if it is constant; otherwise, a nullptr.
|
|
auto get_initializer_if_constant =
|
|
[&graph, get_input_if_exists](size_t input_idx) -> const ONNX_NAMESPACE::TensorProto* {
|
|
const NodeArg* input = get_input_if_exists(input_idx);
|
|
return input ? graph_utils::GetConstantInitializer(graph, input->Name()) : nullptr;
|
|
};
|
|
|
|
auto get_initializer_data =
|
|
[&graph](const ONNX_NAMESPACE::TensorProto* initializer) -> InlinedVector<int64_t> {
|
|
Initializer init(*initializer, graph.ModelPath());
|
|
if (initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT32) {
|
|
int32_t* init_data = init.data<int32_t>();
|
|
return InlinedVector<int64_t>(init_data, init_data + init.size());
|
|
} else if (initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT64) {
|
|
int64_t* init_data = init.data<int64_t>();
|
|
return InlinedVector<int64_t>(init_data, init_data + init.size());
|
|
}
|
|
return {};
|
|
};
|
|
|
|
// Starts and ends inputs have to exist, be constant, and be of the same size.
|
|
const ONNX_NAMESPACE::TensorProto* starts_init = get_initializer_if_constant(1);
|
|
const ONNX_NAMESPACE::TensorProto* ends_init = get_initializer_if_constant(2);
|
|
if (starts_init && ends_init) {
|
|
starts = get_initializer_data(starts_init);
|
|
ends = get_initializer_data(ends_init);
|
|
|
|
if (starts.size() == 0 || ends.size() == 0 || starts.size() != ends.size()) {
|
|
return false;
|
|
}
|
|
|
|
// If axes input exists, it should be constant and of the same size as starts/ends.
|
|
if (get_input_if_exists(3)) {
|
|
const ONNX_NAMESPACE::TensorProto* axes_init = get_initializer_if_constant(3);
|
|
if (!axes_init || axes_init->dims_size() != 1 ||
|
|
static_cast<size_t>(axes_init->dims().Get(0)) != starts.size()) {
|
|
return false;
|
|
}
|
|
|
|
// If steps input exists, it should be constant and all values should be 1.
|
|
if (get_input_if_exists(4)) {
|
|
const ONNX_NAMESPACE::TensorProto* steps_init = get_initializer_if_constant(4);
|
|
if (!steps_init) {
|
|
return false;
|
|
}
|
|
InlinedVector<int64_t> steps = get_initializer_data(steps_init);
|
|
if (steps.size() != starts.size()) {
|
|
return false;
|
|
}
|
|
for (int64_t step : steps) {
|
|
if (step != 1) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
// Should be unreachable, but just to be safe in case a new op version is added.
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// For now eliminate slice operators if starts=0 and ends=MAX_INT.
|
|
// TODO: Take into account the input's shape to get a tighter bound for the ends.
|
|
for (size_t i = 0; i < starts.size(); ++i) {
|
|
if (starts[i] != 0 || ends[i] < INT64_MAX) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
} // namespace onnxruntime
|