diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index f7a694d9f0..923a6e49dc 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -592,6 +592,13 @@ size_t RemoveNodeOutputEdges(Graph& graph, Node& node) { return output_edges.size(); } +size_t RemoveNodeOutputEdges(Graph& graph, Node& node, int output_idx) { + std::vector output_edges = GetNodeOutputEdges(node, output_idx); + RemoveGraphEdges(graph, output_edges); + + return output_edges.size(); +} + void ReplaceDownstreamNodeInput(Graph& graph, Node& node, int output_idx, Node& replacement, int replacement_output_idx) { // get the output edges from node for output_idx std::vector output_edges = GetNodeOutputEdges(node, output_idx); diff --git a/onnxruntime/core/graph/graph_utils.h b/onnxruntime/core/graph/graph_utils.h index f0fbc98cc0..edd1768da7 100644 --- a/onnxruntime/core/graph/graph_utils.h +++ b/onnxruntime/core/graph/graph_utils.h @@ -148,6 +148,9 @@ bool ReplaceNodeWithInitializer(Graph& graph, Node& node, NodeArg& replacement); This should probably be elevated to the Graph API eventually. */ size_t RemoveNodeOutputEdges(Graph& graph, Node& node); +/** Removes output edges from the specific output_idx for the given Node of the Graph. */ +size_t RemoveNodeOutputEdges(Graph& graph, Node& node, int output_idx); + /** Replaces the input to nodes that are downstream from 'node', which was being provided by an output of 'node', with an output from a different node. Moves the output edges from 'node' for 'output_idx' to the replacement node. @param replacement The node providing the replacement output. diff --git a/onnxruntime/core/optimizer/nhwc_transformer.cc b/onnxruntime/core/optimizer/nhwc_transformer.cc index c1245335d4..67952b0e03 100644 --- a/onnxruntime/core/optimizer/nhwc_transformer.cc +++ b/onnxruntime/core/optimizer/nhwc_transformer.cc @@ -40,7 +40,8 @@ class NhwcTransformerImpl { return (it != nhwc_args_.end()) ? it->second.get() : nullptr; } - size_t RemoveOutputEdges(Node& node); + size_t RemoveOutputEdge(Node& node, size_t output_index); + void CreateNhwcArgument(Node& node, Node& nhwc_node, int rank, size_t output_index); void CreateNhwcArgument(Node& node, Node& nhwc_node, int rank); void InsertReorderInput(Node& node, int rank); @@ -48,6 +49,7 @@ class NhwcTransformerImpl { void TransformQLinearBinary(Node& node); void TransformQLinearActivation(Node& node); void TransformQLinearGlobalAveragePool(Node& node); + void TransformSplit(Node& node); Graph& graph_; @@ -63,30 +65,40 @@ class NhwcTransformerImpl { std::deque removed_nodes_; }; -size_t NhwcTransformerImpl::RemoveOutputEdges(Node& node) { - size_t output_edges_count = node.GetOutputEdgesCount(); - if (output_edges_count > 0) { - graph_utils::RemoveNodeOutputEdges(graph_, node); - } - // Bias the edge count to handle the case of a node that produces a graph - // output. - if (!graph_.GetNodeOutputsInGraphOutputs(node).empty()) { - output_edges_count++; +// Remove node's output edge starting from specified index, return number of edges removed. +// If output at specified index for the node is graph output, inc the count returned. +size_t NhwcTransformerImpl::RemoveOutputEdge(Node& node, size_t output_index) { + size_t output_edges_count = graph_utils::RemoveNodeOutputEdges(graph_, node, static_cast(output_index)); + + // Bias the edge count to if the node produces a graph output at output_index. + auto node_outputs_for_graph = graph_.GetNodeOutputsInGraphOutputs(node); + for (auto idx : node_outputs_for_graph) { + if (idx == static_cast(output_index)) { + output_edges_count++; + break; + } } return output_edges_count; } -void NhwcTransformerImpl::CreateNhwcArgument(Node& node, Node& nhwc_node, int rank) { - size_t original_uses = RemoveOutputEdges(node); +void NhwcTransformerImpl::CreateNhwcArgument(Node& node, Node& nhwc_node, int rank, size_t output_index) { + size_t original_uses = RemoveOutputEdge(node, output_index); // Create a new NodeArg to track the output from the NHWC node. auto& output_defs = nhwc_node.MutableOutputDefs(); - auto* output_original_arg = output_defs[0]; + auto* output_original_arg = output_defs[output_index]; std::string output_reorder_def_name = graph_.GenerateNodeArgName("reorder"); auto* output_nhwc_arg = &graph_.GetOrCreateNodeArg(output_reorder_def_name, nullptr); nhwc_args_[output_original_arg] = onnxruntime::make_unique(nhwc_node, output_nhwc_arg, original_uses, rank); - output_defs[0] = output_nhwc_arg; + output_defs[output_index] = output_nhwc_arg; +} + +void NhwcTransformerImpl::CreateNhwcArgument(Node& node, Node& nhwc_node, int rank) { + size_t output_count = node.OutputDefs().size(); + for (size_t output_index = 0; output_index < output_count; ++output_index) { + CreateNhwcArgument(node, nhwc_node, rank, output_index); + } } void NhwcTransformerImpl::InsertReorderInput(Node& node, int rank) { @@ -237,6 +249,39 @@ void NhwcTransformerImpl::TransformQLinearGlobalAveragePool(Node& node) { CreateNhwcArgument(node, node, nhwc_input->rank_); } +void NhwcTransformerImpl::TransformSplit(Node& node) { + auto& input_defs = node.MutableInputDefs(); + + auto* nhwc_input = LookupNhwcArgument(input_defs[0]); + if (nhwc_input == nullptr) { + return; + } + + // Change the axis attribute accordingly for NCHW to NHWC model. + const auto* axis_attr = graph_utils::GetNodeAttribute(node, "axis"); + if (axis_attr != nullptr && utils::HasInt(*axis_attr)) { + int64_t axis = axis_attr->i(); + if (axis < -nhwc_input->rank_ || axis >= nhwc_input->rank_) { + // direct return on invalid axis + return; + } + if (axis < 0) { + axis = axis + nhwc_input->rank_; + } + if (axis == 1) { + axis = nhwc_input->rank_ - 1; + } else if (axis > 1) { + axis = axis - 1; + } + node.AddAttribute("axis", axis); + } + + input_defs[0] = nhwc_input->nhwc_arg_; + nhwc_input->remaining_original_uses_--; + + CreateNhwcArgument(node, node, nhwc_input->rank_); +} + void NhwcTransformerImpl::Transform(Node& node) { if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "QLinearConv", {10})) { TransformQLinearConv(node); @@ -248,6 +293,8 @@ void NhwcTransformerImpl::Transform(Node& node) { TransformQLinearActivation(node); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "QLinearGlobalAveragePool", {1}, kMSDomain)) { TransformQLinearGlobalAveragePool(node); + } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Split", {2, 11, 13})) { + TransformSplit(node); } } diff --git a/onnxruntime/core/providers/cpu/tensor/split.cc b/onnxruntime/core/providers/cpu/tensor/split.cc index 0d3b7ca54c..72f0df3138 100644 --- a/onnxruntime/core/providers/cpu/tensor/split.cc +++ b/onnxruntime/core/providers/cpu/tensor/split.cc @@ -19,6 +19,7 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Split); @@ -32,6 +33,7 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Split); @@ -44,6 +46,7 @@ ONNX_CPU_OPERATOR_KERNEL( DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Split); @@ -97,6 +100,8 @@ Status Split::Compute(OpKernelContext* context) const { status = ComputeImpl(*context, input); else if (input.IsDataType()) status = ComputeImpl(*context, input); + else if (input.IsDataType()) + status = ComputeImpl(*context, input); else if (input.IsDataTypeString()) status = ComputeImpl(*context, input); else diff --git a/onnxruntime/python/tools/quantization/operators/split.py b/onnxruntime/python/tools/quantization/operators/split.py new file mode 100644 index 0000000000..f02033e753 --- /dev/null +++ b/onnxruntime/python/tools/quantization/operators/split.py @@ -0,0 +1,35 @@ +import onnx +from .base_operator import QuantOperatorBase +from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg +from onnx import onnx_pb as onnx_proto + + +class QSplit(QuantOperatorBase): + def __init__(self, onnx_quantizer, onnx_node): + super().__init__(onnx_quantizer, onnx_node) + + def quantize(self): + node = self.node + quantized_input_names, zero_point_names, scale_names, nodes = self.quantizer.quantize_inputs(node, [0]) + if node.name != "": + quantized_node_name = node.name + "_quant" + kwargs = {} + for attribute in node.attribute: + kwargs.update(attribute_to_kwarg(attribute)) + + # Output just derive the scale/zero from input + quantized_output_names = [] + for output_name in node.output: + quantized_output_name = output_name + "quantized" + quantized_output_names.append(quantized_output_name) + q_output = QuantizedValue(output_name, quantized_output_name, scale_names[0], + zero_point_names[0], QuantizedValueType.Input) + self.quantizer.quantized_value_map[output_name] = q_output + + if len(node.input) > 1: + quantized_input_names = quantized_input_names.extend(node.input[1:]) + quantized_node = onnx.helper.make_node( + node.op_type, quantized_input_names, quantized_output_names, quantized_node_name, **kwargs) + + nodes.append(quantized_node) + self.quantizer.new_nodes += nodes diff --git a/onnxruntime/python/tools/quantization/registry.py b/onnxruntime/python/tools/quantization/registry.py index 3cec776ecf..5cb7ad7ad1 100644 --- a/onnxruntime/python/tools/quantization/registry.py +++ b/onnxruntime/python/tools/quantization/registry.py @@ -9,7 +9,8 @@ from .operators.activation import QLinearActivation from .operators.binary_op import QLinearBinaryOp from .operators.maxpool import QMaxPool from .operators.gavgpool import QGlobalAveragePool -from. operators.lstm import LSTMQuant +from .operators.lstm import LSTMQuant +from .operators.split import QSplit CommonOpsRegistry = {"Gather": GatherQuant, "EmbedLayerNormalization": EmbedLayerNormalizationQuant} @@ -32,6 +33,7 @@ QLinearOpsRegistry = { "Sigmoid" : QLinearActivation, "MaxPool": QMaxPool, "GlobalAveragePool": QGlobalAveragePool, + "Split" : QSplit, } QLinearOpsRegistry.update(CommonOpsRegistry) diff --git a/onnxruntime/test/optimizer/nhwc_transformer_test.cc b/onnxruntime/test/optimizer/nhwc_transformer_test.cc index 52fab4c06e..e0f19626e5 100644 --- a/onnxruntime/test/optimizer/nhwc_transformer_test.cc +++ b/onnxruntime/test/optimizer/nhwc_transformer_test.cc @@ -419,6 +419,44 @@ TEST(NhwcTransformerTests, ConvGlobalAveragePool) { NhwcTransformerTester(build_test_case, check_nhwc_graph); } +TEST(NhwcTransformerTests, ConvSplit) { + for (int64_t axis = -4LL; axis < 4; axis++) { + auto build_test_case = [&, axis](NhwcTestHelper& helper) { + auto* input_arg = helper.MakeInput({2, 23, 16, 16}); + auto* conv_output_arg = helper.MakeIntermediate(); + auto* split_output1_arg = helper.MakeIntermediate(); + auto* split_output2_arg = helper.MakeIntermediate(); + auto* qladd_output_arg = helper.MakeIntermediate(); + auto* output_arg = helper.MakeOutput(); + + const int64_t conv1_output_channels = 32; + Node& conv_node = helper.AddQLinearConvNode(input_arg, .01f, 135, + {conv1_output_channels, 23, 3, 3}, .02f, 126, + conv_output_arg, .37f, 131); + conv_node.AddAttribute("pads", std::vector{1, 1, 1, 1}); + Node& split_node = helper.AddNode("Split", {conv_output_arg}, {split_output1_arg, split_output2_arg}); + split_node.AddAttribute("axis", static_cast(axis)); + helper.AddQLinearBinaryNode("QLinearAdd", + split_output1_arg, .37f, 131, + split_output2_arg, .37f, 131, + qladd_output_arg, .43f, 126); + const int64_t channels_after_split = + (axis == 1 || axis == -3) ? conv1_output_channels / 2 : conv1_output_channels; + helper.AddQLinearConvNode(qladd_output_arg, .43f, 126, + {17, channels_after_split, 3, 3}, .02f, 126, + output_arg, .37f, 131); + }; + + auto check_nhwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["com.microsoft.QLinearConv"], 2); + EXPECT_EQ(op_to_count["Transpose"], 2); + }; + + NhwcTransformerTester(build_test_case, check_nhwc_graph); + } +} + TEST(NhwcTransformerTests, ConvBlockActivation) { auto test_case = [&](uint32_t extra_edges) { auto build_test_case = [&](NhwcTestHelper& helper) { diff --git a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc index da52288b29..a402b13fd6 100644 --- a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc @@ -251,24 +251,23 @@ TEST(SplitOperatorTest, Axis1UnequalSplitString) { RunTest(axis, splits, input, outputs, false); } -ShapeAndFloatData CreateInput(std::vector shape) { +template +ShapeAndData CreateInput(std::vector shape) { auto size = TensorShape(shape).Size(); - float i = 0.f, increment = 1.f; - // generate the elements for the data starting at 1.f - std::vector data; + T i = static_cast(0), increment = static_cast(1); + // generate the elements for the data starting at 1 + std::vector data; std::generate_n(std::back_inserter(data), size, [&]() { return i += increment; }); - ShapeAndFloatData input = {shape, data}; - - return input; + return ShapeAndData{shape, data}; } TEST(SplitOperatorTest, Axis2EqualSplit) { const int64_t axis = 2; std::vector outputs; - ShapeAndFloatData input = CreateInput({2, 2, 6}); + ShapeAndFloatData input = CreateInput({2, 2, 6}); outputs.push_back({{2, 2, 2}, {1.f, 2.f, @@ -298,7 +297,7 @@ TEST(SplitOperatorTest, Axis2UnequalSplit) { const int64_t axis = 2; std::vector outputs; - ShapeAndFloatData input = CreateInput({2, 2, 6}); + ShapeAndFloatData input = CreateInput({2, 2, 6}); std::vector splits{1, 2, 3}; @@ -330,7 +329,7 @@ TEST(SplitOperatorTest, ZeroSizeInput) { const int64_t axis = -1; std::vector outputs{{{0, 1}, {}}, {{0, 1}, {}}}; - ShapeAndFloatData input = CreateInput({0, 2}); + ShapeAndFloatData input = CreateInput({0, 2}); RunTest(axis, {}, input, outputs, false); } @@ -340,7 +339,7 @@ TEST(SplitOperatorTest, Axis1SplitMiddleDimensionEqually) { const int64_t axis = 1; std::vector outputs; - ShapeAndFloatData input = CreateInput({2, 4, 4}); + ShapeAndFloatData input = CreateInput({2, 4, 4}); outputs.push_back({{2, 2, 4}, {1.f, 2.f, 3.f, 4.f, @@ -364,7 +363,7 @@ TEST(SplitOperatorTest, Axis1SplitMiddleDimensionUnequally) { const int64_t axis = 1; std::vector outputs; - ShapeAndFloatData input = CreateInput({2, 4, 4}); + ShapeAndFloatData input = CreateInput({2, 4, 4}); std::vector splits{1, 3}; @@ -567,5 +566,52 @@ SplitAxis2() SplitMiddleDimension() */ + +// test split for uint8_t data that has leading and trailing dimensions +TEST(SplitOperatorTest, Uint8Axis1SplitMiddleDimensionUnequally) { + const int64_t axis = 1; + std::vector> outputs; + + ShapeAndData input = CreateInput({2, 4, 4}); + + std::vector splits{1, 3}; + + outputs.push_back({{2, 1, 4}, + {1, 2, 3, 4, + + 17, 18, 19, 20}}); + + outputs.push_back({{2, 3, 4}, + {5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + + 21, 22, 23, 24, + 25, 26, 27, 28, + 29, 30, 31, 32}}); + + RunTest(axis, splits, input, outputs, false); +} + +// test split for uint8_t data on the last axis equally +TEST(SplitOperatorTest, Uint8NegativeAxis) { + const int64_t axis = -1; + std::vector> outputs; + + ShapeAndData input = {{2, 4}, + {1, 2, 3, 4, + 5, 6, 7, 8}}; + + outputs.push_back({{2, 2}, + {1, 2, + 5, 6}}); + + outputs.push_back({{2, 2}, + {3, 4, + 7, 8}}); + + RunTest(axis, {}, input, outputs, false); +} + } // namespace test } // namespace onnxruntime