diff --git a/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc b/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc index fab1dfdfa3..287c5f4148 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc @@ -79,6 +79,63 @@ bool DnnlDefaultMultiInputNodeCapability::IsTypeSupported(const Node* node) cons return all_inputs_supported; } +// DnnlDefaultOptionalMultiInputNodeCapability +//------------------------------------- +DnnlDefaultOptionalMultiInputNodeCapability::DnnlDefaultOptionalMultiInputNodeCapability(rule_map op_rules) { + for (const auto& rule : op_rules) { + // Get the rules into our map and count the number of mandatory inputs + op_rules_[rule.first] = rule.second; + + // Check if the input is mandatory + if (rule.second.first) { + ++num_mandatory; + } + } +} + +bool DnnlDefaultOptionalMultiInputNodeCapability::Supported(const Node* node, const GraphViewer& graph_viewer) const { + ORT_UNUSED_PARAMETER(graph_viewer); + if (!IsTypeSupported(node)) return false; + return true; +} + +unsigned int DnnlDefaultOptionalMultiInputNodeCapability::GetNumMandatoryInputs() { + return num_mandatory; +} + +bool DnnlDefaultOptionalMultiInputNodeCapability::IsTypeSupported(const Node* node) const { + // Get the node list its size + auto node_inputs = node->InputDefs(); + auto num_nodes = node_inputs.size(); + + // We need to make sure that we have at least the mandatory inputs + if (num_mandatory <= num_nodes) { + // Iterate over each entry to check if the input is available, optional and supported + for (const auto& rule : op_rules_) { + if (rule.first < num_nodes) { + // Get the target node + auto node_input = node_inputs[rule.first]; + + // If we found the node we want and it is valid + if (node_input->TypeAsProto() != nullptr) { + // Get its datatype + ORT_DataType node_datatype = static_cast(node_input->TypeAsProto()->tensor_type().elem_type()); + + // If the datatype is NOT on the supported list, we dont support the op + if (rule.second.second.find(node_datatype) == rule.second.second.end()) { + return false; + } + } + } + } + // If the node doesn't have the minimum mandatory inputs + } else { + return false; + } + + return true; +} + // DnnlPoolNodeCapability class //------------------------------------- bool DnnlPoolNodeCapability::Supported(const Node* node, const GraphViewer& graph_viewer) const { @@ -920,4 +977,10 @@ bool DnnlCastNodeCapability::IsCastSupported(const Node* node) const { return false; } +bool DnnlDequantizeLinearNodeCapability::Supported(const Node* node, const GraphViewer& graph_viewer) const { + ORT_UNUSED_PARAMETER(graph_viewer); + if (!IsTypeSupported(node)) return false; + return true; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/dnnl_node_capability.h b/onnxruntime/core/providers/dnnl/dnnl_node_capability.h index 36adcb0d34..bb3916f12c 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_node_capability.h +++ b/onnxruntime/core/providers/dnnl/dnnl_node_capability.h @@ -112,6 +112,45 @@ class DnnlDefaultMultiInputNodeCapability : public DnnlNodeCapability { private: std::vector> inputTypes_; }; +/* + * Works similar to the `DnnlDefaultMultiInputNodeCapability` class except that this + * will check the input of all input nodes and supports optional inputs with different + * supported datatypes by using the input position to evaluate if the node is supported. + * + * Example usage: + * rule_map rules = { + * {0, {true, {type_uint8, type_int8, type_int32}}}, + * {1, {true, {type_float32}}}, + * {2, {false, {type_uint8, type_int8, type_int32}}}, + * }; + * DnnlDefaultOptionalMultiInputNodeCapability(rules) + * + * In general node_data consist of a tuple that has: + * (INPUT_POS, , {list, of, suppurted, types}) + * + * We always asume that the input 0 of the map corresponds to the node input with index 0, + * so if a node has M mandatory inputs and O optional, the map should contain the first + * N inputs followed by the other M optional. + * + * The evaluation will be done in order of appearance, so if we have K number of rules + * with K = M + O, and the evaluated node has M+1 inputs, only the first optional rule (O[0]) + * will be evaluated. + */ +class DnnlDefaultOptionalMultiInputNodeCapability : public DnnlNodeCapability { + public: + typedef std::map>> rule_map; + DnnlDefaultOptionalMultiInputNodeCapability(rule_map op_rules); + + bool Supported(const Node* node, const GraphViewer& graph_viewer) const override; + + protected: + unsigned int num_mandatory = 0; + unsigned int GetNumMandatoryInputs(); + bool IsTypeSupported(const Node* node) const; + + private: + rule_map op_rules_; +}; /** * Decide if a Pool op is supported by DnnlExecutionProvider @@ -358,5 +397,24 @@ class DnnlCastNodeCapability : public DnnlDefaultNodeCapability { bool IsCastSupported(const Node* node) const; }; +class DnnlDequantizeLinearNodeCapability : public DnnlDefaultOptionalMultiInputNodeCapability { + public: + enum InputTensors : int { + IN_X = 0, + IN_X_SCALE = 1, + IN_X_ZERO_POINT = 2, // Optional + }; + DnnlDequantizeLinearNodeCapability() + // ONNX spec requires x and zp to support int32 but + // OneDNN doesn't support it on GPU + : DnnlDefaultOptionalMultiInputNodeCapability({ + {IN_X, {true, {type_uint8, type_int8}}}, + {IN_X_SCALE, {true, {type_float32}}}, + {IN_X_ZERO_POINT, {false, {type_uint8, type_int8}}}, + }) {} + + bool Supported(const Node* node, const GraphViewer& graph_viewer) const override; +}; + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc b/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc index 3ee03d3047..0302febf96 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc @@ -13,6 +13,7 @@ DnnlOpManager::DnnlOpManager() { dnnl_ops_map_.emplace(std::make_pair("BiasGelu", std::unique_ptr(new DnnlDefaultNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("Cast", std::unique_ptr(new DnnlCastNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("Conv", std::unique_ptr(new DnnlDefaultNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("DequantizeLinear", std::unique_ptr(new DnnlDequantizeLinearNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("Div", std::unique_ptr(new DnnlBinaryNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("DynamicQuantizeLinear", std::unique_ptr(new DnnlDynamicQuantizeLinearNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("Elu", std::unique_ptr(new DnnlElementwiseCapability()))); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_dequantizelinear.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_dequantizelinear.cc new file mode 100644 index 0000000000..cbd92f4a21 --- /dev/null +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_dequantizelinear.cc @@ -0,0 +1,175 @@ +// Copyright(C) 2022 Intel Corporation +// Licensed under the MIT License + +#include "dnnl_dequantizelinear.h" +#include "dnnl_subgraph.h" +#include "dnnl_subgraph_primitive.h" + +namespace onnxruntime { +namespace ort_dnnl { + +/* + y = (x - x_zero_point) * x_scale. + 'x_scale' and 'x_zero_point' must have same shape, and can be either a scalar + for per-tensor or per layer quantization, or a 1-D tensor for per-axis quantization. + 'x_zero_point' and 'x' must have same type. 'x' and 'y' must have same shape. + In the case of dequantizing int32, there's no zero point (zero point is supposed to be 0). +*/ +void DnnlDequantizeLinear::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { + // Get engine + auto dnnl_engine = sp.GetEngine(); + + // Validate dims and datatypes + ValidateDims(sp, node); + ValidateType(sp, node); + + // Check if scale and zp are scalars + bool isScalar = sp.IsScalar(node.Input(IN_X_SCALE)); + + // Get the x and scale mem + auto x_mem = sp.GetMemory(node.Input(IN_X)); + auto x_scale_mem = sp.GetMemory(node.Input(IN_X_SCALE)); + // Move to GPU if available + x_mem = sp.GetMemoryAndReshape(node.Input(IN_X), x_mem.get_desc(), dnnl_engine); + x_scale_mem = sp.GetMemoryAndReshape(node.Input(IN_X_SCALE), x_scale_mem.get_desc(), dnnl_engine); + // Get descs + auto x_md = x_mem.get_desc(); + auto x_scale_md = x_scale_mem.get_desc(); + auto x_dims = x_md.dims().size(); + + // Fix scale dims + int64_t axis = GetAxis(node, x_dims); + // Check if axis is negative and fix it + if (axis < 0) { + axis += x_dims; + } + // If scale is a vector, add padding for broadcasting + if (!isScalar) { + Padd(&x_scale_md, static_cast(axis + 1), x_dims); + } + + // Create dst mem + auto dst_md = dnnl::memory::desc(x_md.dims(), node.Output(OUT_Y).Type(), dnnl::memory::format_tag::any); + dnnl::memory dst_mem; + + // If zero point exists and we are NOT dequantizing int32, then substract zp from x and scale + if (node.Input(IN_X_ZERO_POINT).Exists() && + (x_mem.get_desc().data_type() != dnnl::memory::data_type::s32)) { + // Get Zero point + auto x_zp_mem = sp.GetMemory(node.Input(IN_X_ZERO_POINT)); + // Get mds for operands + auto x_zp_md = x_zp_mem.get_desc(); + + // Prepare the zp to prevent broacasting errors + if (isScalar) { + // For scalar zp + Padd(&x_zp_md, x_dims, false); + } else { + // For N-D zp + Padd(&x_zp_md, static_cast(axis) + 1, x_md.dims().size()); + } + + // Create binary desc + auto binary_d = dnnl::binary::desc(dnnl::algorithm::binary_sub, x_md, x_zp_md, dst_md); + // Add post op scale + dnnl::post_ops binary_ops; + dnnl::primitive_attr binary_attr; + binary_ops.append_binary(dnnl::algorithm::binary_mul, x_scale_md); + binary_attr.set_post_ops(binary_ops); + // Add post op to scale result + auto binary_pd = dnnl::binary::primitive_desc(binary_d, binary_attr, dnnl_engine); + // Move to GPU if available + x_zp_mem = sp.GetMemoryAndReshape(node.Input(IN_X_ZERO_POINT), x_zp_md, dnnl_engine); + // Create primitive and set dst mem + dst_mem = dnnl::memory(binary_pd.dst_desc(), dnnl_engine); + auto binary_prim = dnnl::binary(binary_pd); + + sp.AddPrimitive(binary_prim, {{DNNL_ARG_SRC_0, x_mem}, + {DNNL_ARG_SRC_1, x_zp_mem}, + {DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, x_scale_mem}, + {DNNL_ARG_DST, dst_mem}}); + + // If zp doesn't exists or we are dequantizing from int32, only need to scale + } else { + // Create binary and primitive desc + auto binary_d = dnnl::binary::desc(dnnl::algorithm::binary_mul, x_md, x_scale_md, dst_md); + auto binary_pd = dnnl::binary::primitive_desc(binary_d, dnnl_engine); + + // Create primitive + dst_mem = dnnl::memory(binary_pd.dst_desc(), dnnl_engine); + auto binary_prim = dnnl::binary(binary_pd); + + // We recycle the x_mem + sp.AddPrimitive(binary_prim, {{DNNL_ARG_SRC_0, x_mem}, + {DNNL_ARG_SRC_1, x_scale_mem}, + {DNNL_ARG_DST, dst_mem}}); + } + + // Set the output mem + if (sp.IsScalar(node.Input(IN_X))) { + sp.SetMemory(node.Output(OUT_Y), dst_mem, false, true); + } else { + sp.SetMemory(node.Output(OUT_Y), dst_mem); + } +} + +void DnnlDequantizeLinear::Padd(dnnl::memory::desc* target_md, size_t front_pad, size_t back_pad) { + // Pads an input to broadcast the op correctly + auto target_dims = target_md->dims(); + + // Add front padding + while (target_dims.size() < front_pad) { + target_dims.insert(target_dims.begin(), 1); + } + // Add back padd + while (target_dims.size() < back_pad) { + target_dims.insert(target_dims.end(), 1); + } + + *target_md = target_md->reshape(target_dims); +} + +int64_t DnnlDequantizeLinear::GetAxis(DnnlNode& node, size_t x_dims) { + // We need to do sign comparisons so we have to cast + int64_t sig_x_dims = static_cast(x_dims); + auto attr = node.Attributes().find("axis"); + // If axis is provided, make sure axis is an integer and + // has a range of [-r, r] + if (attr != node.Attributes().end()) { + int64_t axis2 = attr->second().i(); + if (attr->second().type() == ::ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INT && + (((axis2 <= 0) && (axis2 >= -sig_x_dims)) || + ((axis2 >= 0) && (axis2 <= (sig_x_dims - 1))))) { + return attr->second().i(); + } + } + // Return the default value + return 1; +} + +void DnnlDequantizeLinear::ValidateDims(DnnlSubgraphPrimitive& sp, DnnlNode& node) { + // We only need to validate when zp is provided + if (node.Input(IN_X_ZERO_POINT).Exists()) { + auto x_scale_dims = sp.GetMemory(node.Input(IN_X_SCALE)).get_desc().dims(); + auto x_zp_dims = sp.GetMemory(node.Input(IN_X_ZERO_POINT)).get_desc().dims(); + + if (x_zp_dims != x_scale_dims) { + ORT_THROW("x_scale and x_zero_point dimensions does not match"); + } + } +} + +void DnnlDequantizeLinear::ValidateType(DnnlSubgraphPrimitive& sp, DnnlNode& node) { + // If zp exists check its dataype + if (node.Input(IN_X_ZERO_POINT).Exists()) { + auto x_md = sp.GetMemory(node.Input(IN_X)).get_desc(); + auto x_zp_md = sp.GetMemory(node.Input(IN_X_ZERO_POINT)).get_desc(); + + if (x_md.data_type() != x_zp_md.data_type()) { + ORT_THROW("x and x_zero_point have different datatypes"); + } + } +} + +} // namespace ort_dnnl +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_dequantizelinear.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_dequantizelinear.h new file mode 100644 index 0000000000..3609e23214 --- /dev/null +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_dequantizelinear.h @@ -0,0 +1,34 @@ +// Copyright(C) 2022 Intel Corporation +// Licensed under the MIT License + +#pragma once +#include "dnnl_subgraph.h" +#include "dnnl_subgraph_primitive.h" + +namespace onnxruntime { +namespace ort_dnnl { + +class DnnlDequantizeLinear { + public: + enum InputTensors : int { + IN_X = 0, + IN_X_SCALE = 1, + IN_X_ZERO_POINT = 2, // Optional + }; + + enum OutputTensors : int { + OUT_Y = 0, + }; + + DnnlDequantizeLinear() = default; + void CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node); + + private: + int64_t GetAxis(DnnlNode& node, size_t x_dims); + void Padd(dnnl::memory::desc* target, size_t front_pad, size_t back_pad); + void ValidateDims(DnnlSubgraphPrimitive& sp, DnnlNode& node); + void ValidateType(DnnlSubgraphPrimitive& sp, DnnlNode& node); +}; + +} // namespace ort_dnnl +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc index 00ec3c8f89..2ef3c08af4 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc @@ -7,6 +7,7 @@ #include "dnnl_binary.h" #include "dnnl_cast.h" #include "dnnl_conv.h" +#include "dnnl_dequantizelinear.h" #include "dnnl_dynamicquantizelinear.h" #include "dnnl_elementwise.h" #include "dnnl_gelu.h" @@ -140,6 +141,8 @@ void DnnlSubgraphPrimitive::AddKernels() { DnnlCast().CreatePrimitive(*this, node); } else if (node.OpType() == "Conv" || node.OpType() == "ConvRelu") { DnnlConv().CreatePrimitive(*this, node); + } else if (node.OpType() == "DequantizeLinear") { + DnnlDequantizeLinear().CreatePrimitive(*this, node); } else if (node.OpType() == "DynamicQuantizeLinear") { DnnlDynamicQuantizeLinear().CreatePrimitive(*this, node); } else if (elementwise_ops.count(node.OpType())) {