Enabled fp16 for input types (#5878)

Signed-off-by: MaajidKhan <n.maajidkhan@gmail.com>

Co-authored-by: S. Manohar Karlapalem <manohar.karlapalem@intel.com>
This commit is contained in:
Maajid khan 2020-11-20 13:49:49 +05:30 committed by GitHub
parent 1068f3eb87
commit b057b3d36e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -492,7 +492,7 @@ static bool IsUnsupportedOpMode(const Node* node, const onnxruntime::GraphViewer
return false;
} else if(optype == "Gather") {
if(device_id == "GPU"){
if(device_id.find("GPU") != std::string::npos){
const auto& input = node->InputDefs()[0];
auto graph_inputs = graph_viewer.GetInputs();
auto it = find(graph_inputs.begin(), graph_inputs.end(), input);
@ -544,6 +544,7 @@ static bool IsTypeSupported(const NodeArg* node_arg, bool is_initializer, const
switch (type_proto->tensor_type().elem_type()) {
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64:
return true;
@ -557,6 +558,17 @@ static bool IsTypeSupported(const NodeArg* node_arg, bool is_initializer, const
return false;
}
} else {
std::set<int> supported_types_vpu = {
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL,
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16,
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32,
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16,
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8,
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8,
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64,
};
std::set<int> supported_types_cpu = {
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL,
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT,
@ -569,12 +581,24 @@ static bool IsTypeSupported(const NodeArg* node_arg, bool is_initializer, const
std::set<int> supported_types_gpu = {
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16,
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32,
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64,
};
auto dtype = type_proto->tensor_type().elem_type();
if (device_id == "CPU" || device_id == "MYRIAD" || device_id == "HDDL" || device_id.find("HETERO") != std::string::npos) {
if (device_id == "MYRIAD" || device_id == "HDDL" || device_id.find("HETERO") != std::string::npos || device_id.find("MULTI") != std::string::npos) {
if (supported_types_vpu.find(dtype) != supported_types_vpu.end())
return true;
else {
#ifndef NDEBUG
if (openvino_ep::backend_utils::IsDebugEnabled()) {
std::cout << "I/O data type is not supported" << std::endl;
}
#endif
return false;
}
} else if (device_id == "CPU") {
if (supported_types_cpu.find(dtype) != supported_types_cpu.end())
return true;
else {
@ -805,7 +829,7 @@ GetCapability_2021_1(const onnxruntime::GraphViewer& graph_viewer, std::string d
return result;
} else if ((node->OpType() == "Greater") || (node->OpType() == "Less")) {
if (device_id == "MYRIAD") {
if (device_id.find("MYRIAD") != std::string::npos) {
auto input_0_data_type = (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
auto input_1_data_type = (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[1]->TypeAsProto()->tensor_type().elem_type();
@ -852,7 +876,7 @@ GetCapability_2021_1(const onnxruntime::GraphViewer& graph_viewer, std::string d
modified_unsupported_nodes.push_back(node_idx);
}
if(optype == "Gather"){
if(device_id == "MYRIAD"){
if(device_id.find("MYRIAD") != std::string::npos){
auto input_data_type = node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
if(input_data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8){
modified_unsupported_nodes.push_back(node_idx);
@ -895,7 +919,7 @@ GetCapability_2021_1(const onnxruntime::GraphViewer& graph_viewer, std::string d
optype == "Cast" || optype == "Concat" || optype == "Gather" ||
optype == "Div" || optype == "Sub" || optype == "Identity") {
if(optype == "Identity" && device_id != "CPU")
if(optype == "Identity" && device_id.find("CPU") == std::string::npos)
continue;
if((optype == "Div" || optype == "Sub") && (device_id.find("MYRIAD") == std::string::npos && device_id.find("GPU") == std::string::npos))