mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
init
This commit is contained in:
parent
981d95b81e
commit
9235ae569e
2 changed files with 188 additions and 0 deletions
146
onnxruntime/core/optimizer/map_to_four_dimension.cc
Normal file
146
onnxruntime/core/optimizer/map_to_four_dimension.cc
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "core/optimizer/map_to_four_dimension.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/optimizer/optimizer_execution_frame.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/framework/data_types.h"
|
||||
|
||||
using namespace onnxruntime::common;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
MapToFourDimensions::MapToFourDimensions(const IExecutionProvider& execution_provider,
|
||||
bool skip_dequantize_linear,
|
||||
bool dequantize_initializer_for_dequantize_linear,
|
||||
const ConfigOptions& config_options,
|
||||
const InlinedHashSet<std::string_view>& compatible_execution_providers,
|
||||
const InlinedHashSet<std::string>& excluded_initializers) noexcept
|
||||
: GraphTransformer("MapToFourDimensions", compatible_execution_providers),
|
||||
skip_dequantize_linear_(skip_dequantize_linear),
|
||||
dequantize_initializer_for_dequantize_linear_(dequantize_initializer_for_dequantize_linear),
|
||||
config_options_(config_options),
|
||||
excluded_initializers_(excluded_initializers),
|
||||
execution_provider_(execution_provider) {
|
||||
}
|
||||
|
||||
onnxruntime::NodeArg* AddSliceReduceConcatNodes(onnxruntime::Graph& graph,
|
||||
onnxruntime::Node& reshape,
|
||||
onnxruntime::Node& reduce_sum,
|
||||
onnxruntime::NodeArg* old_arg,
|
||||
ONNX_NAMESPACE::TypeProto* new_type,
|
||||
bool new_on_input,
|
||||
int64_t to_type,
|
||||
onnxruntime::ProviderType providerType) {
|
||||
// Insert 2 Slice nodes, 2 ReduceSum nodes and 1 Concat node.
|
||||
|
||||
// Create 2 Slice nodes
|
||||
std::string slice_node_0_name = graph.GenerateNodeName(reshape.Name() + "_slice_0");
|
||||
std::string slice_node_1_name = graph.GenerateNodeName(reshape.Name() + "_slice_1");
|
||||
|
||||
// The Slice node type should be the same as the Reshape node (going to be removed) type.
|
||||
auto* slice_node_0_arg = &graph.GetOrCreateNodeArg(slice_node_0_name, reshape.OutputDefs()[0]->TypeAsProto());
|
||||
auto* slice_node_1_arg = &graph.GetOrCreateNodeArg(slice_node_1_name, reshape.OutputDefs()[0]->TypeAsProto());
|
||||
|
||||
std::vector<onnxruntime::NodeArg*> slice_node_0_output_defs = {slice_node_0_arg};
|
||||
std::vector<onnxruntime::NodeArg*> slice_node_1_output_defs = {slice_node_1_arg};
|
||||
|
||||
auto& slice_node_0 = graph.AddNode(slice_node_0_name, "Slice", "Map 5D/6D to 4D",
|
||||
reshape.MutableInputDefs(), slice_node_0_output_defs);
|
||||
auto& slice_node_1 = graph.AddNode(slice_node_1_name, "Slice", "Map 5D/6D to 4D",
|
||||
reshape.MutableInputDefs(), slice_node_1_output_defs);
|
||||
|
||||
// Create 2 ReduceSum nodes
|
||||
std::string reduce_sum_node_0_name = graph.GenerateNodeName(reduce_sum.Name() + "_0");
|
||||
std::string reduce_sum_node_1_name = graph.GenerateNodeName(reduce_sum.Name() + "_1");
|
||||
|
||||
// The ReduceSum node type should be the same as the original ReduceSum node (going to be removed) type.
|
||||
auto* reduce_sum_node_0_arg = &graph.GetOrCreateNodeArg(reduce_sum_node_0_name, reduce_sum.OutputDefs()[0]->TypeAsProto());
|
||||
auto* reduce_sum_node_1_arg = &graph.GetOrCreateNodeArg(reduce_sum_node_1_name, reduce_sum.OutputDefs()[0]->TypeAsProto());
|
||||
|
||||
std::vector<onnxruntime::NodeArg*> reduce_sum_node_0_output_defs = {reduce_sum_node_0_arg};
|
||||
std::vector<onnxruntime::NodeArg*> reduce_sum_node_1_output_defs = {reduce_sum_node_1_arg};
|
||||
|
||||
auto& reduce_sum_node_0 = graph.AddNode(reduce_sum_node_0_name, "ReduceSum", "Map 5D/6D to 4D",
|
||||
slice_node_0_output_defs, reduce_sum_node_0_output_defs);
|
||||
auto& reduce_sum_node_1 = graph.AddNode(reduce_sum_node_1_name, "ReduceSum", "Map 5D/6D to 4D",
|
||||
slice_node_1_output_defs, reduce_sum_node_1_output_defs);
|
||||
|
||||
// Create 1 Concat
|
||||
std::string concat_node_name = graph.GenerateNodeName(reduce_sum_node_0_name + "_concat");
|
||||
auto* concat_node_arg = &graph.GetOrCreateNodeArg(concat_node_name, reduce_sum.OutputDefs()[0]->TypeAsProto());
|
||||
std::vector<onnxruntime::NodeArg*> concat_node_arg_input_defs = {reduce_sum_node_0_arg, reduce_sum_node_1_arg};
|
||||
std::vector<onnxruntime::NodeArg*> concat_node_arg_output_defs = {concat_node_arg};
|
||||
auto& concat_node = graph.AddNode(concat_node_name, "Concat", "Map 5D/6D to 4D",
|
||||
concat_node_arg_input_defs, concat_node_arg_output_defs);
|
||||
|
||||
return concat_node_arg;
|
||||
}
|
||||
|
||||
Status MapToFourDimensions::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
bool have_updated_nodes = false;
|
||||
GraphViewer graph_viewer(graph);
|
||||
auto& order = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
for (NodeIndex i : order) {
|
||||
auto* node = graph.GetNode(i);
|
||||
if (!node) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool map_tensor_to_4d = false;
|
||||
|
||||
// Requirements:
|
||||
// 1. Map 2D/3D to 4D. Replace 2D Gemms with Transpose/Reshape and 1x1 Conv.
|
||||
// 2. Map 5D/6D to 4D
|
||||
if (node->OpType() == "MatMul") {
|
||||
const auto* input_0 = node->InputDefs()[0];
|
||||
const auto* input_1 = node->InputDefs()[1];
|
||||
if ((input_0->Shape()->dim_size() == 2 || input_0->Shape()->dim_size() == 3) &&
|
||||
(input_1->Shape()->dim_size() == 2 || input_1->Shape()->dim_size() == 3)) {
|
||||
//map_tensor_to_4d = true;
|
||||
}
|
||||
} else if (node->OpType() == "ReduceSum") {
|
||||
// assume Reshape -> Q -> DQ -> ReduceSum since we don't remove Q/DQ for now
|
||||
// TODO: Make sure Reshape, Q and DQ does exist
|
||||
const Node& node_x = *node->InputNodesBegin(); // Q
|
||||
const Node& node_y = *node_x.InputNodesBegin(); // DQ
|
||||
const Node& node_z = *node_y.InputNodesBegin(); // Reshape
|
||||
if (node_z.OpType() == "Reshape") {
|
||||
const auto* output_0 = node_z.OutputDefs()[0];
|
||||
if (output_0->Shape()->dim_size() == 5) {
|
||||
map_tensor_to_4d = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!map_tensor_to_4d) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node->OpType() == "MatMul") {
|
||||
const auto* input_0 = node->InputDefs()[0]; // X
|
||||
const auto* input_1 = node->InputDefs()[1]; // W
|
||||
} else if (node->OpType() == "ReduceSum") {
|
||||
// assume Reshape -> Q -> DQ -> ReduceSum since we don't remove Q/DQ for now
|
||||
// TODO: Make sure Reshape, Q and DQ does exist
|
||||
const Node& node_x = *node->InputNodesBegin(); // Q
|
||||
const Node& node_y = *node_x.InputNodesBegin(); // DQ
|
||||
const Node& node_z = *node_y.InputNodesBegin(); // Reshape
|
||||
|
||||
const auto* input_0 = node_z.InputDefs()[0];
|
||||
const auto* output_0 = node_z.OutputDefs()[0];
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
42
onnxruntime/core/optimizer/map_to_four_dimension.h
Normal file
42
onnxruntime/core/optimizer/map_to_four_dimension.h
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
#include "core/framework/ort_value.h"
|
||||
#include <memory>
|
||||
#include "core/framework/execution_provider.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
@class ConstantFolding
|
||||
|
||||
Transformer that traverses the graph top-down and performs constant folding, i.e.,
|
||||
it statically computes parts of the graph that rely only on constant initializers.
|
||||
*/
|
||||
class MapToFourDimensions : public GraphTransformer {
|
||||
public:
|
||||
/*! Constant folding will not be applied to nodes that have one of initializers from excluded_initializers as input.
|
||||
For pre-training, the trainable weights are those initializers to be excluded.
|
||||
\param execution_provider Execution provider instance to execute constant folding.
|
||||
*/
|
||||
MapToFourDimensions(const IExecutionProvider& execution_provider,
|
||||
bool skip_dequantize_linear,
|
||||
bool dequantize_initializer_for_dequantize_linear,
|
||||
const ConfigOptions& config_options,
|
||||
const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
|
||||
const InlinedHashSet<std::string>& excluded_initializers = {}) noexcept;
|
||||
|
||||
private:
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
|
||||
bool skip_dequantize_linear_;
|
||||
bool dequantize_initializer_for_dequantize_linear_;
|
||||
const ConfigOptions& config_options_;
|
||||
const InlinedHashSet<std::string> excluded_initializers_;
|
||||
const IExecutionProvider& execution_provider_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue