mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
cleanup NCHWc transformer (#8479)
This commit is contained in:
parent
3850755feb
commit
7d47175f76
1 changed files with 49 additions and 63 deletions
|
|
@ -93,7 +93,7 @@ class NchwcTransformerImpl {
|
|||
// Stores the proto shape for the NCHWc output.
|
||||
const NchwcArgument::Shape shape_;
|
||||
|
||||
NchwcArgument(Node& output_node, NodeArg* output_nchwc_arg, size_t original_uses, size_t channels, const NchwcArgument::Shape& shape)
|
||||
NchwcArgument(Node& output_node, NodeArg* output_nchwc_arg, size_t original_uses, int64_t channels, const NchwcArgument::Shape& shape)
|
||||
: output_node_(output_node),
|
||||
nchwc_arg_(output_nchwc_arg),
|
||||
starting_original_uses_(original_uses),
|
||||
|
|
@ -103,8 +103,13 @@ class NchwcTransformerImpl {
|
|||
}
|
||||
};
|
||||
|
||||
NchwcArgument* LookupNchwcArgument(NodeArg* arg) {
|
||||
auto it = nchwc_args_.find(arg);
|
||||
return (it != nchwc_args_.end()) ? it->second.get() : nullptr;
|
||||
}
|
||||
|
||||
size_t RemoveOutputEdges(Node& node);
|
||||
void CreateNchwcArgument(Node& node, Node& nchwc_node, size_t channels, const NchwcArgument::Shape& shape);
|
||||
void CreateNchwcArgument(Node& node, Node& nchwc_node, int64_t channels, const NchwcArgument::Shape& shape);
|
||||
void FuseNchwcArgument(Node& node, const NchwcArgument& nchwc_arg);
|
||||
void InsertReorderInput(Node& node);
|
||||
|
||||
|
|
@ -112,7 +117,7 @@ class NchwcTransformerImpl {
|
|||
const NchwcArgument::Shape& input_shape,
|
||||
NchwcArgument::Shape& output_shape,
|
||||
const ONNX_NAMESPACE::TensorProto* filter_shape);
|
||||
Node& InsertReshape(NodeArg* input_arg, NodeArg* output_arg, int64_t channels, bool split_channels);
|
||||
Node& InsertReshape(NodeArg* input_arg, NodeArg* output_arg, bool split_channels);
|
||||
|
||||
void TransformConv(Node& node);
|
||||
void TransformPool(Node& node);
|
||||
|
|
@ -146,10 +151,10 @@ class NchwcTransformerImpl {
|
|||
// NCHWc block size, so multiple nodes can share the NCHWc biases.
|
||||
std::unordered_map<NodeArg*, NodeArg*> aligned_biases_;
|
||||
|
||||
// Stores a mapping of shape initializers for use by Reshape when splitting
|
||||
// or unsplitting the channels dimension of a tensor.
|
||||
std::unordered_map<int64_t, NodeArg*> reshape_split_;
|
||||
std::unordered_map<int64_t, NodeArg*> reshape_unsplit_;
|
||||
// Stores the shape initializers for Reshape to split or unsplit the channels
|
||||
// dimension of a tensor.
|
||||
NodeArg* reshape_split_arg_{nullptr};
|
||||
NodeArg* reshape_unsplit_arg_{nullptr};
|
||||
|
||||
// Tracks the last Transpose node and output NodeArg that transposed from
|
||||
// NHWC to NCHW format.
|
||||
|
|
@ -172,7 +177,7 @@ size_t NchwcTransformerImpl::RemoveOutputEdges(Node& node) {
|
|||
|
||||
void NchwcTransformerImpl::CreateNchwcArgument(Node& node,
|
||||
Node& nchwc_node,
|
||||
size_t channels,
|
||||
int64_t channels,
|
||||
const NchwcArgument::Shape& shape) {
|
||||
size_t original_uses = RemoveOutputEdges(node);
|
||||
|
||||
|
|
@ -194,7 +199,7 @@ void NchwcTransformerImpl::FuseNchwcArgument(Node& node, const NchwcArgument& nc
|
|||
auto& nchwc_node = nchwc_arg.output_node_;
|
||||
auto* output_nchwc_arg = nchwc_node.MutableOutputDefs()[0];
|
||||
nchwc_args_[output_original_arg] =
|
||||
std::make_unique<NchwcArgument>(nchwc_node, output_nchwc_arg, original_uses, static_cast<size_t>(nchwc_arg.channels_), nchwc_arg.shape_);
|
||||
std::make_unique<NchwcArgument>(nchwc_node, output_nchwc_arg, original_uses, nchwc_arg.channels_, nchwc_arg.shape_);
|
||||
}
|
||||
|
||||
void NchwcTransformerImpl::InsertReorderInput(Node& node) {
|
||||
|
|
@ -484,18 +489,17 @@ void NchwcTransformerImpl::TransformConv(Node& node) {
|
|||
NchwcArgument::Shape output_shape(output_defs[0]);
|
||||
|
||||
if (do_reorder_input) {
|
||||
auto it = nchwc_args_.find(input_defs[0]);
|
||||
if (it == nchwc_args_.end()) {
|
||||
auto* nchwc_input = LookupNchwcArgument(input_defs[0]);
|
||||
if (nchwc_input == nullptr) {
|
||||
InsertReorderInput(nchwc_node);
|
||||
} else {
|
||||
auto* nchwc_input = it->second.get();
|
||||
nchwc_node.MutableInputDefs()[0] = nchwc_input->nchwc_arg_;
|
||||
nchwc_input->remaining_original_uses_--;
|
||||
ConvPoolShapeInference(node, nchwc_input->shape_, output_shape, conv_W_tensor_proto);
|
||||
}
|
||||
}
|
||||
|
||||
CreateNchwcArgument(node, nchwc_node, static_cast<size_t>(output_channels), output_shape);
|
||||
CreateNchwcArgument(node, nchwc_node, output_channels, output_shape);
|
||||
removed_nodes_.push_front(node.Index());
|
||||
}
|
||||
|
||||
|
|
@ -540,54 +544,40 @@ void NchwcTransformerImpl::TransformPool(Node& node) {
|
|||
|
||||
NchwcArgument::Shape output_shape(output_defs[0]);
|
||||
|
||||
auto it = nchwc_args_.find(input_defs[0]);
|
||||
if (it == nchwc_args_.end()) {
|
||||
auto* nchwc_input = LookupNchwcArgument(input_defs[0]);
|
||||
if (nchwc_input == nullptr) {
|
||||
InsertReorderInput(nchwc_node);
|
||||
} else {
|
||||
auto* nchwc_input = it->second.get();
|
||||
nchwc_node.MutableInputDefs()[0] = nchwc_input->nchwc_arg_;
|
||||
nchwc_input->remaining_original_uses_--;
|
||||
ConvPoolShapeInference(node, nchwc_input->shape_, output_shape, nullptr);
|
||||
}
|
||||
|
||||
CreateNchwcArgument(node, nchwc_node, static_cast<size_t>(channels), output_shape);
|
||||
CreateNchwcArgument(node, nchwc_node, channels, output_shape);
|
||||
removed_nodes_.push_front(node.Index());
|
||||
}
|
||||
|
||||
Node& NchwcTransformerImpl::InsertReshape(NodeArg* input_arg,
|
||||
NodeArg* output_arg,
|
||||
int64_t channels,
|
||||
bool split_channels) {
|
||||
const int64_t nchwc_block_size = static_cast<int64_t>(MlasNchwcGetBlockSize());
|
||||
const int64_t nchwc_channels = (channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);
|
||||
|
||||
// Reuse the shape initializer across reshapes for the same channel configuration.
|
||||
auto& shape_arg_map = split_channels ? reshape_split_ : reshape_unsplit_;
|
||||
NodeArg* shape_arg = shape_arg_map[nchwc_channels];
|
||||
auto& shape_arg = split_channels ? reshape_split_arg_ : reshape_unsplit_arg_;
|
||||
if (shape_arg == nullptr) {
|
||||
ONNX_NAMESPACE::TensorProto shape_tensor_proto;
|
||||
shape_tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
shape_tensor_proto.set_name(graph_.GenerateNodeArgName("Reshape"));
|
||||
// Passthrough the batch dimension.
|
||||
// Passthrough the batch and spatial dimensions. Compute the channel count based
|
||||
// on the remaining tensor size and whether this is a split or unsplit.
|
||||
shape_tensor_proto.add_int64_data(0);
|
||||
if (split_channels) {
|
||||
shape_tensor_proto.add_int64_data(nchwc_channels / nchwc_block_size);
|
||||
} else {
|
||||
shape_tensor_proto.add_int64_data(nchwc_channels);
|
||||
}
|
||||
// Passthrough the spatial dimensions.
|
||||
shape_tensor_proto.add_int64_data(-1);
|
||||
for (int i = 0; i < kNchwcSpatialDims; i++) {
|
||||
shape_tensor_proto.add_int64_data(0);
|
||||
}
|
||||
if (split_channels) {
|
||||
shape_tensor_proto.add_int64_data(nchwc_block_size);
|
||||
shape_tensor_proto.add_dims(kNchwcDims + 1);
|
||||
} else {
|
||||
shape_tensor_proto.add_dims(kNchwcDims);
|
||||
shape_tensor_proto.add_int64_data(static_cast<int64_t>(MlasNchwcGetBlockSize()));
|
||||
}
|
||||
shape_tensor_proto.add_dims(split_channels ? kNchwcDims + 1 : kNchwcDims);
|
||||
|
||||
shape_arg = &graph_utils::AddInitializer(graph_, shape_tensor_proto);
|
||||
shape_arg_map[nchwc_channels] = shape_arg;
|
||||
}
|
||||
|
||||
Node& reshape_node = graph_.AddNode(graph_.GenerateNodeName("Reshape"),
|
||||
|
|
@ -609,11 +599,11 @@ void NchwcTransformerImpl::TransformBinary(Node& node, bool add_node) {
|
|||
size_t input_defs_count = input_defs.size();
|
||||
nchwc_inputs.reserve(input_defs_count);
|
||||
for (size_t i = 0; i < input_defs_count; i++) {
|
||||
auto it = nchwc_args_.find(input_defs[i]);
|
||||
if (it == nchwc_args_.end()) {
|
||||
auto* nchwc_input = LookupNchwcArgument(input_defs[i]);
|
||||
if (nchwc_input == nullptr) {
|
||||
return;
|
||||
}
|
||||
nchwc_inputs.push_back(it->second.get());
|
||||
nchwc_inputs.push_back(nchwc_input);
|
||||
}
|
||||
|
||||
auto* nchwc_input_0 = nchwc_inputs[0];
|
||||
|
|
@ -697,7 +687,7 @@ void NchwcTransformerImpl::TransformBinary(Node& node, bool add_node) {
|
|||
}
|
||||
}
|
||||
|
||||
CreateNchwcArgument(node, node, static_cast<size_t>(nchwc_input_0->channels_), nchwc_input_0->shape_);
|
||||
CreateNchwcArgument(node, node, nchwc_input_0->channels_, nchwc_input_0->shape_);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -719,7 +709,7 @@ void NchwcTransformerImpl::TransformBinary(Node& node, bool add_node) {
|
|||
for (size_t n = 0; n < input_defs_count; n++) {
|
||||
std::string reshape_input_def_name = graph_.GenerateNodeArgName("reshape");
|
||||
auto* reshape_input_arg = &graph_.GetOrCreateNodeArg(reshape_input_def_name, nullptr);
|
||||
InsertReshape(nchwc_inputs[n]->nchwc_arg_, reshape_input_arg, channels, true);
|
||||
InsertReshape(nchwc_inputs[n]->nchwc_arg_, reshape_input_arg, true);
|
||||
|
||||
input_defs[n] = reshape_input_arg;
|
||||
nchwc_inputs[n]->remaining_original_uses_--;
|
||||
|
|
@ -727,11 +717,11 @@ void NchwcTransformerImpl::TransformBinary(Node& node, bool add_node) {
|
|||
|
||||
std::string output_reshaped_def_name = graph_.GenerateNodeArgName("reshape");
|
||||
auto* output_reshaped_arg = &graph_.GetOrCreateNodeArg(output_reshaped_def_name, nullptr);
|
||||
Node& nchwc_node = InsertReshape(output_reshaped_arg, output_defs[0], channels, false);
|
||||
Node& nchwc_node = InsertReshape(output_reshaped_arg, output_defs[0], false);
|
||||
|
||||
NchwcArgument::Shape output_shape(output_defs[0]);
|
||||
|
||||
CreateNchwcArgument(node, nchwc_node, static_cast<size_t>(channels), output_shape);
|
||||
CreateNchwcArgument(node, nchwc_node, channels, output_shape);
|
||||
output_defs[0] = output_reshaped_arg;
|
||||
return;
|
||||
}
|
||||
|
|
@ -755,17 +745,17 @@ void NchwcTransformerImpl::TransformConcat(Node& node) {
|
|||
nchwc_inputs.reserve(input_defs_count);
|
||||
int64_t total_channels = 0;
|
||||
for (size_t i = 0; i < input_defs_count; i++) {
|
||||
auto it = nchwc_args_.find(input_defs[i]);
|
||||
if (it == nchwc_args_.end()) {
|
||||
auto* nchwc_input = LookupNchwcArgument(input_defs[i]);
|
||||
if (nchwc_input == nullptr) {
|
||||
return;
|
||||
}
|
||||
// Verify that the logical number of channels is block aligned.
|
||||
int64_t input_channels = it->second->channels_;
|
||||
int64_t input_channels = nchwc_input->channels_;
|
||||
if ((input_channels % nchwc_block_size) != 0) {
|
||||
return;
|
||||
}
|
||||
total_channels += input_channels;
|
||||
nchwc_inputs.push_back(it->second.get());
|
||||
nchwc_inputs.push_back(nchwc_input);
|
||||
}
|
||||
|
||||
// Update the node to directly use the NCHWc inputs directly and decrement
|
||||
|
|
@ -780,7 +770,7 @@ void NchwcTransformerImpl::TransformConcat(Node& node) {
|
|||
NchwcArgument::Shape output_shape = nchwc_inputs[0]->shape_;
|
||||
output_shape.dims_[1] = output_defs[0];
|
||||
|
||||
CreateNchwcArgument(node, node, static_cast<size_t>(total_channels), output_shape);
|
||||
CreateNchwcArgument(node, node, total_channels, output_shape);
|
||||
}
|
||||
|
||||
// After doing a Conv/Add fusion, there may be an activation node that could now
|
||||
|
|
@ -789,9 +779,8 @@ void NchwcTransformerImpl::TransformConcat(Node& node) {
|
|||
void NchwcTransformerImpl::TransformActivation(Node& node) {
|
||||
auto& input_defs = node.MutableInputDefs();
|
||||
|
||||
auto it = nchwc_args_.find(input_defs[0]);
|
||||
if (it != nchwc_args_.end()) {
|
||||
auto& nchwc_input = it->second;
|
||||
auto* nchwc_input = LookupNchwcArgument(input_defs[0]);
|
||||
if (nchwc_input != nullptr) {
|
||||
input_defs[0] = nchwc_input->nchwc_arg_;
|
||||
nchwc_input->remaining_original_uses_--;
|
||||
|
||||
|
|
@ -805,7 +794,7 @@ void NchwcTransformerImpl::TransformActivation(Node& node) {
|
|||
FuseNchwcArgument(node, *nchwc_input);
|
||||
removed_nodes_.push_front(node.Index());
|
||||
} else {
|
||||
CreateNchwcArgument(node, node, static_cast<size_t>(nchwc_input->channels_), nchwc_input->shape_);
|
||||
CreateNchwcArgument(node, node, nchwc_input->channels_, nchwc_input->shape_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -823,11 +812,10 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) {
|
|||
}
|
||||
|
||||
// Don't transform the node if the input is not already in NCHWc format.
|
||||
auto it = nchwc_args_.find(input_defs[0]);
|
||||
if (it == nchwc_args_.end()) {
|
||||
auto* nchwc_input = LookupNchwcArgument(input_defs[0]);
|
||||
if (nchwc_input == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto* nchwc_input = it->second.get();
|
||||
|
||||
// Require that BatchNormalization-7 uses spatial normalization.
|
||||
const auto* spatial_attr = graph_utils::GetNodeAttribute(node, "spatial");
|
||||
|
|
@ -926,7 +914,7 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) {
|
|||
|
||||
nchwc_input->remaining_original_uses_--;
|
||||
|
||||
CreateNchwcArgument(node, nchwc_node, static_cast<size_t>(channels), nchwc_input->shape_);
|
||||
CreateNchwcArgument(node, nchwc_node, channels, nchwc_input->shape_);
|
||||
removed_nodes_.push_front(node.Index());
|
||||
}
|
||||
|
||||
|
|
@ -935,11 +923,10 @@ void NchwcTransformerImpl::TransformTransposeToNhwc(Node& node) {
|
|||
auto& output_defs = node.MutableOutputDefs();
|
||||
|
||||
// Don't transform the node if the input is not already in NCHWc format.
|
||||
auto it = nchwc_args_.find(input_defs[0]);
|
||||
if (it == nchwc_args_.end()) {
|
||||
auto* nchwc_input = LookupNchwcArgument(input_defs[0]);
|
||||
if (nchwc_input == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto* nchwc_input = it->second.get();
|
||||
|
||||
const auto* perm_attr = graph_utils::GetNodeAttribute(node, "perm");
|
||||
if (perm_attr == nullptr || perm_attr->ints_size() != 4) {
|
||||
|
|
@ -976,11 +963,10 @@ void NchwcTransformerImpl::TransformResize(Node& node) {
|
|||
auto& output_defs = node.MutableOutputDefs();
|
||||
|
||||
// Don't transform the node if the input is not already in NCHWc format.
|
||||
auto it = nchwc_args_.find(input_defs[0]);
|
||||
if (it == nchwc_args_.end()) {
|
||||
auto* nchwc_input = LookupNchwcArgument(input_defs[0]);
|
||||
if (nchwc_input == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto* nchwc_input = it->second.get();
|
||||
|
||||
// Support nearest (default) and linear modes.
|
||||
const auto* mode_attr = graph_utils::GetNodeAttribute(node, "mode");
|
||||
|
|
@ -1127,7 +1113,7 @@ void NchwcTransformerImpl::TransformResize(Node& node) {
|
|||
|
||||
NchwcArgument::Shape output_shape(output_defs[0]);
|
||||
|
||||
CreateNchwcArgument(node, nchwc_node, static_cast<size_t>(nchwc_input->channels_), output_shape);
|
||||
CreateNchwcArgument(node, nchwc_node, nchwc_input->channels_, output_shape);
|
||||
removed_nodes_.push_front(node.Index());
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue