mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[tvm] Add support for int8 models, update TVM revision (#13519)
### Description In the TVM EP, this adds more entries to the conversion from `ONNXTensorElementDataType` to `DLDataType`. Additionally, it removes an unused function and updates the TVM revision to allow running models from recent revisions of TVM. ### Motivation and Context In the TVM EP, the mapping from `ONNXTensorElementDataType` to `DLDataType` was incomplete and neglected several integer types (in particular `ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8` and `ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8`) which prevented some models from running. Co-authored-by: Peter Salas <psalas@octoml.ai>
This commit is contained in:
parent
9e65f3bfdb
commit
b383312f4c
3 changed files with 29 additions and 20 deletions
|
|
@ -37,7 +37,7 @@
|
|||
"component": {
|
||||
"type": "git",
|
||||
"git": {
|
||||
"commitHash": "3425ed846308a456f98404c79f6df1693bed6377",
|
||||
"commitHash": "2379917985919ed3918dc12cad47f469f245be7a",
|
||||
"repositoryUrl": "https://github.com/apache/tvm.git"
|
||||
},
|
||||
"comments": "needed for TVM EP"
|
||||
|
|
|
|||
2
cmake/external/tvm.cmake
vendored
2
cmake/external/tvm.cmake
vendored
|
|
@ -4,7 +4,7 @@ if (onnxruntime_USE_TVM)
|
|||
FetchContent_Declare(
|
||||
tvm
|
||||
GIT_REPOSITORY https://github.com/apache/tvm.git
|
||||
GIT_TAG 3425ed846308a456f98404c79f6df1693bed6377
|
||||
GIT_TAG 2379917985919ed3918dc12cad47f469f245be7a
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(tvm)
|
||||
|
|
|
|||
|
|
@ -17,27 +17,36 @@ namespace onnxruntime {
|
|||
namespace tvm {
|
||||
|
||||
inline DLDataType GetDataType(ONNXTensorElementDataType type) {
|
||||
if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
|
||||
return {kDLFloat, 64, 1};
|
||||
} else if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
|
||||
return {kDLFloat, 16, 1};
|
||||
} else if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
|
||||
return {kDLFloat, 32, 1};
|
||||
} else if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
|
||||
return {kDLInt, 64, 1};
|
||||
} else if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
|
||||
return {kDLInt, 32, 1};
|
||||
} else if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL) {
|
||||
return {kDLUInt, 1, 1};
|
||||
} else {
|
||||
ORT_NOT_IMPLEMENTED("Unsupported data type");
|
||||
switch (type) {
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
return {kDLUInt, 8, 1};
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
|
||||
return {kDLInt, 8, 1};
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
|
||||
return {kDLUInt, 16, 1};
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
|
||||
return {kDLInt, 16, 1};
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
|
||||
return {kDLUInt, 32, 1};
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
|
||||
return {kDLInt, 32, 1};
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
|
||||
return {kDLUInt, 64, 1};
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
|
||||
return {kDLInt, 64, 1};
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
|
||||
return {kDLFloat, 16, 1};
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
||||
return {kDLFloat, 32, 1};
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
|
||||
return {kDLFloat, 64, 1};
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
|
||||
return {kDLUInt, 1, 1};
|
||||
default:
|
||||
ORT_NOT_IMPLEMENTED("Unsupported data type");
|
||||
}
|
||||
}
|
||||
|
||||
inline DLDataType GetDataTypeFromProto() {
|
||||
return {kDLFloat, 32, 1};
|
||||
}
|
||||
|
||||
inline DLDevice GetDLDevice(OrtMemoryInfoDeviceType device_type) {
|
||||
DLDevice context;
|
||||
switch (device_type) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue