mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
Enabled fp16 for input types (#5878)
Signed-off-by: MaajidKhan <n.maajidkhan@gmail.com> Co-authored-by: S. Manohar Karlapalem <manohar.karlapalem@intel.com>
This commit is contained in:
parent
1068f3eb87
commit
b057b3d36e
1 changed files with 29 additions and 5 deletions
|
|
@ -492,7 +492,7 @@ static bool IsUnsupportedOpMode(const Node* node, const onnxruntime::GraphViewer
|
|||
return false;
|
||||
} else if(optype == "Gather") {
|
||||
|
||||
if(device_id == "GPU"){
|
||||
if(device_id.find("GPU") != std::string::npos){
|
||||
const auto& input = node->InputDefs()[0];
|
||||
auto graph_inputs = graph_viewer.GetInputs();
|
||||
auto it = find(graph_inputs.begin(), graph_inputs.end(), input);
|
||||
|
|
@ -544,6 +544,7 @@ static bool IsTypeSupported(const NodeArg* node_arg, bool is_initializer, const
|
|||
switch (type_proto->tensor_type().elem_type()) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32:
|
||||
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64:
|
||||
return true;
|
||||
|
|
@ -557,6 +558,17 @@ static bool IsTypeSupported(const NodeArg* node_arg, bool is_initializer, const
|
|||
return false;
|
||||
}
|
||||
} else {
|
||||
std::set<int> supported_types_vpu = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL,
|
||||
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16,
|
||||
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32,
|
||||
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16,
|
||||
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8,
|
||||
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8,
|
||||
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64,
|
||||
};
|
||||
|
||||
std::set<int> supported_types_cpu = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL,
|
||||
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT,
|
||||
|
|
@ -569,12 +581,24 @@ static bool IsTypeSupported(const NodeArg* node_arg, bool is_initializer, const
|
|||
|
||||
std::set<int> supported_types_gpu = {
|
||||
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16,
|
||||
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32,
|
||||
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64,
|
||||
};
|
||||
auto dtype = type_proto->tensor_type().elem_type();
|
||||
|
||||
if (device_id == "CPU" || device_id == "MYRIAD" || device_id == "HDDL" || device_id.find("HETERO") != std::string::npos) {
|
||||
if (device_id == "MYRIAD" || device_id == "HDDL" || device_id.find("HETERO") != std::string::npos || device_id.find("MULTI") != std::string::npos) {
|
||||
if (supported_types_vpu.find(dtype) != supported_types_vpu.end())
|
||||
return true;
|
||||
else {
|
||||
#ifndef NDEBUG
|
||||
if (openvino_ep::backend_utils::IsDebugEnabled()) {
|
||||
std::cout << "I/O data type is not supported" << std::endl;
|
||||
}
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
} else if (device_id == "CPU") {
|
||||
if (supported_types_cpu.find(dtype) != supported_types_cpu.end())
|
||||
return true;
|
||||
else {
|
||||
|
|
@ -805,7 +829,7 @@ GetCapability_2021_1(const onnxruntime::GraphViewer& graph_viewer, std::string d
|
|||
return result;
|
||||
} else if ((node->OpType() == "Greater") || (node->OpType() == "Less")) {
|
||||
|
||||
if (device_id == "MYRIAD") {
|
||||
if (device_id.find("MYRIAD") != std::string::npos) {
|
||||
|
||||
auto input_0_data_type = (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
|
||||
auto input_1_data_type = (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[1]->TypeAsProto()->tensor_type().elem_type();
|
||||
|
|
@ -852,7 +876,7 @@ GetCapability_2021_1(const onnxruntime::GraphViewer& graph_viewer, std::string d
|
|||
modified_unsupported_nodes.push_back(node_idx);
|
||||
}
|
||||
if(optype == "Gather"){
|
||||
if(device_id == "MYRIAD"){
|
||||
if(device_id.find("MYRIAD") != std::string::npos){
|
||||
auto input_data_type = node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
|
||||
if(input_data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8){
|
||||
modified_unsupported_nodes.push_back(node_idx);
|
||||
|
|
@ -895,7 +919,7 @@ GetCapability_2021_1(const onnxruntime::GraphViewer& graph_viewer, std::string d
|
|||
optype == "Cast" || optype == "Concat" || optype == "Gather" ||
|
||||
optype == "Div" || optype == "Sub" || optype == "Identity") {
|
||||
|
||||
if(optype == "Identity" && device_id != "CPU")
|
||||
if(optype == "Identity" && device_id.find("CPU") == std::string::npos)
|
||||
continue;
|
||||
|
||||
if((optype == "Div" || optype == "Sub") && (device_id.find("MYRIAD") == std::string::npos && device_id.find("GPU") == std::string::npos))
|
||||
|
|
|
|||
Loading…
Reference in a new issue