[tvm] Add support for int8 models, update TVM revision (#13519)

### Description
In the TVM EP, this adds more entries to the conversion from
`ONNXTensorElementDataType` to `DLDataType`. Additionally, it removes an
unused function and updates the TVM revision to allow running models
from recent revisions of TVM.

### Motivation and Context
In the TVM EP, the mapping from `ONNXTensorElementDataType` to
`DLDataType` was incomplete and neglected several integer types (in
particular `ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8` and
`ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8`) which prevented some models from
running.

Co-authored-by: Peter Salas <psalas@octoml.ai>
This commit is contained in:
Peter Salas 2022-11-08 11:28:32 -08:00 committed by GitHub
parent 9e65f3bfdb
commit b383312f4c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 20 deletions

View file

@ -37,7 +37,7 @@
"component": {
"type": "git",
"git": {
"commitHash": "3425ed846308a456f98404c79f6df1693bed6377",
"commitHash": "2379917985919ed3918dc12cad47f469f245be7a",
"repositoryUrl": "https://github.com/apache/tvm.git"
},
"comments": "needed for TVM EP"

View file

@ -4,7 +4,7 @@ if (onnxruntime_USE_TVM)
FetchContent_Declare(
tvm
GIT_REPOSITORY https://github.com/apache/tvm.git
GIT_TAG 3425ed846308a456f98404c79f6df1693bed6377
GIT_TAG 2379917985919ed3918dc12cad47f469f245be7a
)
FetchContent_GetProperties(tvm)

View file

@ -17,27 +17,36 @@ namespace onnxruntime {
namespace tvm {
inline DLDataType GetDataType(ONNXTensorElementDataType type) {
if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
return {kDLFloat, 64, 1};
} else if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
return {kDLFloat, 16, 1};
} else if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
return {kDLFloat, 32, 1};
} else if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
return {kDLInt, 64, 1};
} else if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
return {kDLInt, 32, 1};
} else if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL) {
return {kDLUInt, 1, 1};
} else {
ORT_NOT_IMPLEMENTED("Unsupported data type");
switch (type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
return {kDLUInt, 8, 1};
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
return {kDLInt, 8, 1};
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
return {kDLUInt, 16, 1};
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
return {kDLInt, 16, 1};
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
return {kDLUInt, 32, 1};
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
return {kDLInt, 32, 1};
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
return {kDLUInt, 64, 1};
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
return {kDLInt, 64, 1};
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
return {kDLFloat, 16, 1};
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
return {kDLFloat, 32, 1};
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
return {kDLFloat, 64, 1};
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
return {kDLUInt, 1, 1};
default:
ORT_NOT_IMPLEMENTED("Unsupported data type");
}
}
inline DLDataType GetDataTypeFromProto() {
return {kDLFloat, 32, 1};
}
inline DLDevice GetDLDevice(OrtMemoryInfoDeviceType device_type) {
DLDevice context;
switch (device_type) {