mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
[VitisAI] Align TensorProto_DataType with onnx1.16 (#21067)
### Description Vitis AI EP synchronously supports the TensorProto data types supported by ONNX 1.16. Add error message show when graph resolve fail for troubleshooting. ### Motivation and Context ONNX 1.15 & 1.16 add support some new TensorProto DataType , such as - FLOAT8E4M3FN - FLOAT8E4M3FNUZ - FLOAT8E5M2 - FLOAT8E5M2FNUZ - UINT4 - INT4 --------- Co-authored-by: liumingyue <mingyue@xilinx.com>
This commit is contained in:
parent
6baaaf5165
commit
7e93cd7f8b
4 changed files with 35 additions and 2 deletions
|
|
@ -270,6 +270,9 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
|
|||
graph.SetGraphResolveNeeded();
|
||||
}
|
||||
auto status = graph.Resolve();
|
||||
if (!status.IsOK()) {
|
||||
std::cerr << "graph resolve error:" << status.ErrorMessage() << std::endl;
|
||||
}
|
||||
return status.Code();
|
||||
};
|
||||
the_global_api.graph_get_consumer_nodes_unsafe = [](const Graph& graph, const std::string& node_arg_name) -> auto {
|
||||
|
|
|
|||
|
|
@ -38,7 +38,13 @@ enum TensorProto_DataType : int {
|
|||
TensorProto_DataType_UINT64 = 13,
|
||||
TensorProto_DataType_COMPLEX64 = 14,
|
||||
TensorProto_DataType_COMPLEX128 = 15,
|
||||
TensorProto_DataType_BFLOAT16 = 16
|
||||
TensorProto_DataType_BFLOAT16 = 16,
|
||||
TensorProto_DataType_FLOAT8E4M3FN = 17,
|
||||
TensorProto_DataType_FLOAT8E4M3FNUZ = 18,
|
||||
TensorProto_DataType_FLOAT8E5M2 = 19,
|
||||
TensorProto_DataType_FLOAT8E5M2FNUZ = 20,
|
||||
TensorProto_DataType_UINT4 = 21,
|
||||
TensorProto_DataType_INT4 = 22
|
||||
};
|
||||
enum AttributeProto_AttributeType : int {
|
||||
AttributeProto_AttributeType_UNDEFINED = 0,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ struct OrtApi;
|
|||
namespace vaip_core {
|
||||
|
||||
#define VAIP_ORT_API_MAJOR (3u)
|
||||
#define VAIP_ORT_API_MINOR (0u)
|
||||
#define VAIP_ORT_API_MINOR (1u)
|
||||
#define VAIP_ORT_API_PATCH (0u)
|
||||
struct OrtApiForVaip {
|
||||
uint32_t magic; // 'VAIP' or something else to make sure the following field
|
||||
|
|
|
|||
|
|
@ -613,8 +613,12 @@ struct ProviderHostImpl : ProviderHost {
|
|||
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT8;
|
||||
} else if (data_type->s() == "int32") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_INT32;
|
||||
} else if (data_type->s() == "uint32") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT32;
|
||||
} else if (data_type->s() == "int64") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_INT64;
|
||||
} else if (data_type->s() == "uint64") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT64;
|
||||
} else if (data_type->s() == "int1") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_BOOL;
|
||||
} else if (data_type->s() == "bfloat16") {
|
||||
|
|
@ -625,6 +629,26 @@ struct ProviderHostImpl : ProviderHost {
|
|||
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT16;
|
||||
} else if (data_type->s() == "int16") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_INT16;
|
||||
} else if (data_type->s() == "double") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
|
||||
} else if (data_type->s() == "string") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_STRING;
|
||||
} else if (data_type->s() == "complex64") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64;
|
||||
} else if (data_type->s() == "complex128") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128;
|
||||
} else if (data_type->s() == "float8e4m3fn") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN;
|
||||
} else if (data_type->s() == "float8e4m3fnuz") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ;
|
||||
} else if (data_type->s() == "float8e5m2") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2;
|
||||
} else if (data_type->s() == "float8e5m2funz") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ;
|
||||
} else if (data_type->s() == "uint4") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT4;
|
||||
} else if (data_type->s() == "int4") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_INT4;
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue