mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-20 02:07:56 +00:00
[NNAPI EP] add uint8 support for Transpose/Concat/Maxpool, add support of QLinearSigmoid (#6534)
* Init change * Add QlinearSigmoid support * Update tests * Add resize int8 support * Add version check for resize linear uint8 and add scale/zero point check for concat uint8 * Address CR comments * minor fix and add test for uint8 handling * Address CR comments * Fixed an existing bug * Fix the new UT break, due to different rounding of 0.5 in device and emulator
This commit is contained in:
parent
6cb8f8c812
commit
464dbef143
9 changed files with 499 additions and 69 deletions
|
|
@ -62,6 +62,8 @@ QLinearOpType GetQLinearOpType(const onnxruntime::Node& node) {
|
|||
return QLinearOpType::QLinearMatMul;
|
||||
else if (op_type == "QLinearAdd")
|
||||
return QLinearOpType::QLinearAdd;
|
||||
else if (op_type == "QLinearSigmoid")
|
||||
return QLinearOpType::QLinearSigmoid;
|
||||
|
||||
return QLinearOpType::Unknown;
|
||||
}
|
||||
|
|
@ -232,8 +234,10 @@ bool HasValidQuantizationZeroPoints(const InitializedTensorSet& initializers, co
|
|||
|
||||
std::unique_ptr<uint8_t[]> unpacked_tensor;
|
||||
size_t tensor_byte_size;
|
||||
auto status = onnxruntime::utils::UnpackInitializerData(zero_tensor, node.ModelPath(),
|
||||
unpacked_tensor, tensor_byte_size);
|
||||
auto status = onnxruntime::utils::UnpackInitializerData(
|
||||
zero_tensor,
|
||||
node.ModelPath(),
|
||||
unpacked_tensor, tensor_byte_size);
|
||||
if (!status.IsOK()) {
|
||||
LOGS_DEFAULT(ERROR) << "QLinearConv erro when unpack zero tensor:" << status.ErrorMessage();
|
||||
return false;
|
||||
|
|
@ -264,6 +268,24 @@ bool HasValidQuantizationZeroPoints(const InitializedTensorSet& initializers, co
|
|||
return true;
|
||||
}
|
||||
|
||||
float GetQuantizationScale(const InitializedTensorSet& initializers, const Node& node, size_t idx) {
|
||||
const auto& scale_tensor = *initializers.at(node.InputDefs()[idx]->Name());
|
||||
return GetTensorFloatData(scale_tensor)[0];
|
||||
}
|
||||
|
||||
common::Status GetQuantizationZeroPoint(const InitializedTensorSet& initializers,
|
||||
const Node& node, size_t idx, int32_t& zero_point) {
|
||||
std::unique_ptr<uint8_t[]> unpacked_tensor;
|
||||
size_t tensor_byte_size;
|
||||
const auto& zero_point_tensor = *initializers.at(node.InputDefs()[idx]->Name());
|
||||
ORT_RETURN_IF_ERROR(
|
||||
onnxruntime::utils::UnpackInitializerData(zero_point_tensor, node.ModelPath(),
|
||||
unpacked_tensor, tensor_byte_size));
|
||||
// Onnx quantization uses uint8 [int8 not yet supported], need to cast to int32_t used by NNAPI
|
||||
zero_point = static_cast<int32_t>(unpacked_tensor.get()[0]);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define GET_TENSOR_DATA(FUNC_NAME, ELEMENT_TYPE, DATA) \
|
||||
const ELEMENT_TYPE* GetTensor##FUNC_NAME(const ONNX_NAMESPACE::TensorProto& tensor) { \
|
||||
return tensor.DATA().empty() \
|
||||
|
|
@ -348,13 +370,13 @@ void GetFlattenOutputShape(const Node& node, const Shape& input_shape, int32_t&
|
|||
dim_2 = std::accumulate(input_shape.cbegin() + axis, input_shape.cend(), 1, std::multiplies<int32_t>());
|
||||
}
|
||||
|
||||
bool IsValidSupportedNodesVec(const std::vector<size_t>& supported_node_vec, const GraphViewer& graph_viewer) {
|
||||
if (supported_node_vec.empty())
|
||||
bool IsValidSupportedNodesGroup(const std::vector<size_t>& supported_node_group, const GraphViewer& graph_viewer) {
|
||||
if (supported_node_group.empty())
|
||||
return false;
|
||||
|
||||
if (supported_node_vec.size() == 1) {
|
||||
if (supported_node_group.size() == 1) {
|
||||
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
|
||||
const auto* node(graph_viewer.GetNode(node_indices[supported_node_vec[0]]));
|
||||
const auto* node(graph_viewer.GetNode(node_indices[supported_node_group[0]]));
|
||||
const auto& op = node->OpType();
|
||||
// It is not worth it to perform a single Reshape/Flatten/Identity operator
|
||||
// which is only copying the data in NNAPI
|
||||
|
|
@ -368,49 +390,116 @@ bool IsValidSupportedNodesVec(const std::vector<size_t>& supported_node_vec, con
|
|||
return true;
|
||||
}
|
||||
|
||||
bool IsInternalQuantizedNode(const Node& node) {
|
||||
// These operators can use uint8 input without specific QLinear version of it
|
||||
// However, the mode has to be internal to the graph/partition (they cannot consume graph inputs)
|
||||
static const std::unordered_set<std::string> internal_quantized_op_types =
|
||||
{
|
||||
"Transpose",
|
||||
"Resize",
|
||||
"Concat",
|
||||
"MaxPool",
|
||||
};
|
||||
|
||||
if (!Contains(internal_quantized_op_types, node.OpType()))
|
||||
return false;
|
||||
|
||||
int32_t input_type;
|
||||
ORT_ENFORCE(GetType(*node.InputDefs()[0], input_type));
|
||||
|
||||
return input_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8;
|
||||
}
|
||||
|
||||
// We support some operators running using uint8 internally
|
||||
// These nodes cannot use a graph input as input since onnx graph input does not carry scale/zero point info
|
||||
bool IsInternalQuantizationSupported(const Node& node, const std::unordered_set<std::string>& node_outputs_in_group) {
|
||||
const auto& op_type = node.OpType();
|
||||
|
||||
// The node's input(s) have to be an output of node(s) within the group
|
||||
// If not, then this node is using graph/partition input(s) as input(s)
|
||||
const auto& input_defs = node.InputDefs();
|
||||
|
||||
// We only need to check input0 for all operators except "Concat"
|
||||
bool check_all_inputs = op_type == "Concat";
|
||||
|
||||
for (size_t i = 0; i < (check_all_inputs ? input_defs.size() : 1); i++) {
|
||||
if (!Contains(node_outputs_in_group, input_defs[i]->Name())) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Node [" << node.Name() << "] type: [" << op_type
|
||||
<< "] has input [" << input_defs[i]->Name()
|
||||
<< "] does not support using graph input(quantized) as node input";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const OpSupportCheckParams& params) {
|
||||
const auto& op_support_checkers = GetOpSupportCheckers();
|
||||
if (Contains(op_support_checkers, node.OpType())) {
|
||||
const auto* op_support_checker = op_support_checkers.at(node.OpType());
|
||||
return op_support_checker->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, params);
|
||||
} else {
|
||||
if (!Contains(op_support_checkers, node.OpType()))
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto* op_support_checker = op_support_checkers.at(node.OpType());
|
||||
return op_support_checker->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, params);
|
||||
}
|
||||
|
||||
bool IsNodeSupportedInternal(const Node& node, const GraphViewer& graph_viewer,
|
||||
const OpSupportCheckParams& params,
|
||||
const std::unordered_set<std::string>& node_outputs_in_group) {
|
||||
if (!IsNodeSupported(node, graph_viewer, params))
|
||||
return false;
|
||||
|
||||
// We also want to check if the node is supported as an internal quantized node
|
||||
if (IsInternalQuantizedNode(node))
|
||||
return IsInternalQuantizationSupported(node, node_outputs_in_group);
|
||||
else // This is not a internal quantized node, it is supported
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::vector<size_t>> GetSupportedNodes(const GraphViewer& graph_viewer, const OpSupportCheckParams& params) {
|
||||
std::vector<std::vector<size_t>> supported_node_vecs;
|
||||
std::vector<std::vector<size_t>> supported_node_groups;
|
||||
if (params.android_sdk_ver < ORT_NNAPI_MIN_API_LEVEL) {
|
||||
LOGS_DEFAULT(WARNING) << "All ops will fallback to CPU EP, because Android API level [" << params.android_sdk_ver
|
||||
<< "] is lower than minimal supported API level [" << ORT_NNAPI_MIN_API_LEVEL
|
||||
<< "] of this build for NNAPI";
|
||||
return supported_node_vecs;
|
||||
return supported_node_groups;
|
||||
}
|
||||
|
||||
std::vector<size_t> supported_node_vec;
|
||||
// This holds the supported node's topological index
|
||||
std::vector<size_t> supported_node_group;
|
||||
// This holds the NodeIndex of the nodes in the above group
|
||||
std::unordered_set<std::string> node_outputs_in_group;
|
||||
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
|
||||
for (size_t i = 0; i < node_indices.size(); i++) {
|
||||
const auto* node(graph_viewer.GetNode(node_indices[i]));
|
||||
bool supported = IsNodeSupported(*node, graph_viewer, params);
|
||||
bool supported = IsNodeSupportedInternal(*node, graph_viewer, params, node_outputs_in_group);
|
||||
LOGS_DEFAULT(VERBOSE) << "Operator type: [" << node->OpType()
|
||||
<< "] index: [" << i
|
||||
<< "] name: [" << node->Name()
|
||||
<< "] supported: [" << supported
|
||||
<< "]";
|
||||
if (supported) {
|
||||
supported_node_vec.push_back(i);
|
||||
} else {
|
||||
if (IsValidSupportedNodesVec(supported_node_vec, graph_viewer)) {
|
||||
supported_node_vecs.push_back(supported_node_vec);
|
||||
supported_node_vec.clear();
|
||||
supported_node_group.push_back(i);
|
||||
|
||||
// We want to put all the output names of nodes in the current group for easy query
|
||||
// See IsInternalQuantizationSupported()
|
||||
for (const auto* output : node->OutputDefs()) {
|
||||
node_outputs_in_group.insert(output->Name());
|
||||
}
|
||||
} else {
|
||||
if (IsValidSupportedNodesGroup(supported_node_group, graph_viewer)) {
|
||||
supported_node_groups.push_back(supported_node_group);
|
||||
}
|
||||
|
||||
supported_node_group.clear();
|
||||
node_outputs_in_group.clear();
|
||||
}
|
||||
}
|
||||
|
||||
if (IsValidSupportedNodesVec(supported_node_vec, graph_viewer))
|
||||
supported_node_vecs.push_back(supported_node_vec);
|
||||
if (IsValidSupportedNodesGroup(supported_node_group, graph_viewer))
|
||||
supported_node_groups.push_back(supported_node_group);
|
||||
|
||||
return supported_node_vecs;
|
||||
return supported_node_groups;
|
||||
}
|
||||
|
||||
std::string Shape2String(const std::vector<uint32_t>& shape) {
|
||||
|
|
|
|||
|
|
@ -76,6 +76,7 @@ enum class QLinearOpType : uint8_t {
|
|||
QLinearConv,
|
||||
QLinearMatMul,
|
||||
QLinearAdd,
|
||||
QLinearSigmoid,
|
||||
// Not yet supported
|
||||
// QLinearAveragePool,
|
||||
// QLinearMul,
|
||||
|
|
@ -107,6 +108,11 @@ bool HasValidQuantizationScales(const InitializedTensorSet& initializers, const
|
|||
bool HasValidQuantizationZeroPoints(const InitializedTensorSet& initializers, const Node& node,
|
||||
const std::vector<size_t>& indices);
|
||||
|
||||
float GetQuantizationScale(const InitializedTensorSet& initializers, const Node& node, size_t idx);
|
||||
|
||||
common::Status GetQuantizationZeroPoint(const InitializedTensorSet& initializers,
|
||||
const Node& node, size_t idx, int32_t& zero_point) ORT_MUST_USE_RESULT;
|
||||
|
||||
// Get initialize tensort float/int32/int64 data without unpacking
|
||||
// TODO, move to ort framework
|
||||
const float* GetTensorFloatData(const ONNX_NAMESPACE::TensorProto& tensor);
|
||||
|
|
|
|||
|
|
@ -143,7 +143,13 @@ std::unordered_map<std::string, vector<const Node*>> GetAllQuantizedOpInputs(con
|
|||
for (const auto& node_idx : node_indices) {
|
||||
const auto* node(graph_viewer.GetNode(node_idx));
|
||||
auto qlinear_op_type = GetQLinearOpType(*node);
|
||||
if (qlinear_op_type == QLinearOpType::DequantizeLinear || IsQLinearBinaryOp(qlinear_op_type)) {
|
||||
|
||||
// Not a qlinear op
|
||||
if (qlinear_op_type == QLinearOpType::Unknown)
|
||||
continue;
|
||||
|
||||
// All qlinear ops EXCEPT QuantizeLinear has quantized input
|
||||
if (qlinear_op_type != QLinearOpType::QuantizeLinear) {
|
||||
const auto& input_name = node->InputDefs()[0]->Name();
|
||||
if (Contains(all_quantized_op_inputs, input_name))
|
||||
all_quantized_op_inputs.at(input_name).push_back(node);
|
||||
|
|
@ -293,7 +299,7 @@ Status ModelBuilder::RegisterModelInputs() {
|
|||
if (!Contains(all_quantized_op_inputs, input_name)) {
|
||||
// We current do not support uint8 input if it is not a quantized input
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"The input of graph doesn't have valid type, name: ", input_name,
|
||||
"The input of graph has unsupported quantized type, name: ", input_name,
|
||||
" type: ", type_proto->tensor_type().elem_type());
|
||||
}
|
||||
|
||||
|
|
@ -305,7 +311,7 @@ Status ModelBuilder::RegisterModelInputs() {
|
|||
default: {
|
||||
// TODO: support other type
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"The input of graph doesn't have valid type, name: ", input_name,
|
||||
"The input of graph has unsupported type, name: ", input_name,
|
||||
" type: ", type_proto->tensor_type().elem_type());
|
||||
}
|
||||
}
|
||||
|
|
@ -369,6 +375,7 @@ void ModelBuilder::RegisterModelShaper() {
|
|||
Status ModelBuilder::AddNewOperand(const std::string& name,
|
||||
const OperandType& operand_type,
|
||||
bool is_nhwc, uint32_t& index) {
|
||||
LOGS_DEFAULT(VERBOSE) << "operand name: " << name;
|
||||
ORT_RETURN_IF_ERROR(AddNewNNAPIOperand(operand_type, index));
|
||||
RegisterOperand(name, index, operand_type, is_nhwc);
|
||||
return Status::OK();
|
||||
|
|
@ -535,6 +542,12 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
|
|||
|
||||
int32_t ModelBuilder::FindActivation(const Node& node, const NodeArg& output) {
|
||||
int32_t fuse_code = ANEURALNETWORKS_FUSED_NONE;
|
||||
|
||||
// We do not support activation fusion for quantized operators for now
|
||||
auto qlinear_op_type = GetQLinearOpType(node);
|
||||
if (qlinear_op_type != QLinearOpType::Unknown)
|
||||
return fuse_code;
|
||||
|
||||
for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) {
|
||||
const auto& dst_node = it->GetNode();
|
||||
const auto* dst_input = dst_node.InputDefs()[it->GetDstArgIndex()];
|
||||
|
|
|
|||
|
|
@ -495,24 +495,6 @@ static Status HandleAutoPad(const Shape& input_shape,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
static float GetQuantizationScale(const ModelBuilder& model_builder, const Node& node, size_t idx) {
|
||||
const auto& scale_tensor = *model_builder.GetInitializerTensors().at(node.InputDefs()[idx]->Name());
|
||||
return GetTensorFloatData(scale_tensor)[0];
|
||||
}
|
||||
|
||||
static Status GetQuantizationZeroPoint(const ModelBuilder& model_builder, const Node& node, size_t idx, int32_t& zero_point)
|
||||
ORT_MUST_USE_RESULT;
|
||||
static Status GetQuantizationZeroPoint(const ModelBuilder& model_builder, const Node& node, size_t idx, int32_t& zero_point) {
|
||||
std::unique_ptr<uint8_t[]> unpacked_tensor;
|
||||
size_t tensor_byte_size;
|
||||
const auto& zero_point_tensor = *model_builder.GetInitializerTensors().at(node.InputDefs()[idx]->Name());
|
||||
ORT_RETURN_IF_ERROR(
|
||||
onnxruntime::utils::UnpackInitializerData(zero_point_tensor, model_builder.GetGraphViewer().ModelPath(),
|
||||
unpacked_tensor, tensor_byte_size));
|
||||
zero_point = static_cast<int32_t>(unpacked_tensor.get()[0]);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get scales and zero points for the qlinear binary ops (which has 2 input and 1 output)
|
||||
// QLinearConv, QLinearMatmul, QLinearAdd
|
||||
// a, b are inputs, and y is output
|
||||
|
|
@ -524,13 +506,14 @@ static Status GetBinaryOpQuantizationScaleAndZeroPoint(
|
|||
const ModelBuilder& model_builder, const Node& node,
|
||||
float& a_scale, float& b_scale, float& y_scale,
|
||||
int32_t& a_zero_point, int32_t& b_zero_point, int32_t& y_zero_point) {
|
||||
a_scale = GetQuantizationScale(model_builder, node, 1);
|
||||
b_scale = GetQuantizationScale(model_builder, node, 4);
|
||||
y_scale = GetQuantizationScale(model_builder, node, 6);
|
||||
const auto& initializers = model_builder.GetInitializerTensors();
|
||||
a_scale = GetQuantizationScale(initializers, node, 1);
|
||||
b_scale = GetQuantizationScale(initializers, node, 4);
|
||||
y_scale = GetQuantizationScale(initializers, node, 6);
|
||||
|
||||
ORT_RETURN_IF_ERROR(GetQuantizationZeroPoint(model_builder, node, 2, a_zero_point));
|
||||
ORT_RETURN_IF_ERROR(GetQuantizationZeroPoint(model_builder, node, 5, b_zero_point));
|
||||
ORT_RETURN_IF_ERROR(GetQuantizationZeroPoint(model_builder, node, 7, y_zero_point));
|
||||
ORT_RETURN_IF_ERROR(GetQuantizationZeroPoint(initializers, node, 2, a_zero_point));
|
||||
ORT_RETURN_IF_ERROR(GetQuantizationZeroPoint(initializers, node, 5, b_zero_point));
|
||||
ORT_RETURN_IF_ERROR(GetQuantizationZeroPoint(initializers, node, 7, y_zero_point));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -660,7 +643,8 @@ Status GetQuantizedInputScaleAndZeroPoint(const ModelBuilder& model_builder,
|
|||
qlinear_op_type != QLinearOpType::QuantizeLinear);
|
||||
|
||||
size_t scale_idx, zero_point_idx;
|
||||
if (qlinear_op_type == QLinearOpType::DequantizeLinear) {
|
||||
if (qlinear_op_type == QLinearOpType::DequantizeLinear ||
|
||||
qlinear_op_type == QLinearOpType::QLinearSigmoid) {
|
||||
scale_idx = 1;
|
||||
zero_point_idx = 2;
|
||||
} else if (IsQLinearBinaryOp(qlinear_op_type)) {
|
||||
|
|
@ -679,10 +663,10 @@ Status GetQuantizedInputScaleAndZeroPoint(const ModelBuilder& model_builder,
|
|||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported op: ", op_type);
|
||||
}
|
||||
|
||||
scale = GetQuantizationScale(model_builder, node, scale_idx);
|
||||
scale = GetQuantizationScale(model_builder.GetInitializerTensors(), node, scale_idx);
|
||||
zero_point = 0;
|
||||
if (node.InputDefs().size() > 2) {
|
||||
ORT_RETURN_IF_ERROR(GetQuantizationZeroPoint(model_builder, node, zero_point_idx, zero_point));
|
||||
ORT_RETURN_IF_ERROR(GetQuantizationZeroPoint(model_builder.GetInitializerTensors(), node, zero_point_idx, zero_point));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
@ -811,7 +795,7 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
a_zero_point, b_zero_point, y_zero_point));
|
||||
}
|
||||
|
||||
// Verify if the scale and zero point matchs from onnx input and nnapi input
|
||||
// Verify if the scale and zero point matchs from onnx input and nnapi input match
|
||||
if (op_is_qlinear) {
|
||||
ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input1, a_scale, a_zero_point));
|
||||
ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input2, b_scale, b_zero_point));
|
||||
|
|
@ -1260,7 +1244,8 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
|
|||
onnx_pads, onnx_strides, kernel_shape,
|
||||
use_nchw,
|
||||
output));
|
||||
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
|
||||
OperandType output_operand_type = operand_types.at(input);
|
||||
output_operand_type.SetDimensions(shaper[output]);
|
||||
ORT_RETURN_IF_ERROR(model_builder.AddOperation(op_code, input_indices,
|
||||
{output}, {output_operand_type}, {output_is_nhwc}));
|
||||
return Status::OK();
|
||||
|
|
@ -1802,12 +1787,29 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
|
|||
|
||||
class UnaryOpBuilder : public BaseOpBuilder {
|
||||
public:
|
||||
void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
|
||||
static void CreateSharedOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
|
||||
private:
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) const override ORT_MUST_USE_RESULT;
|
||||
};
|
||||
|
||||
void UnaryOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
|
||||
const auto& op = node.OpType();
|
||||
if (op != "QLinearSigmoid")
|
||||
return;
|
||||
|
||||
const auto input_defs = node.InputDefs();
|
||||
|
||||
// skip input/output scales and zeropoints
|
||||
model_builder.AddInitializerToSkip(input_defs[1]->Name()); // X_scale
|
||||
model_builder.AddInitializerToSkip(input_defs[2]->Name()); // X_zero_point
|
||||
model_builder.AddInitializerToSkip(input_defs[3]->Name()); // Y_scale
|
||||
|
||||
if (input_defs.size() == 5) // has Y_zero_point input
|
||||
model_builder.AddInitializerToSkip(input_defs[4]->Name()); // Y_zero_point
|
||||
}
|
||||
|
||||
/* static */ void UnaryOpBuilder::CreateSharedOpBuilder(
|
||||
const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
CreateSharedOpBuilderImpl<UnaryOpBuilder>(
|
||||
|
|
@ -1822,6 +1824,7 @@ class UnaryOpBuilder : public BaseOpBuilder {
|
|||
"Sin",
|
||||
"Sqrt",
|
||||
"Tanh",
|
||||
"QLinearSigmoid",
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -1836,7 +1839,7 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
|
||||
|
||||
ORT_RETURN_IF_ERROR(shaper.Identity(input, output));
|
||||
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
|
||||
bool is_qlinear_sigmoid = op_type == "QLinearSigmoid";
|
||||
|
||||
int32_t op_code;
|
||||
if (op_type == "Abs")
|
||||
|
|
@ -1847,7 +1850,7 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
op_code = ANEURALNETWORKS_FLOOR;
|
||||
else if (op_type == "Log")
|
||||
op_code = ANEURALNETWORKS_LOG;
|
||||
else if (op_type == "Sigmoid")
|
||||
else if (op_type == "Sigmoid" || is_qlinear_sigmoid)
|
||||
op_code = ANEURALNETWORKS_LOGISTIC;
|
||||
else if (op_type == "Neg")
|
||||
op_code = ANEURALNETWORKS_NEG;
|
||||
|
|
@ -1860,8 +1863,26 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "UnaryOpBuilder, unknown op: ", op_type);
|
||||
}
|
||||
|
||||
float y_scale = 0.0f;
|
||||
int32_t y_zero_point = 0;
|
||||
if (is_qlinear_sigmoid) {
|
||||
const auto& initializers = model_builder.GetInitializerTensors();
|
||||
float x_scale = GetQuantizationScale(initializers, node, 1);
|
||||
int32_t x_zero_point = 0;
|
||||
ORT_RETURN_IF_ERROR(GetQuantizationZeroPoint(initializers, node, 2, x_zero_point));
|
||||
|
||||
// Verify if the scale and zero point values from onnx input and nnapi input match
|
||||
ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, x_scale, x_zero_point));
|
||||
|
||||
// We already verified this in UnaryOpSupportChecker::IsOpSupportedImpl
|
||||
y_scale = 1.f / 256;
|
||||
y_zero_point = 0;
|
||||
}
|
||||
|
||||
std::vector<uint32_t> input_indices;
|
||||
input_indices.push_back(operand_indices.at(input));
|
||||
const OperandType output_operand_type(operand_types.at(input).type, shaper[output], y_scale, y_zero_point);
|
||||
ORT_RETURN_IF_ERROR(model_builder.AddOperation(op_code, input_indices,
|
||||
{output}, {output_operand_type}, {output_is_nhwc}));
|
||||
return Status::OK();
|
||||
|
|
@ -1888,6 +1909,24 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
bool output_is_nhwc = false;
|
||||
const auto node_input_size = node.InputDefs().size();
|
||||
|
||||
// First if the inputs are uint8, we need verify all the inputs have same scale and zero points
|
||||
if (operand_types.at(input0).type == android::nn::wrapper::Type::TENSOR_QUANT8_ASYMM) {
|
||||
auto scale = operand_types.at(input0).operandType.scale;
|
||||
auto zero_point = operand_types.at(input0).operandType.zeroPoint;
|
||||
|
||||
// Compare scale and zp of input0 to input1~n
|
||||
for (size_t i = 1; i < node_input_size; i++) {
|
||||
const auto& type = operand_types.at(node.InputDefs()[i]->Name());
|
||||
ORT_RETURN_IF_NOT(scale == type.operandType.scale,
|
||||
"Input[", i, "]'s scale: ", type.operandType.scale,
|
||||
" is different than input[0]'s scale: ", scale);
|
||||
|
||||
ORT_RETURN_IF_NOT(zero_point == type.operandType.zeroPoint,
|
||||
"Input[", i, "]'s zero_point: ", type.operandType.zeroPoint,
|
||||
" is different than input[0]'s zero_point: ", zero_point);
|
||||
}
|
||||
}
|
||||
|
||||
// First we want to see if all the input are same layout
|
||||
for (size_t i = 0; i < node_input_size - 1; i++) {
|
||||
all_input_have_same_layout =
|
||||
|
|
@ -1934,7 +1973,8 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
|
||||
const auto& output = node.OutputDefs()[0]->Name();
|
||||
ORT_RETURN_IF_ERROR(shaper.Concat(inputs, axis, output));
|
||||
const OperandType output_operand_type(operand_types.at(input0).type, shaper[output]);
|
||||
OperandType output_operand_type = operand_types.at(input0);
|
||||
output_operand_type.SetDimensions(shaper[output]);
|
||||
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_CONCATENATION, input_indices,
|
||||
{output}, {output_operand_type}, {output_is_nhwc}));
|
||||
return Status::OK();
|
||||
|
|
@ -2023,12 +2063,12 @@ Status QuantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builde
|
|||
const auto& output = node.OutputDefs()[0]->Name();
|
||||
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
|
||||
|
||||
float scale = GetQuantizationScale(model_builder, node, 1);
|
||||
float scale = GetQuantizationScale(model_builder.GetInitializerTensors(), node, 1);
|
||||
int32_t zero_point = 0;
|
||||
Type output_type = Type::TENSOR_QUANT8_ASYMM;
|
||||
|
||||
if (input_defs.size() == 3) { // Get zero point
|
||||
ORT_RETURN_IF_ERROR(GetQuantizationZeroPoint(model_builder, node, 2, zero_point));
|
||||
ORT_RETURN_IF_ERROR(GetQuantizationZeroPoint(model_builder.GetInitializerTensors(), node, 2, zero_point));
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(shaper.Identity(input, output));
|
||||
|
|
@ -2070,10 +2110,10 @@ Status DequantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil
|
|||
const auto& output = node.OutputDefs()[0]->Name();
|
||||
bool output_is_nhwc = model_builder.IsOperandNHWC(input);
|
||||
|
||||
float scale = GetQuantizationScale(model_builder, node, 1);
|
||||
float scale = GetQuantizationScale(model_builder.GetInitializerTensors(), node, 1);
|
||||
int32_t zero_point = 0;
|
||||
if (input_defs.size() == 3) { // Get zero point
|
||||
ORT_RETURN_IF_ERROR(GetQuantizationZeroPoint(model_builder, node, 2, zero_point));
|
||||
ORT_RETURN_IF_ERROR(GetQuantizationZeroPoint(model_builder.GetInitializerTensors(), node, 2, zero_point));
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, scale, zero_point));
|
||||
|
|
@ -2296,7 +2336,8 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
}
|
||||
}
|
||||
|
||||
const OperandType output_operand_type(operand_types.at(input).type, output_shape);
|
||||
OperandType output_operand_type = operand_types.at(input);
|
||||
output_operand_type.SetDimensions(output_shape);
|
||||
ORT_RETURN_IF_ERROR(model_builder.AddOperation(operationCode, input_indices,
|
||||
{output}, {output_operand_type}, {output_is_nhwc}));
|
||||
|
||||
|
|
@ -2468,6 +2509,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
|
|||
NNAPI_EP_ADD_SHARED_OP_BUILDER("Sin", UnaryOpBuilder);
|
||||
NNAPI_EP_ADD_SHARED_OP_BUILDER("Sqrt", UnaryOpBuilder);
|
||||
NNAPI_EP_ADD_SHARED_OP_BUILDER("Tanh", UnaryOpBuilder);
|
||||
NNAPI_EP_ADD_SHARED_OP_BUILDER("QLinearSigmoid", UnaryOpBuilder);
|
||||
}
|
||||
|
||||
NNAPI_EP_ADD_SINGLE_OP_BUILDER("Concat", ConcatOpBuilder);
|
||||
|
|
|
|||
|
|
@ -321,6 +321,8 @@ class TransposeOpSupportChecker : public BaseOpSupportChecker {
|
|||
int32_t GetMinSupportedSdkVer(const Node& /* node */, const OpSupportCheckParams& /* params */) const override {
|
||||
return 28;
|
||||
}
|
||||
|
||||
bool HasSupportedInputsImpl(const Node& node) const override;
|
||||
};
|
||||
|
||||
bool TransposeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
|
|
@ -339,6 +341,22 @@ bool TransposeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /*
|
|||
return true;
|
||||
}
|
||||
|
||||
bool TransposeOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
|
||||
int32_t input_type;
|
||||
if (!GetType(*node.InputDefs()[0], input_type))
|
||||
return false;
|
||||
|
||||
if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
|
||||
input_type != ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
|
||||
LOGS_DEFAULT(VERBOSE) << "[" << node.OpType()
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region op_reshape
|
||||
|
|
@ -465,6 +483,8 @@ class PoolOpSupportChecker : public BaseOpSupportChecker {
|
|||
int32_t GetMinSupportedSdkVer(const Node& /* node */, const OpSupportCheckParams& params) const override {
|
||||
return params.use_nchw ? 29 : 28;
|
||||
}
|
||||
|
||||
bool HasSupportedInputsImpl(const Node& node) const override;
|
||||
};
|
||||
|
||||
/* static */ void PoolOpSupportChecker::CreateSharedOpSupportChecker(
|
||||
|
|
@ -537,6 +557,25 @@ bool PoolOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* init
|
|||
return true;
|
||||
}
|
||||
|
||||
bool PoolOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
|
||||
if (node.OpType() != "MaxPool")
|
||||
return BaseOpSupportChecker::HasSupportedInputsImpl(node);
|
||||
|
||||
int32_t input_type;
|
||||
if (!GetType(*node.InputDefs()[0], input_type))
|
||||
return false;
|
||||
|
||||
if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
|
||||
input_type != ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
|
||||
LOGS_DEFAULT(VERBOSE) << "[" << node.OpType()
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#pragma endregion op_pool
|
||||
|
||||
#pragma region op_conv
|
||||
|
|
@ -917,11 +956,17 @@ class UnaryOpSupportChecker : public BaseOpSupportChecker {
|
|||
const std::string& op_type, OpSupportCheckerRegistrations& op_registrations);
|
||||
|
||||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const OpSupportCheckParams& params) const override;
|
||||
|
||||
int32_t GetMinSupportedSdkVer(const Node& node, const OpSupportCheckParams& params) const override;
|
||||
|
||||
// All ops except "Sin" opset 5- uses consumed_inputs attribute which is not supported for now
|
||||
// "Sin" op has support from opset 7, return 6 here for all ops
|
||||
int GetMinSupportedOpSet(const Node& /* node */) const override { return 6; }
|
||||
bool HasSupportedInputsImpl(const Node& node) const override;
|
||||
|
||||
int GetMinSupportedOpSet(const Node& node) const override;
|
||||
|
||||
static bool IsQuantizedOpSupported(const InitializedTensorSet& initializers, const Node& node,
|
||||
const OpSupportCheckParams& params);
|
||||
};
|
||||
|
||||
/* static */ void UnaryOpSupportChecker::CreateSharedOpSupportChecker(
|
||||
|
|
@ -938,9 +983,18 @@ class UnaryOpSupportChecker : public BaseOpSupportChecker {
|
|||
"Sin",
|
||||
"Sqrt",
|
||||
"Tanh",
|
||||
"QLinearSigmoid",
|
||||
});
|
||||
}
|
||||
|
||||
bool UnaryOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const OpSupportCheckParams& params) const {
|
||||
if (node.OpType() == "QLinearSigmoid")
|
||||
return IsQuantizedOpSupported(initializers, node, params);
|
||||
else // Everything except "QLinearSigmoid" are by default supported
|
||||
return true;
|
||||
}
|
||||
|
||||
int32_t UnaryOpSupportChecker::GetMinSupportedSdkVer(
|
||||
const Node& node, const OpSupportCheckParams& /* params */) const {
|
||||
const auto& op(node.OpType());
|
||||
|
|
@ -956,6 +1010,86 @@ int32_t UnaryOpSupportChecker::GetMinSupportedSdkVer(
|
|||
return 27;
|
||||
}
|
||||
|
||||
bool UnaryOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
|
||||
// We only need to override input check for QLinearSigmoid
|
||||
if (node.OpType() != "QLinearSigmoid")
|
||||
return BaseOpSupportChecker::HasSupportedInputsImpl(node);
|
||||
|
||||
int32_t input_type;
|
||||
if (!GetType(*node.InputDefs()[0], input_type))
|
||||
return false;
|
||||
|
||||
if (input_type != ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
|
||||
LOGS_DEFAULT(VERBOSE) << "[" << node.OpType()
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// All ops except "Sin" opset 5- uses consumed_inputs attribute which is not supported for now
|
||||
// "Sin" op has support from opset 7, return 6 here for all ops
|
||||
// "QLinearSigmoid" is a contrib op, OpSet will always be 1
|
||||
int UnaryOpSupportChecker::GetMinSupportedOpSet(const Node& node) const {
|
||||
if (node.OpType() == "QLinearSigmoid")
|
||||
return 1;
|
||||
|
||||
return 6;
|
||||
}
|
||||
|
||||
/* static */ bool UnaryOpSupportChecker::IsQuantizedOpSupported(
|
||||
const InitializedTensorSet& initializers, const Node& node, const OpSupportCheckParams& params) {
|
||||
const auto& op_type = node.OpType();
|
||||
ORT_ENFORCE(op_type == "QLinearSigmoid");
|
||||
|
||||
const auto& op_name = node.Name();
|
||||
const auto input_defs(node.InputDefs());
|
||||
// const auto output_defs(node.OutputDefs());
|
||||
|
||||
if (input_defs.size() < 4)
|
||||
return false;
|
||||
|
||||
bool has_output_zp = input_defs.size() == 5;
|
||||
|
||||
if (!HasValidQuantizationScales(initializers, node, {1, 3}, params))
|
||||
return false;
|
||||
|
||||
if (!HasValidQuantizationZeroPoints(initializers, node,
|
||||
has_output_zp
|
||||
? std::vector<size_t>{2}
|
||||
: std::vector<size_t>{2, 4}))
|
||||
return false;
|
||||
|
||||
// NNAPI requires the scale be 1.f/256 and zero point to be 0
|
||||
// See https://android.googlesource.com/platform/frameworks/ml/+/refs/heads/android10-c2f2-release/nn/common/operations/Activation.cpp#180
|
||||
auto output_scale = GetQuantizationScale(initializers, node, 3);
|
||||
if (output_scale != 1.f / 256) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Op [" << op_type << "] name [" << op_name
|
||||
<< "] output scale can only be 1.f/256, actual scale: " << output_scale;
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t output_zp;
|
||||
if (has_output_zp) {
|
||||
auto status = GetQuantizationZeroPoint(initializers, node, 4, output_zp);
|
||||
if (!status.IsOK()) {
|
||||
LOGS_DEFAULT(ERROR) << "Op [" << op_type << "] name [" << op_name
|
||||
<< "] GetQuantizationZeroPoint failed, message: " << status.ErrorMessage();
|
||||
return false;
|
||||
}
|
||||
|
||||
if (output_zp != 0) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Op [" << op_type << "] name [" << op_name
|
||||
<< "] output zero point can only be 0, actual zero point: " << output_scale;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region op_concat
|
||||
|
|
@ -964,6 +1098,8 @@ class ConcatOpSupportChecker : public BaseOpSupportChecker {
|
|||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const OpSupportCheckParams& params) const override;
|
||||
|
||||
bool HasSupportedInputsImpl(const Node& node) const override;
|
||||
};
|
||||
|
||||
bool ConcatOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
|
|
@ -982,6 +1118,22 @@ bool ConcatOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* in
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ConcatOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
|
||||
int32_t input_type;
|
||||
if (!GetType(*node.InputDefs()[0], input_type))
|
||||
return false;
|
||||
|
||||
if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
|
||||
input_type != ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
|
||||
LOGS_DEFAULT(VERBOSE) << "[" << node.OpType()
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region op_squeeze
|
||||
|
|
@ -1172,13 +1324,13 @@ class ResizeOpSupportChecker : public BaseOpSupportChecker {
|
|||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const OpSupportCheckParams& params) const override;
|
||||
|
||||
int32_t GetMinSupportedSdkVer(const Node& /* node */, const OpSupportCheckParams& /* params */) const override {
|
||||
return 28;
|
||||
}
|
||||
int32_t GetMinSupportedSdkVer(const Node& /* node */, const OpSupportCheckParams& /* params */) const override;
|
||||
|
||||
// Resize opset 10- is very different than Resize opset 11+, with many key attributes missing
|
||||
// We only support Resize opset 11+ here
|
||||
int GetMinSupportedOpSet(const Node& /* node */) const override { return 11; }
|
||||
|
||||
bool HasSupportedInputsImpl(const Node& node) const override;
|
||||
};
|
||||
|
||||
bool ResizeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
|
|
@ -1291,6 +1443,35 @@ bool ResizeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initi
|
|||
return true;
|
||||
}
|
||||
|
||||
int32_t ResizeOpSupportChecker::GetMinSupportedSdkVer(const Node& node, const OpSupportCheckParams& /* params */) const {
|
||||
int32_t input_type;
|
||||
|
||||
// This should not happen, but if it happens make sure this will require an impossible version
|
||||
if (!GetType(*node.InputDefs()[0], input_type))
|
||||
return std::numeric_limits<int32_t>::max();
|
||||
|
||||
if (input_type != ONNX_NAMESPACE::TensorProto_DataType_UINT8)
|
||||
return 29;
|
||||
|
||||
return 28;
|
||||
}
|
||||
|
||||
bool ResizeOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
|
||||
int32_t input_type;
|
||||
if (!GetType(*node.InputDefs()[0], input_type))
|
||||
return false;
|
||||
|
||||
if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
|
||||
input_type != ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
|
||||
LOGS_DEFAULT(VERBOSE) << "[" << node.OpType()
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region op_flatten
|
||||
|
|
@ -1439,6 +1620,7 @@ static OpSupportCheckerRegistrations CreateOpSupportCheckerRegistrations() {
|
|||
NNAPI_EP_ADD_SHARED_OP_SUPPORT_CHECKER("Sin", UnaryOpSupportChecker);
|
||||
NNAPI_EP_ADD_SHARED_OP_SUPPORT_CHECKER("Sqrt", UnaryOpSupportChecker);
|
||||
NNAPI_EP_ADD_SHARED_OP_SUPPORT_CHECKER("Tanh", UnaryOpSupportChecker);
|
||||
NNAPI_EP_ADD_SHARED_OP_SUPPORT_CHECKER("QLinearSigmoid", UnaryOpSupportChecker);
|
||||
}
|
||||
|
||||
NNAPI_EP_ADD_SINGLE_OP_SUPPORT_CHECKER("Concat", ConcatOpSupportChecker);
|
||||
|
|
|
|||
|
|
@ -89,5 +89,31 @@ TEST(QLinearLookupTableBasedOperatorTests, QLinearSigmoid_UInt8) {
|
|||
std::fesetround(origin_round_mode);
|
||||
}
|
||||
|
||||
// NNAPI can only take 0 as Y_zero_point
|
||||
TEST(QLinearLookupTableBasedOperatorTests, QLinearSigmoid_UInt8_0_Y_ZP) {
|
||||
auto run_test = [](bool scales_and_zp_are_initializers) {
|
||||
OpTester test("QLinearSigmoid", 1, onnxruntime::kMSDomain);
|
||||
float X_scale = 0.025f;
|
||||
uint8_t X_zero_point = 128;
|
||||
float Y_scale = 1.0f / 256.0f;
|
||||
uint8_t Y_zero_point = 0;
|
||||
|
||||
std::vector<int64_t> dims = {16};
|
||||
test.AddInput<uint8_t>("X", dims, {0, 16, 17, 18, 19, 90, 91, 127, 128, 136, 137, 138, 216, 217, 218, 255});
|
||||
test.AddInput<float>("X_scale", {}, {X_scale}, scales_and_zp_are_initializers);
|
||||
test.AddInput<uint8_t>("X_zero_point", {}, {X_zero_point}, scales_and_zp_are_initializers);
|
||||
test.AddInput<float>("Y_scale", {}, {Y_scale}, scales_and_zp_are_initializers);
|
||||
test.AddInput<uint8_t>("Y_zero_point", {}, {Y_zero_point}, scales_and_zp_are_initializers);
|
||||
test.AddOutput<uint8_t>("Y", dims, {10, 15, 15, 15, 16, 71, 73, 126, 128, 141, 142, 144, 230, 231, 232, 246});
|
||||
auto origin_round_mode = std::fegetround();
|
||||
std::fesetround(FE_TONEAREST);
|
||||
test.Run();
|
||||
std::fesetround(origin_round_mode);
|
||||
};
|
||||
|
||||
run_test(false);
|
||||
run_test(true);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -72,6 +72,39 @@ TEST(NnapiExecutionProviderTest, ReshapeFlattenTest) {
|
|||
#endif
|
||||
}
|
||||
|
||||
// This is to test the uint8 handling of operators without "QLinear" such as Concat and Transpose
|
||||
// NNAPI will require scale and zero point for inputs of all quantized operations
|
||||
// For these operators without "Qlinear", there is no information about the scale and zero point, we can
|
||||
// only fetch these from the output of the previous node
|
||||
// So uint8 support of these operators will only be enabled when they are internal to the graph
|
||||
// by not consuming graph inputs
|
||||
TEST(NnapiExecutionProviderTest, InternalUint8SupportTest) {
|
||||
const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/nnapi_internal_uint8_support.onnx");
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
std::vector<int64_t> dims_x = {1, 3};
|
||||
std::vector<float> values_x = {0.0f, 256.0f, 512.0f};
|
||||
OrtValue ml_value_x;
|
||||
CreateMLValue<float>(TestNnapiExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_x, values_x,
|
||||
&ml_value_x);
|
||||
NameMLValMap feeds;
|
||||
feeds.insert(std::make_pair("X", ml_value_x));
|
||||
|
||||
RunAndVerifyOutputsWithEP(model_file_name, "NnapiExecutionProviderTest.InternalUint8SupportTest",
|
||||
onnxruntime::make_unique<NnapiExecutionProvider>(0),
|
||||
feeds);
|
||||
#else
|
||||
// test load only
|
||||
SessionOptions so;
|
||||
InferenceSessionWrapper session_object{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique<NnapiExecutionProvider>(0)));
|
||||
ASSERT_STATUS_OK(session_object.Load(model_file_name));
|
||||
ASSERT_STATUS_OK(session_object.Initialize());
|
||||
ASSERT_GT(CountAssignedNodes(session_object.GetGraph(), kNnapiExecutionProvider), 0)
|
||||
<< "Some nodes should have been taken by the NNAPI EP";
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
// This is to verify the op_builders and op_support_checkers are consistent
|
||||
TEST(NnapiExecutionProviderTest, CreateOpBuilderAndOpSupportCheckerTest) {
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/nnapi_internal_uint8_support.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/nnapi_internal_uint8_support.onnx
vendored
Normal file
Binary file not shown.
39
onnxruntime/test/testdata/nnapi_internal_uint8_support.py
vendored
Normal file
39
onnxruntime/test/testdata/nnapi_internal_uint8_support.py
vendored
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import TensorProto
|
||||
|
||||
|
||||
# This is to test the operators without "Qlinear" support but still support uint8 input
|
||||
# These operators need to be internal to a graph/partition
|
||||
# def GenerateModel(model_name):
|
||||
def GenerateModel(model_name):
|
||||
nodes = [
|
||||
helper.make_node("QuantizeLinear", ["X", "Scale", "Zero_point"], ["X_quantized"], "quantize"),
|
||||
helper.make_node("Concat", ["X_quantized", "X_quantized"], ["X_concat"], axis=0, name="concat"),
|
||||
helper.make_node("Transpose", ["X_concat"], ["X_transposed"], "transpose"),
|
||||
helper.make_node("DequantizeLinear", ["X_transposed", "Scale", "Zero_point"], ["Y"], "dequantize"),
|
||||
]
|
||||
|
||||
initializers = [
|
||||
helper.make_tensor('Scale', TensorProto.FLOAT, [1], [256.0]),
|
||||
helper.make_tensor('Zero_point', TensorProto.UINT8, [1], [0]),
|
||||
]
|
||||
|
||||
inputs = [
|
||||
helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 3]),
|
||||
]
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
"NNAPI_Internal_uint8_Test",
|
||||
inputs,
|
||||
[helper.make_tensor_value_info('Y', TensorProto.FLOAT, [3, 2])],
|
||||
initializers
|
||||
)
|
||||
|
||||
model = helper.make_model(graph)
|
||||
onnx.save(model, model_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GenerateModel('nnapi_internal_uint8_support.onnx')
|
||||
Loading…
Reference in a new issue