mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-26 03:00:54 +00:00
Enabled Cast operator on OneDNN EP (#11023)
This commit is contained in:
parent
6a6840d5c6
commit
6c005bfdbc
6 changed files with 236 additions and 0 deletions
|
|
@ -849,5 +849,75 @@ bool DnnlQAttentionNodeCapability::IsDimensionSupported(const Node* node) const
|
|||
return true;
|
||||
}
|
||||
|
||||
// DnnlCastNodeCapability class
|
||||
//-------------------------------------
|
||||
DnnlCastNodeCapability::DnnlCastNodeCapability(std::vector<ORT_DataType> validTypes)
|
||||
: DnnlDefaultNodeCapability(validTypes) {
|
||||
for (ORT_DataType datatype : validTypes)
|
||||
validTypes_.push_back(datatype);
|
||||
}
|
||||
|
||||
bool DnnlCastNodeCapability::Supported(const Node* node, const GraphViewer& graph_viewer) const {
|
||||
ORT_UNUSED_PARAMETER(graph_viewer);
|
||||
if (!IsTypeSupported(node) || !IsCastSupported(node)) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DnnlCastNodeCapability::IsCastSupported(const Node* node) const {
|
||||
// Get input and attributes
|
||||
const NodeAttributes& node_attr = node->GetAttributes();
|
||||
auto node_input = node->InputDefs();
|
||||
auto attr_to = node_attr.find("to");
|
||||
|
||||
// If we have valid results
|
||||
if (!node_input.empty() &&
|
||||
node_input[0]->TypeAsProto() != nullptr &&
|
||||
attr_to != node_attr.end()) {
|
||||
|
||||
// Get the input and cast target type
|
||||
auto input_type = node_input[0]->TypeAsProto()->tensor_type().elem_type();
|
||||
auto cast_type = attr_to->second().i();
|
||||
|
||||
// Some FP16 operations are not supported yet on CPU
|
||||
#if defined(DNNL_CPU_RUNTIME)
|
||||
// From uint8 and int8 => To Float16
|
||||
if ((input_type == type_uint8 ||
|
||||
input_type == type_int8) &&
|
||||
cast_type == type_float16) {
|
||||
return false;
|
||||
}
|
||||
// From uint32 => To Float16 and BFloat16
|
||||
if (input_type == type_int32 &&
|
||||
(cast_type == type_float16 ||
|
||||
cast_type == type_bfloat16)) {
|
||||
return false;
|
||||
}
|
||||
// From Float16 => To int(uint8, int8 and int32) and BFloat16
|
||||
if (input_type == type_float16 &&
|
||||
(cast_type == type_uint8 ||
|
||||
cast_type == type_int8 ||
|
||||
cast_type == type_int32 ||
|
||||
cast_type == type_bfloat16)) {
|
||||
return false;
|
||||
}
|
||||
// From BFloat16 => To int32 and BFloat16
|
||||
if (input_type == type_bfloat16 &&
|
||||
(cast_type == type_int32 ||
|
||||
cast_type == type_float16)) {
|
||||
return false;
|
||||
}
|
||||
#endif // defined(DNNL_CPU_RUNTIME)
|
||||
|
||||
// Check if the cast type is supported
|
||||
for (auto validType : validTypes_) {
|
||||
if (validType == cast_type) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -337,5 +337,26 @@ class DnnlQAttentionNodeCapability : public DnnlDefaultNodeCapability {
|
|||
private:
|
||||
bool IsDimensionSupported(const Node* node) const;
|
||||
};
|
||||
|
||||
/**
|
||||
* We need to access the valid input types in order to check if the cast target
|
||||
* is valid, and since inputTypes_ is private, we generate our own copy
|
||||
*/
|
||||
class DnnlCastNodeCapability : public DnnlDefaultNodeCapability {
|
||||
public:
|
||||
DnnlCastNodeCapability(std::vector<ORT_DataType> validTypes = {type_float32,
|
||||
type_float16,
|
||||
type_bfloat16,
|
||||
type_int32,
|
||||
type_int8,
|
||||
type_uint8});
|
||||
|
||||
bool Supported(const Node* node, const GraphViewer& graph_viewer) const override;
|
||||
|
||||
private:
|
||||
std::vector<ORT_DataType> validTypes_;
|
||||
bool IsCastSupported(const Node* node) const;
|
||||
};
|
||||
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ DnnlOpManager::DnnlOpManager() {
|
|||
dnnl_ops_map_.emplace(std::make_pair("AveragePool", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("BatchNormalization", std::unique_ptr<DnnlNodeCapability>(new DnnlBatchNormalizationNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("BiasGelu", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Cast", std::unique_ptr<DnnlNodeCapability>(new DnnlCastNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Conv", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Div", std::unique_ptr<DnnlNodeCapability>(new DnnlBinaryNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("DynamicQuantizeLinear", std::unique_ptr<DnnlNodeCapability>(new DnnlDynamicQuantizeLinearNodeCapability())));
|
||||
|
|
|
|||
112
onnxruntime/core/providers/dnnl/subgraph/dnnl_cast.cc
Normal file
112
onnxruntime/core/providers/dnnl/subgraph/dnnl_cast.cc
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
// Copyright(C) 2021 Intel Corporation
|
||||
// Licensed under the MIT License
|
||||
|
||||
#include "dnnl_cast.h"
|
||||
#include "dnnl_subgraph.h"
|
||||
#include "dnnl_subgraph_primitive.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace ort_dnnl {
|
||||
|
||||
DnnlCast::DnnlCast() {}
|
||||
|
||||
void DnnlCast::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) {
|
||||
|
||||
// Get the DNNL engine
|
||||
auto dnnl_engine = sp.GetEngine();
|
||||
|
||||
// Get the memory from the input node
|
||||
auto src_mem = sp.GetMemory(node.Input(IN_INPUT));
|
||||
auto src_tag = node.Input(IN_INPUT).Format();
|
||||
auto src_md = src_mem.get_desc();
|
||||
auto src_dims = src_md.dims();
|
||||
|
||||
// dst characteristics
|
||||
dnnl::memory::data_type dst_type;
|
||||
dnnl::memory::format_tag dst_tag;
|
||||
|
||||
// Get the target data type
|
||||
auto dst_type_desc = GetTo(node);
|
||||
|
||||
// Check fot the target datat ype
|
||||
switch (dst_type_desc) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
|
||||
dst_type = dnnl::memory::data_type::f32;
|
||||
break;
|
||||
}
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: {
|
||||
dst_type = dnnl::memory::data_type::f16;
|
||||
break;
|
||||
}
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: {
|
||||
dst_type = dnnl::memory::data_type::bf16;
|
||||
break;
|
||||
}
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT32: {
|
||||
dst_type = dnnl::memory::data_type::s32;
|
||||
break;
|
||||
}
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT8: {
|
||||
dst_type = dnnl::memory::data_type::s8;
|
||||
break;
|
||||
}
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8: {
|
||||
dst_type = dnnl::memory::data_type::u8;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
ORT_THROW("Unsupported data type: ", dst_type_desc);
|
||||
break;
|
||||
}
|
||||
// Be aware that the output memory will be in plain format
|
||||
// and depending on the operation you do next, this wont be as
|
||||
// efficient as you'd like
|
||||
// If the format tag is any
|
||||
if (src_tag == dnnl::memory::format_tag::any) {
|
||||
// Define a plain data ND format
|
||||
dst_tag = sp.GetDnnlFormat(src_dims.size());
|
||||
} else {
|
||||
// Else use the same as the source
|
||||
dst_tag = src_tag;
|
||||
}
|
||||
|
||||
// Generate the dst memory descriptor
|
||||
auto dst_md = dnnl::memory::desc(src_md.dims(), dst_type, dst_tag);
|
||||
|
||||
// Create the reorder primitive descriptor.
|
||||
auto reorder_pd = dnnl::reorder::primitive_desc(dnnl_engine, src_md, dnnl_engine, dst_md);
|
||||
// Get the dst memory
|
||||
auto dst_mem = dnnl::memory(reorder_pd.dst_desc(), dnnl_engine);
|
||||
|
||||
// If using GPU this will move the memory from the CPU to the GPU.
|
||||
src_mem = sp.GetMemoryAndReshape(node.Input(IN_INPUT), reorder_pd.src_desc(), dnnl_engine);
|
||||
|
||||
// OneDNN uses reorder to cast the src_md data to the dst_md data type
|
||||
auto reorder = dnnl::reorder(reorder_pd);
|
||||
|
||||
// Add primitive to the graph
|
||||
sp.AddPrimitive(reorder, {{DNNL_ARG_SRC, src_mem},
|
||||
{DNNL_ARG_DST, dst_mem}});
|
||||
|
||||
// Support scalar return values
|
||||
if (sp.IsScalar(node.Input(OUT_OUTPUT))) {
|
||||
sp.SetMemory(node.Output(OUT_OUTPUT), dst_mem, false, true);
|
||||
} else {
|
||||
sp.SetMemory(node.Output(OUT_OUTPUT), dst_mem);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
int64_t DnnlCast::GetTo(DnnlNode& node) {
|
||||
// Get the attribute
|
||||
auto attr = node.Attributes().find("to");
|
||||
if (attr != node.Attributes().end()) {
|
||||
return attr->second().i();
|
||||
} else {
|
||||
// to attribute should always exist in order to cast
|
||||
ORT_THROW("TO(CAST TARGET DATA TYPE) DOES NOT EXIST");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ort_dnnl
|
||||
} // namespace onnxruntime
|
||||
29
onnxruntime/core/providers/dnnl/subgraph/dnnl_cast.h
Normal file
29
onnxruntime/core/providers/dnnl/subgraph/dnnl_cast.h
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
// Copyright(C) 2021 Intel Corporation
|
||||
// Licensed under the MIT License
|
||||
|
||||
#pragma once
|
||||
#include "dnnl_subgraph.h"
|
||||
#include "dnnl_subgraph_primitive.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace ort_dnnl {
|
||||
|
||||
class DnnlCast {
|
||||
public:
|
||||
enum InputTensors : int {
|
||||
IN_INPUT = 0
|
||||
};
|
||||
|
||||
enum OutputTensors : int {
|
||||
OUT_OUTPUT = 0
|
||||
};
|
||||
|
||||
DnnlCast();
|
||||
void CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node);
|
||||
|
||||
private:
|
||||
int64_t GetTo(DnnlNode& node);
|
||||
};
|
||||
|
||||
} // namespace ort_dnnl
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
#include "dnnl_batchnorm.h"
|
||||
#include "dnnl_binary.h"
|
||||
#include "dnnl_cast.h"
|
||||
#include "dnnl_conv.h"
|
||||
#include "dnnl_dynamicquantizelinear.h"
|
||||
#include "dnnl_elementwise.h"
|
||||
|
|
@ -135,6 +136,8 @@ void DnnlSubgraphPrimitive::AddKernels() {
|
|||
DnnlBatchNorm().CreatePrimitive(*this, node);
|
||||
} else if (binary_ops.count(node.OpType())) {
|
||||
DnnlBinary().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "Cast") {
|
||||
DnnlCast().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "Conv" || node.OpType() == "ConvRelu") {
|
||||
DnnlConv().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "DynamicQuantizeLinear") {
|
||||
|
|
|
|||
Loading…
Reference in a new issue