mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
Quantization support for split operator with its NHWC support (#6107)
* Make split working for quantization. * NHWC transformer support for split operator * Refactor some according to Feedback. Will add test cases soon. * Fix build error on windows. * Add test case for split op on uint8_t support * Add nhwc_transformer_test for split uint8_t support * Some change according to PR feedbacks.
This commit is contained in:
parent
6b73bae035
commit
f77ff1bc3d
8 changed files with 210 additions and 27 deletions
|
|
@ -592,6 +592,13 @@ size_t RemoveNodeOutputEdges(Graph& graph, Node& node) {
|
|||
return output_edges.size();
|
||||
}
|
||||
|
||||
size_t RemoveNodeOutputEdges(Graph& graph, Node& node, int output_idx) {
|
||||
std::vector<GraphEdge> output_edges = GetNodeOutputEdges(node, output_idx);
|
||||
RemoveGraphEdges(graph, output_edges);
|
||||
|
||||
return output_edges.size();
|
||||
}
|
||||
|
||||
void ReplaceDownstreamNodeInput(Graph& graph, Node& node, int output_idx, Node& replacement, int replacement_output_idx) {
|
||||
// get the output edges from node for output_idx
|
||||
std::vector<GraphEdge> output_edges = GetNodeOutputEdges(node, output_idx);
|
||||
|
|
|
|||
|
|
@ -148,6 +148,9 @@ bool ReplaceNodeWithInitializer(Graph& graph, Node& node, NodeArg& replacement);
|
|||
This should probably be elevated to the Graph API eventually. */
|
||||
size_t RemoveNodeOutputEdges(Graph& graph, Node& node);
|
||||
|
||||
/** Removes output edges from the specific output_idx for the given Node of the Graph. */
|
||||
size_t RemoveNodeOutputEdges(Graph& graph, Node& node, int output_idx);
|
||||
|
||||
/** Replaces the input to nodes that are downstream from 'node', which was being provided by an output of 'node',
|
||||
with an output from a different node. Moves the output edges from 'node' for 'output_idx' to the replacement node.
|
||||
@param replacement The node providing the replacement output.
|
||||
|
|
|
|||
|
|
@ -40,7 +40,8 @@ class NhwcTransformerImpl {
|
|||
return (it != nhwc_args_.end()) ? it->second.get() : nullptr;
|
||||
}
|
||||
|
||||
size_t RemoveOutputEdges(Node& node);
|
||||
size_t RemoveOutputEdge(Node& node, size_t output_index);
|
||||
void CreateNhwcArgument(Node& node, Node& nhwc_node, int rank, size_t output_index);
|
||||
void CreateNhwcArgument(Node& node, Node& nhwc_node, int rank);
|
||||
void InsertReorderInput(Node& node, int rank);
|
||||
|
||||
|
|
@ -48,6 +49,7 @@ class NhwcTransformerImpl {
|
|||
void TransformQLinearBinary(Node& node);
|
||||
void TransformQLinearActivation(Node& node);
|
||||
void TransformQLinearGlobalAveragePool(Node& node);
|
||||
void TransformSplit(Node& node);
|
||||
|
||||
Graph& graph_;
|
||||
|
||||
|
|
@ -63,30 +65,40 @@ class NhwcTransformerImpl {
|
|||
std::deque<NodeIndex> removed_nodes_;
|
||||
};
|
||||
|
||||
size_t NhwcTransformerImpl::RemoveOutputEdges(Node& node) {
|
||||
size_t output_edges_count = node.GetOutputEdgesCount();
|
||||
if (output_edges_count > 0) {
|
||||
graph_utils::RemoveNodeOutputEdges(graph_, node);
|
||||
}
|
||||
// Bias the edge count to handle the case of a node that produces a graph
|
||||
// output.
|
||||
if (!graph_.GetNodeOutputsInGraphOutputs(node).empty()) {
|
||||
output_edges_count++;
|
||||
// Remove node's output edge starting from specified index, return number of edges removed.
|
||||
// If output at specified index for the node is graph output, inc the count returned.
|
||||
size_t NhwcTransformerImpl::RemoveOutputEdge(Node& node, size_t output_index) {
|
||||
size_t output_edges_count = graph_utils::RemoveNodeOutputEdges(graph_, node, static_cast<int>(output_index));
|
||||
|
||||
// Bias the edge count to if the node produces a graph output at output_index.
|
||||
auto node_outputs_for_graph = graph_.GetNodeOutputsInGraphOutputs(node);
|
||||
for (auto idx : node_outputs_for_graph) {
|
||||
if (idx == static_cast<int>(output_index)) {
|
||||
output_edges_count++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return output_edges_count;
|
||||
}
|
||||
|
||||
void NhwcTransformerImpl::CreateNhwcArgument(Node& node, Node& nhwc_node, int rank) {
|
||||
size_t original_uses = RemoveOutputEdges(node);
|
||||
void NhwcTransformerImpl::CreateNhwcArgument(Node& node, Node& nhwc_node, int rank, size_t output_index) {
|
||||
size_t original_uses = RemoveOutputEdge(node, output_index);
|
||||
|
||||
// Create a new NodeArg to track the output from the NHWC node.
|
||||
auto& output_defs = nhwc_node.MutableOutputDefs();
|
||||
auto* output_original_arg = output_defs[0];
|
||||
auto* output_original_arg = output_defs[output_index];
|
||||
std::string output_reorder_def_name = graph_.GenerateNodeArgName("reorder");
|
||||
auto* output_nhwc_arg = &graph_.GetOrCreateNodeArg(output_reorder_def_name, nullptr);
|
||||
nhwc_args_[output_original_arg] =
|
||||
onnxruntime::make_unique<NhwcArgument>(nhwc_node, output_nhwc_arg, original_uses, rank);
|
||||
output_defs[0] = output_nhwc_arg;
|
||||
output_defs[output_index] = output_nhwc_arg;
|
||||
}
|
||||
|
||||
void NhwcTransformerImpl::CreateNhwcArgument(Node& node, Node& nhwc_node, int rank) {
|
||||
size_t output_count = node.OutputDefs().size();
|
||||
for (size_t output_index = 0; output_index < output_count; ++output_index) {
|
||||
CreateNhwcArgument(node, nhwc_node, rank, output_index);
|
||||
}
|
||||
}
|
||||
|
||||
void NhwcTransformerImpl::InsertReorderInput(Node& node, int rank) {
|
||||
|
|
@ -237,6 +249,39 @@ void NhwcTransformerImpl::TransformQLinearGlobalAveragePool(Node& node) {
|
|||
CreateNhwcArgument(node, node, nhwc_input->rank_);
|
||||
}
|
||||
|
||||
void NhwcTransformerImpl::TransformSplit(Node& node) {
|
||||
auto& input_defs = node.MutableInputDefs();
|
||||
|
||||
auto* nhwc_input = LookupNhwcArgument(input_defs[0]);
|
||||
if (nhwc_input == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Change the axis attribute accordingly for NCHW to NHWC model.
|
||||
const auto* axis_attr = graph_utils::GetNodeAttribute(node, "axis");
|
||||
if (axis_attr != nullptr && utils::HasInt(*axis_attr)) {
|
||||
int64_t axis = axis_attr->i();
|
||||
if (axis < -nhwc_input->rank_ || axis >= nhwc_input->rank_) {
|
||||
// direct return on invalid axis
|
||||
return;
|
||||
}
|
||||
if (axis < 0) {
|
||||
axis = axis + nhwc_input->rank_;
|
||||
}
|
||||
if (axis == 1) {
|
||||
axis = nhwc_input->rank_ - 1;
|
||||
} else if (axis > 1) {
|
||||
axis = axis - 1;
|
||||
}
|
||||
node.AddAttribute("axis", axis);
|
||||
}
|
||||
|
||||
input_defs[0] = nhwc_input->nhwc_arg_;
|
||||
nhwc_input->remaining_original_uses_--;
|
||||
|
||||
CreateNhwcArgument(node, node, nhwc_input->rank_);
|
||||
}
|
||||
|
||||
void NhwcTransformerImpl::Transform(Node& node) {
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "QLinearConv", {10})) {
|
||||
TransformQLinearConv(node);
|
||||
|
|
@ -248,6 +293,8 @@ void NhwcTransformerImpl::Transform(Node& node) {
|
|||
TransformQLinearActivation(node);
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "QLinearGlobalAveragePool", {1}, kMSDomain)) {
|
||||
TransformQLinearGlobalAveragePool(node);
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Split", {2, 11, 13})) {
|
||||
TransformSplit(node);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
|||
DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>(),
|
||||
DataTypeImpl::GetTensorType<uint8_t>(),
|
||||
DataTypeImpl::GetTensorType<std::string>()}),
|
||||
Split);
|
||||
|
||||
|
|
@ -32,6 +33,7 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
|||
DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>(),
|
||||
DataTypeImpl::GetTensorType<uint8_t>(),
|
||||
DataTypeImpl::GetTensorType<std::string>()}),
|
||||
Split);
|
||||
|
||||
|
|
@ -44,6 +46,7 @@ ONNX_CPU_OPERATOR_KERNEL(
|
|||
DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>(),
|
||||
DataTypeImpl::GetTensorType<uint8_t>(),
|
||||
DataTypeImpl::GetTensorType<std::string>()}),
|
||||
Split);
|
||||
|
||||
|
|
@ -97,6 +100,8 @@ Status Split::Compute(OpKernelContext* context) const {
|
|||
status = ComputeImpl<int32_t>(*context, input);
|
||||
else if (input.IsDataType<int64_t>())
|
||||
status = ComputeImpl<int64_t>(*context, input);
|
||||
else if (input.IsDataType<uint8_t>())
|
||||
status = ComputeImpl<uint8_t>(*context, input);
|
||||
else if (input.IsDataTypeString())
|
||||
status = ComputeImpl<std::string>(*context, input);
|
||||
else
|
||||
|
|
|
|||
35
onnxruntime/python/tools/quantization/operators/split.py
Normal file
35
onnxruntime/python/tools/quantization/operators/split.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
import onnx
|
||||
from .base_operator import QuantOperatorBase
|
||||
from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg
|
||||
from onnx import onnx_pb as onnx_proto
|
||||
|
||||
|
||||
class QSplit(QuantOperatorBase):
|
||||
def __init__(self, onnx_quantizer, onnx_node):
|
||||
super().__init__(onnx_quantizer, onnx_node)
|
||||
|
||||
def quantize(self):
|
||||
node = self.node
|
||||
quantized_input_names, zero_point_names, scale_names, nodes = self.quantizer.quantize_inputs(node, [0])
|
||||
if node.name != "":
|
||||
quantized_node_name = node.name + "_quant"
|
||||
kwargs = {}
|
||||
for attribute in node.attribute:
|
||||
kwargs.update(attribute_to_kwarg(attribute))
|
||||
|
||||
# Output just derive the scale/zero from input
|
||||
quantized_output_names = []
|
||||
for output_name in node.output:
|
||||
quantized_output_name = output_name + "quantized"
|
||||
quantized_output_names.append(quantized_output_name)
|
||||
q_output = QuantizedValue(output_name, quantized_output_name, scale_names[0],
|
||||
zero_point_names[0], QuantizedValueType.Input)
|
||||
self.quantizer.quantized_value_map[output_name] = q_output
|
||||
|
||||
if len(node.input) > 1:
|
||||
quantized_input_names = quantized_input_names.extend(node.input[1:])
|
||||
quantized_node = onnx.helper.make_node(
|
||||
node.op_type, quantized_input_names, quantized_output_names, quantized_node_name, **kwargs)
|
||||
|
||||
nodes.append(quantized_node)
|
||||
self.quantizer.new_nodes += nodes
|
||||
|
|
@ -9,7 +9,8 @@ from .operators.activation import QLinearActivation
|
|||
from .operators.binary_op import QLinearBinaryOp
|
||||
from .operators.maxpool import QMaxPool
|
||||
from .operators.gavgpool import QGlobalAveragePool
|
||||
from. operators.lstm import LSTMQuant
|
||||
from .operators.lstm import LSTMQuant
|
||||
from .operators.split import QSplit
|
||||
|
||||
CommonOpsRegistry = {"Gather": GatherQuant, "EmbedLayerNormalization": EmbedLayerNormalizationQuant}
|
||||
|
||||
|
|
@ -32,6 +33,7 @@ QLinearOpsRegistry = {
|
|||
"Sigmoid" : QLinearActivation,
|
||||
"MaxPool": QMaxPool,
|
||||
"GlobalAveragePool": QGlobalAveragePool,
|
||||
"Split" : QSplit,
|
||||
}
|
||||
QLinearOpsRegistry.update(CommonOpsRegistry)
|
||||
|
||||
|
|
|
|||
|
|
@ -419,6 +419,44 @@ TEST(NhwcTransformerTests, ConvGlobalAveragePool) {
|
|||
NhwcTransformerTester(build_test_case, check_nhwc_graph);
|
||||
}
|
||||
|
||||
TEST(NhwcTransformerTests, ConvSplit) {
|
||||
for (int64_t axis = -4LL; axis < 4; axis++) {
|
||||
auto build_test_case = [&, axis](NhwcTestHelper& helper) {
|
||||
auto* input_arg = helper.MakeInput<uint8_t>({2, 23, 16, 16});
|
||||
auto* conv_output_arg = helper.MakeIntermediate();
|
||||
auto* split_output1_arg = helper.MakeIntermediate();
|
||||
auto* split_output2_arg = helper.MakeIntermediate();
|
||||
auto* qladd_output_arg = helper.MakeIntermediate();
|
||||
auto* output_arg = helper.MakeOutput();
|
||||
|
||||
const int64_t conv1_output_channels = 32;
|
||||
Node& conv_node = helper.AddQLinearConvNode<uint8_t>(input_arg, .01f, 135,
|
||||
{conv1_output_channels, 23, 3, 3}, .02f, 126,
|
||||
conv_output_arg, .37f, 131);
|
||||
conv_node.AddAttribute("pads", std::vector<int64_t>{1, 1, 1, 1});
|
||||
Node& split_node = helper.AddNode("Split", {conv_output_arg}, {split_output1_arg, split_output2_arg});
|
||||
split_node.AddAttribute("axis", static_cast<int64_t>(axis));
|
||||
helper.AddQLinearBinaryNode("QLinearAdd",
|
||||
split_output1_arg, .37f, 131,
|
||||
split_output2_arg, .37f, 131,
|
||||
qladd_output_arg, .43f, 126);
|
||||
const int64_t channels_after_split =
|
||||
(axis == 1 || axis == -3) ? conv1_output_channels / 2 : conv1_output_channels;
|
||||
helper.AddQLinearConvNode<uint8_t>(qladd_output_arg, .43f, 126,
|
||||
{17, channels_after_split, 3, 3}, .02f, 126,
|
||||
output_arg, .37f, 131);
|
||||
};
|
||||
|
||||
auto check_nhwc_graph = [&](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
EXPECT_EQ(op_to_count["com.microsoft.QLinearConv"], 2);
|
||||
EXPECT_EQ(op_to_count["Transpose"], 2);
|
||||
};
|
||||
|
||||
NhwcTransformerTester(build_test_case, check_nhwc_graph);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NhwcTransformerTests, ConvBlockActivation) {
|
||||
auto test_case = [&](uint32_t extra_edges) {
|
||||
auto build_test_case = [&](NhwcTestHelper& helper) {
|
||||
|
|
|
|||
|
|
@ -251,24 +251,23 @@ TEST(SplitOperatorTest, Axis1UnequalSplitString) {
|
|||
RunTest<std::string>(axis, splits, input, outputs, false);
|
||||
}
|
||||
|
||||
ShapeAndFloatData CreateInput(std::vector<int64_t> shape) {
|
||||
template <typename T>
|
||||
ShapeAndData<T> CreateInput(std::vector<int64_t> shape) {
|
||||
auto size = TensorShape(shape).Size();
|
||||
|
||||
float i = 0.f, increment = 1.f;
|
||||
// generate the elements for the data starting at 1.f
|
||||
std::vector<float> data;
|
||||
T i = static_cast<T>(0), increment = static_cast<T>(1);
|
||||
// generate the elements for the data starting at 1
|
||||
std::vector<T> data;
|
||||
std::generate_n(std::back_inserter(data), size, [&]() { return i += increment; });
|
||||
|
||||
ShapeAndFloatData input = {shape, data};
|
||||
|
||||
return input;
|
||||
return ShapeAndData<T>{shape, data};
|
||||
}
|
||||
|
||||
TEST(SplitOperatorTest, Axis2EqualSplit) {
|
||||
const int64_t axis = 2;
|
||||
std::vector<ShapeAndFloatData> outputs;
|
||||
|
||||
ShapeAndFloatData input = CreateInput({2, 2, 6});
|
||||
ShapeAndFloatData input = CreateInput<float>({2, 2, 6});
|
||||
|
||||
outputs.push_back({{2, 2, 2},
|
||||
{1.f, 2.f,
|
||||
|
|
@ -298,7 +297,7 @@ TEST(SplitOperatorTest, Axis2UnequalSplit) {
|
|||
const int64_t axis = 2;
|
||||
std::vector<ShapeAndFloatData> outputs;
|
||||
|
||||
ShapeAndFloatData input = CreateInput({2, 2, 6});
|
||||
ShapeAndFloatData input = CreateInput<float>({2, 2, 6});
|
||||
|
||||
std::vector<int64_t> splits{1, 2, 3};
|
||||
|
||||
|
|
@ -330,7 +329,7 @@ TEST(SplitOperatorTest, ZeroSizeInput) {
|
|||
const int64_t axis = -1;
|
||||
std::vector<ShapeAndFloatData> outputs{{{0, 1}, {}}, {{0, 1}, {}}};
|
||||
|
||||
ShapeAndFloatData input = CreateInput({0, 2});
|
||||
ShapeAndFloatData input = CreateInput<float>({0, 2});
|
||||
|
||||
RunTest<float>(axis, {}, input, outputs, false);
|
||||
}
|
||||
|
|
@ -340,7 +339,7 @@ TEST(SplitOperatorTest, Axis1SplitMiddleDimensionEqually) {
|
|||
const int64_t axis = 1;
|
||||
std::vector<ShapeAndFloatData> outputs;
|
||||
|
||||
ShapeAndFloatData input = CreateInput({2, 4, 4});
|
||||
ShapeAndFloatData input = CreateInput<float>({2, 4, 4});
|
||||
|
||||
outputs.push_back({{2, 2, 4},
|
||||
{1.f, 2.f, 3.f, 4.f,
|
||||
|
|
@ -364,7 +363,7 @@ TEST(SplitOperatorTest, Axis1SplitMiddleDimensionUnequally) {
|
|||
const int64_t axis = 1;
|
||||
std::vector<ShapeAndFloatData> outputs;
|
||||
|
||||
ShapeAndFloatData input = CreateInput({2, 4, 4});
|
||||
ShapeAndFloatData input = CreateInput<float>({2, 4, 4});
|
||||
|
||||
std::vector<int64_t> splits{1, 3};
|
||||
|
||||
|
|
@ -567,5 +566,52 @@ SplitAxis2()
|
|||
SplitMiddleDimension()
|
||||
|
||||
*/
|
||||
|
||||
// test split for uint8_t data that has leading and trailing dimensions
|
||||
TEST(SplitOperatorTest, Uint8Axis1SplitMiddleDimensionUnequally) {
|
||||
const int64_t axis = 1;
|
||||
std::vector<ShapeAndData<uint8_t>> outputs;
|
||||
|
||||
ShapeAndData<uint8_t> input = CreateInput<uint8_t>({2, 4, 4});
|
||||
|
||||
std::vector<int64_t> splits{1, 3};
|
||||
|
||||
outputs.push_back({{2, 1, 4},
|
||||
{1, 2, 3, 4,
|
||||
|
||||
17, 18, 19, 20}});
|
||||
|
||||
outputs.push_back({{2, 3, 4},
|
||||
{5, 6, 7, 8,
|
||||
9, 10, 11, 12,
|
||||
13, 14, 15, 16,
|
||||
|
||||
21, 22, 23, 24,
|
||||
25, 26, 27, 28,
|
||||
29, 30, 31, 32}});
|
||||
|
||||
RunTest<uint8_t>(axis, splits, input, outputs, false);
|
||||
}
|
||||
|
||||
// test split for uint8_t data on the last axis equally
|
||||
TEST(SplitOperatorTest, Uint8NegativeAxis) {
|
||||
const int64_t axis = -1;
|
||||
std::vector<ShapeAndData<uint8_t>> outputs;
|
||||
|
||||
ShapeAndData<uint8_t> input = {{2, 4},
|
||||
{1, 2, 3, 4,
|
||||
5, 6, 7, 8}};
|
||||
|
||||
outputs.push_back({{2, 2},
|
||||
{1, 2,
|
||||
5, 6}});
|
||||
|
||||
outputs.push_back({{2, 2},
|
||||
{3, 4,
|
||||
7, 8}});
|
||||
|
||||
RunTest<uint8_t>(axis, {}, input, outputs, false);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue