Enabled Cast operator on OneDNN EP (#11023)

This commit is contained in:
Erick Alejandro Muñoz Alvarado 2022-03-29 09:16:01 -06:00 committed by GitHub
parent 6a6840d5c6
commit 6c005bfdbc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 236 additions and 0 deletions

View file

@ -849,5 +849,75 @@ bool DnnlQAttentionNodeCapability::IsDimensionSupported(const Node* node) const
return true;
}
// DnnlCastNodeCapability class
//-------------------------------------
DnnlCastNodeCapability::DnnlCastNodeCapability(std::vector<ORT_DataType> validTypes)
: DnnlDefaultNodeCapability(validTypes) {
for (ORT_DataType datatype : validTypes)
validTypes_.push_back(datatype);
}
bool DnnlCastNodeCapability::Supported(const Node* node, const GraphViewer& graph_viewer) const {
ORT_UNUSED_PARAMETER(graph_viewer);
if (!IsTypeSupported(node) || !IsCastSupported(node)) return false;
return true;
}
bool DnnlCastNodeCapability::IsCastSupported(const Node* node) const {
// Get input and attributes
const NodeAttributes& node_attr = node->GetAttributes();
auto node_input = node->InputDefs();
auto attr_to = node_attr.find("to");
// If we have valid results
if (!node_input.empty() &&
node_input[0]->TypeAsProto() != nullptr &&
attr_to != node_attr.end()) {
// Get the input and cast target type
auto input_type = node_input[0]->TypeAsProto()->tensor_type().elem_type();
auto cast_type = attr_to->second().i();
// Some FP16 operations are not supported yet on CPU
#if defined(DNNL_CPU_RUNTIME)
// From uint8 and int8 => To Float16
if ((input_type == type_uint8 ||
input_type == type_int8) &&
cast_type == type_float16) {
return false;
}
// From uint32 => To Float16 and BFloat16
if (input_type == type_int32 &&
(cast_type == type_float16 ||
cast_type == type_bfloat16)) {
return false;
}
// From Float16 => To int(uint8, int8 and int32) and BFloat16
if (input_type == type_float16 &&
(cast_type == type_uint8 ||
cast_type == type_int8 ||
cast_type == type_int32 ||
cast_type == type_bfloat16)) {
return false;
}
// From BFloat16 => To int32 and BFloat16
if (input_type == type_bfloat16 &&
(cast_type == type_int32 ||
cast_type == type_float16)) {
return false;
}
#endif // defined(DNNL_CPU_RUNTIME)
// Check if the cast type is supported
for (auto validType : validTypes_) {
if (validType == cast_type) {
return true;
}
}
}
return false;
}
} // namespace onnxruntime

View file

@ -337,5 +337,26 @@ class DnnlQAttentionNodeCapability : public DnnlDefaultNodeCapability {
private:
bool IsDimensionSupported(const Node* node) const;
};
/**
* We need to access the valid input types in order to check if the cast target
* is valid, and since inputTypes_ is private, we generate our own copy
*/
class DnnlCastNodeCapability : public DnnlDefaultNodeCapability {
public:
DnnlCastNodeCapability(std::vector<ORT_DataType> validTypes = {type_float32,
type_float16,
type_bfloat16,
type_int32,
type_int8,
type_uint8});
bool Supported(const Node* node, const GraphViewer& graph_viewer) const override;
private:
std::vector<ORT_DataType> validTypes_;
bool IsCastSupported(const Node* node) const;
};
} // namespace onnxruntime

View file

@ -11,6 +11,7 @@ DnnlOpManager::DnnlOpManager() {
dnnl_ops_map_.emplace(std::make_pair("AveragePool", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("BatchNormalization", std::unique_ptr<DnnlNodeCapability>(new DnnlBatchNormalizationNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("BiasGelu", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("Cast", std::unique_ptr<DnnlNodeCapability>(new DnnlCastNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("Conv", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("Div", std::unique_ptr<DnnlNodeCapability>(new DnnlBinaryNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("DynamicQuantizeLinear", std::unique_ptr<DnnlNodeCapability>(new DnnlDynamicQuantizeLinearNodeCapability())));

View file

@ -0,0 +1,112 @@
// Copyright(C) 2021 Intel Corporation
// Licensed under the MIT License
#include "dnnl_cast.h"
#include "dnnl_subgraph.h"
#include "dnnl_subgraph_primitive.h"
namespace onnxruntime {
namespace ort_dnnl {
DnnlCast::DnnlCast() {}
void DnnlCast::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) {
// Get the DNNL engine
auto dnnl_engine = sp.GetEngine();
// Get the memory from the input node
auto src_mem = sp.GetMemory(node.Input(IN_INPUT));
auto src_tag = node.Input(IN_INPUT).Format();
auto src_md = src_mem.get_desc();
auto src_dims = src_md.dims();
// dst characteristics
dnnl::memory::data_type dst_type;
dnnl::memory::format_tag dst_tag;
// Get the target data type
auto dst_type_desc = GetTo(node);
// Check fot the target datat ype
switch (dst_type_desc) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
dst_type = dnnl::memory::data_type::f32;
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: {
dst_type = dnnl::memory::data_type::f16;
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: {
dst_type = dnnl::memory::data_type::bf16;
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_INT32: {
dst_type = dnnl::memory::data_type::s32;
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_INT8: {
dst_type = dnnl::memory::data_type::s8;
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_UINT8: {
dst_type = dnnl::memory::data_type::u8;
break;
}
default:
ORT_THROW("Unsupported data type: ", dst_type_desc);
break;
}
// Be aware that the output memory will be in plain format
// and depending on the operation you do next, this wont be as
// efficient as you'd like
// If the format tag is any
if (src_tag == dnnl::memory::format_tag::any) {
// Define a plain data ND format
dst_tag = sp.GetDnnlFormat(src_dims.size());
} else {
// Else use the same as the source
dst_tag = src_tag;
}
// Generate the dst memory descriptor
auto dst_md = dnnl::memory::desc(src_md.dims(), dst_type, dst_tag);
// Create the reorder primitive descriptor.
auto reorder_pd = dnnl::reorder::primitive_desc(dnnl_engine, src_md, dnnl_engine, dst_md);
// Get the dst memory
auto dst_mem = dnnl::memory(reorder_pd.dst_desc(), dnnl_engine);
// If using GPU this will move the memory from the CPU to the GPU.
src_mem = sp.GetMemoryAndReshape(node.Input(IN_INPUT), reorder_pd.src_desc(), dnnl_engine);
// OneDNN uses reorder to cast the src_md data to the dst_md data type
auto reorder = dnnl::reorder(reorder_pd);
// Add primitive to the graph
sp.AddPrimitive(reorder, {{DNNL_ARG_SRC, src_mem},
{DNNL_ARG_DST, dst_mem}});
// Support scalar return values
if (sp.IsScalar(node.Input(OUT_OUTPUT))) {
sp.SetMemory(node.Output(OUT_OUTPUT), dst_mem, false, true);
} else {
sp.SetMemory(node.Output(OUT_OUTPUT), dst_mem);
}
}
int64_t DnnlCast::GetTo(DnnlNode& node) {
// Get the attribute
auto attr = node.Attributes().find("to");
if (attr != node.Attributes().end()) {
return attr->second().i();
} else {
// to attribute should always exist in order to cast
ORT_THROW("TO(CAST TARGET DATA TYPE) DOES NOT EXIST");
}
}
} // namespace ort_dnnl
} // namespace onnxruntime

View file

@ -0,0 +1,29 @@
// Copyright(C) 2021 Intel Corporation
// Licensed under the MIT License
#pragma once
#include "dnnl_subgraph.h"
#include "dnnl_subgraph_primitive.h"
namespace onnxruntime {
namespace ort_dnnl {
class DnnlCast {
public:
enum InputTensors : int {
IN_INPUT = 0
};
enum OutputTensors : int {
OUT_OUTPUT = 0
};
DnnlCast();
void CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node);
private:
int64_t GetTo(DnnlNode& node);
};
} // namespace ort_dnnl
} // namespace onnxruntime

View file

@ -5,6 +5,7 @@
#include "dnnl_batchnorm.h"
#include "dnnl_binary.h"
#include "dnnl_cast.h"
#include "dnnl_conv.h"
#include "dnnl_dynamicquantizelinear.h"
#include "dnnl_elementwise.h"
@ -135,6 +136,8 @@ void DnnlSubgraphPrimitive::AddKernels() {
DnnlBatchNorm().CreatePrimitive(*this, node);
} else if (binary_ops.count(node.OpType())) {
DnnlBinary().CreatePrimitive(*this, node);
} else if (node.OpType() == "Cast") {
DnnlCast().CreatePrimitive(*this, node);
} else if (node.OpType() == "Conv" || node.OpType() == "ConvRelu") {
DnnlConv().CreatePrimitive(*this, node);
} else if (node.OpType() == "DynamicQuantizeLinear") {