From 6c005bfdbc28a04c521c364bc6b1df0f26e6acbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erick=20Alejandro=20Mu=C3=B1oz=20Alvarado?= Date: Tue, 29 Mar 2022 09:16:01 -0600 Subject: [PATCH] Enabled Cast operator on OneDNN EP (#11023) --- .../providers/dnnl/dnnl_node_capability.cc | 70 +++++++++++ .../providers/dnnl/dnnl_node_capability.h | 21 ++++ .../core/providers/dnnl/dnnl_op_manager.cc | 1 + .../core/providers/dnnl/subgraph/dnnl_cast.cc | 112 ++++++++++++++++++ .../core/providers/dnnl/subgraph/dnnl_cast.h | 29 +++++ .../dnnl/subgraph/dnnl_subgraph_primitive.cc | 3 + 6 files changed, 236 insertions(+) create mode 100644 onnxruntime/core/providers/dnnl/subgraph/dnnl_cast.cc create mode 100644 onnxruntime/core/providers/dnnl/subgraph/dnnl_cast.h diff --git a/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc b/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc index a2e68774cd..fab1dfdfa3 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc @@ -849,5 +849,75 @@ bool DnnlQAttentionNodeCapability::IsDimensionSupported(const Node* node) const return true; } +// DnnlCastNodeCapability class +//------------------------------------- +DnnlCastNodeCapability::DnnlCastNodeCapability(std::vector validTypes) + : DnnlDefaultNodeCapability(validTypes) { + for (ORT_DataType datatype : validTypes) + validTypes_.push_back(datatype); +} + +bool DnnlCastNodeCapability::Supported(const Node* node, const GraphViewer& graph_viewer) const { + ORT_UNUSED_PARAMETER(graph_viewer); + if (!IsTypeSupported(node) || !IsCastSupported(node)) return false; + return true; +} + +bool DnnlCastNodeCapability::IsCastSupported(const Node* node) const { + // Get input and attributes + const NodeAttributes& node_attr = node->GetAttributes(); + auto node_input = node->InputDefs(); + auto attr_to = node_attr.find("to"); + + // If we have valid results + if (!node_input.empty() && + node_input[0]->TypeAsProto() != nullptr && + attr_to != node_attr.end()) { + + // Get the input and cast target type + auto input_type = node_input[0]->TypeAsProto()->tensor_type().elem_type(); + auto cast_type = attr_to->second().i(); + + // Some FP16 operations are not supported yet on CPU +#if defined(DNNL_CPU_RUNTIME) + // From uint8 and int8 => To Float16 + if ((input_type == type_uint8 || + input_type == type_int8) && + cast_type == type_float16) { + return false; + } + // From uint32 => To Float16 and BFloat16 + if (input_type == type_int32 && + (cast_type == type_float16 || + cast_type == type_bfloat16)) { + return false; + } + // From Float16 => To int(uint8, int8 and int32) and BFloat16 + if (input_type == type_float16 && + (cast_type == type_uint8 || + cast_type == type_int8 || + cast_type == type_int32 || + cast_type == type_bfloat16)) { + return false; + } + // From BFloat16 => To int32 and BFloat16 + if (input_type == type_bfloat16 && + (cast_type == type_int32 || + cast_type == type_float16)) { + return false; + } +#endif // defined(DNNL_CPU_RUNTIME) + + // Check if the cast type is supported + for (auto validType : validTypes_) { + if (validType == cast_type) { + return true; + } + } + + } + + return false; +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/dnnl_node_capability.h b/onnxruntime/core/providers/dnnl/dnnl_node_capability.h index 1605475d70..36adcb0d34 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_node_capability.h +++ b/onnxruntime/core/providers/dnnl/dnnl_node_capability.h @@ -337,5 +337,26 @@ class DnnlQAttentionNodeCapability : public DnnlDefaultNodeCapability { private: bool IsDimensionSupported(const Node* node) const; }; + +/** + * We need to access the valid input types in order to check if the cast target + * is valid, and since inputTypes_ is private, we generate our own copy + */ +class DnnlCastNodeCapability : public DnnlDefaultNodeCapability { + public: + DnnlCastNodeCapability(std::vector validTypes = {type_float32, + type_float16, + type_bfloat16, + type_int32, + type_int8, + type_uint8}); + + bool Supported(const Node* node, const GraphViewer& graph_viewer) const override; + + private: + std::vector validTypes_; + bool IsCastSupported(const Node* node) const; +}; + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc b/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc index 43e3dc1508..39fb8401f9 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_op_manager.cc @@ -11,6 +11,7 @@ DnnlOpManager::DnnlOpManager() { dnnl_ops_map_.emplace(std::make_pair("AveragePool", std::unique_ptr(new DnnlPoolNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("BatchNormalization", std::unique_ptr(new DnnlBatchNormalizationNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("BiasGelu", std::unique_ptr(new DnnlDefaultNodeCapability()))); + dnnl_ops_map_.emplace(std::make_pair("Cast", std::unique_ptr(new DnnlCastNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("Conv", std::unique_ptr(new DnnlDefaultNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("Div", std::unique_ptr(new DnnlBinaryNodeCapability()))); dnnl_ops_map_.emplace(std::make_pair("DynamicQuantizeLinear", std::unique_ptr(new DnnlDynamicQuantizeLinearNodeCapability()))); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_cast.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_cast.cc new file mode 100644 index 0000000000..1a21d290e2 --- /dev/null +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_cast.cc @@ -0,0 +1,112 @@ +// Copyright(C) 2021 Intel Corporation +// Licensed under the MIT License + +#include "dnnl_cast.h" +#include "dnnl_subgraph.h" +#include "dnnl_subgraph_primitive.h" + +namespace onnxruntime { +namespace ort_dnnl { + +DnnlCast::DnnlCast() {} + +void DnnlCast::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { + + // Get the DNNL engine + auto dnnl_engine = sp.GetEngine(); + + // Get the memory from the input node + auto src_mem = sp.GetMemory(node.Input(IN_INPUT)); + auto src_tag = node.Input(IN_INPUT).Format(); + auto src_md = src_mem.get_desc(); + auto src_dims = src_md.dims(); + + // dst characteristics + dnnl::memory::data_type dst_type; + dnnl::memory::format_tag dst_tag; + + // Get the target data type + auto dst_type_desc = GetTo(node); + + // Check fot the target datat ype + switch (dst_type_desc) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + dst_type = dnnl::memory::data_type::f32; + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { + dst_type = dnnl::memory::data_type::f16; + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: { + dst_type = dnnl::memory::data_type::bf16; + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT32: { + dst_type = dnnl::memory::data_type::s32; + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT8: { + dst_type = dnnl::memory::data_type::s8; + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { + dst_type = dnnl::memory::data_type::u8; + break; + } + default: + ORT_THROW("Unsupported data type: ", dst_type_desc); + break; + } + // Be aware that the output memory will be in plain format + // and depending on the operation you do next, this wont be as + // efficient as you'd like + // If the format tag is any + if (src_tag == dnnl::memory::format_tag::any) { + // Define a plain data ND format + dst_tag = sp.GetDnnlFormat(src_dims.size()); + } else { + // Else use the same as the source + dst_tag = src_tag; + } + + // Generate the dst memory descriptor + auto dst_md = dnnl::memory::desc(src_md.dims(), dst_type, dst_tag); + + // Create the reorder primitive descriptor. + auto reorder_pd = dnnl::reorder::primitive_desc(dnnl_engine, src_md, dnnl_engine, dst_md); + // Get the dst memory + auto dst_mem = dnnl::memory(reorder_pd.dst_desc(), dnnl_engine); + + // If using GPU this will move the memory from the CPU to the GPU. + src_mem = sp.GetMemoryAndReshape(node.Input(IN_INPUT), reorder_pd.src_desc(), dnnl_engine); + + // OneDNN uses reorder to cast the src_md data to the dst_md data type + auto reorder = dnnl::reorder(reorder_pd); + + // Add primitive to the graph + sp.AddPrimitive(reorder, {{DNNL_ARG_SRC, src_mem}, + {DNNL_ARG_DST, dst_mem}}); + + // Support scalar return values + if (sp.IsScalar(node.Input(OUT_OUTPUT))) { + sp.SetMemory(node.Output(OUT_OUTPUT), dst_mem, false, true); + } else { + sp.SetMemory(node.Output(OUT_OUTPUT), dst_mem); + } + +} + +int64_t DnnlCast::GetTo(DnnlNode& node) { + // Get the attribute + auto attr = node.Attributes().find("to"); + if (attr != node.Attributes().end()) { + return attr->second().i(); + } else { + // to attribute should always exist in order to cast + ORT_THROW("TO(CAST TARGET DATA TYPE) DOES NOT EXIST"); + } +} + +} // namespace ort_dnnl +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_cast.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_cast.h new file mode 100644 index 0000000000..8c3f804973 --- /dev/null +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_cast.h @@ -0,0 +1,29 @@ +// Copyright(C) 2021 Intel Corporation +// Licensed under the MIT License + +#pragma once +#include "dnnl_subgraph.h" +#include "dnnl_subgraph_primitive.h" + +namespace onnxruntime { +namespace ort_dnnl { + +class DnnlCast { + public: + enum InputTensors : int { + IN_INPUT = 0 + }; + + enum OutputTensors : int { + OUT_OUTPUT = 0 + }; + + DnnlCast(); + void CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node); + + private: + int64_t GetTo(DnnlNode& node); +}; + +} // namespace ort_dnnl +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc index 8cd2fe14cb..dbf0104f07 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc @@ -5,6 +5,7 @@ #include "dnnl_batchnorm.h" #include "dnnl_binary.h" +#include "dnnl_cast.h" #include "dnnl_conv.h" #include "dnnl_dynamicquantizelinear.h" #include "dnnl_elementwise.h" @@ -135,6 +136,8 @@ void DnnlSubgraphPrimitive::AddKernels() { DnnlBatchNorm().CreatePrimitive(*this, node); } else if (binary_ops.count(node.OpType())) { DnnlBinary().CreatePrimitive(*this, node); + } else if (node.OpType() == "Cast") { + DnnlCast().CreatePrimitive(*this, node); } else if (node.OpType() == "Conv" || node.OpType() == "ConvRelu") { DnnlConv().CreatePrimitive(*this, node); } else if (node.OpType() == "DynamicQuantizeLinear") {