diff --git a/cgmanifests/cgmanifest.json b/cgmanifests/cgmanifest.json index 22407e398d..64159368f3 100644 --- a/cgmanifests/cgmanifest.json +++ b/cgmanifests/cgmanifest.json @@ -37,7 +37,7 @@ "component": { "type": "git", "git": { - "commitHash": "3425ed846308a456f98404c79f6df1693bed6377", + "commitHash": "2379917985919ed3918dc12cad47f469f245be7a", "repositoryUrl": "https://github.com/apache/tvm.git" }, "comments": "needed for TVM EP" diff --git a/cmake/external/tvm.cmake b/cmake/external/tvm.cmake index 61db398df6..1e224a2dad 100644 --- a/cmake/external/tvm.cmake +++ b/cmake/external/tvm.cmake @@ -4,7 +4,7 @@ if (onnxruntime_USE_TVM) FetchContent_Declare( tvm GIT_REPOSITORY https://github.com/apache/tvm.git - GIT_TAG 3425ed846308a456f98404c79f6df1693bed6377 + GIT_TAG 2379917985919ed3918dc12cad47f469f245be7a ) FetchContent_GetProperties(tvm) diff --git a/onnxruntime/core/providers/tvm/tvm_utils.h b/onnxruntime/core/providers/tvm/tvm_utils.h index 39e9c75110..e9b98e3b1e 100644 --- a/onnxruntime/core/providers/tvm/tvm_utils.h +++ b/onnxruntime/core/providers/tvm/tvm_utils.h @@ -17,27 +17,36 @@ namespace onnxruntime { namespace tvm { inline DLDataType GetDataType(ONNXTensorElementDataType type) { - if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { - return {kDLFloat, 64, 1}; - } else if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { - return {kDLFloat, 16, 1}; - } else if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { - return {kDLFloat, 32, 1}; - } else if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { - return {kDLInt, 64, 1}; - } else if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { - return {kDLInt, 32, 1}; - } else if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL) { - return {kDLUInt, 1, 1}; - } else { - ORT_NOT_IMPLEMENTED("Unsupported data type"); + switch (type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return {kDLUInt, 8, 1}; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + return {kDLInt, 8, 1}; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + return {kDLUInt, 16, 1}; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + return {kDLInt, 16, 1}; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return {kDLUInt, 32, 1}; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return {kDLInt, 32, 1}; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + return {kDLUInt, 64, 1}; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + return {kDLInt, 64, 1}; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return {kDLFloat, 16, 1}; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return {kDLFloat, 32, 1}; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + return {kDLFloat, 64, 1}; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + return {kDLUInt, 1, 1}; + default: + ORT_NOT_IMPLEMENTED("Unsupported data type"); } } -inline DLDataType GetDataTypeFromProto() { - return {kDLFloat, 32, 1}; -} - inline DLDevice GetDLDevice(OrtMemoryInfoDeviceType device_type) { DLDevice context; switch (device_type) {