Quantization support for split operator with its NHWC support (#6107)

* Make split working for quantization.

* NHWC transformer support for split operator

* Refactor some according to Feedback. Will add test cases soon.

* Fix build error on windows.

* Add test case for split op on uint8_t support

* Add nhwc_transformer_test for split uint8_t support

* Some change according to PR feedbacks.
This commit is contained in:
Zhang Lei 2021-01-13 10:05:34 -08:00 committed by GitHub
parent 6b73bae035
commit f77ff1bc3d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 210 additions and 27 deletions

View file

@ -592,6 +592,13 @@ size_t RemoveNodeOutputEdges(Graph& graph, Node& node) {
return output_edges.size();
}
size_t RemoveNodeOutputEdges(Graph& graph, Node& node, int output_idx) {
std::vector<GraphEdge> output_edges = GetNodeOutputEdges(node, output_idx);
RemoveGraphEdges(graph, output_edges);
return output_edges.size();
}
void ReplaceDownstreamNodeInput(Graph& graph, Node& node, int output_idx, Node& replacement, int replacement_output_idx) {
// get the output edges from node for output_idx
std::vector<GraphEdge> output_edges = GetNodeOutputEdges(node, output_idx);

View file

@ -148,6 +148,9 @@ bool ReplaceNodeWithInitializer(Graph& graph, Node& node, NodeArg& replacement);
This should probably be elevated to the Graph API eventually. */
size_t RemoveNodeOutputEdges(Graph& graph, Node& node);
/** Removes output edges from the specific output_idx for the given Node of the Graph. */
size_t RemoveNodeOutputEdges(Graph& graph, Node& node, int output_idx);
/** Replaces the input to nodes that are downstream from 'node', which was being provided by an output of 'node',
with an output from a different node. Moves the output edges from 'node' for 'output_idx' to the replacement node.
@param replacement The node providing the replacement output.

View file

@ -40,7 +40,8 @@ class NhwcTransformerImpl {
return (it != nhwc_args_.end()) ? it->second.get() : nullptr;
}
size_t RemoveOutputEdges(Node& node);
size_t RemoveOutputEdge(Node& node, size_t output_index);
void CreateNhwcArgument(Node& node, Node& nhwc_node, int rank, size_t output_index);
void CreateNhwcArgument(Node& node, Node& nhwc_node, int rank);
void InsertReorderInput(Node& node, int rank);
@ -48,6 +49,7 @@ class NhwcTransformerImpl {
void TransformQLinearBinary(Node& node);
void TransformQLinearActivation(Node& node);
void TransformQLinearGlobalAveragePool(Node& node);
void TransformSplit(Node& node);
Graph& graph_;
@ -63,30 +65,40 @@ class NhwcTransformerImpl {
std::deque<NodeIndex> removed_nodes_;
};
size_t NhwcTransformerImpl::RemoveOutputEdges(Node& node) {
size_t output_edges_count = node.GetOutputEdgesCount();
if (output_edges_count > 0) {
graph_utils::RemoveNodeOutputEdges(graph_, node);
}
// Bias the edge count to handle the case of a node that produces a graph
// output.
if (!graph_.GetNodeOutputsInGraphOutputs(node).empty()) {
output_edges_count++;
// Remove node's output edge starting from specified index, return number of edges removed.
// If output at specified index for the node is graph output, inc the count returned.
size_t NhwcTransformerImpl::RemoveOutputEdge(Node& node, size_t output_index) {
size_t output_edges_count = graph_utils::RemoveNodeOutputEdges(graph_, node, static_cast<int>(output_index));
// Bias the edge count to if the node produces a graph output at output_index.
auto node_outputs_for_graph = graph_.GetNodeOutputsInGraphOutputs(node);
for (auto idx : node_outputs_for_graph) {
if (idx == static_cast<int>(output_index)) {
output_edges_count++;
break;
}
}
return output_edges_count;
}
void NhwcTransformerImpl::CreateNhwcArgument(Node& node, Node& nhwc_node, int rank) {
size_t original_uses = RemoveOutputEdges(node);
void NhwcTransformerImpl::CreateNhwcArgument(Node& node, Node& nhwc_node, int rank, size_t output_index) {
size_t original_uses = RemoveOutputEdge(node, output_index);
// Create a new NodeArg to track the output from the NHWC node.
auto& output_defs = nhwc_node.MutableOutputDefs();
auto* output_original_arg = output_defs[0];
auto* output_original_arg = output_defs[output_index];
std::string output_reorder_def_name = graph_.GenerateNodeArgName("reorder");
auto* output_nhwc_arg = &graph_.GetOrCreateNodeArg(output_reorder_def_name, nullptr);
nhwc_args_[output_original_arg] =
onnxruntime::make_unique<NhwcArgument>(nhwc_node, output_nhwc_arg, original_uses, rank);
output_defs[0] = output_nhwc_arg;
output_defs[output_index] = output_nhwc_arg;
}
void NhwcTransformerImpl::CreateNhwcArgument(Node& node, Node& nhwc_node, int rank) {
size_t output_count = node.OutputDefs().size();
for (size_t output_index = 0; output_index < output_count; ++output_index) {
CreateNhwcArgument(node, nhwc_node, rank, output_index);
}
}
void NhwcTransformerImpl::InsertReorderInput(Node& node, int rank) {
@ -237,6 +249,39 @@ void NhwcTransformerImpl::TransformQLinearGlobalAveragePool(Node& node) {
CreateNhwcArgument(node, node, nhwc_input->rank_);
}
void NhwcTransformerImpl::TransformSplit(Node& node) {
auto& input_defs = node.MutableInputDefs();
auto* nhwc_input = LookupNhwcArgument(input_defs[0]);
if (nhwc_input == nullptr) {
return;
}
// Change the axis attribute accordingly for NCHW to NHWC model.
const auto* axis_attr = graph_utils::GetNodeAttribute(node, "axis");
if (axis_attr != nullptr && utils::HasInt(*axis_attr)) {
int64_t axis = axis_attr->i();
if (axis < -nhwc_input->rank_ || axis >= nhwc_input->rank_) {
// direct return on invalid axis
return;
}
if (axis < 0) {
axis = axis + nhwc_input->rank_;
}
if (axis == 1) {
axis = nhwc_input->rank_ - 1;
} else if (axis > 1) {
axis = axis - 1;
}
node.AddAttribute("axis", axis);
}
input_defs[0] = nhwc_input->nhwc_arg_;
nhwc_input->remaining_original_uses_--;
CreateNhwcArgument(node, node, nhwc_input->rank_);
}
void NhwcTransformerImpl::Transform(Node& node) {
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "QLinearConv", {10})) {
TransformQLinearConv(node);
@ -248,6 +293,8 @@ void NhwcTransformerImpl::Transform(Node& node) {
TransformQLinearActivation(node);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "QLinearGlobalAveragePool", {1}, kMSDomain)) {
TransformQLinearGlobalAveragePool(node);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Split", {2, 11, 13})) {
TransformSplit(node);
}
}

View file

@ -19,6 +19,7 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>(),
DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<std::string>()}),
Split);
@ -32,6 +33,7 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>(),
DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<std::string>()}),
Split);
@ -44,6 +46,7 @@ ONNX_CPU_OPERATOR_KERNEL(
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>(),
DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<std::string>()}),
Split);
@ -97,6 +100,8 @@ Status Split::Compute(OpKernelContext* context) const {
status = ComputeImpl<int32_t>(*context, input);
else if (input.IsDataType<int64_t>())
status = ComputeImpl<int64_t>(*context, input);
else if (input.IsDataType<uint8_t>())
status = ComputeImpl<uint8_t>(*context, input);
else if (input.IsDataTypeString())
status = ComputeImpl<std::string>(*context, input);
else

View file

@ -0,0 +1,35 @@
import onnx
from .base_operator import QuantOperatorBase
from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg
from onnx import onnx_pb as onnx_proto
class QSplit(QuantOperatorBase):
def __init__(self, onnx_quantizer, onnx_node):
super().__init__(onnx_quantizer, onnx_node)
def quantize(self):
node = self.node
quantized_input_names, zero_point_names, scale_names, nodes = self.quantizer.quantize_inputs(node, [0])
if node.name != "":
quantized_node_name = node.name + "_quant"
kwargs = {}
for attribute in node.attribute:
kwargs.update(attribute_to_kwarg(attribute))
# Output just derive the scale/zero from input
quantized_output_names = []
for output_name in node.output:
quantized_output_name = output_name + "quantized"
quantized_output_names.append(quantized_output_name)
q_output = QuantizedValue(output_name, quantized_output_name, scale_names[0],
zero_point_names[0], QuantizedValueType.Input)
self.quantizer.quantized_value_map[output_name] = q_output
if len(node.input) > 1:
quantized_input_names = quantized_input_names.extend(node.input[1:])
quantized_node = onnx.helper.make_node(
node.op_type, quantized_input_names, quantized_output_names, quantized_node_name, **kwargs)
nodes.append(quantized_node)
self.quantizer.new_nodes += nodes

View file

@ -9,7 +9,8 @@ from .operators.activation import QLinearActivation
from .operators.binary_op import QLinearBinaryOp
from .operators.maxpool import QMaxPool
from .operators.gavgpool import QGlobalAveragePool
from. operators.lstm import LSTMQuant
from .operators.lstm import LSTMQuant
from .operators.split import QSplit
CommonOpsRegistry = {"Gather": GatherQuant, "EmbedLayerNormalization": EmbedLayerNormalizationQuant}
@ -32,6 +33,7 @@ QLinearOpsRegistry = {
"Sigmoid" : QLinearActivation,
"MaxPool": QMaxPool,
"GlobalAveragePool": QGlobalAveragePool,
"Split" : QSplit,
}
QLinearOpsRegistry.update(CommonOpsRegistry)

View file

@ -419,6 +419,44 @@ TEST(NhwcTransformerTests, ConvGlobalAveragePool) {
NhwcTransformerTester(build_test_case, check_nhwc_graph);
}
TEST(NhwcTransformerTests, ConvSplit) {
for (int64_t axis = -4LL; axis < 4; axis++) {
auto build_test_case = [&, axis](NhwcTestHelper& helper) {
auto* input_arg = helper.MakeInput<uint8_t>({2, 23, 16, 16});
auto* conv_output_arg = helper.MakeIntermediate();
auto* split_output1_arg = helper.MakeIntermediate();
auto* split_output2_arg = helper.MakeIntermediate();
auto* qladd_output_arg = helper.MakeIntermediate();
auto* output_arg = helper.MakeOutput();
const int64_t conv1_output_channels = 32;
Node& conv_node = helper.AddQLinearConvNode<uint8_t>(input_arg, .01f, 135,
{conv1_output_channels, 23, 3, 3}, .02f, 126,
conv_output_arg, .37f, 131);
conv_node.AddAttribute("pads", std::vector<int64_t>{1, 1, 1, 1});
Node& split_node = helper.AddNode("Split", {conv_output_arg}, {split_output1_arg, split_output2_arg});
split_node.AddAttribute("axis", static_cast<int64_t>(axis));
helper.AddQLinearBinaryNode("QLinearAdd",
split_output1_arg, .37f, 131,
split_output2_arg, .37f, 131,
qladd_output_arg, .43f, 126);
const int64_t channels_after_split =
(axis == 1 || axis == -3) ? conv1_output_channels / 2 : conv1_output_channels;
helper.AddQLinearConvNode<uint8_t>(qladd_output_arg, .43f, 126,
{17, channels_after_split, 3, 3}, .02f, 126,
output_arg, .37f, 131);
};
auto check_nhwc_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
EXPECT_EQ(op_to_count["com.microsoft.QLinearConv"], 2);
EXPECT_EQ(op_to_count["Transpose"], 2);
};
NhwcTransformerTester(build_test_case, check_nhwc_graph);
}
}
TEST(NhwcTransformerTests, ConvBlockActivation) {
auto test_case = [&](uint32_t extra_edges) {
auto build_test_case = [&](NhwcTestHelper& helper) {

View file

@ -251,24 +251,23 @@ TEST(SplitOperatorTest, Axis1UnequalSplitString) {
RunTest<std::string>(axis, splits, input, outputs, false);
}
ShapeAndFloatData CreateInput(std::vector<int64_t> shape) {
template <typename T>
ShapeAndData<T> CreateInput(std::vector<int64_t> shape) {
auto size = TensorShape(shape).Size();
float i = 0.f, increment = 1.f;
// generate the elements for the data starting at 1.f
std::vector<float> data;
T i = static_cast<T>(0), increment = static_cast<T>(1);
// generate the elements for the data starting at 1
std::vector<T> data;
std::generate_n(std::back_inserter(data), size, [&]() { return i += increment; });
ShapeAndFloatData input = {shape, data};
return input;
return ShapeAndData<T>{shape, data};
}
TEST(SplitOperatorTest, Axis2EqualSplit) {
const int64_t axis = 2;
std::vector<ShapeAndFloatData> outputs;
ShapeAndFloatData input = CreateInput({2, 2, 6});
ShapeAndFloatData input = CreateInput<float>({2, 2, 6});
outputs.push_back({{2, 2, 2},
{1.f, 2.f,
@ -298,7 +297,7 @@ TEST(SplitOperatorTest, Axis2UnequalSplit) {
const int64_t axis = 2;
std::vector<ShapeAndFloatData> outputs;
ShapeAndFloatData input = CreateInput({2, 2, 6});
ShapeAndFloatData input = CreateInput<float>({2, 2, 6});
std::vector<int64_t> splits{1, 2, 3};
@ -330,7 +329,7 @@ TEST(SplitOperatorTest, ZeroSizeInput) {
const int64_t axis = -1;
std::vector<ShapeAndFloatData> outputs{{{0, 1}, {}}, {{0, 1}, {}}};
ShapeAndFloatData input = CreateInput({0, 2});
ShapeAndFloatData input = CreateInput<float>({0, 2});
RunTest<float>(axis, {}, input, outputs, false);
}
@ -340,7 +339,7 @@ TEST(SplitOperatorTest, Axis1SplitMiddleDimensionEqually) {
const int64_t axis = 1;
std::vector<ShapeAndFloatData> outputs;
ShapeAndFloatData input = CreateInput({2, 4, 4});
ShapeAndFloatData input = CreateInput<float>({2, 4, 4});
outputs.push_back({{2, 2, 4},
{1.f, 2.f, 3.f, 4.f,
@ -364,7 +363,7 @@ TEST(SplitOperatorTest, Axis1SplitMiddleDimensionUnequally) {
const int64_t axis = 1;
std::vector<ShapeAndFloatData> outputs;
ShapeAndFloatData input = CreateInput({2, 4, 4});
ShapeAndFloatData input = CreateInput<float>({2, 4, 4});
std::vector<int64_t> splits{1, 3};
@ -567,5 +566,52 @@ SplitAxis2()
SplitMiddleDimension()
*/
// test split for uint8_t data that has leading and trailing dimensions
TEST(SplitOperatorTest, Uint8Axis1SplitMiddleDimensionUnequally) {
const int64_t axis = 1;
std::vector<ShapeAndData<uint8_t>> outputs;
ShapeAndData<uint8_t> input = CreateInput<uint8_t>({2, 4, 4});
std::vector<int64_t> splits{1, 3};
outputs.push_back({{2, 1, 4},
{1, 2, 3, 4,
17, 18, 19, 20}});
outputs.push_back({{2, 3, 4},
{5, 6, 7, 8,
9, 10, 11, 12,
13, 14, 15, 16,
21, 22, 23, 24,
25, 26, 27, 28,
29, 30, 31, 32}});
RunTest<uint8_t>(axis, splits, input, outputs, false);
}
// test split for uint8_t data on the last axis equally
TEST(SplitOperatorTest, Uint8NegativeAxis) {
const int64_t axis = -1;
std::vector<ShapeAndData<uint8_t>> outputs;
ShapeAndData<uint8_t> input = {{2, 4},
{1, 2, 3, 4,
5, 6, 7, 8}};
outputs.push_back({{2, 2},
{1, 2,
5, 6}});
outputs.push_back({{2, 2},
{3, 4,
7, 8}});
RunTest<uint8_t>(axis, {}, input, outputs, false);
}
} // namespace test
} // namespace onnxruntime