Add Dequantize Linear operator on OneDNN EP (#11036)

This commit is contained in:
Erick Muñoz 2022-04-05 09:32:26 -06:00 committed by GitHub
parent 8db180c245
commit 25fdf8b167
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 334 additions and 0 deletions

View file

@ -79,6 +79,63 @@ bool DnnlDefaultMultiInputNodeCapability::IsTypeSupported(const Node* node) cons
return all_inputs_supported;
}
// DnnlDefaultOptionalMultiInputNodeCapability
//-------------------------------------
DnnlDefaultOptionalMultiInputNodeCapability::DnnlDefaultOptionalMultiInputNodeCapability(rule_map op_rules) {
for (const auto& rule : op_rules) {
// Get the rules into our map and count the number of mandatory inputs
op_rules_[rule.first] = rule.second;
// Check if the input is mandatory
if (rule.second.first) {
++num_mandatory;
}
}
}
bool DnnlDefaultOptionalMultiInputNodeCapability::Supported(const Node* node, const GraphViewer& graph_viewer) const {
ORT_UNUSED_PARAMETER(graph_viewer);
if (!IsTypeSupported(node)) return false;
return true;
}
unsigned int DnnlDefaultOptionalMultiInputNodeCapability::GetNumMandatoryInputs() {
return num_mandatory;
}
bool DnnlDefaultOptionalMultiInputNodeCapability::IsTypeSupported(const Node* node) const {
// Get the node list its size
auto node_inputs = node->InputDefs();
auto num_nodes = node_inputs.size();
// We need to make sure that we have at least the mandatory inputs
if (num_mandatory <= num_nodes) {
// Iterate over each entry to check if the input is available, optional and supported
for (const auto& rule : op_rules_) {
if (rule.first < num_nodes) {
// Get the target node
auto node_input = node_inputs[rule.first];
// If we found the node we want and it is valid
if (node_input->TypeAsProto() != nullptr) {
// Get its datatype
ORT_DataType node_datatype = static_cast<ORT_DataType>(node_input->TypeAsProto()->tensor_type().elem_type());
// If the datatype is NOT on the supported list, we dont support the op
if (rule.second.second.find(node_datatype) == rule.second.second.end()) {
return false;
}
}
}
}
// If the node doesn't have the minimum mandatory inputs
} else {
return false;
}
return true;
}
// DnnlPoolNodeCapability class
//-------------------------------------
bool DnnlPoolNodeCapability::Supported(const Node* node, const GraphViewer& graph_viewer) const {
@ -920,4 +977,10 @@ bool DnnlCastNodeCapability::IsCastSupported(const Node* node) const {
return false;
}
bool DnnlDequantizeLinearNodeCapability::Supported(const Node* node, const GraphViewer& graph_viewer) const {
ORT_UNUSED_PARAMETER(graph_viewer);
if (!IsTypeSupported(node)) return false;
return true;
}
} // namespace onnxruntime

View file

@ -112,6 +112,45 @@ class DnnlDefaultMultiInputNodeCapability : public DnnlNodeCapability {
private:
std::vector<std::unordered_set<ORT_DataType>> inputTypes_;
};
/*
* Works similar to the `DnnlDefaultMultiInputNodeCapability` class except that this
* will check the input of all input nodes and supports optional inputs with different
* supported datatypes by using the input position to evaluate if the node is supported.
*
* Example usage:
* rule_map rules = {
* {0, {true, {type_uint8, type_int8, type_int32}}},
* {1, {true, {type_float32}}},
* {2, {false, {type_uint8, type_int8, type_int32}}},
* };
* DnnlDefaultOptionalMultiInputNodeCapability(rules)
*
* In general node_data consist of a tuple that has:
* (INPUT_POS, <IsThisInputMandatory?(true or false)>, {list, of, suppurted, types})
*
* We always asume that the input 0 of the map corresponds to the node input with index 0,
* so if a node has M mandatory inputs and O optional, the map should contain the first
* N inputs followed by the other M optional.
*
* The evaluation will be done in order of appearance, so if we have K number of rules
* with K = M + O, and the evaluated node has M+1 inputs, only the first optional rule (O[0])
* will be evaluated.
*/
class DnnlDefaultOptionalMultiInputNodeCapability : public DnnlNodeCapability {
public:
typedef std::map<size_t, std::pair<bool, std::unordered_set<ORT_DataType>>> rule_map;
DnnlDefaultOptionalMultiInputNodeCapability(rule_map op_rules);
bool Supported(const Node* node, const GraphViewer& graph_viewer) const override;
protected:
unsigned int num_mandatory = 0;
unsigned int GetNumMandatoryInputs();
bool IsTypeSupported(const Node* node) const;
private:
rule_map op_rules_;
};
/**
* Decide if a Pool op is supported by DnnlExecutionProvider
@ -358,5 +397,24 @@ class DnnlCastNodeCapability : public DnnlDefaultNodeCapability {
bool IsCastSupported(const Node* node) const;
};
class DnnlDequantizeLinearNodeCapability : public DnnlDefaultOptionalMultiInputNodeCapability {
public:
enum InputTensors : int {
IN_X = 0,
IN_X_SCALE = 1,
IN_X_ZERO_POINT = 2, // Optional
};
DnnlDequantizeLinearNodeCapability()
// ONNX spec requires x and zp to support int32 but
// OneDNN doesn't support it on GPU
: DnnlDefaultOptionalMultiInputNodeCapability({
{IN_X, {true, {type_uint8, type_int8}}},
{IN_X_SCALE, {true, {type_float32}}},
{IN_X_ZERO_POINT, {false, {type_uint8, type_int8}}},
}) {}
bool Supported(const Node* node, const GraphViewer& graph_viewer) const override;
};
} // namespace onnxruntime

View file

@ -13,6 +13,7 @@ DnnlOpManager::DnnlOpManager() {
dnnl_ops_map_.emplace(std::make_pair("BiasGelu", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("Cast", std::unique_ptr<DnnlNodeCapability>(new DnnlCastNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("Conv", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("DequantizeLinear", std::unique_ptr<DnnlNodeCapability>(new DnnlDequantizeLinearNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("Div", std::unique_ptr<DnnlNodeCapability>(new DnnlBinaryNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("DynamicQuantizeLinear", std::unique_ptr<DnnlNodeCapability>(new DnnlDynamicQuantizeLinearNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("Elu", std::unique_ptr<DnnlNodeCapability>(new DnnlElementwiseCapability())));

View file

@ -0,0 +1,175 @@
// Copyright(C) 2022 Intel Corporation
// Licensed under the MIT License
#include "dnnl_dequantizelinear.h"
#include "dnnl_subgraph.h"
#include "dnnl_subgraph_primitive.h"
namespace onnxruntime {
namespace ort_dnnl {
/*
y = (x - x_zero_point) * x_scale.
'x_scale' and 'x_zero_point' must have same shape, and can be either a scalar
for per-tensor or per layer quantization, or a 1-D tensor for per-axis quantization.
'x_zero_point' and 'x' must have same type. 'x' and 'y' must have same shape.
In the case of dequantizing int32, there's no zero point (zero point is supposed to be 0).
*/
void DnnlDequantizeLinear::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) {
// Get engine
auto dnnl_engine = sp.GetEngine();
// Validate dims and datatypes
ValidateDims(sp, node);
ValidateType(sp, node);
// Check if scale and zp are scalars
bool isScalar = sp.IsScalar(node.Input(IN_X_SCALE));
// Get the x and scale mem
auto x_mem = sp.GetMemory(node.Input(IN_X));
auto x_scale_mem = sp.GetMemory(node.Input(IN_X_SCALE));
// Move to GPU if available
x_mem = sp.GetMemoryAndReshape(node.Input(IN_X), x_mem.get_desc(), dnnl_engine);
x_scale_mem = sp.GetMemoryAndReshape(node.Input(IN_X_SCALE), x_scale_mem.get_desc(), dnnl_engine);
// Get descs
auto x_md = x_mem.get_desc();
auto x_scale_md = x_scale_mem.get_desc();
auto x_dims = x_md.dims().size();
// Fix scale dims
int64_t axis = GetAxis(node, x_dims);
// Check if axis is negative and fix it
if (axis < 0) {
axis += x_dims;
}
// If scale is a vector, add padding for broadcasting
if (!isScalar) {
Padd(&x_scale_md, static_cast<uint64_t>(axis + 1), x_dims);
}
// Create dst mem
auto dst_md = dnnl::memory::desc(x_md.dims(), node.Output(OUT_Y).Type(), dnnl::memory::format_tag::any);
dnnl::memory dst_mem;
// If zero point exists and we are NOT dequantizing int32, then substract zp from x and scale
if (node.Input(IN_X_ZERO_POINT).Exists() &&
(x_mem.get_desc().data_type() != dnnl::memory::data_type::s32)) {
// Get Zero point
auto x_zp_mem = sp.GetMemory(node.Input(IN_X_ZERO_POINT));
// Get mds for operands
auto x_zp_md = x_zp_mem.get_desc();
// Prepare the zp to prevent broacasting errors
if (isScalar) {
// For scalar zp
Padd(&x_zp_md, x_dims, false);
} else {
// For N-D zp
Padd(&x_zp_md, static_cast<uint64_t>(axis) + 1, x_md.dims().size());
}
// Create binary desc
auto binary_d = dnnl::binary::desc(dnnl::algorithm::binary_sub, x_md, x_zp_md, dst_md);
// Add post op scale
dnnl::post_ops binary_ops;
dnnl::primitive_attr binary_attr;
binary_ops.append_binary(dnnl::algorithm::binary_mul, x_scale_md);
binary_attr.set_post_ops(binary_ops);
// Add post op to scale result
auto binary_pd = dnnl::binary::primitive_desc(binary_d, binary_attr, dnnl_engine);
// Move to GPU if available
x_zp_mem = sp.GetMemoryAndReshape(node.Input(IN_X_ZERO_POINT), x_zp_md, dnnl_engine);
// Create primitive and set dst mem
dst_mem = dnnl::memory(binary_pd.dst_desc(), dnnl_engine);
auto binary_prim = dnnl::binary(binary_pd);
sp.AddPrimitive(binary_prim, {{DNNL_ARG_SRC_0, x_mem},
{DNNL_ARG_SRC_1, x_zp_mem},
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, x_scale_mem},
{DNNL_ARG_DST, dst_mem}});
// If zp doesn't exists or we are dequantizing from int32, only need to scale
} else {
// Create binary and primitive desc
auto binary_d = dnnl::binary::desc(dnnl::algorithm::binary_mul, x_md, x_scale_md, dst_md);
auto binary_pd = dnnl::binary::primitive_desc(binary_d, dnnl_engine);
// Create primitive
dst_mem = dnnl::memory(binary_pd.dst_desc(), dnnl_engine);
auto binary_prim = dnnl::binary(binary_pd);
// We recycle the x_mem
sp.AddPrimitive(binary_prim, {{DNNL_ARG_SRC_0, x_mem},
{DNNL_ARG_SRC_1, x_scale_mem},
{DNNL_ARG_DST, dst_mem}});
}
// Set the output mem
if (sp.IsScalar(node.Input(IN_X))) {
sp.SetMemory(node.Output(OUT_Y), dst_mem, false, true);
} else {
sp.SetMemory(node.Output(OUT_Y), dst_mem);
}
}
void DnnlDequantizeLinear::Padd(dnnl::memory::desc* target_md, size_t front_pad, size_t back_pad) {
// Pads an input to broadcast the op correctly
auto target_dims = target_md->dims();
// Add front padding
while (target_dims.size() < front_pad) {
target_dims.insert(target_dims.begin(), 1);
}
// Add back padd
while (target_dims.size() < back_pad) {
target_dims.insert(target_dims.end(), 1);
}
*target_md = target_md->reshape(target_dims);
}
int64_t DnnlDequantizeLinear::GetAxis(DnnlNode& node, size_t x_dims) {
// We need to do sign comparisons so we have to cast
int64_t sig_x_dims = static_cast<uint64_t>(x_dims);
auto attr = node.Attributes().find("axis");
// If axis is provided, make sure axis is an integer and
// has a range of [-r, r]
if (attr != node.Attributes().end()) {
int64_t axis2 = attr->second().i();
if (attr->second().type() == ::ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INT &&
(((axis2 <= 0) && (axis2 >= -sig_x_dims)) ||
((axis2 >= 0) && (axis2 <= (sig_x_dims - 1))))) {
return attr->second().i();
}
}
// Return the default value
return 1;
}
void DnnlDequantizeLinear::ValidateDims(DnnlSubgraphPrimitive& sp, DnnlNode& node) {
// We only need to validate when zp is provided
if (node.Input(IN_X_ZERO_POINT).Exists()) {
auto x_scale_dims = sp.GetMemory(node.Input(IN_X_SCALE)).get_desc().dims();
auto x_zp_dims = sp.GetMemory(node.Input(IN_X_ZERO_POINT)).get_desc().dims();
if (x_zp_dims != x_scale_dims) {
ORT_THROW("x_scale and x_zero_point dimensions does not match");
}
}
}
void DnnlDequantizeLinear::ValidateType(DnnlSubgraphPrimitive& sp, DnnlNode& node) {
// If zp exists check its dataype
if (node.Input(IN_X_ZERO_POINT).Exists()) {
auto x_md = sp.GetMemory(node.Input(IN_X)).get_desc();
auto x_zp_md = sp.GetMemory(node.Input(IN_X_ZERO_POINT)).get_desc();
if (x_md.data_type() != x_zp_md.data_type()) {
ORT_THROW("x and x_zero_point have different datatypes");
}
}
}
} // namespace ort_dnnl
} // namespace onnxruntime

View file

@ -0,0 +1,34 @@
// Copyright(C) 2022 Intel Corporation
// Licensed under the MIT License
#pragma once
#include "dnnl_subgraph.h"
#include "dnnl_subgraph_primitive.h"
namespace onnxruntime {
namespace ort_dnnl {
class DnnlDequantizeLinear {
public:
enum InputTensors : int {
IN_X = 0,
IN_X_SCALE = 1,
IN_X_ZERO_POINT = 2, // Optional
};
enum OutputTensors : int {
OUT_Y = 0,
};
DnnlDequantizeLinear() = default;
void CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node);
private:
int64_t GetAxis(DnnlNode& node, size_t x_dims);
void Padd(dnnl::memory::desc* target, size_t front_pad, size_t back_pad);
void ValidateDims(DnnlSubgraphPrimitive& sp, DnnlNode& node);
void ValidateType(DnnlSubgraphPrimitive& sp, DnnlNode& node);
};
} // namespace ort_dnnl
} // namespace onnxruntime

View file

@ -7,6 +7,7 @@
#include "dnnl_binary.h"
#include "dnnl_cast.h"
#include "dnnl_conv.h"
#include "dnnl_dequantizelinear.h"
#include "dnnl_dynamicquantizelinear.h"
#include "dnnl_elementwise.h"
#include "dnnl_gelu.h"
@ -140,6 +141,8 @@ void DnnlSubgraphPrimitive::AddKernels() {
DnnlCast().CreatePrimitive(*this, node);
} else if (node.OpType() == "Conv" || node.OpType() == "ConvRelu") {
DnnlConv().CreatePrimitive(*this, node);
} else if (node.OpType() == "DequantizeLinear") {
DnnlDequantizeLinear().CreatePrimitive(*this, node);
} else if (node.OpType() == "DynamicQuantizeLinear") {
DnnlDynamicQuantizeLinear().CreatePrimitive(*this, node);
} else if (elementwise_ops.count(node.OpType())) {