[VitisAI] Align TensorProto_DataType with onnx1.16 (#21067)

### Description
Vitis AI EP synchronously supports the TensorProto data types supported
by ONNX 1.16.
Add error message show when graph resolve fail for troubleshooting.


### Motivation and Context
ONNX 1.15 & 1.16 add support some new TensorProto DataType , such as 
- FLOAT8E4M3FN
- FLOAT8E4M3FNUZ
- FLOAT8E5M2
- FLOAT8E5M2FNUZ
- UINT4
- INT4

---------

Co-authored-by: liumingyue <mingyue@xilinx.com>
This commit is contained in:
mingyueliuh 2024-06-28 20:19:20 -04:00 committed by GitHub
parent 6baaaf5165
commit 7e93cd7f8b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 35 additions and 2 deletions

View file

@ -270,6 +270,9 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
graph.SetGraphResolveNeeded();
}
auto status = graph.Resolve();
if (!status.IsOK()) {
std::cerr << "graph resolve error:" << status.ErrorMessage() << std::endl;
}
return status.Code();
};
the_global_api.graph_get_consumer_nodes_unsafe = [](const Graph& graph, const std::string& node_arg_name) -> auto {

View file

@ -38,7 +38,13 @@ enum TensorProto_DataType : int {
TensorProto_DataType_UINT64 = 13,
TensorProto_DataType_COMPLEX64 = 14,
TensorProto_DataType_COMPLEX128 = 15,
TensorProto_DataType_BFLOAT16 = 16
TensorProto_DataType_BFLOAT16 = 16,
TensorProto_DataType_FLOAT8E4M3FN = 17,
TensorProto_DataType_FLOAT8E4M3FNUZ = 18,
TensorProto_DataType_FLOAT8E5M2 = 19,
TensorProto_DataType_FLOAT8E5M2FNUZ = 20,
TensorProto_DataType_UINT4 = 21,
TensorProto_DataType_INT4 = 22
};
enum AttributeProto_AttributeType : int {
AttributeProto_AttributeType_UNDEFINED = 0,

View file

@ -13,7 +13,7 @@ struct OrtApi;
namespace vaip_core {
#define VAIP_ORT_API_MAJOR (3u)
#define VAIP_ORT_API_MINOR (0u)
#define VAIP_ORT_API_MINOR (1u)
#define VAIP_ORT_API_PATCH (0u)
struct OrtApiForVaip {
uint32_t magic; // 'VAIP' or something else to make sure the following field

View file

@ -613,8 +613,12 @@ struct ProviderHostImpl : ProviderHost {
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT8;
} else if (data_type->s() == "int32") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_INT32;
} else if (data_type->s() == "uint32") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT32;
} else if (data_type->s() == "int64") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_INT64;
} else if (data_type->s() == "uint64") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT64;
} else if (data_type->s() == "int1") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_BOOL;
} else if (data_type->s() == "bfloat16") {
@ -625,6 +629,26 @@ struct ProviderHostImpl : ProviderHost {
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT16;
} else if (data_type->s() == "int16") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_INT16;
} else if (data_type->s() == "double") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
} else if (data_type->s() == "string") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_STRING;
} else if (data_type->s() == "complex64") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64;
} else if (data_type->s() == "complex128") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128;
} else if (data_type->s() == "float8e4m3fn") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN;
} else if (data_type->s() == "float8e4m3fnuz") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ;
} else if (data_type->s() == "float8e5m2") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2;
} else if (data_type->s() == "float8e5m2funz") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ;
} else if (data_type->s() == "uint4") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT4;
} else if (data_type->s() == "int4") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_INT4;
} else {
return;
}